Skip to content

Conversation

@ytl0623
Copy link
Contributor

@ytl0623 ytl0623 commented Oct 30, 2025

Fixes #8603 .

Description

Addresses the limitations of AsymmetricUnifiedFocalLoss and its dependencies (AsymmetricFocalLoss, AsymmetricFocalTverskyLoss). The previous implementations were hard-coded for binary segmentation and lacked the standard sigmoid/softmax/to_onehot_y interface common to other MONAI losses.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 30, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

The unified focal loss module undergoes restructuring across three classes. AsymmetricFocalTverskyLoss and AsymmetricFocalLoss replace the to_onehot_y parameter with include_background, removing one-hot conversion logic and adding optional background exclusion with warnings for edge cases. AsymmetricUnifiedFocalLoss is refactored to support sigmoid/softmax activation controls, replacing weight-based configuration with separate gamma/delta parameters for focal and Tversky loss components, fused via lambda_focal. Forward methods updated to handle background exclusion, apply activations conditionally, and compute combined losses.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

  • Forward method logic: Each class's forward pass now contains conditional branches for include_background and activation modes; verify correctness of shape handling and loss computation for all paths.
  • Parameter migration: to_onehot_yinclude_background affects initialization and behavior; check backward compatibility implications and default value semantics.
  • Activation handling: Sigmoid/softmax mutual exclusivity, interaction with to_onehot_y, and per-activation loss shape expectations require careful validation.
  • Fused loss computation: AsymmetricUnifiedFocalLoss combines two sub-losses via lambda_focal; verify reduction semantics and combined loss shape across all reduction modes (MEAN, SUM, NONE).
  • Edge case warnings: Single-channel scenarios emit warnings; confirm warning conditions are accurate and non-disruptive.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (3 passed)
Check name Status Explanation
Title Check ✅ Passed The title "Generalize AsymmetricUnifiedFocalLoss for multi-class and align interface" accurately captures the primary changes in the PR. The changes include adding sigmoid/softmax activation controls, renaming the to_onehot_y parameter to include_background for better alignment with MONAI conventions, and updating all three loss classes to support multi-class segmentation beyond binary cases. The title is concise, specific, and clearly conveys the main intent without being overly verbose.
Linked Issues Check ✅ Passed The code changes address all primary coding objectives from issue #8603. The PR adds sigmoid/softmax activation controls [#8603], generalizes the loss classes to support multi-class segmentation via include_background parameter [#8603], aligns the API with standard MONAI conventions [#8603], and reuses existing loss components by composing AsymmetricFocalLoss and AsymmetricFocalTverskyLoss within the unified loss [#8603]. All modifications directly support the stated goals of removing binary-only assumptions and matching the FocalLoss interface.
Out of Scope Changes Check ✅ Passed All changes remain within scope of issue #8603 objectives. Parameter renaming, activation controls, multi-class generalization logic, and forward method updates all directly support the goal of adding sigmoid/softmax interface and removing binary-only constraints. The composition of sub-losses using existing components aligns with the preference to reuse implementations rather than reinvent functionality. No extraneous modifications to unrelated functionality were identified.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR successfully modernizes three asymmetric loss classes (AsymmetricFocalTverskyLoss, AsymmetricFocalLoss, and AsymmetricUnifiedFocalLoss) to align with MONAI's standard loss interface, addressing issue #8603. The changes generalize these losses from binary-only to multi-class segmentation support.

Key improvements:

  • Replaced deprecated to_onehot_y parameter with include_background in base loss classes
  • Added sigmoid/softmax activation support to AsymmetricUnifiedFocalLoss
  • Implemented proper multi-class handling with asymmetric foreground/background treatment
  • Improved documentation and parameter naming clarity
  • Refactored reduction logic to properly support "mean", "sum", and "none" modes

Issues found:

  • Critical: Logic error in one-hot conversion when sigmoid=True (lines 272-278) - attempts to create one-hot with num_classes=1 which produces incorrect shapes
  • Critical: Shape validation logic error (lines 281-287) - the condition y_true.ndim == y_pred_act.ndim - 1 evaluates incorrectly
  • Unused import (torch.nn.functional as F) should be removed
  • Shape validation before one-hot conversion could be more robust

The refactoring follows the pattern established by FocalLoss but contains logic errors that will cause runtime failures with certain activation configurations.

Confidence Score: 2/5

  • This PR has critical logic errors that will cause runtime failures with sigmoid activation and certain input shapes
  • Score of 2 reflects two critical logic errors: (1) one-hot conversion with sigmoid creates invalid shapes by using num_classes=1, and (2) shape dimension checking logic is incorrect. These will cause runtime failures in common use cases. The refactoring approach is sound and aligns well with MONAI patterns, but the implementation bugs must be fixed before merging.
  • monai/losses/unified_focal_loss.py lines 272-287 contain critical logic errors in shape handling and one-hot conversion that must be resolved

Important Files Changed

File Analysis

Filename Score Overview
monai/losses/unified_focal_loss.py 3/5 Generalizes loss functions from binary to multi-class with new sigmoid/softmax interface. Found critical logic errors in one-hot conversion with sigmoid and shape handling that will cause runtime failures.

Sequence Diagram

sequenceDiagram
    participant User
    participant AsymmetricUnifiedFocalLoss
    participant AsymmetricFocalLoss
    participant AsymmetricFocalTverskyLoss
    
    User->>AsymmetricUnifiedFocalLoss: forward(y_pred, y_true)
    
    alt sigmoid=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Apply sigmoid activation
    else softmax=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Apply softmax activation
    end
    
    alt to_onehot_y=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Convert y_true to one-hot format
    end
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Validate shapes match
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalLoss: forward(y_pred_act, y_true)
    
    alt include_background=False
        AsymmetricFocalLoss->>AsymmetricFocalLoss: Exclude background channel
    end
    
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Compute cross-entropy loss
    
    alt Multi-class with background
        AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply asymmetric weighting<br/>(gamma for BG, no gamma for FG)
    else Foreground only
        AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply foreground weighting
    end
    
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply reduction (mean/sum/none)
    AsymmetricFocalLoss-->>AsymmetricUnifiedFocalLoss: Return focal_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalTverskyLoss: forward(y_pred_act, y_true)
    
    alt include_background=False
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Exclude background channel
    end
    
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute TP, FN, FP
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute Tversky index
    
    alt Multi-class with background
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply asymmetric focal<br/>(no gamma for BG, gamma for FG)
    else Foreground only
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply foreground focal
    end
    
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply reduction (mean/sum/none)
    AsymmetricFocalTverskyLoss-->>AsymmetricUnifiedFocalLoss: Return tversky_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Combine losses<br/>lambda_focal * focal_loss + (1-lambda_focal) * tversky_loss
    AsymmetricUnifiedFocalLoss-->>User: Return combined_loss
Loading

1 file reviewed, 4 comments

Edit Code Review Agent Settings | Greptile


if not self.include_background:
# All classes are foreground, apply foreground logic
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: exponent should be 1 + self.gamma (not 1 - self.gamma)

The original paper's formula uses (1 - dice)^(1/gamma), which when gamma < 1, becomes (1 - dice)^(>1), increasing the penalty. With 1 - self.gamma when gamma=0.75, you get (1-dice)^0.25, which decreases the penalty for low dice scores.

Suggested change
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C)
loss = torch.pow(1.0 - dice_class, 1.0 + self.gamma) # (B, C)

loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C)
elif n_classes == 1:
# Single class, must be foreground (BG was excluded or not provided)
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: same exponent error here - should be 1 + self.gamma

Suggested change
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1)
loss = torch.pow(1.0 - dice_class, 1.0 + self.gamma) # (B, 1)

else:
# Asymmetric logic: class 0 is BG, others are FG
back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: same exponent error - should be 1 + self.gamma

Suggested change
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 + self.gamma) # (B, C-1)

elif n_pred_ch > 1 or self.sigmoid:
# Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
if y_true.shape[1] != 1:
y_true = y_true.unsqueeze(1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: unsqueezing unconditionally can cause issues if y_true is already (B, 1, H, W)

This adds an extra dimension. Should check y_true.ndim first or only unsqueeze when y_true.ndim == 3.

Suggested change
y_true = y_true.unsqueeze(1)
if y_true.ndim == 3: # (B, H, W) -> (B, 1, H, W)
y_true = y_true.unsqueeze(1)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 69f3dd2 and ec36c14.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.2)
monai/losses/unified_focal_loss.py

61-61: Avoid specifying long messages outside the exception class

(TRY003)


66-66: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


103-103: Avoid specifying long messages outside the exception class

(TRY003)


145-145: Avoid specifying long messages outside the exception class

(TRY003)


150-150: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


182-182: Avoid specifying long messages outside the exception class

(TRY003)


239-239: Avoid specifying long messages outside the exception class

(TRY003)


267-267: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


273-273: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


287-287: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: build-docs
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: packaging

@ytl0623 ytl0623 force-pushed the fix-issue-8603 branch 2 times, most recently from 2959015 to dca758d Compare October 30, 2025 02:43
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Generalizes AsymmetricUnifiedFocalLoss and its component losses for multi-class segmentation. Replaces to_onehot_y parameter with standard include_background/sigmoid/softmax interface matching other MONAI losses. Adds proper handling for background vs foreground classes using asymmetric weighting.

Critical Issues Found:

  • Mathematical errors in gamma exponents (lines 86, 89, 93) - uses 1.0 - self.gamma instead of 1.0 / self.gamma, inverting the focal weighting behavior
  • Syntax error in exception message (line 287) - unclosed string literal
  • Logic error in shape validation (line 283) - incorrect dimension checking that will cause IndexError

Changes:

  • Replaced to_onehot_y parameter with include_background in component losses
  • Added sigmoid/softmax activation support to AsymmetricUnifiedFocalLoss
  • Split gamma/delta parameters into separate focal and Tversky variants
  • Generalized from binary to multi-class segmentation support
  • Added proper reduction handling (none/mean/sum)

Confidence Score: 0/5

  • This PR is unsafe to merge due to critical mathematical errors and syntax error that will cause runtime failures
  • Score of 0 reflects three critical blocking issues: (1) mathematical formula errors in lines 86, 89, 93 where 1.0 - self.gamma should be 1.0 / self.gamma, completely inverting the focal weighting behavior, (2) syntax error on line 287 with unclosed string literal that will prevent code from running, and (3) logic error on line 283 that will cause IndexError when accessing shape dimensions
  • monai/losses/unified_focal_loss.py requires immediate attention - all identified errors must be fixed before merge

Important Files Changed

File Analysis

Filename Score Overview
monai/losses/unified_focal_loss.py 1/5 Critical mathematical errors in gamma exponents (lines 86, 89, 93), syntax error in exception message (line 287), and logic error in shape handling (line 283)

Sequence Diagram

