Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,5 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
fuse_elewise_add_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass
alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass record_skip_memory_opt_vars_pass)
12 changes: 6 additions & 6 deletions paddle/fluid/framework/details/build_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
}

// for single card training, fuse_all_reduce_ops is unnecessary.
// alloc_continuous_space_for_grad_pass should be before of MultiDevPass.
// coalesce_grad_tensor_pass should be before of MultiDevPass.
if (strategy_.fuse_all_reduce_ops_) {
VLOG(1) << "Add alloc_continuous_space_for_grad_pass";
AppendPass("alloc_continuous_space_for_grad_pass");
VLOG(1) << "Add coalesce_grad_tensor_pass";
AppendPass("coalesce_grad_tensor_pass");
}

if (strategy_.fuse_all_optimizer_ops_) {
Expand Down Expand Up @@ -301,7 +301,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
#endif
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass" ||
} else if (pass->Type() == "coalesce_grad_tensor_pass" ||
pass->Type() == "fuse_adam_op_pass" ||
pass->Type() == "fuse_sgd_op_pass" ||
pass->Type() == "fuse_momentum_op_pass" ||
Expand All @@ -321,7 +321,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
new bool(use_hierarchical_allreduce_));
#endif
}
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass") {
} else if (pass->Type() == "coalesce_grad_tensor_pass") {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
pass->Erase(kLocalScopes);
Expand Down Expand Up @@ -389,7 +389,7 @@ USE_PASS(backward_optimizer_op_deps_pass);
USE_PASS(modify_op_lock_and_record_event_pass);
USE_PASS(inplace_pass);
USE_PASS(lock_free_optimize_pass);
USE_PASS(alloc_continuous_space_for_grad_pass);
USE_PASS(coalesce_grad_tensor_pass);
USE_PASS(graph_to_program_pass);
USE_PASS(fuse_adam_op_pass);
USE_PASS(fuse_sgd_op_pass);
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/details/multi_devices_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ typedef std::vector<std::string> FusedGrads;
constexpr char kFusedGrads[] = "fused_gradients";

typedef std::vector<std::pair<std::string, std::string>> ParamsAndGrads;
constexpr char kParamsAndGrads[] = "params_grads";
constexpr char kParamsAndDenseGrads[] = "params_and_dense_grads";
constexpr char kParamsAndSparseGrads[] = "params_and_sparse_grads";

typedef std::vector<std::vector<std::pair<std::string, std::string>>>
GroupParamsAndGrads;
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)

cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper)
cc_library(coalesce_grad_tensor_pass SRCS coalesce_grad_tensor_pass.cc DEPS graph graph_helper)

pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base)
Expand Down
Loading