Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/fluid/eager/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,8 @@ std::vector<paddle::Tensor> RunBackward(
node_input_buffer->Buffers(), create_graph, is_general_grad);

if (!inputs.empty() && is_general_grad) {
GeneralGrad::Instance().SetResultForEnddingNodes(grad_output_tensors,
node);
GeneralGrad::Instance().SetResultForEndingNodes(grad_output_tensors,
node);
}

// retain_grad or not
Expand Down
56 changes: 28 additions & 28 deletions paddle/fluid/eager/general_grad.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,15 @@ class GeneralGrad {
}
visited.insert(target_node);
if (!(depending_nodes_)[target_node].empty()) {
auto precedding_nodes = (depending_nodes_)[target_node];
for (auto pre_nodes : precedding_nodes) {
auto preceding_nodes = (depending_nodes_)[target_node];
for (auto pre_nodes : preceding_nodes) {
queue.push_back(pre_nodes);
needed_nodes_.emplace(pre_nodes);
if (IsInputTargetNodes(pre_nodes)) {
input_target_nodes_on_path.emplace(pre_nodes);
}
}
} else { // startup_ops have no precedding nodes
} else { // startup_ops have no preceding nodes
VLOG(6) << "Emplace startup_ops";
startup_ops.emplace(target_node);
needed_nodes_.emplace(target_node);
Expand All @@ -143,7 +143,7 @@ class GeneralGrad {
input_target_nodes_inputmeta_map_) {
if (!input_target_nodes_on_path.count(
target_nodes_inputmeta_pair.first)) {
endding_nodes_.emplace(target_nodes_inputmeta_pair.first);
ending_nodes_.emplace(target_nodes_inputmeta_pair.first);
}
}

Expand Down Expand Up @@ -236,12 +236,12 @@ class GeneralGrad {
} // TODO(jiabin): Some check here.
}

void SetResultForEnddingNodes(
void SetResultForEndingNodes(
paddle::small_vector<std::vector<paddle::Tensor>, kSlotSmallVectorSize>
grad_output,
GradNodeBase* node) {
if (IsEnddingNodes(node)) {
VLOG(6) << "Set result for endding_nodes_ with grad_output_tensors";
if (IsEndingNodes(node)) {
VLOG(6) << "Set result for ending_nodes_ with grad_output_tensors";
results_map_[node] = std::make_shared<paddle::Tensor>(grad_output[0][0]);
}
}
Expand Down Expand Up @@ -275,9 +275,9 @@ class GeneralGrad {
}

// Register Hook to fetch input's gradients, when input's grad node is not an
// endding node in backward graph. If input's grad node is an endding node in
// ending node in backward graph. If input's grad node is an ending node in
// backward graph, use grad node's output as inputs' gradients and no need to
// register Hook. Please note that endding node must be GradNodeAccumulation
// register Hook. Please note that ending node must be GradNodeAccumulation
// after ModifyBackwardGraph function.
void RegisterFetchGradHook(const std::vector<paddle::Tensor>& inputs) {
VLOG(6) << "Running in RegisterFetchGradHook.";
Expand All @@ -296,8 +296,8 @@ class GeneralGrad {

if (orig_to_copied_node_map_.count(target_node)) {
target_node = orig_to_copied_node_map_[target_node].get();
if (copied_node_to_endding_node_map_.count(target_node)) {
VLOG(6) << "No need to call FetchGradForTensor for endding_nodes";
if (copied_node_to_ending_node_map_.count(target_node)) {
VLOG(6) << "No need to call FetchGradForTensor for ending_nodes";
continue;
}
}
Expand All @@ -309,7 +309,7 @@ class GeneralGrad {
"stop_gradient=True.",
i));

if (!IsEnddingNodes(target_node)) {
if (!IsEndingNodes(target_node)) {
// Fetch grad for tensor in target_node on path.
auto fetched_grad = FetchGradForTensor(inputs[i], target_node);
results_map_[target_node] = fetched_grad;
Expand All @@ -321,9 +321,9 @@ class GeneralGrad {
void SetNodeToAccumulationNode(GradNodeBase* node) {
if (dynamic_cast<egr::GradNodeAccumulation*>(node)) return;
if (!(depending_nodes_)[node].empty()) {
// Find precedding_nodes of current node.
auto precedding_nodes = (depending_nodes_)[node];
for (auto pre_nodes : precedding_nodes) {
// Find preceding_nodes of current node.
auto preceding_nodes = (depending_nodes_)[node];
for (auto pre_nodes : preceding_nodes) {
paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
pre_nodes_edges = pre_nodes->MutableOutputMeta();
for (size_t i = 0; i < pre_nodes_edges.size(); i++) {
Expand All @@ -332,21 +332,21 @@ class GeneralGrad {
if (edge_.GetGradNode() == node) {
Edge& pre_node_edge = pre_nodes_edges[i][j].GetMutableEdge();

if (copied_node_to_endding_node_map_.count(node)) {
if (copied_node_to_ending_node_map_.count(node)) {
pre_node_edge.SetGradNode(
copied_node_to_endding_node_map_[node]);
copied_node_to_ending_node_map_[node]);
} else {
auto autograd_meta = egr::AutogradMeta(edge_);
std::shared_ptr<GradNodeBase> shared_grad_node_accumulation =
std::make_shared<egr::GradNodeAccumulation>(&autograd_meta);
pre_node_edge.SetGradNode(shared_grad_node_accumulation);
copied_node_to_endding_node_map_[node] =
copied_node_to_ending_node_map_[node] =
shared_grad_node_accumulation;
}

auto* grad_node = pre_node_edge.GetGradNode();
needed_nodes_.emplace(grad_node);
endding_nodes_.emplace(grad_node);
ending_nodes_.emplace(grad_node);
input_target_nodes_inputmeta_map_[grad_node] =
input_target_nodes_inputmeta_map_[node];

Expand Down Expand Up @@ -384,7 +384,7 @@ class GeneralGrad {
}
visited.insert(node);

if (IsInputTargetNodes(node) && IsEnddingNodes(node)) {
if (IsInputTargetNodes(node) && IsEndingNodes(node)) {
SetNodeToAccumulationNode(node);
continue;
}
Expand Down Expand Up @@ -413,7 +413,7 @@ class GeneralGrad {
}

if (meta.size() != 1 && IsNeededNodes(node) &&
!IsNeededNodes(next_node.get()) && !IsEnddingNodes(node)) {
!IsNeededNodes(next_node.get()) && !IsEndingNodes(node)) {
VLOG(3) << "Get stop edge from grad_node: " << node->name() << " : "
<< node << " to:" << next_node->name() << ", "
<< next_node.get() << " with output rank info: " << i
Expand Down Expand Up @@ -448,8 +448,8 @@ class GeneralGrad {
auto* target_node = auto_grad_meta->GetMutableGradNode().get();
if (orig_to_copied_node_map_.count(target_node)) {
target_node = orig_to_copied_node_map_[target_node].get();
if (copied_node_to_endding_node_map_.count(target_node)) {
target_node = copied_node_to_endding_node_map_[target_node].get();
if (copied_node_to_ending_node_map_.count(target_node)) {
target_node = copied_node_to_ending_node_map_[target_node].get();
}
} else {
VLOG(6) << "Unable to find target node in "
Expand Down Expand Up @@ -480,7 +480,7 @@ class GeneralGrad {

bool IsNeededNodes(GradNodeBase* node) { return needed_nodes_.count(node); }

bool IsEnddingNodes(GradNodeBase* node) { return endding_nodes_.count(node); }
bool IsEndingNodes(GradNodeBase* node) { return ending_nodes_.count(node); }

bool IsInputTargetNodes(GradNodeBase* node) {
auto iter = input_target_nodes_inputmeta_map_.find(node);
Expand Down Expand Up @@ -621,9 +621,9 @@ class GeneralGrad {
results_map_.clear();
copied_grad_nodes_.clear();
orig_to_copied_node_map_.clear();
copied_node_to_endding_node_map_.clear();
copied_node_to_ending_node_map_.clear();
needed_nodes_.clear();
endding_nodes_.clear();
ending_nodes_.clear();
}

private:
Expand All @@ -649,8 +649,8 @@ class GeneralGrad {
std::unordered_set<GradNodeBase*> needed_nodes_;
// Record which grad_node has been transformed to AccumulationNode
std::unordered_map<GradNodeBase*, std::shared_ptr<GradNodeBase>>
copied_node_to_endding_node_map_;
std::unordered_set<GradNodeBase*> endding_nodes_;
copied_node_to_ending_node_map_;
std::unordered_set<GradNodeBase*> ending_nodes_;

DISABLE_COPY_AND_ASSIGN(GeneralGrad);
};
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/framework/details/op_handle_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ class OpHandleBase {

virtual bool GetSkipRunning() const { return skip_running_; }

virtual void SetSkipRunning(bool skip_runing) { skip_running_ = skip_runing; }
virtual void SetSkipRunning(bool skip_running) {
skip_running_ = skip_running;
}

virtual std::string Name() const = 0;

Expand Down