diff --git a/sycl/include/CL/sycl/detail/cg.hpp b/sycl/include/CL/sycl/detail/cg.hpp index 170c0f39906c3..8c642ce04789a 100644 --- a/sycl/include/CL/sycl/detail/cg.hpp +++ b/sycl/include/CL/sycl/detail/cg.hpp @@ -248,6 +248,7 @@ class CGExecKernel : public CG { std::string MKernelName; detail::OSModuleHandle MOSModuleHandle; std::vector> MStreams; + std::vector> MReductionResources; CGExecKernel(NDRDescT NDRDesc, std::unique_ptr HKernel, std::shared_ptr SyclKernel, @@ -259,6 +260,7 @@ class CGExecKernel : public CG { std::vector Args, std::string KernelName, detail::OSModuleHandle OSModuleHandle, std::vector> Streams, + std::vector> ReductionResources, CGTYPE Type, detail::code_location loc = {}) : CG(Type, std::move(ArgsStorage), std::move(AccStorage), std::move(SharedPtrStorage), std::move(Requirements), @@ -266,7 +268,8 @@ class CGExecKernel : public CG { MNDRDesc(std::move(NDRDesc)), MHostKernel(std::move(HKernel)), MSyclKernel(std::move(SyclKernel)), MArgs(std::move(Args)), MKernelName(std::move(KernelName)), MOSModuleHandle(OSModuleHandle), - MStreams(std::move(Streams)) { + MStreams(std::move(Streams)), + MReductionResources(std::move(ReductionResources)) { assert((getType() == RunOnHostIntel || getType() == Kernel) && "Wrong type of exec kernel CG."); } @@ -277,6 +280,10 @@ class CGExecKernel : public CG { return MStreams; } + std::vector> getReductionResources() const { + return MReductionResources; + } + std::shared_ptr getKernelBundle() { const std::shared_ptr> &ExtendedMembers = getExtendedMembers(); @@ -290,6 +297,7 @@ class CGExecKernel : public CG { } void clearStreams() { MStreams.clear(); } + void clearReductionResources() { MReductionResources.clear(); } }; /// "Copy memory" command group class. diff --git a/sycl/include/CL/sycl/handler.hpp b/sycl/include/CL/sycl/handler.hpp index 157fe9ff4620f..73e4c85301ec0 100644 --- a/sycl/include/CL/sycl/handler.hpp +++ b/sycl/include/CL/sycl/handler.hpp @@ -472,12 +472,9 @@ class __SYCL_EXPORT handler { /// Saves buffers created by handling reduction feature in handler. /// They are then forwarded to command group and destroyed only after /// the command group finishes the work on device/host. - /// The 'MSharedPtrStorage' suits that need. /// /// @param ReduObj is a pointer to object that must be stored. - void addReduction(const std::shared_ptr &ReduObj) { - MSharedPtrStorage.push_back(ReduObj); - } + void addReduction(const std::shared_ptr &ReduObj); ~handler() = default; @@ -1271,6 +1268,7 @@ class __SYCL_EXPORT handler { } std::shared_ptr getHandlerImpl() const; + std::shared_ptr evictHandlerImpl() const; void setStateExplicitKernelBundle(); void setStateSpecConstSet(); diff --git a/sycl/source/detail/handler_impl.hpp b/sycl/source/detail/handler_impl.hpp index d4171e8d4d1d6..f77c2a44f9955 100644 --- a/sycl/source/detail/handler_impl.hpp +++ b/sycl/source/detail/handler_impl.hpp @@ -65,6 +65,12 @@ class handler_impl { /// equal to the queue associated with the handler if the corresponding /// submission is a fallback from a previous submission. std::shared_ptr MSubmissionSecondaryQueue; + + // Protects reduction resources + std::mutex MReductionResourcesMutex; + + // Stores additional resources used by reductions. + std::vector> MReductionResources; }; } // namespace detail diff --git a/sycl/source/detail/scheduler/commands.cpp b/sycl/source/detail/scheduler/commands.cpp index 6724608227699..d8e735bb463b1 100644 --- a/sycl/source/detail/scheduler/commands.cpp +++ b/sycl/source/detail/scheduler/commands.cpp @@ -1307,11 +1307,23 @@ std::vector ExecCGCommand::getStreams() const { return {}; } +std::vector> +ExecCGCommand::getReductionResources() const { + if (MCommandGroup->getType() == CG::Kernel) + return ((CGExecKernel *)MCommandGroup.get())->getReductionResources(); + return {}; +} + void ExecCGCommand::clearStreams() { if (MCommandGroup->getType() == CG::Kernel) ((CGExecKernel *)MCommandGroup.get())->clearStreams(); } +void ExecCGCommand::clearReductionResources() { + if (MCommandGroup->getType() == CG::Kernel) + ((CGExecKernel *)MCommandGroup.get())->clearReductionResources(); +} + cl_int UpdateHostRequirementCommand::enqueueImp() { waitForPreparedHostEvents(); std::vector EventImpls = MPreparedDepsEvents; diff --git a/sycl/source/detail/scheduler/commands.hpp b/sycl/source/detail/scheduler/commands.hpp index 4a556f8a5567e..958e0836d22fe 100644 --- a/sycl/source/detail/scheduler/commands.hpp +++ b/sycl/source/detail/scheduler/commands.hpp @@ -518,8 +518,10 @@ class ExecCGCommand : public Command { ExecCGCommand(std::unique_ptr CommandGroup, QueueImplPtr Queue); std::vector getStreams() const; + std::vector> getReductionResources() const; void clearStreams(); + void clearReductionResources(); void printDot(std::ostream &Stream) const final; void emitInstrumentationData() final; diff --git a/sycl/source/detail/scheduler/graph_builder.cpp b/sycl/source/detail/scheduler/graph_builder.cpp index 61f10b41845cd..b6c49a751f368 100644 --- a/sycl/source/detail/scheduler/graph_builder.cpp +++ b/sycl/source/detail/scheduler/graph_builder.cpp @@ -1001,7 +1001,8 @@ void Scheduler::GraphBuilder::decrementLeafCountersForRecord( void Scheduler::GraphBuilder::cleanupCommandsForRecord( MemObjRecord *Record, - std::vector> &StreamsToDeallocate) { + std::vector> &StreamsToDeallocate, + std::vector> &ReduResourcesToDeallocate) { std::vector &AllocaCommands = Record->MAllocaCommands; if (AllocaCommands.empty()) return; @@ -1053,10 +1054,20 @@ void Scheduler::GraphBuilder::cleanupCommandsForRecord( // Collect stream objects for a visited command. if (Cmd->getType() == Command::CommandType::RUN_CG) { auto ExecCmd = static_cast(Cmd); + + // Transfer ownership of stream implementations. std::vector> Streams = ExecCmd->getStreams(); ExecCmd->clearStreams(); StreamsToDeallocate.insert(StreamsToDeallocate.end(), Streams.begin(), Streams.end()); + + // Transfer ownership of reduction resources. + std::vector> ReduResources = + ExecCmd->getReductionResources(); + ExecCmd->clearReductionResources(); + ReduResourcesToDeallocate.insert(ReduResourcesToDeallocate.end(), + ReduResources.begin(), + ReduResources.end()); } for (Command *UserCmd : Cmd->MUsers) @@ -1098,7 +1109,8 @@ void Scheduler::GraphBuilder::cleanupCommandsForRecord( void Scheduler::GraphBuilder::cleanupFinishedCommands( Command *FinishedCmd, - std::vector> &StreamsToDeallocate) { + std::vector> &StreamsToDeallocate, + std::vector> &ReduResourcesToDeallocate) { assert(MCmdsToVisit.empty()); MCmdsToVisit.push(FinishedCmd); MVisitedCmds.clear(); @@ -1114,10 +1126,20 @@ void Scheduler::GraphBuilder::cleanupFinishedCommands( // Collect stream objects for a visited command. if (Cmd->getType() == Command::CommandType::RUN_CG) { auto ExecCmd = static_cast(Cmd); + + // Transfer ownership of stream implementations. std::vector> Streams = ExecCmd->getStreams(); ExecCmd->clearStreams(); StreamsToDeallocate.insert(StreamsToDeallocate.end(), Streams.begin(), Streams.end()); + + // Transfer ownership of reduction resources. + std::vector> ReduResources = + ExecCmd->getReductionResources(); + ExecCmd->clearReductionResources(); + ReduResourcesToDeallocate.insert(ReduResourcesToDeallocate.end(), + ReduResources.begin(), + ReduResources.end()); } for (const DepDesc &Dep : Cmd->MDeps) { diff --git a/sycl/source/detail/scheduler/scheduler.cpp b/sycl/source/detail/scheduler/scheduler.cpp index c17beb8a3621d..311df88f3f4ee 100644 --- a/sycl/source/detail/scheduler/scheduler.cpp +++ b/sycl/source/detail/scheduler/scheduler.cpp @@ -229,6 +229,11 @@ void Scheduler::cleanupFinishedCommands(EventImplPtr FinishedEvent) { // objects, this is needed to guarantee that streamed data is printed and // resources are released. std::vector> StreamsToDeallocate; + // Similar to streams, we also collect the reduction resources used by the + // commands. Cleanup will make sure the commands do not own the resources + // anymore, so we just need them to survive the graph lock then they can die + // as they go out of scope. + std::vector> ReduResourcesToDeallocate; { // Avoiding deadlock situation, where one thread is in the process of // enqueueing (with a locked mutex) a currently blocked task that waits for @@ -239,7 +244,8 @@ void Scheduler::cleanupFinishedCommands(EventImplPtr FinishedEvent) { // The command might have been cleaned up (and set to nullptr) by another // thread if (FinishedCmd) - MGraphBuilder.cleanupFinishedCommands(FinishedCmd, StreamsToDeallocate); + MGraphBuilder.cleanupFinishedCommands(FinishedCmd, StreamsToDeallocate, + ReduResourcesToDeallocate); } } deallocateStreams(StreamsToDeallocate); @@ -251,6 +257,11 @@ void Scheduler::removeMemoryObject(detail::SYCLMemObjI *MemObj) { // objects, this is needed to guarantee that streamed data is printed and // resources are released. std::vector> StreamsToDeallocate; + // Similar to streams, we also collect the reduction resources used by the + // commands. Cleanup will make sure the commands do not own the resources + // anymore, so we just need them to survive the graph lock then they can die + // as they go out of scope. + std::vector> ReduResourcesToDeallocate; { MemObjRecord *Record = nullptr; @@ -272,7 +283,8 @@ void Scheduler::removeMemoryObject(detail::SYCLMemObjI *MemObj) { WriteLockT Lock(MGraphLock, std::defer_lock); acquireWriteLock(Lock); MGraphBuilder.decrementLeafCountersForRecord(Record); - MGraphBuilder.cleanupCommandsForRecord(Record, StreamsToDeallocate); + MGraphBuilder.cleanupCommandsForRecord(Record, StreamsToDeallocate, + ReduResourcesToDeallocate); MGraphBuilder.removeRecordForMemObj(MemObj); } } diff --git a/sycl/source/detail/scheduler/scheduler.hpp b/sycl/source/detail/scheduler/scheduler.hpp index 66e17d7862301..15266923a6fcc 100644 --- a/sycl/source/detail/scheduler/scheduler.hpp +++ b/sycl/source/detail/scheduler/scheduler.hpp @@ -509,7 +509,8 @@ class Scheduler { /// (assuming that all its commands have been waited for). void cleanupFinishedCommands( Command *FinishedCmd, - std::vector> &); + std::vector> &, + std::vector> &); /// Reschedules the command passed using Queue provided. /// @@ -535,7 +536,8 @@ class Scheduler { /// Removes commands that use the given MemObjRecord from the graph. void cleanupCommandsForRecord( MemObjRecord *Record, - std::vector> &); + std::vector> &, + std::vector> &); /// Removes the MemObjRecord for the memory object passed. void removeRecordForMemObj(SYCLMemObjI *MemObject); diff --git a/sycl/source/handler.cpp b/sycl/source/handler.cpp index e1c27d2650898..2381c71c0650c 100644 --- a/sycl/source/handler.cpp +++ b/sycl/source/handler.cpp @@ -48,24 +48,40 @@ handler::handler(std::shared_ptr Queue, MSharedPtrStorage.push_back(std::move(ExtendedMembers)); } +static detail::ExtendedMemberT &getHandlerImplMember( + std::vector> &SharedPtrStorage) { + assert(!SharedPtrStorage.empty()); + std::shared_ptr> ExtendedMembersVec = + detail::convertToExtendedMembers(SharedPtrStorage[0]); + assert(ExtendedMembersVec->size() > 0); + auto &HandlerImplMember = (*ExtendedMembersVec)[0]; + assert(detail::ExtendedMembersType::HANDLER_IMPL == HandlerImplMember.MType); + return HandlerImplMember; +} + /// Gets the handler_impl at the start of the extended members. std::shared_ptr handler::getHandlerImpl() const { std::lock_guard Lock( detail::GlobalHandler::instance().getHandlerExtendedMembersMutex()); + return std::static_pointer_cast( + getHandlerImplMember(MSharedPtrStorage).MData); +} - assert(!MSharedPtrStorage.empty()); - - std::shared_ptr> ExtendedMembersVec = - detail::convertToExtendedMembers(MSharedPtrStorage[0]); - - assert(ExtendedMembersVec->size() > 0); - - auto HandlerImplMember = (*ExtendedMembersVec)[0]; +/// Gets the handler_impl at the start of the extended members and removes it. +std::shared_ptr handler::evictHandlerImpl() const { + std::lock_guard Lock( + detail::GlobalHandler::instance().getHandlerExtendedMembersMutex()); + auto &HandlerImplMember = getHandlerImplMember(MSharedPtrStorage); + auto Impl = + std::static_pointer_cast(HandlerImplMember.MData); - assert(detail::ExtendedMembersType::HANDLER_IMPL == HandlerImplMember.MType); + // Reset the data of the member. + // NOTE: We let it stay because removing the front can be expensive. This will + // be improved when the impl is made a member of handler. In fact eviction is + // likely to not be needed when that happens. + HandlerImplMember.MData.reset(); - return std::static_pointer_cast( - HandlerImplMember.MData); + return Impl; } // Sets the submission state to indicate that an explicit kernel bundle has been @@ -220,6 +236,10 @@ event handler::finalize() { return MLastEvent; } + // Evict handler_impl from extended members to make sure the command group + // does not keep it alive. + std::shared_ptr Impl = evictHandlerImpl(); + std::unique_ptr CommandGroup; switch (type) { case detail::CG::Kernel: @@ -232,7 +252,8 @@ event handler::finalize() { std::move(MArgsStorage), std::move(MAccStorage), std::move(MSharedPtrStorage), std::move(MRequirements), std::move(MEvents), std::move(MArgs), MKernelName, MOSModuleHandle, - std::move(MStreamStorage), MCGType, MCodeLoc)); + std::move(MStreamStorage), std::move(Impl->MReductionResources), + MCGType, MCodeLoc)); break; } case detail::CG::CodeplayInteropTask: @@ -321,6 +342,12 @@ event handler::finalize() { return MLastEvent; } +void handler::addReduction(const std::shared_ptr &ReduObj) { + std::shared_ptr Impl = getHandlerImpl(); + std::lock_guard Lock(Impl->MReductionResourcesMutex); + Impl->MReductionResources.push_back(ReduObj); +} + void handler::associateWithHandler(detail::AccessorBaseHost *AccBase, access::target AccTarget) { detail::AccessorImplPtr AccImpl = detail::getSyclObjImpl(*AccBase); diff --git a/sycl/test/abi/sycl_symbols_linux.dump b/sycl/test/abi/sycl_symbols_linux.dump index 5423c563c710f..62959c1c92d0a 100644 --- a/sycl/test/abi/sycl_symbols_linux.dump +++ b/sycl/test/abi/sycl_symbols_linux.dump @@ -3940,6 +3940,7 @@ _ZN2cl4sycl7contextC2ESt10shared_ptrINS0_6detail12context_implEE _ZN2cl4sycl7handler10mem_adviseEPKvmi _ZN2cl4sycl7handler10processArgEPvRKNS0_6detail19kernel_param_kind_tEimRmb _ZN2cl4sycl7handler10processArgEPvRKNS0_6detail19kernel_param_kind_tEimRmbb +_ZN2cl4sycl7handler12addReductionERKSt10shared_ptrIKvE _ZN2cl4sycl7handler13getKernelNameB5cxx11Ev _ZN2cl4sycl7handler17use_kernel_bundleERKNS0_13kernel_bundleILNS0_12bundle_stateE2EEE _ZN2cl4sycl7handler18RangeRoundingTraceEv @@ -4231,12 +4232,12 @@ _ZNK2cl4sycl6device8get_infoILNS0_4info6deviceE65574EEENS3_12param_traitsIS4_XT_ _ZNK2cl4sycl6device8get_infoILNS0_4info6deviceE65575EEENS3_12param_traitsIS4_XT_EE11return_typeEv _ZNK2cl4sycl6device8get_infoILNS0_4info6deviceE65808EEENS3_12param_traitsIS4_XT_EE11return_typeEv _ZNK2cl4sycl6device8get_infoILNS0_4info6deviceE65809EEENS3_12param_traitsIS4_XT_EE11return_typeEv +_ZNK2cl4sycl6device8get_infoILNS0_4info6deviceE65810EEENS3_12param_traitsIS4_XT_EE11return_typeEv _ZNK2cl4sycl6device9getNativeEv _ZNK2cl4sycl6kernel11get_backendEv _ZNK2cl4sycl6kernel11get_contextEv _ZNK2cl4sycl6kernel11get_programEv _ZNK2cl4sycl6kernel13getNativeImplEv -_ZNK2cl4sycl6kernel9getNativeEv _ZNK2cl4sycl6kernel17get_kernel_bundleEv _ZNK2cl4sycl6kernel18get_sub_group_infoILNS0_4info16kernel_sub_groupE16650EEENS3_12param_traitsIS4_XT_EE11return_typeERKNS0_6deviceE _ZNK2cl4sycl6kernel18get_sub_group_infoILNS0_4info16kernel_sub_groupE4537EEENS3_12param_traitsIS4_XT_EE11return_typeERKNS0_6deviceE @@ -4265,6 +4266,7 @@ _ZNK2cl4sycl6kernel8get_infoILNS0_4info6kernelE4498EEENS3_12param_traitsIS4_XT_E _ZNK2cl4sycl6kernel8get_infoILNS0_4info6kernelE4499EEENS3_12param_traitsIS4_XT_EE11return_typeEv _ZNK2cl4sycl6kernel8get_infoILNS0_4info6kernelE4500EEENS3_12param_traitsIS4_XT_EE11return_typeEv _ZNK2cl4sycl6kernel8get_infoILNS0_4info6kernelE4501EEENS3_12param_traitsIS4_XT_EE11return_typeEv +_ZNK2cl4sycl6kernel9getNativeEv _ZNK2cl4sycl6stream22get_max_statement_sizeEv _ZNK2cl4sycl6stream8get_sizeEv _ZNK2cl4sycl6streameqERKS1_ @@ -4306,6 +4308,7 @@ _ZNK2cl4sycl7context8get_infoILNS0_4info7contextE4228EEENS3_12param_traitsIS4_XT _ZNK2cl4sycl7context8get_infoILNS0_4info7contextE65552EEENS3_12param_traitsIS4_XT_EE11return_typeEv _ZNK2cl4sycl7context9getNativeEv _ZNK2cl4sycl7handler14getHandlerImplEv +_ZNK2cl4sycl7handler16evictHandlerImplEv _ZNK2cl4sycl7handler27isStateExplicitKernelBundleEv _ZNK2cl4sycl7handler30getOrInsertHandlerKernelBundleEb _ZNK2cl4sycl7program10get_kernelENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE @@ -4405,8 +4408,3 @@ _ZNK2cl4sycl9exception8categoryEv _ZNK2cl4sycl9kernel_id8get_nameEv __sycl_register_lib __sycl_unregister_lib -_ZNK2cl4sycl6device8get_infoILNS0_4info6deviceE131072EEENS3_12param_traitsIS4_XT_EE11return_typeEv -_ZNK2cl4sycl6device8get_infoILNS0_4info6deviceE131075EEENS3_12param_traitsIS4_XT_EE11return_typeEv -_ZNK2cl4sycl6device8get_infoILNS0_4info6deviceE131074EEENS3_12param_traitsIS4_XT_EE11return_typeEv -_ZNK2cl4sycl6device8get_infoILNS0_4info6deviceE131073EEENS3_12param_traitsIS4_XT_EE11return_typeEv -_ZNK2cl4sycl6device8get_infoILNS0_4info6deviceE65810EEENS3_12param_traitsIS4_XT_EE11return_typeEv diff --git a/sycl/unittests/program_manager/EliminatedArgMask.cpp b/sycl/unittests/program_manager/EliminatedArgMask.cpp index db7f380e1e31c..80e40f8f80700 100644 --- a/sycl/unittests/program_manager/EliminatedArgMask.cpp +++ b/sycl/unittests/program_manager/EliminatedArgMask.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include +#include #include #include #include @@ -126,6 +127,7 @@ class MockHandler : public sycl::handler { std::unique_ptr finalize() { auto CGH = static_cast(this); + std::shared_ptr Impl = evictHandlerImpl(); std::unique_ptr CommandGroup; switch (getType()) { case sycl::detail::CG::Kernel: { @@ -136,7 +138,7 @@ class MockHandler : public sycl::handler { std::move(CGH->MRequirements), std::move(CGH->MEvents), std::move(CGH->MArgs), std::move(CGH->MKernelName), std::move(CGH->MOSModuleHandle), std::move(CGH->MStreamStorage), - CGH->MCGType, CGH->MCodeLoc)); + std::move(Impl->MReductionResources), CGH->MCGType, CGH->MCodeLoc)); break; } default: diff --git a/sycl/unittests/scheduler/SchedulerTestUtils.hpp b/sycl/unittests/scheduler/SchedulerTestUtils.hpp index bd80f24820f8f..633dda5cce65a 100644 --- a/sycl/unittests/scheduler/SchedulerTestUtils.hpp +++ b/sycl/unittests/scheduler/SchedulerTestUtils.hpp @@ -113,7 +113,9 @@ class MockScheduler : public cl::sycl::detail::Scheduler { void cleanupCommandsForRecord(cl::sycl::detail::MemObjRecord *Rec) { std::vector> StreamsToDeallocate; - MGraphBuilder.cleanupCommandsForRecord(Rec, StreamsToDeallocate); + std::vector> ReductionResourcesToDeallocate; + MGraphBuilder.cleanupCommandsForRecord(Rec, StreamsToDeallocate, + ReductionResourcesToDeallocate); } void addNodeToLeaves(cl::sycl::detail::MemObjRecord *Rec, diff --git a/sycl/unittests/scheduler/StreamInitDependencyOnHost.cpp b/sycl/unittests/scheduler/StreamInitDependencyOnHost.cpp index e1e87d8464f57..6747a87c38a65 100644 --- a/sycl/unittests/scheduler/StreamInitDependencyOnHost.cpp +++ b/sycl/unittests/scheduler/StreamInitDependencyOnHost.cpp @@ -9,6 +9,7 @@ #include "SchedulerTest.hpp" #include "SchedulerTestUtils.hpp" +#include #include using namespace cl::sycl; @@ -39,6 +40,7 @@ class MockHandler : public sycl::handler { std::unique_ptr finalize() { auto CGH = static_cast(this); + std::shared_ptr Impl = evictHandlerImpl(); std::unique_ptr CommandGroup; switch (CGH->MCGType) { case detail::CG::Kernel: @@ -50,7 +52,7 @@ class MockHandler : public sycl::handler { std::move(CGH->MRequirements), std::move(CGH->MEvents), std::move(CGH->MArgs), std::move(CGH->MKernelName), std::move(CGH->MOSModuleHandle), std::move(CGH->MStreamStorage), - CGH->MCGType, CGH->MCodeLoc)); + std::move(Impl->MReductionResources), CGH->MCGType, CGH->MCodeLoc)); break; } default: