Skip to content

Commit dffb0b2

Browse files
authored
fix set_grad_ivar bug of Tensor.backward (#34819)
1 parent 6326c3e commit dffb0b2

File tree

4 files changed

+94
-33
lines changed

4 files changed

+94
-33
lines changed

paddle/fluid/imperative/gradient_accumulator.cc

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,12 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) {
184184
auto data_type = src_tensor.type();
185185
auto place = src_tensor.place();
186186

187+
PADDLE_ENFORCE_EQ(dst_tensor->type(), data_type,
188+
platform::errors::PreconditionNotMet(
189+
"The data type of source tensor and destination tensor "
190+
"should be equal, Otherwise, the calculation results "
191+
"will be incorrect."));
192+
187193
#define PADDLE_TENSOR_ADD(cpp_type) \
188194
if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) { \
189195
TensorAddFunctor<cpp_type> func( \
@@ -422,9 +428,9 @@ void GradientAccumulator::AccumulateGrad() {
422428
auto* src = inner_var_->MutableVar();
423429
auto* dst = var_->MutableVar();
424430
if (!var_->IsEmpty()) {
425-
VLOG(6) << "Leaf Gradient Var(" << var_->Name()
426-
<< ") has been calculated by previous graph, will accumulate on "
427-
"previous graph.";
431+
VLOG(6) << "Leaf Var(" << var_->Name()
432+
<< ")'s Gradient has been initizlized, will accumulate on "
433+
"previous gradient.";
428434
if (dst->IsType<framework::LoDTensor>()) {
429435
if (src->IsType<framework::LoDTensor>()) {
430436
TensorAdd(*src, dst);
@@ -444,8 +450,9 @@ void GradientAccumulator::AccumulateGrad() {
444450
"Only support LoDTensor and SelectedRows for gradient var"));
445451
}
446452
} else {
447-
VLOG(6) << "Leaf Gradient Var(" << var_->Name()
448-
<< ") has not been initialized, not accumulate. Just move";
453+
VLOG(6)
454+
<< "Leaf Var(" << var_->Name()
455+
<< ")'s Gradient has not been initialized, not accumulate. Just move";
449456
*(dst) = std::move(*src);
450457
var_->SetType(inner_var_->Type());
451458
var_->SetDataType(inner_var_->DataType());

paddle/fluid/imperative/layer.cc

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -277,32 +277,73 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
277277
}
278278

279279
void VarBase::CopyFrom(const VarBase& src, const bool blocking) {
280-
if (SharedVar()->IsEmpty()) {
281-
VLOG(3) << "deep copy Variable from " << src.Name() << " to " << Name();
282-
SetPersistable(src.Persistable());
280+
if (src.SharedVar()->IsEmpty()) {
281+
return;
282+
}
283+
284+
VLOG(3) << "Deep copy Tensor from " << src.Name() << " to " << Name();
285+
if (Var().IsInitialized()) {
286+
PADDLE_ENFORCE_EQ(DataType(), src.DataType(),
287+
platform::errors::PreconditionNotMet(
288+
"Tensor %s has different data type with Tensor %s, "
289+
"Tensor Copy cannot be performed!",
290+
Name(), src.Name()));
291+
PADDLE_ENFORCE_EQ(Type(), src.Type(),
292+
platform::errors::PreconditionNotMet(
293+
"Tensor %s has different type with Tensor %s, Tensor "
294+
"Copy cannot be performed!",
295+
Name(), src.Name()));
296+
} else {
283297
SetDataType(src.DataType());
284298
SetType(src.Type());
285-
SetOverridedStopGradient(src.OverridedStopGradient());
286-
if (!src.SharedVar()->IsEmpty()) {
287-
const platform::Place& place = src.Place();
288-
if (src.Var().IsType<framework::LoDTensor>()) {
289-
auto& src_tensor = src.Var().Get<framework::LoDTensor>();
290-
auto* dst_tensor = MutableVar()->GetMutable<framework::LoDTensor>();
291-
dst_tensor->set_lod(src_tensor.lod());
292-
framework::TensorCopy(src_tensor, place, dst_tensor);
293-
} else if (src.Var().IsType<framework::SelectedRows>()) {
294-
auto& src_selected_rows = src.Var().Get<framework::SelectedRows>();
295-
auto* dst_selected_rows =
296-
MutableVar()->GetMutable<framework::SelectedRows>();
297-
dst_selected_rows->set_height(src_selected_rows.height());
298-
dst_selected_rows->set_rows(src_selected_rows.rows());
299-
framework::TensorCopy(src_selected_rows.value(), place,
300-
dst_selected_rows->mutable_value());
301-
}
302-
if (blocking) {
303-
platform::DeviceContextPool::Instance().Get(place)->Wait();
304-
}
299+
SetPersistable(src.Persistable());
300+
InnerSetOverridedStopGradient(src.OverridedStopGradient());
301+
}
302+
303+
platform::Place place = src.Place();
304+
if (src.Var().IsType<framework::LoDTensor>()) {
305+
auto& src_tensor = src.Var().Get<framework::LoDTensor>();
306+
auto* dst_tensor = MutableVar()->GetMutable<framework::LoDTensor>();
307+
if (dst_tensor && dst_tensor->IsInitialized()) {
308+
PADDLE_ENFORCE_EQ(dst_tensor->dims(), src_tensor.dims(),
309+
platform::errors::PreconditionNotMet(
310+
"Tensor %s has different dims with Tensor %s, "
311+
"Tensor Copy cannot be performed!",
312+
Name(), src.Name()));
313+
PADDLE_ENFORCE_EQ(dst_tensor->lod(), src_tensor.lod(),
314+
platform::errors::PreconditionNotMet(
315+
"Tensor %s has different dims with Tensor %s, "
316+
"Tensor Copy cannot be performed!",
317+
Name(), src.Name()));
318+
place = Place();
319+
} else {
320+
dst_tensor->set_lod(src_tensor.lod());
321+
dst_tensor->Resize(src_tensor.dims());
322+
}
323+
framework::TensorCopy(src_tensor, place, dst_tensor);
324+
} else if (src.Var().IsType<framework::SelectedRows>()) {
325+
auto& src_selected_rows = src.Var().Get<framework::SelectedRows>();
326+
auto* dst_selected_rows =
327+
MutableVar()->GetMutable<framework::SelectedRows>();
328+
dst_selected_rows->set_height(src_selected_rows.height());
329+
dst_selected_rows->set_rows(src_selected_rows.rows());
330+
331+
auto& src_tensor = src_selected_rows.value();
332+
auto* dst_tensor = dst_selected_rows->mutable_value();
333+
if (dst_tensor && dst_tensor->IsInitialized()) {
334+
PADDLE_ENFORCE_EQ(dst_tensor->dims(), src_tensor.dims(),
335+
platform::errors::PreconditionNotMet(
336+
"Tensor %s has different dims with Tensor %s, "
337+
"Tensor Copy cannot be performed!",
338+
Name(), src.Name()));
339+
place = Place();
340+
} else {
341+
dst_tensor->Resize(src_tensor.dims());
305342
}
343+
framework::TensorCopy(src_tensor, place, dst_tensor);
344+
}
345+
if (blocking) {
346+
platform::DeviceContextPool::Instance().Get(place)->Wait();
306347
}
307348
}
308349

paddle/fluid/imperative/layer.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ class VarBase {
110110

111111
void SetGradVarBase(const VarBase& grad_var) {
112112
MutableGradVarBase()->CopyFrom(grad_var, true);
113+
MutableGradVarBase()->SharedVar()->SetIsEmpty(false);
113114
}
114115

115116
const std::shared_ptr<VarBase>& MutableGradVarBase() {
@@ -142,6 +143,8 @@ class VarBase {
142143
return grad_var_->MutableVar();
143144
}
144145

146+
bool IsLeaf() const { return var_->IsLeaf(); }
147+
145148
void SetOverridedStopGradient(bool stop_gradient) {
146149
var_->SetOverridedStopGradient(stop_gradient);
147150
if (grad_var_) {
@@ -151,17 +154,19 @@ class VarBase {
151154

152155
bool OverridedStopGradient() const { return var_->OverridedStopGradient(); }
153156

154-
bool IsLeaf() const { return var_->IsLeaf(); }
155-
156157
void InnerSetOverridedStopGradient(bool stop_gradient) {
157-
if (var_->InnerOverridedStopGradient() == -1) {
158+
if (InnerOverridedStopGradient() == -1) {
158159
var_->InnerSetOverridedStopGradient(stop_gradient);
159160
if (grad_var_) {
160161
grad_var_->InnerSetOverridedStopGradient(stop_gradient);
161162
}
162163
}
163164
}
164165

166+
int InnerOverridedStopGradient() const {
167+
return var_->InnerOverridedStopGradient();
168+
}
169+
165170
void SetPersistable(bool persistable) { var_->SetPersistable(persistable); }
166171

167172
bool Persistable() const { return var_->Persistable(); }

python/paddle/fluid/tests/unittests/test_imperative_basic.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def forward(self, inputs):
4141
class MLP(fluid.Layer):
4242
def __init__(self, input_size):
4343
super(MLP, self).__init__()
44-
self._linear1 = None
4544
self._linear1 = Linear(
4645
input_size,
4746
3,
@@ -607,12 +606,21 @@ def test_mlp(sort_sum_gradient):
607606

608607
mlp2.clear_gradients()
609608
self.assertTrue(np.array_equal(clear_loss.grad.numpy(), [1]))
610-
if ((batch_id + 1) % 10) == 0:
609+
if ((batch_id + 1) % 10) % 2 == 0:
611610
mlp1.clear_gradients()
612611
expected_weight1_grad = 0.
613612
expected_bias1_grad = 0.
614613
expected_weight2_grad = 0.
615614
expected_bias2_grad = 0.
615+
elif ((batch_id + 1) % 10) % 2 == 1:
616+
mlp1.clear_gradients()
617+
mlp1._linear1.weight._set_grad_ivar(
618+
paddle.ones([input_size, 3]))
619+
mlp1._linear2.weight._set_grad_ivar(paddle.ones([3, 4]))
620+
expected_weight1_grad = 1.
621+
expected_bias1_grad = 0.
622+
expected_weight2_grad = 1.
623+
expected_bias2_grad = 0.
616624

617625
with fluid.dygraph.guard():
618626
test_single_api(False)

0 commit comments

Comments
 (0)