@@ -80,10 +80,10 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
8080
8181 if not self .include_background :
8282 # All classes are foreground, apply foreground logic
83- loss = torch .pow (1.0 - dice_class , 1.0 - self .gamma ) # (B, C)
83+ loss = torch .pow (1.0 - dice_class , 1.0 / self .gamma ) # (B, C)
8484 elif n_classes == 1 :
8585 # Single class, must be foreground (BG was excluded or not provided)
86- loss = torch .pow (1.0 - dice_class , 1.0 - self .gamma ) # (B, 1)
86+ loss = torch .pow (1.0 - dice_class , 1.0 / self .gamma ) # (B, 1)
8787 else :
8888 # Asymmetric logic: class 0 is BG, others are FG
8989 back_dice_loss = (1.0 - dice_class [:, 0 ]).unsqueeze (1 ) # (B, 1)
@@ -276,9 +276,8 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
276276
277277 # Ensure y_true has the same shape as y_pred_act
278278 if y_true .shape != y_pred_act .shape :
279- # This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid
280- if y_true .shape [1 ] != y_pred_act .shape [1 ] and y_true .ndim == y_pred_act .ndim - 1 :
281- y_true = y_true .unsqueeze (1 ) # Add channel dim
279+ if y_true .ndim == y_pred_act .ndim - 1 :
280+ y_true = y_true .unsqueeze (1 )
282281
283282 if y_true .shape != y_pred_act .shape :
284283 raise ValueError (
0 commit comments