Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions paddle/fluid/framework/block_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,25 @@ std::vector<VarDesc *> BlockDesc::AllVars() const {
return res;
}

std::vector<std::string> BlockDesc::AllVarsName() const {
std::vector<std::string> res;
for (const auto &p : vars_) {
res.push_back(p.first);
}
return res;
}

void BlockDesc::AppendAllocatedVar(VarDesc *var_desc) {
auto name = var_desc->Name();
if (this->HasVar(name)) {
return;
} else {
need_update_ = true;
VarDesc *new_var = new VarDesc(*var_desc);
vars_[name] = std::make_unique<VarDesc>(*new_var);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里感觉会不会有问题,你append的和实际存储进block的不是同一个var,怎么同步?并且这里第136行完全是冗余的,make_unique内部就会new一个新的出来。。。

Copy link
Contributor Author

@Wangzheee Wangzheee Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里感觉会不会有问题,你append的和实际存储进block的不是同一个var,怎么同步?并且这里第136行完全是冗余的,make_unique内部就会new一个新的出来。。。

当前的IR都只在自己的block上进行graph的修改(可能会修改var_desc的属性),不同的block中的边界var_desc是不同的,为了不干扰其它block上的graph操作。op的执行阶段是通过var_desc中的name去找scope,只要var name一样,就能正常运行

}
}

OpDesc *BlockDesc::AppendOp() {
need_update_ = true;
ops_.emplace_back(new OpDesc(this));
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/block_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,16 @@ class TEST_API BlockDesc {

std::vector<VarDesc *> AllVars() const;

std::vector<std::string> AllVarsName() const;

BlockDesc *ParentBlock() const;

BlockDesc *ForwardBlock() const;

void SetForwardBlockID(int32_t forward_block_id);

void AppendAllocatedVar(VarDesc *var_desc);

OpDesc *AppendOp();

void AppendAllocatedOp(std::unique_ptr<OpDesc> &&op_desc);
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ if(WITH_TENSORRT)
pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference)
pass_library(split_layernorm_to_math_ops_pass inference)
pass_library(trt_remove_amp_strategy_op_pass inference)
pass_library(set_subgraph_edge_pass inference)
endif()

if(WITH_GPU OR WITH_ROCM)
Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,14 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph);

VLOG(3) << "Running conv_bn_fuse_pass.";
if (graph->IsMainGraph()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

类似这样的日志判断,推荐用IS_VLOG_ON,默认情况下就可以跳过执行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

类似这样的日志判断,推荐用IS_VLOG_ON,默认情况下就可以跳过执行

学习到了,下个pr修改一下

VLOG(3) << "The ID of block running conv_bn_fuse_pass is: 0(main_graph)";
} else {
VLOG(3) << "The ID of block running conv_bn_fuse_pass is: "
<< graph->GetBlockId();
}

auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
Expand Down Expand Up @@ -612,6 +620,15 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph);

VLOG(3) << "Running conv_eltwiseadd_bn_fuse_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running conv_eltwiseadd_bn_fuse_pass is: "
"0(main_graph)";
} else {
VLOG(3) << "The ID of block running conv_eltwiseadd_bn_fuse_pass is: "
<< graph->GetBlockId();
}

auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,17 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
scope,
platform::errors::InvalidArgument(
"Scope in DeleteQuantDequantLinearOpPass should not be null."));

VLOG(3) << "Running delete_quant_dequant_linear_op_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running delete_quant_dequant_linear_op_pass "
"is: 0(main_graph)";
} else {
VLOG(3)
<< "The ID of block running delete_quant_dequant_linear_op_pass is: "
<< graph->GetBlockId();
}

std::unordered_map<std::string, std::vector<float>> var_quant_scales{};

// Create pattern
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,17 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
true,
platform::errors::InvalidArgument(
"Graph must have kParamScopeAttr attribute."));
VLOG(3) << "Handle delete weight dequant linear op pass ...";

VLOG(3) << "Running delete_weight_dequant_linear_op_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running delete_weight_dequant_linear_op_pass "
"is: 0(main_graph)";
} else {
VLOG(3)
<< "The ID of block running delete_weight_dequant_linear_op_pass is: "
<< graph->GetBlockId();
}

auto& scope = graph->Get<framework::Scope>(kParamScopeAttr);
bool is_int8 = false;

Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,16 @@ void FuseMultiTransformerLayerPass::ApplyImpl(Graph* graph) const {
scope,
platform::errors::Fatal("During the fuse_multi_transformer_layer pass, "
"The scope should not be null."));

