diff --git a/sycl/source/detail/device_image_impl.hpp b/sycl/source/detail/device_image_impl.hpp index c39d00e9b7c27..38bcb19e240ae 100644 --- a/sycl/source/detail/device_image_impl.hpp +++ b/sycl/source/detail/device_image_impl.hpp @@ -549,11 +549,8 @@ class device_image_impl devices_range get_devices() const noexcept { return MDevices; } - bool compatible_with_device(const device &Dev) const { - return std::any_of(MDevices.begin(), MDevices.end(), - [Dev = &*getSyclObjImpl(Dev)](device_impl *DevCand) { - return Dev == DevCand; - }); + bool compatible_with_device(device_impl &Dev) const { + return get_devices().contains(Dev); } const ur_program_handle_t &get_ur_program_ref() const noexcept { diff --git a/sycl/source/detail/helpers.hpp b/sycl/source/detail/helpers.hpp index 3f002bf41b8e2..a1a49361e5755 100644 --- a/sycl/source/detail/helpers.hpp +++ b/sycl/source/detail/helpers.hpp @@ -54,6 +54,8 @@ class variadic_iterator { variadic_iterator(const variadic_iterator &) = default; variadic_iterator(variadic_iterator &&) = default; variadic_iterator(variadic_iterator &) = default; + variadic_iterator &operator=(const variadic_iterator &) = default; + variadic_iterator &operator=(variadic_iterator &&) = default; template variadic_iterator(IterTy &&It) : It(std::forward(It)) {} @@ -151,6 +153,12 @@ template class iterator_range { return Container{std::move(Result)}; } + bool contains(value_type &Other) const { + return std::find_if(begin(), end(), [&Other](value_type &Elem) { + return &Elem == &Other; + }) != end(); + } + protected: template static constexpr bool has_reserve_v = has_reserve::value; diff --git a/sycl/source/detail/kernel_bundle_impl.hpp b/sycl/source/detail/kernel_bundle_impl.hpp index e01ff15b4b8d4..6ce4d38a3420a 100644 --- a/sycl/source/detail/kernel_bundle_impl.hpp +++ b/sycl/source/detail/kernel_bundle_impl.hpp @@ -48,18 +48,17 @@ bool is_source_kernel_bundle_supported( namespace detail { -static bool checkAllDevicesAreInContext(const std::vector &Devices, +inline bool checkAllDevicesAreInContext(devices_range Devices, const context &Context) { - return std::all_of( - Devices.begin(), Devices.end(), [&Context](const device &Dev) { - return getSyclObjImpl(Context)->isDeviceValid(*getSyclObjImpl(Dev)); - }); + return std::all_of(Devices.begin(), Devices.end(), + [&Context](device_impl &Dev) { + return getSyclObjImpl(Context)->isDeviceValid(Dev); + }); } -static bool checkAllDevicesHaveAspect(const std::vector &Devices, - aspect Aspect) { +inline bool checkAllDevicesHaveAspect(devices_range Devices, aspect Aspect) { return std::all_of(Devices.begin(), Devices.end(), - [&Aspect](const device &Dev) { return Dev.has(Aspect); }); + [&Aspect](device_impl &Dev) { return Dev.has(Aspect); }); } namespace syclex = sycl::ext::oneapi::experimental; @@ -100,9 +99,10 @@ class kernel_bundle_impl } public: - kernel_bundle_impl(context Ctx, std::vector Devs, bundle_state State, + kernel_bundle_impl(context Ctx, devices_range Devs, bundle_state State, private_tag) - : MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) { + : MContext(std::move(Ctx)), + MDevices(Devs.to>()), MState(State) { common_ctor_checks(); @@ -112,8 +112,9 @@ class kernel_bundle_impl } // Interop constructor used by make_kernel - kernel_bundle_impl(context Ctx, std::vector Devs, private_tag) - : MContext(Ctx), MDevices(Devs), MState(bundle_state::executable) { + kernel_bundle_impl(context Ctx, devices_range Devs, private_tag) + : MContext(Ctx), MDevices(Devs.to>()), + MState(bundle_state::executable) { if (!checkAllDevicesAreInContext(Devs, Ctx)) throw sycl::exception( make_error_code(errc::invalid), @@ -122,9 +123,9 @@ class kernel_bundle_impl } // Interop constructor - kernel_bundle_impl(context Ctx, std::vector Devs, + kernel_bundle_impl(context Ctx, devices_range Devs, device_image_plain &DevImage, private_tag Tag) - : kernel_bundle_impl(std::move(Ctx), std::move(Devs), Tag) { + : kernel_bundle_impl(std::move(Ctx), Devs, Tag) { MDeviceImages.emplace_back(DevImage); MUniqueDeviceImages.emplace_back(DevImage); } @@ -133,22 +134,19 @@ class kernel_bundle_impl // Have one constructor because sycl::build and sycl::compile have the same // signature kernel_bundle_impl(const kernel_bundle &InputBundle, - std::vector Devs, const property_list &PropList, + devices_range Devs, const property_list &PropList, bundle_state TargetState, private_tag) - : MContext(InputBundle.get_context()), MDevices(std::move(Devs)), - MState(TargetState) { + : MContext(InputBundle.get_context()), + MDevices(Devs.to>()), MState(TargetState) { kernel_bundle_impl &InputBundleImpl = *getSyclObjImpl(InputBundle); MSpecConstValues = InputBundleImpl.get_spec_const_map_ref(); - const std::vector &InputBundleDevices = - InputBundleImpl.get_devices(); + devices_range InputBundleDevices = InputBundleImpl.get_devices(); const bool AllDevsAssociatedWithInputBundle = - std::all_of(MDevices.begin(), MDevices.end(), - [&InputBundleDevices](const device &Dev) { - return InputBundleDevices.end() != - std::find(InputBundleDevices.begin(), - InputBundleDevices.end(), Dev); + std::all_of(get_devices().begin(), get_devices().end(), + [&InputBundleDevices](device_impl &Dev) { + return InputBundleDevices.contains(Dev); }); if (MDevices.empty() || !AllDevsAssociatedWithInputBundle) throw sycl::exception( @@ -163,8 +161,8 @@ class kernel_bundle_impl for (const DevImgPlainWithDeps &DevImgWithDeps : InputBundleImpl.MDeviceImages) { // Skip images which are not compatible with devices provided - if (std::none_of(MDevices.begin(), MDevices.end(), - [&DevImgWithDeps](const device &Dev) { + if (std::none_of(get_devices().begin(), get_devices().end(), + [&DevImgWithDeps](device_impl &Dev) { return getSyclObjImpl(DevImgWithDeps.getMain()) ->compatible_with_device(Dev); })) @@ -206,8 +204,9 @@ class kernel_bundle_impl // Matches sycl::link kernel_bundle_impl( const std::vector> &ObjectBundles, - std::vector Devs, const property_list &PropList, private_tag) - : MDevices(std::move(Devs)), MState(bundle_state::executable) { + devices_range Devs, const property_list &PropList, private_tag) + : MDevices(Devs.to>()), + MState(bundle_state::executable) { if (MDevices.empty()) throw sycl::exception(make_error_code(errc::invalid), "Vector of devices is empty"); @@ -226,16 +225,15 @@ class kernel_bundle_impl // Check if any of the devices in devs are not in the set of associated // devices for any of the bundles in ObjectBundles const bool AllDevsAssociatedWithInputBundles = std::all_of( - MDevices.begin(), MDevices.end(), [&ObjectBundles](const device &Dev) { + get_devices().begin(), get_devices().end(), + [&ObjectBundles](device_impl &Dev) { // Number of devices is expected to be small return std::all_of( ObjectBundles.begin(), ObjectBundles.end(), [&Dev](const kernel_bundle &KernelBundle) { - const std::vector &BundleDevices = + devices_range BundleDevices = getSyclObjImpl(KernelBundle)->get_devices(); - return BundleDevices.end() != std::find(BundleDevices.begin(), - BundleDevices.end(), - Dev); + return BundleDevices.contains(Dev); }); }); if (!AllDevsAssociatedWithInputBundles) @@ -363,41 +361,33 @@ class kernel_bundle_impl } // Create a link graph and clone it for each device. - device_impl &FirstDevice = *getSyclObjImpl(MDevices[0]); - std::map, LinkGraph> - DevImageLinkGraphs; + device_impl &FirstDevice = get_devices().front(); + std::map> DevImageLinkGraphs; const auto &FirstGraph = DevImageLinkGraphs - .emplace(FirstDevice.shared_from_this(), + .emplace(&FirstDevice, LinkGraph{DevImages, Dependencies}) .first->second; - for (size_t I = 1; I < MDevices.size(); ++I) - DevImageLinkGraphs.emplace(getSyclObjImpl(MDevices[I]), - FirstGraph.Clone()); + for (device_impl &Dev : get_devices()) + DevImageLinkGraphs.emplace(&Dev, FirstGraph.Clone()); // Poison the images based on whether the corresponding device supports it. for (auto &GraphIt : DevImageLinkGraphs) { - device Dev = createSyclObjFromImpl(GraphIt.first); + device_impl &Dev = *GraphIt.first; GraphIt.second.Poison([&Dev](const device_image_plain &DevImg) { return !getSyclObjImpl(DevImg)->compatible_with_device(Dev); }); } // Unify graphs after poisoning. - std::map>, - LinkGraph> + std::map, LinkGraph> UnifiedGraphs = UnifyGraphs(DevImageLinkGraphs); // Link based on the resulting graphs. for (auto &GraphIt : UnifiedGraphs) { - std::vector DeviceGroup; - DeviceGroup.reserve(GraphIt.first.size()); - for (const auto &DeviceImgImpl : GraphIt.first) - DeviceGroup.emplace_back(createSyclObjFromImpl(DeviceImgImpl)); - std::vector LinkedResults = detail::ProgramManager::getInstance().link( - GraphIt.second.GetNodeValues(), DeviceGroup, PropList); + GraphIt.second.GetNodeValues(), GraphIt.first, PropList); MDeviceImages.insert(MDeviceImages.end(), LinkedResults.begin(), LinkedResults.end()); MUniqueDeviceImages.insert(MUniqueDeviceImages.end(), @@ -410,8 +400,8 @@ class kernel_bundle_impl for (const DevImgPlainWithDeps *DeviceImageWithDeps : ImagesWithSpecConsts) { // Skip images which are not compatible with devices provided - if (std::none_of(MDevices.begin(), MDevices.end(), - [DeviceImageWithDeps](const device &Dev) { + if (std::none_of(get_devices().begin(), get_devices().end(), + [DeviceImageWithDeps](device_impl &Dev) { return getSyclObjImpl(DeviceImageWithDeps->getMain()) ->compatible_with_device(Dev); })) @@ -438,10 +428,11 @@ class kernel_bundle_impl } } - kernel_bundle_impl(context Ctx, std::vector Devs, + kernel_bundle_impl(context Ctx, devices_range Devs, const std::vector &KernelIDs, bundle_state State, private_tag) - : MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) { + : MContext(std::move(Ctx)), + MDevices(Devs.to>()), MState(State) { common_ctor_checks(); @@ -450,10 +441,11 @@ class kernel_bundle_impl fillUniqueDeviceImages(); } - kernel_bundle_impl(context Ctx, std::vector Devs, + kernel_bundle_impl(context Ctx, devices_range Devs, const DevImgSelectorImpl &Selector, bundle_state State, private_tag) - : MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) { + : MContext(std::move(Ctx)), + MDevices(Devs.to>()), MState(State) { common_ctor_checks(); @@ -548,7 +540,9 @@ class kernel_bundle_impl kernel_bundle_impl(const context &Context, syclex::source_language Lang, const std::string &Src, include_pairs_t IncludePairsVec, private_tag) - : MContext(Context), MDevices(Context.get_devices()), + : MContext(Context), MDevices(getSyclObjImpl(Context) + ->getDevices() + .to>()), MDeviceImages{device_image_plain{device_image_impl::create( Src, MContext, MDevices, Lang, std::move(IncludePairsVec))}}, MUniqueDeviceImages{MDeviceImages[0].getMain()}, @@ -560,7 +554,9 @@ class kernel_bundle_impl // construct from source bytes kernel_bundle_impl(const context &Context, syclex::source_language Lang, const std::vector &Bytes, private_tag) - : MContext(Context), MDevices(Context.get_devices()), + : MContext(Context), MDevices(getSyclObjImpl(Context) + ->getDevices() + .to>()), MDeviceImages{device_image_plain{ device_image_impl::create(Bytes, MContext, MDevices, Lang)}}, MUniqueDeviceImages{MDeviceImages[0].getMain()}, @@ -571,11 +567,11 @@ class kernel_bundle_impl // oneapi_ext_kernel_compiler // construct from built source files kernel_bundle_impl( - const context &Context, const std::vector &Devs, + const context &Context, devices_range Devs, std::vector &&DevImgs, std::vector> &&DevBinaries, bundle_state State, private_tag) - : MContext(Context), MDevices(Devs), + : MContext(Context), MDevices(Devs.to>()), MSharedDeviceBinaries(std::move(DevBinaries)), MUniqueDeviceImages(std::move(DevImgs)), MState(State) { common_ctor_checks(); @@ -587,10 +583,11 @@ class kernel_bundle_impl } // SYCLBIN constructor - kernel_bundle_impl(const context &Context, const std::vector &Devs, + kernel_bundle_impl(const context &Context, devices_range Devs, const sycl::span Bytes, bundle_state State, private_tag) - : MContext(Context), MDevices(Devs), MState(State) { + : MContext(Context), MDevices(Devs.to>()), + MState(State) { common_ctor_checks(); auto &SYCLBIN = MSYCLBINs.emplace_back( @@ -622,7 +619,7 @@ class kernel_bundle_impl } std::shared_ptr build_from_source( - const std::vector Devices, + devices_range Devices, const std::vector &BuildOptions, std::string *LogPtr, const std::vector &RegisteredKernelNames) { @@ -645,7 +642,7 @@ class kernel_bundle_impl } std::shared_ptr compile_from_source( - const std::vector Devices, + devices_range Devices, const std::vector &CompileOptions, std::string *LogPtr, const std::vector &RegisteredKernelNames) { @@ -733,8 +730,9 @@ class kernel_bundle_impl void *ext_oneapi_get_device_global_address(const std::string &Name, const device &Dev) const { DeviceGlobalMapEntry *Entry = getDeviceGlobalEntry(Name); + device_impl &DeviceImpl = *getSyclObjImpl(Dev); - if (std::find(MDevices.begin(), MDevices.end(), Dev) == MDevices.end()) { + if (!get_devices().contains(DeviceImpl)) { throw sycl::exception(make_error_code(errc::invalid), "kernel_bundle not built for device"); } @@ -745,7 +743,6 @@ class kernel_bundle_impl "'device_image_scope' property"); } - device_impl &DeviceImpl = *getSyclObjImpl(Dev); bool SupportContextMemcpy = false; DeviceImpl.getAdapter().call( DeviceImpl.getHandleRef(), @@ -772,7 +769,7 @@ class kernel_bundle_impl context get_context() const noexcept { return MContext; } - const std::vector &get_devices() const noexcept { return MDevices; } + devices_range get_devices() const noexcept { return MDevices; } std::vector get_kernel_ids() const { // Collect kernel ids from all device images, then remove duplicates @@ -1111,7 +1108,7 @@ class kernel_bundle_impl } context MContext; - std::vector MDevices; + std::vector MDevices; // For sycl_jit, building from source may have produced sycl binaries that // the kernel_bundles now manage. diff --git a/sycl/source/detail/kernel_impl.cpp b/sycl/source/detail/kernel_impl.cpp index da22bb28c9922..0cb679f1f0fc3 100644 --- a/sycl/source/detail/kernel_impl.cpp +++ b/sycl/source/detail/kernel_impl.cpp @@ -167,7 +167,7 @@ kernel_impl::get_backend_info() const { "the info::device::version info descriptor can only " "be queried with an OpenCL backend"); } - auto Devices = MKernelBundleImpl->get_devices(); + auto Devices = MKernelBundleImpl->get_devices().to>(); if (Devices.empty()) { return "No available device"; } diff --git a/sycl/source/kernel_bundle.cpp b/sycl/source/kernel_bundle.cpp index b290f43a5ad93..3476c8747102f 100644 --- a/sycl/source/kernel_bundle.cpp +++ b/sycl/source/kernel_bundle.cpp @@ -74,7 +74,7 @@ context kernel_bundle_plain::get_context() const noexcept { } std::vector kernel_bundle_plain::get_devices() const noexcept { - return impl->get_devices(); + return impl->get_devices().to>(); } std::vector kernel_bundle_plain::get_kernel_ids() const {