|
33 | 33 | #include "paddle/fluid/platform/profiler.h" |
34 | 34 | #include "paddle/fluid/string/string_helper.h" |
35 | 35 |
|
| 36 | +DECLARE_bool(sort_sum_gradient); |
| 37 | + |
36 | 38 | namespace paddle { |
37 | 39 | namespace imperative { |
38 | 40 |
|
@@ -529,8 +531,7 @@ class PartialGradTask { |
529 | 531 | const std::vector<std::shared_ptr<VarBase>> &output_targets, |
530 | 532 | const std::vector<std::shared_ptr<VarBase>> &output_grads, |
531 | 533 | const std::vector<std::shared_ptr<VarBase>> &no_grad_vars, |
532 | | - const platform::Place &place, |
533 | | - const detail::BackwardStrategy &strategy, bool create_graph, |
| 534 | + const platform::Place &place, bool create_graph, |
534 | 535 | bool retain_graph, bool allow_unused, bool only_inputs); |
535 | 536 |
|
536 | 537 | std::vector<std::shared_ptr<VarBase>> Run(); |
@@ -577,23 +578,22 @@ class PartialGradTask { |
577 | 578 | bool retain_graph_; |
578 | 579 | bool allow_unused_; |
579 | 580 | bool only_inputs_; |
580 | | - detail::BackwardStrategy strategy_; |
| 581 | + bool sorted_sum_gradient_{FLAGS_sort_sum_gradient}; |
581 | 582 | }; |
582 | 583 |
|
583 | 584 | PartialGradTask::PartialGradTask( |
584 | 585 | const std::vector<std::shared_ptr<VarBase>> &input_targets, |
585 | 586 | const std::vector<std::shared_ptr<VarBase>> &output_targets, |
586 | 587 | const std::vector<std::shared_ptr<VarBase>> &output_grads, |
587 | 588 | const std::vector<std::shared_ptr<VarBase>> &no_grad_vars, |
588 | | - const platform::Place &place, const detail::BackwardStrategy &strategy, |
589 | | - bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs) { |
| 589 | + const platform::Place &place, bool create_graph, bool retain_graph, |
| 590 | + bool allow_unused, bool only_inputs) { |
590 | 591 | input_targets_ = input_targets; |
591 | 592 | place_ = place; |
592 | 593 | create_graph_ = create_graph; |
593 | 594 | retain_graph_ = retain_graph; |
594 | 595 | allow_unused_ = allow_unused; |
595 | 596 | only_inputs_ = only_inputs; |
596 | | - strategy_ = strategy; |
597 | 597 |
|
598 | 598 | PADDLE_ENFORCE_EQ(only_inputs_, true, |
599 | 599 | platform::errors::Unimplemented( |
@@ -981,7 +981,7 @@ void PartialGradTask::PrepareInitialGradientAccumulators(const OpBase *op) { |
981 | 981 |
|
982 | 982 | if (!accumulator) { |
983 | 983 | accumulator.reset(new GradientAccumulationInfo( |
984 | | - var, strategy_.sorted_sum_gradient_, create_graph_)); |
| 984 | + var, sorted_sum_gradient_, create_graph_)); |
985 | 985 | } |
986 | 986 |
|
987 | 987 | accumulator->IncreaseTotalRefCnt(); |
@@ -1033,11 +1033,11 @@ PartialGradEngine::PartialGradEngine( |
1033 | 1033 | const std::vector<std::shared_ptr<VarBase>> &output_targets, |
1034 | 1034 | const std::vector<std::shared_ptr<VarBase>> &output_grads, |
1035 | 1035 | const std::vector<std::shared_ptr<VarBase>> &no_grad_vars, |
1036 | | - const platform::Place &place, const detail::BackwardStrategy &strategy, |
1037 | | - bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs) |
| 1036 | + const platform::Place &place, bool create_graph, bool retain_graph, |
| 1037 | + bool allow_unused, bool only_inputs) |
1038 | 1038 | : task_(new PartialGradTask(input_targets, output_targets, output_grads, |
1039 | | - no_grad_vars, place, strategy, create_graph, |
1040 | | - retain_graph, allow_unused, only_inputs)) {} |
| 1039 | + no_grad_vars, place, create_graph, retain_graph, |
| 1040 | + allow_unused, only_inputs)) {} |
1041 | 1041 |
|
1042 | 1042 | PartialGradEngine::~PartialGradEngine() { Clear(); } |
1043 | 1043 |
|
|
0 commit comments