VLOG(3) << "Running fuse_multi_transformer_layer_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running fuse_multi_transformer_layer_pass is: "
"0(main_graph)";
} else {
VLOG(3) << "The ID of block running fuse_multi_transformer_layer_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);

AddStatis(fusion_count);
Expand Down
34 changes: 34 additions & 0 deletions paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1671,6 +1671,16 @@ void FusedMultiTransformerDecoderPass::ApplyImpl(Graph* graph) const {
platform::errors::Fatal("During the multi_transformer pass, "
"The scope should not be null."));

VLOG(3) << "Running fused_multi_transformer_decoder_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running fused_multi_transformer_decoder_pass "
"is: 0(main_graph)";
} else {
VLOG(3)
<< "The ID of block running fused_multi_transformer_decoder_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerDecoderPass, new bool(true));
Expand Down Expand Up @@ -2376,6 +2386,17 @@ void FusedMultiTransformerDecoderFuseQKVPass::ApplyImpl(Graph* graph) const {
platform::errors::Fatal("During the fused_multi_transformer_decoder "
"pass, The scope should not be null."));

VLOG(3) << "Running fused_multi_transformer_decoder_fuse_qkv_pass.";
if (graph->IsMainGraph()) {
VLOG(3)
<< "The ID of block running "
"fused_multi_transformer_decoder_fuse_qkv_pass is: 0(main_graph)";
} else {
VLOG(3) << "The ID of block running "
"fused_multi_transformer_decoder_fuse_qkv_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerDecoderFuseQKVPass, new bool(true));
Expand Down Expand Up @@ -3146,6 +3167,19 @@ void MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::ApplyImpl(
platform::errors::Fatal("During the fused_multi_transformer_decoder "
"pass, The scope should not be null."));

VLOG(3)
<< "Running multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running "
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass "
"is: 0(main_graph)";
} else {
VLOG(3)
<< "The ID of block running "
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerDecoderFuseQKVPass, new bool(true));
Expand Down
45 changes: 45 additions & 0 deletions paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2400,6 +2400,16 @@ void FusedMultiTransformerEncoderPass::ApplyImpl(Graph* graph) const {
platform::errors::Fatal(
"During the multi_transformer pass, The scope should not be null."));

VLOG(3) << "Running fused_multi_transformer_encoder_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running fused_multi_transformer_encoder_pass "
"is: 0(main_graph)";
} else {
VLOG(3)
<< "The ID of block running fused_multi_transformer_encoder_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerEncoderPass, new bool(true));
Expand Down Expand Up @@ -3211,6 +3221,17 @@ void FusedMultiTransformerEncoderFuseQKVPass::ApplyImpl(Graph* graph) const {
"During the fused_multi_transformer_encoder pass, "
"The scope should not be null."));

VLOG(3) << "Running fused_multi_transformer_encoder_fuse_qkv_pass.";
if (graph->IsMainGraph()) {
VLOG(3)
<< "The ID of block running "
"fused_multi_transformer_encoder_fuse_qkv_pass is: 0(main_graph)";
} else {
VLOG(3) << "The ID of block running "
"fused_multi_transformer_encoder_fuse_qkv_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerEncoderFuseQKVPass, new bool(true));
Expand Down Expand Up @@ -4013,6 +4034,17 @@ void MultiDevicesFusedMultiTransformerEncoderPass::ApplyImpl(
platform::errors::Fatal(
"During the multi_transformer pass, The scope should not be null."));

VLOG(3) << "Running multi_devices_fused_multi_transformer_encoder_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running "
"multi_devices_fused_multi_transformer_encoder_pass is: "
"0(main_graph)";
} else {
VLOG(3) << "The ID of block running "
"multi_devices_fused_multi_transformer_encoder_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerEncoderPass, new bool(true));
Expand Down Expand Up @@ -4872,6 +4904,19 @@ void MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::ApplyImpl(
"During the fused_multi_transformer_encoder pass, "
"The scope should not be null."));

VLOG(3)
<< "Running multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running "
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass "
"is: 0(main_graph)";
} else {
VLOG(3)
<< "The ID of block running "
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kMultiDevicesFusedMultiTransformerEncoderFuseQKVPass,
Expand Down
9 changes: 6 additions & 3 deletions paddle/fluid/framework/ir/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ Graph::Graph(const ProgramDesc &program,
}
}

Graph::Graph(const BlockDesc &block, const Graph *main_graph)
Graph::Graph(BlockDesc &block, const Graph *main_graph)
: Graph(block, main_graph, 0, static_cast<int64_t>(block.AllOps().size())) {
}

