diff --git a/paddle/fluid/framework/naive_executor.cc b/paddle/fluid/framework/naive_executor.cc index 925c04f658b0a7..90f5b93dcb2efa 100644 --- a/paddle/fluid/framework/naive_executor.cc +++ b/paddle/fluid/framework/naive_executor.cc @@ -72,12 +72,14 @@ void NaiveExecutor::PrepareInterpreterCore( } void NaiveExecutor::RunInterpreterCore( - const std::vector &feed_names, bool need_fetch) { + const std::vector &feed_names, + bool need_fetch, + bool switch_stream) { platform::ScopedFlushDenormal flush; #ifdef PADDLE_WITH_NVTX platform::CudaNvtxRangePush("model", platform::NvtxRangeColor::Yellow); #endif - interpreter_core_->Run(feed_names, need_fetch); + interpreter_core_->Run(feed_names, need_fetch, false, false, switch_stream); #ifdef PADDLE_WITH_NVTX platform::CudaNvtxRangePop(); #endif diff --git a/paddle/fluid/framework/naive_executor.h b/paddle/fluid/framework/naive_executor.h index 5a558f3bd69216..8388bfe3a37fc1 100644 --- a/paddle/fluid/framework/naive_executor.h +++ b/paddle/fluid/framework/naive_executor.h @@ -77,7 +77,8 @@ class NaiveExecutor { void Run(); void RunInterpreterCore(const std::vector& feed_names = {}, - bool need_fetch = false); + bool need_fetch = false, + bool switch_stream = false); // Get an tensor to operating directly, without the need for feed_ops. phi::DenseTensor* FindTensor(const std::string& name); diff --git a/paddle/fluid/framework/new_executor/interpreter_base_impl.h b/paddle/fluid/framework/new_executor/interpreter_base_impl.h index ff5832ba8335e6..a7a618ac90284e 100644 --- a/paddle/fluid/framework/new_executor/interpreter_base_impl.h +++ b/paddle/fluid/framework/new_executor/interpreter_base_impl.h @@ -68,13 +68,15 @@ class InterpreterBaseImpl { const std::vector& feed_names, const std::vector& feed_tensors, bool need_fetch = true, - bool enable_job_schedule_profiler = false) = 0; + bool enable_job_schedule_profiler = false, + bool switch_stream = false) = 0; virtual paddle::framework::FetchList Run( const std::vector& feed_names, bool need_fetch = true, bool enable_job_schedule_profiler = false, - bool enable_op_profiling = false) = 0; + bool enable_op_profiling = false, + bool switch_stream = false) = 0; virtual void ShareWorkQueueFrom(InterpreterBaseImpl* src) = 0; diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index b0bbd11aef0dbd..8fdddb1548d9d2 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -67,19 +67,25 @@ FetchList InterpreterCore::Run( const std::vector& feed_names, const std::vector& feed_tensors, bool need_fetch, - bool enable_job_schedule_profiler) { - return impl_->Run( - feed_names, feed_tensors, need_fetch, enable_job_schedule_profiler); + bool enable_job_schedule_profiler, + bool switch_stream) { + return impl_->Run(feed_names, + feed_tensors, + need_fetch, + enable_job_schedule_profiler, + switch_stream); } FetchList InterpreterCore::Run(const std::vector& feed_names, bool need_fetch, bool enable_job_schedule_profiler, - bool enable_op_profiling) { + bool enable_op_profiling, + bool switch_stream) { return impl_->Run(feed_names, need_fetch, enable_job_schedule_profiler, - enable_op_profiling); + enable_op_profiling, + switch_stream); } void InterpreterCore::ShareWorkQueueFrom(std::shared_ptr src) { diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index b8c1913d931dcb..7731620565fb82 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -49,12 +49,14 @@ class InterpreterCore { const std::vector& feed_names, const std::vector& feed_tensors, bool need_fetch = true, - bool enable_job_schedule_profiler = false); + bool enable_job_schedule_profiler = false, + bool switch_stream = false); paddle::framework::FetchList Run(const std::vector& feed_names, bool need_fetch = true, bool enable_job_schedule_profiler = false, - bool enable_op_profiling = false); + bool enable_op_profiling = false, + bool switch_stream = false); void RunProfile(const std::vector& feed_names); diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.cc b/paddle/fluid/framework/new_executor/pir_interpreter.cc index 19e3d6e86ebdeb..84c8fa753eb31d 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/pir_interpreter.cc @@ -1255,7 +1255,8 @@ paddle::framework::FetchList PirInterpreter::Run( const std::vector& feed_names, const std::vector& feed_tensors, bool need_fetch, - bool enable_job_schedule_profiler) { + bool enable_job_schedule_profiler, + bool switch_stream) { enable_job_schedule_profiler_ = enable_job_schedule_profiler; auto FeedInput = [&] { @@ -1318,6 +1319,12 @@ paddle::framework::FetchList PirInterpreter::Run( is_build_ = true; is_shared_results_build_ = true; } else { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (switch_stream) { + BuildInstruction(); + VLOG(4) << "Done BuildInstruction"; + } +#endif if (FLAGS_enable_pir_in_executor_trace_run || nccl_op_num_ > 1 || execution_config_.used_for_inference || ((execution_config_.used_for_jit || execution_config_.used_for_cinn) && @@ -1350,7 +1357,8 @@ paddle::framework::FetchList PirInterpreter::Run( FetchList PirInterpreter::Run(const std::vector& feed_names, bool need_fetch, bool enable_job_schedule_profiler, - bool enable_op_profiling) { + bool enable_op_profiling, + bool switch_stream) { enable_job_schedule_profiler_ = enable_job_schedule_profiler; if (enable_op_profiling) { @@ -1401,6 +1409,12 @@ FetchList PirInterpreter::Run(const std::vector& feed_names, is_build_ = true; is_shared_results_build_ = true; } else { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (switch_stream) { + BuildInstruction(); + VLOG(4) << "Done BuildInstruction"; + } +#endif if (FLAGS_enable_pir_in_executor_trace_run || nccl_op_num_ > 1 || execution_config_.used_for_inference || ((execution_config_.used_for_jit || execution_config_.used_for_cinn) && diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.h b/paddle/fluid/framework/new_executor/pir_interpreter.h index 1684aeffef8cfa..3f197f53e12f8c 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.h +++ b/paddle/fluid/framework/new_executor/pir_interpreter.h @@ -57,12 +57,14 @@ class PirInterpreter : public InterpreterBaseImpl { const std::vector& feed_names, const std::vector& feed_tensors, bool need_fetch = true, - bool enable_job_schedule_profiler = false) override; + bool enable_job_schedule_profiler = false, + bool switch_stream = false) override; paddle::framework::FetchList Run(const std::vector& feed_names, bool need_fetch = true, bool enable_job_schedule_profiler = false, - bool enable_op_profiling = false) override; + bool enable_op_profiling = false, + bool switch_stream = false) override; void ShareWorkQueueFrom(InterpreterBaseImpl* src) override; diff --git a/paddle/fluid/framework/new_executor/program_interpreter.cc b/paddle/fluid/framework/new_executor/program_interpreter.cc index bc41742437ff9c..0f50665e1621e8 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.cc +++ b/paddle/fluid/framework/new_executor/program_interpreter.cc @@ -144,7 +144,8 @@ void ProgramInterpreter::RunImpl() { FetchList ProgramInterpreter::Run(const std::vector& feed_names, bool need_fetch, bool enable_job_schedule_profiler, - bool enable_op_profiling) { + bool enable_op_profiling, + bool switch_stream) { enable_job_schedule_profiler_ = enable_job_schedule_profiler; is_in_op_profiling_mode_ = enable_op_profiling; @@ -163,6 +164,11 @@ FetchList ProgramInterpreter::Run(const std::vector& feed_names, is_build_ = true; is_shared_results_build_ = true; } else { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (switch_stream) { + BuildOpFuncNode(&op_func_nodes); + } +#endif RunImpl(); } @@ -233,7 +239,8 @@ FetchList ProgramInterpreter::Run( const std::vector& feed_names, const std::vector& feed_tensors, bool need_fetch, - bool enable_job_schedule_profiler) { + bool enable_job_schedule_profiler, + bool switch_stream) { enable_job_schedule_profiler_ = enable_job_schedule_profiler; SetDeviceId(place_); @@ -244,7 +251,7 @@ FetchList ProgramInterpreter::Run( #endif bool is_build = is_build_; - Prepare(feed_names, feed_tensors, is_build); + Prepare(feed_names, feed_tensors, is_build, switch_stream); if (is_build) { RunImpl(); @@ -671,42 +678,7 @@ std::tuple ProgramInterpreter::InterpreterRunTime() { void ProgramInterpreter::Convert( std::vector* op_func_nodes) { auto& vec_meta_info = var_scope_.MutableVecMetaInfo(); - auto nodes = *op_func_nodes; - auto op_nums = nodes.size(); - vec_instruction_.clear(); - vec_instruction_.reserve(op_nums); - for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) { - auto& op_func_node = nodes[op_idx]; - stream_analyzer_.SetForceEventsToWaitInfo(force_evnets_to_wait_); - auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node); -#ifdef PADDLE_WITH_CUDA - if (FLAGS_new_executor_use_cuda_graph) { - auto& op = op_func_node.operator_base_; - auto& op_type = op->Type(); - if (op_type == interpreter::kMemcpyD2H || - op_type == interpreter::kMemcpyH2D) { - PADDLE_THROW(paddle::platform::errors::Fatal( - "Cuda memory copy d2h/h2d is not allowed while using cuda graph.")); - } - PADDLE_ENFORCE_EQ(typeid(*dev_ctx_) == typeid(phi::GPUContext), - true, - platform::errors::InvalidArgument( - "Device context of op %s must be [%s] while using " - "cuda graph, but got [%s].", - op_type, - typeid(phi::GPUContext).name(), - typeid(*dev_ctx_).name())); - // cuda graph needs to record all stream - phi::backends::gpu::CUDAGraphContextManager::Instance() - .RecordCapturingDeviceContext(dev_ctx_); - } -#endif - vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_); - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - vec_instruction_.back().UpdataRecordStreamForGcInfo(); -#endif - } + BuildOpFuncNode(op_func_nodes); BuildOperatorDependences(); @@ -743,6 +715,7 @@ void ProgramInterpreter::Convert( } // calculate last_live_ops_ + auto op_nums = (*op_func_nodes).size(); for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) { Instruction& instr = vec_instruction_[op_idx]; OpInOutInfo info; @@ -879,6 +852,46 @@ void ProgramInterpreter::Convert( AnalyseExecuteOrderForTrace(); } +void ProgramInterpreter::BuildOpFuncNode( + std::vector* op_func_nodes) { + auto nodes = *op_func_nodes; + auto op_nums = nodes.size(); + vec_instruction_.clear(); + vec_instruction_.reserve(op_nums); + for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) { + auto& op_func_node = nodes[op_idx]; + stream_analyzer_.SetForceEventsToWaitInfo(force_evnets_to_wait_); + auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node); +#ifdef PADDLE_WITH_CUDA + if (FLAGS_new_executor_use_cuda_graph) { + auto& op = op_func_node.operator_base_; + auto& op_type = op->Type(); + if (op_type == interpreter::kMemcpyD2H || + op_type == interpreter::kMemcpyH2D) { + PADDLE_THROW(paddle::platform::errors::Fatal( + "Cuda memory copy d2h/h2d is not allowed while using cuda graph.")); + } + PADDLE_ENFORCE_EQ(typeid(*dev_ctx_) == typeid(phi::GPUContext), + true, + platform::errors::InvalidArgument( + "Device context of op %s must be [%s] while using " + "cuda graph, but got [%s].", + op_type, + typeid(phi::GPUContext).name(), + typeid(*dev_ctx_).name())); + // cuda graph needs to record all stream + phi::backends::gpu::CUDAGraphContextManager::Instance() + .RecordCapturingDeviceContext(dev_ctx_); + } +#endif + vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_); + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + vec_instruction_.back().UpdataRecordStreamForGcInfo(); +#endif + } +} + void ProgramInterpreter::BuildSkipShareLoDInfo() { for (size_t i = 0; i < vec_instruction_.size(); ++i) { bool can_skip_lod = true; @@ -1494,7 +1507,8 @@ void ProgramInterpreter::CheckGC(const Instruction& instr) { void ProgramInterpreter::Prepare( const std::vector& feed_names, const std::vector& feed_tensors, - bool prepare_feed) { + bool prepare_feed, + bool switch_stream) { PADDLE_ENFORCE_EQ(feed_names.size(), feed_tensors.size(), platform::errors::PreconditionNotMet( @@ -1517,7 +1531,7 @@ void ProgramInterpreter::Prepare( } }; - if (!is_build_) { + if (!is_build_ || switch_stream) { paddle::framework::interpreter::BuildVariableScope( block_, execution_config_, &var_scope_); FeedInput(); diff --git a/paddle/fluid/framework/new_executor/program_interpreter.h b/paddle/fluid/framework/new_executor/program_interpreter.h index b19e3a06a42588..5359c41fddcdc6 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.h +++ b/paddle/fluid/framework/new_executor/program_interpreter.h @@ -49,12 +49,14 @@ class ProgramInterpreter : public InterpreterBaseImpl { const std::vector& feed_names, const std::vector& feed_tensors, bool need_fetch = true, - bool enable_job_schedule_profiler = false) override; + bool enable_job_schedule_profiler = false, + bool switch_stream = false) override; paddle::framework::FetchList Run(const std::vector& feed_names, bool need_fetch = true, bool enable_job_schedule_profiler = false, - bool enable_op_profiling = false) override; + bool enable_op_profiling = false, + bool switch_stream = false) override; std::shared_ptr GetMutableCopyProgram() override; @@ -125,6 +127,8 @@ class ProgramInterpreter : public InterpreterBaseImpl { void BuildSkipShareLoDInfo(); void UpdateSyncOpNum(); void AnalyseExecuteOrderForTrace(); + void BuildOpFuncNode( + std::vector* op_func_nodes); // inplace void BuildInplace(); @@ -150,7 +154,8 @@ class ProgramInterpreter : public InterpreterBaseImpl { // only used when program contains no feed op void Prepare(const std::vector& feed_names, const std::vector& feed_tensors, - bool prepare_feed); + bool prepare_feed, + bool switch_stream = false); void RecordMemcpyD2H(const Instruction& instr_node); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index e042f358c9874c..7682e392bf2089 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2250,7 +2250,7 @@ std::unique_ptr AnalysisPredictor::GetOutputTensor( return res; } -bool AnalysisPredictor::ZeroCopyRun() { +bool AnalysisPredictor::ZeroCopyRun(bool switch_stream) { inference::DisplayMemoryInfo(place_, "before run"); #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) if (config_.dist_config().use_dist_model()) { @@ -2313,7 +2313,7 @@ bool AnalysisPredictor::ZeroCopyRun() { #endif if (config_.new_executor_enabled()) { - executor_->RunInterpreterCore(); + executor_->RunInterpreterCore({}, false, switch_stream); } else { executor_->Run(); } @@ -2354,7 +2354,7 @@ bool AnalysisPredictor::ExpRunWithExternalStream(const gpuStream_t stream) { "Please use config.SetExecStream to init gpu resources, and then we " "will bind gpu resources to execution stream.")); } - + bool switch_stream = false; if (stream != predictor_stream_) { #ifdef PADDLE_WITH_HIP hipStreamSynchronize(static_cast(predictor_stream_)); @@ -2384,9 +2384,9 @@ bool AnalysisPredictor::ExpRunWithExternalStream(const gpuStream_t stream) { })); auto &pool = paddle::experimental::DeviceContextPool::Instance(); pool.SyncDeviceContext(place_); + switch_stream = true; } - - return ZeroCopyRun(); + return ZeroCopyRun(switch_stream); } #endif diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index 4a5cfb229a459e..0f2091478af2a1 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -204,9 +204,10 @@ class AnalysisPredictor : public PaddlePredictor { /// /// \brief Run the prediction engine /// + /// \param switch_stream Whether the stream is switched /// \return Whether the function executed successfully /// - bool ZeroCopyRun() override; + bool ZeroCopyRun(bool switch_stream = false) override; #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) // Note: Can only be used under thread_local semantics. diff --git a/paddle/fluid/inference/api/onnxruntime_predictor.cc b/paddle/fluid/inference/api/onnxruntime_predictor.cc index 25970440469168..f2d8f7478d9024 100644 --- a/paddle/fluid/inference/api/onnxruntime_predictor.cc +++ b/paddle/fluid/inference/api/onnxruntime_predictor.cc @@ -333,7 +333,7 @@ bool ONNXRuntimePredictor::Run(const std::vector &inputs, return false; } -bool ONNXRuntimePredictor::ZeroCopyRun() { +bool ONNXRuntimePredictor::ZeroCopyRun(bool switch_stream) { try { const char *device_name = platform::is_cpu_place(place_) ? "Cpu" : "Cuda"; std::vector inputs; diff --git a/paddle/fluid/inference/api/onnxruntime_predictor.h b/paddle/fluid/inference/api/onnxruntime_predictor.h index 971632c4b3c7a6..c983f8acdae281 100644 --- a/paddle/fluid/inference/api/onnxruntime_predictor.h +++ b/paddle/fluid/inference/api/onnxruntime_predictor.h @@ -175,9 +175,10 @@ class ONNXRuntimePredictor : public PaddlePredictor { /// /// \brief Run the prediction engine /// + /// \param switch_stream Whether the stream is switched /// \return Whether the function executed successfully /// - bool ZeroCopyRun() override; + bool ZeroCopyRun(bool switch_stream = false) override; /// /// \brief Release all tmp tensor to compress the size of the memory pool. diff --git a/paddle/fluid/inference/api/paddle_api.h b/paddle/fluid/inference/api/paddle_api.h index 3fefba9ef22be8..89540a91e37895 100644 --- a/paddle/fluid/inference/api/paddle_api.h +++ b/paddle/fluid/inference/api/paddle_api.h @@ -295,8 +295,9 @@ class PD_INFER_DECL PaddlePredictor { /// To use it, one should call the AnalysisConfig.SwitchUseFeedFetchOp(false) /// and then use the `GetInputTensor` and `GetOutputTensor` /// to directly write or read the input/output tensors. + /// \param switch_stream Whether the stream is switched. /// \return Whether the run is successful - virtual bool ZeroCopyRun() { return false; } + virtual bool ZeroCopyRun(bool switch_stream = false) { return false; } /// /// \brief Clear the intermediate tensors of the predictor diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 03a95e870b8105..94df6a0ee0d418 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -691,7 +691,9 @@ void BindPaddlePredictor(py::module *m) { .def("get_output_tensor", &PaddlePredictor::GetOutputTensor) .def("get_input_names", &PaddlePredictor::GetInputNames) .def("get_output_names", &PaddlePredictor::GetOutputNames) - .def("zero_copy_run", &PaddlePredictor::ZeroCopyRun) + .def("zero_copy_run", + &PaddlePredictor::ZeroCopyRun, + py::arg("switch_stream") = false) .def("clone", [](PaddlePredictor &self) { return self.Clone(nullptr); }) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) .def("clone", @@ -740,7 +742,9 @@ void BindNativePredictor(py::module *m) { }) .def("get_input_tensor", &NativePaddlePredictor::GetInputTensor) .def("get_output_tensor", &NativePaddlePredictor::GetOutputTensor) - .def("zero_copy_run", &NativePaddlePredictor::ZeroCopyRun) + .def("zero_copy_run", + &NativePaddlePredictor::ZeroCopyRun, + py::arg("switch_stream") = false) .def("clone", [](NativePaddlePredictor &self) { return self.Clone(nullptr); }) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -1130,7 +1134,9 @@ void BindAnalysisPredictor(py::module *m) { .def("get_input_names", &AnalysisPredictor::GetInputNames) .def("get_output_names", &AnalysisPredictor::GetOutputNames) .def("get_input_tensor_shape", &AnalysisPredictor::GetInputTensorShape) - .def("zero_copy_run", &AnalysisPredictor::ZeroCopyRun) + .def("zero_copy_run", + &AnalysisPredictor::ZeroCopyRun, + py::arg("switch_stream") = false) .def("clear_intermediate_tensor", &AnalysisPredictor::ClearIntermediateTensor) .def("try_shrink_memory", &AnalysisPredictor::TryShrinkMemory) diff --git a/test/cpp/inference/api/analysis_predictor_tester.cc b/test/cpp/inference/api/analysis_predictor_tester.cc index 3d841954a89d65..3d87140d9c05a7 100644 --- a/test/cpp/inference/api/analysis_predictor_tester.cc +++ b/test/cpp/inference/api/analysis_predictor_tester.cc @@ -668,6 +668,7 @@ TEST(Tensor, RunWithExternalStream) { cudaStream_t stream; cudaStreamCreate(&stream); config.SetExecStream(stream); + config.EnableNewExecutor(); auto predictor = CreatePredictor(config); auto w0 = predictor->GetInputHandle("firstw"); @@ -703,8 +704,7 @@ TEST(Tensor, RunWithExternalStream) { cudaStream_t external_stream; cudaStreamCreate(&external_stream); - Config tmp_config(config); - tmp_config.SetExecStream(external_stream); + predictor->Run(); paddle_infer::experimental::InternalUtils::RunWithExternalStream( predictor.get(), external_stream);