Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions paddle/fluid/framework/naive_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,20 @@ void NaiveExecutor::RegisterInputHook(const HookFunc &hookfunc) {
}
}

void NaiveExecutor::RegisterOutputHook(const PirHookFunc &hookfunc) {
pir_output_hookfuncs_.push_back(hookfunc);
if (interpreter_core_) {
interpreter_core_->SetOutputHooks(pir_output_hookfuncs_);
}
}

void NaiveExecutor::RegisterInputHook(const PirHookFunc &hookfunc) {
pir_input_hookfuncs_.push_back(hookfunc);
if (interpreter_core_) {
interpreter_core_->SetInputHooks(pir_input_hookfuncs_);
}
}

void NaiveExecutor::MakeReusePlan(
const std::unordered_map<std::string, std::string> &reuse_table) {
std::unordered_map<std::string, std::unordered_set<std::string>> clusters;
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/framework/naive_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class NaiveExecutor {
public:
using HookFunc = std::function<void(OperatorBase*, Scope*)>;

using PirHookFunc =
std::function<void(InstructionBase*, ValueExecutionInfo*, Scope*)>;

explicit NaiveExecutor(const platform::Place& place) : place_(place) {}

~NaiveExecutor();
Expand Down Expand Up @@ -94,6 +97,8 @@ class NaiveExecutor {

void RegisterOutputHook(const HookFunc& hookfunc);
void RegisterInputHook(const HookFunc& hookfunc);
void RegisterOutputHook(const PirHookFunc& hookfunc);
void RegisterInputHook(const PirHookFunc& hookfunc);

private:
void CreateOps(const ProgramDesc& desc, int block_id);
Expand All @@ -107,6 +112,9 @@ class NaiveExecutor {
std::vector<HookFunc> output_hookfuncs_;
std::vector<HookFunc> input_hookfuncs_;

std::vector<PirHookFunc> pir_output_hookfuncs_;
std::vector<PirHookFunc> pir_input_hookfuncs_;

// Record information that tensor_a should ShareBufferWith tensor_b.
std::unordered_map<OperatorBase*, std::unordered_map<phi::DenseTensor*, int>>
reuse_cache_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,16 @@ IfInstruction::~IfInstruction() {
}
}

void IfInstruction::SetOutputHooks(const std::vector<PirHookFunc>& hookfuncs) {
true_branch_inter_->SetOutputHooks(hookfuncs);
false_branch_inter_->SetOutputHooks(hookfuncs);
}

void IfInstruction::SetInputHooks(const std::vector<PirHookFunc>& hookfuncs) {
true_branch_inter_->SetInputHooks(hookfuncs);
false_branch_inter_->SetInputHooks(hookfuncs);
}

void IfInstruction::Run() {
bool cond = true;
if (cond_var_->IsType<phi::DenseTensor>()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class IfInstruction : public InstructionBase {

PirInterpreter* FalseBranchInterpreter() const { return false_branch_inter_; }

void SetOutputHooks(const std::vector<PirHookFunc>& hookfuncs);

void SetInputHooks(const std::vector<PirHookFunc>& hookfuncs);

private:
::pir::Operation* op_;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,16 @@ void WhileInstruction::ShareDatasToOutputs() {
}
}

void WhileInstruction::SetOutputHooks(
const std::vector<PirHookFunc>& hookfuncs) {
body_inter_->SetOutputHooks(hookfuncs);
}

void WhileInstruction::SetInputHooks(
const std::vector<PirHookFunc>& hookfuncs) {
body_inter_->SetInputHooks(hookfuncs);
}

void WhileInstruction::Run() {
#ifdef PADDLE_WITH_DNNL
// Executor on being destroyed clears oneDNN cache and resets
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ class WhileInstruction : public InstructionBase {

PirInterpreter* BodyInterpreter() const { return body_inter_.get(); }

void SetOutputHooks(const std::vector<PirHookFunc>& hookfuncs);

void SetInputHooks(const std::vector<PirHookFunc>& hookfuncs);

private:
// 'output' = 'input'
void ShareInputsToOutputs();
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/new_executor/interpreter_base_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ class InterpreterBaseImpl {

virtual void SetInputHooks(const std::vector<HookFunc>& hookfuncs) = 0;

virtual void SetOutputHooks(const std::vector<PirHookFunc>& hookfuncs) = 0;

virtual void SetInputHooks(const std::vector<PirHookFunc>& hookfuncs) = 0;

virtual std::shared_ptr<std::vector<size_t>> GetDependencyCount() const = 0;

virtual bool IsSharedResultsBuild() const = 0;
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,15 @@ void InterpreterCore::SetOutputHooks(const std::vector<HookFunc>& hookfuncs) {
impl_->SetOutputHooks(hookfuncs);
}

void InterpreterCore::SetInputHooks(const std::vector<PirHookFunc>& hookfuncs) {
impl_->SetInputHooks(hookfuncs);
}

void InterpreterCore::SetOutputHooks(
const std::vector<PirHookFunc>& hookfuncs) {
impl_->SetOutputHooks(hookfuncs);
}

void InterpreterCore::Build(
const std::vector<std::string>& feed_names,
std::vector<paddle::framework::OpFuncNode>* op_func_nodes) {
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/framework/new_executor/interpretercore.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#pragma once

#include "paddle/fluid/framework/new_executor/interpreter_base_impl.h"
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"

PD_DECLARE_bool(new_executor_use_local_scope);

Expand Down Expand Up @@ -88,6 +89,10 @@ class InterpreterCore {

void SetInputHooks(const std::vector<HookFunc>& hookfuncs);

void SetOutputHooks(const std::vector<PirHookFunc>& hookfuncs);

void SetInputHooks(const std::vector<PirHookFunc>& hookfuncs);

void Build(const std::vector<std::string>& feed_names,
std::vector<paddle::framework::OpFuncNode>* op_func_nodes);

Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/new_executor/new_executor_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,13 @@ COMMON_DECLARE_bool(dynamic_static_unified_comm);
namespace paddle {
namespace framework {

class InstructionBase;
class ValueExecutionInfo;
using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;

using HookFunc = std::function<void(OperatorBase*, Scope*)>;
using PirHookFunc =
std::function<void(InstructionBase*, ValueExecutionInfo*, Scope*)>;

using SchedulingPriority = int64_t;

Expand Down
38 changes: 34 additions & 4 deletions paddle/fluid/framework/new_executor/pir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -723,8 +723,16 @@ void PirInterpreter::BuildInstruction() {
}
} else if (op.dialect()->name() == "pd_op") {
if (op.isa<paddle::dialect::IfOp>()) { // NOLINT
vec_instruction_base_.emplace_back(std::make_unique<IfInstruction>(
op_idx++, place_, &op, value_exe_info_.get(), execution_config_));
std::unique_ptr<IfInstruction> if_instr_ptr =
std::make_unique<IfInstruction>(op_idx++,
place_,
&op,
value_exe_info_.get(),
execution_config_);
if_instr_ptr->SetOutputHooks(pir_output_hookfuncs_);
if_instr_ptr->SetInputHooks(pir_input_hookfuncs_);
vec_instruction_base_.emplace_back(std::move(if_instr_ptr));

sub_blocks_.insert(
{&op.dyn_cast<paddle::dialect::IfOp>().true_block(),
dynamic_cast<IfInstruction*>(vec_instruction_base_.back().get())
Expand All @@ -742,8 +750,16 @@ void PirInterpreter::BuildInstruction() {
vec_instruction_base_.back().get())
->ForwardInterpreter()});
} else if (op.isa<paddle::dialect::WhileOp>()) {
vec_instruction_base_.emplace_back(std::make_unique<WhileInstruction>(
op_idx++, place_, &op, value_exe_info_.get(), execution_config_));
std::unique_ptr<WhileInstruction> while_instr_ptr =
std::make_unique<WhileInstruction>(op_idx++,
place_,
&op,
value_exe_info_.get(),
execution_config_);
while_instr_ptr->SetOutputHooks(pir_output_hookfuncs_);
while_instr_ptr->SetInputHooks(pir_input_hookfuncs_);
vec_instruction_base_.emplace_back(std::move(while_instr_ptr));

sub_blocks_.insert(
{&op.dyn_cast<paddle::dialect::WhileOp>().body(),
dynamic_cast<WhileInstruction*>(vec_instruction_base_.back().get())
Expand Down Expand Up @@ -1764,6 +1780,13 @@ void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) {
<< " runs on " << platform::GetCurrentThreadName() << "\n"
<< "Before: " << cur_place << " "
<< instr_node->DebugStringEx(scope_, value_exe_info_.get());

if (execution_config_.used_for_inference) {
for (auto& hook : pir_input_hookfuncs_) {
hook(instr_node, value_exe_info_.get(), scope_);
}
}

if (!instr_node->IsArtificial()) {
instr_node->Run();

Expand All @@ -1789,6 +1812,13 @@ void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) {
VLOG(4) << "done CheckGC";
memory::LogDeviceMemoryStats(cur_place, instr_node->Name());
}

if (execution_config_.used_for_inference) {
for (auto& hook : pir_output_hookfuncs_) {
hook(instr_node, value_exe_info_.get(), scope_);
}
}

VLOG(5) << "after run kernel";
instr_node->RecordEvent(cur_place);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Expand Down
16 changes: 10 additions & 6 deletions paddle/fluid/framework/new_executor/pir_interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,16 @@ class PirInterpreter : public InterpreterBaseImpl {

const platform::Place& GetPlace() const override { return place_; }

void SetOutputHooks(const std::vector<HookFunc>& hookfuncs) override {
output_hookfuncs_ = hookfuncs;
void SetOutputHooks(const std::vector<HookFunc>& hookfuncs) override {}

void SetInputHooks(const std::vector<HookFunc>& hookfuncs) override {}

void SetOutputHooks(const std::vector<PirHookFunc>& hookfuncs) override {
pir_output_hookfuncs_ = hookfuncs;
}

void SetInputHooks(const std::vector<HookFunc>& hookfuncs) override {
input_hookfuncs_ = hookfuncs;
void SetInputHooks(const std::vector<PirHookFunc>& hookfuncs) override {
pir_input_hookfuncs_ = hookfuncs;
}

std::string GetNameByValue(::pir::Value value) const;
Expand Down Expand Up @@ -200,8 +204,8 @@ class PirInterpreter : public InterpreterBaseImpl {
int64_t onednn_op_num_{-1};
std::vector<size_t> trace_execute_order_;

std::vector<HookFunc> output_hookfuncs_;
std::vector<HookFunc> input_hookfuncs_;
std::vector<PirHookFunc> pir_output_hookfuncs_;
std::vector<PirHookFunc> pir_input_hookfuncs_;

/// ======================== ///
/// For new ir ///
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/new_executor/program_interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ class ProgramInterpreter : public InterpreterBaseImpl {
input_hookfuncs_ = hookfuncs;
}

void SetOutputHooks(const std::vector<PirHookFunc>& hookfuncs) override {}

void SetInputHooks(const std::vector<PirHookFunc>& hookfuncs) override {}

std::unordered_map<std::string, std::shared_ptr<EventInter>>*
GetForceEventsToWaitInfo() {
return force_events_to_wait_;
Expand Down
Loading