Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc
cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper)
cc_library(fuse_adam_op_pass SRCS fuse_adam_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(fuse_momentum_op_pass SRCS fuse_momentum_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)

cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper)

Expand Down Expand Up @@ -126,4 +127,5 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass
alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass
fuse_adam_op_pass fuse_sgd_op_pass record_skip_memory_opt_vars_pass)
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
record_skip_memory_opt_vars_pass)
34 changes: 19 additions & 15 deletions paddle/fluid/framework/details/build_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("record_skip_memory_opt_vars_pass");

if (strategy_.enable_sequential_execution_) {
VLOG(10) << "Add sequential_execution_pass";
VLOG(5) << "Add sequential_execution_pass";
AppendPass("sequential_execution_pass");
}

Expand All @@ -68,7 +68,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {

// Add op fusion.
if (strategy.fuse_relu_depthwise_conv_) {
VLOG(10) << "Add fuse_relu_depthwise_conv_pass";
VLOG(5) << "Add fuse_relu_depthwise_conv_pass";
AppendPass("fuse_relu_depthwise_conv_pass");
}

Expand All @@ -80,19 +80,19 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {

// Add automatically inplace.
if (strategy_.enable_inplace_) {
VLOG(10) << "Add inplace_pass";
VLOG(5) << "Add inplace_pass";
AppendPass("inplace_pass");
}

if (strategy_.fuse_elewise_add_act_ops_) {
VLOG(10) << "Add fuse_elewise_add_act_pass";
VLOG(5) << "Add fuse_elewise_add_act_pass";
AppendPass("fuse_elewise_add_act_pass");
}

// for single card training, fuse_all_reduce_ops is unnecessary.
// alloc_continuous_space_for_grad_pass should be before of MultiDevPass.
if (strategy_.fuse_all_reduce_ops_) {
VLOG(10) << "Add alloc_continuous_space_for_grad_pass";
VLOG(5) << "Add alloc_continuous_space_for_grad_pass";
AppendPass("alloc_continuous_space_for_grad_pass");
}

Expand All @@ -107,10 +107,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// NOTE: fuse_all_xx_ops will count the number of xx operator first,
// if the number is zero, fuse_all_reduce_ops will do nothing.
// Currently, only one type of optimization algorithm can be fused.
VLOG(10) << "Add fuse_adam_op_pass";
VLOG(5) << "Add fuse_adam_op_pass";
AppendPass("fuse_adam_op_pass");
VLOG(10) << "Add fuse_sgd_op_pass";
VLOG(5) << "Add fuse_sgd_op_pass";
AppendPass("fuse_sgd_op_pass");
VLOG(5) << "Add fuse_momentum_op_pass";
AppendPass("fuse_momentum_op_pass");
}
}

Expand Down Expand Up @@ -139,15 +141,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface.
if (strategy_.memory_optimize_) {
VLOG(10) << "Add memory_optimize_pass";
VLOG(5) << "Add memory_optimize_pass";
AppendPass("memory_optimize_pass");
}

// runtime_context_cache pass should be the last pass to enable the attr of
// all original and fused operators. But no operators can be enabled this
// attr if putting it after MultiDevPass.
if (strategy_.cache_runtime_context_) {
VLOG(10) << "Add runtime_context_cache_pass";
VLOG(5) << "Add runtime_context_cache_pass";
AppendPass("runtime_context_cache_pass");
}

Expand All @@ -161,7 +163,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
if (strategy_.fuse_all_reduce_ops_) {
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
// first, if the number is zero, fuse_all_reduce_ops will do nothing.
VLOG(10) << "Add fuse_all_reduce_op_pass";
VLOG(5) << "Add fuse_all_reduce_op_pass";
AppendPass("fuse_all_reduce_op_pass");
}

Expand All @@ -182,12 +184,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
if (!strategy_.enable_parallel_graph_ &&
(SeqOnlyAllReduceOps(strategy_) ||
strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce)) {
VLOG(10) << "Add all_reduce_deps_pass";
VLOG(5) << "Add all_reduce_deps_pass";
AppendPass("all_reduce_deps_pass");
}

if (strategy_.remove_unnecessary_lock_) {
VLOG(10) << "Add modify_op_lock_and_record_event_pass";
VLOG(5) << "Add modify_op_lock_and_record_event_pass";
AppendPass("modify_op_lock_and_record_event_pass");
}

Expand All @@ -202,16 +204,16 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
if (strategy_.async_mode_) {
multi_devices_pass = AppendPass("async_multi_devices_pass").get();
} else if (strategy_.is_distribution_) {
VLOG(10)
VLOG(5)
<< "Add dist_multi_devices_pass, multi device parameter server mode";
multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
VLOG(10) << "Add all_reduce_mode_multi_devices_pass";
VLOG(5) << "Add all_reduce_mode_multi_devices_pass";
multi_devices_pass =
AppendPass("all_reduce_mode_multi_devices_pass").get();
} else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
VLOG(10) << "Add reduce_mode_multi_devices_pass";
VLOG(5) << "Add reduce_mode_multi_devices_pass";
multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
} else {
PADDLE_THROW("Unknown reduce strategy.");
Expand Down Expand Up @@ -277,6 +279,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass" ||
pass->Type() == "fuse_adam_op_pass" ||
pass->Type() == "fuse_sgd_op_pass" ||
pass->Type() == "fuse_momentum_op_pass" ||
pass->Type() == "fuse_all_reduce_op_pass") {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
Expand Down Expand Up @@ -341,6 +344,7 @@ USE_PASS(alloc_continuous_space_for_grad_pass);
USE_PASS(graph_to_program_pass);
USE_PASS(fuse_adam_op_pass);
USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_momentum_op_pass);
USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass);
USE_PASS(expected_kernel_cache_pass);
Expand Down
Loading