sequenceDiagram
    participant User
    participant AsymmetricUnifiedFocalLoss
    participant AsymmetricFocalLoss
    participant AsymmetricFocalTverskyLoss
    
    User->>AsymmetricUnifiedFocalLoss: forward(y_pred, y_true)
    
    alt sigmoid=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: torch.sigmoid(y_pred)
    else softmax=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: torch.softmax(y_pred, dim=1)
    end
    
    alt to_onehot_y=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: one_hot(y_true)
    end
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Shape validation
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalLoss: forward(y_pred_act, y_true)
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Compute cross-entropy
    
    alt include_background=False
        AsymmetricFocalLoss->>AsymmetricFocalLoss: Exclude background channel
    else n_classes > 1
        AsymmetricFocalLoss->>AsymmetricFocalLoss: Asymmetric weighting<br/>(BG: (1-delta)*pow(1-p, gamma)*CE)<br/>(FG: delta*CE)
    end
    
    AsymmetricFocalLoss-->>AsymmetricUnifiedFocalLoss: f_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalTverskyLoss: forward(y_pred_act, y_true)
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute Tversky index
    
    alt include_background=False
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Exclude background channel
    else n_classes > 1
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Asymmetric weighting<br/>(BG: 1-dice)<br/>(FG: pow(1-dice, 1/gamma))
    end
    
    AsymmetricFocalTverskyLoss-->>AsymmetricUnifiedFocalLoss: t_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: lambda_focal * f_loss +<br/>(1 - lambda_focal) * t_loss
    AsymmetricUnifiedFocalLoss-->>User: combined_loss
Loading

1 file reviewed, 5 comments

Edit Code Review Agent Settings | Greptile


if not self.include_background:
# All classes are foreground, apply foreground logic
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: exponent formula is incorrect - should be 1.0 / self.gamma (not 1.0 - self.gamma)

The paper uses (1 - dice)^(1/gamma). When gamma=0.75, this gives (1-dice)^1.333, increasing penalty for low dice. Your formula gives (1-dice)^0.25, decreasing it.

Suggested change
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C)
loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C)

loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C)
elif n_classes == 1:
# Single class, must be foreground (BG was excluded or not provided)
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: same exponent error - should be 1.0 / self.gamma

Suggested change
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1)
loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1)

else:
# Asymmetric logic: class 0 is BG, others are FG
back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: same exponent error - should be 1.0 / self.gamma

Suggested change
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1)

Comment on lines 287 to 288
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape})
after activations/one-hot")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: syntax error - string literal not properly closed

Missing closing quote on line 287.

Suggested change
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape})
after activations/one-hot")
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) after activations/one-hot")

Comment on lines 283 to 284
if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
y_true = y_true.unsqueeze(1) # Add channel dim
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: logic error - condition check is incorrect

When y_true.ndim == y_pred_act.ndim - 1, y_true has fewer dimensions (e.g., (B,H,W) vs (B,C,H,W)), so the check y_true.shape[1] != y_pred_act.shape[1] will fail with IndexError.

Suggested change
if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
y_true = y_true.unsqueeze(1) # Add channel dim
if y_true.ndim == y_pred_act.ndim - 1: # e.g., (B, H, W) vs (B, C, H, W)
y_true = y_true.unsqueeze(1) # Add channel dim
elif y_true.shape[1] != y_pred_act.shape[1]:
y_true = y_true.unsqueeze(1) # Add channel dim if missing

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR successfully generalizes the AsymmetricUnifiedFocalLoss and its component losses (AsymmetricFocalLoss, AsymmetricFocalTverskyLoss) to support multi-class segmentation, replacing the previous binary-only limitation. The interface has been modernized to align with other MONAI losses by adding sigmoid/softmax/to_onehot_y parameters and replacing to_onehot_y with include_background in the component losses.

Major Changes:

  • Added sigmoid/softmax activation options to AsymmetricUnifiedFocalLoss
  • Replaced to_onehot_y with include_background parameter in component losses
  • Generalized all three loss classes to handle arbitrary number of classes
  • Improved parameter documentation and added clearer examples
  • Properly implemented reduction modes (mean, sum, none)

Critical Issues Found:

  • Incorrect exponent formula in AsymmetricFocalTverskyLoss: Uses 1.0 - self.gamma instead of 1.0 / self.gamma per the paper's formula (1-TI)^(1/γ). With default γ=0.75, this produces exponent 0.25 instead of 1.333, fundamentally breaking the loss behavior by reducing instead of increasing the penalty for low Dice scores.
  • Shape handling logic error in AsymmetricUnifiedFocalLoss line 282: When y_true.ndim == y_pred_act.ndim - 1, the code incorrectly tries to compare y_true.shape[1] (spatial dim H) with y_pred_act.shape[1] (channel dim C) before unsqueezing.
  • String syntax error on line 287: Missing closing quote on the error message string.

Confidence Score: 1/5

  • This PR has critical mathematical errors that break the loss function's core behavior
  • The incorrect exponent formula (1-γ instead of 1/γ) in AsymmetricFocalTverskyLoss fundamentally changes the loss behavior, causing it to decrease penalties for low Dice scores when it should increase them. Combined with the shape handling bug and syntax error, these issues will cause runtime failures and incorrect training behavior. The mathematical error affects all three loss classes and deviates from the published paper.
  • monai/losses/unified_focal_loss.py requires immediate attention - all four exponent formulas must be corrected (lines 85, 88, 92), the shape logic fixed (line 282), and the string syntax error resolved (line 287)

Important Files Changed

File Analysis

Filename Score Overview
monai/losses/unified_focal_loss.py 1/5 Generalizes losses for multi-class segmentation and adds sigmoid/softmax interface. Critical issues: incorrect exponent formula in AsymmetricFocalTverskyLoss (uses 1-γ instead of 1/γ), shape handling logic error in AsymmetricUnifiedFocalLoss, and string syntax error.

Sequence Diagram

sequenceDiagram
    participant User
    participant AUFL as AsymmetricUnifiedFocalLoss
    participant AFL as AsymmetricFocalLoss
    participant AFTL as AsymmetricFocalTverskyLoss
    
    User->>AUFL: forward(y_pred, y_true)
    
    alt sigmoid=True
        AUFL->>AUFL: Apply sigmoid activation
    else softmax=True
        AUFL->>AUFL: Apply softmax activation
    end
    
    alt to_onehot_y=True
        AUFL->>AUFL: Convert y_true to one-hot
    end
    
    AUFL->>AUFL: Shape alignment checks
    
    AUFL->>AFL: forward(y_pred_act, y_true)
    AFL->>AFL: Exclude background if needed
    AFL->>AFL: Compute cross_entropy
    
    alt Multi-class with background
        AFL->>AFL: Apply asymmetric weights<br/>BG: (1-δ)*(1-p)^γ*CE<br/>FG: δ*CE
    else Foreground only
        AFL->>AFL: Apply foreground weight δ*CE
    end
    
    AFL->>AFL: Apply reduction
    AFL-->>AUFL: Return focal_loss
    
    AUFL->>AFTL: forward(y_pred_act, y_true)
    AFTL->>AFTL: Exclude background if needed
    AFTL->>AFTL: Compute Tversky Index
    
    alt Multi-class with background
        AFTL->>AFTL: Apply asymmetric logic<br/>BG: 1-TI<br/>FG: (1-TI)^(1-γ) [INCORRECT]
    else Foreground only
        AFTL->>AFTL: Apply (1-TI)^(1-γ) [INCORRECT]
    end
    
    AFTL->>AFTL: Apply reduction
    AFTL-->>AUFL: Return tversky_loss
    
    AUFL->>AUFL: Combine losses<br/>λ*focal_loss + (1-λ)*tversky_loss
    AUFL-->>User: Return combined loss
Loading

1 file reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 282 to 283
if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
y_true = y_true.unsqueeze(1) # Add channel dim
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: IndexError when y_true.ndim == y_pred_act.ndim - 1

When y_true is (B, H, W) and y_pred_act is (B, C, H, W), the condition y_true.ndim == y_pred_act.ndim - 1 is True, but then checking y_true.shape[1] != y_pred_act.shape[1] compares H with C, which is nonsensical. The logic should unsqueeze immediately when ndim differs.

Suggested change
if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
y_true = y_true.unsqueeze(1) # Add channel dim
if y_true.ndim == y_pred_act.ndim - 1:
y_true = y_true.unsqueeze(1) # Add channel dim
elif y_true.shape[1] != y_pred_act.shape[1]:


if not self.include_background:
# All classes are foreground, apply foreground logic
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Incorrect exponent formula - should be 1.0 / self.gamma

The Unified Focal Loss paper specifies AFTL = (1-TI)^(1/γ). With default γ=0.75, this gives exponent 1.333, increasing penalty for low Dice scores. Current formula 1.0 - self.gamma gives 0.25, incorrectly decreasing the penalty.

Suggested change
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C)
loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C)

loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C)
elif n_classes == 1:
# Single class, must be foreground (BG was excluded or not provided)
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Same exponent error - should be 1.0 / self.gamma

Suggested change
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1)
loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1)

else:
# Asymmetric logic: class 0 is BG, others are FG
back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Same exponent error - should be 1.0 / self.gamma

Suggested change
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1)

loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss
if y_true.shape != y_pred_act.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) " \
f"after activations/one-hot")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: String literal syntax error - missing closing quote

Suggested change
f"after activations/one-hot")
"after activations/one-hot")

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (3)
monai/losses/unified_focal_loss.py (3)

85-92: Critical: Incorrect exponent formula in all three branches.

The paper specifies (1 - dice)^(1/gamma), but lines 85, 88, and 92 use 1.0 - self.gamma. With gamma=0.75, you get (1-dice)^0.25 (reduces penalty for low Dice), when you should get (1-dice)^1.333 (increases penalty).

Apply this diff:

         if not self.include_background:
             # All classes are foreground, apply foreground logic
-            loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma)  # (B, C)
+            loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma)  # (B, C)
         elif n_classes == 1:
             # Single class, must be foreground (BG was excluded or not provided)
-            loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma)  # (B, 1)
+            loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma)  # (B, 1)
         else:
             # Asymmetric logic: class 0 is BG, others are FG
             back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1)  # (B, 1)
-            fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma)  # (B, C-1)
+            fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma)  # (B, C-1)
             loss = torch.cat([back_dice_loss, fore_dice_loss], dim=1)  # (B, C)

270-277: Critical: sigmoid+one-hot causes ValueError with binary classification.

When sigmoid=True and n_pred_ch=1, line 273's condition passes, triggering one_hot(y_true, num_classes=1) at line 277. Any ground truth voxel with value 1 raises ValueError: class values must be smaller than num_classes.

Apply this diff:

         if self.to_onehot_y:
-            if n_pred_ch == 1 and not self.sigmoid:
+            if n_pred_ch == 1:
                 warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
-            elif n_pred_ch > 1 or self.sigmoid:
+            else:
                 # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
                 if y_true.shape[1] != 1:
                     y_true = y_true.unsqueeze(1)
                 y_true = one_hot(y_true, num_classes=n_pred_ch)

280-287: Critical: Logic error in shape alignment causes incorrect comparison.

Line 282 evaluates y_true.shape[1] != y_pred_act.shape[1] even when y_true.ndim == y_pred_act.ndim - 1. If y_true is (B, H, W) and y_pred_act is (B, C, H, W), you're comparing H (spatial dim) to C (channels), which is nonsensical.

Additionally, line 287 has a syntax error—f-string is split across lines incorrectly per Ruff.

