- 
                Notifications
    You must be signed in to change notification settings 
- Fork 1.3k
Generalize AsymmetricUnifiedFocalLoss for multi-class and align interface #8607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
| Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThe unified focal loss module undergoes restructuring across three classes.  Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 
 Pre-merge checks and finishing touches❌ Failed checks (1 warning)
 ✅ Passed checks (3 passed)
 ✨ Finishing touches
 🧪 Generate unit tests (beta)
 Comment  | 
4d192fa    to
    1ab9120      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR successfully modernizes three asymmetric loss classes (AsymmetricFocalTverskyLoss, AsymmetricFocalLoss, and AsymmetricUnifiedFocalLoss) to align with MONAI's standard loss interface, addressing issue #8603. The changes generalize these losses from binary-only to multi-class segmentation support.
Key improvements:
- Replaced deprecated to_onehot_yparameter withinclude_backgroundin base loss classes
- Added sigmoid/softmaxactivation support toAsymmetricUnifiedFocalLoss
- Implemented proper multi-class handling with asymmetric foreground/background treatment
- Improved documentation and parameter naming clarity
- Refactored reduction logic to properly support "mean", "sum", and "none" modes
Issues found:
- Critical: Logic error in one-hot conversion when sigmoid=True(lines 272-278) - attempts to create one-hot withnum_classes=1which produces incorrect shapes
- Critical: Shape validation logic error (lines 281-287) - the condition y_true.ndim == y_pred_act.ndim - 1evaluates incorrectly
- Unused import (torch.nn.functional as F) should be removed
- Shape validation before one-hot conversion could be more robust
The refactoring follows the pattern established by FocalLoss but contains logic errors that will cause runtime failures with certain activation configurations.
Confidence Score: 2/5
- This PR has critical logic errors that will cause runtime failures with sigmoid activation and certain input shapes
- Score of 2 reflects two critical logic errors: (1) one-hot conversion with sigmoid creates invalid shapes by using num_classes=1, and (2) shape dimension checking logic is incorrect. These will cause runtime failures in common use cases. The refactoring approach is sound and aligns well with MONAI patterns, but the implementation bugs must be fixed before merging.
- monai/losses/unified_focal_loss.py lines 272-287 contain critical logic errors in shape handling and one-hot conversion that must be resolved
Important Files Changed
File Analysis
| Filename | Score | Overview | 
|---|---|---|
| monai/losses/unified_focal_loss.py | 3/5 | Generalizes loss functions from binary to multi-class with new sigmoid/softmax interface. Found critical logic errors in one-hot conversion with sigmoid and shape handling that will cause runtime failures. | 
Sequence Diagram
sequenceDiagram
    participant User
    participant AsymmetricUnifiedFocalLoss
    participant AsymmetricFocalLoss
    participant AsymmetricFocalTverskyLoss
    
    User->>AsymmetricUnifiedFocalLoss: forward(y_pred, y_true)
    
    alt sigmoid=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Apply sigmoid activation
    else softmax=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Apply softmax activation
    end
    
    alt to_onehot_y=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Convert y_true to one-hot format
    end
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Validate shapes match
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalLoss: forward(y_pred_act, y_true)
    
    alt include_background=False
        AsymmetricFocalLoss->>AsymmetricFocalLoss: Exclude background channel
    end
    
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Compute cross-entropy loss
    
    alt Multi-class with background
        AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply asymmetric weighting<br/>(gamma for BG, no gamma for FG)
    else Foreground only
        AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply foreground weighting
    end
    
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply reduction (mean/sum/none)
    AsymmetricFocalLoss-->>AsymmetricUnifiedFocalLoss: Return focal_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalTverskyLoss: forward(y_pred_act, y_true)
    
    alt include_background=False
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Exclude background channel
    end
    
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute TP, FN, FP
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute Tversky index
    
    alt Multi-class with background
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply asymmetric focal<br/>(no gamma for BG, gamma for FG)
    else Foreground only
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply foreground focal
    end
    
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply reduction (mean/sum/none)
    AsymmetricFocalTverskyLoss-->>AsymmetricUnifiedFocalLoss: Return tversky_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Combine losses<br/>lambda_focal * focal_loss + (1-lambda_focal) * tversky_loss
    AsymmetricUnifiedFocalLoss-->>User: Return combined_loss
    1 file reviewed, 4 comments
        
          
                monai/losses/unified_focal_loss.py
              
                Outdated
          
        
      |  | ||
| if not self.include_background: | ||
| # All classes are foreground, apply foreground logic | ||
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: exponent should be 1 + self.gamma (not 1 - self.gamma)
The original paper's formula uses (1 - dice)^(1/gamma), which when gamma < 1, becomes (1 - dice)^(>1), increasing the penalty. With 1 - self.gamma when gamma=0.75, you get (1-dice)^0.25, which decreases the penalty for low dice scores.
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | |
| loss = torch.pow(1.0 - dice_class, 1.0 + self.gamma) # (B, C) | 
        
          
                monai/losses/unified_focal_loss.py
              
                Outdated
          
        
      | loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | ||
| elif n_classes == 1: | ||
| # Single class, must be foreground (BG was excluded or not provided) | ||
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: same exponent error here - should be 1 + self.gamma
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) | |
| loss = torch.pow(1.0 - dice_class, 1.0 + self.gamma) # (B, 1) | 
| else: | ||
| # Asymmetric logic: class 0 is BG, others are FG | ||
| back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1) | ||
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: same exponent error - should be 1 + self.gamma
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) | |
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 + self.gamma) # (B, C-1) | 
| elif n_pred_ch > 1 or self.sigmoid: | ||
| # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion | ||
| if y_true.shape[1] != 1: | ||
| y_true = y_true.unsqueeze(1) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: unsqueezing unconditionally can cause issues if y_true is already (B, 1, H, W)
This adds an extra dimension. Should check y_true.ndim first or only unsqueeze when y_true.ndim == 3.
| y_true = y_true.unsqueeze(1) | |
| if y_true.ndim == 3: # (B, H, W) -> (B, 1, H, W) | |
| y_true = y_true.unsqueeze(1) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
- monai/losses/unified_focal_loss.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
- monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.2)
monai/losses/unified_focal_loss.py
61-61: Avoid specifying long messages outside the exception class
(TRY003)
66-66: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
103-103: Avoid specifying long messages outside the exception class
(TRY003)
145-145: Avoid specifying long messages outside the exception class
(TRY003)
150-150: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
182-182: Avoid specifying long messages outside the exception class
(TRY003)
239-239: Avoid specifying long messages outside the exception class
(TRY003)
267-267: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
273-273: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
287-287: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: build-docs
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: packaging
2959015    to
    dca758d      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Generalizes AsymmetricUnifiedFocalLoss and its component losses for multi-class segmentation. Replaces to_onehot_y parameter with standard include_background/sigmoid/softmax interface matching other MONAI losses. Adds proper handling for background vs foreground classes using asymmetric weighting.
Critical Issues Found:
- Mathematical errors in gamma exponents (lines 86, 89, 93) - uses 1.0 - self.gammainstead of1.0 / self.gamma, inverting the focal weighting behavior
- Syntax error in exception message (line 287) - unclosed string literal
- Logic error in shape validation (line 283) - incorrect dimension checking that will cause IndexError
Changes:
- Replaced to_onehot_yparameter withinclude_backgroundin component losses
- Added sigmoid/softmaxactivation support toAsymmetricUnifiedFocalLoss
- Split gamma/delta parameters into separate focal and Tversky variants
- Generalized from binary to multi-class segmentation support
- Added proper reduction handling (none/mean/sum)
Confidence Score: 0/5
- This PR is unsafe to merge due to critical mathematical errors and syntax error that will cause runtime failures
- Score of 0 reflects three critical blocking issues: (1) mathematical formula errors in lines 86, 89, 93 where 1.0 - self.gammashould be1.0 / self.gamma, completely inverting the focal weighting behavior, (2) syntax error on line 287 with unclosed string literal that will prevent code from running, and (3) logic error on line 283 that will cause IndexError when accessing shape dimensions
- monai/losses/unified_focal_loss.py requires immediate attention - all identified errors must be fixed before merge
Important Files Changed
File Analysis
| Filename | Score | Overview | 
|---|---|---|
| monai/losses/unified_focal_loss.py | 1/5 | Critical mathematical errors in gamma exponents (lines 86, 89, 93), syntax error in exception message (line 287), and logic error in shape handling (line 283) | 
Sequence Diagram
sequenceDiagram
    participant User
    participant AsymmetricUnifiedFocalLoss
    participant AsymmetricFocalLoss
    participant AsymmetricFocalTverskyLoss
    
    User->>AsymmetricUnifiedFocalLoss: forward(y_pred, y_true)
    
    alt sigmoid=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: torch.sigmoid(y_pred)
    else softmax=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: torch.softmax(y_pred, dim=1)
    end
    
    alt to_onehot_y=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: one_hot(y_true)
    end
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Shape validation
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalLoss: forward(y_pred_act, y_true)
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Compute cross-entropy
    
    alt include_background=False
        AsymmetricFocalLoss->>AsymmetricFocalLoss: Exclude background channel
    else n_classes > 1
        AsymmetricFocalLoss->>AsymmetricFocalLoss: Asymmetric weighting<br/>(BG: (1-delta)*pow(1-p, gamma)*CE)<br/>(FG: delta*CE)
    end
    
    AsymmetricFocalLoss-->>AsymmetricUnifiedFocalLoss: f_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalTverskyLoss: forward(y_pred_act, y_true)
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute Tversky index
    
    alt include_background=False
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Exclude background channel
    else n_classes > 1
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Asymmetric weighting<br/>(BG: 1-dice)<br/>(FG: pow(1-dice, 1/gamma))
    end
    
    AsymmetricFocalTverskyLoss-->>AsymmetricUnifiedFocalLoss: t_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: lambda_focal * f_loss +<br/>(1 - lambda_focal) * t_loss
    AsymmetricUnifiedFocalLoss-->>User: combined_loss
    1 file reviewed, 5 comments
        
          
                monai/losses/unified_focal_loss.py
              
                Outdated
          
        
      |  | ||
| if not self.include_background: | ||
| # All classes are foreground, apply foreground logic | ||
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: exponent formula is incorrect - should be 1.0 / self.gamma (not 1.0 - self.gamma)
The paper uses (1 - dice)^(1/gamma). When gamma=0.75, this gives (1-dice)^1.333, increasing penalty for low dice. Your formula gives (1-dice)^0.25, decreasing it.
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | |
| loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C) | 
        
          
                monai/losses/unified_focal_loss.py
              
                Outdated
          
        
      | loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | ||
| elif n_classes == 1: | ||
| # Single class, must be foreground (BG was excluded or not provided) | ||
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: same exponent error - should be 1.0 / self.gamma
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) | |
| loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1) | 
| else: | ||
| # Asymmetric logic: class 0 is BG, others are FG | ||
| back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1) | ||
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: same exponent error - should be 1.0 / self.gamma
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) | |
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1) | 
        
          
                monai/losses/unified_focal_loss.py
              
                Outdated
          
        
      | raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) | ||
