Skip to content

Commit 3cddaa8

Browse files
committed
Generalize AsymmetricUnifiedFocalLoss for multi-class and align interface
1 parent 69f3dd2 commit 3cddaa8

File tree

1 file changed

+165
-110
lines changed

1 file changed

+165
-110
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 165 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import warnings
1515

1616
import torch
17+
import torch.nn.functional as F
1718
from torch.nn.modules.loss import _Loss
1819

1920
from monai.networks import one_hot
@@ -23,49 +24,51 @@
2324
class AsymmetricFocalTverskyLoss(_Loss):
2425
"""
2526
AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
26-
27-
Actually, it's only supported for binary image segmentation now.
27+
It treats the background class (index 0) differently from all foreground classes (indices 1...N).
2828
2929
Reimplementation of the Asymmetric Focal Tversky Loss described in:
3030
3131
- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
32-
Michael Yeung, Computerized Medical Imaging and Graphics
32+
Michael Yeung, Computerized Medical Imaging and Graphics
3333
"""
3434

3535
def __init__(
3636
self,
37-
to_onehot_y: bool = False,
37+
include_background: bool = True,
3838
delta: float = 0.7,
3939
gamma: float = 0.75,
4040
epsilon: float = 1e-7,
4141
reduction: LossReduction | str = LossReduction.MEAN,
4242
) -> None:
4343
"""
4444
Args:
45-
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
46-
delta : weight of the background. Defaults to 0.7.
47-
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
48-
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
45+
include_background: whether to include loss computation for the background class. Defaults to True.
46+
delta : weight of the background. Defaults to 0.7. (Used to weigh FNs and FPs in Tversky index)
47+
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
48+
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
49+
reduction: specifies the reduction to apply to the output: "none", "mean", "sum".
4950
"""
5051
super().__init__(reduction=LossReduction(reduction).value)
51-
self.to_onehot_y = to_onehot_y
52+
self.include_background = include_background
5253
self.delta = delta
5354
self.gamma = gamma
5455
self.epsilon = epsilon
5556

5657
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
5758
n_pred_ch = y_pred.shape[1]
5859

59-
if self.to_onehot_y:
60-
if n_pred_ch == 1:
61-
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
62-
else:
63-
y_true = one_hot(y_true, num_classes=n_pred_ch)
64-
6560
if y_true.shape != y_pred.shape:
6661
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
6762

68-
# clip the prediction to avoid NaN
63+
# Exclude background if needed
64+
if not self.include_background:
65+
if n_pred_ch == 1:
66+
warnings.warn("single channel prediction, `include_background=False` ignored.")
67+
else:
68+
y_pred = y_pred[:, 1:]
69+
y_true = y_true[:, 1:]
70+
71+
# Clip predictions
6972
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
7073
axis = list(range(2, len(y_pred.shape)))
7174

@@ -74,167 +77,219 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7477
fn = torch.sum(y_true * (1 - y_pred), dim=axis)
7578
fp = torch.sum((1 - y_true) * y_pred, dim=axis)
7679
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)
77-
78-
# Calculate losses separately for each class, enhancing both classes
79-
back_dice = 1 - dice_class[:, 0]
80-
fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma)
81-
82-
# Average class scores
83-
loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
84-
return loss
80+
# dice_class shape is (B, C)
81+
82+
n_classes = dice_class.shape[1]
83+
84+
if not self.include_background:
85+
# All classes are foreground, apply foreground logic
86+
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C)
87+
elif n_classes == 1:
88+
# Single class, must be foreground (BG was excluded or not provided)
89+
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1)
90+
else:
91+
# Asymmetric logic: class 0 is BG, others are FG
92+
back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1)
93+
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
94+
loss = torch.cat([back_dice_loss, fore_dice_loss], dim=1) # (B, C)
95+
96+
# Apply reduction
97+
if self.reduction == LossReduction.MEAN.value:
98+
return torch.mean(loss) # mean over batch and classes
99+
if self.reduction == LossReduction.SUM.value:
100+
return torch.sum(loss) # sum over batch and classes
101+
if self.reduction == LossReduction.NONE.value:
102+
return loss # returns (B, C) losses
103+
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
85104

86105

87106
class AsymmetricFocalLoss(_Loss):
88107
"""
89-
AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
90-
91-
Actually, it's only supported for binary image segmentation now.
108+
AsymmetricFocalLoss is a variant of FocalLoss, which attentions to the foreground class.
109+
It treats the background class (index 0) differently from all foreground classes (indices 1...N).
110+
Background class (0): applies gamma exponent to (1-p)
111+
Foreground classes (1..N): no gamma exponent
92112
93113
Reimplementation of the Asymmetric Focal Loss described in:
94114
95115
- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
96-
Michael Yeung, Computerized Medical Imaging and Graphics
116+
Michael Yeung, Computerized Medical Imaging and Graphics
97117
"""
98118

