diff --git a/sycl/source/detail/device_image_impl.hpp b/sycl/source/detail/device_image_impl.hpp index 38bcb19e240ae..d8c538f7ef2a2 100644 --- a/sycl/source/detail/device_image_impl.hpp +++ b/sycl/source/detail/device_image_impl.hpp @@ -311,9 +311,7 @@ class device_image_impl private_tag) : MBinImage(BinImage), MContext(std::move(Context)), MDevices(Devices.to>()), MState(State), - MProgram(Program), - MKernelIDs(std::make_shared>()), - MKernelNames{std::move(KernelNames)}, + MProgram(Program), MKernelNames{std::move(KernelNames)}, MEliminatedKernelArgMasks{std::move(EliminatedKernelArgMasks)}, MSpecConstsDefValBlob(getSpecConstsDefValBlob()), MOrigins(ImageOriginKernelCompiler), @@ -347,7 +345,6 @@ class device_image_impl : MBinImage(Src), MContext(std::move(Context)), MDevices(Devices.to>()), MState(bundle_state::ext_oneapi_source), MProgram(nullptr), - MKernelIDs(std::make_shared>()), MSpecConstsDefValBlob(getSpecConstsDefValBlob()), MOrigins(ImageOriginKernelCompiler), MRTCBinInfo( @@ -361,7 +358,6 @@ class device_image_impl : MBinImage(Bytes), MContext(std::move(Context)), MDevices(Devices.to>()), MState(bundle_state::ext_oneapi_source), MProgram(nullptr), - MKernelIDs(std::make_shared>()), MSpecConstsDefValBlob(getSpecConstsDefValBlob()), MOrigins(ImageOriginKernelCompiler), MRTCBinInfo(KernelCompilerBinaryInfo{Lang}) { @@ -375,9 +371,7 @@ class device_image_impl : MBinImage(static_cast(nullptr)), MContext(std::move(Context)), MDevices(Devices.to>()), MState(State), - MProgram(Program), - MKernelIDs(std::make_shared>()), - MKernelNames{std::move(KernelNames)}, + MProgram(Program), MKernelNames{std::move(KernelNames)}, MSpecConstsDefValBlob(getSpecConstsDefValBlob()), MOrigins(ImageOriginKernelCompiler), MRTCBinInfo(KernelCompilerBinaryInfo{Lang}) {} @@ -389,6 +383,8 @@ class device_image_impl } bool has_kernel(const kernel_id &KernelIDCand) const noexcept { + if (!MKernelIDs) + return false; return std::binary_search(MKernelIDs->begin(), MKernelIDs->end(), KernelIDCand, LessByHash{}); } @@ -414,8 +410,18 @@ class device_image_impl return false; } - const std::vector &get_kernel_ids() const noexcept { - return *MKernelIDs; + iterator_range::const_iterator> + get_kernel_ids() const noexcept { + if (MKernelIDs) + return *MKernelIDs; + else + return {}; + } + // This should only be used when creating new device_image_impls that have the + // exact same set of kernels as the source one. In all other scenarios the + // getter above is the one needed: + std::shared_ptr> &get_kernel_ids_ptr() noexcept { + return MKernelIDs; } bool has_specialization_constants() const noexcept { @@ -563,10 +569,6 @@ class device_image_impl const context &get_context() const noexcept { return MContext; } - std::shared_ptr> &get_kernel_ids_ptr() noexcept { - return MKernelIDs; - } - std::vector &get_spec_const_blob_ref() noexcept { return MSpecConstsBlob; } @@ -1300,7 +1302,9 @@ class device_image_impl ur_program_handle_t MProgram = nullptr; // List of kernel ids available in this image, elements should be sorted - // according to LessByNameComp + // according to LessByNameComp. Shared between images for performance reasons + // (e.g. when we compile a single image it keeps the same kernels in it as the + // original source image). std::shared_ptr> MKernelIDs; // List of known kernel names. diff --git a/sycl/source/detail/helpers.hpp b/sycl/source/detail/helpers.hpp index a1a49361e5755..2f413bca2994e 100644 --- a/sycl/source/detail/helpers.hpp +++ b/sycl/source/detail/helpers.hpp @@ -51,6 +51,7 @@ class variadic_iterator { using pointer = value_type *; static_assert(std::is_same_v); + variadic_iterator() = default; variadic_iterator(const variadic_iterator &) = default; variadic_iterator(variadic_iterator &&) = default; variadic_iterator(variadic_iterator &) = default; @@ -88,7 +89,6 @@ class variadic_iterator { // Non-owning! template class iterator_range { using value_type = typename iterator::value_type; - using sycl_type = typename iterator::sycl_type; template struct has_reserve : public std::false_type {}; @@ -104,16 +104,20 @@ template class iterator_range { iterator_range(IterTy Begin, IterTy End, size_t Size) : Begin(Begin), End(End), Size(Size) {} - iterator_range() - : iterator_range(static_cast(nullptr), - static_cast(nullptr), 0) {} + iterator_range() : iterator_range(iterator{}, iterator{}, 0) {} - template + template ().begin()})>> iterator_range(const ContainerTy &Container) : iterator_range(Container.begin(), Container.end(), Container.size()) {} iterator_range(value_type &Obj) : iterator_range(&Obj, &Obj + 1, 1) {} + template ())})>, + // To make it different from `ContainerTy` overload above: + typename = void> iterator_range(const sycl_type &Obj) : iterator_range(&*getSyclObjImpl(Obj), (&*getSyclObjImpl(Obj) + 1), 1) {} @@ -123,13 +127,15 @@ template class iterator_range { bool empty() const { return Size == 0; } decltype(auto) front() const { return *begin(); } - template - std::enable_if_t< - check_type_in_v, - std::queue, std::vector, - std::vector>>, - Container> - to() const { + // Only enable for ranges of `variadic_iterator` and for the containers with + // proper `value_type`. The last part is important so that descendent + // `devices_range` could provide its own specialization for + // `to>()`. + template , typename iterator_::sycl_type>>> + Container to() const { std::conditional_t>, typename std::queue::container_type, Container> @@ -138,14 +144,14 @@ template class iterator_range { Result.reserve(size()); std::transform( begin(), end(), std::back_inserter(Result), [](value_type &E) { - if constexpr (std::is_same_v>) - return createSyclObjFromImpl(E); - else if constexpr (std::is_same_v< - Container, - std::vector>>) + using container_value_type = typename Container::value_type; + if constexpr (std::is_same_v>) return E.shared_from_this(); - else + else if constexpr (std::is_same_v) return &E; + else + return createSyclObjFromImpl(E); }); if constexpr (std::is_same_v) return Result; @@ -153,16 +159,15 @@ template class iterator_range { return Container{std::move(Result)}; } + // Only enable for ranges of `variadic_iterator` above. + template > bool contains(value_type &Other) const { return std::find_if(begin(), end(), [&Other](value_type &Elem) { return &Elem == &Other; }) != end(); } -protected: - template - static constexpr bool has_reserve_v = has_reserve::value; - private: iterator Begin; iterator End; diff --git a/sycl/source/detail/kernel_bundle_impl.hpp b/sycl/source/detail/kernel_bundle_impl.hpp index 6ce4d38a3420a..4000bafaf96b2 100644 --- a/sycl/source/detail/kernel_bundle_impl.hpp +++ b/sycl/source/detail/kernel_bundle_impl.hpp @@ -782,7 +782,7 @@ class kernel_bundle_impl if (DevImgImpl.getRTCInfo()) continue; - const std::vector &KernelIDs = DevImgImpl.get_kernel_ids(); + auto KernelIDs = DevImgImpl.get_kernel_ids(); Result.insert(Result.end(), KernelIDs.begin(), KernelIDs.end()); } diff --git a/sycl/source/detail/program_manager/program_manager.cpp b/sycl/source/detail/program_manager/program_manager.cpp index a07c0f24ab3d4..5cc31e8a38a39 100644 --- a/sycl/source/detail/program_manager/program_manager.cpp +++ b/sycl/source/detail/program_manager/program_manager.cpp @@ -2690,9 +2690,9 @@ ProgramManager::createDependencyImage(const context &Ctx, devices_range Devs, assert(DepState == getBinImageState(DepImage) && "State mismatch between main image and its dependency"); - DeviceImageImplPtr DepImpl = - device_image_impl::create(DepImage, Ctx, Devs, DepState, DepKernelIDs, - /*PIProgram=*/nullptr, ImageOriginSYCLOffline); + DeviceImageImplPtr DepImpl = device_image_impl::create( + DepImage, Ctx, Devs, DepState, std::move(DepKernelIDs), + /*PIProgram=*/nullptr, ImageOriginSYCLOffline); return createSyclObjFromImpl(std::move(DepImpl)); } @@ -2905,10 +2905,8 @@ mergeImageData(const std::vector &Imgs, for (const device_image_plain &Img : Imgs) { device_image_impl &DeviceImageImpl = *getSyclObjImpl(Img); // Duplicates are not expected here, otherwise urProgramLink should fail - if (DeviceImageImpl.get_kernel_ids_ptr()) - KernelIDs.insert(KernelIDs.end(), - DeviceImageImpl.get_kernel_ids_ptr()->begin(), - DeviceImageImpl.get_kernel_ids_ptr()->end()); + KernelIDs.insert(KernelIDs.end(), DeviceImageImpl.get_kernel_ids().begin(), + DeviceImageImpl.get_kernel_ids().end()); // To be able to answer queries about specialziation constants, the new // device image should have the specialization constants from all the linked // images. diff --git a/sycl/source/kernel_bundle.cpp b/sycl/source/kernel_bundle.cpp index 3476c8747102f..831277b1cc818 100644 --- a/sycl/source/kernel_bundle.cpp +++ b/sycl/source/kernel_bundle.cpp @@ -288,8 +288,8 @@ bool has_kernel_bundle_impl(const context &Ctx, const std::vector &Devs, const std::shared_ptr &DeviceImageImpl = getSyclObjImpl(DeviceImage); - CombinedKernelIDs.insert(DeviceImageImpl->get_kernel_ids_ptr()->begin(), - DeviceImageImpl->get_kernel_ids_ptr()->end()); + CombinedKernelIDs.insert(DeviceImageImpl->get_kernel_ids().begin(), + DeviceImageImpl->get_kernel_ids().end()); } }