| after activations/one-hot") | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: syntax error - string literal not properly closed
Missing closing quote on line 287.
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) | |
| after activations/one-hot") | |
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) after activations/one-hot") | 
        
          
                monai/losses/unified_focal_loss.py
              
                Outdated
          
        
      | if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: | ||
| y_true = y_true.unsqueeze(1) # Add channel dim | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: logic error - condition check is incorrect
When y_true.ndim == y_pred_act.ndim - 1, y_true has fewer dimensions (e.g., (B,H,W) vs (B,C,H,W)), so the check y_true.shape[1] != y_pred_act.shape[1] will fail with IndexError.
| if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| if y_true.ndim == y_pred_act.ndim - 1: # e.g., (B, H, W) vs (B, C, H, W) | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| elif y_true.shape[1] != y_pred_act.shape[1]: | |
| y_true = y_true.unsqueeze(1) # Add channel dim if missing | 
…face Signed-off-by: ytl0623 <[email protected]>
100e9f8    to
    ad83444      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR successfully generalizes the AsymmetricUnifiedFocalLoss and its component losses (AsymmetricFocalLoss, AsymmetricFocalTverskyLoss) to support multi-class segmentation, replacing the previous binary-only limitation. The interface has been modernized to align with other MONAI losses by adding sigmoid/softmax/to_onehot_y parameters and replacing to_onehot_y with include_background in the component losses.
Major Changes:
- Added sigmoid/softmaxactivation options toAsymmetricUnifiedFocalLoss
- Replaced to_onehot_ywithinclude_backgroundparameter in component losses
- Generalized all three loss classes to handle arbitrary number of classes
- Improved parameter documentation and added clearer examples
- Properly implemented reduction modes (mean,sum,none)
Critical Issues Found:
- Incorrect exponent formula in AsymmetricFocalTverskyLoss: Uses1.0 - self.gammainstead of1.0 / self.gammaper the paper's formula (1-TI)^(1/γ). With default γ=0.75, this produces exponent 0.25 instead of 1.333, fundamentally breaking the loss behavior by reducing instead of increasing the penalty for low Dice scores.
- Shape handling logic error in AsymmetricUnifiedFocalLossline 282: Wheny_true.ndim == y_pred_act.ndim - 1, the code incorrectly tries to comparey_true.shape[1](spatial dim H) withy_pred_act.shape[1](channel dim C) before unsqueezing.
- String syntax error on line 287: Missing closing quote on the error message string.
Confidence Score: 1/5
- This PR has critical mathematical errors that break the loss function's core behavior
- The incorrect exponent formula (1-γ instead of 1/γ) in AsymmetricFocalTverskyLoss fundamentally changes the loss behavior, causing it to decrease penalties for low Dice scores when it should increase them. Combined with the shape handling bug and syntax error, these issues will cause runtime failures and incorrect training behavior. The mathematical error affects all three loss classes and deviates from the published paper.
- monai/losses/unified_focal_loss.py requires immediate attention - all four exponent formulas must be corrected (lines 85, 88, 92), the shape logic fixed (line 282), and the string syntax error resolved (line 287)
Important Files Changed
File Analysis
| Filename | Score | Overview | 
|---|---|---|
| monai/losses/unified_focal_loss.py | 1/5 | Generalizes losses for multi-class segmentation and adds sigmoid/softmax interface. Critical issues: incorrect exponent formula in AsymmetricFocalTverskyLoss (uses 1-γ instead of 1/γ), shape handling logic error in AsymmetricUnifiedFocalLoss, and string syntax error. | 
Sequence Diagram
sequenceDiagram
    participant User
    participant AUFL as AsymmetricUnifiedFocalLoss
    participant AFL as AsymmetricFocalLoss
    participant AFTL as AsymmetricFocalTverskyLoss
    
    User->>AUFL: forward(y_pred, y_true)
    
    alt sigmoid=True
        AUFL->>AUFL: Apply sigmoid activation
    else softmax=True
        AUFL->>AUFL: Apply softmax activation
    end
    
    alt to_onehot_y=True
        AUFL->>AUFL: Convert y_true to one-hot
    end
    
    AUFL->>AUFL: Shape alignment checks
    
    AUFL->>AFL: forward(y_pred_act, y_true)
    AFL->>AFL: Exclude background if needed
    AFL->>AFL: Compute cross_entropy
    
    alt Multi-class with background
        AFL->>AFL: Apply asymmetric weights<br/>BG: (1-δ)*(1-p)^γ*CE<br/>FG: δ*CE
    else Foreground only
        AFL->>AFL: Apply foreground weight δ*CE
    end
    
    AFL->>AFL: Apply reduction
    AFL-->>AUFL: Return focal_loss
    
    AUFL->>AFTL: forward(y_pred_act, y_true)
    AFTL->>AFTL: Exclude background if needed
    AFTL->>AFTL: Compute Tversky Index
    
    alt Multi-class with background
        AFTL->>AFTL: Apply asymmetric logic<br/>BG: 1-TI<br/>FG: (1-TI)^(1-γ) [INCORRECT]
    else Foreground only
        AFTL->>AFTL: Apply (1-TI)^(1-γ) [INCORRECT]
    end
    
    AFTL->>AFTL: Apply reduction
    AFTL-->>AUFL: Return tversky_loss
    
    AUFL->>AUFL: Combine losses<br/>λ*focal_loss + (1-λ)*tversky_loss
    AUFL-->>User: Return combined loss
    1 file reviewed, 5 comments
        
          
                monai/losses/unified_focal_loss.py
              
                Outdated
          
        
      | if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: | ||
| y_true = y_true.unsqueeze(1) # Add channel dim | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: IndexError when y_true.ndim == y_pred_act.ndim - 1
When y_true is (B, H, W) and y_pred_act is (B, C, H, W), the condition y_true.ndim == y_pred_act.ndim - 1 is True, but then checking y_true.shape[1] != y_pred_act.shape[1] compares H with C, which is nonsensical. The logic should unsqueeze immediately when ndim differs.
| if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| if y_true.ndim == y_pred_act.ndim - 1: | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| elif y_true.shape[1] != y_pred_act.shape[1]: | 
        
          
                monai/losses/unified_focal_loss.py
              
                Outdated
          
        
      |  | ||
| if not self.include_background: | ||
| # All classes are foreground, apply foreground logic | ||
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Incorrect exponent formula - should be 1.0 / self.gamma
The Unified Focal Loss paper specifies AFTL = (1-TI)^(1/γ). With default γ=0.75, this gives exponent 1.333, increasing penalty for low Dice scores. Current formula 1.0 - self.gamma gives 0.25, incorrectly decreasing the penalty.
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | |
| loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C) | 
        
          
                monai/losses/unified_focal_loss.py
              
                Outdated
          
        
      | loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | ||
| elif n_classes == 1: | ||
| # Single class, must be foreground (BG was excluded or not provided) | ||
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Same exponent error - should be 1.0 / self.gamma
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) | |
| loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1) | 
| else: | ||
| # Asymmetric logic: class 0 is BG, others are FG | ||
| back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1) | ||
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Same exponent error - should be 1.0 / self.gamma
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) | |
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1) | 
        
          
                monai/losses/unified_focal_loss.py
              
                Outdated
          
        
      | loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss | ||
| if y_true.shape != y_pred_act.shape: | ||
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) " \ | ||
| f"after activations/one-hot") | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: String literal syntax error - missing closing quote
| f"after activations/one-hot") | |
| "after activations/one-hot") | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (3)
monai/losses/unified_focal_loss.py (3)
85-92: Critical: Incorrect exponent formula in all three branches.The paper specifies
(1 - dice)^(1/gamma), but lines 85, 88, and 92 use1.0 - self.gamma. With gamma=0.75, you get(1-dice)^0.25(reduces penalty for low Dice), when you should get(1-dice)^1.333(increases penalty).Apply this diff:
if not self.include_background: # All classes are foreground, apply foreground logic - loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) + loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C) elif n_classes == 1: # Single class, must be foreground (BG was excluded or not provided) - loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) + loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1) else: # Asymmetric logic: class 0 is BG, others are FG back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1) - fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) + fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1) loss = torch.cat([back_dice_loss, fore_dice_loss], dim=1) # (B, C)
270-277: Critical: sigmoid+one-hot causes ValueError with binary classification.When
sigmoid=Trueandn_pred_ch=1, line 273's condition passes, triggeringone_hot(y_true, num_classes=1)at line 277. Any ground truth voxel with value 1 raisesValueError: class values must be smaller than num_classes.Apply this diff:
if self.to_onehot_y: - if n_pred_ch == 1 and not self.sigmoid: + if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") - elif n_pred_ch > 1 or self.sigmoid: + else: # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion if y_true.shape[1] != 1: y_true = y_true.unsqueeze(1) y_true = one_hot(y_true, num_classes=n_pred_ch)
280-287: Critical: Logic error in shape alignment causes incorrect comparison.Line 282 evaluates
y_true.shape[1] != y_pred_act.shape[1]even wheny_true.ndim == y_pred_act.ndim - 1. If y_true is (B, H, W) and y_pred_act is (B, C, H, W), you're comparing H (spatial dim) to C (channels), which is nonsensical.Additionally, line 287 has a syntax error—f-string is split across lines incorrectly per Ruff.
Apply this diff:
# Ensure y_true has the same shape as y_pred_act if y_true.shape != y_pred_act.shape: - # This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid - if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: + if y_true.ndim == y_pred_act.ndim - 1: # e.g., (B, H, W) vs (B, C, H, W) y_true = y_true.unsqueeze(1) # Add channel dim - - if y_true.shape != y_pred_act.shape: - raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) " \ - f"after activations/one-hot") + elif y_true.shape != y_pred_act.shape: + raise ValueError( + f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) " + f"after activations/one-hot" + )
🧹 Nitpick comments (1)
monai/losses/unified_focal_loss.py (1)
44-48: Improve parameter documentation clarity.Line 45 says delta is "weight of the background," but line 129 says it's "weight of the foreground." The docstrings should clarify that delta controls the FN/FP trade-off in Tversky (line 45) vs. background/foreground weight in focal CE (line 129). Also missing: documentation of raised exceptions (ValueError) per coding guidelines.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
- monai/losses/unified_focal_loss.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
- monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.2)
monai/losses/unified_focal_loss.py
287-287: f-string: unterminated string
(invalid-syntax)
287-288: Expected FStringEnd, found newline
(invalid-syntax)
288-288: Expected ',', found indent
(invalid-syntax)
288-288: Expected ',', found name
(invalid-syntax)
288-288: missing closing quote in string literal
(invalid-syntax)
288-289: Expected ',', found newline
(invalid-syntax)
291-291: Expected ')', found dedent
(invalid-syntax)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: build-docs
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: packaging
🔇 Additional comments (4)
monai/losses/unified_focal_loss.py (4)
146-152: LGTM on background exclusion logic.The include_background handling is consistent with AsymmetricFocalTverskyLoss.
237-238: LGTM on mutual exclusivity check.Correctly prevents both sigmoid and softmax from being enabled simultaneously.
290-295: LGTM on loss fusion logic.Correctly delegates to sub-loss instances and combines via weighted sum. Assuming sub-losses return scalar values when
reduction="mean"orreduction="sum", this will work.
195-251: Add unit tests for new sigmoid/softmax interface.The PR adds sigmoid/softmax activation options and renames
to_onehot_ytoinclude_background, but no new tests are mentioned. These API changes need coverage to verify:
- sigmoid/softmax mutual exclusivity
- One-hot conversion with different activation modes
- Background inclusion/exclusion behavior
- Multi-class vs. binary edge cases
Per coding guidelines, modified definitions must be covered by tests. Do you want me to generate test cases or open an issue to track this?
6b0c14c    to
    f5a2f7e      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR generalizes AsymmetricUnifiedFocalLoss and its component losses to support multi-class segmentation and adds the standard sigmoid/softmax/to_onehot_y interface common to other MONAI losses.
