Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions paddle/fluid/framework/block_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@ std::vector<VarDesc *> BlockDesc::AllVars() const {
return res;
}

std::vector<std::string> BlockDesc::AllVarsName() const {
std::vector<std::string> res;
for (const auto &p : vars_) {
res.push_back(p.first);
}
return res;
}

OpDesc *BlockDesc::AppendOp() {
need_update_ = true;
ops_.emplace_back(new OpDesc(this));
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/block_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,16 @@ class TEST_API BlockDesc {

std::vector<VarDesc *> AllVars() const;

std::vector<std::string> AllVarsName() const;

BlockDesc *ParentBlock() const;

BlockDesc *ForwardBlock() const;

void SetForwardBlockID(int32_t forward_block_id);

void AppendAllocatedVar(VarDesc *var_desc);

OpDesc *AppendOp();

void AppendAllocatedOp(std::unique_ptr<OpDesc> &&op_desc);
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ if(WITH_TENSORRT)
pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference)
pass_library(split_layernorm_to_math_ops_pass inference)
pass_library(trt_remove_amp_strategy_op_pass inference)
pass_library(set_subgraph_edge_pass inference)
endif()

if(WITH_GPU OR WITH_ROCM)
Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,14 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph);

VLOG(3) << "Running conv_bn_fuse_pass.";
if (graph->IsMainGraph()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

类似这样的日志判断,推荐用IS_VLOG_ON,默认情况下就可以跳过执行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

类似这样的日志判断,推荐用IS_VLOG_ON,默认情况下就可以跳过执行

学习到了,下个pr修改一下

VLOG(3) << "The ID of block running conv_bn_fuse_pass is: 0(main_graph)";
} else {
VLOG(3) << "The ID of block running conv_bn_fuse_pass is: "
<< graph->GetBlockId();
}

auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
Expand Down Expand Up @@ -612,6 +620,15 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph);

VLOG(3) << "Running conv_eltwiseadd_bn_fuse_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running conv_eltwiseadd_bn_fuse_pass is: "
"0(main_graph)";
} else {
VLOG(3) << "The ID of block running conv_eltwiseadd_bn_fuse_pass is: "
<< graph->GetBlockId();
}

auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,17 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
scope,
platform::errors::InvalidArgument(
"Scope in DeleteQuantDequantLinearOpPass should not be null."));

VLOG(3) << "Running delete_quant_dequant_linear_op_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running delete_quant_dequant_linear_op_pass "
"is: 0(main_graph)";
} else {
VLOG(3)
<< "The ID of block running delete_quant_dequant_linear_op_pass is: "
<< graph->GetBlockId();
}

std::unordered_map<std::string, std::vector<float>> var_quant_scales{};

// Create pattern
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,17 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
true,
platform::errors::InvalidArgument(
"Graph must have kParamScopeAttr attribute."));
VLOG(3) << "Handle delete weight dequant linear op pass ...";

VLOG(3) << "Running delete_weight_dequant_linear_op_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running delete_weight_dequant_linear_op_pass "
"is: 0(main_graph)";
} else {
VLOG(3)
<< "The ID of block running delete_weight_dequant_linear_op_pass is: "
<< graph->GetBlockId();
}

auto& scope = graph->Get<framework::Scope>(kParamScopeAttr);
bool is_int8 = false;

Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,16 @@ void FuseMultiTransformerLayerPass::ApplyImpl(Graph* graph) const {
scope,
platform::errors::Fatal("During the fuse_multi_transformer_layer pass, "
"The scope should not be null."));

