Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
fast_threaded_ssa_graph_executor variable_helper)

cc_library(executor_cache SRCS executor_cache.cc DEPS parallel_executor)
cc_library(event_based_executor SRCS event_based_executor.cc runtime_graph.cc DEPS framework_proto)
if(WITH_PSCORE)
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
cc_test(dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS
Expand Down Expand Up @@ -440,7 +441,7 @@ endif()
#cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} )
#cc_binary(new_executor SRCS new_exec_test.cc DEPS operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler)

set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator)
set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator event_based_executor)

cc_library(paddle_framework DEPS ${FLUID_FRAMEWORK_MODULES})

Expand Down
49 changes: 49 additions & 0 deletions paddle/fluid/framework/event_based_executor.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/event_based_executor.h"
#include <string>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/runtime_graph.h"

namespace paddle {
namespace framework {

EventBasedExecutor::~EventBasedExecutor() {
std::cout << "In EventBased Deconstructor" << std::endl;
}

void EventBasedExecutor::Compile(const ProgramDesc& program,
const std::string& grain) {
if (grain == "coarse") {
CompileCoarseGrainGraph(program);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VLOG(0)

} else {
CompileFineGrainGraph(program);
}
}

void EventBasedExecutor::CompileCoarseGrainGraph(const ProgramDesc& program) {
runtime_graph_.reset(new RuntimeGraph(program));
runtime_graph_->PrintGraph();
}

void EventBasedExecutor::CompileFineGrainGraph(const ProgramDesc& program) {
std::cout << "Compile Fine Grain Graph" << std::endl;
}

void EventBasedExecutor::Run() {
std::cout << "In Event Based Executor Run" << std::endl;
}
} // namespace framework
} // namespace paddle
42 changes: 42 additions & 0 deletions paddle/fluid/framework/event_based_executor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

放到fleet下,比如sectionworker平级。


#include <memory>
#include <string>
#include <thread>
#include "paddle/fluid/framework/runtime_graph.h"

namespace paddle {
namespace framework {

class ProgramDesc;
class EventBasedWorker;

class EventBasedExecutor {
public:
EventBasedExecutor() = default;
~EventBasedExecutor();

void Compile(const ProgramDesc& program_desc, const std::string& grain);
void Run();

private:
void CompileCoarseGrainGraph(const ProgramDesc& program_desc);
void CompileFineGrainGraph(const ProgramDesc& program_desc);
std::unique_ptr<RuntimeGraph> runtime_graph_;
};
} // namespace framework
} // namespace paddle
164 changes: 164 additions & 0 deletions paddle/fluid/framework/runtime_graph.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/runtime_graph.h"
#include "paddle/fluid/framework/program_desc.h"

namespace paddle {
namespace framework {

namespace {
typedef std::unordered_map<std::string, std::vector<int64_t>> VarList;

void FilterAndAddOutputVars(const BlockDesc &block,
const VarList &prev_input_vars,
VarList *cur_output_vars) {
const auto &ops_in_block = block.AllOps();
for (const OpDesc *op : ops_in_block) {
const auto &output_names = op->OutputArgumentNames();
for (const auto &output_name : output_names) {
if (prev_input_vars.find(output_name) != prev_input_vars.end()) {
for (int64_t consumer_id : prev_input_vars.at(output_name)) {
if (cur_output_vars->find(output_name) == cur_output_vars->end()) {
cur_output_vars->emplace(output_name, std::vector<int64_t>());
}
cur_output_vars->at(output_name).emplace_back(consumer_id);
}
}
}
}
}

void CreateVarNodesAndAddDeps(const VarList &cur_output_vars,
int64_t producer_id,
RuntimeGraph *runtime_graph) {
TaskNode *producer = runtime_graph->GetTaskNode(producer_id);
for (const auto &output_var : cur_output_vars) {
const auto &var_name = output_var.first;
if (!runtime_graph->HasVarNode(var_name)) {
VarDesc *var = producer->FindVar(var_name);
runtime_graph->CreateAndAddVarNode(*var);
}
InterVarNode *var_node = runtime_graph->FindVarNode(var_name);
for (int64_t consumer_id : cur_output_vars.at(var_name)) {
TaskNode *consumer = runtime_graph->GetTaskNode(consumer_id);
var_node->AddConsumedTask(consumer_id);
var_node->SetProducedTask(producer_id);
producer->AddProducedVarNode(var_name);
consumer->AddConsumedVarNode(var_name);
}
}
}

void FilterAndAddInputVars(const BlockDesc &block, VarList *prev_input_vars) {
const auto &ops_in_block = block.AllOps();
for (const OpDesc *op : ops_in_block) {
const auto &var_names = op->InputArgumentNames();
for (const auto &name : var_names) {
if (prev_input_vars->find(name) == prev_input_vars->end()) {
prev_input_vars->emplace(name, std::vector<int64_t>());
}
prev_input_vars->at(name).emplace_back(block.ID());
}
}
}
} // namespace

InterVarNode::InterVarNode(const VarDesc &var) : name_(var.Name()), var_(&var) {
task_produce_this_var_ = -1;
}

void InterVarNode::AddConsumedTask(int64_t task_id) {
tasks_consume_this_var_.insert(task_id);
}

void InterVarNode::SetProducedTask(int64_t task_id) {
task_produce_this_var_ = task_id;
}

TaskNode::TaskNode(const BlockDesc &block)
: task_id_(block.ID()), block_(&block) {}

void TaskNode::AddConsumedVarNode(const std::string &var_node_name) {
consumed_var_node_names_.insert(var_node_name);
}

void TaskNode::AddProducedVarNode(const std::string &var_node_name) {
produced_var_node_names_.insert(var_node_name);
}

bool TaskNode::IsSrcTask() const { return consumed_var_node_names_.empty(); }

VarDesc *TaskNode::FindVar(const std::string &name) const {
return block_->FindVar(name);
}

void TaskNode::PrintTaskNode() const {
std::cout << "consumed variables"
<< ": ";
for (const auto &name : consumed_var_node_names_) {
std::cout << name << " ";
}
std::cout << std::endl;
std::cout << "produced variables"
<< ": ";
for (const auto &name : produced_var_node_names_) {
std::cout << name << " ";
}
std::cout << std::endl;
}

RuntimeGraph::RuntimeGraph(const ProgramDesc &program) {
int64_t block_size = program.Size();
task_nodes_.resize(block_size);
VarList prev_input_vars;
for (int64_t i = block_size - 1; i >= 0; --i) {
const auto &block = program.Block(i);
task_nodes_[i].reset(new TaskNode(block));
VarList cur_output_vars;
FilterAndAddOutputVars(block, prev_input_vars, &cur_output_vars);
CreateVarNodesAndAddDeps(cur_output_vars, block.ID(), this);
FilterAndAddInputVars(block, &prev_input_vars);
}
}

bool RuntimeGraph::HasVarNode(const std::string &name) const {
return var_nodes_.find(name) != var_nodes_.end();
}

InterVarNode *RuntimeGraph::FindVarNode(const std::string &name) const {
CHECK(var_nodes_.find(name) != var_nodes_.end());
return var_nodes_.at(name).get();
}

TaskNode *RuntimeGraph::GetTaskNode(int64_t id) const {
CHECK_LT(id, (int)task_nodes_.size());
return task_nodes_[id].get();
}

InterVarNode *RuntimeGraph::CreateAndAddVarNode(const VarDesc &var) {
InterVarNode *var_node = new InterVarNode(var);
var_nodes_.emplace(var.Name(), var_node);
return var_node;
}

int64_t RuntimeGraph::TaskNodesNum() const { return task_nodes_.size(); }

void RuntimeGraph::PrintGraph() const {
for (const auto &task : task_nodes_) {
task->PrintTaskNode();
}
}
} // namespace framework
} // namespace paddle
94 changes: 94 additions & 0 deletions paddle/fluid/framework/runtime_graph.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>

