1414import warnings
1515
1616import torch
17- import torch .nn .functional as F
1817from torch .nn .modules .loss import _Loss
1918
2019from monai .networks import one_hot
@@ -169,7 +168,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
169168 back_ce = (1.0 - self .delta ) * torch .pow (1.0 - y_pred [:, 0 ], self .gamma ) * cross_entropy [:, 0 ]
170169 # (B, C-1, H, W)
171170 fore_ce = self .delta * cross_entropy [:, 1 :]
172-
171+
173172 loss = torch .cat ([back_ce .unsqueeze (1 ), fore_ce ], dim = 1 ) # (B, C, H, W)
174173
175174 # Apply reduction
@@ -276,13 +275,13 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
276275 if y_true .shape [1 ] != 1 :
277276 y_true = y_true .unsqueeze (1 )
278277 y_true = one_hot (y_true , num_classes = n_pred_ch )
279-
278+
280279 # Ensure y_true has the same shape as y_pred_act
281280 if y_true .shape != y_pred_act .shape :
282281 # This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid
283282 if y_true .shape [1 ] != y_pred_act .shape [1 ] and y_true .ndim == y_pred_act .ndim - 1 :
284283 y_true = y_true .unsqueeze (1 ) # Add channel dim
285-
284+
286285 if y_true .shape != y_pred_act .shape :
287286 raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred_act .shape } ) after activations/one-hot" )
288287
@@ -292,4 +291,4 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
292291
293292 loss : torch .Tensor = self .lambda_focal * f_loss + (1 - self .lambda_focal ) * t_loss
294293
295- return loss
294+ return loss
0 commit comments