-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[Inference]Pir support input/output hook #63101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,8 @@ class InterpreterBaseImpl; | |
| class InterpreterCore { | ||
| using ExecutionConfig = interpreter::ExecutionConfig; | ||
| using HookFunc = std::function<void(OperatorBase*, Scope*)>; | ||
| using PirHookFunc = | ||
| std::function<void(InstructionBase*, ValueExecutionInfo*, Scope*)>; | ||
|
||
|
|
||
| public: | ||
| InterpreterCore(const platform::Place& place, | ||
|
|
@@ -88,6 +90,10 @@ class InterpreterCore { | |
|
|
||
| void SetInputHooks(const std::vector<HookFunc>& hookfuncs); | ||
|
|
||
| void SetOutputHooks(const std::vector<PirHookFunc>& hookfuncs); | ||
|
|
||
| void SetInputHooks(const std::vector<PirHookFunc>& hookfuncs); | ||
|
|
||
| void Build(const std::vector<std::string>& feed_names, | ||
| std::vector<paddle::framework::OpFuncNode>* op_func_nodes); | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -723,8 +723,16 @@ void PirInterpreter::BuildInstruction() { | |
| } | ||
| } else if (op.dialect()->name() == "pd_op") { | ||
| if (op.isa<paddle::dialect::IfOp>()) { // NOLINT | ||
| vec_instruction_base_.emplace_back(std::make_unique<IfInstruction>( | ||
| op_idx++, place_, &op, value_exe_info_.get(), execution_config_)); | ||
| std::unique_ptr<IfInstruction> ifInstrPtr = | ||
|
||
| std::make_unique<IfInstruction>(op_idx++, | ||
| place_, | ||
| &op, | ||
| value_exe_info_.get(), | ||
| execution_config_); | ||
| ifInstrPtr->SetOutputHooks(pir_output_hookfuncs_); | ||
| ifInstrPtr->SetInputHooks(pir_input_hookfuncs_); | ||
| vec_instruction_base_.emplace_back(std::move(ifInstrPtr)); | ||
|
|
||
| sub_blocks_.insert( | ||
| {&op.dyn_cast<paddle::dialect::IfOp>().true_block(), | ||
| dynamic_cast<IfInstruction*>(vec_instruction_base_.back().get()) | ||
|
|
@@ -742,8 +750,16 @@ void PirInterpreter::BuildInstruction() { | |
| vec_instruction_base_.back().get()) | ||
| ->ForwardInterpreter()}); | ||
| } else if (op.isa<paddle::dialect::WhileOp>()) { | ||
| vec_instruction_base_.emplace_back(std::make_unique<WhileInstruction>( | ||
| op_idx++, place_, &op, value_exe_info_.get(), execution_config_)); | ||
| std::unique_ptr<WhileInstruction> whileInstrPtr = | ||
| std::make_unique<WhileInstruction>(op_idx++, | ||
| place_, | ||
| &op, | ||
| value_exe_info_.get(), | ||
| execution_config_); | ||
| whileInstrPtr->SetOutputHooks(pir_output_hookfuncs_); | ||
| whileInstrPtr->SetInputHooks(pir_input_hookfuncs_); | ||
| vec_instruction_base_.emplace_back(std::move(whileInstrPtr)); | ||
|
|
||
| sub_blocks_.insert( | ||
| {&op.dyn_cast<paddle::dialect::WhileOp>().body(), | ||
| dynamic_cast<WhileInstruction*>(vec_instruction_base_.back().get()) | ||
|
|
@@ -1764,6 +1780,13 @@ void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) { | |
| << " runs on " << platform::GetCurrentThreadName() << "\n" | ||
| << "Before: " << cur_place << " " | ||
| << instr_node->DebugStringEx(scope_, value_exe_info_.get()); | ||
|
|
||
| if (execution_config_.used_for_inference) { | ||
| for (auto& hook : pir_input_hookfuncs_) { | ||
| hook(instr_node, value_exe_info_.get(), scope_); | ||
| } | ||
| } | ||
|
|
||
| if (!instr_node->IsArtificial()) { | ||
| instr_node->Run(); | ||
|
|
||
|
|
@@ -1789,6 +1812,13 @@ void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) { | |
| VLOG(4) << "done CheckGC"; | ||
| memory::LogDeviceMemoryStats(cur_place, instr_node->Name()); | ||
| } | ||
|
|
||
| if (execution_config_.used_for_inference) { | ||
| for (auto& hook : pir_output_hookfuncs_) { | ||
| hook(instr_node, value_exe_info_.get(), scope_); | ||
| } | ||
| } | ||
|
|
||
| VLOG(5) << "after run kernel"; | ||
| instr_node->RecordEvent(cur_place); | ||
| #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -101,6 +101,14 @@ class ProgramInterpreter : public InterpreterBaseImpl { | |
| input_hookfuncs_ = hookfuncs; | ||
| } | ||
|
|
||
| void SetOutputHooks(const std::vector<PirHookFunc>& hookfuncs) override { | ||
| pir_output_hookfuncs_ = hookfuncs; | ||
| } | ||
|
|
||
| void SetInputHooks(const std::vector<PirHookFunc>& hookfuncs) override { | ||
| pir_input_hookfuncs_ = hookfuncs; | ||
| } | ||
|
|
||
|
||
| std::unordered_map<std::string, std::shared_ptr<EventInter>>* | ||
| GetForceEventsToWaitInfo() { | ||
| return force_events_to_wait_; | ||
|
|
@@ -239,6 +247,9 @@ class ProgramInterpreter : public InterpreterBaseImpl { | |
| std::vector<HookFunc> output_hookfuncs_; | ||
| std::vector<HookFunc> input_hookfuncs_; | ||
|
|
||
| std::vector<PirHookFunc> pir_output_hookfuncs_; | ||
| std::vector<PirHookFunc> pir_input_hookfuncs_; | ||
|
|
||
|
||
| #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
| std::unique_ptr<phi::CalculateStreamTimer> calculate_stream_timer_; | ||
| #endif | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个和上面的IfInstruction里的不需要重复写这个using,paddle/fluid/framework/new_executor/new_executor_defs.h里写过之后,应该都是可以直接用的