Graph::Graph(const BlockDesc &block,
Graph::Graph(BlockDesc &block,
const Graph *main_graph,
const int64_t start_op_index,
const int64_t end_op_index)
Expand All @@ -103,7 +103,7 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
}

std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
const BlockDesc &block,
BlockDesc &block,
const int64_t start_op_index,
const int64_t end_op_index) {
std::unordered_map<std::string, std::pair<VarDesc *, int>>
Expand Down Expand Up @@ -159,6 +159,9 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
auto desc_and_block_id = name_to_desc_block_id.at(each_var_name);
var = CreateVarNode(desc_and_block_id.first, desc_and_block_id.second);
var_nodes[each_var_name].push_back(var);

// append all var
block.AppendAllocatedVar(desc_and_block_id.first);
} else {
// Operation input var can be optional (dispensable). Which means
// the operation doesn't really need the var at runtime. In this
Expand Down
12 changes: 7 additions & 5 deletions paddle/fluid/framework/ir/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ class Graph {
const int64_t end_op_index);

// Construct a sub_graph
Graph(const BlockDesc &block, const Graph *main_graph);
Graph(BlockDesc &block, const Graph *main_graph); // NOLINT

// Construct a sub_graph with ops[start_op_index, end_op_index)
Graph(const BlockDesc &block,
Graph(BlockDesc &block, // NOLINT
const Graph *main_graph,
const int64_t start_op_index,
const int64_t end_op_index);
Expand Down Expand Up @@ -383,6 +383,8 @@ class Graph {

bool IsMainGraph() const { return main_graph_ == nullptr; }

const Graph *GetMainGraph() const { return main_graph_; }

Graph *GetSubGraph(const size_t idx) const {
PADDLE_ENFORCE_EQ(
this->IsMainGraph(),
Expand Down Expand Up @@ -425,6 +427,8 @@ class Graph {
}
return res;
}
// The block this SubGraph belongs to.
int block_id_{0};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

从private移到public的理由?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

从private移到public的理由?

debug用的,忘了删


private:
// TODO(levi): delete this interface after when we can convert all
Expand All @@ -435,7 +439,7 @@ class Graph {
const int64_t end_op_index);

std::map<std::string, std::vector<ir::Node *>> InitFromBlock(
const BlockDesc &block,
BlockDesc &block, // NOLINT
const int64_t start_op_index,
const int64_t end_op_index);

Expand Down Expand Up @@ -478,8 +482,6 @@ class Graph {
// parts: forward graph and backward graph, which can be executed
// independently.
bool is_partial_{false};
// The block this SubGraph belongs to.
int block_id_{0};
};

bool IsControlDepVar(const ir::Node &var);
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4623,6 +4623,12 @@ void patterns::MulMatmulMatmulV2::operator()(
ops->LinksTo({ops_out});
}

// subgraph_edge_pattern
PDNode *patterns::SubgraphEdgePattern::operator()(
const std::unordered_set<std::string> &ops_type) {
auto ops = pattern->NewNode(ops_repr())->assert_is_ops(ops_type);
return ops;
}
} // namespace ir
} // namespace framework
} // namespace paddle
8 changes: 8 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -2210,6 +2210,14 @@ struct AddSupportInt8 : public PatternBase {
PATTERN_DECL_NODE(quant_out);
};

// subgraph_edge_pattern
struct SubgraphEdgePattern : public PatternBase {
SubgraphEdgePattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "subgraph_edge_pattern") {}
PDNode* operator()(const std::unordered_set<std::string>& ops_type);
PATTERN_DECL_NODE(ops);
};

// The following patterns are used to fuse feedforward in forward
// 1. layer_norm -> linear1 -> activation -> dropout1 -> linear2 -> dropout2
// -> residual_add (pre_layer_norm)
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/framework/ir/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,14 @@ class Node {
// so expose it is a good idea
static constexpr int NO_DESC_ORDER = INT_MAX;

// Set whether the node is an edge of the subgraph.
void SetSubgraphOutput() { subgraph_output_ = true; }
void SetSubgraphInput() { subgraph_input_ = true; }

// Get whether the node is an edge of the subgraph.
bool IsSubgraphOutput() { return subgraph_output_; }
bool IsSubgraphInput() { return subgraph_input_; }

protected:
std::string name_;
std::unique_ptr<VarDesc> var_desc_;
Expand All @@ -268,6 +276,10 @@ class Node {
uint64_t original_desc_id_{0};
int graph_id_{-1};

// Is it the edge of the subgraph.
bool subgraph_output_ = false;
bool subgraph_input_ = false;

private:
// ID can only set by a Graph.
void SetId(int id) { id_ = id; }
Expand Down
Loading