Major changes:
- Replaced to_onehot_yparameter withinclude_backgroundin component losses (AsymmetricFocalLoss,AsymmetricFocalTverskyLoss)
- Added sigmoid/softmax/to_onehot_yparameters toAsymmetricUnifiedFocalLoss
- Refactored loss computation to handle multi-class scenarios
- Updated interface to match MONAI's FocalLosspattern
Critical issues found:
- The gamma exponent formula in AsymmetricFocalTverskyLossis mathematically incorrect on lines 85, 88, and 92. Uses1.0 - self.gammainstead of1.0 / self.gamma, which inverts the focal penalty behavior
- Dimension handling logic error on lines 282-283 in AsymmetricUnifiedFocalLosswhere shape comparison happens before necessary dimension adjustment
Confidence Score: 1/5
- This PR contains critical mathematical errors that will cause incorrect loss computation
- The gamma exponent formula is fundamentally wrong in three places (lines 85, 88, 92), using 1.0 - gammainstead of1.0 / gamma. This inverts the focal penalty behavior. Additionally, there's a dimension handling logic error (lines 282-283). These are not edge cases but core functionality bugs that will affect all users.
- monai/losses/unified_focal_loss.py requires immediate attention to fix the gamma exponent formula and dimension handling logic
Important Files Changed
File Analysis
| Filename | Score | Overview | 
|---|---|---|
| monai/losses/unified_focal_loss.py | 1/5 | Critical mathematical errors in gamma exponent formula (lines 85, 88, 92) and dimension handling logic error (line 282-283). These will cause incorrect loss computation. | 
Sequence Diagram
sequenceDiagram
    participant User
    participant AsymmetricUnifiedFocalLoss
    participant AsymmetricFocalLoss
    participant AsymmetricFocalTverskyLoss
    
    User->>AsymmetricUnifiedFocalLoss: forward(y_pred, y_true)
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Apply sigmoid/softmax if enabled
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Convert y_true to one-hot if to_onehot_y=True
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Ensure y_true and y_pred_act shapes match
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalLoss: forward(y_pred_act, y_true)
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Exclude background if include_background=False
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Compute cross-entropy loss
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply asymmetric focal weighting
    AsymmetricFocalLoss-->>AsymmetricUnifiedFocalLoss: f_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalTverskyLoss: forward(y_pred_act, y_true)
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Exclude background if include_background=False
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute Tversky index (TP, FN, FP)
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply focal exponent (WRONG: 1-gamma instead of 1/gamma)
    AsymmetricFocalTverskyLoss-->>AsymmetricUnifiedFocalLoss: t_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Combine: lambda_focal * f_loss + (1-lambda_focal) * t_loss
    AsymmetricUnifiedFocalLoss-->>User: loss
    1 file reviewed, 4 comments
        
          
                monai/losses/unified_focal_loss.py
              
                Outdated
          
        
      |  | ||
| if not self.include_background: | ||
| # All classes are foreground, apply foreground logic | ||
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: incorrect exponent - should be 1.0 / self.gamma (not 1.0 - self.gamma)
The Unified Focal Loss paper specifies AFTL = (1-TI)^(1/γ). With default γ=0.75, this gives (1-dice)^1.333, increasing penalty for low dice. Current 1.0 - self.gamma gives (1-dice)^0.25, incorrectly decreasing penalty.
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | |
| loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C) | 
        
          
                monai/losses/unified_focal_loss.py
              
                Outdated
          
        
      | loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | ||
| elif n_classes == 1: | ||
| # Single class, must be foreground (BG was excluded or not provided) | ||
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: same exponent error - should be 1.0 / self.gamma
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) | |
| loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1) | 
| else: | ||
| # Asymmetric logic: class 0 is BG, others are FG | ||
| back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1) | ||
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: same exponent error - should be 1.0 / self.gamma
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) | |
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1) | 
        
          
                monai/losses/unified_focal_loss.py
              
                Outdated
          
        
      | if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: | ||
| y_true = y_true.unsqueeze(1) # Add channel dim | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: logic error - when y_true.ndim == y_pred_act.ndim - 1, checking y_true.shape[1] on line 282 will compare wrong dimensions
When y_true is (B, H, W) and y_pred_act is (B, C, H, W), the first condition is True, but then checking y_true.shape[1] (which is H) against y_pred_act.shape[1] (which is C) is meaningless. Should unsqueeze immediately.
| if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| if y_true.ndim == y_pred_act.ndim - 1: # e.g., (B, H, W) vs (B, C, H, W) | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| elif y_true.shape[1] != y_pred_act.shape[1]: | |
| y_true = y_true.unsqueeze(1) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️  Outside diff range comments (1)
monai/losses/unified_focal_loss.py (1)
34-296: Add test coverage for new sigmoid/softmax/include_background modes.Current tests only cover default configuration with perfect predictions. Missing coverage:
- Binary segmentation with sigmoid=True
- Multi-class with softmax=True
- include_background=False with sub-loss variations
- to_onehot_y=True edge cases (sigmoid binary, multi-class)
- Reduction modes (SUM, NONE)
- Sub-loss classes in isolation
♻️ Duplicate comments (3)
monai/losses/unified_focal_loss.py (3)
270-277: Critical: one-hot conversion fails for sigmoid with single channel.When
sigmoid=Trueandn_pred_ch==1, line 273 triggersone_hot(..., num_classes=1), which raises ValueError for any label==1. The original issue #8603 requested sigmoid/softmax support, but this breaks binary sigmoid cases.Apply this diff:
if self.to_onehot_y: - if n_pred_ch == 1 and not self.sigmoid: + if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") - elif n_pred_ch > 1 or self.sigmoid: + else: # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion if y_true.shape[1] != 1: y_true = y_true.unsqueeze(1) y_true = one_hot(y_true, num_classes=n_pred_ch)
280-284: Major: IndexError when y_true is missing channel dim.Line 282 checks
y_true.ndim == y_pred_act.ndim - 1(e.g., y_true is (B,H,W)), but then accessesy_true.shape[1], which indexes the spatial dimension H, not the channel. This causes incorrect logic or IndexError.Apply this diff:
# Ensure y_true has the same shape as y_pred_act if y_true.shape != y_pred_act.shape: # This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid - if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: + if y_true.ndim == y_pred_act.ndim - 1: y_true = y_true.unsqueeze(1) # Add channel dim + elif y_true.shape[1] != y_pred_act.shape[1]: + y_true = y_true.unsqueeze(1) # Add channel dim if mismatch
85-92: Critical: Exponent formula is mathematically incorrect.All three branches use
1.0 - self.gammabut the Unified Focal Loss paper specifies(1 - TI)^(1/γ). With default γ=0.75, your formula gives exponent 0.25 (weakens penalty for poor predictions), but the paper requires 1.333 (strengthens penalty).Apply this diff:
if not self.include_background: # All classes are foreground, apply foreground logic - loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) + loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C) elif n_classes == 1: # Single class, must be foreground (BG was excluded or not provided) - loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) + loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1) else: # Asymmetric logic: class 0 is BG, others are FG back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1) - fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) + fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1) loss = torch.cat([back_dice_loss, fore_dice_loss], dim=1) # (B, C)
🧹 Nitpick comments (4)
monai/losses/unified_focal_loss.py (4)
65-65: Add stacklevel to warning.Per PEP 565, warnings should specify
stacklevel=2so users see the caller's location.- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
149-149: Add stacklevel to warning.- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
266-266: Add stacklevel to warning.- warnings.warn("single channel prediction, softmax=True ignored.") + warnings.warn("single channel prediction, softmax=True ignored.", stacklevel=2)
272-272: Add stacklevel to warning.- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
- monai/losses/unified_focal_loss.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
- monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.2)
monai/losses/unified_focal_loss.py
60-60: Avoid specifying long messages outside the exception class
(TRY003)
65-65: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
102-102: Avoid specifying long messages outside the exception class
(TRY003)
144-144: Avoid specifying long messages outside the exception class
(TRY003)
149-149: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
181-181: Avoid specifying long messages outside the exception class
(TRY003)
238-238: Avoid specifying long messages outside the exception class
(TRY003)
266-266: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
272-272: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
286-289: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: build-docs
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.9)
🔇 Additional comments (2)
monai/losses/unified_focal_loss.py (2)
154-172: Asymmetric weighting logic is correct.Background class gets gamma-modulated penalty; foreground classes get flat delta weighting. This aligns with the asymmetric focal loss design.
291-296: Fused loss computation is correct.The weighted combination
lambda_focal * focal + (1 - lambda_focal) * tverskyproperly unifies both loss components.
f5a2f7e    to
    b4e0fcc      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Refactored AsymmetricUnifiedFocalLoss and its component losses to support multi-class segmentation and added standard MONAI loss interface (sigmoid/softmax/to_onehot_y/include_background).
