@@ -91,28 +91,37 @@ def test_script(self):
9191 test_input = torch .ones (2 , 1 , 8 , 8 )
9292 test_script_save (loss , test_input , test_input )
9393
94- def test_result_with_alpha (self ):
94+ @parameterized .expand ([
95+ ("sum_None_0.5_0.25" , "sum" , None , 0.5 , 0.25 ),
96+ ("sum_weight_0.5_0.25" , "sum" , torch .tensor ([1.0 , 1.0 , 2.0 ]), 0.5 , 0.25 ),
97+ ("sum_weight_tuple_0.5_0.25" , "sum" , (3 , 2.0 , 1 ), 0.5 , 0.25 ),
98+ ("mean_None_0.5_0.25" , "mean" , None , 0.5 , 0.25 ),
99+ ("mean_weight_0.5_0.25" , "mean" , torch .tensor ([1.0 , 1.0 , 2.0 ]), 0.5 , 0.25 ),
100+ ("mean_weight_tuple_0.5_0.25" , "mean" , (3 , 2.0 , 1 ), 0.5 , 0.25 ),
101+ ("none_None_0.5_0.25" , "none" , None , 0.5 , 0.25 ),
102+ ("none_weight_0.5_0.25" , "none" , torch .tensor ([1.0 , 1.0 , 2.0 ]), 0.5 , 0.25 ),
103+ ("none_weight_tuple_0.5_0.25" , "none" , (3 , 2.0 , 1 ), 0.5 , 0.25 ),
104+ ])
105+ def test_with_alpha (self , name , reduction , weight , lambda_focal , alpha ):
95106 size = [3 , 3 , 5 , 5 ]
96107 label = torch .randint (low = 0 , high = 2 , size = size )
97108 pred = torch .randn (size )
98- alpha_values = [0.25 , 0.5 , 0.75 ]
99- for reduction in ["sum" , "mean" , "none" ]:
100- for weight in [None , torch .tensor ([1.0 , 1.0 , 2.0 ]), (3 , 2.0 , 1 )]:
101- common_params = {
102- "include_background" : True ,
103- "to_onehot_y" : False ,
104- "reduction" : reduction ,
105- "weight" : weight ,
106- }
107- for lambda_focal in [0.5 , 1.0 , 1.5 ]:
108- for alpha in alpha_values :
109- dice_focal = DiceFocalLoss (gamma = 1.0 , lambda_focal = lambda_focal , alpha = alpha , ** common_params )
110- dice = DiceLoss (** common_params )
111- focal = FocalLoss (gamma = 1.0 , alpha = alpha , ** common_params )
112- result = dice_focal (pred , label )
113- expected_val = dice (pred , label ) + lambda_focal * focal (pred , label )
114- np .testing .assert_allclose (result , expected_val )
115109
110+ common_params = {
111+ "include_background" : True ,
112+ "to_onehot_y" : False ,
113+ "reduction" : reduction ,
114+ "weight" : weight ,
115+ }
116+
117+ dice_focal = DiceFocalLoss (gamma = 1.0 , lambda_focal = lambda_focal , alpha = alpha , ** common_params )
118+ dice = DiceLoss (** common_params )
119+ focal = FocalLoss (gamma = 1.0 , alpha = alpha , ** common_params )
120+
121+ result = dice_focal (pred , label )
122+ expected_val = dice (pred , label ) + lambda_focal * focal (pred , label )
123+
124+ np .testing .assert_allclose (result , expected_val , err_msg = f"Failed on case: { name } " )
116125
117126if __name__ == "__main__" :
118127 unittest .main ()
0 commit comments