@@ -3320,6 +3320,30 @@ def test_roi_align(self):
33203320 dy_res_value = dy_res .numpy ()
33213321 self .assertTrue (np .array_equal (static_res , dy_res_value ))
33223322
3323+ def test_dice_loss (self ):
3324+ num_classes = 4
3325+ eps = 1e-6
3326+ input_np = np .random .rand (2 , 3 , num_classes ).astype ('float32' )
3327+ label_np = np .random .randint (0 , num_classes , [2 , 3 , 1 ], dtype = np .int64 )
3328+
3329+ with self .static_graph ():
3330+ input_ = layers .data (
3331+ name = "input" , shape = [None , 3 , num_classes ], dtype = "float32" )
3332+ label_ = layers .data (
3333+ name = "label" , shape = [None , 3 , 1 ], dtype = "int64" )
3334+ output = layers .dice_loss (input_ , label_ , eps )
3335+ static_res = self .get_static_graph_result (
3336+ feed = {'input' : input_np ,
3337+ 'label' : label_np },
3338+ fetch_list = [output ])[0 ]
3339+
3340+ with self .dynamic_graph ():
3341+ input_ = base .to_variable (input_np )
3342+ label_ = base .to_variable (label_np )
3343+ dy_res = layers .dice_loss (input_ , label_ , eps )
3344+ dy_res_value = dy_res .numpy ()
3345+ self .assertTrue (np .array_equal (static_res , dy_res_value ))
3346+
33233347 def test_roi_perspective_transform (self ):
33243348 # TODO(minqiyang): dygraph do not support lod now
33253349 with self .static_graph ():
0 commit comments