VLOG(3) << "Running fuse_multi_transformer_layer_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running fuse_multi_transformer_layer_pass is: "
"0(main_graph)";
} else {
VLOG(3) << "The ID of block running fuse_multi_transformer_layer_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);

AddStatis(fusion_count);
Expand Down
34 changes: 34 additions & 0 deletions paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1671,6 +1671,16 @@ void FusedMultiTransformerDecoderPass::ApplyImpl(Graph* graph) const {
platform::errors::Fatal("During the multi_transformer pass, "
"The scope should not be null."));

VLOG(3) << "Running fused_multi_transformer_decoder_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running fused_multi_transformer_decoder_pass "
"is: 0(main_graph)";
} else {
VLOG(3)
<< "The ID of block running fused_multi_transformer_decoder_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerDecoderPass, new bool(true));
Expand Down Expand Up @@ -2376,6 +2386,17 @@ void FusedMultiTransformerDecoderFuseQKVPass::ApplyImpl(Graph* graph) const {
platform::errors::Fatal("During the fused_multi_transformer_decoder "
"pass, The scope should not be null."));

VLOG(3) << "Running fused_multi_transformer_decoder_fuse_qkv_pass.";
if (graph->IsMainGraph()) {
VLOG(3)
<< "The ID of block running "
"fused_multi_transformer_decoder_fuse_qkv_pass is: 0(main_graph)";
} else {
VLOG(3) << "The ID of block running "
"fused_multi_transformer_decoder_fuse_qkv_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerDecoderFuseQKVPass, new bool(true));
Expand Down Expand Up @@ -3146,6 +3167,19 @@ void MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::ApplyImpl(
platform::errors::Fatal("During the fused_multi_transformer_decoder "
"pass, The scope should not be null."));

VLOG(3)
<< "Running multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running "
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass "
"is: 0(main_graph)";
} else {
VLOG(3)
<< "The ID of block running "
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerDecoderFuseQKVPass, new bool(true));
Expand Down
45 changes: 45 additions & 0 deletions paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2400,6 +2400,16 @@ void FusedMultiTransformerEncoderPass::ApplyImpl(Graph* graph) const {
platform::errors::Fatal(
"During the multi_transformer pass, The scope should not be null."));

VLOG(3) << "Running fused_multi_transformer_encoder_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running fused_multi_transformer_encoder_pass "
"is: 0(main_graph)";
} else {
VLOG(3)
<< "The ID of block running fused_multi_transformer_encoder_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerEncoderPass, new bool(true));
Expand Down Expand Up @@ -3211,6 +3221,17 @@ void FusedMultiTransformerEncoderFuseQKVPass::ApplyImpl(Graph* graph) const {
"During the fused_multi_transformer_encoder pass, "
"The scope should not be null."));

VLOG(3) << "Running fused_multi_transformer_encoder_fuse_qkv_pass.";
if (graph->IsMainGraph()) {
VLOG(3)
<< "The ID of block running "
"fused_multi_transformer_encoder_fuse_qkv_pass is: 0(main_graph)";
} else {
VLOG(3) << "The ID of block running "
"fused_multi_transformer_encoder_fuse_qkv_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerEncoderFuseQKVPass, new bool(true));
Expand Down Expand Up @@ -4013,6 +4034,17 @@ void MultiDevicesFusedMultiTransformerEncoderPass::ApplyImpl(
platform::errors::Fatal(
"During the multi_transformer pass, The scope should not be null."));

VLOG(3) << "Running multi_devices_fused_multi_transformer_encoder_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running "
"multi_devices_fused_multi_transformer_encoder_pass is: "
"0(main_graph)";
} else {
VLOG(3) << "The ID of block running "
"multi_devices_fused_multi_transformer_encoder_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerEncoderPass, new bool(true));
Expand Down Expand Up @@ -4872,6 +4904,19 @@ void MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::ApplyImpl(
"During the fused_multi_transformer_encoder pass, "
"The scope should not be null."));

VLOG(3)
<< "Running multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running "
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass "
"is: 0(main_graph)";
} else {
VLOG(3)
<< "The ID of block running "
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kMultiDevicesFusedMultiTransformerEncoderFuseQKVPass,
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/ir/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ Graph::Graph(const ProgramDesc &program,
}
}

Graph::Graph(const BlockDesc &block, const Graph *main_graph)
Graph::Graph(BlockDesc &block, const Graph *main_graph)
: Graph(block, main_graph, 0, static_cast<int64_t>(block.AllOps().size())) {
}

Graph::Graph(const BlockDesc &block,
Graph::Graph(BlockDesc &block,
const Graph *main_graph,
const int64_t start_op_index,
const int64_t end_op_index)
Expand All @@ -103,7 +103,7 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
}

