Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ if(WITH_TENSORRT)
pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference)
pass_library(split_layernorm_to_math_ops_pass inference)
pass_library(trt_remove_amp_strategy_op_pass inference)
pass_library(set_subgraph_edge_pass inference)
endif()

if(WITH_GPU OR WITH_ROCM)
Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,14 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph);

VLOG(3) << "Running conv_bn_fuse_pass.";
if (graph->IsMainGraph()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

类似这样的日志判断,推荐用IS_VLOG_ON,默认情况下就可以跳过执行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

类似这样的日志判断,推荐用IS_VLOG_ON,默认情况下就可以跳过执行

学习到了,下个pr修改一下

VLOG(3) << "The ID of block running conv_bn_fuse_pass is: 0(main_graph)";
} else {
VLOG(3) << "The ID of block running conv_bn_fuse_pass is: "
<< graph->GetBlockId();
}

auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
Expand Down Expand Up @@ -612,6 +620,15 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph);

VLOG(3) << "Running conv_eltwiseadd_bn_fuse_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running conv_eltwiseadd_bn_fuse_pass is: "
"0(main_graph)";
} else {
VLOG(3) << "The ID of block running conv_eltwiseadd_bn_fuse_pass is: "
<< graph->GetBlockId();
}

auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,17 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
scope,
platform::errors::InvalidArgument(
"Scope in DeleteQuantDequantLinearOpPass should not be null."));

VLOG(3) << "Running delete_quant_dequant_linear_op_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running delete_quant_dequant_linear_op_pass "
"is: 0(main_graph)";
} else {
VLOG(3)
<< "The ID of block running delete_quant_dequant_linear_op_pass is: "
<< graph->GetBlockId();
}

std::unordered_map<std::string, std::vector<float>> var_quant_scales{};

// Create pattern
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,17 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
true,
platform::errors::InvalidArgument(
"Graph must have kParamScopeAttr attribute."));
VLOG(3) << "Handle delete weight dequant linear op pass ...";

VLOG(3) << "Running delete_weight_dequant_linear_op_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running delete_weight_dequant_linear_op_pass "
"is: 0(main_graph)";
} else {
VLOG(3)
<< "The ID of block running delete_weight_dequant_linear_op_pass is: "
<< graph->GetBlockId();
}

auto& scope = graph->Get<framework::Scope>(kParamScopeAttr);
bool is_int8 = false;

Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,16 @@ void FuseMultiTransformerLayerPass::ApplyImpl(Graph* graph) const {
scope,
platform::errors::Fatal("During the fuse_multi_transformer_layer pass, "
"The scope should not be null."));

VLOG(3) << "Running fuse_multi_transformer_layer_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running fuse_multi_transformer_layer_pass is: "
"0(main_graph)";
} else {
VLOG(3) << "The ID of block running fuse_multi_transformer_layer_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);

AddStatis(fusion_count);
Expand Down
34 changes: 34 additions & 0 deletions paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1671,6 +1671,16 @@ void FusedMultiTransformerDecoderPass::ApplyImpl(Graph* graph) const {
platform::errors::Fatal("During the multi_transformer pass, "
"The scope should not be null."));

VLOG(3) << "Running fused_multi_transformer_decoder_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running fused_multi_transformer_decoder_pass "
"is: 0(main_graph)";
} else {
VLOG(3)
<< "The ID of block running fused_multi_transformer_decoder_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerDecoderPass, new bool(true));
Expand Down Expand Up @@ -2376,6 +2386,17 @@ void FusedMultiTransformerDecoderFuseQKVPass::ApplyImpl(Graph* graph) const {
platform::errors::Fatal("During the fused_multi_transformer_decoder "
"pass, The scope should not be null."));

VLOG(3) << "Running fused_multi_transformer_decoder_fuse_qkv_pass.";
if (graph->IsMainGraph()) {
VLOG(3)
<< "The ID of block running "
"fused_multi_transformer_decoder_fuse_qkv_pass is: 0(main_graph)";
} else {
VLOG(3) << "The ID of block running "
"fused_multi_transformer_decoder_fuse_qkv_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerDecoderFuseQKVPass, new bool(true));
Expand Down Expand Up @@ -3146,6 +3167,19 @@ void MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::ApplyImpl(
platform::errors::Fatal("During the fused_multi_transformer_decoder "
"pass, The scope should not be null."));

VLOG(3)
<< "Running multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running "
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass "
"is: 0(main_graph)";
} else {
VLOG(3)
<< "The ID of block running "
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerDecoderFuseQKVPass, new bool(true));
Expand Down
45 changes: 45 additions & 0 deletions paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2400,6 +2400,16 @@ void FusedMultiTransformerEncoderPass::ApplyImpl(Graph* graph) const {
platform::errors::Fatal(
"During the multi_transformer pass, The scope should not be null."));

VLOG(3) << "Running fused_multi_transformer_encoder_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running fused_multi_transformer_encoder_pass "
"is: 0(main_graph)";
} else {
VLOG(3)
<< "The ID of block running fused_multi_transformer_encoder_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerEncoderPass, new bool(true));
Expand Down Expand Up @@ -3211,6 +3221,17 @@ void FusedMultiTransformerEncoderFuseQKVPass::ApplyImpl(Graph* graph) const {
"During the fused_multi_transformer_encoder pass, "
"The scope should not be null."));

