1414import  warnings 
1515
1616import  torch 
17+ import  torch .nn .functional  as  F 
1718from  torch .nn .modules .loss  import  _Loss 
1819
1920from  monai .networks  import  one_hot 
2324class  AsymmetricFocalTverskyLoss (_Loss ):
2425    """ 
2526    AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. 
26- 
27-     Actually, it's only supported for binary image segmentation now. 
27+     It treats the background class (index 0) differently from all foreground classes (indices 1...N). 
2828
2929    Reimplementation of the Asymmetric Focal Tversky Loss described in: 
3030
3131    - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", 
32-     Michael Yeung, Computerized Medical Imaging and Graphics 
32+        Michael Yeung, Computerized Medical Imaging and Graphics 
3333    """ 
3434
3535    def  __init__ (
3636        self ,
37-         to_onehot_y : bool  =  False ,
37+         include_background : bool  =  True ,
3838        delta : float  =  0.7 ,
3939        gamma : float  =  0.75 ,
4040        epsilon : float  =  1e-7 ,
4141        reduction : LossReduction  |  str  =  LossReduction .MEAN ,
4242    ) ->  None :
4343        """ 
4444        Args: 
45-             to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. 
46-             delta : weight of the background. Defaults to 0.7. 
47-             gamma : value of the exponent gamma in the definition of the Focal loss  . Defaults to 0.75. 
48-             epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. 
45+             include_background: whether to include loss computation for the background class. Defaults to True. 
46+             delta : weight of the background. Defaults to 0.7. (Used to weigh FNs and FPs in Tversky index) 
47+             gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. 
48+             epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. 
49+             reduction: specifies the reduction to apply to the output: "none", "mean", "sum". 
4950        """ 
5051        super ().__init__ (reduction = LossReduction (reduction ).value )
51-         self .to_onehot_y  =  to_onehot_y 
52+         self .include_background  =  include_background 
5253        self .delta  =  delta 
5354        self .gamma  =  gamma 
5455        self .epsilon  =  epsilon 
5556
5657    def  forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) ->  torch .Tensor :
5758        n_pred_ch  =  y_pred .shape [1 ]
5859
59-         if  self .to_onehot_y :
60-             if  n_pred_ch  ==  1 :
61-                 warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
62-             else :
63-                 y_true  =  one_hot (y_true , num_classes = n_pred_ch )
64- 
6560        if  y_true .shape  !=  y_pred .shape :
6661            raise  ValueError (f"ground truth has different shape ({ y_true .shape }  ) from input ({ y_pred .shape }  )" )
6762
68-         # clip the prediction to avoid NaN 
63+         # Exclude background if needed 
64+         if  not  self .include_background :
65+             if  n_pred_ch  ==  1 :
66+                 warnings .warn ("single channel prediction, `include_background=False` ignored." )
67+             else :
68+                 y_pred  =  y_pred [:, 1 :]
69+                 y_true  =  y_true [:, 1 :]
70+ 
71+         # Clip predictions 
6972        y_pred  =  torch .clamp (y_pred , self .epsilon , 1.0  -  self .epsilon )
7073        axis  =  list (range (2 , len (y_pred .shape )))
7174
@@ -74,167 +77,220 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7477        fn  =  torch .sum (y_true  *  (1  -  y_pred ), dim = axis )
7578        fp  =  torch .sum ((1  -  y_true ) *  y_pred , dim = axis )
7679        dice_class  =  (tp  +  self .epsilon ) /  (tp  +  self .delta  *  fn  +  (1  -  self .delta ) *  fp  +  self .epsilon )
77- 
78-         # Calculate losses separately for each class, enhancing both classes 
79-         back_dice  =  1  -  dice_class [:, 0 ]
80-         fore_dice  =  (1  -  dice_class [:, 1 ]) *  torch .pow (1  -  dice_class [:, 1 ], - self .gamma )
81- 
82-         # Average class scores 
83-         loss  =  torch .mean (torch .stack ([back_dice , fore_dice ], dim = - 1 ))
84-         return  loss 
80+         # dice_class shape is (B, C) 
81+ 
82+         n_classes  =  dice_class .shape [1 ]
83+ 
84+         if  not  self .include_background :
85+             # All classes are foreground, apply foreground logic 
86+             loss  =  torch .pow (1.0  -  dice_class , 1.0  -  self .gamma )  # (B, C) 
87+         elif  n_classes  ==  1 :
88+             # Single class, must be foreground (BG was excluded or not provided) 
89+             loss  =  torch .pow (1.0  -  dice_class , 1.0  -  self .gamma )  # (B, 1) 
90+         else :
91+             # Asymmetric logic: class 0 is BG, others are FG 
92+             back_dice_loss  =  (1.0  -  dice_class [:, 0 ]).unsqueeze (1 )  # (B, 1) 
93+             fore_dice_loss  =  torch .pow (1.0  -  dice_class [:, 1 :], 1.0  -  self .gamma )  # (B, C-1) 
94+             loss  =  torch .cat ([back_dice_loss , fore_dice_loss ], dim = 1 )  # (B, C) 
95+ 
96+         # Apply reduction 
97+         if  self .reduction  ==  LossReduction .MEAN .value :
98+             return  torch .mean (loss )  # mean over batch and classes 
99+         if  self .reduction  ==  LossReduction .SUM .value :
100+             return  torch .sum (loss )  # sum over batch and classes 
101+         if  self .reduction  ==  LossReduction .NONE .value :
102+             return  loss   # returns (B, C) losses 
103+         raise  ValueError (f'Unsupported reduction: { self .reduction }  , available options are ["mean", "sum", "none"].' )
85104
86105
87106class  AsymmetricFocalLoss (_Loss ):
88107    """ 
89-     AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. 
90- 
91-     Actually, it's only supported for binary image segmentation now. 
108+     AsymmetricFocalLoss is a variant of FocalLoss, which attentions to the foreground class. 
109+     It treats the background class (index 0) differently from all foreground classes (indices 1...N). 
110+     Background class (0): applies gamma exponent to (1-p) 
111+     Foreground classes (1..N): no gamma exponent 
92112
93113    Reimplementation of the Asymmetric Focal Loss described in: 
94114
95115    - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", 
96-     Michael Yeung, Computerized Medical Imaging and Graphics 
116+        Michael Yeung, Computerized Medical Imaging and Graphics 
97117    """ 
98118
99119    def  __init__ (
100120        self ,
101-         to_onehot_y : bool  =  False ,
121+         include_background : bool  =  True ,
102122        delta : float  =  0.7 ,
103-         gamma : float  =  2 ,
123+         gamma : float  =  2.0  ,
104124        epsilon : float  =  1e-7 ,
105125        reduction : LossReduction  |  str  =  LossReduction .MEAN ,
106126    ):
107127        """ 
108128        Args: 
109-             to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. 
110-             delta : weight of the background. Defaults to 0.7. 
111-             gamma : value of the exponent gamma in the definition of the Focal loss  . Defaults to 0.75. 
112-             epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. 
129+             include_background: whether to include loss computation for the background class. Defaults to True. 
130+             delta : weight of the foreground. Defaults to 0.7. (1-delta is weight of background) 
131+             gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.0. 
132+             epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. 
133+             reduction: specifies the reduction to apply to the output: "none", "mean", "sum". 
113134        """ 
114135        super ().__init__ (reduction = LossReduction (reduction ).value )
115-         self .to_onehot_y  =  to_onehot_y 
136+         self .include_background  =  include_background 
116137        self .delta  =  delta 
117138        self .gamma  =  gamma 
118139        self .epsilon  =  epsilon 
119140
120141    def  forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) ->  torch .Tensor :
121142        n_pred_ch  =  y_pred .shape [1 ]
122143
123-         if  self .to_onehot_y :
124-             if  n_pred_ch  ==  1 :
125-                 warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
126-             else :
127-                 y_true  =  one_hot (y_true , num_classes = n_pred_ch )
128- 
129144        if  y_true .shape  !=  y_pred .shape :
130145            raise  ValueError (f"ground truth has different shape ({ y_true .shape }  ) from input ({ y_pred .shape }  )" )
131146
132-         y_pred  =  torch .clamp (y_pred , self .epsilon , 1.0  -  self .epsilon )
133-         cross_entropy  =  - y_true  *  torch .log (y_pred )
134- 
135-         back_ce  =  torch .pow (1  -  y_pred [:, 0 ], self .gamma ) *  cross_entropy [:, 0 ]
136-         back_ce  =  (1  -  self .delta ) *  back_ce 
137- 
138-         fore_ce  =  cross_entropy [:, 1 ]
139-         fore_ce  =  self .delta  *  fore_ce 
147+         # Exclude background if needed 
148+         if  not  self .include_background :
149+             if  n_pred_ch  ==  1 :
150+                 warnings .warn ("single channel prediction, `include_background=False` ignored." )
151+             else :
152+                 y_pred  =  y_pred [:, 1 :]
153+                 y_true  =  y_true [:, 1 :]
140154
141-         loss  =  torch .mean (torch .sum (torch .stack ([back_ce , fore_ce ], dim = 1 ), dim = 1 ))
142-         return  loss 
155+         y_pred  =  torch .clamp (y_pred , self .epsilon , 1.0  -  self .epsilon )
156+         cross_entropy  =  - y_true  *  torch .log (y_pred )  # Shape (B, C, H, W, [D]) 
157+ 
158+         n_classes  =  y_pred .shape [1 ]
159+ 
160+         if  not  self .include_background :
161+             # All classes are foreground, apply foreground logic 
162+             loss  =  self .delta  *  cross_entropy   # (B, C, H, W) 
163+         elif  n_classes  ==  1 :
164+             # Single class, must be foreground 
165+             loss  =  self .delta  *  cross_entropy   # (B, 1, H, W) 
166+         else :
167+             # Asymmetric logic: class 0 is BG, others are FG 
168+             # (B, H, W) 
169+             back_ce  =  (1.0  -  self .delta ) *  torch .pow (1.0  -  y_pred [:, 0 ], self .gamma ) *  cross_entropy [:, 0 ]
170+             # (B, C-1, H, W) 
171+             fore_ce  =  self .delta  *  cross_entropy [:, 1 :]
172+             
173+             loss  =  torch .cat ([back_ce .unsqueeze (1 ), fore_ce ], dim = 1 )  # (B, C, H, W) 
174+ 
175+         # Apply reduction 
176+         if  self .reduction  ==  LossReduction .MEAN .value :
177+             return  torch .mean (loss )  # mean over batch, class, and spatial 
178+         if  self .reduction  ==  LossReduction .SUM .value :
179+             return  torch .sum (loss )
180+         if  self .reduction  ==  LossReduction .NONE .value :
181+             return  loss   # returns (B, C, H, W) 
182+         raise  ValueError (f'Unsupported reduction: { self .reduction }  , available options are ["mean", "sum", "none"].' )
143183
144184
145185class  AsymmetricUnifiedFocalLoss (_Loss ):
146186    """ 
147-     AsymmetricUnifiedFocalLoss is a variant of Focal Loss. 
148- 
149-     Actually, it's only supported for binary image segmentation now 
187+     AsymmetricUnifiedFocalLoss is a variant of Focal Loss, combining AsymmetricFocalLoss 
188+     and AsymmetricFocalTverskyLoss. 
150189
151190    Reimplementation of the Asymmetric Unified Focal Tversky Loss described in: 
152191
153192    - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", 
154-     Michael Yeung, Computerized Medical Imaging and Graphics 
193+        Michael Yeung, Computerized Medical Imaging and Graphics 
155194    """ 
156195
157196    def  __init__ (
158197        self ,
198+         include_background : bool  =  True ,
159199        to_onehot_y : bool  =  False ,
160-         num_classes : int  =  2 ,
161-         weight : float  =  0.5 ,
162-         gamma : float  =  0.5 ,
163-         delta : float  =  0.7 ,
200+         sigmoid : bool  =  False ,
201+         softmax : bool  =  False ,
202+         lambda_focal : float  =  0.5 ,
203+         focal_loss_gamma : float  =  2.0 ,
204+         focal_loss_delta : float  =  0.7 ,
205+         tversky_loss_gamma : float  =  0.75 ,
206+         tversky_loss_delta : float  =  0.7 ,
164207        reduction : LossReduction  |  str  =  LossReduction .MEAN ,
165208    ):
166209        """ 
167210        Args: 
211+             include_background: whether to include loss computation for the background class. Defaults to True. 
168212            to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. 
169-             num_classes : number of classes, it only supports 2 now. Defaults to 2. 
170-             delta : weight of the background. Defaults to 0.7. 
171-             gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. 
172-             epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. 
173-             weight : weight for each loss function, if it's none it's 0.5. Defaults to None. 
213+             sigmoid: if True, apply a sigmoid activation to the input y_pred. 
214+             softmax: if True, apply a softmax activation to the input y_pred. 
215+             lambda_focal: the weight for AsymmetricFocalLoss (Cross-Entropy based). 
216+                 The weight for AsymmetricFocalTverskyLoss will be (1 - lambda_focal). Defaults to 0.5. 
217+             focal_loss_gamma: gamma parameter for the AsymmetricFocalLoss component. Defaults to 2.0. 
218+             focal_loss_delta: delta parameter for the AsymmetricFocalLoss component. Defaults to 0.7. 
219+             tversky_loss_gamma: gamma parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.75. 
220+             tversky_loss_delta: delta parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.7. 
221+             reduction: specifies the reduction to apply to the output: "none", "mean", "sum". 
174222
175223        Example: 
176224            >>> import torch 
177225            >>> from monai.losses import AsymmetricUnifiedFocalLoss 
178-             >>> pred = torch.ones ((1,1, 32,32), dtype=torch.float32) 
179-             >>> grnd = torch.ones( (1,1, 32,32), dtype=torch.int64) 
180-             >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True) 
226+             >>> pred = torch.randn ((1, 2,  32,  32), dtype=torch.float32) 
227+             >>> grnd = torch.randint(0, 2,  (1, 1,  32,  32), dtype=torch.int64) 
228+             >>> fl = AsymmetricUnifiedFocalLoss(softmax=True,  to_onehot_y=True) 
181229            >>> fl(pred, grnd) 
182230        """ 
183231        super ().__init__ (reduction = LossReduction (reduction ).value )
232+         self .include_background  =  include_background 
184233        self .to_onehot_y  =  to_onehot_y 
185-         self .num_classes  =  num_classes 
186-         self .gamma  =  gamma 
187-         self .delta  =  delta 
188-         self .weight : float  =  weight 
189-         self .asy_focal_loss  =  AsymmetricFocalLoss (gamma = self .gamma , delta = self .delta )
190-         self .asy_focal_tversky_loss  =  AsymmetricFocalTverskyLoss (gamma = self .gamma , delta = self .delta )
234+         self .sigmoid  =  sigmoid 
235+         self .softmax  =  softmax 
236+         self .lambda_focal  =  lambda_focal 
237+ 
238+         if  sigmoid  and  softmax :
239+             raise  ValueError ("Both sigmoid and softmax cannot be True." )
240+ 
241+         self .asy_focal_loss  =  AsymmetricFocalLoss (
242+             include_background = self .include_background ,
243+             gamma = focal_loss_gamma ,
244+             delta = focal_loss_delta ,
245+             reduction = self .reduction ,
246+         )
247+         self .asy_focal_tversky_loss  =  AsymmetricFocalTverskyLoss (
248+             include_background = self .include_background ,
249+             gamma = tversky_loss_gamma ,
250+             delta = tversky_loss_delta ,
251+             reduction = self .reduction ,
252+         )
191253
192-     # TODO: Implement this  function to support multiple classes segmentation 
193254    def  forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) ->  torch .Tensor :
194255        """ 
195256        Args: 
196-             y_pred : the shape should be BNH[WD], where N is the number of classes. 
197-                 It only supports binary segmentation. 
198-                 The input should be the original logits since it will be transformed by 
199-                     a sigmoid in the forward function. 
200-             y_true : the shape should be BNH[WD], where N is the number of classes. 
201-                 It only supports binary segmentation. 
202- 
203-         Raises: 
204-             ValueError: When input and target are different shape 
205-             ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5 
206-             ValueError: When num_classes 
207-             ValueError: When the number of classes entered does not match the expected number 
257+             y_pred : the shape should be BNH[WD]. 
258+             y_true : the shape should be BNH[WD] or B1H[WD]. 
208259        """ 
209-         if  y_pred .shape  !=  y_true .shape :
210-             raise  ValueError (f"ground truth has different shape ({ y_true .shape }  ) from input ({ y_pred .shape }  )" )
211- 
212-         if  len (y_pred .shape ) !=  4  and  len (y_pred .shape ) !=  5 :
213-             raise  ValueError (f"input shape must be 4 or 5, but got { y_pred .shape }  " )
214- 
215-         if  y_pred .shape [1 ] ==  1 :
216-             y_pred  =  one_hot (y_pred , num_classes = self .num_classes )
217-             y_true  =  one_hot (y_true , num_classes = self .num_classes )
260+         n_pred_ch  =  y_pred .shape [1 ]
218261
219-         if  torch .max (y_true ) !=  self .num_classes  -  1 :
220-             raise  ValueError (f"Please make sure the number of classes is { self .num_classes - 1 }  " )
262+         y_pred_act  =  y_pred 
263+         if  self .sigmoid :
264+             y_pred_act  =  torch .sigmoid (y_pred )
265+         elif  self .softmax :
266+             if  n_pred_ch  ==  1 :
267+                 warnings .warn ("single channel prediction, softmax=True ignored." )
268+             else :
269+                 y_pred_act  =  torch .softmax (y_pred , dim = 1 )
221270
222-         n_pred_ch  =  y_pred .shape [1 ]
223271        if  self .to_onehot_y :
224-             if  n_pred_ch  ==  1 :
272+             if  n_pred_ch  ==  1   and   not   self . sigmoid :
225273                warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
226-             else :
274+             elif  n_pred_ch  >  1  or  self .sigmoid :
275+                 # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion 
276+                 if  y_true .shape [1 ] !=  1 :
277+                     y_true  =  y_true .unsqueeze (1 )
227278                y_true  =  one_hot (y_true , num_classes = n_pred_ch )
279+         
280+         # Ensure y_true has the same shape as y_pred_act 
281+         if  y_true .shape  !=  y_pred_act .shape :
282+              # This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid 
283+             if  y_true .shape [1 ] !=  y_pred_act .shape [1 ] and  y_true .ndim  ==  y_pred_act .ndim  -  1 :
284+                 y_true  =  y_true .unsqueeze (1 ) # Add channel dim 
285+             
286+             if  y_true .shape  !=  y_pred_act .shape :
287+                 raise  ValueError (f"ground truth has different shape ({ y_true .shape }  ) from input ({ y_pred_act .shape }  ) 
288+                                  after  activations / one - hot ")
228289
229-         asy_focal_loss  =  self .asy_focal_loss (y_pred , y_true )
230-         asy_focal_tversky_loss  =  self .asy_focal_tversky_loss (y_pred , y_true )
231290
232-         loss : torch .Tensor  =  self .weight  *  asy_focal_loss  +  (1  -  self .weight ) *  asy_focal_tversky_loss 
291+         f_loss  =  self .asy_focal_loss (y_pred_act , y_true )
292+         t_loss  =  self .asy_focal_tversky_loss (y_pred_act , y_true )
233293
234-         if  self .reduction  ==  LossReduction .SUM .value :
235-             return  torch .sum (loss )  # sum over the batch and channel dims 
236-         if  self .reduction  ==  LossReduction .NONE .value :
237-             return  loss   # returns [N, num_classes] losses 
238-         if  self .reduction  ==  LossReduction .MEAN .value :
239-             return  torch .mean (loss )
240-         raise  ValueError (f'Unsupported reduction: { self .reduction }  , available options are ["mean", "sum", "none"].' )
294+         loss : torch .Tensor  =  self .lambda_focal  *  f_loss  +  (1  -  self .lambda_focal ) *  t_loss 
295+ 
296+         return  loss 
0 commit comments