Skip to content
275 changes: 164 additions & 111 deletions monai/losses/unified_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,218 +23,271 @@
class AsymmetricFocalTverskyLoss(_Loss):
"""
AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
Actually, it's only supported for binary image segmentation now.
It treats the background class (index 0) differently from all foreground classes (indices 1...N).
Reimplementation of the Asymmetric Focal Tversky Loss described in:
- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
Michael Yeung, Computerized Medical Imaging and Graphics
Michael Yeung, Computerized Medical Imaging and Graphics
"""

def __init__(
self,
to_onehot_y: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is to_onehot_y being removed? I think users will still want this functionality so I would leave it in and add include_background as a new last argument. Even if we do want to remove it we need to use the deprecation decorators to mark the argument removed but still keep it for a version or two.

include_background: bool = True,
delta: float = 0.7,
gamma: float = 0.75,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
) -> None:
"""
Args:
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
include_background: whether to include loss computation for the background class. Defaults to True.
delta : weight of the background. Defaults to 0.7. (Used to weigh FNs and FPs in Tversky index)
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
reduction: specifies the reduction to apply to the output: "none", "mean", "sum".
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.include_background = include_background
Comment on lines +36 to +51
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Breaking API change without deprecation.

Removing to_onehot_y from AsymmetricFocalTverskyLoss breaks existing code. Per MONAI conventions and prior reviewer feedback (ericspod), deprecated parameters must be preserved with deprecation decorators for 1-2 versions.

Either restore to_onehot_y with deprecation warnings or explicitly document this as a breaking change in the changelog.

self.delta = delta
self.gamma = gamma
self.epsilon = epsilon

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
n_pred_ch = y_pred.shape[1]

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)

if y_true.shape != y_pred.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

# clip the prediction to avoid NaN
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
# Exclude background if needed
if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
else:
y_pred = y_pred[:, 1:]
y_true = y_true[:, 1:]

axis = list(range(2, len(y_pred.shape)))

# Calculate true positives (tp), false negatives (fn) and false positives (fp)
tp = torch.sum(y_true * y_pred, dim=axis)
fn = torch.sum(y_true * (1 - y_pred), dim=axis)
fp = torch.sum((1 - y_true) * y_pred, dim=axis)
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)

# Calculate losses separately for each class, enhancing both classes
back_dice = 1 - dice_class[:, 0]
fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma)

# Average class scores
loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
return loss
# dice_class shape is (B, C)

n_classes = dice_class.shape[1]

if not self.include_background:
# All classes are foreground, apply foreground logic
loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C)
elif n_classes == 1:
# Single class, must be foreground (BG was excluded or not provided)
loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1)
else:
# Asymmetric logic: class 0 is BG, others are FG
back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: same exponent error - should be 1 + self.gamma

Suggested change
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 + self.gamma) # (B, C-1)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: same exponent error - should be 1.0 / self.gamma

Suggested change
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Same exponent error - should be 1.0 / self.gamma

Suggested change
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: same exponent error - should be 1.0 / self.gamma

Suggested change
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Same exponent error - should be 1.0 / self.gamma

Suggested change
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: exponent formula incorrect - should be 1.0 / self.gamma not 1.0 - self.gamma

Paper specifies AFTL = (1-TI)^(1/γ). With γ=0.75: should be (1-dice)^1.333 (increases penalty), current gives (1-dice)^0.25 (decreases penalty).

Suggested change
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: exponent formula is incorrect - should be 1.0 / self.gamma not 1.0 - self.gamma

Per the Unified Focal Loss paper, AFTL = (1-TI)^(1/γ). With default γ=0.75, this gives (1-dice)^1.333, increasing penalty for low dice. Current formula gives (1-dice)^0.25, incorrectly decreasing the penalty.

Suggested change
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: exponent should be 1.0 / self.gamma not 1.0 - self.gamma (inconsistent with lines 83, 86)

Lines 83 and 86 correctly use 1.0 / self.gamma, but line 90 incorrectly uses 1.0 - self.gamma. Per the Unified Focal Loss paper, AFTL = (1-TI)^(1/γ). With γ=0.75, should be (1-dice)^1.333 to increase penalty for low dice scores.

Suggested change
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: exponent should be 1.0 / self.gamma not 1.0 - self.gamma (inconsistent with lines 83, 86)

Lines 83 and 86 correctly use 1.0 / self.gamma, but line 90 incorrectly uses 1.0 - self.gamma. Per the Unified Focal Loss paper, AFTL = (1-TI)^(1/γ). With γ=0.75, should be (1-dice)^1.333 to increase penalty for low dice scores.

Suggested change
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1)

loss = torch.cat([back_dice_loss, fore_dice_loss], dim=1) # (B, C)

# Apply reduction
if self.reduction == LossReduction.MEAN.value:
return torch.mean(loss) # mean over batch and classes
if self.reduction == LossReduction.SUM.value:
return torch.sum(loss) # sum over batch and classes
if self.reduction == LossReduction.NONE.value:
return loss # returns (B, C) losses
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')


class AsymmetricFocalLoss(_Loss):
"""
AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
Actually, it's only supported for binary image segmentation now.
AsymmetricFocalLoss is a variant of FocalLoss, which attentions to the foreground class.
It treats the background class (index 0) differently from all foreground classes (indices 1...N).
Background class (0): applies gamma exponent to (1-p)
Foreground classes (1..N): no gamma exponent
Reimplementation of the Asymmetric Focal Loss described in:
- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
Michael Yeung, Computerized Medical Imaging and Graphics
Michael Yeung, Computerized Medical Imaging and Graphics
"""

def __init__(
self,
to_onehot_y: bool = False,
include_background: bool = True,
delta: float = 0.7,
gamma: float = 2,
gamma: float = 2.0,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
):
"""
Args:
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
include_background: whether to include loss computation for the background class. Defaults to True.
delta : weight of the foreground. Defaults to 0.7. (1-delta is weight of background)
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.0.
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
reduction: specifies the reduction to apply to the output: "none", "mean", "sum".
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.include_background = include_background
self.delta = delta
self.gamma = gamma
self.epsilon = epsilon

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
n_pred_ch = y_pred.shape[1]

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)