VLOG(3) << "Running fused_multi_transformer_encoder_fuse_qkv_pass.";
if (graph->IsMainGraph()) {
VLOG(3)
<< "The ID of block running "
"fused_multi_transformer_encoder_fuse_qkv_pass is: 0(main_graph)";
} else {
VLOG(3) << "The ID of block running "
"fused_multi_transformer_encoder_fuse_qkv_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerEncoderFuseQKVPass, new bool(true));
Expand Down Expand Up @@ -4013,6 +4034,17 @@ void MultiDevicesFusedMultiTransformerEncoderPass::ApplyImpl(
platform::errors::Fatal(
"During the multi_transformer pass, The scope should not be null."));

VLOG(3) << "Running multi_devices_fused_multi_transformer_encoder_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running "
"multi_devices_fused_multi_transformer_encoder_pass is: "
"0(main_graph)";
} else {
VLOG(3) << "The ID of block running "
"multi_devices_fused_multi_transformer_encoder_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerEncoderPass, new bool(true));
Expand Down Expand Up @@ -4872,6 +4904,19 @@ void MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::ApplyImpl(
"During the fused_multi_transformer_encoder pass, "
"The scope should not be null."));

VLOG(3)
<< "Running multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass.";
if (graph->IsMainGraph()) {
VLOG(3) << "The ID of block running "
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass "
"is: 0(main_graph)";
} else {
VLOG(3)
<< "The ID of block running "
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass is: "
<< graph->GetBlockId();
}

int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kMultiDevicesFusedMultiTransformerEncoderFuseQKVPass,
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/ir/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ PADDLE_DEFINE_EXPORTED_bool(convert_all_blocks,
true,
"Convert all blocks in program into SSAgraphs");

PADDLE_DEFINE_EXPORTED_bool(all_blocks_convert_trt,
false,
"Convert all blocks'Ops into TensorRT Ops");

namespace paddle {
namespace framework {
namespace ir {
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/ir/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License. */
#include "paddle/utils/flags.h"

PD_DECLARE_bool(convert_all_blocks);
PD_DECLARE_bool(all_blocks_convert_trt);

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -383,6 +384,8 @@ class Graph {

bool IsMainGraph() const { return main_graph_ == nullptr; }

const Graph *GetMainGraph() const { return main_graph_; }

Graph *GetSubGraph(const size_t idx) const {
PADDLE_ENFORCE_EQ(
this->IsMainGraph(),
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/graph_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif
#include "paddle/fluid/platform/flags.h"
PD_DECLARE_bool(convert_all_blocks);
PD_DECLARE_bool(all_blocks_convert_trt);
PADDLE_DEFINE_EXPORTED_string(print_sub_graph_dir,
"",
"FLAGS_print_sub_graph_dir is used "
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4937,6 +4937,13 @@ void patterns::MulMatmulMatmulV2::operator()(
ops->LinksTo({ops_out});
}

// subgraph_edge_pattern
PDNode *patterns::SubgraphEdgePattern::operator()(
const std::unordered_set<std::string> &ops_type) {
auto ops = pattern->NewNode(ops_repr())->assert_is_ops(ops_type);
return ops;
}

PDNode *patterns::ConvBNAddAct::operator()(
const std::unordered_set<std::string> &act_types,
bool shortcut,
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -2296,6 +2296,14 @@ struct AddSupportInt8 : public PatternBase {
PATTERN_DECL_NODE(quant_out);
};

// subgraph_edge_pattern
struct SubgraphEdgePattern : public PatternBase {
SubgraphEdgePattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "subgraph_edge_pattern") {}
PDNode* operator()(const std::unordered_set<std::string>& ops_type);
PATTERN_DECL_NODE(ops);
};

// The following patterns are used to fuse feedforward in forward
// 1. layer_norm -> linear1 -> activation -> dropout1 -> linear2 -> dropout2
// -> residual_add (pre_layer_norm)
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/framework/ir/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,14 @@ class Node {
// so expose it is a good idea
static constexpr int NO_DESC_ORDER = INT_MAX;

// Set whether the node is an edge of the subgraph.
void SetSubgraphOutput() { subgraph_output_ = true; }
void SetSubgraphInput() { subgraph_input_ = true; }

// Get whether the node is an edge of the subgraph.
bool IsSubgraphOutput() { return subgraph_output_; }
bool IsSubgraphInput() { return subgraph_input_; }

protected:
std::string name_;
std::unique_ptr<VarDesc> var_desc_;
Expand All @@ -269,6 +277,10 @@ class Node {
uint64_t original_desc_id_{0};
int graph_id_{-1};

// Is it the edge of the subgraph.
bool subgraph_output_ = false;
bool subgraph_input_ = false;

private:
// ID can only set by a Graph.
void SetId(int id) { id_ = id; }
Expand Down
7 changes: 6 additions & 1 deletion paddle/fluid/framework/ir/pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ namespace ir {
static const char kParamScopeAttr[] = "__param_scope__"; // NOLINT

static const std::vector<std::string> support_subgraph_passes = {
"feed_fetch_subgraph_pass",
"set_subgraph_edge_pass",
"trt_map_ops_to_matrix_multiply_pass",
"tensorrt_subgraph_pass",
"simplify_with_basic_ops_pass",
"fused_multi_transformer_encoder_pass",
"fused_multi_transformer_decoder_pass",
Expand Down Expand Up @@ -128,7 +132,8 @@ Graph *Pass::Apply(Graph *graph) const {
} else {
subgraph_passes = support_subgraph_passes;
}
if (graph->IsMainGraph() &&
if (FLAGS_all_blocks_convert_trt && FLAGS_convert_all_blocks &&
graph->IsMainGraph() &&
(std::count(subgraph_passes.begin(), subgraph_passes.end(), Type()) ||
std::count(support_subgraph_generate_passes.begin(),
support_subgraph_generate_passes.end(),
Expand Down
Loading