1414import warnings
1515
1616import torch
17+ import torch .nn .functional as F
1718from torch .nn .modules .loss import _Loss
1819
1920from monai .networks import one_hot
2324class AsymmetricFocalTverskyLoss (_Loss ):
2425 """
2526 AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
26-
27- Actually, it's only supported for binary image segmentation now.
27+ It treats the background class (index 0) differently from all foreground classes (indices 1...N).
2828
2929 Reimplementation of the Asymmetric Focal Tversky Loss described in:
3030
3131 - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
32- Michael Yeung, Computerized Medical Imaging and Graphics
32+ Michael Yeung, Computerized Medical Imaging and Graphics
3333 """
3434
3535 def __init__ (
3636 self ,
37- to_onehot_y : bool = False ,
37+ include_background : bool = True ,
3838 delta : float = 0.7 ,
3939 gamma : float = 0.75 ,
4040 epsilon : float = 1e-7 ,
4141 reduction : LossReduction | str = LossReduction .MEAN ,
4242 ) -> None :
4343 """
4444 Args:
45- to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
46- delta : weight of the background. Defaults to 0.7.
47- gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
48- epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
45+ include_background: whether to include loss computation for the background class. Defaults to True.
46+ delta : weight of the background. Defaults to 0.7. (Used to weigh FNs and FPs in Tversky index)
47+ gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
48+ epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
49+ reduction: specifies the reduction to apply to the output: "none", "mean", "sum".
4950 """
5051 super ().__init__ (reduction = LossReduction (reduction ).value )
51- self .to_onehot_y = to_onehot_y
52+ self .include_background = include_background
5253 self .delta = delta
5354 self .gamma = gamma
5455 self .epsilon = epsilon
5556
5657 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
5758 n_pred_ch = y_pred .shape [1 ]
5859
59- if self .to_onehot_y :
60- if n_pred_ch == 1 :
61- warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
62- else :
63- y_true = one_hot (y_true , num_classes = n_pred_ch )
64-
6560 if y_true .shape != y_pred .shape :
6661 raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred .shape } )" )
6762
68- # clip the prediction to avoid NaN
63+ # Exclude background if needed
64+ if not self .include_background :
65+ if n_pred_ch == 1 :
66+ warnings .warn ("single channel prediction, `include_background=False` ignored." )
67+ else :
68+ y_pred = y_pred [:, 1 :]
69+ y_true = y_true [:, 1 :]
70+
71+ # Clip predictions
6972 y_pred = torch .clamp (y_pred , self .epsilon , 1.0 - self .epsilon )
7073 axis = list (range (2 , len (y_pred .shape )))
7174
@@ -74,167 +77,219 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7477 fn = torch .sum (y_true * (1 - y_pred ), dim = axis )
7578 fp = torch .sum ((1 - y_true ) * y_pred , dim = axis )
7679 dice_class = (tp + self .epsilon ) / (tp + self .delta * fn + (1 - self .delta ) * fp + self .epsilon )
77-
78- # Calculate losses separately for each class, enhancing both classes
79- back_dice = 1 - dice_class [:, 0 ]
80- fore_dice = (1 - dice_class [:, 1 ]) * torch .pow (1 - dice_class [:, 1 ], - self .gamma )
81-
82- # Average class scores
83- loss = torch .mean (torch .stack ([back_dice , fore_dice ], dim = - 1 ))
84- return loss
80+ # dice_class shape is (B, C)
81+
82+ n_classes = dice_class .shape [1 ]
83+
84+ if not self .include_background :
85+ # All classes are foreground, apply foreground logic
86+ loss = torch .pow (1.0 - dice_class , 1.0 - self .gamma ) # (B, C)
87+ elif n_classes == 1 :
88+ # Single class, must be foreground (BG was excluded or not provided)
89+ loss = torch .pow (1.0 - dice_class , 1.0 - self .gamma ) # (B, 1)
90+ else :
91+ # Asymmetric logic: class 0 is BG, others are FG
92+ back_dice_loss = (1.0 - dice_class [:, 0 ]).unsqueeze (1 ) # (B, 1)
93+ fore_dice_loss = torch .pow (1.0 - dice_class [:, 1 :], 1.0 - self .gamma ) # (B, C-1)
94+ loss = torch .cat ([back_dice_loss , fore_dice_loss ], dim = 1 ) # (B, C)
95+
96+ # Apply reduction
97+ if self .reduction == LossReduction .MEAN .value :
98+ return torch .mean (loss ) # mean over batch and classes
99+ if self .reduction == LossReduction .SUM .value :
100+ return torch .sum (loss ) # sum over batch and classes
101+ if self .reduction == LossReduction .NONE .value :
102+ return loss # returns (B, C) losses
103+ raise ValueError (f'Unsupported reduction: { self .reduction } , available options are ["mean", "sum", "none"].' )
85104
86105
87106class AsymmetricFocalLoss (_Loss ):
88107 """
89- AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
90-
91- Actually, it's only supported for binary image segmentation now.
108+ AsymmetricFocalLoss is a variant of FocalLoss, which attentions to the foreground class.
109+ It treats the background class (index 0) differently from all foreground classes (indices 1...N).
110+ Background class (0): applies gamma exponent to (1-p)
111+ Foreground classes (1..N): no gamma exponent
92112
93113 Reimplementation of the Asymmetric Focal Loss described in:
94114
95115 - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
96- Michael Yeung, Computerized Medical Imaging and Graphics
116+ Michael Yeung, Computerized Medical Imaging and Graphics
97117 """
98118
99119 def __init__ (
100120 self ,
101- to_onehot_y : bool = False ,
121+ include_background : bool = True ,
102122 delta : float = 0.7 ,
103- gamma : float = 2 ,
123+ gamma : float = 2.0 ,
104124 epsilon : float = 1e-7 ,
105125 reduction : LossReduction | str = LossReduction .MEAN ,
106126 ):
107127 """
108128 Args:
109- to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
110- delta : weight of the background. Defaults to 0.7.
111- gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
112- epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
129+ include_background: whether to include loss computation for the background class. Defaults to True.
130+ delta : weight of the foreground. Defaults to 0.7. (1-delta is weight of background)
131+ gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.0.
132+ epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
133+ reduction: specifies the reduction to apply to the output: "none", "mean", "sum".
113134 """
114135 super ().__init__ (reduction = LossReduction (reduction ).value )
115- self .to_onehot_y = to_onehot_y
136+ self .include_background = include_background
116137 self .delta = delta
117138 self .gamma = gamma
118139 self .epsilon = epsilon
119140
120141 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
121142 n_pred_ch = y_pred .shape [1 ]
122143
123- if self .to_onehot_y :
124- if n_pred_ch == 1 :
125- warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
126- else :
127- y_true = one_hot (y_true , num_classes = n_pred_ch )
128-
129144 if y_true .shape != y_pred .shape :
130145 raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred .shape } )" )
131146
132- y_pred = torch .clamp (y_pred , self .epsilon , 1.0 - self .epsilon )
133- cross_entropy = - y_true * torch .log (y_pred )
134-
135- back_ce = torch .pow (1 - y_pred [:, 0 ], self .gamma ) * cross_entropy [:, 0 ]
136- back_ce = (1 - self .delta ) * back_ce
137-
138- fore_ce = cross_entropy [:, 1 ]
139- fore_ce = self .delta * fore_ce
147+ # Exclude background if needed
148+ if not self .include_background :
149+ if n_pred_ch == 1 :
150+ warnings .warn ("single channel prediction, `include_background=False` ignored." )
151+ else :
152+ y_pred = y_pred [:, 1 :]
153+ y_true = y_true [:, 1 :]
140154
141- loss = torch .mean (torch .sum (torch .stack ([back_ce , fore_ce ], dim = 1 ), dim = 1 ))
142- return loss
155+ y_pred = torch .clamp (y_pred , self .epsilon , 1.0 - self .epsilon )
156+ cross_entropy = - y_true * torch .log (y_pred ) # Shape (B, C, H, W, [D])
157+
158+ n_classes = y_pred .shape [1 ]
159+
160+ if not self .include_background :
161+ # All classes are foreground, apply foreground logic
162+ loss = self .delta * cross_entropy # (B, C, H, W)
163+ elif n_classes == 1 :
164+ # Single class, must be foreground
165+ loss = self .delta * cross_entropy # (B, 1, H, W)
166+ else :
167+ # Asymmetric logic: class 0 is BG, others are FG
168+ # (B, H, W)
169+ back_ce = (1.0 - self .delta ) * torch .pow (1.0 - y_pred [:, 0 ], self .gamma ) * cross_entropy [:, 0 ]
170+ # (B, C-1, H, W)
171+ fore_ce = self .delta * cross_entropy [:, 1 :]
172+
173+ loss = torch .cat ([back_ce .unsqueeze (1 ), fore_ce ], dim = 1 ) # (B, C, H, W)
174+
175+ # Apply reduction
176+ if self .reduction == LossReduction .MEAN .value :
177+ return torch .mean (loss ) # mean over batch, class, and spatial
178+ if self .reduction == LossReduction .SUM .value :
179+ return torch .sum (loss )
180+ if self .reduction == LossReduction .NONE .value :
181+ return loss # returns (B, C, H, W)
182+ raise ValueError (f'Unsupported reduction: { self .reduction } , available options are ["mean", "sum", "none"].' )
143183
144184
145185class AsymmetricUnifiedFocalLoss (_Loss ):
146186 """
147- AsymmetricUnifiedFocalLoss is a variant of Focal Loss.
148-
149- Actually, it's only supported for binary image segmentation now
187+ AsymmetricUnifiedFocalLoss is a variant of Focal Loss, combining AsymmetricFocalLoss
188+ and AsymmetricFocalTverskyLoss.
150189
151190 Reimplementation of the Asymmetric Unified Focal Tversky Loss described in:
152191
153192 - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
154- Michael Yeung, Computerized Medical Imaging and Graphics
193+ Michael Yeung, Computerized Medical Imaging and Graphics
155194 """
156195
157196 def __init__ (
158197 self ,
198+ include_background : bool = True ,
159199 to_onehot_y : bool = False ,
160- num_classes : int = 2 ,
161- weight : float = 0.5 ,
162- gamma : float = 0.5 ,
163- delta : float = 0.7 ,
200+ sigmoid : bool = False ,
201+ softmax : bool = False ,
202+ lambda_focal : float = 0.5 ,
203+ focal_loss_gamma : float = 2.0 ,
204+ focal_loss_delta : float = 0.7 ,
205+ tversky_loss_gamma : float = 0.75 ,
206+ tversky_loss_delta : float = 0.7 ,
164207 reduction : LossReduction | str = LossReduction .MEAN ,
165208 ):
166209 """
167210 Args:
211+ include_background: whether to include loss computation for the background class. Defaults to True.
168212 to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
169- num_classes : number of classes, it only supports 2 now. Defaults to 2.
170- delta : weight of the background. Defaults to 0.7.
171- gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
172- epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
173- weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
213+ sigmoid: if True, apply a sigmoid activation to the input y_pred.
214+ softmax: if True, apply a softmax activation to the input y_pred.
215+ lambda_focal: the weight for AsymmetricFocalLoss (Cross-Entropy based).
216+ The weight for AsymmetricFocalTverskyLoss will be (1 - lambda_focal). Defaults to 0.5.
217+ focal_loss_gamma: gamma parameter for the AsymmetricFocalLoss component. Defaults to 2.0.
218+ focal_loss_delta: delta parameter for the AsymmetricFocalLoss component. Defaults to 0.7.
219+ tversky_loss_gamma: gamma parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.75.
220+ tversky_loss_delta: delta parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.7.
221+ reduction: specifies the reduction to apply to the output: "none", "mean", "sum".
174222
175223 Example:
176224 >>> import torch
177225 >>> from monai.losses import AsymmetricUnifiedFocalLoss
178- >>> pred = torch.ones ((1,1, 32,32), dtype=torch.float32)
179- >>> grnd = torch.ones( (1,1, 32,32), dtype=torch.int64)
180- >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True)
226+ >>> pred = torch.randn ((1, 2, 32, 32), dtype=torch.float32)
227+ >>> grnd = torch.randint(0, 2, (1, 1, 32, 32), dtype=torch.int64)
228+ >>> fl = AsymmetricUnifiedFocalLoss(softmax=True, to_onehot_y=True)
181229 >>> fl(pred, grnd)
182230 """
183231 super ().__init__ (reduction = LossReduction (reduction ).value )
232+ self .include_background = include_background
184233 self .to_onehot_y = to_onehot_y
185- self .num_classes = num_classes
186- self .gamma = gamma
187- self .delta = delta
188- self .weight : float = weight
189- self .asy_focal_loss = AsymmetricFocalLoss (gamma = self .gamma , delta = self .delta )
190- self .asy_focal_tversky_loss = AsymmetricFocalTverskyLoss (gamma = self .gamma , delta = self .delta )
234+ self .sigmoid = sigmoid
235+ self .softmax = softmax
236+ self .lambda_focal = lambda_focal
237+
238+ if sigmoid and softmax :
239+ raise ValueError ("Both sigmoid and softmax cannot be True." )
240+
241+ self .asy_focal_loss = AsymmetricFocalLoss (
242+ include_background = self .include_background ,
243+ gamma = focal_loss_gamma ,
244+ delta = focal_loss_delta ,
245+ reduction = self .reduction ,
246+ )
247+ self .asy_focal_tversky_loss = AsymmetricFocalTverskyLoss (
248+ include_background = self .include_background ,
249+ gamma = tversky_loss_gamma ,
250+ delta = tversky_loss_delta ,
251+ reduction = self .reduction ,
252+ )
191253
192- # TODO: Implement this function to support multiple classes segmentation
193254 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
194255 """
195256 Args:
196- y_pred : the shape should be BNH[WD], where N is the number of classes.
197- It only supports binary segmentation.
198- The input should be the original logits since it will be transformed by
199- a sigmoid in the forward function.
200- y_true : the shape should be BNH[WD], where N is the number of classes.
201- It only supports binary segmentation.
202-
203- Raises:
204- ValueError: When input and target are different shape
205- ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5
206- ValueError: When num_classes
207- ValueError: When the number of classes entered does not match the expected number
257+ y_pred : the shape should be BNH[WD].
258+ y_true : the shape should be BNH[WD] or B1H[WD].
208259 """
209- if y_pred .shape != y_true .shape :
210- raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred .shape } )" )
211-
212- if len (y_pred .shape ) != 4 and len (y_pred .shape ) != 5 :
213- raise ValueError (f"input shape must be 4 or 5, but got { y_pred .shape } " )
214-
215- if y_pred .shape [1 ] == 1 :
216- y_pred = one_hot (y_pred , num_classes = self .num_classes )
217- y_true = one_hot (y_true , num_classes = self .num_classes )
260+ n_pred_ch = y_pred .shape [1 ]
218261
219- if torch .max (y_true ) != self .num_classes - 1 :
220- raise ValueError (f"Please make sure the number of classes is { self .num_classes - 1 } " )
262+ y_pred_act = y_pred
263+ if self .sigmoid :
264+ y_pred_act = torch .sigmoid (y_pred )
265+ elif self .softmax :
266+ if n_pred_ch == 1 :
267+ warnings .warn ("single channel prediction, softmax=True ignored." )
268+ else :
269+ y_pred_act = torch .softmax (y_pred , dim = 1 )
221270
222- n_pred_ch = y_pred .shape [1 ]
223271 if self .to_onehot_y :
224- if n_pred_ch == 1 :
272+ if n_pred_ch == 1 and not self . sigmoid :
225273 warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
226- else :
274+ elif n_pred_ch > 1 or self .sigmoid :
275+ # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
276+ if y_true .shape [1 ] != 1 :
277+ y_true = y_true .unsqueeze (1 )
227278 y_true = one_hot (y_true , num_classes = n_pred_ch )
279+
280+ # Ensure y_true has the same shape as y_pred_act
281+ if y_true .shape != y_pred_act .shape :
282+ # This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid
283+ if y_true .shape [1 ] != y_pred_act .shape [1 ] and y_true .ndim == y_pred_act .ndim - 1 :
284+ y_true = y_true .unsqueeze (1 ) # Add channel dim
285+
286+ if y_true .shape != y_pred_act .shape :
287+ raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred_act .shape } ) after activations/one-hot" )
228288
229- asy_focal_loss = self .asy_focal_loss (y_pred , y_true )
230- asy_focal_tversky_loss = self .asy_focal_tversky_loss (y_pred , y_true )
231289
232- loss : torch .Tensor = self .weight * asy_focal_loss + (1 - self .weight ) * asy_focal_tversky_loss
290+ f_loss = self .asy_focal_loss (y_pred_act , y_true )
291+ t_loss = self .asy_focal_tversky_loss (y_pred_act , y_true )
233292
234- if self .reduction == LossReduction .SUM .value :
235- return torch .sum (loss ) # sum over the batch and channel dims
236- if self .reduction == LossReduction .NONE .value :
237- return loss # returns [N, num_classes] losses
238- if self .reduction == LossReduction .MEAN .value :
239- return torch .mean (loss )
240- raise ValueError (f'Unsupported reduction: { self .reduction } , available options are ["mean", "sum", "none"].' )
293+ loss : torch .Tensor = self .lambda_focal * f_loss + (1 - self .lambda_focal ) * t_loss
294+
295+ return loss
0 commit comments