Apply this diff:

         # Ensure y_true has the same shape as y_pred_act
         if y_true.shape != y_pred_act.shape:
-             # This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid
-            if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
+            if y_true.ndim == y_pred_act.ndim - 1:  # e.g., (B, H, W) vs (B, C, H, W)
                 y_true = y_true.unsqueeze(1) # Add channel dim
-
-            if y_true.shape != y_pred_act.shape:
-                raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) " \
-                                 f"after activations/one-hot")
+            elif y_true.shape != y_pred_act.shape:
+                raise ValueError(
+                    f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) "
+                    f"after activations/one-hot"
+                )
🧹 Nitpick comments (1)
monai/losses/unified_focal_loss.py (1)

44-48: Improve parameter documentation clarity.

Line 45 says delta is "weight of the background," but line 129 says it's "weight of the foreground." The docstrings should clarify that delta controls the FN/FP trade-off in Tversky (line 45) vs. background/foreground weight in focal CE (line 129). Also missing: documentation of raised exceptions (ValueError) per coding guidelines.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between ec36c14 and 6b0c14c.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.2)
monai/losses/unified_focal_loss.py

287-287: f-string: unterminated string

(invalid-syntax)


287-288: Expected FStringEnd, found newline

(invalid-syntax)


288-288: Expected ',', found indent

(invalid-syntax)


288-288: Expected ',', found name

(invalid-syntax)


288-288: missing closing quote in string literal

(invalid-syntax)


288-289: Expected ',', found newline

(invalid-syntax)


291-291: Expected ')', found dedent

(invalid-syntax)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: build-docs
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: packaging
🔇 Additional comments (4)
monai/losses/unified_focal_loss.py (4)

146-152: LGTM on background exclusion logic.

The include_background handling is consistent with AsymmetricFocalTverskyLoss.


237-238: LGTM on mutual exclusivity check.

Correctly prevents both sigmoid and softmax from being enabled simultaneously.


290-295: LGTM on loss fusion logic.

Correctly delegates to sub-loss instances and combines via weighted sum. Assuming sub-losses return scalar values when reduction="mean" or reduction="sum", this will work.


195-251: Add unit tests for new sigmoid/softmax interface.

The PR adds sigmoid/softmax activation options and renames to_onehot_y to include_background, but no new tests are mentioned. These API changes need coverage to verify:

  • sigmoid/softmax mutual exclusivity
  • One-hot conversion with different activation modes
  • Background inclusion/exclusion behavior
  • Multi-class vs. binary edge cases

Per coding guidelines, modified definitions must be covered by tests. Do you want me to generate test cases or open an issue to track this?

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR generalizes AsymmetricUnifiedFocalLoss and its component losses to support multi-class segmentation and adds the standard sigmoid/softmax/to_onehot_y interface common to other MONAI losses.

Major changes:

  • Replaced to_onehot_y parameter with include_background in component losses (AsymmetricFocalLoss, AsymmetricFocalTverskyLoss)
  • Added sigmoid/softmax/to_onehot_y parameters to AsymmetricUnifiedFocalLoss
  • Refactored loss computation to handle multi-class scenarios
  • Updated interface to match MONAI's FocalLoss pattern

Critical issues found:

  • The gamma exponent formula in AsymmetricFocalTverskyLoss is mathematically incorrect on lines 85, 88, and 92. Uses 1.0 - self.gamma instead of 1.0 / self.gamma, which inverts the focal penalty behavior
  • Dimension handling logic error on lines 282-283 in AsymmetricUnifiedFocalLoss where shape comparison happens before necessary dimension adjustment

Confidence Score: 1/5

  • This PR contains critical mathematical errors that will cause incorrect loss computation
  • The gamma exponent formula is fundamentally wrong in three places (lines 85, 88, 92), using 1.0 - gamma instead of 1.0 / gamma. This inverts the focal penalty behavior. Additionally, there's a dimension handling logic error (lines 282-283). These are not edge cases but core functionality bugs that will affect all users.
  • monai/losses/unified_focal_loss.py requires immediate attention to fix the gamma exponent formula and dimension handling logic

Important Files Changed

File Analysis

Filename Score Overview
monai/losses/unified_focal_loss.py 1/5 Critical mathematical errors in gamma exponent formula (lines 85, 88, 92) and dimension handling logic error (line 282-283). These will cause incorrect loss computation.

Sequence Diagram

sequenceDiagram
    participant User
    participant AsymmetricUnifiedFocalLoss
    participant AsymmetricFocalLoss
    participant AsymmetricFocalTverskyLoss
    
    User->>AsymmetricUnifiedFocalLoss: forward(y_pred, y_true)
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Apply sigmoid/softmax if enabled
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Convert y_true to one-hot if to_onehot_y=True
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Ensure y_true and y_pred_act shapes match
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalLoss: forward(y_pred_act, y_true)
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Exclude background if include_background=False
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Compute cross-entropy loss
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply asymmetric focal weighting
    AsymmetricFocalLoss-->>AsymmetricUnifiedFocalLoss: f_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalTverskyLoss: forward(y_pred_act, y_true)
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Exclude background if include_background=False
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute Tversky index (TP, FN, FP)
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply focal exponent (WRONG: 1-gamma instead of 1/gamma)
    AsymmetricFocalTverskyLoss-->>AsymmetricUnifiedFocalLoss: t_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Combine: lambda_focal * f_loss + (1-lambda_focal) * t_loss
    AsymmetricUnifiedFocalLoss-->>User: loss
Loading

1 file reviewed, 4 comments

Edit Code Review Agent Settings | Greptile


if not self.include_background:
# All classes are foreground, apply foreground logic
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: incorrect exponent - should be 1.0 / self.gamma (not 1.0 - self.gamma)

The Unified Focal Loss paper specifies AFTL = (1-TI)^(1/γ). With default γ=0.75, this gives (1-dice)^1.333, increasing penalty for low dice. Current 1.0 - self.gamma gives (1-dice)^0.25, incorrectly decreasing penalty.

Suggested change
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C)
loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C)

loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C)
elif n_classes == 1:
# Single class, must be foreground (BG was excluded or not provided)
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: same exponent error - should be 1.0 / self.gamma

Suggested change
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1)
loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1)

else:
# Asymmetric logic: class 0 is BG, others are FG
back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: same exponent error - should be 1.0 / self.gamma

Suggested change
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1)

Comment on lines 282 to 281
if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
y_true = y_true.unsqueeze(1) # Add channel dim
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: logic error - when y_true.ndim == y_pred_act.ndim - 1, checking y_true.shape[1] on line 282 will compare wrong dimensions

When y_true is (B, H, W) and y_pred_act is (B, C, H, W), the first condition is True, but then checking y_true.shape[1] (which is H) against y_pred_act.shape[1] (which is C) is meaningless. Should unsqueeze immediately.

Suggested change
if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
y_true = y_true.unsqueeze(1) # Add channel dim
if y_true.ndim == y_pred_act.ndim - 1: # e.g., (B, H, W) vs (B, C, H, W)
y_true = y_true.unsqueeze(1) # Add channel dim
elif y_true.shape[1] != y_pred_act.shape[1]:
y_true = y_true.unsqueeze(1)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
monai/losses/unified_focal_loss.py (1)

34-296: Add test coverage for new sigmoid/softmax/include_background modes.

Current tests only cover default configuration with perfect predictions. Missing coverage:

  • Binary segmentation with sigmoid=True
  • Multi-class with softmax=True
  • include_background=False with sub-loss variations
  • to_onehot_y=True edge cases (sigmoid binary, multi-class)
  • Reduction modes (SUM, NONE)
  • Sub-loss classes in isolation
♻️ Duplicate comments (3)
monai/losses/unified_focal_loss.py (3)

270-277: Critical: one-hot conversion fails for sigmoid with single channel.

When sigmoid=True and n_pred_ch==1, line 273 triggers one_hot(..., num_classes=1), which raises ValueError for any label==1. The original issue #8603 requested sigmoid/softmax support, but this breaks binary sigmoid cases.

Apply this diff:

         if self.to_onehot_y:
-            if n_pred_ch == 1 and not self.sigmoid:
+            if n_pred_ch == 1:
                 warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
-            elif n_pred_ch > 1 or self.sigmoid:
+            else:
                 # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
                 if y_true.shape[1] != 1:
                     y_true = y_true.unsqueeze(1)
                 y_true = one_hot(y_true, num_classes=n_pred_ch)

280-284: Major: IndexError when y_true is missing channel dim.

Line 282 checks y_true.ndim == y_pred_act.ndim - 1 (e.g., y_true is (B,H,W)), but then accesses y_true.shape[1], which indexes the spatial dimension H, not the channel. This causes incorrect logic or IndexError.

Apply this diff:

         # Ensure y_true has the same shape as y_pred_act
         if y_true.shape != y_pred_act.shape:
             # This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid
-            if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
+            if y_true.ndim == y_pred_act.ndim - 1:
                 y_true = y_true.unsqueeze(1)  # Add channel dim
+            elif y_true.shape[1] != y_pred_act.shape[1]:
+                y_true = y_true.unsqueeze(1)  # Add channel dim if mismatch

85-92: Critical: Exponent formula is mathematically incorrect.

All three branches use 1.0 - self.gamma but the Unified Focal Loss paper specifies (1 - TI)^(1/γ). With default γ=0.75, your formula gives exponent 0.25 (weakens penalty for poor predictions), but the paper requires 1.333 (strengthens penalty).

Apply this diff:

         if not self.include_background:
             # All classes are foreground, apply foreground logic
-            loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma)  # (B, C)
+            loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma)  # (B, C)
         elif n_classes == 1:
             # Single class, must be foreground (BG was excluded or not provided)
-            loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma)  # (B, 1)
+            loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma)  # (B, 1)
         else:
             # Asymmetric logic: class 0 is BG, others are FG
             back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1)  # (B, 1)
-            fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma)  # (B, C-1)
+            fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma)  # (B, C-1)
             loss = torch.cat([back_dice_loss, fore_dice_loss], dim=1)  # (B, C)
🧹 Nitpick comments (4)
monai/losses/unified_focal_loss.py (4)

65-65: Add stacklevel to warning.

Per PEP 565, warnings should specify stacklevel=2 so users see the caller's location.

-                warnings.warn("single channel prediction, `include_background=False` ignored.")
+                warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)

149-149: Add stacklevel to warning.

-                warnings.warn("single channel prediction, `include_background=False` ignored.")
+                warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)

266-266: Add stacklevel to warning.

-                warnings.warn("single channel prediction, softmax=True ignored.")
+                warnings.warn("single channel prediction, softmax=True ignored.", stacklevel=2)

272-272: Add stacklevel to warning.

-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 6b0c14c and f5a2f7e.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.2)
monai/losses/unified_focal_loss.py

60-60: Avoid specifying long messages outside the exception class

(TRY003)


65-65: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


102-102: Avoid specifying long messages outside the exception class

(TRY003)


144-144: Avoid specifying long messages outside the exception class

(TRY003)


149-149: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


181-181: Avoid specifying long messages outside the exception class

(TRY003)


238-238: Avoid specifying long messages outside the exception class

