Skip to content

Commit c3ef9a5

Browse files
ooooo-createwanghuancoder
authored andcommitted
[Accuracy diff No.21] Fix accuracy diff for heaviside API (PaddlePaddle#72894)
1 parent 84b2680 commit c3ef9a5

File tree

3 files changed

+40
-2
lines changed

3 files changed

+40
-2
lines changed

paddle/phi/kernels/cpu/elementwise_kernel.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,15 @@ void HeavisideKernel(const Context& dev_ctx,
7474
DenseTensor* out) {
7575
// allocate memory for out
7676
dev_ctx.template Alloc<T>(out);
77-
funcs::ElementwiseCompute<funcs::ElementwiseHeavisideFunctor<T>, T>(
78-
dev_ctx, x, y, funcs::ElementwiseHeavisideFunctor<T>(), out);
77+
auto x_dims = x.dims();
78+
auto y_dims = y.dims();
79+
if (x_dims.size() >= y_dims.size()) {
80+
funcs::ElementwiseCompute<funcs::ElementwiseHeavisideFunctor<T>, T>(
81+
dev_ctx, x, y, funcs::ElementwiseHeavisideFunctor<T>(), out);
82+
} else {
83+
funcs::ElementwiseCompute<funcs::ElementwiseInverseHeavisideFunctor<T>, T>(
84+
dev_ctx, x, y, funcs::ElementwiseInverseHeavisideFunctor<T>(), out);
85+
}
7986
}
8087

8188
template <typename T, typename Context>

paddle/phi/kernels/funcs/elementwise_functor.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,13 @@ struct ElementwiseHeavisideFunctor {
738738
}
739739
};
740740

741+
template <typename T>
742+
struct ElementwiseInverseHeavisideFunctor {
743+
inline HOSTDEVICE T operator()(const T a, const T b) const {
744+
return b == static_cast<T>(0) ? a : static_cast<T>(b > static_cast<T>(0));
745+
}
746+
};
747+
741748
template <typename T, typename Enable = void>
742749
struct FloorDivideFunctor {
743750
inline HOSTDEVICE T operator()(const T a, const T b) const {

test/legacy_test/test_elementwise_heaviside_op.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,30 @@ def setUp(self):
205205
self.outputs = {'Out': np.heaviside(self.inputs['X'], self.inputs['Y'])}
206206

207207

208+
class TestElementwiseOp3(TestElementwiseOp):
209+
def setUp(self):
210+
self.op_type = "elementwise_heaviside"
211+
x = np.random.uniform(-10, 10, [100]).astype("float64")
212+
y = np.random.uniform(-10, 10, [3, 100]).astype("float64")
213+
self.python_api = paddle.heaviside
214+
self.prim_op_type = "comp"
215+
self.public_python_api = paddle.heaviside
216+
self.inputs = {'X': x, 'Y': y}
217+
self.outputs = {'Out': np.heaviside(self.inputs['X'], self.inputs['Y'])}
218+
219+
220+
class TestElementwiseOp4(TestElementwiseOp):
221+
def setUp(self):
222+
self.op_type = "elementwise_heaviside"
223+
x = np.random.uniform(0, 10, []).astype("float64")
224+
y = np.random.uniform(-10, 0, [2, 3, 20]).astype("float64")
225+
self.python_api = paddle.heaviside
226+
self.prim_op_type = "comp"
227+
self.public_python_api = paddle.heaviside
228+
self.inputs = {'X': x, 'Y': y}
229+
self.outputs = {'Out': np.heaviside(self.inputs['X'], self.inputs['Y'])}
230+
231+
208232
class TestHeavisideFP16Op(OpTest):
209233
def setUp(self):
210234
self.dtype = np.float16

0 commit comments

Comments
 (0)