std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
const BlockDesc &block,
BlockDesc &block,
const int64_t start_op_index,
const int64_t end_op_index) {
std::unordered_map<std::string, std::pair<VarDesc *, int>>
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/framework/ir/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ class Graph {
const int64_t end_op_index);

// Construct a sub_graph
Graph(const BlockDesc &block, const Graph *main_graph);
Graph(BlockDesc &block, const Graph *main_graph); // NOLINT

// Construct a sub_graph with ops[start_op_index, end_op_index)
Graph(const BlockDesc &block,
Graph(BlockDesc &block, // NOLINT
const Graph *main_graph,
const int64_t start_op_index,
const int64_t end_op_index);
Expand Down Expand Up @@ -383,6 +383,8 @@ class Graph {

bool IsMainGraph() const { return main_graph_ == nullptr; }

const Graph *GetMainGraph() const { return main_graph_; }

Graph *GetSubGraph(const size_t idx) const {
PADDLE_ENFORCE_EQ(
this->IsMainGraph(),
Expand Down Expand Up @@ -435,7 +437,7 @@ class Graph {
const int64_t end_op_index);

std::map<std::string, std::vector<ir::Node *>> InitFromBlock(
const BlockDesc &block,
BlockDesc &block, // NOLINT
const int64_t start_op_index,
const int64_t end_op_index);

Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4623,6 +4623,12 @@ void patterns::MulMatmulMatmulV2::operator()(
ops->LinksTo({ops_out});
}

// subgraph_edge_pattern
PDNode *patterns::SubgraphEdgePattern::operator()(
const std::unordered_set<std::string> &ops_type) {
auto ops = pattern->NewNode(ops_repr())->assert_is_ops(ops_type);
return ops;
}
} // namespace ir
} // namespace framework
} // namespace paddle
8 changes: 8 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -2210,6 +2210,14 @@ struct AddSupportInt8 : public PatternBase {
PATTERN_DECL_NODE(quant_out);
};

// subgraph_edge_pattern
struct SubgraphEdgePattern : public PatternBase {
SubgraphEdgePattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "subgraph_edge_pattern") {}
PDNode* operator()(const std::unordered_set<std::string>& ops_type);
PATTERN_DECL_NODE(ops);
};

// The following patterns are used to fuse feedforward in forward
// 1. layer_norm -> linear1 -> activation -> dropout1 -> linear2 -> dropout2
// -> residual_add (pre_layer_norm)
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/framework/ir/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,14 @@ class Node {
// so expose it is a good idea
static constexpr int NO_DESC_ORDER = INT_MAX;

// Set whether the node is an edge of the subgraph.
void SetSubgraphOutput() { subgraph_output_ = true; }
void SetSubgraphInput() { subgraph_input_ = true; }

// Get whether the node is an edge of the subgraph.
bool IsSubgraphOutput() { return subgraph_output_; }
bool IsSubgraphInput() { return subgraph_input_; }

protected:
std::string name_;
std::unique_ptr<VarDesc> var_desc_;
Expand All @@ -268,6 +276,10 @@ class Node {
uint64_t original_desc_id_{0};
int graph_id_{-1};

// Is it the edge of the subgraph.
bool subgraph_output_ = false;
bool subgraph_input_ = false;

private:
// ID can only set by a Graph.
void SetId(int id) { id_ = id; }
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/framework/ir/pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ namespace ir {
static const char kParamScopeAttr[] = "__param_scope__"; // NOLINT

static const std::vector<std::string> support_subgraph_passes = {
"feed_fetch_subgraph_pass",
"set_subgraph_edge_pass",
"trt_map_ops_to_matrix_multiply_pass",
"tensorrt_subgraph_pass",
"simplify_with_basic_ops_pass",
"fused_multi_transformer_encoder_pass",
"fused_multi_transformer_decoder_pass",
Expand Down Expand Up @@ -128,7 +132,7 @@ Graph *Pass::Apply(Graph *graph) const {
} else {
subgraph_passes = support_subgraph_passes;
}
if (graph->IsMainGraph() &&
if (FLAGS_convert_all_blocks && graph->IsMainGraph() &&
(std::count(subgraph_passes.begin(), subgraph_passes.end(), Type()) ||
std::count(support_subgraph_generate_passes.begin(),
support_subgraph_generate_passes.end(),
Expand Down
Loading