(TRY003)


266-266: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


272-272: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


286-289: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: packaging
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: build-docs
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-py3 (3.9)
🔇 Additional comments (2)
monai/losses/unified_focal_loss.py (2)

154-172: Asymmetric weighting logic is correct.

Background class gets gamma-modulated penalty; foreground classes get flat delta weighting. This aligns with the asymmetric focal loss design.


291-296: Fused loss computation is correct.

The weighted combination lambda_focal * focal + (1 - lambda_focal) * tversky properly unifies both loss components.

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Refactored AsymmetricUnifiedFocalLoss and its component losses to support multi-class segmentation and added standard MONAI loss interface (sigmoid/softmax/to_onehot_y/include_background).

Key Changes:

  • Replaced to_onehot_y parameter with include_background in component losses
  • Added sigmoid/softmax activation options to main loss
  • Generalized asymmetric logic to handle multi-class cases
  • Improved parameter naming (lambda_focal instead of weight, separate gamma/delta for each component)
  • Updated reduction logic to handle different output shapes

Critical Issues Identified (from previous comments):

  • Incorrect exponent formula in AsymmetricFocalTverskyLoss (should be 1.0 / self.gamma not 1.0 - self.gamma)
  • Shape handling bug on line 280 that checks wrong dimensions when y_true.ndim == y_pred_act.ndim - 1
  • String literal syntax error on line 287 (unclosed quote)

Confidence Score: 1/5

  • This PR has critical mathematical errors and runtime bugs that will cause incorrect behavior
  • Multiple critical logic errors have been identified: (1) incorrect exponent formula in 3 locations that fundamentally breaks the loss calculation, (2) shape dimension checking bug that will cause IndexError at runtime, (3) syntax error with unclosed string. These issues need to be resolved before merge.
  • monai/losses/unified_focal_loss.py requires immediate attention - fix exponent formulas on lines 83, 86, 90 and shape checking logic on lines 280-287

Important Files Changed

File Analysis

Filename Score Overview
monai/losses/unified_focal_loss.py 1/5 Refactored to support multi-class segmentation with sigmoid/softmax interface. Critical exponent formula errors need fixing, and shape handling logic has bugs that could cause runtime errors.

Sequence Diagram

sequenceDiagram
    participant User
    participant AsymmetricUnifiedFocalLoss
    participant AsymmetricFocalLoss
    participant AsymmetricFocalTverskyLoss
    
    User->>AsymmetricUnifiedFocalLoss: forward(y_pred, y_true)
    
    alt sigmoid==True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Apply sigmoid(y_pred)
    else softmax==True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Apply softmax(y_pred)
    end
    
    alt to_onehot_y==True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Convert y_true to one-hot
    end
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Shape alignment checks
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalLoss: forward(y_pred_act, y_true)
    
    alt include_background==False
        AsymmetricFocalLoss->>AsymmetricFocalLoss: Exclude channel 0
    end
    
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Compute cross_entropy
    
    alt n_classes > 1 and include_background
        AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply asymmetric logic<br/>(BG: gamma exponent, FG: no exponent)
    else
        AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply foreground logic<br/>(delta * cross_entropy)
    end
    
    AsymmetricFocalLoss-->>AsymmetricUnifiedFocalLoss: f_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalTverskyLoss: forward(y_pred_act, y_true)
    
    alt include_background==False
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Exclude channel 0
    end
    
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute TP, FN, FP
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute Tversky index
    
    alt n_classes > 1 and include_background
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply asymmetric logic<br/>(BG: no exponent, FG: gamma exponent)
    else
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply foreground logic<br/>(gamma exponent)
    end
    
    AsymmetricFocalTverskyLoss-->>AsymmetricUnifiedFocalLoss: t_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Combine: lambda*f_loss + (1-lambda)*t_loss
    
    AsymmetricUnifiedFocalLoss-->>User: Combined loss
Loading

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
monai/losses/unified_focal_loss.py (2)

83-90: Critical: incorrect exponent formula breaks Tversky loss behavior.

Lines 83, 86, and 90 use 1.0 - self.gamma but the paper specifies (1 - TI)^(1/γ). With default γ=0.75, the correct exponent is 1.333, increasing penalty for low Dice scores. Current formula gives 0.25, inverting the intended behavior.

Apply this diff:

-            loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma)  # (B, C)
+            loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma)  # (B, C)
-            loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma)  # (B, 1)
+            loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma)  # (B, 1)
-            fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma)  # (B, C-1)
+            fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma)  # (B, C-1)

280-281: Critical: IndexError when y_true has fewer dimensions.

When y_true.ndim == y_pred_act.ndim - 1 (e.g., (B,H,W) vs (B,C,H,W)), checking y_true.shape[1] compares spatial dimension H with channel dimension C, which is nonsensical and may raise IndexError.

Apply this diff:

-            if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
-                y_true = y_true.unsqueeze(1)  # Add channel dim
+            if y_true.ndim == y_pred_act.ndim - 1:
+                y_true = y_true.unsqueeze(1)  # Add channel dim
+            elif y_true.shape[1] != y_pred_act.shape[1]:
+                y_true = y_true.unsqueeze(1)
🧹 Nitpick comments (4)
monai/losses/unified_focal_loss.py (4)

65-65: Add stacklevel to warning.

-                warnings.warn("single channel prediction, `include_background=False` ignored.")
+                warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)

147-147: Add stacklevel to warning.

-                warnings.warn("single channel prediction, `include_background=False` ignored.")
+                warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)

264-264: Add stacklevel to warning.

-                warnings.warn("single channel prediction, softmax=True ignored.")
+                warnings.warn("single channel prediction, softmax=True ignored.", stacklevel=2)

270-270: Add stacklevel to warning.

-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between f5a2f7e and b4e0fcc.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.2)
monai/losses/unified_focal_loss.py

60-60: Avoid specifying long messages outside the exception class

(TRY003)


65-65: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


100-100: Avoid specifying long messages outside the exception class

(TRY003)


142-142: Avoid specifying long messages outside the exception class

(TRY003)


147-147: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


179-179: Avoid specifying long messages outside the exception class

(TRY003)


236-236: Avoid specifying long messages outside the exception class

(TRY003)


264-264: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


270-270: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


284-287: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: build-docs
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: packaging
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-py3 (3.11)
🔇 Additional comments (4)
monai/losses/unified_focal_loss.py (4)

152-170: Asymmetric focal loss logic is correct.

Background class applies gamma exponent to (1-p), foreground classes don't. This matches the paper's asymmetric formulation.


238-249: Sub-loss composition approach is sound.

Creating separate AsymmetricFocalLoss and AsymmetricFocalTverskyLoss instances with independent gamma/delta parameters provides clean separation and reuses existing components per PR objectives.


235-236: Mutual exclusivity check for sigmoid/softmax is correct.


289-294: Fused loss computation is correct.

Weighted combination of focal and Tversky losses matches the unified focal loss formulation.

@ytl0623
Copy link
Contributor Author

ytl0623 commented Oct 30, 2025

Hi @KumoLiu, @Nic-Ma and @ericspod,

Sorry to bother.
Could you please review the PR or give some tips?

Thanks in advance!

Copy link
Member

@ericspod ericspod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ytl0623 thanks for the contribution. I had only a minor comment but please look at the greptile comments. A few look spurious they should all be resolved.


