-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[Paddle-TRT] support sub block: while op, condition_block op #59588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
8adf23b
6511140
a51dbb9
26c6bed
a7329c7
3543027
c4d685a
9aa2eff
88618cc
5c600f9
a48e5bd
a61488a
000102e
49e05ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -312,6 +312,14 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { | |
| graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); | ||
| FusePassBase::Init(name_scope_, graph); | ||
|
|
||
| VLOG(3) << "Running conv_bn_fuse_pass."; | ||
| if (graph->IsMainGraph()) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 类似这样的日志判断,推荐用IS_VLOG_ON,默认情况下就可以跳过执行
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
学习到了,下个pr修改一下 |
||
| VLOG(3) << "The ID of block running conv_bn_fuse_pass is: 0(main_graph)"; | ||
| } else { | ||
| VLOG(3) << "The ID of block running conv_bn_fuse_pass is: " | ||
| << graph->GetBlockId(); | ||
| } | ||
|
|
||
| auto* scope = param_scope(); | ||
| PADDLE_ENFORCE_NOT_NULL( | ||
| scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); | ||
|
|
@@ -612,6 +620,15 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const { | |
| graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); | ||
| FusePassBase::Init(name_scope_, graph); | ||
|
|
||
| VLOG(3) << "Running conv_eltwiseadd_bn_fuse_pass."; | ||
| if (graph->IsMainGraph()) { | ||
| VLOG(3) << "The ID of block running conv_eltwiseadd_bn_fuse_pass is: " | ||
| "0(main_graph)"; | ||
| } else { | ||
| VLOG(3) << "The ID of block running conv_eltwiseadd_bn_fuse_pass is: " | ||
| << graph->GetBlockId(); | ||
| } | ||
|
|
||
| auto* scope = param_scope(); | ||
| PADDLE_ENFORCE_NOT_NULL( | ||
| scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -94,10 +94,10 @@ class Graph { | |
| const int64_t end_op_index); | ||
|
|
||
| // Construct a sub_graph | ||
| Graph(const BlockDesc &block, const Graph *main_graph); | ||
| Graph(BlockDesc &block, const Graph *main_graph); // NOLINT | ||
|
|
||
| // Construct a sub_graph with ops[start_op_index, end_op_index) | ||
| Graph(const BlockDesc &block, | ||
| Graph(BlockDesc &block, // NOLINT | ||
| const Graph *main_graph, | ||
| const int64_t start_op_index, | ||
| const int64_t end_op_index); | ||
|
|
@@ -383,6 +383,8 @@ class Graph { | |
|
|
||
| bool IsMainGraph() const { return main_graph_ == nullptr; } | ||
|
|
||
| const Graph *GetMainGraph() const { return main_graph_; } | ||
|
|
||
| Graph *GetSubGraph(const size_t idx) const { | ||
| PADDLE_ENFORCE_EQ( | ||
| this->IsMainGraph(), | ||
|
|
@@ -425,6 +427,8 @@ class Graph { | |
| } | ||
| return res; | ||
| } | ||
| // The block this SubGraph belongs to. | ||
| int block_id_{0}; | ||
|
||
|
|
||
| private: | ||
| // TODO(levi): delete this interface after when we can convert all | ||
|
|
@@ -435,7 +439,7 @@ class Graph { | |
| const int64_t end_op_index); | ||
|
|
||
| std::map<std::string, std::vector<ir::Node *>> InitFromBlock( | ||
| const BlockDesc &block, | ||
| BlockDesc &block, // NOLINT | ||
| const int64_t start_op_index, | ||
| const int64_t end_op_index); | ||
|
|
||
|
|
@@ -478,8 +482,6 @@ class Graph { | |
| // parts: forward graph and backward graph, which can be executed | ||
| // independently. | ||
| bool is_partial_{false}; | ||
| // The block this SubGraph belongs to. | ||
| int block_id_{0}; | ||
| }; | ||
|
|
||
| bool IsControlDepVar(const ir::Node &var); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里感觉会不会有问题,你append的和实际存储进block的不是同一个var,怎么同步?并且这里第136行完全是冗余的,make_unique内部就会new一个新的出来。。。
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
当前的IR都只在自己的block上进行graph的修改(可能会修改var_desc的属性),不同的block中的边界var_desc是不同的,为了不干扰其它block上的graph操作。op的执行阶段是通过var_desc中的name去找scope,只要var name一样,就能正常运行