Skip to content

Commit f15a534

Browse files
committed
add dice_loss unittest
1 parent 66c3bfa commit f15a534

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

python/paddle/fluid/tests/unittests/test_layers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3320,6 +3320,30 @@ def test_roi_align(self):
33203320
dy_res_value = dy_res.numpy()
33213321
self.assertTrue(np.array_equal(static_res, dy_res_value))
33223322

3323+
def test_dice_loss(self):
3324+
num_classes = 4
3325+
eps = 1e-6
3326+
input_np = np.random.rand(2, 3, num_classes).astype('float32')
3327+
label_np = np.random.randint(0, num_classes, [2, 3, 1], dtype=np.int64)
3328+
3329+
with self.static_graph():
3330+
input_ = layers.data(
3331+
name="input", shape=[None, 3, num_classes], dtype="float32")
3332+
label_ = layers.data(
3333+
name="label", shape=[None, 3, 1], dtype="int64")
3334+
output = layers.dice_loss(input_, label_, eps)
3335+
static_res = self.get_static_graph_result(
3336+
feed={'input': input_np,
3337+
'label': label_np},
3338+
fetch_list=[output])[0]
3339+
3340+
with self.dynamic_graph():
3341+
input_ = base.to_variable(input_np)
3342+
label_ = base.to_variable(label_np)
3343+
dy_res = layers.dice_loss(input_, label_, eps)
3344+
dy_res_value = dy_res.numpy()
3345+
self.assertTrue(np.array_equal(static_res, dy_res_value))
3346+
33233347
def test_roi_perspective_transform(self):
33243348
# TODO(minqiyang): dygraph do not support lod now
33253349
with self.static_graph():

0 commit comments

Comments
 (0)