def __init__(
self,
to_onehot_y: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is to_onehot_y being removed? I think users will still want this functionality so I would leave it in and add include_background as a new last argument. Even if we do want to remove it we need to use the deprecation decorators to mark the argument removed but still keep it for a version or two.

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR modernizes the AsymmetricUnifiedFocalLoss and its component losses (AsymmetricFocalLoss, AsymmetricFocalTverskyLoss) to support multi-class segmentation and adds standard MONAI loss interfaces (sigmoid/softmax/to_onehot_y/include_background).

Key changes:

  • Replaced to_onehot_y parameter with include_background in base loss classes, matching MONAI conventions
  • Added sigmoid/softmax activation options to AsymmetricUnifiedFocalLoss
  • Generalized from binary-only to multi-class segmentation support
  • Maintained asymmetric treatment: background class (index 0) receives different weighting than foreground classes
  • Updated reduction logic to properly handle "none"/"mean"/"sum" modes

Critical issues identified:

  1. Incorrect exponent formula: Lines 83, 86, 90 use 1.0 - self.gamma but should use 1.0 / self.gamma according to the Unified Focal Loss paper's formula: AFTL = (1-TI)^(1/γ). With default γ=0.75, current code gives (1-dice)^0.25 (decreasing penalty) instead of (1-dice)^1.333 (increasing penalty).

  2. Shape handling logic error: Line 280 checks y_true.shape[1] != y_pred_act.shape[1] after confirming y_true.ndim == y_pred_act.ndim - 1, which compares spatial dimension H with channel dimension C.

Recommendations:

  • Fix the critical exponent formula errors immediately
  • Fix the shape handling conditional logic
  • Add tests covering multi-class scenarios and new sigmoid/softmax interfaces
  • Verify against reference implementation or paper's equations

Confidence Score: 1/5

  • This PR has critical mathematical errors that will cause incorrect loss computation and training behavior
  • Score of 1 reflects multiple critical logical errors: (1) incorrect exponent formula in AsymmetricFocalTverskyLoss using 1.0 - self.gamma instead of 1.0 / self.gamma, causing fundamentally wrong loss values that contradict the referenced paper; (2) shape handling logic error comparing incompatible dimensions. These are not edge cases but affect core functionality. The PR cannot be merged until these mathematical errors are corrected.
  • The only changed file monai/losses/unified_focal_loss.py requires immediate attention - specifically lines 83, 86, 90 (exponent formula) and lines 280-281 (shape handling logic)

Important Files Changed

File Analysis

Filename Score Overview
monai/losses/unified_focal_loss.py 1/5 Generalizes AsymmetricUnifiedFocalLoss for multi-class segmentation and adds sigmoid/softmax interface. Critical issues: incorrect exponent formula (uses 1.0 - self.gamma instead of 1.0 / self.gamma) in AsymmetricFocalTverskyLoss, shape handling logic error when y_true.ndim differs from y_pred_act.ndim

Sequence Diagram

sequenceDiagram
    participant User
    participant AsymmetricUnifiedFocalLoss
    participant AsymmetricFocalLoss
    participant AsymmetricFocalTverskyLoss

    User->>AsymmetricUnifiedFocalLoss: forward(y_pred, y_true)
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Apply sigmoid/softmax if requested
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Convert y_true to one-hot if to_onehot_y=True
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Ensure shape compatibility
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalLoss: forward(y_pred_act, y_true)
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Exclude background if needed
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Compute cross-entropy loss
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply asymmetric weighting (BG vs FG)
    AsymmetricFocalLoss-->>AsymmetricUnifiedFocalLoss: f_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalTverskyLoss: forward(y_pred_act, y_true)
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Exclude background if needed
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute Tversky Index (TP, FN, FP)
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply focal term: (1-TI)^(1/gamma)
    AsymmetricFocalTverskyLoss-->>AsymmetricUnifiedFocalLoss: t_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Combine: lambda*f_loss + (1-lambda)*t_loss
    AsymmetricUnifiedFocalLoss-->>User: loss
Loading

1 file reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 280 to 281
if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
y_true = y_true.unsqueeze(1) # Add channel dim
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: condition will fail when y_true.ndim == y_pred_act.ndim - 1

When y_true has shape (B, H, W) and y_pred_act has shape (B, C, H, W), the first condition y_true.ndim == y_pred_act.ndim - 1 evaluates to True (3 == 4-1). The code then proceeds to check y_true.shape[1] != y_pred_act.shape[1] on line 280, which compares H (spatial dimension) with C (channel dimension). This is a logical error.

The code should unsqueeze immediately when ndim differs:

Suggested change
if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
y_true = y_true.unsqueeze(1) # Add channel dim
if y_true.ndim == y_pred_act.ndim - 1: # e.g., (B, H, W) vs (B, C, H, W)
y_true = y_true.unsqueeze(1) # Add channel dim
elif y_true.shape[1] != y_pred_act.shape[1]:

Comment on lines 271 to 275
elif n_pred_ch > 1 or self.sigmoid:
# Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
if y_true.shape[1] != 1:
y_true = y_true.unsqueeze(1)
y_true = one_hot(y_true, num_classes=n_pred_ch)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: one-hot with num_classes=1 when sigmoid=True and n_pred_ch==1

When sigmoid=True and n_pred_ch == 1, line 269's condition is False (because not self.sigmoid is False), so the code proceeds to line 271. Since self.sigmoid is True, the condition n_pred_ch > 1 or self.sigmoid is True, and one-hot encoding is performed with num_classes=1.

One-hot encoding with 1 class doesn't make logical sense - it just keeps the single channel as-is. For binary segmentation with sigmoid, the pattern should be: input has 1 channel (foreground probability), and one-hot should convert to 2 channels (background, foreground).

Suggested change
elif n_pred_ch > 1 or self.sigmoid:
# Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
if y_true.shape[1] != 1:
y_true = y_true.unsqueeze(1)
y_true = one_hot(y_true, num_classes=n_pred_ch)
elif n_pred_ch > 1 or (self.sigmoid and n_pred_ch == 1):
# Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
if y_true.shape[1] != 1:
y_true = y_true.unsqueeze(1)
# For binary sigmoid case, need 2 classes
num_classes = 2 if (self.sigmoid and n_pred_ch == 1) else n_pred_ch
y_true = one_hot(y_true, num_classes=num_classes)

Comment on lines 268 to 275
if self.to_onehot_y:
if n_pred_ch == 1:
if n_pred_ch == 1 and not self.sigmoid:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
elif n_pred_ch > 1 or self.sigmoid:
# Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
if y_true.shape[1] != 1:
y_true = y_true.unsqueeze(1)
y_true = one_hot(y_true, num_classes=n_pred_ch)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: potential shape mismatch with sigmoid=True, n_pred_ch=1, and to_onehot_y=True

When using sigmoid=True with 1 input channel, y_pred_act becomes (B, 1, H, W). If to_onehot_y=True, the code converts y_true to one-hot with num_classes=1, keeping it (B, 1, H, W).

Consider adding a warning or documentation note that for binary segmentation with sigmoid=True and to_onehot_y=True, users should keep y_true as 0/1 masks without one-hot conversion, since one-hot with 1 class is redundant.

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR generalizes AsymmetricUnifiedFocalLoss, AsymmetricFocalLoss, and AsymmetricFocalTverskyLoss to support multi-class segmentation and adds the standard sigmoid/softmax/to_onehot_y interface used across other MONAI losses.

Key changes:

  • Replaced to_onehot_y parameter with include_background in base loss classes
  • Added use_sigmoid/use_softmax activation options to AsymmetricUnifiedFocalLoss
  • Generalized asymmetric logic to handle arbitrary number of classes (not just binary)
  • Separated gamma/delta parameters for focal and Tversky components
  • Updated parameter names for clarity (lambda_focal instead of weight)

Critical issues found:

  • Incorrect gamma exponent formula: Uses 1.0 - self.gamma but should be 1.0 / self.gamma per the paper (lines 83, 86, 90). With default gamma=0.75, current code gives (1-dice)^0.25 (decreasing penalty) instead of (1-dice)^1.333 (increasing penalty)
  • Dimension checking logic error: Line 280 compares spatial dimension with channel dimension when y_true.ndim differs from y_pred_act.ndim
  • One-hot with 1 class: When use_sigmoid=True with single channel, creates one-hot encoding with num_classes=1 which is illogical

Confidence Score: 1/5

  • This PR has critical mathematical errors that will cause incorrect loss values and training behavior
  • The incorrect gamma exponent formula (1-gamma instead of 1/gamma) is a fundamental mathematical error that affects all three loss classes. This will cause the loss to behave opposite to its intended design, reducing penalties for poor predictions instead of increasing them. Combined with the dimension checking bug and one-hot encoding issue, this PR cannot be merged safely without fixes.
  • The single file monai/losses/unified_focal_loss.py requires immediate attention to fix the gamma exponent formula in all three loss classes (lines 83, 86, 90), the dimension checking logic (line 280), and the one-hot encoding logic (line 271)

Important Files Changed

File Analysis

Filename Score Overview
monai/losses/unified_focal_loss.py 1/5 Generalizes three loss classes for multi-class support with new sigmoid/softmax interface. Critical issues: incorrect gamma exponent formula (uses 1-gamma instead of 1/gamma), dimension checking logic error, and one-hot encoding with 1 class.

Sequence Diagram

sequenceDiagram
    participant User
    participant AUFL as AsymmetricUnifiedFocalLoss
    participant AFL as AsymmetricFocalLoss
    participant AFTL as AsymmetricFocalTverskyLoss
    
    User->>AUFL: forward(y_pred, y_true)
    
    alt use_sigmoid
        AUFL->>AUFL: y_pred_act = sigmoid(y_pred)
    else use_softmax
        AUFL->>AUFL: y_pred_act = softmax(y_pred, dim=1)
    end
    
    alt to_onehot_y
        AUFL->>AUFL: Convert y_true to one-hot
    end
    
    AUFL->>AUFL: Shape alignment checks
    
    AUFL->>AFL: forward(y_pred_act, y_true)
    
    alt include_background=False
        AFL->>AFL: Exclude channel 0
    end
    
    AFL->>AFL: cross_entropy = -y_true * log(y_pred)
    
    alt Asymmetric logic
        AFL->>AFL: back_ce with gamma exponent
        AFL->>AFL: fore_ce without gamma
        AFL->>AFL: Combine and reduce
    end
    
    AFL-->>AUFL: f_loss
    
    AUFL->>AFTL: forward(y_pred_act, y_true)
    
    alt include_background=False
        AFTL->>AFTL: Exclude channel 0
    end
    
    AFTL->>AFTL: Calculate tp, fn, fp
    AFTL->>AFTL: dice_class = (tp + ε) / (tp + δ*fn + (1-δ)*fp + ε)
    
    alt Asymmetric logic
        AFTL->>AFTL: back_dice without gamma
        AFTL->>AFTL: fore_dice with gamma exponent
        AFTL->>AFTL: Combine and reduce
    end
    
    AFTL-->>AUFL: t_loss
    
    AUFL->>AUFL: loss = λ*f_loss + (1-λ)*t_loss
    AUFL-->>User: loss
Loading

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +271 to 275
elif n_pred_ch > 1 or self.use_sigmoid:
# Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
if y_true.shape[1] != 1:
y_true = y_true.unsqueeze(1)
y_true = one_hot(y_true, num_classes=n_pred_ch)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: when use_sigmoid=True and n_pred_ch=1, one-hot with num_classes=1 doesn't make sense

When use_sigmoid=True with single channel, line 271's condition n_pred_ch > 1 or self.use_sigmoid evaluates to True (because self.use_sigmoid is True). This leads to one-hot encoding with num_classes=1, which is illogical. For binary sigmoid, should convert to 2 classes (background + foreground).

Comment on lines 280 to 281
if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
y_true = y_true.unsqueeze(1) # Add channel dim
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: condition logic is backwards - checks incompatible dimensions

When y_true.ndim == y_pred_act.ndim - 1 is True (e.g., (B,H,W) vs (B,C,H,W)), the code should immediately unsqueeze. Instead, it also checks y_true.shape[1] != y_pred_act.shape[1] which compares H (spatial) with C (channel) - wrong dimensions.

Suggested change
if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
y_true = y_true.unsqueeze(1) # Add channel dim
if y_true.ndim == y_pred_act.ndim - 1: # e.g., (B, H, W) vs (B, C, H, W)
y_true = y_true.unsqueeze(1) # Add channel dim
elif y_true.shape[1] != y_pred_act.shape[1]:

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (6)
monai/losses/unified_focal_loss.py (6)

83-83: Critical: incorrect exponent formula.

The Unified Focal Loss paper specifies (1 - dice)^(1/γ). With default γ=0.75, this gives (1 - dice)^1.333, increasing penalty for low Dice scores. Current formula 1.0 - self.gamma gives (1 - dice)^0.25, which incorrectly decreases the penalty.

Apply this diff:

-            loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma)  # (B, C)
+            loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma)  # (B, C)

86-86: Critical: same exponent error.

Apply this diff:

-            loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma)  # (B, 1)
+            loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma)  # (B, 1)

90-90: Critical: same exponent error.

Apply this diff:

-            fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma)  # (B, C-1)
+            fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma)  # (B, C-1)

118-118: Breaking change: parameter removed without deprecation.

Same issue as AsymmetricFocalTverskyLoss—to_onehot_y removed without deprecation decorator.


280-281: Critical: IndexError when y_true has fewer dimensions.

When y_true.ndim == y_pred_act.ndim - 1 (e.g., (B, H, W) vs (B, C, H, W)), the condition is True, but then checking y_true.shape[1] != y_pred_act.shape[1] compares H (spatial) with C (channel)—wrong dimensions. Should unsqueeze immediately when ndim differs.

Apply this diff:

-            if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
+            if y_true.ndim == y_pred_act.ndim - 1:
                 y_true = y_true.unsqueeze(1)  # Add channel dim
+            elif y_true.shape[1] != y_pred_act.shape[1]:
+                y_true = y_true.unsqueeze(1)

271-275: Major: one-hot with num_classes=1 when use_sigmoid=True.

When use_sigmoid=True with single-channel input (n_pred_ch=1), the condition n_pred_ch > 1 or self.use_sigmoid is True, leading to one_hot(y_true, num_classes=1). This raises ValueError when any voxel has label 1. For binary sigmoid, should convert to 2 classes (background + foreground) or skip one-hot entirely.

Apply this diff to handle binary sigmoid properly:

             elif n_pred_ch > 1 or self.use_sigmoid:
                 # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
                 if y_true.shape[1] != 1:
                     y_true = y_true.unsqueeze(1)
-                y_true = one_hot(y_true, num_classes=n_pred_ch)
+                # For binary sigmoid, need 2 classes; otherwise use n_pred_ch
+                num_classes = 2 if (self.use_sigmoid and n_pred_ch == 1) else n_pred_ch
+                y_true = one_hot(y_true, num_classes=num_classes)