namespace paddle {
namespace framework {

class ProgramDesc;
class OpDesc;
class BlockDesc;
class VarDesc;

class InterVarNode final {
public:
InterVarNode() = delete;
explicit InterVarNode(const VarDesc &var);
~InterVarNode() = default;
InterVarNode(const InterVarNode &) = delete;
InterVarNode(InterVarNode &&) = delete;
InterVarNode &operator=(const InterVarNode &) = delete;
InterVarNode &operator=(InterVarNode &&) = delete;

void AddConsumedTask(int64_t task_id);
void SetProducedTask(int64_t task_id);

private:
std::string name_;
std::unordered_set<int64_t> tasks_consume_this_var_;
int64_t task_produce_this_var_;
const VarDesc *var_;
};

class TaskNode final {
public:
TaskNode() = delete;
explicit TaskNode(const BlockDesc &block);
~TaskNode() = default;
TaskNode(const TaskNode &) = delete;
TaskNode(TaskNode &&) = delete;
TaskNode &operator=(const TaskNode &) = delete;
TaskNode &operator=(TaskNode &&) = delete;

bool IsSrcTask() const;
void AddConsumedVarNode(const std::string &var_node_name);
void AddProducedVarNode(const std::string &var_node_name);
VarDesc *FindVar(const std::string &name) const;
void PrintTaskNode() const;

private:
int64_t task_id_;
const BlockDesc *block_;
std::unordered_set<std::string> consumed_var_node_names_;
std::unordered_set<std::string> produced_var_node_names_;
};

class RuntimeGraph final {
public:
RuntimeGraph() = delete;
explicit RuntimeGraph(const ProgramDesc &program);
~RuntimeGraph() = default;
RuntimeGraph(const RuntimeGraph &) = delete;
RuntimeGraph(RuntimeGraph &&) = delete;
RuntimeGraph &operator=(const RuntimeGraph &) = delete;
RuntimeGraph &operator=(RuntimeGraph &&) = delete;

TaskNode *GetTaskNode(int64_t id) const;
InterVarNode *FindVarNode(const std::string &name) const;
bool HasVarNode(const std::string &name) const;
InterVarNode *CreateAndAddVarNode(const VarDesc &var);
int64_t TaskNodesNum() const;
void PrintGraph() const;

private:
std::vector<std::unique_ptr<TaskNode>> task_nodes_;
std::unordered_map<std::string, std::unique_ptr<InterVarNode>> var_nodes_;
};
} // namespace framework
} // namespace paddle
6 changes: 6 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License. */
#include "paddle/fluid/framework/custom_operator.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/event_based_executor.h"
#include "paddle/fluid/framework/executor_cache.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
Expand Down Expand Up @@ -1907,6 +1908,11 @@ All parameter, weight, gradient are variables in Paddle.

m.def("_get_eager_deletion_vars", &framework::GetEagerDeletionCleanVars);

py::class_<framework::EventBasedExecutor>(m, "EventBasedExecutor")
.def(py::init<>())
.def("compile", &EventBasedExecutor::Compile)
.def("run", &EventBasedExecutor::Run);

py::class_<framework::Executor>(m, "Executor")
.def(py::init<const platform::Place &>())
.def("close", &Executor::Close)
Expand Down
Loading