if y_true.shape != y_pred.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
cross_entropy = -y_true * torch.log(y_pred)

back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
back_ce = (1 - self.delta) * back_ce

fore_ce = cross_entropy[:, 1]
fore_ce = self.delta * fore_ce
# Exclude background if needed
if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
else:
y_pred = y_pred[:, 1:]
y_true = y_true[:, 1:]

loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1))
return loss
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
cross_entropy = -y_true * torch.log(y_pred) # Shape (B, C, H, W, [D])

n_classes = y_pred.shape[1]

if not self.include_background:
# All classes are foreground, apply foreground logic
loss = self.delta * cross_entropy # (B, C, H, W)
elif n_classes == 1:
# Single class, must be foreground
loss = self.delta * cross_entropy # (B, 1, H, W)
else:
# Asymmetric logic: class 0 is BG, others are FG
# (B, H, W)
back_ce = (1.0 - self.delta) * torch.pow(1.0 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
# (B, C-1, H, W)
fore_ce = self.delta * cross_entropy[:, 1:]

loss = torch.cat([back_ce.unsqueeze(1), fore_ce], dim=1) # (B, C, H, W)

# Apply reduction
if self.reduction == LossReduction.MEAN.value:
return torch.mean(loss) # mean over batch, class, and spatial
if self.reduction == LossReduction.SUM.value:
return torch.sum(loss)
if self.reduction == LossReduction.NONE.value:
return loss # returns (B, C, H, W)
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')


class AsymmetricUnifiedFocalLoss(_Loss):
"""
AsymmetricUnifiedFocalLoss is a variant of Focal Loss.
Actually, it's only supported for binary image segmentation now
AsymmetricUnifiedFocalLoss is a variant of Focal Loss, combining AsymmetricFocalLoss
and AsymmetricFocalTverskyLoss.
Reimplementation of the Asymmetric Unified Focal Tversky Loss described in:
- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
Michael Yeung, Computerized Medical Imaging and Graphics
Michael Yeung, Computerized Medical Imaging and Graphics
"""

def __init__(
self,
to_onehot_y: bool = False,
num_classes: int = 2,
weight: float = 0.5,
gamma: float = 0.5,
delta: float = 0.7,
use_sigmoid: bool = False,
use_softmax: bool = False,
lambda_focal: float = 0.5,
focal_loss_gamma: float = 2.0,
focal_loss_delta: float = 0.7,
tversky_loss_gamma: float = 0.75,
tversky_loss_delta: float = 0.7,
include_background: bool = True,
reduction: LossReduction | str = LossReduction.MEAN,
):
"""
Args:
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
num_classes : number of classes, it only supports 2 now. Defaults to 2.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
use_sigmoid: if True, apply a sigmoid activation to the input y_pred.
use_softmax: if True, apply a softmax activation to the input y_pred.
lambda_focal: the weight for AsymmetricFocalLoss (Cross-Entropy based).
The weight for AsymmetricFocalTverskyLoss will be (1 - lambda_focal). Defaults to 0.5.
focal_loss_gamma: gamma parameter for the AsymmetricFocalLoss component. Defaults to 2.0.
focal_loss_delta: delta parameter for the AsymmetricFocalLoss component. Defaults to 0.7.
tversky_loss_gamma: gamma parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.75.
tversky_loss_delta: delta parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.7.
include_background: whether to include loss computation for the background class. Defaults to True.
reduction: specifies the reduction to apply to the output: "none", "mean", "sum".
Example:
>>> import torch
>>> from monai.losses import AsymmetricUnifiedFocalLoss
>>> pred = torch.ones((1,1,32,32), dtype=torch.float32)
>>> grnd = torch.ones((1,1,32,32), dtype=torch.int64)
>>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True)
>>> pred = torch.randn((1, 2, 32, 32), dtype=torch.float32)
>>> grnd = torch.randint(0, 2, (1, 1, 32, 32), dtype=torch.int64)
>>> fl = AsymmetricUnifiedFocalLoss(use_softmax=True, to_onehot_y=True)
>>> fl(pred, grnd)
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.num_classes = num_classes
self.gamma = gamma
self.delta = delta
self.weight: float = weight
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)
self.use_sigmoid = use_sigmoid
self.use_softmax = use_softmax
self.lambda_focal = lambda_focal
self.include_background = include_background

if self.use_sigmoid and self.use_softmax:
raise ValueError("Both use_sigmoid and use_softmax cannot be True.")

self.asy_focal_loss = AsymmetricFocalLoss(
include_background=self.include_background,
gamma=focal_loss_gamma,
delta=focal_loss_delta,
reduction=self.reduction,
)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
include_background=self.include_background,
gamma=tversky_loss_gamma,
delta=tversky_loss_delta,
reduction=self.reduction,
)

# TODO: Implement this function to support multiple classes segmentation
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
y_pred : the shape should be BNH[WD], where N is the number of classes.
It only supports binary segmentation.
The input should be the original logits since it will be transformed by
a sigmoid in the forward function.
y_true : the shape should be BNH[WD], where N is the number of classes.
It only supports binary segmentation.
Raises:
ValueError: When input and target are different shape
ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5
ValueError: When num_classes
ValueError: When the number of classes entered does not match the expected number
y_pred : the shape should be BNH[WD].
y_true : the shape should be BNH[WD] or B1H[WD].
"""
if y_pred.shape != y_true.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")

if y_pred.shape[1] == 1:
y_pred = one_hot(y_pred, num_classes=self.num_classes)
y_true = one_hot(y_true, num_classes=self.num_classes)
n_pred_ch = y_pred.shape[1]

if torch.max(y_true) != self.num_classes - 1:
raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}")
y_pred_act = y_pred
if self.use_sigmoid:
y_pred_act = torch.sigmoid(y_pred)
elif self.use_softmax:
if n_pred_ch == 1:
warnings.warn("single channel prediction, use_softmax=True ignored.")
else:
y_pred_act = torch.softmax(y_pred, dim=1)