Note: This also requires expanding y_pred_act from 1 to 2 channels for consistency.

🧹 Nitpick comments (5)
monai/losses/unified_focal_loss.py (5)

65-65: Add stacklevel=2 to warning.

-                warnings.warn("single channel prediction, `include_background=False` ignored.")
+                warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)

147-147: Add stacklevel=2 to warning.

-                warnings.warn("single channel prediction, `include_background=False` ignored.")
+                warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)

196-197: Parameter naming inconsistency.

Parameters named use_sigmoid and use_softmax in the signature, but stored as self.use_sigmoid and self.use_softmax (lines 230-231). However, checked as self.use_sigmoid and self.use_softmax throughout. This is consistent, but the naming convention differs from typical MONAI losses which use sigmoid/softmax (not use_sigmoid/use_softmax).

For consistency with other MONAI losses (e.g., FocalLoss), consider renaming to sigmoid and softmax:

-        use_sigmoid: bool = False,
-        use_softmax: bool = False,
+        sigmoid: bool = False,
+        softmax: bool = False,

Then update all references accordingly.


264-264: Add stacklevel=2 to warning.

-                warnings.warn("single channel prediction, use_softmax=True ignored.")
+                warnings.warn("single channel prediction, use_softmax=True ignored.", stacklevel=2)

270-270: Add stacklevel=2 to warning.

-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between b4e0fcc and 5f3d1d2.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.2)
monai/losses/unified_focal_loss.py

60-60: Avoid specifying long messages outside the exception class

(TRY003)


65-65: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


100-100: Avoid specifying long messages outside the exception class

(TRY003)


142-142: Avoid specifying long messages outside the exception class

(TRY003)


147-147: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


179-179: Avoid specifying long messages outside the exception class

(TRY003)


236-236: Avoid specifying long messages outside the exception class

(TRY003)


264-264: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


270-270: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


284-287: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (18)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: build-docs
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: packaging
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-pytorch (2.8.0)
🔇 Additional comments (1)
monai/losses/unified_focal_loss.py (1)

193-294: Request test coverage for new functionality and bug fixes.

This PR introduces significant changes: new activation modes (use_sigmoid/use_softmax), generalization to multi-class, parameter API changes, and critical bug fixes (exponent formula). Without tests, these changes cannot be validated.

Please add unit tests covering:

  1. Binary segmentation with use_sigmoid=True
  2. Multi-class segmentation with use_softmax=True
  3. Background inclusion/exclusion (include_background)
  4. One-hot conversion edge cases
  5. Shape alignment scenarios
  6. Correctness of exponent formula after fixes

Would you like assistance generating test cases?

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR successfully generalizes AsymmetricUnifiedFocalLoss and its dependencies to support multi-class segmentation and adds the standard sigmoid/softmax/to_onehot_y interface used by other MONAI losses like FocalLoss. The interface is now aligned with MONAI conventions, replacing the to_onehot_y parameter with include_background for the component losses.

Major changes:

  • Added use_sigmoid/use_softmax activation options to AsymmetricUnifiedFocalLoss
  • Replaced to_onehot_y with include_background in AsymmetricFocalLoss and AsymmetricFocalTverskyLoss
  • Implemented asymmetric treatment: background class (index 0) handled differently from foreground classes
  • Refactored to support N-class segmentation instead of binary-only

Critical issues found:

  • Exponent formula error (lines 83, 86, 90): Uses 1.0 - self.gamma instead of 1.0 / self.gamma per the Unified Focal Loss paper. With default gamma=0.75, this incorrectly computes (1-dice)^0.25 (reduces penalty) instead of (1-dice)^1.333 (increases penalty as intended).
  • Shape handling logic error (line 280): When y_true has fewer dimensions than y_pred_act, the condition compares spatial dimension with channel dimension.

Confidence Score: 1/5

  • This PR has critical mathematical errors that will produce incorrect loss values
  • Score of 1 reflects critical bugs: the exponent formula error in AsymmetricFocalTverskyLoss (lines 83, 86, 90) uses 1.0 - gamma instead of 1.0 / gamma, fundamentally breaking the loss computation per the paper's specification. This will cause incorrect training behavior. Additionally, shape handling logic at line 280 has a dimensional mismatch bug. These are not edge cases but core functionality errors that must be fixed before merge.
  • monai/losses/unified_focal_loss.py requires immediate attention for exponent formula corrections and shape handling fixes

Important Files Changed

File Analysis

Filename Score Overview
monai/losses/unified_focal_loss.py 2/5 Major refactoring to add sigmoid/softmax interface and multi-class support. Critical exponent formula errors (should be 1/gamma not 1-gamma) in lines 83, 86, 90. Shape handling logic error at line 280.

Sequence Diagram

sequenceDiagram
    participant User
    participant AsymmetricUnifiedFocalLoss
    participant AsymmetricFocalLoss
    participant AsymmetricFocalTverskyLoss

    User->>AsymmetricUnifiedFocalLoss: forward(y_pred, y_true)
    
    alt use_sigmoid=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: y_pred_act = sigmoid(y_pred)
    else use_softmax=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: y_pred_act = softmax(y_pred)
    else
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: y_pred_act = y_pred
    end
    
    alt to_onehot_y=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: y_true = one_hot(y_true, num_classes)
    end
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Ensure y_true.shape == y_pred_act.shape
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalLoss: forward(y_pred_act, y_true)
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Exclude background if include_background=False
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Compute cross_entropy
    
    alt Multi-class with background
        AsymmetricFocalLoss->>AsymmetricFocalLoss: back_ce = (1-delta) * (1-p)^gamma * CE
        AsymmetricFocalLoss->>AsymmetricFocalLoss: fore_ce = delta * CE
    else Foreground only
        AsymmetricFocalLoss->>AsymmetricFocalLoss: loss = delta * CE
    end
    
    AsymmetricFocalLoss-->>AsymmetricUnifiedFocalLoss: f_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalTverskyLoss: forward(y_pred_act, y_true)
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Exclude background if include_background=False
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute Tversky Index (TI)
    
    alt Multi-class with background
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: back_loss = 1 - TI[0]
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: fore_loss = (1 - TI[1:])^(1-gamma) [BUG: should be 1/gamma]
    else Foreground only
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: loss = (1 - TI)^(1-gamma) [BUG: should be 1/gamma]
    end
    
    AsymmetricFocalTverskyLoss-->>AsymmetricUnifiedFocalLoss: t_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: loss = lambda_focal * f_loss + (1-lambda_focal) * t_loss
    AsymmetricUnifiedFocalLoss-->>User: loss
Loading

1 file reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 280 to 281
if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
y_true = y_true.unsqueeze(1) # Add channel dim
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: When y_true.ndim == y_pred_act.ndim - 1 (e.g., (B, H, W) vs (B, C, H, W)), y_true.shape[1] is the spatial dimension H, not a channel dimension. Comparing it with y_pred_act.shape[1] (C) is illogical.

Suggested change
if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
y_true = y_true.unsqueeze(1) # Add channel dim
if y_true.ndim == y_pred_act.ndim - 1:
y_true = y_true.unsqueeze(1) # Add channel dim
elif y_true.shape[1] != y_pred_act.shape[1]:


if not self.include_background:
# All classes are foreground, apply foreground logic
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Exponent formula is incorrect per the Unified Focal Loss paper. Should be 1.0 / self.gamma not 1.0 - self.gamma. With default gamma=0.75, current gives (1-dice)^0.25 (decreases penalty), should be (1-dice)^1.333 (increases penalty).

Suggested change
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C)
loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C)

loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C)
elif n_classes == 1:
# Single class, must be foreground (BG was excluded or not provided)
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Same exponent error - should be 1.0 / self.gamma

Suggested change
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1)
loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1)

else:
# Asymmetric logic: class 0 is BG, others are FG
back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Same exponent error - should be 1.0 / self.gamma

Suggested change
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1)

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Refactored three loss classes (AsymmetricFocalTverskyLoss, AsymmetricFocalLoss, AsymmetricUnifiedFocalLoss) to support multi-class segmentation and align with MONAI's standard interface (sigmoid/softmax/to_onehot_y parameters).

Key changes:

  • Replaced to_onehot_y parameter with include_background for component losses
  • Added use_sigmoid and use_softmax activation options to AsymmetricUnifiedFocalLoss
  • Generalized asymmetric logic to handle N classes (not just binary)
  • Improved parameter clarity: renamed weight to lambda_focal, split gamma/delta params for each component
  • Proper reduction handling (mean/sum/none) for all loss components

Critical issues found:

  • Incorrect exponent formula: Uses 1.0 - self.gamma instead of 1.0 / self.gamma in 3 locations (lines 83, 86, 90), contradicting the Unified Focal Loss paper's formula (1-TI)^(1/γ). With default γ=0.75, current code produces (1-dice)^0.25 (reduces penalty) instead of (1-dice)^1.333 (increases penalty)
  • Shape alignment bug: Line 280 condition y_true.ndim == y_pred_act.ndim - 1 and y_true.shape[1] != y_pred_act.shape[1] will compare spatial dimension H with channel dimension C when shapes are (B,H,W) vs (B,C,H,W)

Confidence Score: 1/5

  • This PR contains critical mathematical errors that will produce incorrect loss values and must be fixed before merging
  • Score of 1 reflects critical bugs: (1) the exponent formula error in 3 locations fundamentally breaks the loss calculation per the paper's specification, causing the wrong penalty direction; (2) shape alignment logic error will cause IndexError in some input configurations. These are not edge cases but affect core functionality.
  • monai/losses/unified_focal_loss.py requires immediate attention - all three exponent errors (lines 83, 86, 90) and the shape logic bug (line 280) must be fixed

Important Files Changed

File Analysis

Filename Score Overview
monai/losses/unified_focal_loss.py 2/5 Refactored to support multi-class segmentation and added sigmoid/softmax/to_onehot_y interface. Critical bug: exponent formula uses 1.0 - gamma instead of 1.0 / gamma (lines 83, 86, 90), causing incorrect loss calculations. Additional issue with shape alignment logic (line 280).

Sequence Diagram

sequenceDiagram
    participant User
    participant AsymmetricUnifiedFocalLoss
    participant AsymmetricFocalLoss
    participant AsymmetricFocalTverskyLoss
    
    User->>AsymmetricUnifiedFocalLoss: forward(y_pred, y_true)
    
    alt use_sigmoid=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Apply sigmoid(y_pred)
    else use_softmax=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Apply softmax(y_pred, dim=1)
    end
    
    alt to_onehot_y=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Convert y_true to one-hot encoding
    end
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Align y_true shape with y_pred_act
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalLoss: forward(y_pred_act, y_true)
    
    alt include_background=False
        AsymmetricFocalLoss->>AsymmetricFocalLoss: Exclude background channel
    end
    
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Calculate cross_entropy = -y_true * log(y_pred)
    
    alt n_classes > 1
        AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply asymmetric weighting<br/>(BG: gamma exponent, FG: no exponent)
    end
    
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply reduction (mean/sum/none)
    AsymmetricFocalLoss-->>AsymmetricUnifiedFocalLoss: Return f_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalTverskyLoss: forward(y_pred_act, y_true)
    
    alt include_background=False
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Exclude background channel
    end
    
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Calculate Tversky index (TP, FN, FP)
    
    alt n_classes > 1
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply asymmetric focal modulation<br/>(BG: no exponent, FG: gamma exponent)
    end
    
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply reduction (mean/sum/none)
    AsymmetricFocalTverskyLoss-->>AsymmetricUnifiedFocalLoss: Return t_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Combine: lambda_focal * f_loss + (1-lambda_focal) * t_loss
    AsymmetricUnifiedFocalLoss-->>User: Return combined loss
