Skip to content

Commit f280f8e

Browse files
authored
[Inference] Pir support input/output hook (#63101)
* add register hook for pir * fix
1 parent a128eca commit f280f8e

15 files changed

Lines changed: 199 additions & 38 deletions

paddle/fluid/framework/naive_executor.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,20 @@ void NaiveExecutor::RegisterInputHook(const HookFunc &hookfunc) {
234234
}
235235
}
236236

237+
void NaiveExecutor::RegisterOutputHook(const PirHookFunc &hookfunc) {
238+
pir_output_hookfuncs_.push_back(hookfunc);
239+
if (interpreter_core_) {
240+
interpreter_core_->SetOutputHooks(pir_output_hookfuncs_);
241+
}
242+
}
243+
244+
void NaiveExecutor::RegisterInputHook(const PirHookFunc &hookfunc) {
245+
pir_input_hookfuncs_.push_back(hookfunc);
246+
if (interpreter_core_) {
247+
interpreter_core_->SetInputHooks(pir_input_hookfuncs_);
248+
}
249+
}
250+
237251
void NaiveExecutor::MakeReusePlan(
238252
const std::unordered_map<std::string, std::string> &reuse_table) {
239253
std::unordered_map<std::string, std::unordered_set<std::string>> clusters;

paddle/fluid/framework/naive_executor.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ class NaiveExecutor {
4545
public:
4646
using HookFunc = std::function<void(OperatorBase*, Scope*)>;
4747

48+
using PirHookFunc =
49+
std::function<void(InstructionBase*, ValueExecutionInfo*, Scope*)>;
50+
4851
explicit NaiveExecutor(const platform::Place& place) : place_(place) {}
4952

5053
~NaiveExecutor();
@@ -94,6 +97,8 @@ class NaiveExecutor {
9497

9598
void RegisterOutputHook(const HookFunc& hookfunc);
9699
void RegisterInputHook(const HookFunc& hookfunc);
100+
void RegisterOutputHook(const PirHookFunc& hookfunc);
101+
void RegisterInputHook(const PirHookFunc& hookfunc);
97102

98103
private:
99104
void CreateOps(const ProgramDesc& desc, int block_id);
@@ -107,6 +112,9 @@ class NaiveExecutor {
107112
std::vector<HookFunc> output_hookfuncs_;
108113
std::vector<HookFunc> input_hookfuncs_;
109114

115+
std::vector<PirHookFunc> pir_output_hookfuncs_;
116+
std::vector<PirHookFunc> pir_input_hookfuncs_;
117+
110118
// Record information that tensor_a should ShareBufferWith tensor_b.
111119
std::unordered_map<OperatorBase*, std::unordered_map<phi::DenseTensor*, int>>
112120
reuse_cache_;

paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,16 @@ IfInstruction::~IfInstruction() {
198198
}
199199
}
200200

201+
void IfInstruction::SetOutputHooks(const std::vector<PirHookFunc>& hookfuncs) {
202+
true_branch_inter_->SetOutputHooks(hookfuncs);
203+
false_branch_inter_->SetOutputHooks(hookfuncs);
204+
}
205+
206+
void IfInstruction::SetInputHooks(const std::vector<PirHookFunc>& hookfuncs) {
207+
true_branch_inter_->SetInputHooks(hookfuncs);
208+
false_branch_inter_->SetInputHooks(hookfuncs);
209+
}
210+
201211
void IfInstruction::Run() {
202212
bool cond = true;
203213
if (cond_var_->IsType<phi::DenseTensor>()) {

paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ class IfInstruction : public InstructionBase {
4848

4949
PirInterpreter* FalseBranchInterpreter() const { return false_branch_inter_; }
5050

51+
void SetOutputHooks(const std::vector<PirHookFunc>& hookfuncs);
52+
53+
void SetInputHooks(const std::vector<PirHookFunc>& hookfuncs);
54+
5155
private:
5256
::pir::Operation* op_;
5357

paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,16 @@ void WhileInstruction::ShareDatasToOutputs() {
240240
}
241241
}
242242

243+
void WhileInstruction::SetOutputHooks(
244+
const std::vector<PirHookFunc>& hookfuncs) {
245+
body_inter_->SetOutputHooks(hookfuncs);
246+
}
247+
248+
void WhileInstruction::SetInputHooks(
249+
const std::vector<PirHookFunc>& hookfuncs) {
250+
body_inter_->SetInputHooks(hookfuncs);
251+
}
252+
243253
void WhileInstruction::Run() {
244254
#ifdef PADDLE_WITH_DNNL
245255
// Executor on being destroyed clears oneDNN cache and resets

paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ class WhileInstruction : public InstructionBase {
5050

5151
PirInterpreter* BodyInterpreter() const { return body_inter_.get(); }
5252

53+
void SetOutputHooks(const std::vector<PirHookFunc>& hookfuncs);
54+
55+
void SetInputHooks(const std::vector<PirHookFunc>& hookfuncs);
56+
5357
private:
5458
// 'output' = 'input'
5559
void ShareInputsToOutputs();

paddle/fluid/framework/new_executor/interpreter_base_impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ class InterpreterBaseImpl {
104104

105105
virtual void SetInputHooks(const std::vector<HookFunc>& hookfuncs) = 0;
106106

107+
virtual void SetOutputHooks(const std::vector<PirHookFunc>& hookfuncs) = 0;
108+
109+
virtual void SetInputHooks(const std::vector<PirHookFunc>& hookfuncs) = 0;
110+
107111
virtual std::shared_ptr<std::vector<size_t>> GetDependencyCount() const = 0;
108112

109113
virtual bool IsSharedResultsBuild() const = 0;

paddle/fluid/framework/new_executor/interpretercore.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,15 @@ void InterpreterCore::SetOutputHooks(const std::vector<HookFunc>& hookfuncs) {
139139
impl_->SetOutputHooks(hookfuncs);
140140
}
141141

142+
void InterpreterCore::SetInputHooks(const std::vector<PirHookFunc>& hookfuncs) {
143+
impl_->SetInputHooks(hookfuncs);
144+
}
145+
146+
void InterpreterCore::SetOutputHooks(
147+
const std::vector<PirHookFunc>& hookfuncs) {
148+
impl_->SetOutputHooks(hookfuncs);
149+
}
150+
142151
void InterpreterCore::Build(
143152
const std::vector<std::string>& feed_names,
144153
std::vector<paddle::framework::OpFuncNode>* op_func_nodes) {

paddle/fluid/framework/new_executor/interpretercore.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#pragma once
1515

1616
#include "paddle/fluid/framework/new_executor/interpreter_base_impl.h"
17+
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
1718

1819
PD_DECLARE_bool(new_executor_use_local_scope);
1920

@@ -88,6 +89,10 @@ class InterpreterCore {
8889

8990
void SetInputHooks(const std::vector<HookFunc>& hookfuncs);
9091

92+
void SetOutputHooks(const std::vector<PirHookFunc>& hookfuncs);
93+
94+
void SetInputHooks(const std::vector<PirHookFunc>& hookfuncs);
95+
9196
void Build(const std::vector<std::string>& feed_names,
9297
std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
9398

paddle/fluid/framework/new_executor/new_executor_defs.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,13 @@ COMMON_DECLARE_bool(dynamic_static_unified_comm);
4040
namespace paddle {
4141
namespace framework {
4242

43+
class InstructionBase;
44+
class ValueExecutionInfo;
4345
using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;
4446

4547
using HookFunc = std::function<void(OperatorBase*, Scope*)>;
48+
using PirHookFunc =
49+
std::function<void(InstructionBase*, ValueExecutionInfo*, Scope*)>;
4650

4751
using SchedulingPriority = int64_t;
4852

0 commit comments

Comments
 (0)