-
Notifications
You must be signed in to change notification settings - Fork 6k
add custom init grad for backward function #31540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 21 commits
d0915f8
0bccce6
5dac8e9
ef4c7b9
33b0416
837e26b
1901970
55e0cfb
8271dc0
5af3bd0
1467feb
eb267fa
b80f449
2bb8f3c
41b375f
6974e5c
1e3e975
c7de011
2f2824c
8415df4
be065e4
7f8e58c
0374c0b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,48 +36,74 @@ DECLARE_bool(sort_sum_gradient); | |
| namespace paddle { | ||
| namespace imperative { | ||
|
|
||
| void BasicEngine::Init(VarBase* var, bool retain_graph) { | ||
| void BasicEngine::Init( | ||
| const std::vector<std::shared_ptr<VarBase>>& tensors, | ||
| const std::vector<std::shared_ptr<VarBase>>& grad_tensors, | ||
| bool retain_graph) { | ||
| retain_graph_ = retain_graph; | ||
| init_node_ = var->GradVarBase()->GradNode(); | ||
| PADDLE_ENFORCE_EQ(var->GradVarBase()->GraphIsFreed(), false, | ||
| platform::errors::Unavailable( | ||
| "%s trying to backward through the same graph a second " | ||
| "time, but this graph have already been freed. Please " | ||
| "specify Tensor.backward(retain_graph=True) when " | ||
| "calling backward at the first time.", | ||
| var->Name())); | ||
|
|
||
| if (!retain_graph) { | ||
| VLOG(5) << "Clear the auto-grad graph from grad var " << var->Name() | ||
| << " because of retain_graph=False when calling backward"; | ||
| var->GradVarBase()->SetGraphIsFreed(true); | ||
| var->GradVarBase()->ClearGradNode(); | ||
| } | ||
|
|
||
| if (init_node_ == nullptr || var->OverridedStopGradient()) { | ||
| VLOG(3) << "Skip auto grad since there is no grad op for var or loss is " | ||
| "stop_gradient=True: " | ||
| << var->Name(); | ||
| return; | ||
| } | ||
| PADDLE_ENFORCE_EQ( | ||
| tensors.size(), grad_tensors.size(), | ||
| platform::errors::Unavailable( | ||
| "the size of tensors must equal the size of grad_tensors, but" | ||
| "the size of tensors is %s, and the size of grad_tensors is %s.", | ||
| tensors.size(), grad_tensors.size())); | ||
|
|
||
| for (size_t i = 0; i < tensors.size(); ++i) { | ||
| auto var = tensors[i]; | ||
| auto grad_tensor = grad_tensors[i]; | ||
|
|
||
| auto init_node_ = var->GradVarBase()->GradNode(); | ||
|
||
| PADDLE_ENFORCE_EQ( | ||
| var->GradVarBase()->GraphIsFreed(), false, | ||
| platform::errors::Unavailable( | ||
| "%s trying to backward through the same graph a second " | ||
| "time, but this graph have already been freed. Please " | ||
| "specify Tensor.backward(retain_graph=True) when " | ||
| "calling backward at the first time.", | ||
| var->Name())); | ||
|
|
||
| if (!retain_graph) { | ||
| VLOG(5) << "Clear the auto-grad graph from grad var " << var->Name() | ||
| << " because of retain_graph=False when calling backward"; | ||
| var->GradVarBase()->SetGraphIsFreed(true); | ||
| var->GradVarBase()->ClearGradNode(); | ||
| } | ||
|
|
||
| VLOG(3) << "Init first node of backward"; | ||
| if (init_node_ == nullptr || var->OverridedStopGradient()) { | ||
| VLOG(3) << "Skip auto grad since there is no grad op for var or loss is " | ||
| "stop_gradient=True: " | ||
| << var->Name(); | ||
| continue; | ||
| } | ||
|
|
||
| PADDLE_ENFORCE_EQ( | ||
| var->HasGradVar(), true, | ||
| platform::errors::NotFound("Grad variable not exist for variable %s", | ||
| var->Name())); | ||
|
|
||
| auto& fwd_var = var->Var().Get<framework::LoDTensor>(); | ||
| auto* grad_var = | ||
| var->GradVarBase()->MutableVar()->GetMutable<framework::LoDTensor>(); | ||
| VLOG(6) << "init loss grad:" << var->GradVarBase()->Name() | ||
| << " as stop_gradient false"; | ||
| var->GradVarBase()->InnerSetOverridedStopGradient(false); | ||
| auto* dev_ctx = platform::DeviceContextPool::Instance().Get(fwd_var.place()); | ||
| grad_var->Resize(fwd_var.dims()); | ||
| grad_var->mutable_data(fwd_var.place(), fwd_var.type()); | ||
| operators::math::set_constant(*dev_ctx, grad_var, 1.0); | ||
| VLOG(3) << "Init node of backward"; | ||
|
|
||
| PADDLE_ENFORCE_EQ( | ||
| var->HasGradVar(), true, | ||
| platform::errors::NotFound("Grad variable not exist for variable %s", | ||
|
||
| var->Name())); | ||
|
|
||
| auto& fwd_var = var->Var().Get<framework::LoDTensor>(); | ||
| auto* grad_var = | ||
| var->GradVarBase()->MutableVar()->GetMutable<framework::LoDTensor>(); | ||
| VLOG(6) << "init loss grad:" << var->GradVarBase()->Name() | ||
| << " as stop_gradient false"; | ||
| var->GradVarBase()->InnerSetOverridedStopGradient(false); | ||
| auto* dev_ctx = | ||
| platform::DeviceContextPool::Instance().Get(fwd_var.place()); | ||
| if (grad_tensor == nullptr) { | ||
| grad_var->Resize(fwd_var.dims()); | ||
| grad_var->mutable_data(fwd_var.place(), fwd_var.type()); | ||
| operators::math::set_constant(*dev_ctx, grad_var, 1.0); | ||
| } else { | ||
| paddle::framework::TensorCopy( | ||
| grad_tensor->Var().Get<framework::LoDTensor>(), fwd_var.place(), | ||
| *dev_ctx, grad_var); | ||
| } | ||
|
|
||
| init_nodes_.push_back(init_node_); | ||
| } | ||
| } | ||
|
|
||
| void BasicEngine::CheckBackwardInputs(const OpBase& op) { | ||
|
|
@@ -235,8 +261,10 @@ void BasicEngine::PrepareDeps() { | |
| std::queue<GradOpNode*> q; | ||
| std::unordered_set<GradOpNode*> visited; | ||
|
|
||
| q.push(init_node_.get()); | ||
| visited.insert(init_node_.get()); | ||
| for (size_t i = 0; i < init_nodes_.size(); ++i) { | ||
| q.push(init_nodes_[i].get()); | ||
| visited.insert(init_nodes_[i].get()); | ||
| } | ||
|
|
||
| while (!q.empty()) { | ||
| auto* cur_node = q.front(); | ||
|
|
@@ -263,14 +291,16 @@ void BasicEngine::PrepareDeps() { | |
| } | ||
|
|
||
| void BasicEngine::Execute() { | ||
| if (init_node_ == nullptr) { | ||
| if (init_nodes_.empty()) { | ||
| return; | ||
| } | ||
|
|
||
| PrepareDeps(); | ||
| // Start execute Computation graph | ||
| std::queue<std::shared_ptr<GradOpNode>> q; | ||
| q.push(std::move(init_node_)); | ||
| for (size_t i = 0; i < init_nodes_.size(); ++i) { | ||
| q.push(std::move(init_nodes_[i])); | ||
| } | ||
|
|
||
| size_t op_num = 0; | ||
|
|
||
|
|
@@ -470,7 +500,7 @@ void BasicEngine::Execute() { | |
| } | ||
|
|
||
| void BasicEngine::Clear() { | ||
| init_node_.reset(); | ||
| init_nodes_.clear(); | ||
| node_deps_.clear(); | ||
| accumulators_.clear(); | ||
| accumulators_with_grad_node_.clear(); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,7 +30,9 @@ class OpBase; | |
|
|
||
| class BasicEngine : public Engine { | ||
| public: | ||
| void Init(VarBase* var, bool retain_graph = false); | ||
| void Init(const std::vector<std::shared_ptr<VarBase>>& tensors, | ||
| const std::vector<std::shared_ptr<VarBase>>& grad_tensors, | ||
| bool retain_graph = false); | ||
|
|
||
| void Execute() override; | ||
|
|
||
|
|
@@ -46,7 +48,7 @@ class BasicEngine : public Engine { | |
| void Clear(); | ||
|
|
||
| private: | ||
| std::shared_ptr<GradOpNode> init_node_; | ||
| std::vector<std::shared_ptr<GradOpNode>> init_nodes_; | ||
| std::unordered_map<GradOpNode*, size_t> node_deps_; | ||
| // The input and output of Inplace op are the same. If only `var` is used | ||
| // as the key, then the input and output of inplace op must be gradient | ||
|
|
@@ -74,6 +76,7 @@ class BasicEngine : public Engine { | |
| std::vector<GradientAccumulator*> leaf_accumulators_; | ||
|
|
||
| bool retain_graph_; | ||
| bool create_graph_; | ||
|
||
| }; | ||
|
|
||
| } // namespace imperative | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -720,6 +720,7 @@ void BindImperative(py::module *m_ptr) { | |
| Bump the version whenever the Tensor is modified through an inplace operation. | ||
| )DOC") | ||
| .def("numpy", | ||
|
|
||
| [](imperative::VarBase &self) -> py::array { | ||
| const auto &tensor = | ||
| self.MutableVar()->Get<framework::LoDTensor>(); | ||
|
|
@@ -919,12 +920,17 @@ void BindImperative(py::module *m_ptr) { | |
| print(x.grad) # None | ||
| )DOC") | ||
| .def("_run_backward", | ||
| [](imperative::VarBase &self, const imperative::Tracer &tracer, | ||
| bool retain_graph) { | ||
| [](std::shared_ptr<imperative::VarBase> &self, | ||
|
||
| const imperative::Tracer &tracer, bool retain_graph, | ||
| std::shared_ptr<imperative::VarBase> &grad_tensor) { | ||
| // TODO(jiabin): when we impl more backward execution we can | ||
| // select them | ||
| std::vector<std::shared_ptr<imperative::VarBase>> tensors{self}; | ||
| std::vector<std::shared_ptr<imperative::VarBase>> grad_tensors{ | ||
| grad_tensor}; | ||
|
|
||
| auto *engine = tracer.GetEngine(); | ||
| engine->Init(&self, retain_graph); | ||
| engine->Init(tensors, grad_tensors, retain_graph); | ||
| VLOG(3) << "Start backward"; | ||
| engine->Execute(); | ||
| VLOG(3) << "Finish backward"; | ||
|
|
@@ -1412,6 +1418,19 @@ void BindImperative(py::module *m_ptr) { | |
| }, | ||
| py::call_guard<py::gil_scoped_release>()); | ||
|
|
||
| m.def( | ||
| "dygraph_run_backward", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this method no need show to users, use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| [](const std::vector<std::shared_ptr<imperative::VarBase>> &tensors, | ||
| const std::vector<std::shared_ptr<imperative::VarBase>> &grad_tensors, | ||
| bool retain_graph, const imperative::Tracer &tracer) { | ||
| auto *engine = tracer.GetEngine(); | ||
| engine->Init(tensors, grad_tensors, retain_graph); | ||
| VLOG(3) << "Start backward"; | ||
| engine->Execute(); | ||
| VLOG(3) << "Finish backward"; | ||
| }, | ||
| py::call_guard<py::gil_scoped_release>()); | ||
|
|
||
| #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ | ||
| defined(PADDLE_WITH_XPU_BKCL) | ||
| py::class_<imperative::ParallelContext, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from ..fluid.dygraph.base import grad #DEFINE_ALIAS | ||
|
|
||
| from . import backward_mode | ||
| from .backward_mode import backward | ||
|
|
||
| __all__ = ['grad'] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also need
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the next line |
||
|
|
||
| __all__ += backward_mode.__all__ | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recommend to capitalize the first letter,
the->TheThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done