Loading

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR generalizes AsymmetricUnifiedFocalLoss and its component losses to support multi-class segmentation and adds the standard MONAI loss interface (use_sigmoid, use_softmax, to_onehot_y, include_background). The refactoring replaces the hardcoded binary segmentation logic with flexible multi-class handling.

Major changes:

  • Replaced to_onehot_y parameter with include_background in component losses (AsymmetricFocalLoss, AsymmetricFocalTverskyLoss)
  • Added use_sigmoid/use_softmax activation options to AsymmetricUnifiedFocalLoss
  • Implemented asymmetric treatment: background class (index 0) uses different loss weighting than foreground classes
  • Updated parameter names for clarity (weightlambda_focal, added separate gamma/delta for each component)

Critical issues found:

  • Line 90: Exponent formula error in AsymmetricFocalTverskyLoss - uses 1.0 - self.gamma instead of 1.0 / self.gamma, causing incorrect loss behavior
  • Lines 279-280: Shape handling logic error when y_true has fewer dimensions than y_pred_act
  • Lines 271-275: One-hot encoding with num_classes=1 when using use_sigmoid=True with single-channel input

Confidence Score: 1/5

  • This PR has critical mathematical errors that will cause incorrect loss computation
  • Score of 1 reflects a critical exponent formula error on line 90 that fundamentally breaks the AsymmetricFocalTverskyLoss behavior, plus logic errors in shape handling (lines 279-280) and one-hot encoding (lines 271-275) that could cause runtime failures or incorrect results
  • monai/losses/unified_focal_loss.py requires immediate attention - line 90 has incorrect exponent formula (1.0 - gamma instead of 1.0 / gamma), lines 279-280 have shape handling logic error, and lines 271-275 have one-hot encoding issue

Important Files Changed

File Analysis

Filename Score Overview
monai/losses/unified_focal_loss.py 2/5 Refactored to add sigmoid/softmax interface and multi-class support, but contains critical exponent formula error in AsymmetricFocalTverskyLoss (line 90) and shape/one-hot handling issues in AsymmetricUnifiedFocalLoss

Sequence Diagram

sequenceDiagram
    participant User
    participant AsymmetricUnifiedFocalLoss
    participant AsymmetricFocalLoss
    participant AsymmetricFocalTverskyLoss
    
    User->>AsymmetricUnifiedFocalLoss: forward(y_pred, y_true)
    
    alt use_sigmoid
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: y_pred_act = sigmoid(y_pred)
    else use_softmax
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: y_pred_act = softmax(y_pred)
    else no activation
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: y_pred_act = y_pred
    end
    
    alt to_onehot_y
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: y_true = one_hot(y_true, num_classes)
    end
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: ensure shape compatibility
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalLoss: forward(y_pred_act, y_true)
    
    alt include_background
        AsymmetricFocalLoss->>AsymmetricFocalLoss: apply asymmetric weighting<br/>(BG: (1-delta)*(1-p)^gamma*CE, FG: delta*CE)
    else exclude background
        AsymmetricFocalLoss->>AsymmetricFocalLoss: slice [:, 1:] and apply FG logic
    end
    
    AsymmetricFocalLoss-->>AsymmetricUnifiedFocalLoss: focal_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalTverskyLoss: forward(y_pred_act, y_true)
    
    alt include_background
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: compute Tversky index<br/>asymmetric: BG linear, FG with exponent
    else exclude background
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: slice [:, 1:] and apply FG logic
    end
    
    AsymmetricFocalTverskyLoss-->>AsymmetricUnifiedFocalLoss: tversky_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: loss = lambda_focal * focal_loss<br/>+ (1-lambda_focal) * tversky_loss
    
    AsymmetricUnifiedFocalLoss-->>User: loss
Loading

1 file reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

else:
# Asymmetric logic: class 0 is BG, others are FG
back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: exponent formula incorrect - should be 1.0 / self.gamma not 1.0 - self.gamma

Paper specifies AFTL = (1-TI)^(1/γ). With γ=0.75: should be (1-dice)^1.333 (increases penalty), current gives (1-dice)^0.25 (decreases penalty).

Suggested change
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1)

Comment on lines +279 to +280
if y_true.ndim == y_pred_act.ndim - 1:
y_true = y_true.unsqueeze(1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: when y_true.ndim == y_pred_act.ndim - 1 (e.g., (B,H,W) vs (B,C,H,W)), should immediately unsqueeze. The AND condition causes the unsqueeze to be skipped when dimensions match exactly at position 1.

Suggested change
if y_true.ndim == y_pred_act.ndim - 1:
y_true = y_true.unsqueeze(1)
if y_true.ndim == y_pred_act.ndim - 1: # e.g., (B, H, W) vs (B, C, H, W)
y_true = y_true.unsqueeze(1) # Add channel dim
elif y_true.shape[1] != y_pred_act.shape[1]:
y_true = y_true.unsqueeze(1)

Comment on lines +271 to 275
elif n_pred_ch > 1 or self.use_sigmoid:
# Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
if y_true.shape[1] != 1:
y_true = y_true.unsqueeze(1)
y_true = one_hot(y_true, num_classes=n_pred_ch)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: when use_sigmoid=True and n_pred_ch=1, line 271's condition n_pred_ch > 1 or self.use_sigmoid evaluates to True, leading to one-hot with num_classes=1 on line 275. For binary sigmoid, should convert to 2 classes (background + foreground).

Suggested change
elif n_pred_ch > 1 or self.use_sigmoid:
# Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
if y_true.shape[1] != 1:
y_true = y_true.unsqueeze(1)
y_true = one_hot(y_true, num_classes=n_pred_ch)
elif n_pred_ch > 1 or (self.use_sigmoid and n_pred_ch == 1):
# Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
if y_true.shape[1] != 1:
y_true = y_true.unsqueeze(1)
# For binary sigmoid case, need 2 classes
num_classes = 2 if (self.use_sigmoid and n_pred_ch == 1) else n_pred_ch
y_true = one_hot(y_true, num_classes=num_classes)

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR generalizes AsymmetricUnifiedFocalLoss, AsymmetricFocalLoss, and AsymmetricFocalTverskyLoss to support multi-class segmentation and aligns them with MONAI's standard loss interface (adding sigmoid/softmax/to_onehot_y parameters).

Key Changes:

  • Replaced to_onehot_y with include_background in the component losses (AsymmetricFocalLoss, AsymmetricFocalTverskyLoss)
  • Added use_sigmoid/use_softmax activation options to AsymmetricUnifiedFocalLoss
  • Implemented multi-class support with asymmetric treatment: background class (index 0) gets different weighting than foreground classes (indices 1+)
  • Refactored loss computation to handle three cases: no background, single class, and multi-class with background
  • Renamed parameters for clarity (weightlambda_focal, separate focal_loss_gamma/tversky_loss_gamma)

Critical Issues:

  • Line 90: Incorrect exponent formula 1.0 - self.gamma should be 1.0 / self.gamma per the Unified Focal Loss paper
  • Line 271-275: One-hot encoding with num_classes=1 for binary sigmoid is illogical; should use 2 classes

Confidence Score: 2/5

  • This PR has a critical mathematical error in the loss formula that will cause incorrect training behavior
  • Score reflects a critical bug in the exponent formula (line 90) that fundamentally breaks the loss computation for multi-class cases with asymmetric foreground handling. Additionally, the binary sigmoid one-hot logic is flawed. These issues will cause incorrect loss values during training.
  • The single file monai/losses/unified_focal_loss.py requires immediate attention for the mathematical formula error on line 90 and one-hot encoding logic on lines 271-275

Important Files Changed

File Analysis

Filename Score Overview
monai/losses/unified_focal_loss.py 2/5 Generalized three loss classes for multi-class support and standardized interface. Critical bug in AsymmetricFocalTverskyLoss line 90 (wrong exponent formula), and logic issue in AsymmetricUnifiedFocalLoss for binary sigmoid one-hot handling.

Sequence Diagram

sequenceDiagram
    participant User
    participant AUFL as AsymmetricUnifiedFocalLoss
    participant AFL as AsymmetricFocalLoss
    participant AFTL as AsymmetricFocalTverskyLoss
    
    User->>AUFL: forward(y_pred, y_true)
    
    alt use_sigmoid=True
        AUFL->>AUFL: y_pred_act = sigmoid(y_pred)
    else use_softmax=True
        AUFL->>AUFL: y_pred_act = softmax(y_pred)
    else
        AUFL->>AUFL: y_pred_act = y_pred
    end
    
    alt to_onehot_y=True
        AUFL->>AUFL: y_true = one_hot(y_true, num_classes)
    end
    
    AUFL->>AUFL: Shape validation & adjustment
    
    par Parallel Loss Computation
        AUFL->>AFL: forward(y_pred_act, y_true)
        AFL->>AFL: Exclude background if needed
        AFL->>AFL: cross_entropy = -y_true * log(y_pred)
        
        alt Multi-class with background
            AFL->>AFL: BG: (1-delta) * (1-p)^gamma * CE
            AFL->>AFL: FG: delta * CE
        else Single/No background
            AFL->>AFL: All FG: delta * CE
        end
        
        AFL-->>AUFL: f_loss
    and
        AUFL->>AFTL: forward(y_pred_act, y_true)
        AFTL->>AFTL: Exclude background if needed
        AFTL->>AFTL: Compute Tversky Index (tp, fn, fp)
        
        alt Multi-class with background
            AFTL->>AFTL: BG: 1 - dice
            AFTL->>AFTL: FG: (1-dice)^(1/gamma)
        else Single/No background
            AFTL->>AFTL: All: (1-dice)^(1/gamma)
        end
        
        AFTL-->>AUFL: t_loss
    end
    
    AUFL->>AUFL: loss = lambda * f_loss + (1-lambda) * t_loss
    AUFL-->>User: loss
Loading

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

else:
# Asymmetric logic: class 0 is BG, others are FG
back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: exponent formula is incorrect - should be 1.0 / self.gamma not 1.0 - self.gamma

Per the Unified Focal Loss paper, AFTL = (1-TI)^(1/γ). With default γ=0.75, this gives (1-dice)^1.333, increasing penalty for low dice. Current formula gives (1-dice)^0.25, incorrectly decreasing the penalty.

Suggested change
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1)

Comment on lines +271 to 275
elif n_pred_ch > 1 or self.use_sigmoid:
# Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
if y_true.shape[1] != 1:
y_true = y_true.unsqueeze(1)
y_true = one_hot(y_true, num_classes=n_pred_ch)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: when use_sigmoid=True and n_pred_ch=1, one-hot with num_classes=1 is illogical

For binary sigmoid (1 channel foreground probability), should convert to 2 classes (background + foreground) for proper loss computation with the asymmetric logic.

