Skip to content

Commit 1221360

Browse files
Eddie-Wang1120pull[bot]
authored andcommitted
[Prim][PIR] support bce_loss op forward in prim pir (#63918)
* add bce forward prim * support dynamic * Update composite.h * Update test_prim_sub_graph_dynamic_shape.py * Update composite.h
1 parent 981a734 commit 1221360

File tree

4 files changed

+35
-2
lines changed

4 files changed

+35
-2
lines changed

paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"any",
2424
"batch_norm",
2525
"batch_norm_",
26+
"bce_loss",
2627
"bmm",
2728
"dropout",
2829
"elu",
@@ -60,6 +61,7 @@
6061
decomp_interface_implementation_gen_op_list = [
6162
"any",
6263
"add_n",
64+
"bce_loss",
6365
"bmm",
6466
"dropout",
6567
"elu",

paddle/fluid/primitive/composite/composite.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,14 @@ Tensor reciprocal_decomp(const Tensor& x) {
251251
return full<T>(empty_shape, 1.0, x.dtype()) / x;
252252
}
253253

254+
template <typename T>
255+
Tensor bce_loss_decomp(const Tensor& x, const Tensor& label) {
256+
auto one = full<T>(empty_shape, 1, x.dtype());
257+
auto ans = full<T>(empty_shape, -1, x.dtype()) *
258+
(label * log<T>(x) + (one - label) * log<T>(one - x));
259+
return ans;
260+
}
261+
254262
template <typename T>
255263
Tensor bmm_decomp(const Tensor& x, const Tensor& y) {
256264
std::size_t x_ndims = x.dims().size();

test/deprecated/legacy_test/test_bce_loss.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,9 @@ def setUp(self):
256256
self.init_test_dtype()
257257
self.init_test_case()
258258
self.op_type = "bce_loss"
259+
self.prim_op_type = "comp"
259260
self.python_api = bce_wrapper
261+
self.public_python_api = bce_wrapper
260262
input_np = np.random.uniform(0.1, 0.8, self.shape).astype(self.dtype)
261263
label_np = np.random.randint(0, 2, self.shape).astype(self.dtype)
262264
output_np = bce_loss(input_np, label_np)
@@ -265,7 +267,7 @@ def setUp(self):
265267
self.outputs = {'Out': output_np}
266268

267269
def test_check_output(self):
268-
self.check_output(check_pir=True)
270+
self.check_output(check_pir=True, check_prim_pir=True)
269271

270272
def test_check_grad(self):
271273
self.check_grad(['X'], 'Out', check_pir=True)
@@ -289,7 +291,7 @@ def init_test_cast(self):
289291

290292
class TestBceLossOpFP16(TestBceLossOp):
291293
def test_check_output(self):
292-
self.check_output(check_pir=True)
294+
self.check_output(check_pir=True, check_prim_pir=True)
293295

294296
def test_check_grad(self):
295297
self.check_grad(['X'], 'Out', check_pir=True)

test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ def index_sample_net(x, index):
7474
return paddle.index_sample(x, index)
7575

7676

77+
def bce_loss_net(x, label):
78+
return paddle._C_ops.bce_loss(x, label)
79+
80+
7781
def swiglu_net1(x, y):
7882
return paddle.incubate.nn.functional.swiglu(x, y)
7983

@@ -318,6 +322,23 @@ def setUp(self):
318322
self.tol = 1e-6
319323

320324

325+
class TestPrimBceLoss(TestPrimTwo):
326+
def setUp(self):
327+
np.random.seed(2023)
328+
self.x_shape = [20, 30, 40, 50]
329+
self.y_shape = [20, 30, 40, 50]
330+
self.dtype_x = "float32"
331+
self.dtype_y = "float32"
332+
self.init_x_shape = [None, None]
333+
self.init_y_shape = [None, None]
334+
self.x = np.random.uniform(0.1, 0.8, self.x_shape).astype(self.dtype_x)
335+
self.y = np.random.randint(0, 2, self.x_shape).astype(self.dtype_y)
336+
self.net = bce_loss_net
337+
self.necessary_ops = "pd_op.bce_loss"
338+
self.enable_cinn = False
339+
self.tol = 1e-6
340+
341+
321342
class TestPrimSwiglu1(TestPrimTwo):
322343
def setUp(self):
323344
np.random.seed(2023)

0 commit comments

Comments
 (0)