Skip to content

Commit 7ba85ac

Browse files
authored
Add inner register backward hook method for Tensor (#32171)
* add register backward hook method * add leaf grad accumullated test
1 parent f3e49c4 commit 7ba85ac

File tree

9 files changed

+280
-135
lines changed

9 files changed

+280
-135
lines changed

paddle/fluid/imperative/basic_engine.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -284,15 +284,15 @@ static std::shared_ptr<NameVarMap<VariableWrapper>> CallGradientHooks(
284284
for (const auto& pair : bwd_ins) {
285285
for (size_t i = 0; i < pair.second.size(); ++i) {
286286
auto& var = pair.second[i];
287-
if (var->HasHook()) {
287+
if (var->HasVariableWrapperHook()) {
288288
if (tmp_ins_ptr == nullptr) {
289289
tmp_ins_ptr = std::make_shared<NameVarMap<VariableWrapper>>(bwd_ins);
290290
}
291-
VLOG(3) << "Call " << var->GetHooks().size() << " hooks of " << op_type
292-
<< "'s input `" << pair.first << "`'s var `" << var->Name()
293-
<< "`.";
291+
VLOG(3) << "Call " << var->GetVariableWrapperHooks().size()
292+
<< " hooks of " << op_type << "'s input `" << pair.first
293+
<< "`'s var `" << var->Name() << "`.";
294294
auto tmp_var = var;
295-
for (const auto& hook_pair : var->GetHooks()) {
295+
for (const auto& hook_pair : var->GetVariableWrapperHooks()) {
296296
tmp_var = (*hook_pair.second)(tmp_var);
297297
}
298298
(*tmp_ins_ptr)[pair.first][i] = tmp_var;

paddle/fluid/imperative/gradient_accumulator.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -467,14 +467,14 @@ void GradientAccumulator::CallGradientHooks() {
467467
platform::errors::PreconditionNotMet("Leaf Tensor's inner var "
468468
"is not initialized when "
469469
"call gradient hook."));
470-
if (var_->HasHook()) {
471-
VLOG(3) << "Call " << var_->GetHooks().size()
470+
if (var_->HasVariableWrapperHook()) {
471+
VLOG(3) << "Call " << var_->GetVariableWrapperHooks().size()
472472
<< " hooks of leaf gradient accumulator's inner var `"
473473
<< var_->Name() << "`.";
474474
auto tmp_var = inner_var_;
475475
VLOG(3) << "Input var " << var_->Name() << "'s hook size - "
476-
<< var_->GetHooks().size();
477-
for (const auto& hook_pair : var_->GetHooks()) {
476+
<< var_->GetVariableWrapperHooks().size();
477+
for (const auto& hook_pair : var_->GetVariableWrapperHooks()) {
478478
tmp_var = (*hook_pair.second)(tmp_var);
479479
}
480480
inner_var_ = tmp_var;
@@ -495,10 +495,10 @@ void GradientAccumulator::CallReduceHooks() {
495495
"Only can call reduce hooks after the "
496496
"gradient accumulation is completed in "
497497
"current batch or across batchs."));
498-
if (var_->HasMutableHook()) {
499-
for (const auto& hook : var_->GetMutableHooks()) {
498+
if (var_->HasVoidHook()) {
499+
for (const auto& hook : var_->GetVoidHooks()) {
500500
VLOG(3) << "call gradient accumulator backward hooks.";
501-
(*hook)(var_);
501+
(*hook)();
502502
}
503503
}
504504
}

paddle/fluid/imperative/hooks.h

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -23,32 +23,34 @@ namespace imperative {
2323

2424
class VariableWrapper;
2525

26-
/** [ Const VariableWrapper Hook: Pre hook functor of OpBase ]
26+
/** [ VariableWrapper Hook ]
2727
*
28-
* @brief This hook functor is executed before the grad OpBase is executed,
29-
* taking the input of the current grad OpBase as input, and
30-
* executing python hooks (user-defined) or C++ hooks (developer-defined)
31-
* to achieve the purpose of custom operations on the interior VarBase
32-
* gradient.
28+
* @brief This hook functor is executed before the grad OpBase is executed or
29+
* after gradient accumulation completed in current batch.
30+
* 1. For interior var, VariableWrapper Hook take the input of the
31+
* current grad OpBase as input.
32+
* 2. For leaf var, VariableWrapper Hook take the inner_var_ of
33+
* GradientAccumulator as input.
3334
*
34-
* @note This hook functor will not change the input gradient VarBase.
35+
* @note This hook functor will not change the input gradient VariableWrapper,
36+
* but if you copy the input VariableWrapper and change the value of
37+
* Variable in VariableWrapper, the value of input will also be changed,
38+
* because they shared same PlaceHolder.
3539
*
36-
* @note [Why need to be OpBase `PreHook`, why not `PostHook`?]
40+
* @note [ Why need to be OpBase `PreHook`, why not `PostHook`? ]
3741
*
38-
* 1. We expect If set OpBase post hook, when the op executed end, the
42+
* We expect If set OpBase post hook, when the op executed end, the
3943
* op's output gradient may not be the final state, because it may need
4044
* other op's gradient output to accumulated to it. But before op can
4145
* be executed, the gradient output must have been accumulated to final
4246
* value.
43-
* 2. We don’t want the hook to change its input Tensor value, so now
44-
* we can't call all hooks in GradAccumulator.
4547
*
46-
* @note [Why only can be used for interior VarBase?]
48+
* @note [ Why Leaf gradient is special? ]
4749
*
4850
* Because the leaf VarBase's GradVarBase has no GradOpNode, so leaf
4951
* GradVarBase has no next OpBase to executed, so if need to deal with
50-
* the leaf GradVarBase, cannot use this hook functor. For this case, we
51-
* deal with by other inplace hook method.
52+
* the leaf GradVarBase, we should call hooks after gradient accumulation
53+
* completed.
5254
*/
5355
class VariableWrapperHook {
5456
public:
@@ -57,34 +59,22 @@ class VariableWrapperHook {
5759
const std::shared_ptr<VariableWrapper>& var) = 0;
5860
};
5961

60-
/** [ Inplace VariableWrapper Hook: Post hook functor of GradAccumulator ]
61-
*
62-
* @brief This hook functor is the Hook that operates on the current
63-
* gradientafter the GradientAccumulator has accumulated the gradient.
64-
* Leaf GradVarBase has no next OpBase, if we want to register hook
65-
* for it, we also need to wait until the leaf GradVarBase accumulation
66-
* is completed, so we can add post hook to GradientAccumulator.
67-
*
68-
* @note This hook functor will change the grad VarBase value.
69-
*
70-
* @note Only allow leaf VarBase hold call this hook functor.
71-
*/
72-
class InplaceVariableWrapperHook {
73-
public:
74-
virtual ~InplaceVariableWrapperHook() = default;
75-
virtual void operator()(VariableWrapper* var) = 0;
76-
};
77-
78-
class LambdaInplaceVariableWrapperHook : public InplaceVariableWrapperHook {
62+
class CppVariableWrapperHook : public VariableWrapperHook {
7963
public:
80-
explicit LambdaInplaceVariableWrapperHook(
81-
std::function<void(VariableWrapper*)>&& fn)
64+
explicit CppVariableWrapperHook(
65+
std::function<std::shared_ptr<VariableWrapper>(
66+
const std::shared_ptr<VariableWrapper>&)>&& fn)
8267
: fn_(std::move(fn)) {}
8368

84-
void operator()(VariableWrapper* var) override { fn_(var); }
69+
std::shared_ptr<VariableWrapper> operator()(
70+
const std::shared_ptr<VariableWrapper>& var) override {
71+
return fn_(var);
72+
}
8573

8674
private:
87-
std::function<void(VariableWrapper*)> fn_;
75+
std::function<std::shared_ptr<VariableWrapper>(
76+
const std::shared_ptr<VariableWrapper>&)>
77+
fn_;
8878
};
8979

9080
} // namespace imperative

paddle/fluid/imperative/layer.h

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -226,23 +226,25 @@ class VarBase {
226226
void BumpInplaceVersion();
227227

228228
/* Hook related method: now only used for GradVarBase */
229-
bool HasHook() const { return var_->HasHook(); }
229+
bool HasVariableWrapperHook() const { return var_->HasVariableWrapperHook(); }
230230

231-
int64_t AddHook(std::shared_ptr<VariableWrapperHook>&& hook) {
232-
return var_->AddHook(
231+
int64_t AddVariableWrapperHook(std::shared_ptr<VariableWrapperHook>&& hook) {
232+
return var_->AddVariableWrapperHook(
233233
std::forward<std::shared_ptr<VariableWrapperHook>>(hook));
234234
}
235235

236-
bool RemoveHook(const int64_t& hook_id) { return var_->RemoveHook(hook_id); }
236+
bool RemoveVariableWrapperHook(const int64_t& hook_id) {
237+
return var_->RemoveVariableWrapperHook(hook_id);
238+
}
237239

238-
const std::map<int64_t, std::shared_ptr<VariableWrapperHook>>& GetHooks()
239-
const {
240-
return var_->GetHooks();
240+
const std::map<int64_t, std::shared_ptr<VariableWrapperHook>>&
241+
GetVariableWrapperHooks() const {
242+
return var_->GetVariableWrapperHooks();
241243
}
242244

243-
void AddMutableHook(std::shared_ptr<InplaceVariableWrapperHook>&& hook) {
244-
var_->AddMutableHook(
245-
std::forward<std::shared_ptr<InplaceVariableWrapperHook>>(hook));
245+
void AddVoidHook(std::shared_ptr<std::function<void()>>&& hook) {
246+
var_->AddVoidHook(
247+
std::forward<std::shared_ptr<std::function<void()>>>(hook));
246248
}
247249

248250
private:

paddle/fluid/imperative/reducer.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -310,9 +310,8 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars,
310310
for (size_t global_var_index = 0; global_var_index < vars_.size();
311311
++global_var_index) {
312312
auto var = vars_[global_var_index];
313-
var->GradVarBase()->AddMutableHook(
314-
std::make_shared<LambdaInplaceVariableWrapperHook>([=](
315-
VariableWrapper *grad) { this->AddDistHook(global_var_index); }));
313+
var->GradVarBase()->AddVoidHook(std::make_shared<std::function<void()>>(
314+
[=]() { this->AddDistHook(global_var_index); }));
316315
var_index_map_[var->GradVarBase()->SharedVar().get()] = global_var_index;
317316
}
318317

paddle/fluid/imperative/tests/test_hooks.cc

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,30 @@ namespace imperative {
3737
using vb_vector = std::vector<std::shared_ptr<imperative::VarBase>>;
3838
using var_pair = std::pair<std::string, vb_vector>;
3939

40+
std::shared_ptr<imperative::VariableWrapper> DoubleHook(
41+
const std::shared_ptr<imperative::VariableWrapper>& var) {
42+
// 1. create out var
43+
auto out_var = std::make_shared<imperative::VariableWrapper>(var->Name());
44+
out_var->SetType(var->Type());
45+
out_var->SetDataType(var->DataType());
46+
out_var->SetForwardDataType(var->ForwardDataType());
47+
out_var->InnerSetOverridedStopGradient(var->InnerOverridedStopGradient());
48+
49+
// 2. get input and output var's tensor
50+
auto* out_tensor = out_var->MutableVar()->GetMutable<framework::LoDTensor>();
51+
auto& tensor = var->Var().Get<framework::LoDTensor>();
52+
out_tensor->Resize(tensor.dims());
53+
54+
// 3. double calc
55+
auto* data = tensor.data<float>();
56+
auto* out_data = out_tensor->mutable_data<float>(platform::CPUPlace());
57+
for (int64_t i = 0; i < out_tensor->numel(); ++i) {
58+
out_data[i] = data[i] * 2.0;
59+
}
60+
61+
return out_var;
62+
}
63+
4064
TEST(TestHooks, TestGradVarLeafBackwardHook) {
4165
// 1. prepare
4266
Tracer tracer;
@@ -73,16 +97,14 @@ TEST(TestHooks, TestGradVarLeafBackwardHook) {
7397
framework::AttributeMap mul_attr_map;
7498
mul_attr_map["use_mkldnn"] = false;
7599

76-
// add GradAccumulatorPostHook
77-
x->GradVarBase()->AddMutableHook(
78-
std::make_shared<LambdaInplaceVariableWrapperHook>(
79-
[=](VariableWrapper* grad) {
80-
auto* grad_tensor =
81-
grad->MutableVar()->GetMutable<framework::LoDTensor>();
82-
for (int i = 0; i < grad_tensor->numel(); ++i) {
83-
grad_tensor->mutable_data<float>(place)[i] *= 2.0;
84-
}
85-
}));
100+
// add VariableWrapper hook
101+
x->GradVarBase()->AddVariableWrapperHook(
102+
std::make_shared<imperative::CppVariableWrapperHook>(DoubleHook));
103+
104+
// add Void hook
105+
int64_t hook_value = 0;
106+
x->GradVarBase()->AddVoidHook(
107+
std::make_shared<std::function<void()>>([&]() { hook_value = 10; }));
86108

87109
// 2. forward
88110
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true);
@@ -98,12 +120,15 @@ TEST(TestHooks, TestGradVarLeafBackwardHook) {
98120
engine.Init(tensors, grad_tensors);
99121
engine.Execute();
100122

123+
// verify VariableWrapper hook result
101124
framework::LoDTensor x_grad;
102125
framework::TensorCopySync(x->GradVar().Get<framework::LoDTensor>(), place,
103126
&x_grad);
104127
for (int i = 0; i < x_grad.numel(); ++i) {
105128
ASSERT_EQ(x_grad.data<float>()[i], 8.0);
106129
}
130+
// verify Void hook result
131+
ASSERT_EQ(hook_value, 10);
107132

108133
framework::LoDTensor y_grad;
109134
framework::TensorCopySync(y->GradVar().Get<framework::LoDTensor>(), place,
@@ -152,16 +177,14 @@ void GradVarLeafBackwardHookWithGradAccmulatedTest() {
152177
memory::Copy(place, mutable_z, place, src_data.data(),
153178
sizeof(float) * src_data.size());
154179

155-
// add ReduceBackwardHook
156-
x->GradVarBase()->AddMutableHook(
157-
std::make_shared<LambdaInplaceVariableWrapperHook>(
158-
[=](VariableWrapper* grad) {
159-
auto* grad_tensor =
160-
grad->MutableVar()->GetMutable<framework::LoDTensor>();
161-
for (int i = 0; i < grad_tensor->numel(); ++i) {
162-
grad_tensor->mutable_data<float>(place)[i] *= 2.0;
163-
}
164-
}));
180+
// add VariableWrapper hook
181+
x->GradVarBase()->AddVariableWrapperHook(
182+
std::make_shared<imperative::CppVariableWrapperHook>(DoubleHook));
183+
184+
// add Void hook
185+
int64_t hook_value = 0;
186+
x->GradVarBase()->AddVoidHook(
187+
std::make_shared<std::function<void()>>([&]() { hook_value = 100; }));
165188

166189
// 2. forward
167190
var_pair x_pair = var_pair("X", vb_vector(1, x));
@@ -199,12 +222,15 @@ void GradVarLeafBackwardHookWithGradAccmulatedTest() {
199222
engine.Init(tensors, grad_tensors);
200223
engine.Execute();
201224

225+
// verify VariableWrapper hook result
202226
framework::LoDTensor x_grad;
203227
framework::TensorCopySync(x->GradVar().Get<framework::LoDTensor>(), place,
204228
&x_grad);
205229
for (int i = 0; i < x_grad.numel(); ++i) {
206230
ASSERT_EQ(x_grad.data<float>()[i], 16.0);
207231
}
232+
// verify Void hook result
233+
ASSERT_EQ(hook_value, 100);
208234

209235
framework::LoDTensor y_grad;
210236
framework::TensorCopySync(y->GradVar().Get<framework::LoDTensor>(), place,

paddle/fluid/imperative/variable_wrapper.h

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -220,35 +220,35 @@ class VariableWrapper {
220220
}
221221

222222
/* Hook related methods */
223-
bool HasHook() const { return !hooks_.empty(); }
223+
bool HasVariableWrapperHook() const { return !var_hooks_.empty(); }
224224

225-
bool HasMutableHook() const { return !mutable_hooks_.empty(); }
226-
227-
int64_t AddHook(std::shared_ptr<VariableWrapperHook>&& hook) {
228-
hooks_.emplace(next_hook_id_, std::move(hook));
225+
int64_t AddVariableWrapperHook(std::shared_ptr<VariableWrapperHook>&& hook) {
226+
var_hooks_.emplace(next_hook_id_, std::move(hook));
229227
return next_hook_id_++;
230228
}
231229

232-
bool RemoveHook(const int64_t& hook_id) {
233-
auto remove_cnt = hooks_.erase(hook_id);
230+
bool RemoveVariableWrapperHook(const int64_t& hook_id) {
231+
auto remove_cnt = var_hooks_.erase(hook_id);
234232
if (remove_cnt == 0) {
235233
return false;
236234
}
237235
return true;
238236
}
239237

240-
const std::map<int64_t, std::shared_ptr<VariableWrapperHook>>& GetHooks()
241-
const {
242-
return hooks_;
238+
const std::map<int64_t, std::shared_ptr<VariableWrapperHook>>&
239+
GetVariableWrapperHooks() const {
240+
return var_hooks_;
243241
}
244242

245-
void AddMutableHook(std::shared_ptr<InplaceVariableWrapperHook>&& hook) {
246-
mutable_hooks_.emplace_back(std::move(hook));
243+
bool HasVoidHook() const { return !void_hooks_.empty(); }
244+
245+
void AddVoidHook(std::shared_ptr<std::function<void()>>&& hook) {
246+
void_hooks_.emplace_back(std::move(hook));
247247
}
248248

249-
const std::vector<std::shared_ptr<InplaceVariableWrapperHook>>&
250-
GetMutableHooks() const {
251-
return mutable_hooks_;
249+
const std::vector<std::shared_ptr<std::function<void()>>>& GetVoidHooks()
250+
const {
251+
return void_hooks_;
252252
}
253253

254254
private:
@@ -319,14 +319,19 @@ class VariableWrapper {
319319
// isn't need
320320
bool is_empty_{false};
321321

322-
// NOTE(chenweihang): only grad var can hold hooks now
322+
// NOTE(chenweihang): only grad var will hold hooks now
323323
int64_t next_hook_id_{0};
324-
// Hooks used to register hook for grad var, support adding and removing,
324+
// [ Hooks with VariableWrapper as input and output ]
325+
// NOTE: Now registered for grad var, support adding and removing,
325326
// key is the accumulated int64_t value
326-
std::map<int64_t, std::shared_ptr<VariableWrapperHook>> hooks_;
327-
// Hooks executed after the execution of the entire backward process is over,
328-
// currently only supported for reducing in distributed training
329-
std::vector<std::shared_ptr<InplaceVariableWrapperHook>> mutable_hooks_;
327+
// NOTE: Var hook need to support removing, so need hook id
328+
std::map<int64_t, std::shared_ptr<VariableWrapperHook>> var_hooks_;
329+
// [ Hooks without input and output ]
330+
// NOTE: Now registered after the execution of the entire backward
331+
// process is over, currently only used for reducing in distributed
332+
// training
333+
// NOTE: Now no need to support remove void hook
334+
std::vector<std::shared_ptr<std::function<void()>>> void_hooks_;
330335
};
331336

332337
} // namespace imperative

0 commit comments

Comments
 (0)