99119
def __init__(
100120
self,
101-
to_onehot_y: bool = False,
121+
include_background: bool = True,
102122
delta: float = 0.7,
103-
gamma: float = 2,
123+
gamma: float = 2.0,
104124
epsilon: float = 1e-7,
105125
reduction: LossReduction | str = LossReduction.MEAN,
106126
):
107127
"""
108128
Args:
109-
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
110-
delta : weight of the background. Defaults to 0.7.
111-
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
112-
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
129+
include_background: whether to include loss computation for the background class. Defaults to True.
130+
delta : weight of the foreground. Defaults to 0.7. (1-delta is weight of background)
131+
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.0.
132+
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
133+
reduction: specifies the reduction to apply to the output: "none", "mean", "sum".
113134
"""
114135
super().__init__(reduction=LossReduction(reduction).value)
115-
self.to_onehot_y = to_onehot_y
136+
self.include_background = include_background
116137
self.delta = delta
117138
self.gamma = gamma
118139
self.epsilon = epsilon
119140

120141
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
121142
n_pred_ch = y_pred.shape[1]
122143

123-
if self.to_onehot_y:
124-
if n_pred_ch == 1:
125-
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
126-
else:
127-
y_true = one_hot(y_true, num_classes=n_pred_ch)
128-
129144
if y_true.shape != y_pred.shape:
130145
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
131146

132-
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
133-
cross_entropy = -y_true * torch.log(y_pred)
134-
135-
back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
136-
back_ce = (1 - self.delta) * back_ce
137-
138-
fore_ce = cross_entropy[:, 1]
139-
fore_ce = self.delta * fore_ce
147+
# Exclude background if needed
148+
if not self.include_background:
149+
if n_pred_ch == 1:
150+
warnings.warn("single channel prediction, `include_background=False` ignored.")
151+
else:
152+
y_pred = y_pred[:, 1:]
153+
y_true = y_true[:, 1:]
140154

141-
loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1))
142-
return loss
155+
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
156+
cross_entropy = -y_true * torch.log(y_pred) # Shape (B, C, H, W, [D])
157+
158+
n_classes = y_pred.shape[1]
159+
160+
if not self.include_background:
161+
# All classes are foreground, apply foreground logic
162+
loss = self.delta * cross_entropy # (B, C, H, W)
163+
elif n_classes == 1:
164+
# Single class, must be foreground
165+
loss = self.delta * cross_entropy # (B, 1, H, W)
166+
else:
167+
# Asymmetric logic: class 0 is BG, others are FG
168+
# (B, H, W)
169+
back_ce = (1.0 - self.delta) * torch.pow(1.0 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
170+
# (B, C-1, H, W)
171+
fore_ce = self.delta * cross_entropy[:, 1:]
172+
173+
loss = torch.cat([back_ce.unsqueeze(1), fore_ce], dim=1) # (B, C, H, W)
174+
175+
# Apply reduction
176+
if self.reduction == LossReduction.MEAN.value:
177+
return torch.mean(loss) # mean over batch, class, and spatial
178+
if self.reduction == LossReduction.SUM.value:
179+
return torch.sum(loss)
180+
if self.reduction == LossReduction.NONE.value:
181+
return loss # returns (B, C, H, W)
182+
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
143183

144184

145185
class AsymmetricUnifiedFocalLoss(_Loss):
146186
"""
147-
AsymmetricUnifiedFocalLoss is a variant of Focal Loss.
148-
149-
Actually, it's only supported for binary image segmentation now
187+
AsymmetricUnifiedFocalLoss is a variant of Focal Loss, combining AsymmetricFocalLoss
188+
and AsymmetricFocalTverskyLoss.
150189
151190
Reimplementation of the Asymmetric Unified Focal Tversky Loss described in:
152191
153192
- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
154-
Michael Yeung, Computerized Medical Imaging and Graphics
193+
Michael Yeung, Computerized Medical Imaging and Graphics
155194
"""
156195

157196
def __init__(
158197
self,
198+
include_background: bool = True,
159199
to_onehot_y: bool = False,
160-
num_classes: int = 2,
161-
weight: float = 0.5,
162-
gamma: float = 0.5,
163-
delta: float = 0.7,
200+
sigmoid: bool = False,
201+
softmax: bool = False,
202+
lambda_focal: float = 0.5,
203+
focal_loss_gamma: float = 2.0,
204+
focal_loss_delta: float = 0.7,
205+
tversky_loss_gamma: float = 0.75,
206+
tversky_loss_delta: float = 0.7,
164207
reduction: LossReduction | str = LossReduction.MEAN,
165208
):
166209
"""
167210
Args:
211+
include_background: whether to include loss computation for the background class. Defaults to True.
168212
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
169-
num_classes : number of classes, it only supports 2 now. Defaults to 2.
170-
delta : weight of the background. Defaults to 0.7.
171-
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
172-
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
173-
weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
213+
sigmoid: if True, apply a sigmoid activation to the input y_pred.
214+
softmax: if True, apply a softmax activation to the input y_pred.
215+
lambda_focal: the weight for AsymmetricFocalLoss (Cross-Entropy based).
216+
The weight for AsymmetricFocalTverskyLoss will be (1 - lambda_focal). Defaults to 0.5.
217+
focal_loss_gamma: gamma parameter for the AsymmetricFocalLoss component. Defaults to 2.0.
218+
focal_loss_delta: delta parameter for the AsymmetricFocalLoss component. Defaults to 0.7.
219+
tversky_loss_gamma: gamma parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.75.
220+
tversky_loss_delta: delta parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.7.
221+
reduction: specifies the reduction to apply to the output: "none", "mean", "sum".
174222
175223
Example:
176224
>>> import torch
177225
>>> from monai.losses import AsymmetricUnifiedFocalLoss
178-
>>> pred = torch.ones((1,1,32,32), dtype=torch.float32)
179-
>>> grnd = torch.ones((1,1,32,32), dtype=torch.int64)
180-
>>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True)
226+
>>> pred = torch.randn((1, 2, 32, 32), dtype=torch.float32)
227+
>>> grnd = torch.randint(0, 2, (1, 1, 32, 32), dtype=torch.int64)
228+
>>> fl = AsymmetricUnifiedFocalLoss(softmax=True, to_onehot_y=True)
181229
>>> fl(pred, grnd)
182230
"""
183231
super().__init__(reduction=LossReduction(reduction).value)
232+
self.include_background = include_background
184233
self.to_onehot_y = to_onehot_y
185-
self.num_classes = num_classes
186-
self.gamma = gamma
187-
self.delta = delta
188-
self.weight: float = weight
189-
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
190-
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)
234+
self.sigmoid = sigmoid
235+
self.softmax = softmax
236+
self.lambda_focal = lambda_focal
237+
238+
if sigmoid and softmax:
239+
raise ValueError("Both sigmoid and softmax cannot be True.")
240+
241+
self.asy_focal_loss = AsymmetricFocalLoss(
242+
include_background=self.include_background,
243+
gamma=focal_loss_gamma,
244+
delta=focal_loss_delta,
245+
reduction=self.reduction,
246+
)
247+
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
248+
include_background=self.include_background,
249+
gamma=tversky_loss_gamma,
250+
delta=tversky_loss_delta,
251+
reduction=self.reduction,
252+
)
191253