Key Changes:
- Replaced to_onehot_yparameter withinclude_backgroundin component losses
- Added sigmoid/softmaxactivation options to main loss
- Generalized asymmetric logic to handle multi-class cases
- Improved parameter naming (lambda_focalinstead ofweight, separate gamma/delta for each component)
- Updated reduction logic to handle different output shapes
Critical Issues Identified (from previous comments):
- Incorrect exponent formula in AsymmetricFocalTverskyLoss(should be1.0 / self.gammanot1.0 - self.gamma)
- Shape handling bug on line 280 that checks wrong dimensions when y_true.ndim == y_pred_act.ndim - 1
- String literal syntax error on line 287 (unclosed quote)
Confidence Score: 1/5
- This PR has critical mathematical errors and runtime bugs that will cause incorrect behavior
- Multiple critical logic errors have been identified: (1) incorrect exponent formula in 3 locations that fundamentally breaks the loss calculation, (2) shape dimension checking bug that will cause IndexError at runtime, (3) syntax error with unclosed string. These issues need to be resolved before merge.
- monai/losses/unified_focal_loss.py requires immediate attention - fix exponent formulas on lines 83, 86, 90 and shape checking logic on lines 280-287
Important Files Changed
File Analysis
| Filename | Score | Overview | 
|---|---|---|
| monai/losses/unified_focal_loss.py | 1/5 | Refactored to support multi-class segmentation with sigmoid/softmax interface. Critical exponent formula errors need fixing, and shape handling logic has bugs that could cause runtime errors. | 
Sequence Diagram
sequenceDiagram
    participant User
    participant AsymmetricUnifiedFocalLoss
    participant AsymmetricFocalLoss
    participant AsymmetricFocalTverskyLoss
    
    User->>AsymmetricUnifiedFocalLoss: forward(y_pred, y_true)
    
    alt sigmoid==True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Apply sigmoid(y_pred)
    else softmax==True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Apply softmax(y_pred)
    end
    
    alt to_onehot_y==True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Convert y_true to one-hot
    end
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Shape alignment checks
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalLoss: forward(y_pred_act, y_true)
    
    alt include_background==False
        AsymmetricFocalLoss->>AsymmetricFocalLoss: Exclude channel 0
    end
    
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Compute cross_entropy
    
    alt n_classes > 1 and include_background
        AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply asymmetric logic<br/>(BG: gamma exponent, FG: no exponent)
    else
        AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply foreground logic<br/>(delta * cross_entropy)
    end
    
    AsymmetricFocalLoss-->>AsymmetricUnifiedFocalLoss: f_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalTverskyLoss: forward(y_pred_act, y_true)
    
    alt include_background==False
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Exclude channel 0
    end
    
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute TP, FN, FP
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute Tversky index
    
    alt n_classes > 1 and include_background
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply asymmetric logic<br/>(BG: no exponent, FG: gamma exponent)
    else
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply foreground logic<br/>(gamma exponent)
    end
    
    AsymmetricFocalTverskyLoss-->>AsymmetricUnifiedFocalLoss: t_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Combine: lambda*f_loss + (1-lambda)*t_loss
    
    AsymmetricUnifiedFocalLoss-->>User: Combined loss
    1 file reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
