Skip to content

Commit c073f0e

Browse files
committed
fix some log info error
1 parent 672c052 commit c073f0e

File tree

4 files changed

+31
-25
lines changed

4 files changed

+31
-25
lines changed

paddle/fluid/operators/amp/check_finite_and_unscale_op_xpu.cc

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,11 @@ class CheckFiniteAndUnscaleXPUKernel : public framework::OpKernel<T> {
6969
r = xpu::logical_not(dev_ctx.x_context(), reinterpret_cast<const bool*>(
7070
is_finite.data<bool>()),
7171
is_finite.data<bool>(), x->numel());
72-
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
73-
"XPU API(isfinite) return wrong "
74-
"value[%d %s]",
75-
r, XPUAPIErrorMsg[r]));
72+
PADDLE_ENFORCE_EQ(
73+
r, XPU_SUCCESS,
74+
platform::errors::External("XPU API(logical_not) return wrong "
75+
"value[%d %s]",
76+
r, XPUAPIErrorMsg[r]));
7677
r = xpu::isnan(dev_ctx.x_context(),
7778
reinterpret_cast<const XPUTyp*>(x->data<T>()),
7879
is_nan.data<bool>(), x->numel());
@@ -83,10 +84,11 @@ class CheckFiniteAndUnscaleXPUKernel : public framework::OpKernel<T> {
8384
r = xpu::logical_or(dev_ctx.x_context(), is_finite.data<bool>(),
8485
is_nan.data<bool>(), is_finite.data<bool>(),
8586
x->numel());
86-
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
87-
"XPU API(any) return wrong "
88-
"value[%d %s]",
89-
r, XPUAPIErrorMsg[r]));
87+
PADDLE_ENFORCE_EQ(
88+
r, XPU_SUCCESS,
89+
platform::errors::External("XPU API(logical_or) return wrong "
90+
"value[%d %s]",
91+
r, XPUAPIErrorMsg[r]));
9092
r = xpu::any(dev_ctx.x_context(), is_finite.data<bool>(),
9193
found_inf_data, x->numel());
9294
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
@@ -135,6 +137,9 @@ class CheckFiniteAndUnscaleXPUKernel : public framework::OpKernel<T> {
135137
"XPU API(cast_v2) return wrong "
136138
"value[%d %s]",
137139
r, XPUAPIErrorMsg[r]));
140+
if (dev_ctx.x_context()->xpu_stream) {
141+
dev_ctx.Wait();
142+
}
138143

139144
} else {
140145
int r = xpu::scale(dev_ctx.x_context(),

paddle/fluid/operators/amp/update_loss_scaling_op_xpu.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,10 @@ class UpdateLossScalingXPUKernel : public framework::OpKernel<T> {
5858
r = xpu::constant(dev_ctx.x_context(),
5959
reinterpret_cast<XPUTyp*>(out_data), num,
6060
XPUTyp(0.0));
61-
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
62-
platform::errors::External(
63-
"XPU API return wrong value[%d], please check "
64-
"where Baidu Kunlun Card is properly installed.",
65-
r));
61+
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
62+
"XPU API(constant) return wrong "
63+
"value[%d %s]",
64+
r, XPUAPIErrorMsg[r]));
6665
}
6766
}
6867
const bool stop_update = ctx.Attr<bool>("stop_update");

paddle/fluid/operators/dropout_op_xpu.cc

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -118,25 +118,21 @@ class DropoutGradXPUKernel : public framework::OpKernel<T> {
118118
reinterpret_cast<const XPUTyp*>(mask->data<T>()),
119119
reinterpret_cast<XPUTyp*>(mask_new.data<T>()),
120120
mask->numel(), false, scale, 0.0f);
121-
PADDLE_ENFORCE_EQ(
122-
r, xpu::Error_t::SUCCESS,
123-
platform::errors::External(
124-
"XPU dropout return wrong value[%d], please check whether "
125-
"Baidu Kunlun Card is properly installed.",
126-
r));
121+
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
122+
"XPU API(scale) return wrong "
123+
"value[%d %s]",
124+
r, XPUAPIErrorMsg[r]));
127125
mask_data = mask_new.data<T>();
128126
}
129127

130128
int r = xpu::mul(
131129
dev_ctx.x_context(), reinterpret_cast<const XPUTyp*>(grad_y->data<T>()),
132130
reinterpret_cast<const XPUTyp*>(mask_data),
133131
reinterpret_cast<XPUTyp*>(grad_x->data<T>()), grad_y->numel());
134-
PADDLE_ENFORCE_EQ(
135-
r, xpu::Error_t::SUCCESS,
136-
platform::errors::External(
137-
"XPU dropout return wrong value[%d], please check whether "
138-
"Baidu Kunlun Card is properly installed.",
139-
r));
132+
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
133+
platform::errors::External("XPU API(mul) return wrong "
134+
"value[%d %s]",
135+
r, XPUAPIErrorMsg[r]));
140136
}
141137
};
142138
} // namespace operators

paddle/fluid/operators/elementwise/elementwise_add_op_xpu.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,19 +150,22 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
150150
const T* dz_data = dz->data<T>();
151151
framework::Tensor dx_local_tensor;
152152
framework::Tensor dy_local_tensor;
153+
bool need_wait = false;
153154
T* dx_data = nullptr;
154155
T* dy_data = nullptr;
155156
if (dx) {
156157
dx_data = dx->mutable_data<T>(ctx.GetPlace());
157158
} else {
158159
dx_data =
159160
dx_local_tensor.mutable_data<T>(ctx.GetPlace(), x_len * sizeof(T));
161+
need_wait = true;
160162
}
161163
if (dy) {
162164
dy_data = dy->mutable_data<T>(ctx.GetPlace());
163165
} else {
164166
dy_data =
165167
dy_local_tensor.mutable_data<T>(ctx.GetPlace(), y_len * sizeof(T));
168+
need_wait = true;
166169
}
167170

168171
auto& dev_ctx =
@@ -175,6 +178,9 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
175178
platform::errors::External(
176179
"XPU kernel Elementwise occur error in XPUElementwise error code ",
177180
ret, XPUAPIErrorMsg[ret]));
181+
if (need_wait && dev_ctx.x_context()->xpu_stream) {
182+
dev_ctx.Wait();
183+
}
178184
}
179185
};
180186

0 commit comments

Comments
 (0)