Suggested change
elif n_pred_ch > 1 or self.use_sigmoid:
# Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
if y_true.shape[1] != 1:
y_true = y_true.unsqueeze(1)
y_true = one_hot(y_true, num_classes=n_pred_ch)
elif n_pred_ch > 1 or (self.use_sigmoid and n_pred_ch == 1):
# Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
if y_true.shape[1] != 1:
y_true = y_true.unsqueeze(1)
# For binary sigmoid case, need 2 classes
num_classes = 2 if (self.use_sigmoid and n_pred_ch == 1) else n_pred_ch
y_true = one_hot(y_true, num_classes=num_classes)

pre-commit-ci bot and others added 6 commits October 31, 2025 23:05
Fixes Project-MONAI#8564 .

### Description

Add Fourier feature positional encodings to `PatchEmbeddingBlock`. It
has been shown, that Fourier feature positional encodings are better
suited for Anistropic images and videos:
https://arxiv.org/abs/2509.02488

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: NabJa <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
Signed-off-by: ytl0623 <[email protected]>
Signed-off-by: ytl0623 <[email protected]>
….com>

I, NabJa <[email protected]>, hereby add my Signed-off-by to this commit: 76c4391

Signed-off-by: NabJa <[email protected]>

Signed-off-by: ytl0623 <[email protected]>
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR successfully generalizes AsymmetricUnifiedFocalLoss and its component losses (AsymmetricFocalLoss, AsymmetricFocalTverskyLoss) to support multi-class segmentation beyond binary cases. The interface has been modernized to align with other MONAI losses by:

  • Replacing to_onehot_y parameter with include_background in component losses
  • Adding use_sigmoid/use_softmax activation options to the unified loss
  • Renaming weight to lambda_focal for clarity
  • Separating gamma and delta parameters for each loss component
  • Updating documentation to reflect multi-class support

Critical Issue Found:

  • Line 90 in AsymmetricFocalTverskyLoss uses incorrect exponent formula 1.0 - self.gamma instead of 1.0 / self.gamma, which is inconsistent with lines 83 and 86 in the same function. This causes incorrect loss computation for foreground classes when background is included.

The refactoring is well-structured and follows MONAI patterns, but the mathematical error needs to be fixed before merging.

Confidence Score: 2/5

  • This PR has a critical mathematical bug that will cause incorrect loss computation
  • The exponent formula error on line 90 is a critical bug that will produce incorrect gradients and affect model training. While the refactoring is well-designed and the interface improvements are good, this single mathematical error makes the loss function mathematically incorrect per the paper's specification.
  • monai/losses/unified_focal_loss.py - Line 90 has critical mathematical error

Important Files Changed

File Analysis

Filename Score Overview
monai/losses/unified_focal_loss.py 2/5 Generalizes losses for multi-class segmentation and adds sigmoid/softmax interface. Critical bug: incorrect exponent formula on line 90 (uses 1.0 - gamma instead of 1.0 / gamma)

Sequence Diagram

sequenceDiagram
    participant User
    participant AUFL as AsymmetricUnifiedFocalLoss
    participant AFL as AsymmetricFocalLoss
    participant AFTL as AsymmetricFocalTverskyLoss
    
    User->>AUFL: forward(y_pred, y_true)
    
    alt use_sigmoid
        AUFL->>AUFL: y_pred_act = sigmoid(y_pred)
    else use_softmax
        AUFL->>AUFL: y_pred_act = softmax(y_pred)
    else
        AUFL->>AUFL: y_pred_act = y_pred
    end
    
    alt to_onehot_y
        AUFL->>AUFL: y_true = one_hot(y_true, num_classes)
    end
    
    AUFL->>AUFL: Ensure shapes match
    
    par Compute both loss components
        AUFL->>AFL: forward(y_pred_act, y_true)
        AFL->>AFL: Exclude background if needed
        AFL->>AFL: cross_entropy = -y_true * log(y_pred)
        
        alt include_background and n_classes > 1
            AFL->>AFL: back_ce with gamma modulation
            AFL->>AFL: fore_ce without gamma
            AFL->>AFL: Concatenate losses
        else
            AFL->>AFL: Apply foreground logic
        end
        
        AFL->>AFL: Apply reduction
        AFL-->>AUFL: f_loss
    and
        AUFL->>AFTL: forward(y_pred_act, y_true)
        AFTL->>AFTL: Exclude background if needed
        AFTL->>AFTL: Compute TP, FN, FP
        AFTL->>AFTL: dice_class = Tversky index
        
        alt include_background and n_classes > 1
            AFTL->>AFTL: back_dice_loss (no gamma)
            AFTL->>AFTL: fore_dice_loss with (1-dice)^(1/gamma)
            Note right of AFTL: BUG: Line 90 uses 1-gamma
            AFTL->>AFTL: Concatenate losses
        else
            AFTL->>AFTL: Apply (1-dice)^(1/gamma)
        end
        
        AFTL->>AFTL: Apply reduction
        AFTL-->>AUFL: t_loss
    end
    
    AUFL->>AUFL: loss = lambda_focal * f_loss + (1 - lambda_focal) * t_loss
    AUFL-->>User: loss
Loading

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

else:
# Asymmetric logic: class 0 is BG, others are FG
back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: exponent should be 1.0 / self.gamma not 1.0 - self.gamma (inconsistent with lines 83, 86)

Lines 83 and 86 correctly use 1.0 / self.gamma, but line 90 incorrectly uses 1.0 - self.gamma. Per the Unified Focal Loss paper, AFTL = (1-TI)^(1/γ). With γ=0.75, should be (1-dice)^1.333 to increase penalty for low dice scores.

Suggested change
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (4)
monai/losses/unified_focal_loss.py (4)

118-136: Breaking API change without deprecation.

Same issue as AsymmetricFocalTverskyLoss: removing to_onehot_y breaks existing code. Restore with deprecation or document as breaking change.


90-90: Critical: Wrong exponent formula - should be 1.0 / self.gamma.

The Unified Focal Loss paper specifies AFTL = (1-TI)^(1/γ). With default γ=0.75, this gives (1-dice)^1.333 (increases penalty for low dice). Current 1.0 - self.gamma gives (1-dice)^0.25 (decreases penalty).

This critical error has been flagged in multiple past reviews.

Apply this diff:

-            fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma)  # (B, C-1)
+            fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma)  # (B, C-1)

279-280: Logic error: dimension mismatch when y_true lacks channel dim.

When y_true is (B, H, W) and y_pred_act is (B, C, H, W), condition on line 279 is True (3 == 4-1) but code only unsqueezes, doesn't verify shape match. Should unsqueeze immediately when ndim differs.

This was flagged multiple times in past reviews.

Apply this diff:

         # Ensure y_true has the same shape as y_pred_act
         if y_true.shape != y_pred_act.shape:
-            if y_true.ndim == y_pred_act.ndim - 1:
+            if y_true.ndim == y_pred_act.ndim - 1:  # e.g., (B, H, W) vs (B, C, H, W)
                 y_true = y_true.unsqueeze(1)
-
-            if y_true.shape != y_pred_act.shape:
+            elif y_true.shape != y_pred_act.shape:
                 raise ValueError(
                     f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) "
                     f"after activations/one-hot"
                 )

271-275: Critical: One-hot with num_classes=1 breaks binary sigmoid.

When use_sigmoid=True with single-channel logits (n_pred_ch=1), line 271's condition n_pred_ch > 1 or self.use_sigmoid evaluates True, calling one_hot(y_true, num_classes=1). Any label==1 raises ValueError.

This blocking bug was flagged by coderabbitai[bot] in a past review and remains unfixed.

Apply this diff:

         if self.to_onehot_y:
             if n_pred_ch == 1 and not self.use_sigmoid:
                 warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
-            elif n_pred_ch > 1 or self.use_sigmoid:
+            elif n_pred_ch > 1:
                 # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
                 if y_true.shape[1] != 1:
                     y_true = y_true.unsqueeze(1)
                 y_true = one_hot(y_true, num_classes=n_pred_ch)

Or expand binary logits to 2 channels before one-hotting.

🧹 Nitpick comments (5)
monai/losses/unified_focal_loss.py (5)

65-65: Add stacklevel=2 to warning.

-                warnings.warn("single channel prediction, `include_background=False` ignored.")
+                warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)

147-147: Add stacklevel=2 to warning.

-                warnings.warn("single channel prediction, `include_background=False` ignored.")
+                warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)

264-264: Add stacklevel=2 to warning.

-                warnings.warn("single channel prediction, use_softmax=True ignored.")
+                warnings.warn("single channel prediction, use_softmax=True ignored.", stacklevel=2)

270-270: Add stacklevel=2 to warning.

-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)

193-227: Docstrings missing exception documentation.

Per coding guidelines, docstrings should document raised exceptions. Line 236 raises ValueError but it's not in the Args section.

Add a Raises section:

         Args:
             to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
             use_sigmoid: if True, apply a sigmoid activation to the input y_pred.
             use_softmax: if True, apply a softmax activation to the input y_pred.
             lambda_focal: the weight for AsymmetricFocalLoss (Cross-Entropy based).
                 The weight for AsymmetricFocalTverskyLoss will be (1 - lambda_focal). Defaults to 0.5.
             focal_loss_gamma: gamma parameter for the AsymmetricFocalLoss component. Defaults to 2.0.
             focal_loss_delta: delta parameter for the AsymmetricFocalLoss component. Defaults to 0.7.
             tversky_loss_gamma: gamma parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.75.
             tversky_loss_delta: delta parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.7.
             include_background: whether to include loss computation for the background class. Defaults to True.
             reduction: specifies the reduction to apply to the output: "none", "mean", "sum".
+
+        Raises:
+            ValueError: if both use_sigmoid and use_softmax are True.
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 5d14d85 and f69a25b.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.2)
monai/losses/unified_focal_loss.py

60-60: Avoid specifying long messages outside the exception class

(TRY003)


65-65: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


100-100: Avoid specifying long messages outside the exception class

(TRY003)


142-142: Avoid specifying long messages outside the exception class

(TRY003)


147-147: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


179-179: Avoid specifying long messages outside the exception class

(TRY003)


236-236: Avoid specifying long messages outside the exception class

(TRY003)


264-264: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


270-270: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


283-286: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (15)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: packaging
  • GitHub Check: build-docs

Comment on lines +36 to +51
include_background: bool = True,
delta: float = 0.7,
gamma: float = 0.75,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
) -> None:
"""
Args:
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
include_background: whether to include loss computation for the background class. Defaults to True.
delta : weight of the background. Defaults to 0.7. (Used to weigh FNs and FPs in Tversky index)
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
reduction: specifies the reduction to apply to the output: "none", "mean", "sum".
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.include_background = include_background
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Breaking API change without deprecation.

Removing to_onehot_y from AsymmetricFocalTverskyLoss breaks existing code. Per MONAI conventions and prior reviewer feedback (ericspod), deprecated parameters must be preserved with deprecation decorators for 1-2 versions.

Either restore to_onehot_y with deprecation warnings or explicitly document this as a breaking change in the changelog.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add sigmoid/softmax interface for AsymmetricUnifiedFocalLoss

3 participants