monai/losses/unified_focal_loss.py (2)
83-90: Critical: incorrect exponent formula breaks Tversky loss behavior.Lines 83, 86, and 90 use
1.0 - self.gammabut the paper specifies(1 - TI)^(1/γ). With default γ=0.75, the correct exponent is 1.333, increasing penalty for low Dice scores. Current formula gives 0.25, inverting the intended behavior.Apply this diff:
- loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) + loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C)- loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) + loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1)- fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) + fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1)
280-281: Critical: IndexError when y_true has fewer dimensions.When
y_true.ndim == y_pred_act.ndim - 1(e.g., (B,H,W) vs (B,C,H,W)), checkingy_true.shape[1]compares spatial dimension H with channel dimension C, which is nonsensical and may raise IndexError.Apply this diff:
- if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: - y_true = y_true.unsqueeze(1) # Add channel dim + if y_true.ndim == y_pred_act.ndim - 1: + y_true = y_true.unsqueeze(1) # Add channel dim + elif y_true.shape[1] != y_pred_act.shape[1]: + y_true = y_true.unsqueeze(1)
🧹 Nitpick comments (4)
monai/losses/unified_focal_loss.py (4)
65-65: Add stacklevel to warning.- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
147-147: Add stacklevel to warning.- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
264-264: Add stacklevel to warning.- warnings.warn("single channel prediction, softmax=True ignored.") + warnings.warn("single channel prediction, softmax=True ignored.", stacklevel=2)
270-270: Add stacklevel to warning.- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
- monai/losses/unified_focal_loss.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
- monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.2)
monai/losses/unified_focal_loss.py
60-60: Avoid specifying long messages outside the exception class
(TRY003)
65-65: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
100-100: Avoid specifying long messages outside the exception class
(TRY003)
142-142: Avoid specifying long messages outside the exception class
(TRY003)
147-147: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
179-179: Avoid specifying long messages outside the exception class
(TRY003)
236-236: Avoid specifying long messages outside the exception class
(TRY003)
264-264: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
270-270: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
284-287: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: build-docs
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: packaging
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.11)
🔇 Additional comments (4)
monai/losses/unified_focal_loss.py (4)
152-170: Asymmetric focal loss logic is correct.Background class applies gamma exponent to (1-p), foreground classes don't. This matches the paper's asymmetric formulation.
238-249: Sub-loss composition approach is sound.Creating separate AsymmetricFocalLoss and AsymmetricFocalTverskyLoss instances with independent gamma/delta parameters provides clean separation and reuses existing components per PR objectives.
235-236: Mutual exclusivity check for sigmoid/softmax is correct.
289-294: Fused loss computation is correct.Weighted combination of focal and Tversky losses matches the unified focal loss formulation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ytl0623 thanks for the contribution. I had only a minor comment but please look at the greptile comments. A few look spurious they should all be resolved.
|  | ||
| def __init__( | ||
| self, | ||
| to_onehot_y: bool = False, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is to_onehot_y being removed? I think users will still want this functionality so I would leave it in and add include_background as a new last argument. Even if we do want to remove it we need to use the deprecation decorators to mark the argument removed but still keep it for a version or two.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR modernizes the AsymmetricUnifiedFocalLoss and its component losses (AsymmetricFocalLoss, AsymmetricFocalTverskyLoss) to support multi-class segmentation and adds standard MONAI loss interfaces (sigmoid/softmax/to_onehot_y/include_background).
Key changes:
- Replaced to_onehot_yparameter withinclude_backgroundin base loss classes, matching MONAI conventions
- Added sigmoid/softmaxactivation options toAsymmetricUnifiedFocalLoss
- Generalized from binary-only to multi-class segmentation support
- Maintained asymmetric treatment: background class (index 0) receives different weighting than foreground classes
- Updated reduction logic to properly handle "none"/"mean"/"sum" modes
Critical issues identified:
- 
Incorrect exponent formula: Lines 83, 86, 90 use 1.0 - self.gammabut should use1.0 / self.gammaaccording to the Unified Focal Loss paper's formula: AFTL = (1-TI)^(1/γ). With default γ=0.75, current code gives (1-dice)^0.25 (decreasing penalty) instead of (1-dice)^1.333 (increasing penalty).
- 
Shape handling logic error: Line 280 checks y_true.shape[1] != y_pred_act.shape[1]after confirmingy_true.ndim == y_pred_act.ndim - 1, which compares spatial dimension H with channel dimension C.
Recommendations:
- Fix the critical exponent formula errors immediately
- Fix the shape handling conditional logic
- Add tests covering multi-class scenarios and new sigmoid/softmax interfaces
- Verify against reference implementation or paper's equations
Confidence Score: 1/5
- This PR has critical mathematical errors that will cause incorrect loss computation and training behavior
- Score of 1 reflects multiple critical logical errors: (1) incorrect exponent formula in AsymmetricFocalTverskyLoss using 1.0 - self.gammainstead of1.0 / self.gamma, causing fundamentally wrong loss values that contradict the referenced paper; (2) shape handling logic error comparing incompatible dimensions. These are not edge cases but affect core functionality. The PR cannot be merged until these mathematical errors are corrected.
- The only changed file monai/losses/unified_focal_loss.pyrequires immediate attention - specifically lines 83, 86, 90 (exponent formula) and lines 280-281 (shape handling logic)
Important Files Changed
File Analysis
| Filename | Score | Overview | 
|---|---|---|
| monai/losses/unified_focal_loss.py | 1/5 | Generalizes AsymmetricUnifiedFocalLoss for multi-class segmentation and adds sigmoid/softmax interface. Critical issues: incorrect exponent formula (uses 1.0 - self.gammainstead of1.0 / self.gamma) in AsymmetricFocalTverskyLoss, shape handling logic error when y_true.ndim differs from y_pred_act.ndim | 
Sequence Diagram
sequenceDiagram
    participant User
    participant AsymmetricUnifiedFocalLoss
    participant AsymmetricFocalLoss
    participant AsymmetricFocalTverskyLoss
    User->>AsymmetricUnifiedFocalLoss: forward(y_pred, y_true)
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Apply sigmoid/softmax if requested
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Convert y_true to one-hot if to_onehot_y=True
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Ensure shape compatibility
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalLoss: forward(y_pred_act, y_true)
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Exclude background if needed
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Compute cross-entropy loss
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply asymmetric weighting (BG vs FG)
    AsymmetricFocalLoss-->>AsymmetricUnifiedFocalLoss: f_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalTverskyLoss: forward(y_pred_act, y_true)
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Exclude background if needed
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute Tversky Index (TP, FN, FP)
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply focal term: (1-TI)^(1/gamma)
    AsymmetricFocalTverskyLoss-->>AsymmetricUnifiedFocalLoss: t_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Combine: lambda*f_loss + (1-lambda)*t_loss
    AsymmetricUnifiedFocalLoss-->>User: loss
    1 file reviewed, 3 comments
        
          
                monai/losses/unified_focal_loss.py
              
                Outdated
          
        
      | if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: | ||
| y_true = y_true.unsqueeze(1) # Add channel dim | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: condition will fail when y_true.ndim == y_pred_act.ndim - 1
When y_true has shape (B, H, W) and y_pred_act has shape (B, C, H, W), the first condition y_true.ndim == y_pred_act.ndim - 1 evaluates to True (3 == 4-1). The code then proceeds to check y_true.shape[1] != y_pred_act.shape[1] on line 280, which compares H (spatial dimension) with C (channel dimension). This is a logical error.
The code should unsqueeze immediately when ndim differs:
| if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| if y_true.ndim == y_pred_act.ndim - 1: # e.g., (B, H, W) vs (B, C, H, W) | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| elif y_true.shape[1] != y_pred_act.shape[1]: | 
        
          
                monai/losses/unified_focal_loss.py
              
                Outdated
          
        
      | elif n_pred_ch > 1 or self.sigmoid: | ||
| # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion | ||
| if y_true.shape[1] != 1: | ||
| y_true = y_true.unsqueeze(1) | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: one-hot with num_classes=1 when sigmoid=True and n_pred_ch==1
When sigmoid=True and n_pred_ch == 1, line 269's condition is False (because not self.sigmoid is False), so the code proceeds to line 271. Since self.sigmoid is True, the condition n_pred_ch > 1 or self.sigmoid is True, and one-hot encoding is performed with num_classes=1.
One-hot encoding with 1 class doesn't make logical sense - it just keeps the single channel as-is. For binary segmentation with sigmoid, the pattern should be: input has 1 channel (foreground probability), and one-hot should convert to 2 channels (background, foreground).
| elif n_pred_ch > 1 or self.sigmoid: | |
| # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion | |
| if y_true.shape[1] != 1: | |
| y_true = y_true.unsqueeze(1) | |
| y_true = one_hot(y_true, num_classes=n_pred_ch) | |
| elif n_pred_ch > 1 or (self.sigmoid and n_pred_ch == 1): | |
| # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion | |
| if y_true.shape[1] != 1: | |
| y_true = y_true.unsqueeze(1) | |
| # For binary sigmoid case, need 2 classes | |
| num_classes = 2 if (self.sigmoid and n_pred_ch == 1) else n_pred_ch | |
| y_true = one_hot(y_true, num_classes=num_classes) | 
| if self.to_onehot_y: | ||
| if n_pred_ch == 1: | ||
| if n_pred_ch == 1 and not self.sigmoid: | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||
| else: | ||
| elif n_pred_ch > 1 or self.sigmoid: | ||
| # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion | ||
| if y_true.shape[1] != 1: | ||
| y_true = y_true.unsqueeze(1) | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: potential shape mismatch with sigmoid=True, n_pred_ch=1, and to_onehot_y=True
When using sigmoid=True with 1 input channel, y_pred_act becomes (B, 1, H, W). If to_onehot_y=True, the code converts y_true to one-hot with num_classes=1, keeping it (B, 1, H, W).
Consider adding a warning or documentation note that for binary segmentation with sigmoid=True and to_onehot_y=True, users should keep y_true as 0/1 masks without one-hot conversion, since one-hot with 1 class is redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR generalizes AsymmetricUnifiedFocalLoss, AsymmetricFocalLoss, and AsymmetricFocalTverskyLoss to support multi-class segmentation and adds the standard sigmoid/softmax/to_onehot_y interface used across other MONAI losses.
Key changes:
- Replaced to_onehot_yparameter withinclude_backgroundin base loss classes
- Added use_sigmoid/use_softmaxactivation options toAsymmetricUnifiedFocalLoss
- Generalized asymmetric logic to handle arbitrary number of classes (not just binary)
- Separated gamma/delta parameters for focal and Tversky components
- Updated parameter names for clarity (lambda_focalinstead ofweight)
Critical issues found:
- Incorrect gamma exponent formula: Uses 1.0 - self.gammabut should be1.0 / self.gammaper the paper (lines 83, 86, 90). With default gamma=0.75, current code gives (1-dice)^0.25 (decreasing penalty) instead of (1-dice)^1.333 (increasing penalty)
- Dimension checking logic error: Line 280 compares spatial dimension with channel dimension when y_true.ndimdiffers fromy_pred_act.ndim
- One-hot with 1 class: When use_sigmoid=Truewith single channel, creates one-hot encoding withnum_classes=1which is illogical
Confidence Score: 1/5
- This PR has critical mathematical errors that will cause incorrect loss values and training behavior
- The incorrect gamma exponent formula (1-gamma instead of 1/gamma) is a fundamental mathematical error that affects all three loss classes. This will cause the loss to behave opposite to its intended design, reducing penalties for poor predictions instead of increasing them. Combined with the dimension checking bug and one-hot encoding issue, this PR cannot be merged safely without fixes.
- The single file monai/losses/unified_focal_loss.pyrequires immediate attention to fix the gamma exponent formula in all three loss classes (lines 83, 86, 90), the dimension checking logic (line 280), and the one-hot encoding logic (line 271)
Important Files Changed
File Analysis
| Filename | Score | Overview | 
|---|---|---|
| monai/losses/unified_focal_loss.py | 1/5 | Generalizes three loss classes for multi-class support with new sigmoid/softmax interface. Critical issues: incorrect gamma exponent formula (uses 1-gammainstead of1/gamma), dimension checking logic error, and one-hot encoding with 1 class. | 
Sequence Diagram
sequenceDiagram
    participant User
    participant AUFL as AsymmetricUnifiedFocalLoss
    participant AFL as AsymmetricFocalLoss
    participant AFTL as AsymmetricFocalTverskyLoss
    
    User->>AUFL: forward(y_pred, y_true)
    
    alt use_sigmoid
        AUFL->>AUFL: y_pred_act = sigmoid(y_pred)
    else use_softmax
        AUFL->>AUFL: y_pred_act = softmax(y_pred, dim=1)
    end
    
    alt to_onehot_y
        AUFL->>AUFL: Convert y_true to one-hot
    end
    
    AUFL->>AUFL: Shape alignment checks
    
    AUFL->>AFL: forward(y_pred_act, y_true)
    
    alt include_background=False
        AFL->>AFL: Exclude channel 0
    end
    
    AFL->>AFL: cross_entropy = -y_true * log(y_pred)
    
    alt Asymmetric logic
        AFL->>AFL: back_ce with gamma exponent
        AFL->>AFL: fore_ce without gamma
        AFL->>AFL: Combine and reduce
    end
    
    AFL-->>AUFL: f_loss
    
    AUFL->>AFTL: forward(y_pred_act, y_true)
    
    alt include_background=False
        AFTL->>AFTL: Exclude channel 0
    end
    
    AFTL->>AFTL: Calculate tp, fn, fp
    AFTL->>AFTL: dice_class = (tp + ε) / (tp + δ*fn + (1-δ)*fp + ε)
    
    alt Asymmetric logic
        AFTL->>AFTL: back_dice without gamma
        AFTL->>AFTL: fore_dice with gamma exponent
        AFTL->>AFTL: Combine and reduce
    end
    
    AFTL-->>AUFL: t_loss
    
    AUFL->>AUFL: loss = λ*f_loss + (1-λ)*t_loss
    AUFL-->>User: loss
    1 file reviewed, 2 comments
| elif n_pred_ch > 1 or self.use_sigmoid: | ||
| # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion | ||
| if y_true.shape[1] != 1: | ||
| y_true = y_true.unsqueeze(1) | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: when use_sigmoid=True and n_pred_ch=1, one-hot with num_classes=1 doesn't make sense
When use_sigmoid=True with single channel, line 271's condition n_pred_ch > 1 or self.use_sigmoid evaluates to True (because self.use_sigmoid is True). This leads to one-hot encoding with num_classes=1, which is illogical. For binary sigmoid, should convert to 2 classes (background + foreground).
        
          
                monai/losses/unified_focal_loss.py
              
                Outdated
          
        
      | if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: | ||
| y_true = y_true.unsqueeze(1) # Add channel dim | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: condition logic is backwards - checks incompatible dimensions
When y_true.ndim == y_pred_act.ndim - 1 is True (e.g., (B,H,W) vs (B,C,H,W)), the code should immediately unsqueeze. Instead, it also checks y_true.shape[1] != y_pred_act.shape[1] which compares H (spatial) with C (channel) - wrong dimensions.
| if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| if y_true.ndim == y_pred_act.ndim - 1: # e.g., (B, H, W) vs (B, C, H, W) | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| elif y_true.shape[1] != y_pred_act.shape[1]: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (6)
monai/losses/unified_focal_loss.py (6)
83-83: Critical: incorrect exponent formula.The Unified Focal Loss paper specifies
(1 - dice)^(1/γ). With defaultγ=0.75, this gives(1 - dice)^1.333, increasing penalty for low Dice scores. Current formula1.0 - self.gammagives(1 - dice)^0.25, which incorrectly decreases the penalty.Apply this diff:
- loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) + loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C)
86-86: Critical: same exponent error.Apply this diff:
- loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) + loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1)
90-90: Critical: same exponent error.Apply this diff:
- fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) + fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1)
118-118: Breaking change: parameter removed without deprecation.Same issue as AsymmetricFocalTverskyLoss—
to_onehot_yremoved without deprecation decorator.
280-281: Critical: IndexError wheny_truehas fewer dimensions.When
y_true.ndim == y_pred_act.ndim - 1(e.g.,(B, H, W)vs(B, C, H, W)), the condition is True, but then checkingy_true.shape[1] != y_pred_act.shape[1]comparesH(spatial) withC(channel)—wrong dimensions. Should unsqueeze immediately whenndimdiffers.Apply this diff:
- if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: + if y_true.ndim == y_pred_act.ndim - 1: y_true = y_true.unsqueeze(1) # Add channel dim + elif y_true.shape[1] != y_pred_act.shape[1]: + y_true = y_true.unsqueeze(1)
271-275: Major: one-hot withnum_classes=1whenuse_sigmoid=True.When
use_sigmoid=Truewith single-channel input (n_pred_ch=1), the conditionn_pred_ch > 1 or self.use_sigmoidis True, leading toone_hot(y_true, num_classes=1). This raisesValueErrorwhen any voxel has label1. For binary sigmoid, should convert to 2 classes (background + foreground) or skip one-hot entirely.Apply this diff to handle binary sigmoid properly:
elif n_pred_ch > 1 or self.use_sigmoid: # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion if y_true.shape[1] != 1: y_true = y_true.unsqueeze(1) - y_true = one_hot(y_true, num_classes=n_pred_ch) + # For binary sigmoid, need 2 classes; otherwise use n_pred_ch + num_classes = 2 if (self.use_sigmoid and n_pred_ch == 1) else n_pred_ch + y_true = one_hot(y_true, num_classes=num_classes)Note: This also requires expanding
y_pred_actfrom 1 to 2 channels for consistency.
🧹 Nitpick comments (5)
monai/losses/unified_focal_loss.py (5)
65-65: Addstacklevel=2to warning.- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
147-147: Addstacklevel=2to warning.- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
196-197: Parameter naming inconsistency.Parameters named
use_sigmoidanduse_softmaxin the signature, but stored asself.use_sigmoidandself.use_softmax(lines 230-231). However, checked asself.use_sigmoidandself.use_softmaxthroughout. This is consistent, but the naming convention differs from typical MONAI losses which usesigmoid/softmax(notuse_sigmoid/use_softmax).For consistency with other MONAI losses (e.g.,
FocalLoss), consider renaming tosigmoidandsoftmax:- use_sigmoid: bool = False, - use_softmax: bool = False, + sigmoid: bool = False, + softmax: bool = False,Then update all references accordingly.
264-264: Addstacklevel=2to warning.- warnings.warn("single channel prediction, use_softmax=True ignored.") + warnings.warn("single channel prediction, use_softmax=True ignored.", stacklevel=2)
270-270: Addstacklevel=2to warning.- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
- monai/losses/unified_focal_loss.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
- monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.2)
monai/losses/unified_focal_loss.py
60-60: Avoid specifying long messages outside the exception class
(TRY003)
65-65: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
100-100: Avoid specifying long messages outside the exception class
(TRY003)
142-142: Avoid specifying long messages outside the exception class
(TRY003)
147-147: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
179-179: Avoid specifying long messages outside the exception class
(TRY003)
236-236: Avoid specifying long messages outside the exception class
(TRY003)
264-264: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
270-270: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
284-287: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (18)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: build-docs
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: packaging
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.8.0)
🔇 Additional comments (1)
monai/losses/unified_focal_loss.py (1)
193-294: Request test coverage for new functionality and bug fixes.This PR introduces significant changes: new activation modes (
use_sigmoid/use_softmax), generalization to multi-class, parameter API changes, and critical bug fixes (exponent formula). Without tests, these changes cannot be validated.Please add unit tests covering:
- Binary segmentation with
use_sigmoid=True- Multi-class segmentation with
use_softmax=True- Background inclusion/exclusion (
include_background)- One-hot conversion edge cases
- Shape alignment scenarios
- Correctness of exponent formula after fixes
Would you like assistance generating test cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR successfully generalizes AsymmetricUnifiedFocalLoss and its dependencies to support multi-class segmentation and adds the standard sigmoid/softmax/to_onehot_y interface used by other MONAI losses like FocalLoss. The interface is now aligned with MONAI conventions, replacing the to_onehot_y parameter with include_background for the component losses.
Major changes:
- Added use_sigmoid/use_softmaxactivation options toAsymmetricUnifiedFocalLoss
- Replaced to_onehot_ywithinclude_backgroundinAsymmetricFocalLossandAsymmetricFocalTverskyLoss
- Implemented asymmetric treatment: background class (index 0) handled differently from foreground classes
- Refactored to support N-class segmentation instead of binary-only
Critical issues found:
- Exponent formula error (lines 83, 86, 90): Uses 1.0 - self.gammainstead of1.0 / self.gammaper the Unified Focal Loss paper. With default gamma=0.75, this incorrectly computes(1-dice)^0.25(reduces penalty) instead of(1-dice)^1.333(increases penalty as intended).
- Shape handling logic error (line 280): When y_truehas fewer dimensions thany_pred_act, the condition compares spatial dimension with channel dimension.
Confidence Score: 1/5
- This PR has critical mathematical errors that will produce incorrect loss values
- Score of 1 reflects critical bugs: the exponent formula error in AsymmetricFocalTverskyLoss (lines 83, 86, 90) uses 1.0 - gammainstead of1.0 / gamma, fundamentally breaking the loss computation per the paper's specification. This will cause incorrect training behavior. Additionally, shape handling logic at line 280 has a dimensional mismatch bug. These are not edge cases but core functionality errors that must be fixed before merge.
- monai/losses/unified_focal_loss.py requires immediate attention for exponent formula corrections and shape handling fixes
Important Files Changed
File Analysis
| Filename | Score | Overview | 
|---|---|---|
| monai/losses/unified_focal_loss.py | 2/5 | Major refactoring to add sigmoid/softmax interface and multi-class support. Critical exponent formula errors (should be 1/gamma not 1-gamma) in lines 83, 86, 90. Shape handling logic error at line 280. | 
Sequence Diagram
sequenceDiagram
    participant User
    participant AsymmetricUnifiedFocalLoss
    participant AsymmetricFocalLoss
    participant AsymmetricFocalTverskyLoss
    User->>AsymmetricUnifiedFocalLoss: forward(y_pred, y_true)
    
    alt use_sigmoid=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: y_pred_act = sigmoid(y_pred)
    else use_softmax=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: y_pred_act = softmax(y_pred)
    else
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: y_pred_act = y_pred
    end
    
    alt to_onehot_y=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: y_true = one_hot(y_true, num_classes)
    end
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Ensure y_true.shape == y_pred_act.shape
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalLoss: forward(y_pred_act, y_true)
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Exclude background if include_background=False
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Compute cross_entropy
    
    alt Multi-class with background
        AsymmetricFocalLoss->>AsymmetricFocalLoss: back_ce = (1-delta) * (1-p)^gamma * CE
        AsymmetricFocalLoss->>AsymmetricFocalLoss: fore_ce = delta * CE
    else Foreground only
        AsymmetricFocalLoss->>AsymmetricFocalLoss: loss = delta * CE
    end
    
    AsymmetricFocalLoss-->>AsymmetricUnifiedFocalLoss: f_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalTverskyLoss: forward(y_pred_act, y_true)
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Exclude background if include_background=False
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Compute Tversky Index (TI)
    
    alt Multi-class with background
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: back_loss = 1 - TI[0]
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: fore_loss = (1 - TI[1:])^(1-gamma) [BUG: should be 1/gamma]
    else Foreground only
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: loss = (1 - TI)^(1-gamma) [BUG: should be 1/gamma]
    end
    
    AsymmetricFocalTverskyLoss-->>AsymmetricUnifiedFocalLoss: t_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: loss = lambda_focal * f_loss + (1-lambda_focal) * t_loss
    AsymmetricUnifiedFocalLoss-->>User: loss
    1 file reviewed, 4 comments
        
          
                monai/losses/unified_focal_loss.py
              
                Outdated
          
        
      | if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: | ||
