Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions paddle/fluid/framework/naive_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,20 @@ void NaiveExecutor::RegisterInputHook(const HookFunc &hookfunc) {
}
}

void NaiveExecutor::RegisterOutputHook(const PirHookFunc &hookfunc) {
pir_output_hookfuncs_.push_back(hookfunc);
if (interpreter_core_) {
interpreter_core_->SetOutputHooks(pir_output_hookfuncs_);
}
}

void NaiveExecutor::RegisterInputHook(const PirHookFunc &hookfunc) {
pir_input_hookfuncs_.push_back(hookfunc);
if (interpreter_core_) {
interpreter_core_->SetInputHooks(pir_input_hookfuncs_);
}
}

void NaiveExecutor::MakeReusePlan(
const std::unordered_map<std::string, std::string> &reuse_table) {
std::unordered_map<std::string, std::unordered_set<std::string>> clusters;
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/framework/naive_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class NaiveExecutor {
public:
using HookFunc = std::function<void(OperatorBase*, Scope*)>;

using PirHookFunc =
std::function<void(InstructionBase*, ValueExecutionInfo*, Scope*)>;

explicit NaiveExecutor(const platform::Place& place) : place_(place) {}

~NaiveExecutor();
Expand Down Expand Up @@ -94,6 +97,8 @@ class NaiveExecutor {

void RegisterOutputHook(const HookFunc& hookfunc);
void RegisterInputHook(const HookFunc& hookfunc);
void RegisterOutputHook(const PirHookFunc& hookfunc);
void RegisterInputHook(const PirHookFunc& hookfunc);

private:
void CreateOps(const ProgramDesc& desc, int block_id);
Expand All @@ -107,6 +112,9 @@ class NaiveExecutor {
std::vector<HookFunc> output_hookfuncs_;
std::vector<HookFunc> input_hookfuncs_;

std::vector<PirHookFunc> pir_output_hookfuncs_;
std::vector<PirHookFunc> pir_input_hookfuncs_;

// Record information that tensor_a should ShareBufferWith tensor_b.
std::unordered_map<OperatorBase*, std::unordered_map<phi::DenseTensor*, int>>
reuse_cache_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,16 @@ IfInstruction::~IfInstruction() {
}
}

void IfInstruction::SetOutputHooks(const std::vector<PirHookFunc>& hookfuncs) {
true_branch_inter_->SetOutputHooks(hookfuncs);
false_branch_inter_->SetOutputHooks(hookfuncs);
}

void IfInstruction::SetInputHooks(const std::vector<PirHookFunc>& hookfuncs) {
true_branch_inter_->SetInputHooks(hookfuncs);
false_branch_inter_->SetInputHooks(hookfuncs);
}

void IfInstruction::Run() {
bool cond = true;
if (cond_var_->IsType<phi::DenseTensor>()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class PirInterpreter;
class ValueExecutionInfo;

class IfInstruction : public InstructionBase {
using PirHookFunc =
std::function<void(InstructionBase*, ValueExecutionInfo*, Scope*)>;

public:
IfInstruction(size_t id,
const platform::Place& place,
Expand All @@ -48,6 +51,10 @@ class IfInstruction : public InstructionBase {

PirInterpreter* FalseBranchInterpreter() const { return false_branch_inter_; }

void SetOutputHooks(const std::vector<PirHookFunc>& hookfuncs);

void SetInputHooks(const std::vector<PirHookFunc>& hookfuncs);

private:
::pir::Operation* op_;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,16 @@ void WhileInstruction::ShareDatasToOutputs() {
}
}

void WhileInstruction::SetOutputHooks(
const std::vector<PirHookFunc>& hookfuncs) {
body_inter_->SetOutputHooks(hookfuncs);
}

void WhileInstruction::SetInputHooks(
const std::vector<PirHookFunc>& hookfuncs) {
body_inter_->SetInputHooks(hookfuncs);
}

void WhileInstruction::Run() {
#ifdef PADDLE_WITH_DNNL
// Executor on being destroyed clears oneDNN cache and resets
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class ValueExecutionInfo;
/// 'cond', 'output' = body_block('output');
/// }
class WhileInstruction : public InstructionBase {
using PirHookFunc =
std::function<void(InstructionBase*, ValueExecutionInfo*, Scope*)>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个和上面的IfInstruction里的不需要重复写这个using,paddle/fluid/framework/new_executor/new_executor_defs.h里写过之后,应该都是可以直接用的


public:
WhileInstruction(size_t id,
const platform::Place& place,
Expand All @@ -50,6 +53,10 @@ class WhileInstruction : public InstructionBase {

PirInterpreter* BodyInterpreter() const { return body_inter_.get(); }

void SetOutputHooks(const std::vector<PirHookFunc>& hookfuncs);

void SetInputHooks(const std::vector<PirHookFunc>& hookfuncs);

private:
// 'output' = 'input'
void ShareInputsToOutputs();
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/new_executor/interpreter_base_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ class InterpreterBaseImpl {

virtual void SetInputHooks(const std::vector<HookFunc>& hookfuncs) = 0;

virtual void SetOutputHooks(const std::vector<PirHookFunc>& hookfuncs) = 0;

virtual void SetInputHooks(const std::vector<PirHookFunc>& hookfuncs) = 0;

virtual std::shared_ptr<std::vector<size_t>> GetDependencyCount() const = 0;

virtual bool IsSharedResultsBuild() const = 0;
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,15 @@ void InterpreterCore::SetOutputHooks(const std::vector<HookFunc>& hookfuncs) {
impl_->SetOutputHooks(hookfuncs);
}

void InterpreterCore::SetInputHooks(const std::vector<PirHookFunc>& hookfuncs) {
impl_->SetInputHooks(hookfuncs);
}

void InterpreterCore::SetOutputHooks(
const std::vector<PirHookFunc>& hookfuncs) {
impl_->SetOutputHooks(hookfuncs);
}

void InterpreterCore::Build(
const std::vector<std::string>& feed_names,
std::vector<paddle::framework::OpFuncNode>* op_func_nodes) {
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/framework/new_executor/interpretercore.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class InterpreterBaseImpl;
class InterpreterCore {
using ExecutionConfig = interpreter::ExecutionConfig;
using HookFunc = std::function<void(OperatorBase*, Scope*)>;
using PirHookFunc =
std::function<void(InstructionBase*, ValueExecutionInfo*, Scope*)>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里尝试下能不能也删去,包含paddle/fluid/framework/new_executor/new_executor_defs.h头文件


public:
InterpreterCore(const platform::Place& place,
Expand Down Expand Up @@ -88,6 +90,10 @@ class InterpreterCore {

void SetInputHooks(const std::vector<HookFunc>& hookfuncs);

void SetOutputHooks(const std::vector<PirHookFunc>& hookfuncs);

void SetInputHooks(const std::vector<PirHookFunc>& hookfuncs);

void Build(const std::vector<std::string>& feed_names,
std::vector<paddle::framework::OpFuncNode>* op_func_nodes);

Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/new_executor/new_executor_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,13 @@ COMMON_DECLARE_bool(dynamic_static_unified_comm);
namespace paddle {
namespace framework {

class InstructionBase;
class ValueExecutionInfo;
using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;

using HookFunc = std::function<void(OperatorBase*, Scope*)>;
using PirHookFunc =
std::function<void(InstructionBase*, ValueExecutionInfo*, Scope*)>;

using SchedulingPriority = int64_t;

Expand Down
38 changes: 34 additions & 4 deletions paddle/fluid/framework/new_executor/pir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -723,8 +723,16 @@ void PirInterpreter::BuildInstruction() {
}
} else if (op.dialect()->name() == "pd_op") {
if (op.isa<paddle::dialect::IfOp>()) { // NOLINT
vec_instruction_base_.emplace_back(std::make_unique<IfInstruction>(
op_idx++, place_, &op, value_exe_info_.get(), execution_config_));
std::unique_ptr<IfInstruction> ifInstrPtr =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

变量的命名我们一般都是小写+下划线,其他地方一并改下

std::make_unique<IfInstruction>(op_idx++,
place_,
&op,
value_exe_info_.get(),
execution_config_);
ifInstrPtr->SetOutputHooks(pir_output_hookfuncs_);
ifInstrPtr->SetInputHooks(pir_input_hookfuncs_);
vec_instruction_base_.emplace_back(std::move(ifInstrPtr));

sub_blocks_.insert(
{&op.dyn_cast<paddle::dialect::IfOp>().true_block(),
dynamic_cast<IfInstruction*>(vec_instruction_base_.back().get())
Expand All @@ -742,8 +750,16 @@ void PirInterpreter::BuildInstruction() {
vec_instruction_base_.back().get())
->ForwardInterpreter()});
} else if (op.isa<paddle::dialect::WhileOp>()) {
vec_instruction_base_.emplace_back(std::make_unique<WhileInstruction>(
op_idx++, place_, &op, value_exe_info_.get(), execution_config_));
std::unique_ptr<WhileInstruction> whileInstrPtr =
std::make_unique<WhileInstruction>(op_idx++,
place_,
&op,
value_exe_info_.get(),
execution_config_);
whileInstrPtr->SetOutputHooks(pir_output_hookfuncs_);
whileInstrPtr->SetInputHooks(pir_input_hookfuncs_);
vec_instruction_base_.emplace_back(std::move(whileInstrPtr));

sub_blocks_.insert(
{&op.dyn_cast<paddle::dialect::WhileOp>().body(),
dynamic_cast<WhileInstruction*>(vec_instruction_base_.back().get())
Expand Down Expand Up @@ -1764,6 +1780,13 @@ void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) {
<< " runs on " << platform::GetCurrentThreadName() << "\n"
<< "Before: " << cur_place << " "
<< instr_node->DebugStringEx(scope_, value_exe_info_.get());

if (execution_config_.used_for_inference) {
for (auto& hook : pir_input_hookfuncs_) {
hook(instr_node, value_exe_info_.get(), scope_);
}
}

if (!instr_node->IsArtificial()) {
instr_node->Run();

Expand All @@ -1789,6 +1812,13 @@ void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) {
VLOG(4) << "done CheckGC";
memory::LogDeviceMemoryStats(cur_place, instr_node->Name());
}

if (execution_config_.used_for_inference) {
for (auto& hook : pir_output_hookfuncs_) {
hook(instr_node, value_exe_info_.get(), scope_);
}
}

VLOG(5) << "after run kernel";
instr_node->RecordEvent(cur_place);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/framework/new_executor/pir_interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ class PirInterpreter : public InterpreterBaseImpl {
input_hookfuncs_ = hookfuncs;
}

void SetOutputHooks(const std::vector<PirHookFunc>& hookfuncs) override {
pir_output_hookfuncs_ = hookfuncs;
}

void SetInputHooks(const std::vector<PirHookFunc>& hookfuncs) override {
pir_input_hookfuncs_ = hookfuncs;
}

std::string GetNameByValue(::pir::Value value) const;

// Only for debug
Expand Down Expand Up @@ -203,6 +211,9 @@ class PirInterpreter : public InterpreterBaseImpl {
std::vector<HookFunc> output_hookfuncs_;
std::vector<HookFunc> input_hookfuncs_;

std::vector<PirHookFunc> pir_output_hookfuncs_;
std::vector<PirHookFunc> pir_input_hookfuncs_;

/// ======================== ///
/// For new ir ///
/// ======================== ///
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/framework/new_executor/program_interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ class ProgramInterpreter : public InterpreterBaseImpl {
input_hookfuncs_ = hookfuncs;
}

void SetOutputHooks(const std::vector<PirHookFunc>& hookfuncs) override {
pir_output_hookfuncs_ = hookfuncs;
}

void SetInputHooks(const std::vector<PirHookFunc>& hookfuncs) override {
pir_input_hookfuncs_ = hookfuncs;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件的修改是完全不需要的哈,这是旧IR相关的

std::unordered_map<std::string, std::shared_ptr<EventInter>>*
GetForceEventsToWaitInfo() {
return force_events_to_wait_;
Expand Down Expand Up @@ -239,6 +247,9 @@ class ProgramInterpreter : public InterpreterBaseImpl {
std::vector<HookFunc> output_hookfuncs_;
std::vector<HookFunc> input_hookfuncs_;

std::vector<PirHookFunc> pir_output_hookfuncs_;
std::vector<PirHookFunc> pir_input_hookfuncs_;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
std::unique_ptr<phi::CalculateStreamTimer> calculate_stream_timer_;
#endif
Expand Down
Loading