-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Generalize AsymmetricUnifiedFocalLoss for multi-class and align interface #8607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ad83444
0ba7863
30e8263
d493815
befcfeb
04a98d4
ac66250
f69a25b
ba8a090
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -23,218 +23,271 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||
| class AsymmetricFocalTverskyLoss(_Loss): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Actually, it's only supported for binary image segmentation now. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| It treats the background class (index 0) differently from all foreground classes (indices 1...N). | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Reimplementation of the Asymmetric Focal Tversky Loss described in: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Michael Yeung, Computerized Medical Imaging and Graphics | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Michael Yeung, Computerized Medical Imaging and Graphics | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| to_onehot_y: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| include_background: bool = True, | ||||||||||||||||||||||||||||||||||||||||||||||||||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| delta: float = 0.7, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| gamma: float = 0.75, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| epsilon: float = 1e-7, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| reduction: LossReduction | str = LossReduction.MEAN, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| delta : weight of the background. Defaults to 0.7. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| include_background: whether to include loss computation for the background class. Defaults to True. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| delta : weight of the background. Defaults to 0.7. (Used to weigh FNs and FPs in Tversky index) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| reduction: specifies the reduction to apply to the output: "none", "mean", "sum". | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| super().__init__(reduction=LossReduction(reduction).value) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.to_onehot_y = to_onehot_y | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.include_background = include_background | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+36
to
+51
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Breaking API change without deprecation. Removing Either restore |
||||||||||||||||||||||||||||||||||||||||||||||||||
| self.delta = delta | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.gamma = gamma | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.epsilon = epsilon | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| n_pred_ch = y_pred.shape[1] | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.to_onehot_y: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if n_pred_ch == 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| if y_true.shape != y_pred.shape: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # clip the prediction to avoid NaN | ||||||||||||||||||||||||||||||||||||||||||||||||||
| y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Exclude background if needed | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if not self.include_background: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if n_pred_ch == 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| warnings.warn("single channel prediction, `include_background=False` ignored.") | ||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| y_pred = y_pred[:, 1:] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| y_true = y_true[:, 1:] | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| axis = list(range(2, len(y_pred.shape))) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Calculate true positives (tp), false negatives (fn) and false positives (fp) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| tp = torch.sum(y_true * y_pred, dim=axis) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| fn = torch.sum(y_true * (1 - y_pred), dim=axis) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| fp = torch.sum((1 - y_true) * y_pred, dim=axis) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Calculate losses separately for each class, enhancing both classes | ||||||||||||||||||||||||||||||||||||||||||||||||||
| back_dice = 1 - dice_class[:, 0] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Average class scores | ||||||||||||||||||||||||||||||||||||||||||||||||||
| loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return loss | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # dice_class shape is (B, C) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| n_classes = dice_class.shape[1] | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| if not self.include_background: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # All classes are foreground, apply foreground logic | ||||||||||||||||||||||||||||||||||||||||||||||||||
| loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| elif n_classes == 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Single class, must be foreground (BG was excluded or not provided) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Asymmetric logic: class 0 is BG, others are FG | ||||||||||||||||||||||||||||||||||||||||||||||||||
| back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: same exponent error - should be
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: same exponent error - should be
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: Same exponent error - should be
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: same exponent error - should be
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: Same exponent error - should be
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: exponent formula incorrect - should be Paper specifies AFTL = (1-TI)^(1/γ). With γ=0.75: should be (1-dice)^1.333 (increases penalty), current gives (1-dice)^0.25 (decreases penalty).
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: exponent formula is incorrect - should be Per the Unified Focal Loss paper, AFTL = (1-TI)^(1/γ). With default γ=0.75, this gives (1-dice)^1.333, increasing penalty for low dice. Current formula gives (1-dice)^0.25, incorrectly decreasing the penalty.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: exponent should be Lines 83 and 86 correctly use
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: exponent should be Lines 83 and 86 correctly use
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| loss = torch.cat([back_dice_loss, fore_dice_loss], dim=1) # (B, C) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Apply reduction | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.reduction == LossReduction.MEAN.value: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return torch.mean(loss) # mean over batch and classes | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.reduction == LossReduction.SUM.value: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return torch.sum(loss) # sum over batch and classes | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.reduction == LossReduction.NONE.value: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return loss # returns (B, C) losses | ||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| class AsymmetricFocalLoss(_Loss): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Actually, it's only supported for binary image segmentation now. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| AsymmetricFocalLoss is a variant of FocalLoss, which attentions to the foreground class. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| It treats the background class (index 0) differently from all foreground classes (indices 1...N). | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Background class (0): applies gamma exponent to (1-p) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Foreground classes (1..N): no gamma exponent | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Reimplementation of the Asymmetric Focal Loss described in: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Michael Yeung, Computerized Medical Imaging and Graphics | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Michael Yeung, Computerized Medical Imaging and Graphics | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| to_onehot_y: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| include_background: bool = True, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| delta: float = 0.7, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| gamma: float = 2, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| gamma: float = 2.0, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| epsilon: float = 1e-7, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| reduction: LossReduction | str = LossReduction.MEAN, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| delta : weight of the background. Defaults to 0.7. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| include_background: whether to include loss computation for the background class. Defaults to True. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| delta : weight of the foreground. Defaults to 0.7. (1-delta is weight of background) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.0. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| reduction: specifies the reduction to apply to the output: "none", "mean", "sum". | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| super().__init__(reduction=LossReduction(reduction).value) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.to_onehot_y = to_onehot_y | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.include_background = include_background | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.delta = delta | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.gamma = gamma | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.epsilon = epsilon | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| n_pred_ch = y_pred.shape[1] | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.to_onehot_y: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if n_pred_ch == 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| if y_true.shape != y_pred.shape: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| cross_entropy = -y_true * torch.log(y_pred) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| back_ce = (1 - self.delta) * back_ce | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| fore_ce = cross_entropy[:, 1] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| fore_ce = self.delta * fore_ce | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Exclude background if needed | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if not self.include_background: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if n_pred_ch == 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| warnings.warn("single channel prediction, `include_background=False` ignored.") | ||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| y_pred = y_pred[:, 1:] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| y_true = y_true[:, 1:] | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1)) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return loss | ||||||||||||||||||||||||||||||||||||||||||||||||||
| y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| cross_entropy = -y_true * torch.log(y_pred) # Shape (B, C, H, W, [D]) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| n_classes = y_pred.shape[1] | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| if not self.include_background: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # All classes are foreground, apply foreground logic | ||||||||||||||||||||||||||||||||||||||||||||||||||
| loss = self.delta * cross_entropy # (B, C, H, W) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| elif n_classes == 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Single class, must be foreground | ||||||||||||||||||||||||||||||||||||||||||||||||||
| loss = self.delta * cross_entropy # (B, 1, H, W) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Asymmetric logic: class 0 is BG, others are FG | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # (B, H, W) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| back_ce = (1.0 - self.delta) * torch.pow(1.0 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # (B, C-1, H, W) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| fore_ce = self.delta * cross_entropy[:, 1:] | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| loss = torch.cat([back_ce.unsqueeze(1), fore_ce], dim=1) # (B, C, H, W) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Apply reduction | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.reduction == LossReduction.MEAN.value: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return torch.mean(loss) # mean over batch, class, and spatial | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.reduction == LossReduction.SUM.value: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return torch.sum(loss) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.reduction == LossReduction.NONE.value: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return loss # returns (B, C, H, W) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| class AsymmetricUnifiedFocalLoss(_Loss): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| AsymmetricUnifiedFocalLoss is a variant of Focal Loss. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Actually, it's only supported for binary image segmentation now | ||||||||||||||||||||||||||||||||||||||||||||||||||
| AsymmetricUnifiedFocalLoss is a variant of Focal Loss, combining AsymmetricFocalLoss | ||||||||||||||||||||||||||||||||||||||||||||||||||
| and AsymmetricFocalTverskyLoss. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Reimplementation of the Asymmetric Unified Focal Tversky Loss described in: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Michael Yeung, Computerized Medical Imaging and Graphics | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Michael Yeung, Computerized Medical Imaging and Graphics | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| to_onehot_y: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| num_classes: int = 2, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| weight: float = 0.5, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| gamma: float = 0.5, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| delta: float = 0.7, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| use_sigmoid: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| use_softmax: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| lambda_focal: float = 0.5, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| focal_loss_gamma: float = 2.0, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| focal_loss_delta: float = 0.7, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| tversky_loss_gamma: float = 0.75, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| tversky_loss_delta: float = 0.7, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| include_background: bool = True, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| reduction: LossReduction | str = LossReduction.MEAN, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| num_classes : number of classes, it only supports 2 now. Defaults to 2. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| delta : weight of the background. Defaults to 0.7. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| weight : weight for each loss function, if it's none it's 0.5. Defaults to None. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| use_sigmoid: if True, apply a sigmoid activation to the input y_pred. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| use_softmax: if True, apply a softmax activation to the input y_pred. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| lambda_focal: the weight for AsymmetricFocalLoss (Cross-Entropy based). | ||||||||||||||||||||||||||||||||||||||||||||||||||
| The weight for AsymmetricFocalTverskyLoss will be (1 - lambda_focal). Defaults to 0.5. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| focal_loss_gamma: gamma parameter for the AsymmetricFocalLoss component. Defaults to 2.0. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| focal_loss_delta: delta parameter for the AsymmetricFocalLoss component. Defaults to 0.7. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| tversky_loss_gamma: gamma parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.75. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| tversky_loss_delta: delta parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.7. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| include_background: whether to include loss computation for the background class. Defaults to True. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| reduction: specifies the reduction to apply to the output: "none", "mean", "sum". | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Example: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> from monai.losses import AsymmetricUnifiedFocalLoss | ||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> pred = torch.ones((1,1,32,32), dtype=torch.float32) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> grnd = torch.ones((1,1,32,32), dtype=torch.int64) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> pred = torch.randn((1, 2, 32, 32), dtype=torch.float32) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> grnd = torch.randint(0, 2, (1, 1, 32, 32), dtype=torch.int64) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> fl = AsymmetricUnifiedFocalLoss(use_softmax=True, to_onehot_y=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> fl(pred, grnd) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| super().__init__(reduction=LossReduction(reduction).value) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.to_onehot_y = to_onehot_y | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.num_classes = num_classes | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.gamma = gamma | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.delta = delta | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.weight: float = weight | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.use_sigmoid = use_sigmoid | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.use_softmax = use_softmax | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.lambda_focal = lambda_focal | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.include_background = include_background | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.use_sigmoid and self.use_softmax: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("Both use_sigmoid and use_softmax cannot be True.") | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| self.asy_focal_loss = AsymmetricFocalLoss( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| include_background=self.include_background, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| gamma=focal_loss_gamma, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| delta=focal_loss_delta, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| reduction=self.reduction, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| include_background=self.include_background, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| gamma=tversky_loss_gamma, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| delta=tversky_loss_delta, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| reduction=self.reduction, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # TODO: Implement this function to support multiple classes segmentation | ||||||||||||||||||||||||||||||||||||||||||||||||||
| def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| y_pred : the shape should be BNH[WD], where N is the number of classes. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| It only supports binary segmentation. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| The input should be the original logits since it will be transformed by | ||||||||||||||||||||||||||||||||||||||||||||||||||
| a sigmoid in the forward function. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| y_true : the shape should be BNH[WD], where N is the number of classes. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| It only supports binary segmentation. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Raises: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ValueError: When input and target are different shape | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5 | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ValueError: When num_classes | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ValueError: When the number of classes entered does not match the expected number | ||||||||||||||||||||||||||||||||||||||||||||||||||
| y_pred : the shape should be BNH[WD]. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| y_true : the shape should be BNH[WD] or B1H[WD]. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if y_pred.shape != y_true.shape: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| if y_pred.shape[1] == 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| y_pred = one_hot(y_pred, num_classes=self.num_classes) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| y_true = one_hot(y_true, num_classes=self.num_classes) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| n_pred_ch = y_pred.shape[1] | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| if torch.max(y_true) != self.num_classes - 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}") | ||||||||||||||||||||||||||||||||||||||||||||||||||
| y_pred_act = y_pred | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.use_sigmoid: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| y_pred_act = torch.sigmoid(y_pred) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| elif self.use_softmax: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if n_pred_ch == 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| warnings.warn("single channel prediction, use_softmax=True ignored.") | ||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| y_pred_act = torch.softmax(y_pred, dim=1) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| n_pred_ch = y_pred.shape[1] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.to_onehot_y: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if n_pred_ch == 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if n_pred_ch == 1 and not self.use_sigmoid: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| elif n_pred_ch > 1 or self.use_sigmoid: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if y_true.shape[1] != 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| y_true = y_true.unsqueeze(1) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: unsqueezing unconditionally can cause issues if This adds an extra dimension. Should check
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+271
to
275
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: when When
Comment on lines
+271
to
275
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: when
Suggested change
Comment on lines
+271
to
275
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: when For binary sigmoid (1 channel foreground probability), should convert to 2 classes (background + foreground) for proper loss computation with the asymmetric logic.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| asy_focal_loss = self.asy_focal_loss(y_pred, y_true) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Ensure y_true has the same shape as y_pred_act | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if y_true.shape != y_pred_act.shape: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if y_true.ndim == y_pred_act.ndim - 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| y_true = y_true.unsqueeze(1) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+279
to
+280
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: when
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if y_true.shape != y_pred_act.shape: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) " | ||||||||||||||||||||||||||||||||||||||||||||||||||
| f"after activations/one-hot" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.reduction == LossReduction.SUM.value: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return torch.sum(loss) # sum over the batch and channel dims | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.reduction == LossReduction.NONE.value: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return loss # returns [N, num_classes] losses | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.reduction == LossReduction.MEAN.value: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return torch.mean(loss) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') | ||||||||||||||||||||||||||||||||||||||||||||||||||
| f_loss = self.asy_focal_loss(y_pred_act, y_true) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| t_loss = self.asy_focal_tversky_loss(y_pred_act, y_true) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| loss: torch.Tensor = self.lambda_focal * f_loss + (1 - self.lambda_focal) * t_loss | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| return loss | ||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is
to_onehot_ybeing removed? I think users will still want this functionality so I would leave it in and addinclude_backgroundas a new last argument. Even if we do want to remove it we need to use the deprecation decorators to mark the argument removed but still keep it for a version or two.