| y_true = y_true.unsqueeze(1) # Add channel dim | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: When y_true.ndim == y_pred_act.ndim - 1 (e.g., (B, H, W) vs (B, C, H, W)), y_true.shape[1] is the spatial dimension H, not a channel dimension. Comparing it with y_pred_act.shape[1] (C) is illogical.
| if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| if y_true.ndim == y_pred_act.ndim - 1: | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| elif y_true.shape[1] != y_pred_act.shape[1]: | 
        
          
                monai/losses/unified_focal_loss.py
              
                Outdated
          
        
      |  | ||
| if not self.include_background: | ||
| # All classes are foreground, apply foreground logic | ||
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Exponent formula is incorrect per the Unified Focal Loss paper. Should be 1.0 / self.gamma not 1.0 - self.gamma. With default gamma=0.75, current gives (1-dice)^0.25 (decreases penalty), should be (1-dice)^1.333 (increases penalty).
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | |
| loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C) | 
        
          
                monai/losses/unified_focal_loss.py
              
                Outdated
          
        
      | loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) | ||
| elif n_classes == 1: | ||
| # Single class, must be foreground (BG was excluded or not provided) | ||
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Same exponent error - should be 1.0 / self.gamma
| loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) | |
| loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1) | 
| else: | ||
| # Asymmetric logic: class 0 is BG, others are FG | ||
| back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1) | ||
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Same exponent error - should be 1.0 / self.gamma
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) | |
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1) | 
f57d294    to
    30c82db      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Refactored three loss classes (AsymmetricFocalTverskyLoss, AsymmetricFocalLoss, AsymmetricUnifiedFocalLoss) to support multi-class segmentation and align with MONAI's standard interface (sigmoid/softmax/to_onehot_y parameters).
Key changes:
- Replaced to_onehot_yparameter withinclude_backgroundfor component losses
- Added use_sigmoidanduse_softmaxactivation options toAsymmetricUnifiedFocalLoss
- Generalized asymmetric logic to handle N classes (not just binary)
- Improved parameter clarity: renamed weighttolambda_focal, split gamma/delta params for each component
- Proper reduction handling (mean/sum/none) for all loss components
Critical issues found:
- Incorrect exponent formula: Uses 1.0 - self.gammainstead of1.0 / self.gammain 3 locations (lines 83, 86, 90), contradicting the Unified Focal Loss paper's formula(1-TI)^(1/γ). With default γ=0.75, current code produces(1-dice)^0.25(reduces penalty) instead of(1-dice)^1.333(increases penalty)
- Shape alignment bug: Line 280 condition y_true.ndim == y_pred_act.ndim - 1 and y_true.shape[1] != y_pred_act.shape[1]will compare spatial dimension H with channel dimension C when shapes are (B,H,W) vs (B,C,H,W)
Confidence Score: 1/5
- This PR contains critical mathematical errors that will produce incorrect loss values and must be fixed before merging
- Score of 1 reflects critical bugs: (1) the exponent formula error in 3 locations fundamentally breaks the loss calculation per the paper's specification, causing the wrong penalty direction; (2) shape alignment logic error will cause IndexError in some input configurations. These are not edge cases but affect core functionality.
- monai/losses/unified_focal_loss.py requires immediate attention - all three exponent errors (lines 83, 86, 90) and the shape logic bug (line 280) must be fixed
Important Files Changed
File Analysis
| Filename | Score | Overview | 
|---|---|---|
| monai/losses/unified_focal_loss.py | 2/5 | Refactored to support multi-class segmentation and added sigmoid/softmax/to_onehot_y interface. Critical bug: exponent formula uses 1.0 - gammainstead of1.0 / gamma(lines 83, 86, 90), causing incorrect loss calculations. Additional issue with shape alignment logic (line 280). | 
Sequence Diagram
sequenceDiagram
    participant User
    participant AsymmetricUnifiedFocalLoss
    participant AsymmetricFocalLoss
    participant AsymmetricFocalTverskyLoss
    
    User->>AsymmetricUnifiedFocalLoss: forward(y_pred, y_true)
    
    alt use_sigmoid=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Apply sigmoid(y_pred)
    else use_softmax=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Apply softmax(y_pred, dim=1)
    end
    
    alt to_onehot_y=True
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Convert y_true to one-hot encoding
    end
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Align y_true shape with y_pred_act
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalLoss: forward(y_pred_act, y_true)
    
    alt include_background=False
        AsymmetricFocalLoss->>AsymmetricFocalLoss: Exclude background channel
    end
    
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Calculate cross_entropy = -y_true * log(y_pred)
    
    alt n_classes > 1
        AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply asymmetric weighting<br/>(BG: gamma exponent, FG: no exponent)
    end
    
    AsymmetricFocalLoss->>AsymmetricFocalLoss: Apply reduction (mean/sum/none)
    AsymmetricFocalLoss-->>AsymmetricUnifiedFocalLoss: Return f_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalTverskyLoss: forward(y_pred_act, y_true)
    
    alt include_background=False
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Exclude background channel
    end
    
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Calculate Tversky index (TP, FN, FP)
    
    alt n_classes > 1
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply asymmetric focal modulation<br/>(BG: no exponent, FG: gamma exponent)
    end
    
    AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: Apply reduction (mean/sum/none)
    AsymmetricFocalTverskyLoss-->>AsymmetricUnifiedFocalLoss: Return t_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: Combine: lambda_focal * f_loss + (1-lambda_focal) * t_loss
    AsymmetricUnifiedFocalLoss-->>User: Return combined loss
    1 file reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR generalizes AsymmetricUnifiedFocalLoss and its component losses to support multi-class segmentation and adds the standard MONAI loss interface (use_sigmoid, use_softmax, to_onehot_y, include_background). The refactoring replaces the hardcoded binary segmentation logic with flexible multi-class handling.
