Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,10 +350,11 @@ def npair_loss(anchor, positive, labels, l2_reg=0.002):
2.94269347)

"""
if anchor.size == 0:
raise ValueError("The dims of anchor should be greater than 0.")
if positive.size == 0:
raise ValueError("The dims of positive should be greater than 0.")
if in_dynamic_mode():
if anchor.size == 0:
raise ValueError("The dims of anchor should be greater than 0.")
if positive.size == 0:
raise ValueError("The dims of positive should be greater than 0.")
check_variable_and_dtype(
anchor, 'anchor', ['float32', 'float64'], 'npair_loss'
)
Expand Down
130 changes: 69 additions & 61 deletions test/legacy_test/test_npair_loss_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

import paddle
from paddle import base
from paddle.base import Program, core, program_guard
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


def npairloss(anchor, positive, labels, l2_reg=0.002):
Expand Down Expand Up @@ -58,74 +59,81 @@ def __assert_close(self, tensor, np_array, msg, atol=1e-4):
np.array(tensor), np_array, rtol=1e-05, atol=atol, err_msg=msg
)

@test_with_pir_api
def test_npair_loss(self):
reg_lambda = 0.002
num_data, feat_dim, num_classes = 18, 6, 3

place = core.CPUPlace()
exe = base.Executor(place)
exe.run(base.default_startup_program())

embeddings_anchor = np.random.rand(num_data, feat_dim).astype(
np.float32
)
embeddings_positive = np.random.rand(num_data, feat_dim).astype(
np.float32
)
row_labels = np.random.randint(0, num_classes, size=(num_data)).astype(
np.float32
)
out_loss = npairloss(
embeddings_anchor,
embeddings_positive,
row_labels,
l2_reg=reg_lambda,
)
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
reg_lambda = 0.002
num_data, feat_dim, num_classes = 18, 6, 3

place = core.CPUPlace()
exe = base.Executor(place)
exe.run(startup)

embeddings_anchor = np.random.rand(num_data, feat_dim).astype(
np.float32
)
embeddings_positive = np.random.rand(num_data, feat_dim).astype(
np.float32
)
row_labels = np.random.randint(
0, num_classes, size=(num_data)
).astype(np.float32)
out_loss = npairloss(
embeddings_anchor,
embeddings_positive,
row_labels,
l2_reg=reg_lambda,
)

anc = paddle.static.data(
dtype='float32',
name='anc',
shape=embeddings_anchor.shape,
)
pos = paddle.static.data(
dtype='float32',
name='pos',
shape=embeddings_positive.shape,
)
lab = paddle.static.data(
dtype='float32',
name='lab',
shape=row_labels.shape,
)
anc = paddle.static.data(
dtype='float32',
name='anc',
shape=embeddings_anchor.shape,
)
pos = paddle.static.data(
dtype='float32',
name='pos',
shape=embeddings_positive.shape,
)
lab = paddle.static.data(
dtype='float32',
name='lab',
shape=row_labels.shape,
)

npair_loss_op = paddle.nn.functional.npair_loss(
anchor=anc, positive=pos, labels=lab, l2_reg=reg_lambda
)
out_tensor = exe.run(
feed={
'anc': embeddings_anchor,
'pos': embeddings_positive,
'lab': row_labels,
},
fetch_list=[npair_loss_op.name],
)
npair_loss_op = paddle.nn.functional.npair_loss(
anchor=anc, positive=pos, labels=lab, l2_reg=reg_lambda
)
out_tensor = exe.run(
feed={
'anc': embeddings_anchor,
'pos': embeddings_positive,
'lab': row_labels,
},
fetch_list=[npair_loss_op],
)

self.__assert_close(
out_tensor,
out_loss,
"inference output are different at "
+ str(place)
+ ", "
+ str(np.dtype('float32'))
+ str(np.array(out_tensor))
+ str(out_loss),
atol=1e-3,
)
self.__assert_close(
out_tensor,
out_loss,
"inference output are different at "
+ str(place)
+ ", "
+ str(np.dtype('float32'))
+ str(np.array(out_tensor))
+ str(out_loss),
atol=1e-3,
)


class TestNpairLossOpError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
with program_guard(Program(), Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
anchor_np = np.random.random((2, 4)).astype("float32")
positive_np = np.random.random((2, 4)).astype("float32")
labels_np = np.random.random(2).astype("float32")
Expand Down