-
Notifications
You must be signed in to change notification settings - Fork 6k
add custom init grad for backward function #31540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
d0915f8
0bccce6
5dac8e9
ef4c7b9
33b0416
837e26b
1901970
55e0cfb
8271dc0
5af3bd0
1467feb
eb267fa
b80f449
2bb8f3c
41b375f
6974e5c
1e3e975
c7de011
2f2824c
8415df4
be065e4
7f8e58c
0374c0b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,7 +36,7 @@ DECLARE_bool(sort_sum_gradient); | |
| namespace paddle { | ||
| namespace imperative { | ||
|
|
||
| void BasicEngine::Init(VarBase* var, bool retain_graph) { | ||
| void BasicEngine::Init(VarBase* var, bool retain_graph, VarBase* grad_tensor) { | ||
| retain_graph_ = retain_graph; | ||
| init_node_ = var->GradVarBase()->GradNode(); | ||
| PADDLE_ENFORCE_EQ(var->GradVarBase()->GraphIsFreed(), false, | ||
|
|
@@ -75,9 +75,15 @@ void BasicEngine::Init(VarBase* var, bool retain_graph) { | |
| << " as stop_gradient false"; | ||
| var->GradVarBase()->InnerSetOverridedStopGradient(false); | ||
| auto* dev_ctx = platform::DeviceContextPool::Instance().Get(fwd_var.place()); | ||
| grad_var->Resize(fwd_var.dims()); | ||
| grad_var->mutable_data(fwd_var.place(), fwd_var.type()); | ||
| operators::math::set_constant(*dev_ctx, grad_var, 1.0); | ||
| if (grad_tensor == nullptr) { | ||
| grad_var->Resize(fwd_var.dims()); | ||
| grad_var->mutable_data(fwd_var.place(), fwd_var.type()); | ||
| operators::math::set_constant(*dev_ctx, grad_var, 1.0); | ||
| } else { | ||
| paddle::framework::TensorCopy( | ||
|
||
| grad_tensor->Var().Get<framework::LoDTensor>(), fwd_var.place(), | ||
| *dev_ctx, grad_var); | ||
| } | ||
| } | ||
|
|
||
| void BasicEngine::CheckBackwardInputs(const OpBase& op) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -920,11 +920,11 @@ void BindImperative(py::module *m_ptr) { | |
| )DOC") | ||
| .def("_run_backward", | ||
| [](imperative::VarBase &self, const imperative::Tracer &tracer, | ||
| bool retain_graph) { | ||
| bool retain_graph, imperative::VarBase &grad_tensor) { | ||
|
||
| // TODO(jiabin): when we impl more backward execution we can | ||
| // select them | ||
| auto *engine = tracer.GetEngine(); | ||
| engine->Init(&self, retain_graph); | ||
| engine->Init(&self, retain_graph, &grad_tensor); | ||
| VLOG(3) << "Start backward"; | ||
| engine->Execute(); | ||
| VLOG(3) << "Finish backward"; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -133,7 +133,7 @@ def set_value(self, value): | |
| framework._current_expected_place()) | ||
|
|
||
| @framework.dygraph_only | ||
| def backward(self, retain_graph=False): | ||
| def backward(self, retain_graph=False, grad_tensor=None): | ||
|
||
| """ | ||
| Run backward of current Graph which starts from current Tensor. | ||
|
|
||
|
|
@@ -147,6 +147,10 @@ def backward(self, retain_graph=False): | |
| :code:`retain_graph` to True, then the grads will be retained. Thus, seting it to False is much more memory-efficient. | ||
| Defaults to False. | ||
|
|
||
| grad_tensor(Tensor, optional): initial gradient values of `outputs` . If `grad_tensor` is None, | ||
|
||
| the initial gradient values of `outputs` would be Tensor filled with 1; | ||
|
||
| if `grad_tensor` is not None, it must have the same length as `outputs`. | ||
|
||
| Default None. | ||
| Returns: | ||
| NoneType: None | ||
|
|
||
|
|
@@ -168,6 +172,17 @@ def backward(self, retain_graph=False): | |
| print("{}".format(x.grad)) | ||
| # 0. | ||
|
|
||
| grad_tensor=paddle.to_tensor(2.) | ||
| for i in range(5): | ||
| y = paddle.pow(x, 4.0) | ||
| y.backward(grad_tensor=grad_tensor) | ||
|
||
| print("{}: {}".format(i, x.grad)) | ||
| # 0: [1000.] | ||
| # 1: [2000.] | ||
| # 2: [3000.] | ||
| # 3: [4000.] | ||
| # 4: [5000.] | ||
|
|
||
| """ | ||
| if framework.in_dygraph_mode(): | ||
| if paddle.is_compiled_with_xpu(): | ||
|
|
@@ -176,7 +191,12 @@ def backward(self, retain_graph=False): | |
| scaled_loss._run_backward(framework._dygraph_tracer(), | ||
| retain_graph) | ||
| else: | ||
| self._run_backward(framework._dygraph_tracer(), retain_graph) | ||
| if grad_tensor is not None: | ||
| assert grad_tensor.shape == self.shape, "Variable Shape not match, Variable of grad_tensor [ {} ] with shape {} mismatch Variable [ {} ] with shape {}".format( | ||
|
||
| grad_tensor.name, grad_tensor.shape, self.name, | ||
| self.shape) | ||
| self._run_backward(framework._dygraph_tracer(), retain_graph, | ||
| grad_tensor) | ||
| else: | ||
| raise ValueError( | ||
| "Variable.backward() is only available in DyGraph mode") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import print_function | ||
|
|
||
| import unittest | ||
| import numpy as np | ||
|
|
||
| import paddle | ||
| import paddle.fluid.dygraph as dg | ||
| from op_test import OpTest | ||
|
|
||
|
|
||
| class TestBackward(unittest.TestCase): | ||
| def setUp(self): | ||
| self._dtypes = ["float32", "float64"] | ||
| self._places = [paddle.CPUPlace()] | ||
| if paddle.is_compiled_with_cuda(): | ||
| self._places.append(paddle.CUDAPlace(0)) | ||
|
|
||
| def test_all_positive(self): | ||
| for dtype in self._dtypes: | ||
| x = np.random.random([2, 100]).astype(dtype) | ||
| y = np.random.random([100, 2]).astype(dtype) | ||
| z = np.matmul(x, y) | ||
| grad = np.random.random(z.shape).astype(dtype) | ||
| for place in self._places: | ||
| with dg.guard(place): | ||
| x_tensor = paddle.to_tensor(x, stop_gradient=False) | ||
| y_tensor = paddle.to_tensor(y) | ||
| z_tensor = paddle.matmul(x_tensor, y_tensor) | ||
|
|
||
| grad_tensor = paddle.to_tensor(grad) | ||
| z_tensor.backward(grad_tensor=grad_tensor) | ||
|
|
||
| x_grad = np.matmul(grad, y.T) | ||
|
|
||
| self.assertTrue(np.allclose(x_grad, x_tensor.grad)) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以把 grad_tensor 设置为默认从参数nullptr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
声明处默认参数为nullptr