Major changes:
- Replaced to_onehot_yparameter withinclude_backgroundin component losses (AsymmetricFocalLoss,AsymmetricFocalTverskyLoss)
- Added use_sigmoid/use_softmaxactivation options toAsymmetricUnifiedFocalLoss
- Implemented asymmetric treatment: background class (index 0) uses different loss weighting than foreground classes
- Updated parameter names for clarity (weight→lambda_focal, added separate gamma/delta for each component)
Critical issues found:
- Line 90: Exponent formula error in AsymmetricFocalTverskyLoss- uses1.0 - self.gammainstead of1.0 / self.gamma, causing incorrect loss behavior
- Lines 279-280: Shape handling logic error when y_truehas fewer dimensions thany_pred_act
- Lines 271-275: One-hot encoding with num_classes=1when usinguse_sigmoid=Truewith single-channel input
Confidence Score: 1/5
- This PR has critical mathematical errors that will cause incorrect loss computation
- Score of 1 reflects a critical exponent formula error on line 90 that fundamentally breaks the AsymmetricFocalTverskyLoss behavior, plus logic errors in shape handling (lines 279-280) and one-hot encoding (lines 271-275) that could cause runtime failures or incorrect results
- monai/losses/unified_focal_loss.py requires immediate attention - line 90 has incorrect exponent formula (1.0 - gamma instead of 1.0 / gamma), lines 279-280 have shape handling logic error, and lines 271-275 have one-hot encoding issue
Important Files Changed
File Analysis
| Filename | Score | Overview | 
|---|---|---|
| monai/losses/unified_focal_loss.py | 2/5 | Refactored to add sigmoid/softmax interface and multi-class support, but contains critical exponent formula error in AsymmetricFocalTverskyLoss (line 90) and shape/one-hot handling issues in AsymmetricUnifiedFocalLoss | 
Sequence Diagram
sequenceDiagram
    participant User
    participant AsymmetricUnifiedFocalLoss
    participant AsymmetricFocalLoss
    participant AsymmetricFocalTverskyLoss
    
    User->>AsymmetricUnifiedFocalLoss: forward(y_pred, y_true)
    
    alt use_sigmoid
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: y_pred_act = sigmoid(y_pred)
    else use_softmax
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: y_pred_act = softmax(y_pred)
    else no activation
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: y_pred_act = y_pred
    end
    
    alt to_onehot_y
        AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: y_true = one_hot(y_true, num_classes)
    end
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: ensure shape compatibility
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalLoss: forward(y_pred_act, y_true)
    
    alt include_background
        AsymmetricFocalLoss->>AsymmetricFocalLoss: apply asymmetric weighting<br/>(BG: (1-delta)*(1-p)^gamma*CE, FG: delta*CE)
    else exclude background
        AsymmetricFocalLoss->>AsymmetricFocalLoss: slice [:, 1:] and apply FG logic
    end
    
    AsymmetricFocalLoss-->>AsymmetricUnifiedFocalLoss: focal_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricFocalTverskyLoss: forward(y_pred_act, y_true)
    
    alt include_background
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: compute Tversky index<br/>asymmetric: BG linear, FG with exponent
    else exclude background
        AsymmetricFocalTverskyLoss->>AsymmetricFocalTverskyLoss: slice [:, 1:] and apply FG logic
    end
    
    AsymmetricFocalTverskyLoss-->>AsymmetricUnifiedFocalLoss: tversky_loss
    
    AsymmetricUnifiedFocalLoss->>AsymmetricUnifiedFocalLoss: loss = lambda_focal * focal_loss<br/>+ (1-lambda_focal) * tversky_loss
    
    AsymmetricUnifiedFocalLoss-->>User: loss
    1 file reviewed, 3 comments
| else: | ||
| # Asymmetric logic: class 0 is BG, others are FG | ||
| back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1) | ||
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: exponent formula incorrect - should be 1.0 / self.gamma not 1.0 - self.gamma
Paper specifies AFTL = (1-TI)^(1/γ). With γ=0.75: should be (1-dice)^1.333 (increases penalty), current gives (1-dice)^0.25 (decreases penalty).
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) | |
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1) | 
| if y_true.ndim == y_pred_act.ndim - 1: | ||
| y_true = y_true.unsqueeze(1) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: when y_true.ndim == y_pred_act.ndim - 1 (e.g., (B,H,W) vs (B,C,H,W)), should immediately unsqueeze. The AND condition causes the unsqueeze to be skipped when dimensions match exactly at position 1.
| if y_true.ndim == y_pred_act.ndim - 1: | |
| y_true = y_true.unsqueeze(1) | |
| if y_true.ndim == y_pred_act.ndim - 1: # e.g., (B, H, W) vs (B, C, H, W) | |
| y_true = y_true.unsqueeze(1) # Add channel dim | |
| elif y_true.shape[1] != y_pred_act.shape[1]: | |
| y_true = y_true.unsqueeze(1) | 
| elif n_pred_ch > 1 or self.use_sigmoid: | ||
| # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion | ||
| if y_true.shape[1] != 1: | ||
| y_true = y_true.unsqueeze(1) | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: when use_sigmoid=True and n_pred_ch=1, line 271's condition n_pred_ch > 1 or self.use_sigmoid evaluates to True, leading to one-hot with num_classes=1 on line 275. For binary sigmoid, should convert to 2 classes (background + foreground).
| elif n_pred_ch > 1 or self.use_sigmoid: | |
| # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion | |
| if y_true.shape[1] != 1: | |
| y_true = y_true.unsqueeze(1) | |
| y_true = one_hot(y_true, num_classes=n_pred_ch) | |
| elif n_pred_ch > 1 or (self.use_sigmoid and n_pred_ch == 1): | |
| # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion | |
| if y_true.shape[1] != 1: | |
| y_true = y_true.unsqueeze(1) | |
| # For binary sigmoid case, need 2 classes | |
| num_classes = 2 if (self.use_sigmoid and n_pred_ch == 1) else n_pred_ch | |
| y_true = one_hot(y_true, num_classes=num_classes) | 
5d14d85    to
    d991466      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR generalizes AsymmetricUnifiedFocalLoss, AsymmetricFocalLoss, and AsymmetricFocalTverskyLoss to support multi-class segmentation and aligns them with MONAI's standard loss interface (adding sigmoid/softmax/to_onehot_y parameters).
Key Changes:
- Replaced to_onehot_ywithinclude_backgroundin the component losses (AsymmetricFocalLoss,AsymmetricFocalTverskyLoss)
- Added use_sigmoid/use_softmaxactivation options toAsymmetricUnifiedFocalLoss
- Implemented multi-class support with asymmetric treatment: background class (index 0) gets different weighting than foreground classes (indices 1+)
- Refactored loss computation to handle three cases: no background, single class, and multi-class with background
- Renamed parameters for clarity (weight→lambda_focal, separatefocal_loss_gamma/tversky_loss_gamma)
Critical Issues:
- Line 90: Incorrect exponent formula 1.0 - self.gammashould be1.0 / self.gammaper the Unified Focal Loss paper
- Line 271-275: One-hot encoding with num_classes=1for binary sigmoid is illogical; should use 2 classes
Confidence Score: 2/5
- This PR has a critical mathematical error in the loss formula that will cause incorrect training behavior
- Score reflects a critical bug in the exponent formula (line 90) that fundamentally breaks the loss computation for multi-class cases with asymmetric foreground handling. Additionally, the binary sigmoid one-hot logic is flawed. These issues will cause incorrect loss values during training.
- The single file monai/losses/unified_focal_loss.pyrequires immediate attention for the mathematical formula error on line 90 and one-hot encoding logic on lines 271-275
Important Files Changed
File Analysis
| Filename | Score | Overview | 
|---|---|---|
| monai/losses/unified_focal_loss.py | 2/5 | Generalized three loss classes for multi-class support and standardized interface. Critical bug in AsymmetricFocalTverskyLoss line 90 (wrong exponent formula), and logic issue in AsymmetricUnifiedFocalLoss for binary sigmoid one-hot handling. | 
Sequence Diagram
sequenceDiagram
    participant User
    participant AUFL as AsymmetricUnifiedFocalLoss
    participant AFL as AsymmetricFocalLoss
    participant AFTL as AsymmetricFocalTverskyLoss
    
    User->>AUFL: forward(y_pred, y_true)
    
    alt use_sigmoid=True
        AUFL->>AUFL: y_pred_act = sigmoid(y_pred)
    else use_softmax=True
        AUFL->>AUFL: y_pred_act = softmax(y_pred)
    else
        AUFL->>AUFL: y_pred_act = y_pred
    end
    
    alt to_onehot_y=True
        AUFL->>AUFL: y_true = one_hot(y_true, num_classes)
    end
    
    AUFL->>AUFL: Shape validation & adjustment
    
    par Parallel Loss Computation
        AUFL->>AFL: forward(y_pred_act, y_true)
        AFL->>AFL: Exclude background if needed
        AFL->>AFL: cross_entropy = -y_true * log(y_pred)
        
        alt Multi-class with background
            AFL->>AFL: BG: (1-delta) * (1-p)^gamma * CE
            AFL->>AFL: FG: delta * CE
        else Single/No background
            AFL->>AFL: All FG: delta * CE
        end
        
        AFL-->>AUFL: f_loss
    and
        AUFL->>AFTL: forward(y_pred_act, y_true)
        AFTL->>AFTL: Exclude background if needed
        AFTL->>AFTL: Compute Tversky Index (tp, fn, fp)
        
        alt Multi-class with background
            AFTL->>AFTL: BG: 1 - dice
            AFTL->>AFTL: FG: (1-dice)^(1/gamma)
        else Single/No background
            AFTL->>AFTL: All: (1-dice)^(1/gamma)
        end
        
        AFTL-->>AUFL: t_loss
    end
    
    AUFL->>AUFL: loss = lambda * f_loss + (1-lambda) * t_loss
    AUFL-->>User: loss
    1 file reviewed, 2 comments
| else: | ||
| # Asymmetric logic: class 0 is BG, others are FG | ||
| back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1) | ||
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: exponent formula is incorrect - should be 1.0 / self.gamma not 1.0 - self.gamma
Per the Unified Focal Loss paper, AFTL = (1-TI)^(1/γ). With default γ=0.75, this gives (1-dice)^1.333, increasing penalty for low dice. Current formula gives (1-dice)^0.25, incorrectly decreasing the penalty.
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) | |
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1) | 
| elif n_pred_ch > 1 or self.use_sigmoid: | ||
| # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion | ||
| if y_true.shape[1] != 1: | ||
| y_true = y_true.unsqueeze(1) | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: when use_sigmoid=True and n_pred_ch=1, one-hot with num_classes=1 is illogical
For binary sigmoid (1 channel foreground probability), should convert to 2 classes (background + foreground) for proper loss computation with the asymmetric logic.
| elif n_pred_ch > 1 or self.use_sigmoid: | |
| # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion | |
| if y_true.shape[1] != 1: | |
| y_true = y_true.unsqueeze(1) | |
| y_true = one_hot(y_true, num_classes=n_pred_ch) | |
| elif n_pred_ch > 1 or (self.use_sigmoid and n_pred_ch == 1): | |
| # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion | |
| if y_true.shape[1] != 1: | |
| y_true = y_true.unsqueeze(1) | |
| # For binary sigmoid case, need 2 classes | |
| num_classes = 2 if (self.use_sigmoid and n_pred_ch == 1) else n_pred_ch | |
| y_true = one_hot(y_true, num_classes=num_classes) | 
for more information, see https://pre-commit.ci Signed-off-by: ytl0623 <[email protected]>
Fixes Project-MONAI#8564 . ### Description Add Fourier feature positional encodings to `PatchEmbeddingBlock`. It has been shown, that Fourier feature positional encodings are better suited for Anistropic images and videos: https://arxiv.org/abs/2509.02488 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: NabJa <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]> Signed-off-by: ytl0623 <[email protected]>
Signed-off-by: ytl0623 <[email protected]>
Signed-off-by: ytl0623 <[email protected]>
Signed-off-by: ytl0623 <[email protected]>
….com> I, NabJa <[email protected]>, hereby add my Signed-off-by to this commit: 76c4391 Signed-off-by: NabJa <[email protected]> Signed-off-by: ytl0623 <[email protected]>
957edd3    to
    ac66250      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR successfully generalizes AsymmetricUnifiedFocalLoss and its component losses (AsymmetricFocalLoss, AsymmetricFocalTverskyLoss) to support multi-class segmentation beyond binary cases. The interface has been modernized to align with other MONAI losses by:
