Skip to content

Commit f28e01d

Browse files
committed
copy beta pow to same place when skip_update=1
1 parent f29a3c6 commit f28e01d

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

paddle/fluid/operators/optimizers/adam_op.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,11 @@ class AdamOpCUDAKernel : public framework::OpKernel<T> {
198198
*mom2, ctx.GetPlace(),
199199
ctx.template device_context<platform::DeviceContext>(), mom2_out);
200200
framework::TensorCopy(
201-
*beta1_pow, ctx.GetPlace(),
201+
*beta1_pow, beta1_pow->place(),
202202
ctx.template device_context<platform::DeviceContext>(),
203203
beta1_pow_out);
204204
framework::TensorCopy(
205-
*beta2_pow, ctx.GetPlace(),
205+
*beta2_pow, beta2_pow->place(),
206206
ctx.template device_context<platform::DeviceContext>(),
207207
beta2_pow_out);
208208
return;

paddle/fluid/operators/optimizers/adam_op_npu.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,11 @@ class AdamNPUKernel : public framework::OpKernel<T> {
8484
*mom2, ctx.GetPlace(),
8585
ctx.template device_context<platform::DeviceContext>(), mom2_out);
8686
framework::TensorCopy(
87-
*beta1_pow, ctx.GetPlace(),
87+
*beta1_pow, beta1_pow->place(),
8888
ctx.template device_context<platform::DeviceContext>(),
8989
beta1_pow_out);
9090
framework::TensorCopy(
91-
*beta2_pow, ctx.GetPlace(),
91+
*beta2_pow, beta2_pow->place(),
9292
ctx.template device_context<platform::DeviceContext>(),
9393
beta2_pow_out);
9494
return;

paddle/fluid/operators/optimizers/adam_op_xpu.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
8686
mom2, ctx.GetPlace(),
8787
ctx.template device_context<platform::DeviceContext>(), &mom2_out);
8888
framework::TensorCopy(
89-
beta1_pow, ctx.GetPlace(),
89+
*beta1_pow, beta1_pow->place(),
9090
ctx.template device_context<platform::DeviceContext>(),
9191
beta1_pow_out);
9292
framework::TensorCopy(
93-
beta2_pow, ctx.GetPlace(),
93+
*beta2_pow, beta2_pow->place(),
9494
ctx.template device_context<platform::DeviceContext>(),
9595
beta2_pow_out);
9696
return;

0 commit comments

Comments
 (0)