Skip to content

Commit 3b4d943

Browse files
author
chengduozh
committed
Merge branch '140basebatchmerge' of https://github.com/gongweibao/Paddle into cherry_pick_fix_all_reduce_dep_pass_bug
test=release/1.4
2 parents b77fcb1 + 619f9fb commit 3b4d943

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

paddle/fluid/framework/ir/multi_batch_merge_pass.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const {
8484

8585
// 1. record op nodes of different roles
8686
for (auto node : nodes) {
87-
if (node->IsVar()) continue;
87+
if (!node->IsOp()) continue;
88+
PADDLE_ENFORCE(node->Op(), "must find opdesc");
8889
int op_role = boost::get<int>(node->Op()->GetAttr(
8990
framework::OpProtoAndCheckerMaker::OpRoleAttrName()));
9091
if ((op_role == static_cast<int>(framework::OpRole::kForward)) ||

python/paddle/fluid/tests/unittests/test_dist_base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,7 @@ def run_trainer(self, args):
139139
pass_builder = None
140140
if args.batch_merge_repeat > 1:
141141
pass_builder = build_stra._finalize_strategy_and_create_passes()
142-
mypass = pass_builder.insert_pass(
143-
len(pass_builder.all_passes()) - 3, "multi_batch_merge_pass")
142+
mypass = pass_builder.insert_pass(0, "multi_batch_merge_pass")
144143
mypass.set("num_repeats", args.batch_merge_repeat)
145144

146145
if args.update_method == "nccl2" or args.update_method == "nccl2_reduce_layer":

0 commit comments

Comments
 (0)