-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Generalize AsymmetricUnifiedFocalLoss for multi-class and align interface #8607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThe unified focal loss module undergoes restructuring across three classes. Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
4d192fa to
1ab9120
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR successfully modernizes three asymmetric loss classes (AsymmetricFocalTverskyLoss, AsymmetricFocalLoss, and AsymmetricUnifiedFocalLoss) to align with MONAI's standard loss interface, addressing issue #8603. The changes generalize these losses from binary-only to multi-class segmentation support.
Key improvements:
- Replaced deprecated
to_onehot_yparameter withinclude_backgroundin base loss classes - Added
sigmoid/softmaxactivation support toAsymmetricUnifiedFocalLoss - Implemented proper multi-class handling with asymmetric foreground/background treatment
- Improved documentation and parameter naming clarity
- Refactored reduction logic to properly support "mean", "sum", and "none" modes
Issues found:
- Critical: Logic error in one-hot conversion when
sigmoid=True(lines 272-278) - attempts to create one-hot withnum_classes=1which produces incorrect shapes - Critical: Shape validation logic error (lines 281-287) - the condition
y_true.ndim == y_pred_act.ndim - 1evaluates incorrectly - Unused import (
torch.nn.functional as F) should be removed - Shape validation before one-hot conversion could be more robust
The refactoring follows the pattern established by FocalLoss but contains logic errors that will cause runtime failures with certain activation configurations.
Confidence Score: 2/5
- This PR has critical logic errors that will cause runtime failures with sigmoid activation and certain input shapes
- Score of 2 reflects two critical logic errors: (1) one-hot conversion with sigmoid creates invalid shapes by using num_classes=1, and (2) shape dimension checking logic is incorrect. These will cause runtime failures in common use cases. The refactoring approach is sound and aligns well with MONAI patterns, but the implementation bugs must be fixed before merging.
- monai/losses/unified_focal_loss.py lines 272-287 contain critical logic errors in shape handling and one-hot conversion that must be resolved
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| monai/losses/unified_focal_loss.py | 3/5 | Generalizes loss functions from binary to multi-class with new sigmoid/softmax interface. Found critical logic errors in one-hot conversion with sigmoid and shape handling that will cause runtime failures. |
Sequence Diagram
sequenceDiagram
participant User
participant AsymmetricUnifiedFocalLoss
participant AsymmetricFocalLoss
participant AsymmetricFocalTverskyLoss
User->>AsymmetricUnifiedFocalLoss: forward(y_pred, y_true)
alt sigmoid=True
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Apply sigmoid activation
else softmax=True
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Apply softmax activation
end
alt to_onehot_y=True
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Convert y_true to one-hot format
end
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Validate shapes match
AsymmetricUnifiedFocalLoss->>AsymmetricFocalLoss: forward(y_pred_act, y_true)
alt include_background=False
AsymmetricFocalLoss->>AsymmetricFocalLoss: Exclude background channel
end
AsymmetricFocalLoss->>AsymmetricFocalLoss: Compute cross-entropy loss
alt Multi-class with background
AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply asymmetric weighting<br/>(gamma for BG, no gamma for FG)
else Foreground only
AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply foreground weighting
end
AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply reduction (mean/sum/none)
AsymmetricFocalLoss-->>AsymmetricUnifiedFocalLoss: Return focal_loss
AsymmetricUnifiedFocalLoss->>AsymmetricFocalTverskyLoss: forward(y_pred_act, y_true)
alt include_background=False
AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Exclude background channel
end
AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute TP, FN, FP
AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute Tversky index
alt Multi-class with background
AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply asymmetric focal<br/>(no gamma for BG, gamma for FG)
else Foreground only
AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply foreground focal
end
AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply reduction (mean/sum/none)
AsymmetricFocalTverskyLoss-->>AsymmetricUnifiedFocalLoss: Return tversky_loss
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Combine losses<br/>lambda_focal * focal_loss + (1-lambda_focal) * tversky_loss
AsymmetricUnifiedFocalLoss-->>User: Return combined_loss
1 file reviewed, 4 comments
monai/losses/unified_focal_loss.py
Outdated
|
|
||
| if not self.include_background: | ||
| # All classes are foreground, apply foreground logic | ||
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: exponent should be 1 + self.gamma (not 1 - self.gamma)
The original paper's formula uses (1 - dice)^(1/gamma), which when gamma < 1, becomes (1 - dice)^(>1), increasing the penalty. With 1 - self.gamma when gamma=0.75, you get (1-dice)^0.25, which decreases the penalty for low dice scores.
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | |
| loss = torch.pow(1.0 - dice_class, 1.0 + self.gamma) # (B, C) |
monai/losses/unified_focal_loss.py
Outdated
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | ||
| elif n_classes == 1: | ||
| # Single class, must be foreground (BG was excluded or not provided) | ||
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: same exponent error here - should be 1 + self.gamma
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) | |
| loss = torch.pow(1.0 - dice_class, 1.0 + self.gamma) # (B, 1) |
| else: | ||
| # Asymmetric logic: class 0 is BG, others are FG | ||
| back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1) | ||
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: same exponent error - should be 1 + self.gamma
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) | |
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 + self.gamma) # (B, C-1) |
| elif n_pred_ch > 1 or self.sigmoid: | ||
| # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion | ||
| if y_true.shape[1] != 1: | ||
| y_true = y_true.unsqueeze(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: unsqueezing unconditionally can cause issues if y_true is already (B, 1, H, W)
This adds an extra dimension. Should check y_true.ndim first or only unsqueeze when y_true.ndim == 3.
| y_true = y_true.unsqueeze(1) | |
| if y_true.ndim == 3: # (B, H, W) -> (B, 1, H, W) | |
| y_true = y_true.unsqueeze(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.2)
monai/losses/unified_focal_loss.py
61-61: Avoid specifying long messages outside the exception class
(TRY003)
66-66: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
103-103: Avoid specifying long messages outside the exception class
(TRY003)
145-145: Avoid specifying long messages outside the exception class
(TRY003)
150-150: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
182-182: Avoid specifying long messages outside the exception class
(TRY003)
239-239: Avoid specifying long messages outside the exception class
(TRY003)
267-267: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
273-273: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
287-287: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: build-docs
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: packaging
2959015 to
dca758d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Generalizes AsymmetricUnifiedFocalLoss and its component losses for multi-class segmentation. Replaces to_onehot_y parameter with standard include_background/sigmoid/softmax interface matching other MONAI losses. Adds proper handling for background vs foreground classes using asymmetric weighting.
Critical Issues Found:
- Mathematical errors in gamma exponents (lines 86, 89, 93) - uses
1.0 - self.gammainstead of1.0 / self.gamma, inverting the focal weighting behavior - Syntax error in exception message (line 287) - unclosed string literal
- Logic error in shape validation (line 283) - incorrect dimension checking that will cause IndexError
Changes:
- Replaced
to_onehot_yparameter withinclude_backgroundin component losses - Added
sigmoid/softmaxactivation support toAsymmetricUnifiedFocalLoss - Split gamma/delta parameters into separate focal and Tversky variants
- Generalized from binary to multi-class segmentation support
- Added proper reduction handling (
none/mean/sum)
Confidence Score: 0/5
- This PR is unsafe to merge due to critical mathematical errors and syntax error that will cause runtime failures
- Score of 0 reflects three critical blocking issues: (1) mathematical formula errors in lines 86, 89, 93 where
1.0 - self.gammashould be1.0 / self.gamma, completely inverting the focal weighting behavior, (2) syntax error on line 287 with unclosed string literal that will prevent code from running, and (3) logic error on line 283 that will cause IndexError when accessing shape dimensions - monai/losses/unified_focal_loss.py requires immediate attention - all identified errors must be fixed before merge
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| monai/losses/unified_focal_loss.py | 1/5 | Critical mathematical errors in gamma exponents (lines 86, 89, 93), syntax error in exception message (line 287), and logic error in shape handling (line 283) |
Sequence Diagram
sequenceDiagram
participant User
participant AsymmetricUnifiedFocalLoss
participant AsymmetricFocalLoss
participant AsymmetricFocalTverskyLoss
User->>AsymmetricUnifiedFocalLoss: forward(y_pred, y_true)
alt sigmoid=True
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: torch.sigmoid(y_pred)
else softmax=True
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: torch.softmax(y_pred, dim=1)
end
alt to_onehot_y=True
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: one_hot(y_true)
end
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Shape validation
AsymmetricUnifiedFocalLoss->>AsymmetricFocalLoss: forward(y_pred_act, y_true)
AsymmetricFocalLoss->>AsymmetricFocalLoss: Compute cross-entropy
alt include_background=False
AsymmetricFocalLoss->>AsymmetricFocalLoss: Exclude background channel
else n_classes > 1
AsymmetricFocalLoss->>AsymmetricFocalLoss: Asymmetric weighting<br/>(BG: (1-delta)*pow(1-p, gamma)*CE)<br/>(FG: delta*CE)
end
AsymmetricFocalLoss-->>AsymmetricUnifiedFocalLoss: f_loss
AsymmetricUnifiedFocalLoss->>AsymmetricFocalTverskyLoss: forward(y_pred_act, y_true)
AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute Tversky index
alt include_background=False
AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Exclude background channel
else n_classes > 1
AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Asymmetric weighting<br/>(BG: 1-dice)<br/>(FG: pow(1-dice, 1/gamma))
end
AsymmetricFocalTverskyLoss-->>AsymmetricUnifiedFocalLoss: t_loss
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: lambda_focal * f_loss +<br/>(1 - lambda_focal) * t_loss
AsymmetricUnifiedFocalLoss-->>User: combined_loss
1 file reviewed, 5 comments
monai/losses/unified_focal_loss.py
Outdated
|
|
||
| if not self.include_background: | ||
| # All classes are foreground, apply foreground logic | ||
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: exponent formula is incorrect - should be 1.0 / self.gamma (not 1.0 - self.gamma)
The paper uses (1 - dice)^(1/gamma). When gamma=0.75, this gives (1-dice)^1.333, increasing penalty for low dice. Your formula gives (1-dice)^0.25, decreasing it.
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | |
| loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C) |
monai/losses/unified_focal_loss.py
Outdated
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | ||
| elif n_classes == 1: | ||
| # Single class, must be foreground (BG was excluded or not provided) | ||
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: same exponent error - should be 1.0 / self.gamma
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) | |
| loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1) |
| else: | ||
| # Asymmetric logic: class 0 is BG, others are FG | ||
| back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1) | ||
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: same exponent error - should be 1.0 / self.gamma
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) | |
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1) |
monai/losses/unified_focal_loss.py
Outdated
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) | ||
| after activations/one-hot") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: syntax error - string literal not properly closed
Missing closing quote on line 287.
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) | |
| after activations/one-hot") | |
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) after activations/one-hot") |
monai/losses/unified_focal_loss.py
Outdated
| if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: | ||
| y_true = y_true.unsqueeze(1) # Add channel dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: logic error - condition check is incorrect
When y_true.ndim == y_pred_act.ndim - 1, y_true has fewer dimensions (e.g., (B,H,W) vs (B,C,H,W)), so the check y_true.shape[1] != y_pred_act.shape[1] will fail with IndexError.
| if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| if y_true.ndim == y_pred_act.ndim - 1: # e.g., (B, H, W) vs (B, C, H, W) | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| elif y_true.shape[1] != y_pred_act.shape[1]: | |
| y_true = y_true.unsqueeze(1) # Add channel dim if missing |
…face Signed-off-by: ytl0623 <[email protected]>
100e9f8 to
ad83444
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR successfully generalizes the AsymmetricUnifiedFocalLoss and its component losses (AsymmetricFocalLoss, AsymmetricFocalTverskyLoss) to support multi-class segmentation, replacing the previous binary-only limitation. The interface has been modernized to align with other MONAI losses by adding sigmoid/softmax/to_onehot_y parameters and replacing to_onehot_y with include_background in the component losses.
Major Changes:
- Added
sigmoid/softmaxactivation options toAsymmetricUnifiedFocalLoss - Replaced
to_onehot_ywithinclude_backgroundparameter in component losses - Generalized all three loss classes to handle arbitrary number of classes
- Improved parameter documentation and added clearer examples
- Properly implemented reduction modes (
mean,sum,none)
Critical Issues Found:
- Incorrect exponent formula in
AsymmetricFocalTverskyLoss: Uses1.0 - self.gammainstead of1.0 / self.gammaper the paper's formula (1-TI)^(1/γ). With default γ=0.75, this produces exponent 0.25 instead of 1.333, fundamentally breaking the loss behavior by reducing instead of increasing the penalty for low Dice scores. - Shape handling logic error in
AsymmetricUnifiedFocalLossline 282: Wheny_true.ndim == y_pred_act.ndim - 1, the code incorrectly tries to comparey_true.shape[1](spatial dim H) withy_pred_act.shape[1](channel dim C) before unsqueezing. - String syntax error on line 287: Missing closing quote on the error message string.
Confidence Score: 1/5
- This PR has critical mathematical errors that break the loss function's core behavior
- The incorrect exponent formula (1-γ instead of 1/γ) in AsymmetricFocalTverskyLoss fundamentally changes the loss behavior, causing it to decrease penalties for low Dice scores when it should increase them. Combined with the shape handling bug and syntax error, these issues will cause runtime failures and incorrect training behavior. The mathematical error affects all three loss classes and deviates from the published paper.
- monai/losses/unified_focal_loss.py requires immediate attention - all four exponent formulas must be corrected (lines 85, 88, 92), the shape logic fixed (line 282), and the string syntax error resolved (line 287)
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| monai/losses/unified_focal_loss.py | 1/5 | Generalizes losses for multi-class segmentation and adds sigmoid/softmax interface. Critical issues: incorrect exponent formula in AsymmetricFocalTverskyLoss (uses 1-γ instead of 1/γ), shape handling logic error in AsymmetricUnifiedFocalLoss, and string syntax error. |
Sequence Diagram
sequenceDiagram
participant User
participant AUFL as AsymmetricUnifiedFocalLoss
participant AFL as AsymmetricFocalLoss
participant AFTL as AsymmetricFocalTverskyLoss
User->>AUFL: forward(y_pred, y_true)
alt sigmoid=True
AUFL->>AUFL: Apply sigmoid activation
else softmax=True
AUFL->>AUFL: Apply softmax activation
end
alt to_onehot_y=True
AUFL->>AUFL: Convert y_true to one-hot
end
AUFL->>AUFL: Shape alignment checks
AUFL->>AFL: forward(y_pred_act, y_true)
AFL->>AFL: Exclude background if needed
AFL->>AFL: Compute cross_entropy
alt Multi-class with background
AFL->>AFL: Apply asymmetric weights<br/>BG: (1-δ)*(1-p)^γ*CE<br/>FG: δ*CE
else Foreground only
AFL->>AFL: Apply foreground weight δ*CE
end
AFL->>AFL: Apply reduction
AFL-->>AUFL: Return focal_loss
AUFL->>AFTL: forward(y_pred_act, y_true)
AFTL->>AFTL: Exclude background if needed
AFTL->>AFTL: Compute Tversky Index
alt Multi-class with background
AFTL->>AFTL: Apply asymmetric logic<br/>BG: 1-TI<br/>FG: (1-TI)^(1-γ) [INCORRECT]
else Foreground only
AFTL->>AFTL: Apply (1-TI)^(1-γ) [INCORRECT]
end
AFTL->>AFTL: Apply reduction
AFTL-->>AUFL: Return tversky_loss
AUFL->>AUFL: Combine losses<br/>λ*focal_loss + (1-λ)*tversky_loss
AUFL-->>User: Return combined loss
1 file reviewed, 5 comments
monai/losses/unified_focal_loss.py
Outdated
| if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: | ||
| y_true = y_true.unsqueeze(1) # Add channel dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: IndexError when y_true.ndim == y_pred_act.ndim - 1
When y_true is (B, H, W) and y_pred_act is (B, C, H, W), the condition y_true.ndim == y_pred_act.ndim - 1 is True, but then checking y_true.shape[1] != y_pred_act.shape[1] compares H with C, which is nonsensical. The logic should unsqueeze immediately when ndim differs.
| if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| if y_true.ndim == y_pred_act.ndim - 1: | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| elif y_true.shape[1] != y_pred_act.shape[1]: |
monai/losses/unified_focal_loss.py
Outdated
|
|
||
| if not self.include_background: | ||
| # All classes are foreground, apply foreground logic | ||
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Incorrect exponent formula - should be 1.0 / self.gamma
The Unified Focal Loss paper specifies AFTL = (1-TI)^(1/γ). With default γ=0.75, this gives exponent 1.333, increasing penalty for low Dice scores. Current formula 1.0 - self.gamma gives 0.25, incorrectly decreasing the penalty.
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | |
| loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C) |
monai/losses/unified_focal_loss.py
Outdated
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | ||
| elif n_classes == 1: | ||
| # Single class, must be foreground (BG was excluded or not provided) | ||
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Same exponent error - should be 1.0 / self.gamma
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) | |
| loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1) |
| else: | ||
| # Asymmetric logic: class 0 is BG, others are FG | ||
| back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1) | ||
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Same exponent error - should be 1.0 / self.gamma
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) | |
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1) |
monai/losses/unified_focal_loss.py
Outdated
| loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss | ||
| if y_true.shape != y_pred_act.shape: | ||
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) " \ | ||
| f"after activations/one-hot") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: String literal syntax error - missing closing quote
| f"after activations/one-hot") | |
| "after activations/one-hot") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (3)
monai/losses/unified_focal_loss.py (3)
85-92: Critical: Incorrect exponent formula in all three branches.The paper specifies
(1 - dice)^(1/gamma), but lines 85, 88, and 92 use1.0 - self.gamma. With gamma=0.75, you get(1-dice)^0.25(reduces penalty for low Dice), when you should get(1-dice)^1.333(increases penalty).Apply this diff:
if not self.include_background: # All classes are foreground, apply foreground logic - loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) + loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C) elif n_classes == 1: # Single class, must be foreground (BG was excluded or not provided) - loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) + loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1) else: # Asymmetric logic: class 0 is BG, others are FG back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1) - fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) + fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1) loss = torch.cat([back_dice_loss, fore_dice_loss], dim=1) # (B, C)
270-277: Critical: sigmoid+one-hot causes ValueError with binary classification.When
sigmoid=Trueandn_pred_ch=1, line 273's condition passes, triggeringone_hot(y_true, num_classes=1)at line 277. Any ground truth voxel with value 1 raisesValueError: class values must be smaller than num_classes.Apply this diff:
if self.to_onehot_y: - if n_pred_ch == 1 and not self.sigmoid: + if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") - elif n_pred_ch > 1 or self.sigmoid: + else: # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion if y_true.shape[1] != 1: y_true = y_true.unsqueeze(1) y_true = one_hot(y_true, num_classes=n_pred_ch)
280-287: Critical: Logic error in shape alignment causes incorrect comparison.Line 282 evaluates
y_true.shape[1] != y_pred_act.shape[1]even wheny_true.ndim == y_pred_act.ndim - 1. If y_true is (B, H, W) and y_pred_act is (B, C, H, W), you're comparing H (spatial dim) to C (channels), which is nonsensical.Additionally, line 287 has a syntax error—f-string is split across lines incorrectly per Ruff.
Apply this diff:
# Ensure y_true has the same shape as y_pred_act if y_true.shape != y_pred_act.shape: - # This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid - if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: + if y_true.ndim == y_pred_act.ndim - 1: # e.g., (B, H, W) vs (B, C, H, W) y_true = y_true.unsqueeze(1) # Add channel dim - - if y_true.shape != y_pred_act.shape: - raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) " \ - f"after activations/one-hot") + elif y_true.shape != y_pred_act.shape: + raise ValueError( + f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) " + f"after activations/one-hot" + )
🧹 Nitpick comments (1)
monai/losses/unified_focal_loss.py (1)
44-48: Improve parameter documentation clarity.Line 45 says delta is "weight of the background," but line 129 says it's "weight of the foreground." The docstrings should clarify that delta controls the FN/FP trade-off in Tversky (line 45) vs. background/foreground weight in focal CE (line 129). Also missing: documentation of raised exceptions (ValueError) per coding guidelines.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.2)
monai/losses/unified_focal_loss.py
287-287: f-string: unterminated string
(invalid-syntax)
287-288: Expected FStringEnd, found newline
(invalid-syntax)
288-288: Expected ',', found indent
(invalid-syntax)
288-288: Expected ',', found name
(invalid-syntax)
288-288: missing closing quote in string literal
(invalid-syntax)
288-289: Expected ',', found newline
(invalid-syntax)
291-291: Expected ')', found dedent
(invalid-syntax)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: build-docs
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: packaging
🔇 Additional comments (4)
monai/losses/unified_focal_loss.py (4)
146-152: LGTM on background exclusion logic.The include_background handling is consistent with AsymmetricFocalTverskyLoss.
237-238: LGTM on mutual exclusivity check.Correctly prevents both sigmoid and softmax from being enabled simultaneously.
290-295: LGTM on loss fusion logic.Correctly delegates to sub-loss instances and combines via weighted sum. Assuming sub-losses return scalar values when
reduction="mean"orreduction="sum", this will work.
195-251: Add unit tests for new sigmoid/softmax interface.The PR adds sigmoid/softmax activation options and renames
to_onehot_ytoinclude_background, but no new tests are mentioned. These API changes need coverage to verify:
- sigmoid/softmax mutual exclusivity
- One-hot conversion with different activation modes
- Background inclusion/exclusion behavior
- Multi-class vs. binary edge cases
Per coding guidelines, modified definitions must be covered by tests. Do you want me to generate test cases or open an issue to track this?
6b0c14c to
f5a2f7e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR generalizes AsymmetricUnifiedFocalLoss and its component losses to support multi-class segmentation and adds the standard sigmoid/softmax/to_onehot_y interface common to other MONAI losses.
Major changes:
- Replaced
to_onehot_yparameter withinclude_backgroundin component losses (AsymmetricFocalLoss,AsymmetricFocalTverskyLoss) - Added
sigmoid/softmax/to_onehot_yparameters toAsymmetricUnifiedFocalLoss - Refactored loss computation to handle multi-class scenarios
- Updated interface to match MONAI's
FocalLosspattern
Critical issues found:
- The gamma exponent formula in
AsymmetricFocalTverskyLossis mathematically incorrect on lines 85, 88, and 92. Uses1.0 - self.gammainstead of1.0 / self.gamma, which inverts the focal penalty behavior - Dimension handling logic error on lines 282-283 in
AsymmetricUnifiedFocalLosswhere shape comparison happens before necessary dimension adjustment
Confidence Score: 1/5
- This PR contains critical mathematical errors that will cause incorrect loss computation
- The gamma exponent formula is fundamentally wrong in three places (lines 85, 88, 92), using
1.0 - gammainstead of1.0 / gamma. This inverts the focal penalty behavior. Additionally, there's a dimension handling logic error (lines 282-283). These are not edge cases but core functionality bugs that will affect all users. - monai/losses/unified_focal_loss.py requires immediate attention to fix the gamma exponent formula and dimension handling logic
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| monai/losses/unified_focal_loss.py | 1/5 | Critical mathematical errors in gamma exponent formula (lines 85, 88, 92) and dimension handling logic error (line 282-283). These will cause incorrect loss computation. |
Sequence Diagram
sequenceDiagram
participant User
participant AsymmetricUnifiedFocalLoss
participant AsymmetricFocalLoss
participant AsymmetricFocalTverskyLoss
User->>AsymmetricUnifiedFocalLoss: forward(y_pred, y_true)
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Apply sigmoid/softmax if enabled
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Convert y_true to one-hot if to_onehot_y=True
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Ensure y_true and y_pred_act shapes match
AsymmetricUnifiedFocalLoss->>AsymmetricFocalLoss: forward(y_pred_act, y_true)
AsymmetricFocalLoss->>AsymmetricFocalLoss: Exclude background if include_background=False
AsymmetricFocalLoss->>AsymmetricFocalLoss: Compute cross-entropy loss
AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply asymmetric focal weighting
AsymmetricFocalLoss-->>AsymmetricUnifiedFocalLoss: f_loss
AsymmetricUnifiedFocalLoss->>AsymmetricFocalTverskyLoss: forward(y_pred_act, y_true)
AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Exclude background if include_background=False
AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute Tversky index (TP, FN, FP)
AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply focal exponent (WRONG: 1-gamma instead of 1/gamma)
AsymmetricFocalTverskyLoss-->>AsymmetricUnifiedFocalLoss: t_loss
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Combine: lambda_focal * f_loss + (1-lambda_focal) * t_loss
AsymmetricUnifiedFocalLoss-->>User: loss
1 file reviewed, 4 comments
monai/losses/unified_focal_loss.py
Outdated
|
|
||
| if not self.include_background: | ||
| # All classes are foreground, apply foreground logic | ||
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: incorrect exponent - should be 1.0 / self.gamma (not 1.0 - self.gamma)
The Unified Focal Loss paper specifies AFTL = (1-TI)^(1/γ). With default γ=0.75, this gives (1-dice)^1.333, increasing penalty for low dice. Current 1.0 - self.gamma gives (1-dice)^0.25, incorrectly decreasing penalty.
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | |
| loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C) |
monai/losses/unified_focal_loss.py
Outdated
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | ||
| elif n_classes == 1: | ||
| # Single class, must be foreground (BG was excluded or not provided) | ||
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: same exponent error - should be 1.0 / self.gamma
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) | |
| loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1) |
| else: | ||
| # Asymmetric logic: class 0 is BG, others are FG | ||
| back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1) | ||
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: same exponent error - should be 1.0 / self.gamma
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) | |
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1) |
monai/losses/unified_focal_loss.py
Outdated
| if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: | ||
| y_true = y_true.unsqueeze(1) # Add channel dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: logic error - when y_true.ndim == y_pred_act.ndim - 1, checking y_true.shape[1] on line 282 will compare wrong dimensions
When y_true is (B, H, W) and y_pred_act is (B, C, H, W), the first condition is True, but then checking y_true.shape[1] (which is H) against y_pred_act.shape[1] (which is C) is meaningless. Should unsqueeze immediately.
| if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| if y_true.ndim == y_pred_act.ndim - 1: # e.g., (B, H, W) vs (B, C, H, W) | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| elif y_true.shape[1] != y_pred_act.shape[1]: | |
| y_true = y_true.unsqueeze(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/losses/unified_focal_loss.py (1)
34-296: Add test coverage for new sigmoid/softmax/include_background modes.Current tests only cover default configuration with perfect predictions. Missing coverage:
- Binary segmentation with sigmoid=True
- Multi-class with softmax=True
- include_background=False with sub-loss variations
- to_onehot_y=True edge cases (sigmoid binary, multi-class)
- Reduction modes (SUM, NONE)
- Sub-loss classes in isolation
♻️ Duplicate comments (3)
monai/losses/unified_focal_loss.py (3)
270-277: Critical: one-hot conversion fails for sigmoid with single channel.When
sigmoid=Trueandn_pred_ch==1, line 273 triggersone_hot(..., num_classes=1), which raises ValueError for any label==1. The original issue #8603 requested sigmoid/softmax support, but this breaks binary sigmoid cases.Apply this diff:
if self.to_onehot_y: - if n_pred_ch == 1 and not self.sigmoid: + if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") - elif n_pred_ch > 1 or self.sigmoid: + else: # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion if y_true.shape[1] != 1: y_true = y_true.unsqueeze(1) y_true = one_hot(y_true, num_classes=n_pred_ch)
280-284: Major: IndexError when y_true is missing channel dim.Line 282 checks
y_true.ndim == y_pred_act.ndim - 1(e.g., y_true is (B,H,W)), but then accessesy_true.shape[1], which indexes the spatial dimension H, not the channel. This causes incorrect logic or IndexError.Apply this diff:
# Ensure y_true has the same shape as y_pred_act if y_true.shape != y_pred_act.shape: # This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid - if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: + if y_true.ndim == y_pred_act.ndim - 1: y_true = y_true.unsqueeze(1) # Add channel dim + elif y_true.shape[1] != y_pred_act.shape[1]: + y_true = y_true.unsqueeze(1) # Add channel dim if mismatch
85-92: Critical: Exponent formula is mathematically incorrect.All three branches use
1.0 - self.gammabut the Unified Focal Loss paper specifies(1 - TI)^(1/γ). With default γ=0.75, your formula gives exponent 0.25 (weakens penalty for poor predictions), but the paper requires 1.333 (strengthens penalty).Apply this diff:
if not self.include_background: # All classes are foreground, apply foreground logic - loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) + loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C) elif n_classes == 1: # Single class, must be foreground (BG was excluded or not provided) - loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) + loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1) else: # Asymmetric logic: class 0 is BG, others are FG back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1) - fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) + fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1) loss = torch.cat([back_dice_loss, fore_dice_loss], dim=1) # (B, C)
🧹 Nitpick comments (4)
monai/losses/unified_focal_loss.py (4)
65-65: Add stacklevel to warning.Per PEP 565, warnings should specify
stacklevel=2so users see the caller's location.- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
149-149: Add stacklevel to warning.- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
266-266: Add stacklevel to warning.- warnings.warn("single channel prediction, softmax=True ignored.") + warnings.warn("single channel prediction, softmax=True ignored.", stacklevel=2)
272-272: Add stacklevel to warning.- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.2)
monai/losses/unified_focal_loss.py
60-60: Avoid specifying long messages outside the exception class
(TRY003)
65-65: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
102-102: Avoid specifying long messages outside the exception class
(TRY003)
144-144: Avoid specifying long messages outside the exception class
(TRY003)
149-149: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
181-181: Avoid specifying long messages outside the exception class
(TRY003)
238-238: Avoid specifying long messages outside the exception class
(TRY003)
266-266: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
272-272: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
286-289: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: build-docs
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.9)
🔇 Additional comments (2)
monai/losses/unified_focal_loss.py (2)
154-172: Asymmetric weighting logic is correct.Background class gets gamma-modulated penalty; foreground classes get flat delta weighting. This aligns with the asymmetric focal loss design.
291-296: Fused loss computation is correct.The weighted combination
lambda_focal * focal + (1 - lambda_focal) * tverskyproperly unifies both loss components.
f5a2f7e to
b4e0fcc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Refactored AsymmetricUnifiedFocalLoss and its component losses to support multi-class segmentation and added standard MONAI loss interface (sigmoid/softmax/to_onehot_y/include_background).
Key Changes:
- Replaced
to_onehot_yparameter withinclude_backgroundin component losses - Added
sigmoid/softmaxactivation options to main loss - Generalized asymmetric logic to handle multi-class cases
- Improved parameter naming (
lambda_focalinstead ofweight, separate gamma/delta for each component) - Updated reduction logic to handle different output shapes
Critical Issues Identified (from previous comments):
- Incorrect exponent formula in
AsymmetricFocalTverskyLoss(should be1.0 / self.gammanot1.0 - self.gamma) - Shape handling bug on line 280 that checks wrong dimensions when
y_true.ndim == y_pred_act.ndim - 1 - String literal syntax error on line 287 (unclosed quote)
Confidence Score: 1/5
- This PR has critical mathematical errors and runtime bugs that will cause incorrect behavior
- Multiple critical logic errors have been identified: (1) incorrect exponent formula in 3 locations that fundamentally breaks the loss calculation, (2) shape dimension checking bug that will cause IndexError at runtime, (3) syntax error with unclosed string. These issues need to be resolved before merge.
- monai/losses/unified_focal_loss.py requires immediate attention - fix exponent formulas on lines 83, 86, 90 and shape checking logic on lines 280-287
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| monai/losses/unified_focal_loss.py | 1/5 | Refactored to support multi-class segmentation with sigmoid/softmax interface. Critical exponent formula errors need fixing, and shape handling logic has bugs that could cause runtime errors. |
Sequence Diagram
sequenceDiagram
participant User
participant AsymmetricUnifiedFocalLoss
participant AsymmetricFocalLoss
participant AsymmetricFocalTverskyLoss
User->>AsymmetricUnifiedFocalLoss: forward(y_pred, y_true)
alt sigmoid==True
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Apply sigmoid(y_pred)
else softmax==True
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Apply softmax(y_pred)
end
alt to_onehot_y==True
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Convert y_true to one-hot
end
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Shape alignment checks
AsymmetricUnifiedFocalLoss->>AsymmetricFocalLoss: forward(y_pred_act, y_true)
alt include_background==False
AsymmetricFocalLoss->>AsymmetricFocalLoss: Exclude channel 0
end
AsymmetricFocalLoss->>AsymmetricFocalLoss: Compute cross_entropy
alt n_classes > 1 and include_background
AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply asymmetric logic<br/>(BG: gamma exponent, FG: no exponent)
else
AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply foreground logic<br/>(delta * cross_entropy)
end
AsymmetricFocalLoss-->>AsymmetricUnifiedFocalLoss: f_loss
AsymmetricUnifiedFocalLoss->>AsymmetricFocalTverskyLoss: forward(y_pred_act, y_true)
alt include_background==False
AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Exclude channel 0
end
AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute TP, FN, FP
AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute Tversky index
alt n_classes > 1 and include_background
AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply asymmetric logic<br/>(BG: no exponent, FG: gamma exponent)
else
AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply foreground logic<br/>(gamma exponent)
end
AsymmetricFocalTverskyLoss-->>AsymmetricUnifiedFocalLoss: t_loss
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Combine: lambda*f_loss + (1-lambda)*t_loss
AsymmetricUnifiedFocalLoss-->>User: Combined loss
1 file reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
monai/losses/unified_focal_loss.py (2)
83-90: Critical: incorrect exponent formula breaks Tversky loss behavior.Lines 83, 86, and 90 use
1.0 - self.gammabut the paper specifies(1 - TI)^(1/γ). With default γ=0.75, the correct exponent is 1.333, increasing penalty for low Dice scores. Current formula gives 0.25, inverting the intended behavior.Apply this diff:
- loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) + loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C)- loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) + loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1)- fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) + fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1)
280-281: Critical: IndexError when y_true has fewer dimensions.When
y_true.ndim == y_pred_act.ndim - 1(e.g., (B,H,W) vs (B,C,H,W)), checkingy_true.shape[1]compares spatial dimension H with channel dimension C, which is nonsensical and may raise IndexError.Apply this diff:
- if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: - y_true = y_true.unsqueeze(1) # Add channel dim + if y_true.ndim == y_pred_act.ndim - 1: + y_true = y_true.unsqueeze(1) # Add channel dim + elif y_true.shape[1] != y_pred_act.shape[1]: + y_true = y_true.unsqueeze(1)
🧹 Nitpick comments (4)
monai/losses/unified_focal_loss.py (4)
65-65: Add stacklevel to warning.- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
147-147: Add stacklevel to warning.- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
264-264: Add stacklevel to warning.- warnings.warn("single channel prediction, softmax=True ignored.") + warnings.warn("single channel prediction, softmax=True ignored.", stacklevel=2)
270-270: Add stacklevel to warning.- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.2)
monai/losses/unified_focal_loss.py
60-60: Avoid specifying long messages outside the exception class
(TRY003)
65-65: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
100-100: Avoid specifying long messages outside the exception class
(TRY003)
142-142: Avoid specifying long messages outside the exception class
(TRY003)
147-147: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
179-179: Avoid specifying long messages outside the exception class
(TRY003)
236-236: Avoid specifying long messages outside the exception class
(TRY003)
264-264: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
270-270: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
284-287: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: build-docs
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: packaging
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.11)
🔇 Additional comments (4)
monai/losses/unified_focal_loss.py (4)
152-170: Asymmetric focal loss logic is correct.Background class applies gamma exponent to (1-p), foreground classes don't. This matches the paper's asymmetric formulation.
238-249: Sub-loss composition approach is sound.Creating separate AsymmetricFocalLoss and AsymmetricFocalTverskyLoss instances with independent gamma/delta parameters provides clean separation and reuses existing components per PR objectives.
235-236: Mutual exclusivity check for sigmoid/softmax is correct.
289-294: Fused loss computation is correct.Weighted combination of focal and Tversky losses matches the unified focal loss formulation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ytl0623 thanks for the contribution. I had only a minor comment but please look at the greptile comments. A few look spurious they should all be resolved.
|
|
||
| def __init__( | ||
| self, | ||
| to_onehot_y: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is to_onehot_y being removed? I think users will still want this functionality so I would leave it in and add include_background as a new last argument. Even if we do want to remove it we need to use the deprecation decorators to mark the argument removed but still keep it for a version or two.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR modernizes the AsymmetricUnifiedFocalLoss and its component losses (AsymmetricFocalLoss, AsymmetricFocalTverskyLoss) to support multi-class segmentation and adds standard MONAI loss interfaces (sigmoid/softmax/to_onehot_y/include_background).
Key changes:
- Replaced
to_onehot_yparameter withinclude_backgroundin base loss classes, matching MONAI conventions - Added
sigmoid/softmaxactivation options toAsymmetricUnifiedFocalLoss - Generalized from binary-only to multi-class segmentation support
- Maintained asymmetric treatment: background class (index 0) receives different weighting than foreground classes
- Updated reduction logic to properly handle "none"/"mean"/"sum" modes
Critical issues identified:
-
Incorrect exponent formula: Lines 83, 86, 90 use
1.0 - self.gammabut should use1.0 / self.gammaaccording to the Unified Focal Loss paper's formula: AFTL = (1-TI)^(1/γ). With default γ=0.75, current code gives (1-dice)^0.25 (decreasing penalty) instead of (1-dice)^1.333 (increasing penalty). -
Shape handling logic error: Line 280 checks
y_true.shape[1] != y_pred_act.shape[1]after confirmingy_true.ndim == y_pred_act.ndim - 1, which compares spatial dimension H with channel dimension C.
Recommendations:
- Fix the critical exponent formula errors immediately
- Fix the shape handling conditional logic
- Add tests covering multi-class scenarios and new sigmoid/softmax interfaces
- Verify against reference implementation or paper's equations
Confidence Score: 1/5
- This PR has critical mathematical errors that will cause incorrect loss computation and training behavior
- Score of 1 reflects multiple critical logical errors: (1) incorrect exponent formula in AsymmetricFocalTverskyLoss using
1.0 - self.gammainstead of1.0 / self.gamma, causing fundamentally wrong loss values that contradict the referenced paper; (2) shape handling logic error comparing incompatible dimensions. These are not edge cases but affect core functionality. The PR cannot be merged until these mathematical errors are corrected. - The only changed file
monai/losses/unified_focal_loss.pyrequires immediate attention - specifically lines 83, 86, 90 (exponent formula) and lines 280-281 (shape handling logic)
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| monai/losses/unified_focal_loss.py | 1/5 | Generalizes AsymmetricUnifiedFocalLoss for multi-class segmentation and adds sigmoid/softmax interface. Critical issues: incorrect exponent formula (uses 1.0 - self.gamma instead of 1.0 / self.gamma) in AsymmetricFocalTverskyLoss, shape handling logic error when y_true.ndim differs from y_pred_act.ndim |
Sequence Diagram
sequenceDiagram
participant User
participant AsymmetricUnifiedFocalLoss
participant AsymmetricFocalLoss
participant AsymmetricFocalTverskyLoss
User->>AsymmetricUnifiedFocalLoss: forward(y_pred, y_true)
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Apply sigmoid/softmax if requested
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Convert y_true to one-hot if to_onehot_y=True
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Ensure shape compatibility
AsymmetricUnifiedFocalLoss->>AsymmetricFocalLoss: forward(y_pred_act, y_true)
AsymmetricFocalLoss->>AsymmetricFocalLoss: Exclude background if needed
AsymmetricFocalLoss->>AsymmetricFocalLoss: Compute cross-entropy loss
AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply asymmetric weighting (BG vs FG)
AsymmetricFocalLoss-->>AsymmetricUnifiedFocalLoss: f_loss
AsymmetricUnifiedFocalLoss->>AsymmetricFocalTverskyLoss: forward(y_pred_act, y_true)
AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Exclude background if needed
AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute Tversky Index (TP, FN, FP)
AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply focal term: (1-TI)^(1/gamma)
AsymmetricFocalTverskyLoss-->>AsymmetricUnifiedFocalLoss: t_loss
AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Combine: lambda*f_loss + (1-lambda)*t_loss
AsymmetricUnifiedFocalLoss-->>User: loss
1 file reviewed, 3 comments
monai/losses/unified_focal_loss.py
Outdated
| if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: | ||
| y_true = y_true.unsqueeze(1) # Add channel dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: condition will fail when y_true.ndim == y_pred_act.ndim - 1
When y_true has shape (B, H, W) and y_pred_act has shape (B, C, H, W), the first condition y_true.ndim == y_pred_act.ndim - 1 evaluates to True (3 == 4-1). The code then proceeds to check y_true.shape[1] != y_pred_act.shape[1] on line 280, which compares H (spatial dimension) with C (channel dimension). This is a logical error.
The code should unsqueeze immediately when ndim differs:
| if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| if y_true.ndim == y_pred_act.ndim - 1: # e.g., (B, H, W) vs (B, C, H, W) | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| elif y_true.shape[1] != y_pred_act.shape[1]: |
monai/losses/unified_focal_loss.py
Outdated
| elif n_pred_ch > 1 or self.sigmoid: | ||
| # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion | ||
| if y_true.shape[1] != 1: | ||
| y_true = y_true.unsqueeze(1) | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: one-hot with num_classes=1 when sigmoid=True and n_pred_ch==1
When sigmoid=True and n_pred_ch == 1, line 269's condition is False (because not self.sigmoid is False), so the code proceeds to line 271. Since self.sigmoid is True, the condition n_pred_ch > 1 or self.sigmoid is True, and one-hot encoding is performed with num_classes=1.
One-hot encoding with 1 class doesn't make logical sense - it just keeps the single channel as-is. For binary segmentation with sigmoid, the pattern should be: input has 1 channel (foreground probability), and one-hot should convert to 2 channels (background, foreground).
| elif n_pred_ch > 1 or self.sigmoid: | |
| # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion | |
| if y_true.shape[1] != 1: | |
| y_true = y_true.unsqueeze(1) | |
| y_true = one_hot(y_true, num_classes=n_pred_ch) | |
| elif n_pred_ch > 1 or (self.sigmoid and n_pred_ch == 1): | |
| # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion | |
| if y_true.shape[1] != 1: | |
| y_true = y_true.unsqueeze(1) | |
| # For binary sigmoid case, need 2 classes | |
| num_classes = 2 if (self.sigmoid and n_pred_ch == 1) else n_pred_ch | |
| y_true = one_hot(y_true, num_classes=num_classes) |
| if self.to_onehot_y: | ||
| if n_pred_ch == 1: | ||
| if n_pred_ch == 1 and not self.sigmoid: | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||
| else: | ||
| elif n_pred_ch > 1 or self.sigmoid: | ||
| # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion | ||
| if y_true.shape[1] != 1: | ||
| y_true = y_true.unsqueeze(1) | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: potential shape mismatch with sigmoid=True, n_pred_ch=1, and to_onehot_y=True
When using sigmoid=True with 1 input channel, y_pred_act becomes (B, 1, H, W). If to_onehot_y=True, the code converts y_true to one-hot with num_classes=1, keeping it (B, 1, H, W).
Consider adding a warning or documentation note that for binary segmentation with sigmoid=True and to_onehot_y=True, users should keep y_true as 0/1 masks without one-hot conversion, since one-hot with 1 class is redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR generalizes AsymmetricUnifiedFocalLoss, AsymmetricFocalLoss, and AsymmetricFocalTverskyLoss to support multi-class segmentation and adds the standard sigmoid/softmax/to_onehot_y interface used across other MONAI losses.
Key changes:
- Replaced
to_onehot_yparameter withinclude_backgroundin base loss classes - Added
use_sigmoid/use_softmaxactivation options toAsymmetricUnifiedFocalLoss - Generalized asymmetric logic to handle arbitrary number of classes (not just binary)
- Separated gamma/delta parameters for focal and Tversky components
- Updated parameter names for clarity (
lambda_focalinstead ofweight)
Critical issues found:
- Incorrect gamma exponent formula: Uses
1.0 - self.gammabut should be1.0 / self.gammaper the paper (lines 83, 86, 90). With default gamma=0.75, current code gives (1-dice)^0.25 (decreasing penalty) instead of (1-dice)^1.333 (increasing penalty) - Dimension checking logic error: Line 280 compares spatial dimension with channel dimension when
y_true.ndimdiffers fromy_pred_act.ndim - One-hot with 1 class: When
use_sigmoid=Truewith single channel, creates one-hot encoding withnum_classes=1which is illogical
Confidence Score: 1/5
- This PR has critical mathematical errors that will cause incorrect loss values and training behavior
- The incorrect gamma exponent formula (1-gamma instead of 1/gamma) is a fundamental mathematical error that affects all three loss classes. This will cause the loss to behave opposite to its intended design, reducing penalties for poor predictions instead of increasing them. Combined with the dimension checking bug and one-hot encoding issue, this PR cannot be merged safely without fixes.
- The single file
monai/losses/unified_focal_loss.pyrequires immediate attention to fix the gamma exponent formula in all three loss classes (lines 83, 86, 90), the dimension checking logic (line 280), and the one-hot encoding logic (line 271)
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| monai/losses/unified_focal_loss.py | 1/5 | Generalizes three loss classes for multi-class support with new sigmoid/softmax interface. Critical issues: incorrect gamma exponent formula (uses 1-gamma instead of 1/gamma), dimension checking logic error, and one-hot encoding with 1 class. |
Sequence Diagram
sequenceDiagram
participant User
participant AUFL as AsymmetricUnifiedFocalLoss
participant AFL as AsymmetricFocalLoss
participant AFTL as AsymmetricFocalTverskyLoss
User->>AUFL: forward(y_pred, y_true)
alt use_sigmoid
AUFL->>AUFL: y_pred_act = sigmoid(y_pred)
else use_softmax
AUFL->>AUFL: y_pred_act = softmax(y_pred, dim=1)
end
alt to_onehot_y
AUFL->>AUFL: Convert y_true to one-hot
end
AUFL->>AUFL: Shape alignment checks
AUFL->>AFL: forward(y_pred_act, y_true)
alt include_background=False
AFL->>AFL: Exclude channel 0
end
AFL->>AFL: cross_entropy = -y_true * log(y_pred)
alt Asymmetric logic
AFL->>AFL: back_ce with gamma exponent
AFL->>AFL: fore_ce without gamma
AFL->>AFL: Combine and reduce
end
AFL-->>AUFL: f_loss
AUFL->>AFTL: forward(y_pred_act, y_true)
alt include_background=False
AFTL->>AFTL: Exclude channel 0
end
AFTL->>AFTL: Calculate tp, fn, fp
AFTL->>AFTL: dice_class = (tp + ε) / (tp + δ*fn + (1-δ)*fp + ε)
alt Asymmetric logic
AFTL->>AFTL: back_dice without gamma
AFTL->>AFTL: fore_dice with gamma exponent
AFTL->>AFTL: Combine and reduce
end
AFTL-->>AUFL: t_loss
AUFL->>AUFL: loss = λ*f_loss + (1-λ)*t_loss
AUFL-->>User: loss
1 file reviewed, 2 comments
| elif n_pred_ch > 1 or self.use_sigmoid: | ||
| # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion | ||
| if y_true.shape[1] != 1: | ||
| y_true = y_true.unsqueeze(1) | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: when use_sigmoid=True and n_pred_ch=1, one-hot with num_classes=1 doesn't make sense
When use_sigmoid=True with single channel, line 271's condition n_pred_ch > 1 or self.use_sigmoid evaluates to True (because self.use_sigmoid is True). This leads to one-hot encoding with num_classes=1, which is illogical. For binary sigmoid, should convert to 2 classes (background + foreground).
monai/losses/unified_focal_loss.py
Outdated
| if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: | ||
| y_true = y_true.unsqueeze(1) # Add channel dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: condition logic is backwards - checks incompatible dimensions
When y_true.ndim == y_pred_act.ndim - 1 is True (e.g., (B,H,W) vs (B,C,H,W)), the code should immediately unsqueeze. Instead, it also checks y_true.shape[1] != y_pred_act.shape[1] which compares H (spatial) with C (channel) - wrong dimensions.
| if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| if y_true.ndim == y_pred_act.ndim - 1: # e.g., (B, H, W) vs (B, C, H, W) | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| elif y_true.shape[1] != y_pred_act.shape[1]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (6)
monai/losses/unified_focal_loss.py (6)
83-83: Critical: incorrect exponent formula.The Unified Focal Loss paper specifies
(1 - dice)^(1/γ). With defaultγ=0.75, this gives(1 - dice)^1.333, increasing penalty for low Dice scores. Current formula1.0 - self.gammagives(1 - dice)^0.25, which incorrectly decreases the penalty.Apply this diff:
- loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) + loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C)
86-86: Critical: same exponent error.Apply this diff:
- loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) + loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1)
90-90: Critical: same exponent error.Apply this diff:
- fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) + fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1)
118-118: Breaking change: parameter removed without deprecation.Same issue as AsymmetricFocalTverskyLoss—
to_onehot_yremoved without deprecation decorator.
280-281: Critical: IndexError wheny_truehas fewer dimensions.When
y_true.ndim == y_pred_act.ndim - 1(e.g.,(B, H, W)vs(B, C, H, W)), the condition is True, but then checkingy_true.shape[1] != y_pred_act.shape[1]comparesH(spatial) withC(channel)—wrong dimensions. Should unsqueeze immediately whenndimdiffers.Apply this diff:
- if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: + if y_true.ndim == y_pred_act.ndim - 1: y_true = y_true.unsqueeze(1) # Add channel dim + elif y_true.shape[1] != y_pred_act.shape[1]: + y_true = y_true.unsqueeze(1)
271-275: Major: one-hot withnum_classes=1whenuse_sigmoid=True.When
use_sigmoid=Truewith single-channel input (n_pred_ch=1), the conditionn_pred_ch > 1 or self.use_sigmoidis True, leading toone_hot(y_true, num_classes=1). This raisesValueErrorwhen any voxel has label1. For binary sigmoid, should convert to 2 classes (background + foreground) or skip one-hot entirely.Apply this diff to handle binary sigmoid properly:
elif n_pred_ch > 1 or self.use_sigmoid: # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion if y_true.shape[1] != 1: y_true = y_true.unsqueeze(1) - y_true = one_hot(y_true, num_classes=n_pred_ch) + # For binary sigmoid, need 2 classes; otherwise use n_pred_ch + num_classes = 2 if (self.use_sigmoid and n_pred_ch == 1) else n_pred_ch + y_true = one_hot(y_true, num_classes=num_classes)Note: This also requires expanding
y_pred_actfrom 1 to 2 channels for consistency.
🧹 Nitpick comments (5)
monai/losses/unified_focal_loss.py (5)
65-65: Addstacklevel=2to warning.- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
147-147: Addstacklevel=2to warning.- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
196-197: Parameter naming inconsistency.Parameters named
use_sigmoidanduse_softmaxin the signature, but stored asself.use_sigmoidandself.use_softmax(lines 230-231). However, checked asself.use_sigmoidandself.use_softmaxthroughout. This is consistent, but the naming convention differs from typical MONAI losses which usesigmoid/softmax(notuse_sigmoid/use_softmax).For consistency with other MONAI losses (e.g.,
FocalLoss), consider renaming tosigmoidandsoftmax:- use_sigmoid: bool = False, - use_softmax: bool = False, + sigmoid: bool = False, + softmax: bool = False,Then update all references accordingly.
264-264: Addstacklevel=2to warning.- warnings.warn("single channel prediction, use_softmax=True ignored.") + warnings.warn("single channel prediction, use_softmax=True ignored.", stacklevel=2)
270-270: Addstacklevel=2to warning.- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.2)
monai/losses/unified_focal_loss.py
60-60: Avoid specifying long messages outside the exception class
(TRY003)
65-65: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
100-100: Avoid specifying long messages outside the exception class
(TRY003)
142-142: Avoid specifying long messages outside the exception class
(TRY003)
147-147: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
179-179: Avoid specifying long messages outside the exception class
(TRY003)
236-236: Avoid specifying long messages outside the exception class
(TRY003)
264-264: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
270-270: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
284-287: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (18)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: build-docs
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: packaging
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.8.0)
🔇 Additional comments (1)
monai/losses/unified_focal_loss.py (1)
193-294: Request test coverage for new functionality and bug fixes.This PR introduces significant changes: new activation modes (
use_sigmoid/use_softmax), generalization to multi-class, parameter API changes, and critical bug fixes (exponent formula). Without tests, these changes cannot be validated.Please add unit tests covering:
- Binary segmentation with
use_sigmoid=True- Multi-class segmentation with
use_softmax=True- Background inclusion/exclusion (
include_background)- One-hot conversion edge cases
- Shape alignment scenarios
- Correctness of exponent formula after fixes
Would you like assistance generating test cases?
f57d294 to
30c82db
Compare
5d14d85 to
d991466
Compare
for more information, see https://pre-commit.ci Signed-off-by: ytl0623 <[email protected]>
Fixes Project-MONAI#8564 . ### Description Add Fourier feature positional encodings to `PatchEmbeddingBlock`. It has been shown, that Fourier feature positional encodings are better suited for Anistropic images and videos: https://arxiv.org/abs/2509.02488 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: NabJa <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]> Signed-off-by: ytl0623 <[email protected]>
Signed-off-by: ytl0623 <[email protected]>
Signed-off-by: ytl0623 <[email protected]>
Signed-off-by: ytl0623 <[email protected]>
….com> I, NabJa <[email protected]>, hereby add my Signed-off-by to this commit: 76c4391 Signed-off-by: NabJa <[email protected]> Signed-off-by: ytl0623 <[email protected]>
957edd3 to
ac66250
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (4)
monai/losses/unified_focal_loss.py (4)
118-136: Breaking API change without deprecation.Same issue as
AsymmetricFocalTverskyLoss: removingto_onehot_ybreaks existing code. Restore with deprecation or document as breaking change.
90-90: Critical: Wrong exponent formula - should be1.0 / self.gamma.The Unified Focal Loss paper specifies AFTL = (1-TI)^(1/γ). With default γ=0.75, this gives (1-dice)^1.333 (increases penalty for low dice). Current
1.0 - self.gammagives (1-dice)^0.25 (decreases penalty).This critical error has been flagged in multiple past reviews.
Apply this diff:
- fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) + fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1)
279-280: Logic error: dimension mismatch when y_true lacks channel dim.When
y_trueis (B, H, W) andy_pred_actis (B, C, H, W), condition on line 279 is True (3 == 4-1) but code only unsqueezes, doesn't verify shape match. Should unsqueeze immediately when ndim differs.This was flagged multiple times in past reviews.
Apply this diff:
# Ensure y_true has the same shape as y_pred_act if y_true.shape != y_pred_act.shape: - if y_true.ndim == y_pred_act.ndim - 1: + if y_true.ndim == y_pred_act.ndim - 1: # e.g., (B, H, W) vs (B, C, H, W) y_true = y_true.unsqueeze(1) - - if y_true.shape != y_pred_act.shape: + elif y_true.shape != y_pred_act.shape: raise ValueError( f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) " f"after activations/one-hot" )
271-275: Critical: One-hot with num_classes=1 breaks binary sigmoid.When
use_sigmoid=Truewith single-channel logits (n_pred_ch=1), line 271's conditionn_pred_ch > 1 or self.use_sigmoidevaluates True, callingone_hot(y_true, num_classes=1). Any label==1 raises ValueError.This blocking bug was flagged by coderabbitai[bot] in a past review and remains unfixed.
Apply this diff:
if self.to_onehot_y: if n_pred_ch == 1 and not self.use_sigmoid: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") - elif n_pred_ch > 1 or self.use_sigmoid: + elif n_pred_ch > 1: # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion if y_true.shape[1] != 1: y_true = y_true.unsqueeze(1) y_true = one_hot(y_true, num_classes=n_pred_ch)Or expand binary logits to 2 channels before one-hotting.
🧹 Nitpick comments (5)
monai/losses/unified_focal_loss.py (5)
65-65: Addstacklevel=2to warning.- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
147-147: Addstacklevel=2to warning.- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
264-264: Addstacklevel=2to warning.- warnings.warn("single channel prediction, use_softmax=True ignored.") + warnings.warn("single channel prediction, use_softmax=True ignored.", stacklevel=2)
270-270: Addstacklevel=2to warning.- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
193-227: Docstrings missing exception documentation.Per coding guidelines, docstrings should document raised exceptions. Line 236 raises ValueError but it's not in the Args section.
Add a Raises section:
Args: to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. use_sigmoid: if True, apply a sigmoid activation to the input y_pred. use_softmax: if True, apply a softmax activation to the input y_pred. lambda_focal: the weight for AsymmetricFocalLoss (Cross-Entropy based). The weight for AsymmetricFocalTverskyLoss will be (1 - lambda_focal). Defaults to 0.5. focal_loss_gamma: gamma parameter for the AsymmetricFocalLoss component. Defaults to 2.0. focal_loss_delta: delta parameter for the AsymmetricFocalLoss component. Defaults to 0.7. tversky_loss_gamma: gamma parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.75. tversky_loss_delta: delta parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.7. include_background: whether to include loss computation for the background class. Defaults to True. reduction: specifies the reduction to apply to the output: "none", "mean", "sum". + + Raises: + ValueError: if both use_sigmoid and use_softmax are True.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.2)
monai/losses/unified_focal_loss.py
60-60: Avoid specifying long messages outside the exception class
(TRY003)
65-65: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
100-100: Avoid specifying long messages outside the exception class
(TRY003)
142-142: Avoid specifying long messages outside the exception class
(TRY003)
147-147: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
179-179: Avoid specifying long messages outside the exception class
(TRY003)
236-236: Avoid specifying long messages outside the exception class
(TRY003)
264-264: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
270-270: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
283-286: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (15)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: packaging
- GitHub Check: build-docs
| include_background: bool = True, | ||
| delta: float = 0.7, | ||
| gamma: float = 0.75, | ||
| epsilon: float = 1e-7, | ||
| reduction: LossReduction | str = LossReduction.MEAN, | ||
| ) -> None: | ||
| """ | ||
| Args: | ||
| to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. | ||
| delta : weight of the background. Defaults to 0.7. | ||
| gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. | ||
| epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. | ||
| include_background: whether to include loss computation for the background class. Defaults to True. | ||
| delta : weight of the background. Defaults to 0.7. (Used to weigh FNs and FPs in Tversky index) | ||
| gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. | ||
| epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. | ||
| reduction: specifies the reduction to apply to the output: "none", "mean", "sum". | ||
| """ | ||
| super().__init__(reduction=LossReduction(reduction).value) | ||
| self.to_onehot_y = to_onehot_y | ||
| self.include_background = include_background |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Breaking API change without deprecation.
Removing to_onehot_y from AsymmetricFocalTverskyLoss breaks existing code. Per MONAI conventions and prior reviewer feedback (ericspod), deprecated parameters must be preserved with deprecation decorators for 1-2 versions.
Either restore to_onehot_y with deprecation warnings or explicitly document this as a breaking change in the changelog.
Fixes #8603 .
Description
Addresses the limitations of
AsymmetricUnifiedFocalLossand its dependencies (AsymmetricFocalLoss,AsymmetricFocalTverskyLoss). The previous implementations were hard-coded for binary segmentation and lacked the standardsigmoid/softmax/to_onehot_yinterface common to other MONAI losses.Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.