192-
# TODO: Implement this function to support multiple classes segmentation
193254
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
194255
"""
195256
Args:
196-
y_pred : the shape should be BNH[WD], where N is the number of classes.
197-
It only supports binary segmentation.
198-
The input should be the original logits since it will be transformed by
199-
a sigmoid in the forward function.
200-
y_true : the shape should be BNH[WD], where N is the number of classes.
201-
It only supports binary segmentation.
202-
203-
Raises:
204-
ValueError: When input and target are different shape
205-
ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5
206-
ValueError: When num_classes
207-
ValueError: When the number of classes entered does not match the expected number
257+
y_pred : the shape should be BNH[WD].
258+
y_true : the shape should be BNH[WD] or B1H[WD].
208259
"""
209-
if y_pred.shape != y_true.shape:
210-
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
211-
212-
if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
213-
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")
214-
215-
if y_pred.shape[1] == 1:
216-
y_pred = one_hot(y_pred, num_classes=self.num_classes)
217-
y_true = one_hot(y_true, num_classes=self.num_classes)
260+
n_pred_ch = y_pred.shape[1]
218261

219-
if torch.max(y_true) != self.num_classes - 1:
220-
raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}")
262+
y_pred_act = y_pred
263+
if self.sigmoid:
264+
y_pred_act = torch.sigmoid(y_pred)
265+
elif self.softmax:
266+
if n_pred_ch == 1:
267+
warnings.warn("single channel prediction, softmax=True ignored.")
268+
else:
269+
y_pred_act = torch.softmax(y_pred, dim=1)
221270

222-
n_pred_ch = y_pred.shape[1]
223271
if self.to_onehot_y:
224-
if n_pred_ch == 1:
272+
if n_pred_ch == 1 and not self.sigmoid:
225273
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
226-
else:
274+
elif n_pred_ch > 1 or self.sigmoid:
275+
# Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
276+
if y_true.shape[1] != 1:
277+
y_true = y_true.unsqueeze(1)
227278
y_true = one_hot(y_true, num_classes=n_pred_ch)
279+
280+
# Ensure y_true has the same shape as y_pred_act
281+
if y_true.shape != y_pred_act.shape:
282+
# This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid
283+
if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
284+
y_true = y_true.unsqueeze(1) # Add channel dim
285+
286+
if y_true.shape != y_pred_act.shape:
287+
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) after activations/one-hot")
228288

229-
asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
230-
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)
231289

232-
loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss
290+
f_loss = self.asy_focal_loss(y_pred_act, y_true)
291+
t_loss = self.asy_focal_tversky_loss(y_pred_act, y_true)
233292

234-
if self.reduction == LossReduction.SUM.value:
235-
return torch.sum(loss) # sum over the batch and channel dims
236-
if self.reduction == LossReduction.NONE.value:
237-
return loss # returns [N, num_classes] losses
238-
if self.reduction == LossReduction.MEAN.value:
239-
return torch.mean(loss)
240-
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
293+
loss: torch.Tensor = self.lambda_focal * f_loss + (1 - self.lambda_focal) * t_loss
294+
295+
return loss

0 commit comments

Comments
 (0)