n_pred_ch = y_pred.shape[1]
if self.to_onehot_y:
if n_pred_ch == 1:
if n_pred_ch == 1 and not self.use_sigmoid:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
elif n_pred_ch > 1 or self.use_sigmoid:
# Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
if y_true.shape[1] != 1:
y_true = y_true.unsqueeze(1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: unsqueezing unconditionally can cause issues if y_true is already (B, 1, H, W)

This adds an extra dimension. Should check y_true.ndim first or only unsqueeze when y_true.ndim == 3.

Suggested change
y_true = y_true.unsqueeze(1)
if y_true.ndim == 3: # (B, H, W) -> (B, 1, H, W)
y_true = y_true.unsqueeze(1)

y_true = one_hot(y_true, num_classes=n_pred_ch)
Comment on lines +271 to 275
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: when use_sigmoid=True and n_pred_ch=1, one-hot with num_classes=1 doesn't make sense

When use_sigmoid=True with single channel, line 271's condition n_pred_ch > 1 or self.use_sigmoid evaluates to True (because self.use_sigmoid is True). This leads to one-hot encoding with num_classes=1, which is illogical. For binary sigmoid, should convert to 2 classes (background + foreground).

Comment on lines +271 to 275
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: when use_sigmoid=True and n_pred_ch=1, line 271's condition n_pred_ch > 1 or self.use_sigmoid evaluates to True, leading to one-hot with num_classes=1 on line 275. For binary sigmoid, should convert to 2 classes (background + foreground).

Suggested change
elif n_pred_ch > 1 or self.use_sigmoid:
# Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
if y_true.shape[1] != 1:
y_true = y_true.unsqueeze(1)
y_true = one_hot(y_true, num_classes=n_pred_ch)
elif n_pred_ch > 1 or (self.use_sigmoid and n_pred_ch == 1):
# Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
if y_true.shape[1] != 1:
y_true = y_true.unsqueeze(1)
# For binary sigmoid case, need 2 classes
num_classes = 2 if (self.use_sigmoid and n_pred_ch == 1) else n_pred_ch
y_true = one_hot(y_true, num_classes=num_classes)

Comment on lines +271 to 275
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: when use_sigmoid=True and n_pred_ch=1, one-hot with num_classes=1 is illogical

For binary sigmoid (1 channel foreground probability), should convert to 2 classes (background + foreground) for proper loss computation with the asymmetric logic.

Suggested change
elif n_pred_ch > 1 or self.use_sigmoid:
# Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
if y_true.shape[1] != 1:
y_true = y_true.unsqueeze(1)
y_true = one_hot(y_true, num_classes=n_pred_ch)
elif n_pred_ch > 1 or (self.use_sigmoid and n_pred_ch == 1):
# Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
if y_true.shape[1] != 1:
y_true = y_true.unsqueeze(1)
# For binary sigmoid case, need 2 classes
num_classes = 2 if (self.use_sigmoid and n_pred_ch == 1) else n_pred_ch
y_true = one_hot(y_true, num_classes=num_classes)


asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)
# Ensure y_true has the same shape as y_pred_act
if y_true.shape != y_pred_act.shape:
if y_true.ndim == y_pred_act.ndim - 1:
y_true = y_true.unsqueeze(1)
Comment on lines +279 to +280
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: when y_true.ndim == y_pred_act.ndim - 1 (e.g., (B,H,W) vs (B,C,H,W)), should immediately unsqueeze. The AND condition causes the unsqueeze to be skipped when dimensions match exactly at position 1.

Suggested change
if y_true.ndim == y_pred_act.ndim - 1:
y_true = y_true.unsqueeze(1)
if y_true.ndim == y_pred_act.ndim - 1: # e.g., (B, H, W) vs (B, C, H, W)
y_true = y_true.unsqueeze(1) # Add channel dim
elif y_true.shape[1] != y_pred_act.shape[1]:
y_true = y_true.unsqueeze(1)


loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss
if y_true.shape != y_pred_act.shape:
raise ValueError(
f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) "
f"after activations/one-hot"
)

if self.reduction == LossReduction.SUM.value:
return torch.sum(loss) # sum over the batch and channel dims
if self.reduction == LossReduction.NONE.value:
return loss # returns [N, num_classes] losses
if self.reduction == LossReduction.MEAN.value:
return torch.mean(loss)
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
f_loss = self.asy_focal_loss(y_pred_act, y_true)
t_loss = self.asy_focal_tversky_loss(y_pred_act, y_true)

loss: torch.Tensor = self.lambda_focal * f_loss + (1 - self.lambda_focal) * t_loss

return loss
Loading