@@ -37,6 +37,30 @@ namespace imperative {
3737using vb_vector = std::vector<std::shared_ptr<imperative::VarBase>>;
3838using var_pair = std::pair<std::string, vb_vector>;
3939
40+ std::shared_ptr<imperative::VariableWrapper> DoubleHook (
41+ const std::shared_ptr<imperative::VariableWrapper>& var) {
42+ // 1. create out var
43+ auto out_var = std::make_shared<imperative::VariableWrapper>(var->Name ());
44+ out_var->SetType (var->Type ());
45+ out_var->SetDataType (var->DataType ());
46+ out_var->SetForwardDataType (var->ForwardDataType ());
47+ out_var->InnerSetOverridedStopGradient (var->InnerOverridedStopGradient ());
48+
49+ // 2. get input and output var's tensor
50+ auto * out_tensor = out_var->MutableVar ()->GetMutable <framework::LoDTensor>();
51+ auto & tensor = var->Var ().Get <framework::LoDTensor>();
52+ out_tensor->Resize (tensor.dims ());
53+
54+ // 3. double calc
55+ auto * data = tensor.data <float >();
56+ auto * out_data = out_tensor->mutable_data <float >(platform::CPUPlace ());
57+ for (int64_t i = 0 ; i < out_tensor->numel (); ++i) {
58+ out_data[i] = data[i] * 2.0 ;
59+ }
60+
61+ return out_var;
62+ }
63+
4064TEST (TestHooks, TestGradVarLeafBackwardHook) {
4165 // 1. prepare
4266 Tracer tracer;
@@ -73,16 +97,14 @@ TEST(TestHooks, TestGradVarLeafBackwardHook) {
7397 framework::AttributeMap mul_attr_map;
7498 mul_attr_map[" use_mkldnn" ] = false ;
7599
76- // add GradAccumulatorPostHook
77- x->GradVarBase ()->AddMutableHook (
78- std::make_shared<LambdaInplaceVariableWrapperHook>(
79- [=](VariableWrapper* grad) {
80- auto * grad_tensor =
81- grad->MutableVar ()->GetMutable <framework::LoDTensor>();
82- for (int i = 0 ; i < grad_tensor->numel (); ++i) {
83- grad_tensor->mutable_data <float >(place)[i] *= 2.0 ;
84- }
85- }));
100+ // add VariableWrapper hook
101+ x->GradVarBase ()->AddVariableWrapperHook (
102+ std::make_shared<imperative::CppVariableWrapperHook>(DoubleHook));
103+
104+ // add Void hook
105+ int64_t hook_value = 0 ;
106+ x->GradVarBase ()->AddVoidHook (
107+ std::make_shared<std::function<void ()>>([&]() { hook_value = 10 ; }));
86108
87109 // 2. forward
88110 tracer.TraceOp (" mul" , ins, outs, mul_attr_map, place, true );
@@ -98,12 +120,15 @@ TEST(TestHooks, TestGradVarLeafBackwardHook) {
98120 engine.Init (tensors, grad_tensors);
99121 engine.Execute ();
100122
123+ // verify VariableWrapper hook result
101124 framework::LoDTensor x_grad;
102125 framework::TensorCopySync (x->GradVar ().Get <framework::LoDTensor>(), place,
103126 &x_grad);
104127 for (int i = 0 ; i < x_grad.numel (); ++i) {
105128 ASSERT_EQ (x_grad.data <float >()[i], 8.0 );
106129 }
130+ // verify Void hook result
131+ ASSERT_EQ (hook_value, 10 );
107132
108133 framework::LoDTensor y_grad;
109134 framework::TensorCopySync (y->GradVar ().Get <framework::LoDTensor>(), place,
@@ -152,16 +177,14 @@ void GradVarLeafBackwardHookWithGradAccmulatedTest() {
152177 memory::Copy (place, mutable_z, place, src_data.data (),
153178 sizeof (float ) * src_data.size ());
154179
155- // add ReduceBackwardHook
156- x->GradVarBase ()->AddMutableHook (
157- std::make_shared<LambdaInplaceVariableWrapperHook>(
158- [=](VariableWrapper* grad) {
159- auto * grad_tensor =
160- grad->MutableVar ()->GetMutable <framework::LoDTensor>();
161- for (int i = 0 ; i < grad_tensor->numel (); ++i) {
162- grad_tensor->mutable_data <float >(place)[i] *= 2.0 ;
163- }
164- }));
180+ // add VariableWrapper hook
181+ x->GradVarBase ()->AddVariableWrapperHook (
182+ std::make_shared<imperative::CppVariableWrapperHook>(DoubleHook));
183+
184+ // add Void hook
185+ int64_t hook_value = 0 ;
186+ x->GradVarBase ()->AddVoidHook (
187+ std::make_shared<std::function<void ()>>([&]() { hook_value = 100 ; }));
165188
166189 // 2. forward
167190 var_pair x_pair = var_pair (" X" , vb_vector (1 , x));
@@ -199,12 +222,15 @@ void GradVarLeafBackwardHookWithGradAccmulatedTest() {
199222 engine.Init (tensors, grad_tensors);
200223 engine.Execute ();
201224
225+ // verify VariableWrapper hook result
202226 framework::LoDTensor x_grad;
203227 framework::TensorCopySync (x->GradVar ().Get <framework::LoDTensor>(), place,
204228 &x_grad);
205229 for (int i = 0 ; i < x_grad.numel (); ++i) {
206230 ASSERT_EQ (x_grad.data <float >()[i], 16.0 );
207231 }
232+ // verify Void hook result
233+ ASSERT_EQ (hook_value, 100 );
208234
209235 framework::LoDTensor y_grad;
210236 framework::TensorCopySync (y->GradVar ().Get <framework::LoDTensor>(), place,
0 commit comments