- Replacing to_onehot_yparameter withinclude_backgroundin component losses
- Adding use_sigmoid/use_softmaxactivation options to the unified loss
- Renaming weighttolambda_focalfor clarity
- Separating gamma and delta parameters for each loss component
- Updating documentation to reflect multi-class support
Critical Issue Found:
- Line 90 in AsymmetricFocalTverskyLossuses incorrect exponent formula1.0 - self.gammainstead of1.0 / self.gamma, which is inconsistent with lines 83 and 86 in the same function. This causes incorrect loss computation for foreground classes when background is included.
The refactoring is well-structured and follows MONAI patterns, but the mathematical error needs to be fixed before merging.
Confidence Score: 2/5
- This PR has a critical mathematical bug that will cause incorrect loss computation
- The exponent formula error on line 90 is a critical bug that will produce incorrect gradients and affect model training. While the refactoring is well-designed and the interface improvements are good, this single mathematical error makes the loss function mathematically incorrect per the paper's specification.
- monai/losses/unified_focal_loss.py - Line 90 has critical mathematical error
Important Files Changed
File Analysis
| Filename | Score | Overview | 
|---|---|---|
| monai/losses/unified_focal_loss.py | 2/5 | Generalizes losses for multi-class segmentation and adds sigmoid/softmax interface. Critical bug: incorrect exponent formula on line 90 (uses 1.0 - gamma instead of 1.0 / gamma) | 
Sequence Diagram
sequenceDiagram
    participant User
    participant AUFL as AsymmetricUnifiedFocalLoss
    participant AFL as AsymmetricFocalLoss
    participant AFTL as AsymmetricFocalTverskyLoss
    
    User->>AUFL: forward(y_pred, y_true)
    
    alt use_sigmoid
        AUFL->>AUFL: y_pred_act = sigmoid(y_pred)
    else use_softmax
        AUFL->>AUFL: y_pred_act = softmax(y_pred)
    else
        AUFL->>AUFL: y_pred_act = y_pred
    end
    
    alt to_onehot_y
        AUFL->>AUFL: y_true = one_hot(y_true, num_classes)
    end
    
    AUFL->>AUFL: Ensure shapes match
    
    par Compute both loss components
        AUFL->>AFL: forward(y_pred_act, y_true)
        AFL->>AFL: Exclude background if needed
        AFL->>AFL: cross_entropy = -y_true * log(y_pred)
        
        alt include_background and n_classes > 1
            AFL->>AFL: back_ce with gamma modulation
            AFL->>AFL: fore_ce without gamma
            AFL->>AFL: Concatenate losses
        else
            AFL->>AFL: Apply foreground logic
        end
        
        AFL->>AFL: Apply reduction
        AFL-->>AUFL: f_loss
    and
        AUFL->>AFTL: forward(y_pred_act, y_true)
        AFTL->>AFTL: Exclude background if needed
        AFTL->>AFTL: Compute TP, FN, FP
        AFTL->>AFTL: dice_class = Tversky index
        
        alt include_background and n_classes > 1
            AFTL->>AFTL: back_dice_loss (no gamma)
            AFTL->>AFTL: fore_dice_loss with (1-dice)^(1/gamma)
            Note right of AFTL: BUG: Line 90 uses 1-gamma
            AFTL->>AFTL: Concatenate losses
        else
            AFTL->>AFTL: Apply (1-dice)^(1/gamma)
        end
        
        AFTL->>AFTL: Apply reduction
        AFTL-->>AUFL: t_loss
    end
    
    AUFL->>AUFL: loss = lambda_focal * f_loss + (1 - lambda_focal) * t_loss
    AUFL-->>User: loss
    1 file reviewed, 1 comment
| else: | ||
| # Asymmetric logic: class 0 is BG, others are FG | ||
| back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1) | ||
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: exponent should be 1.0 / self.gamma not 1.0 - self.gamma (inconsistent with lines 83, 86)
Lines 83 and 86 correctly use 1.0 / self.gamma, but line 90 incorrectly uses 1.0 - self.gamma. Per the Unified Focal Loss paper, AFTL = (1-TI)^(1/γ). With γ=0.75, should be (1-dice)^1.333 to increase penalty for low dice scores.
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) | |
| fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (4)
monai/losses/unified_focal_loss.py (4)
118-136: Breaking API change without deprecation.Same issue as
AsymmetricFocalTverskyLoss: removingto_onehot_ybreaks existing code. Restore with deprecation or document as breaking change.
90-90: Critical: Wrong exponent formula - should be1.0 / self.gamma.The Unified Focal Loss paper specifies AFTL = (1-TI)^(1/γ). With default γ=0.75, this gives (1-dice)^1.333 (increases penalty for low dice). Current
1.0 - self.gammagives (1-dice)^0.25 (decreases penalty).This critical error has been flagged in multiple past reviews.
Apply this diff:
- fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) + fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 / self.gamma) # (B, C-1)
279-280: Logic error: dimension mismatch when y_true lacks channel dim.When
y_trueis (B, H, W) andy_pred_actis (B, C, H, W), condition on line 279 is True (3 == 4-1) but code only unsqueezes, doesn't verify shape match. Should unsqueeze immediately when ndim differs.This was flagged multiple times in past reviews.
Apply this diff:
# Ensure y_true has the same shape as y_pred_act if y_true.shape != y_pred_act.shape: - if y_true.ndim == y_pred_act.ndim - 1: + if y_true.ndim == y_pred_act.ndim - 1: # e.g., (B, H, W) vs (B, C, H, W) y_true = y_true.unsqueeze(1) - - if y_true.shape != y_pred_act.shape: + elif y_true.shape != y_pred_act.shape: raise ValueError( f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) " f"after activations/one-hot" )
271-275: Critical: One-hot with num_classes=1 breaks binary sigmoid.When
use_sigmoid=Truewith single-channel logits (n_pred_ch=1), line 271's conditionn_pred_ch > 1 or self.use_sigmoidevaluates True, callingone_hot(y_true, num_classes=1). Any label==1 raises ValueError.This blocking bug was flagged by coderabbitai[bot] in a past review and remains unfixed.
Apply this diff:
if self.to_onehot_y: if n_pred_ch == 1 and not self.use_sigmoid: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") - elif n_pred_ch > 1 or self.use_sigmoid: + elif n_pred_ch > 1: # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion if y_true.shape[1] != 1: y_true = y_true.unsqueeze(1) y_true = one_hot(y_true, num_classes=n_pred_ch)Or expand binary logits to 2 channels before one-hotting.
🧹 Nitpick comments (5)
monai/losses/unified_focal_loss.py (5)
65-65: Addstacklevel=2to warning.- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
147-147: Addstacklevel=2to warning.- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
264-264: Addstacklevel=2to warning.- warnings.warn("single channel prediction, use_softmax=True ignored.") + warnings.warn("single channel prediction, use_softmax=True ignored.", stacklevel=2)
270-270: Addstacklevel=2to warning.- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
193-227: Docstrings missing exception documentation.Per coding guidelines, docstrings should document raised exceptions. Line 236 raises ValueError but it's not in the Args section.
Add a Raises section:
Args: to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. use_sigmoid: if True, apply a sigmoid activation to the input y_pred. use_softmax: if True, apply a softmax activation to the input y_pred. lambda_focal: the weight for AsymmetricFocalLoss (Cross-Entropy based). The weight for AsymmetricFocalTverskyLoss will be (1 - lambda_focal). Defaults to 0.5. focal_loss_gamma: gamma parameter for the AsymmetricFocalLoss component. Defaults to 2.0. focal_loss_delta: delta parameter for the AsymmetricFocalLoss component. Defaults to 0.7. tversky_loss_gamma: gamma parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.75. tversky_loss_delta: delta parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.7. include_background: whether to include loss computation for the background class. Defaults to True. reduction: specifies the reduction to apply to the output: "none", "mean", "sum". + + Raises: + ValueError: if both use_sigmoid and use_softmax are True.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
- monai/losses/unified_focal_loss.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
- monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.2)
monai/losses/unified_focal_loss.py
60-60: Avoid specifying long messages outside the exception class
(TRY003)
65-65: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
100-100: Avoid specifying long messages outside the exception class
(TRY003)
142-142: Avoid specifying long messages outside the exception class
(TRY003)
147-147: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
179-179: Avoid specifying long messages outside the exception class
(TRY003)
236-236: Avoid specifying long messages outside the exception class
(TRY003)
264-264: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
270-270: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
283-286: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (15)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: packaging
- GitHub Check: build-docs
| include_background: bool = True, | ||
| delta: float = 0.7, | ||
| gamma: float = 0.75, | ||
| epsilon: float = 1e-7, | ||
| reduction: LossReduction | str = LossReduction.MEAN, | ||
| ) -> None: | ||
| """ | ||
| Args: | ||
| to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. | ||
| delta : weight of the background. Defaults to 0.7. | ||
| gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. | ||
| epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. | ||
| include_background: whether to include loss computation for the background class. Defaults to True. | ||
| delta : weight of the background. Defaults to 0.7. (Used to weigh FNs and FPs in Tversky index) | ||
| gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. | ||
| epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. | ||
| reduction: specifies the reduction to apply to the output: "none", "mean", "sum". | ||
| """ | ||
| super().__init__(reduction=LossReduction(reduction).value) | ||
| self.to_onehot_y = to_onehot_y | ||
| self.include_background = include_background | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Breaking API change without deprecation.
Removing to_onehot_y from AsymmetricFocalTverskyLoss breaks existing code. Per MONAI conventions and prior reviewer feedback (ericspod), deprecated parameters must be preserved with deprecation decorators for 1-2 versions.
Either restore to_onehot_y with deprecation warnings or explicitly document this as a breaking change in the changelog.
Fixes #8603 .
Description
Addresses the limitations of
AsymmetricUnifiedFocalLossand its dependencies (AsymmetricFocalLoss,AsymmetricFocalTverskyLoss). The previous implementations were hard-coded for binary segmentation and lacked the standardsigmoid/softmax/to_onehot_yinterface common to other MONAI losses.Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.