@@ -129,6 +129,10 @@ static std::vector<paddle::Tensor> Trans2ContiguousTensors(
129129 return res;
130130}
131131
132+ inline int64_t hash_int_value (int64_t value) {
133+ return value + 0x9e3779b9 + (value << 6 ) + (value >> 2 );
134+ }
135+
132136inline void run_program_ad_func (
133137 const std::vector<paddle::Tensor>& x,
134138 const std::vector<paddle::Tensor>& params,
@@ -155,7 +159,22 @@ inline void run_program_ad_func(
155159 auto params_tmp = Trans2ContiguousTensors (params);
156160 // Call forward function
157161 // if require_any_grad is False, don't save any middle vars.
158- RunProgramAPI (x_tmp, params_tmp, out, step_scope, require_any_grad, attrs);
162+ std::vector<int64_t > place_hash_keys = std::vector<int64_t >();
163+ for (const paddle::Tensor& tensor : x) {
164+ int64_t device_type = static_cast <int64_t >(tensor.place ().GetType ());
165+ place_hash_keys.emplace_back (hash_int_value (device_type));
166+ }
167+ for (const paddle::Tensor& tensor : params) {
168+ int64_t device_type = static_cast <int64_t >(tensor.place ().GetType ());
169+ place_hash_keys.emplace_back (hash_int_value (device_type));
170+ }
171+ RunProgramAPI (x_tmp,
172+ params_tmp,
173+ out,
174+ step_scope,
175+ require_any_grad,
176+ attrs,
177+ place_hash_keys);
159178 VLOG (2 ) << " start run run_program grad" ;
160179 auto is_test = false ;
161180 if (attrs.count (" is_test" )) {
@@ -168,6 +187,9 @@ inline void run_program_ad_func(
168187 // Create GradOpNode (1 means [out_grad], 2 means [x_grad, paramx_grad])
169188 auto grad_node = std::make_shared<GradNodeRunProgram>(1 , 2 );
170189
190+ // Set place hash keys for backward
191+ grad_node->SetPlaceHashKeys (place_hash_keys);
192+
171193 // Set Attributes
172194 grad_node->SetAttrMap (attrs);
173195
@@ -266,9 +288,27 @@ inline void pir_run_program_ad_func(
266288
267289 // Call forward function
268290 // if require_any_grad is False, don't save any middle vars.
269- PirRunProgramAPI (
270- x, params, out, middles, step_scope, require_any_grad, attrs);
291+ std::vector<int64_t > place_hash_keys = std::vector<int64_t >();
292+ for (const paddle::Tensor& tensor : x) {
293+ int64_t device_type = static_cast <int64_t >(tensor.place ().GetType ());
294+ place_hash_keys.emplace_back (hash_int_value (device_type));
295+ }
296+ for (const paddle::Tensor& tensor : params) {
297+ int64_t device_type = static_cast <int64_t >(tensor.place ().GetType ());
298+ place_hash_keys.emplace_back (hash_int_value (device_type));
299+ }
300+ PirRunProgramAPI (x,
301+ params,
302+ out,
303+ middles,
304+ step_scope,
305+ require_any_grad,
306+ attrs,
307+ place_hash_keys);
271308 if (!is_test && require_any_grad) {
309+ // Set place hash keys for backward
310+ grad_node->SetPlaceHashKeys (place_hash_keys);
311+
272312 // Set Attributes
273313 grad_node->SetAttrMap (attrs);
274314
0 commit comments