Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions sycl/source/detail/device_image_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -549,11 +549,8 @@ class device_image_impl

devices_range get_devices() const noexcept { return MDevices; }

bool compatible_with_device(const device &Dev) const {
return std::any_of(MDevices.begin(), MDevices.end(),
[Dev = &*getSyclObjImpl(Dev)](device_impl *DevCand) {
return Dev == DevCand;
});
bool compatible_with_device(device_impl &Dev) const {
return get_devices().contains(Dev);
}

const ur_program_handle_t &get_ur_program_ref() const noexcept {
Expand Down
8 changes: 8 additions & 0 deletions sycl/source/detail/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class variadic_iterator {
variadic_iterator(const variadic_iterator &) = default;
variadic_iterator(variadic_iterator &&) = default;
variadic_iterator(variadic_iterator &) = default;
variadic_iterator &operator=(const variadic_iterator &) = default;
variadic_iterator &operator=(variadic_iterator &&) = default;

template <typename IterTy>
variadic_iterator(IterTy &&It) : It(std::forward<IterTy>(It)) {}
Expand Down Expand Up @@ -151,6 +153,12 @@ template <typename iterator> class iterator_range {
return Container{std::move(Result)};
}

bool contains(value_type &Other) const {
return std::find_if(begin(), end(), [&Other](value_type &Elem) {
return &Elem == &Other;
}) != end();
}

protected:
template <typename Container>
static constexpr bool has_reserve_v = has_reserve<Container>::value;
Expand Down
131 changes: 64 additions & 67 deletions sycl/source/detail/kernel_bundle_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,17 @@ bool is_source_kernel_bundle_supported(

namespace detail {

static bool checkAllDevicesAreInContext(const std::vector<device> &Devices,
inline bool checkAllDevicesAreInContext(devices_range Devices,
const context &Context) {
return std::all_of(
Devices.begin(), Devices.end(), [&Context](const device &Dev) {
return getSyclObjImpl(Context)->isDeviceValid(*getSyclObjImpl(Dev));
});
return std::all_of(Devices.begin(), Devices.end(),
[&Context](device_impl &Dev) {
return getSyclObjImpl(Context)->isDeviceValid(Dev);
});
}

static bool checkAllDevicesHaveAspect(const std::vector<device> &Devices,
aspect Aspect) {
inline bool checkAllDevicesHaveAspect(devices_range Devices, aspect Aspect) {
return std::all_of(Devices.begin(), Devices.end(),
[&Aspect](const device &Dev) { return Dev.has(Aspect); });
[&Aspect](device_impl &Dev) { return Dev.has(Aspect); });
}

namespace syclex = sycl::ext::oneapi::experimental;
Expand Down Expand Up @@ -100,9 +99,10 @@ class kernel_bundle_impl
}

public:
kernel_bundle_impl(context Ctx, std::vector<device> Devs, bundle_state State,
kernel_bundle_impl(context Ctx, devices_range Devs, bundle_state State,
private_tag)
: MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) {
: MContext(std::move(Ctx)),
MDevices(Devs.to<std::vector<device_impl *>>()), MState(State) {

common_ctor_checks();

Expand All @@ -112,8 +112,9 @@ class kernel_bundle_impl
}

// Interop constructor used by make_kernel
kernel_bundle_impl(context Ctx, std::vector<device> Devs, private_tag)
: MContext(Ctx), MDevices(Devs), MState(bundle_state::executable) {
kernel_bundle_impl(context Ctx, devices_range Devs, private_tag)
: MContext(Ctx), MDevices(Devs.to<std::vector<device_impl *>>()),
MState(bundle_state::executable) {
if (!checkAllDevicesAreInContext(Devs, Ctx))
throw sycl::exception(
make_error_code(errc::invalid),
Expand All @@ -122,9 +123,9 @@ class kernel_bundle_impl
}

// Interop constructor
kernel_bundle_impl(context Ctx, std::vector<device> Devs,
kernel_bundle_impl(context Ctx, devices_range Devs,
device_image_plain &DevImage, private_tag Tag)
: kernel_bundle_impl(std::move(Ctx), std::move(Devs), Tag) {
: kernel_bundle_impl(std::move(Ctx), Devs, Tag) {
MDeviceImages.emplace_back(DevImage);
MUniqueDeviceImages.emplace_back(DevImage);
}
Expand All @@ -133,22 +134,19 @@ class kernel_bundle_impl
// Have one constructor because sycl::build and sycl::compile have the same
// signature
kernel_bundle_impl(const kernel_bundle<bundle_state::input> &InputBundle,
std::vector<device> Devs, const property_list &PropList,
devices_range Devs, const property_list &PropList,
bundle_state TargetState, private_tag)
: MContext(InputBundle.get_context()), MDevices(std::move(Devs)),
MState(TargetState) {
: MContext(InputBundle.get_context()),
MDevices(Devs.to<std::vector<device_impl *>>()), MState(TargetState) {

kernel_bundle_impl &InputBundleImpl = *getSyclObjImpl(InputBundle);
MSpecConstValues = InputBundleImpl.get_spec_const_map_ref();

const std::vector<device> &InputBundleDevices =
InputBundleImpl.get_devices();
devices_range InputBundleDevices = InputBundleImpl.get_devices();
const bool AllDevsAssociatedWithInputBundle =
std::all_of(MDevices.begin(), MDevices.end(),
[&InputBundleDevices](const device &Dev) {
return InputBundleDevices.end() !=
std::find(InputBundleDevices.begin(),
InputBundleDevices.end(), Dev);
std::all_of(get_devices().begin(), get_devices().end(),
[&InputBundleDevices](device_impl &Dev) {
return InputBundleDevices.contains(Dev);
});
if (MDevices.empty() || !AllDevsAssociatedWithInputBundle)
throw sycl::exception(
Expand All @@ -163,8 +161,8 @@ class kernel_bundle_impl
for (const DevImgPlainWithDeps &DevImgWithDeps :
InputBundleImpl.MDeviceImages) {
// Skip images which are not compatible with devices provided
if (std::none_of(MDevices.begin(), MDevices.end(),
[&DevImgWithDeps](const device &Dev) {
if (std::none_of(get_devices().begin(), get_devices().end(),
[&DevImgWithDeps](device_impl &Dev) {
return getSyclObjImpl(DevImgWithDeps.getMain())
->compatible_with_device(Dev);
}))
Expand Down Expand Up @@ -206,8 +204,9 @@ class kernel_bundle_impl
// Matches sycl::link
kernel_bundle_impl(
const std::vector<kernel_bundle<bundle_state::object>> &ObjectBundles,
std::vector<device> Devs, const property_list &PropList, private_tag)
: MDevices(std::move(Devs)), MState(bundle_state::executable) {
devices_range Devs, const property_list &PropList, private_tag)
: MDevices(Devs.to<std::vector<device_impl *>>()),
MState(bundle_state::executable) {
if (MDevices.empty())
throw sycl::exception(make_error_code(errc::invalid),
"Vector of devices is empty");
Expand All @@ -226,16 +225,15 @@ class kernel_bundle_impl
// Check if any of the devices in devs are not in the set of associated
// devices for any of the bundles in ObjectBundles
const bool AllDevsAssociatedWithInputBundles = std::all_of(
MDevices.begin(), MDevices.end(), [&ObjectBundles](const device &Dev) {
get_devices().begin(), get_devices().end(),
[&ObjectBundles](device_impl &Dev) {
// Number of devices is expected to be small
return std::all_of(
ObjectBundles.begin(), ObjectBundles.end(),
[&Dev](const kernel_bundle<bundle_state::object> &KernelBundle) {
const std::vector<device> &BundleDevices =
devices_range BundleDevices =
getSyclObjImpl(KernelBundle)->get_devices();
return BundleDevices.end() != std::find(BundleDevices.begin(),
BundleDevices.end(),
Dev);
return BundleDevices.contains(Dev);
});
});
if (!AllDevsAssociatedWithInputBundles)
Expand Down Expand Up @@ -363,41 +361,33 @@ class kernel_bundle_impl
}

// Create a link graph and clone it for each device.
device_impl &FirstDevice = *getSyclObjImpl(MDevices[0]);
std::map<std::shared_ptr<device_impl>, LinkGraph<device_image_plain>>
DevImageLinkGraphs;
device_impl &FirstDevice = get_devices().front();
std::map<device_impl *, LinkGraph<device_image_plain>> DevImageLinkGraphs;
const auto &FirstGraph =
DevImageLinkGraphs
.emplace(FirstDevice.shared_from_this(),
.emplace(&FirstDevice,
LinkGraph<device_image_plain>{DevImages, Dependencies})
.first->second;
for (size_t I = 1; I < MDevices.size(); ++I)
DevImageLinkGraphs.emplace(getSyclObjImpl(MDevices[I]),
FirstGraph.Clone());
for (device_impl &Dev : get_devices())
DevImageLinkGraphs.emplace(&Dev, FirstGraph.Clone());

// Poison the images based on whether the corresponding device supports it.
for (auto &GraphIt : DevImageLinkGraphs) {
device Dev = createSyclObjFromImpl<device>(GraphIt.first);
device_impl &Dev = *GraphIt.first;
GraphIt.second.Poison([&Dev](const device_image_plain &DevImg) {
return !getSyclObjImpl(DevImg)->compatible_with_device(Dev);
});
}

// Unify graphs after poisoning.
std::map<std::vector<std::shared_ptr<device_impl>>,
LinkGraph<device_image_plain>>
std::map<std::vector<device_impl *>, LinkGraph<device_image_plain>>
UnifiedGraphs = UnifyGraphs(DevImageLinkGraphs);

// Link based on the resulting graphs.
for (auto &GraphIt : UnifiedGraphs) {
std::vector<device> DeviceGroup;
DeviceGroup.reserve(GraphIt.first.size());
for (const auto &DeviceImgImpl : GraphIt.first)
DeviceGroup.emplace_back(createSyclObjFromImpl<device>(DeviceImgImpl));

std::vector<device_image_plain> LinkedResults =
detail::ProgramManager::getInstance().link(
GraphIt.second.GetNodeValues(), DeviceGroup, PropList);
GraphIt.second.GetNodeValues(), GraphIt.first, PropList);
MDeviceImages.insert(MDeviceImages.end(), LinkedResults.begin(),
LinkedResults.end());
MUniqueDeviceImages.insert(MUniqueDeviceImages.end(),
Expand All @@ -410,8 +400,8 @@ class kernel_bundle_impl
for (const DevImgPlainWithDeps *DeviceImageWithDeps :
ImagesWithSpecConsts) {
// Skip images which are not compatible with devices provided
if (std::none_of(MDevices.begin(), MDevices.end(),
[DeviceImageWithDeps](const device &Dev) {
if (std::none_of(get_devices().begin(), get_devices().end(),
[DeviceImageWithDeps](device_impl &Dev) {
return getSyclObjImpl(DeviceImageWithDeps->getMain())
->compatible_with_device(Dev);
}))
Expand All @@ -438,10 +428,11 @@ class kernel_bundle_impl
}
}

kernel_bundle_impl(context Ctx, std::vector<device> Devs,
kernel_bundle_impl(context Ctx, devices_range Devs,
const std::vector<kernel_id> &KernelIDs,
bundle_state State, private_tag)
: MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) {
: MContext(std::move(Ctx)),
MDevices(Devs.to<std::vector<device_impl *>>()), MState(State) {

common_ctor_checks();

Expand All @@ -450,10 +441,11 @@ class kernel_bundle_impl
fillUniqueDeviceImages();
}

kernel_bundle_impl(context Ctx, std::vector<device> Devs,
kernel_bundle_impl(context Ctx, devices_range Devs,
const DevImgSelectorImpl &Selector, bundle_state State,
private_tag)
: MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) {
: MContext(std::move(Ctx)),
MDevices(Devs.to<std::vector<device_impl *>>()), MState(State) {

common_ctor_checks();

Expand Down Expand Up @@ -548,7 +540,9 @@ class kernel_bundle_impl
kernel_bundle_impl(const context &Context, syclex::source_language Lang,
const std::string &Src, include_pairs_t IncludePairsVec,
private_tag)
: MContext(Context), MDevices(Context.get_devices()),
: MContext(Context), MDevices(getSyclObjImpl(Context)
->getDevices()
.to<std::vector<device_impl *>>()),
MDeviceImages{device_image_plain{device_image_impl::create(
Src, MContext, MDevices, Lang, std::move(IncludePairsVec))}},
MUniqueDeviceImages{MDeviceImages[0].getMain()},
Expand All @@ -560,7 +554,9 @@ class kernel_bundle_impl
// construct from source bytes
kernel_bundle_impl(const context &Context, syclex::source_language Lang,
const std::vector<std::byte> &Bytes, private_tag)
: MContext(Context), MDevices(Context.get_devices()),
: MContext(Context), MDevices(getSyclObjImpl(Context)
->getDevices()
.to<std::vector<device_impl *>>()),
MDeviceImages{device_image_plain{
device_image_impl::create(Bytes, MContext, MDevices, Lang)}},
MUniqueDeviceImages{MDeviceImages[0].getMain()},
Expand All @@ -571,11 +567,11 @@ class kernel_bundle_impl
// oneapi_ext_kernel_compiler
// construct from built source files
kernel_bundle_impl(
const context &Context, const std::vector<device> &Devs,
const context &Context, devices_range Devs,
std::vector<device_image_plain> &&DevImgs,
std::vector<std::shared_ptr<ManagedDeviceBinaries>> &&DevBinaries,
bundle_state State, private_tag)
: MContext(Context), MDevices(Devs),
: MContext(Context), MDevices(Devs.to<std::vector<device_impl *>>()),
MSharedDeviceBinaries(std::move(DevBinaries)),
MUniqueDeviceImages(std::move(DevImgs)), MState(State) {
common_ctor_checks();
Expand All @@ -587,10 +583,11 @@ class kernel_bundle_impl
}

// SYCLBIN constructor
kernel_bundle_impl(const context &Context, const std::vector<device> &Devs,
kernel_bundle_impl(const context &Context, devices_range Devs,
const sycl::span<char> Bytes, bundle_state State,
private_tag)
: MContext(Context), MDevices(Devs), MState(State) {
: MContext(Context), MDevices(Devs.to<std::vector<device_impl *>>()),
MState(State) {
common_ctor_checks();

auto &SYCLBIN = MSYCLBINs.emplace_back(
Expand Down Expand Up @@ -622,7 +619,7 @@ class kernel_bundle_impl
}

std::shared_ptr<kernel_bundle_impl> build_from_source(
const std::vector<device> Devices,
devices_range Devices,
const std::vector<sycl::detail::string_view> &BuildOptions,
std::string *LogPtr,
const std::vector<sycl::detail::string_view> &RegisteredKernelNames) {
Expand All @@ -645,7 +642,7 @@ class kernel_bundle_impl
}

std::shared_ptr<kernel_bundle_impl> compile_from_source(
const std::vector<device> Devices,
devices_range Devices,
const std::vector<sycl::detail::string_view> &CompileOptions,
std::string *LogPtr,
const std::vector<sycl::detail::string_view> &RegisteredKernelNames) {
Expand Down Expand Up @@ -733,8 +730,9 @@ class kernel_bundle_impl
void *ext_oneapi_get_device_global_address(const std::string &Name,
const device &Dev) const {
DeviceGlobalMapEntry *Entry = getDeviceGlobalEntry(Name);
device_impl &DeviceImpl = *getSyclObjImpl(Dev);

if (std::find(MDevices.begin(), MDevices.end(), Dev) == MDevices.end()) {
if (!get_devices().contains(DeviceImpl)) {
throw sycl::exception(make_error_code(errc::invalid),
"kernel_bundle not built for device");
}
Expand All @@ -745,7 +743,6 @@ class kernel_bundle_impl
"'device_image_scope' property");
}

device_impl &DeviceImpl = *getSyclObjImpl(Dev);
bool SupportContextMemcpy = false;
DeviceImpl.getAdapter().call<UrApiKind::urDeviceGetInfo>(
DeviceImpl.getHandleRef(),
Expand All @@ -772,7 +769,7 @@ class kernel_bundle_impl

context get_context() const noexcept { return MContext; }

const std::vector<device> &get_devices() const noexcept { return MDevices; }
devices_range get_devices() const noexcept { return MDevices; }

std::vector<kernel_id> get_kernel_ids() const {
// Collect kernel ids from all device images, then remove duplicates
Expand Down Expand Up @@ -1111,7 +1108,7 @@ class kernel_bundle_impl
}

context MContext;
std::vector<device> MDevices;
std::vector<device_impl *> MDevices;

// For sycl_jit, building from source may have produced sycl binaries that
// the kernel_bundles now manage.
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/kernel_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ kernel_impl::get_backend_info<info::device::version>() const {
"the info::device::version info descriptor can only "
"be queried with an OpenCL backend");
}
auto Devices = MKernelBundleImpl->get_devices();
auto Devices = MKernelBundleImpl->get_devices().to<std::vector<device>>();
if (Devices.empty()) {
return "No available device";
}
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/kernel_bundle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ context kernel_bundle_plain::get_context() const noexcept {
}

std::vector<device> kernel_bundle_plain::get_devices() const noexcept {
return impl->get_devices();
return impl->get_devices().to<std::vector<device>>();
}

std::vector<kernel_id> kernel_bundle_plain::get_kernel_ids() const {
Expand Down