Skip to content

Commit 5d14d85

Browse files
committed
Update unified_focal_loss.py
Signed-off-by: ytl0623 <[email protected]>
1 parent 30c82db commit 5d14d85

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
8080

8181
if not self.include_background:
8282
# All classes are foreground, apply foreground logic
83-
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C)
83+
loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C)
8484
elif n_classes == 1:
8585
# Single class, must be foreground (BG was excluded or not provided)
86-
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1)
86+
loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1)
8787
else:
8888
# Asymmetric logic: class 0 is BG, others are FG
8989
back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1)
@@ -276,9 +276,8 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
276276

277277
# Ensure y_true has the same shape as y_pred_act
278278
if y_true.shape != y_pred_act.shape:
279-
# This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid
280-
if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
281-
y_true = y_true.unsqueeze(1) # Add channel dim
279+
if y_true.ndim == y_pred_act.ndim - 1:
280+
y_true = y_true.unsqueeze(1)
282281

283282
if y_true.shape != y_pred_act.shape:
284283
raise ValueError(

0 commit comments

Comments
 (0)