Skip to content

Commit ec36c14

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 1ab9120 commit ec36c14

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import warnings
1515

1616
import torch
17-
import torch.nn.functional as F
1817
from torch.nn.modules.loss import _Loss
1918

2019
from monai.networks import one_hot
@@ -169,7 +168,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
169168
back_ce = (1.0 - self.delta) * torch.pow(1.0 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
170169
# (B, C-1, H, W)
171170
fore_ce = self.delta * cross_entropy[:, 1:]
172-
171+
173172
loss = torch.cat([back_ce.unsqueeze(1), fore_ce], dim=1) # (B, C, H, W)
174173

175174
# Apply reduction
@@ -276,13 +275,13 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
276275
if y_true.shape[1] != 1:
277276
y_true = y_true.unsqueeze(1)
278277
y_true = one_hot(y_true, num_classes=n_pred_ch)
279-
278+
280279
# Ensure y_true has the same shape as y_pred_act
281280
if y_true.shape != y_pred_act.shape:
282281
# This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid
283282
if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
284283
y_true = y_true.unsqueeze(1) # Add channel dim
285-
284+
286285
if y_true.shape != y_pred_act.shape:
287286
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) after activations/one-hot")
288287

@@ -292,4 +291,4 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
292291

293292
loss: torch.Tensor = self.lambda_focal * f_loss + (1 - self.lambda_focal) * t_loss
294293

295-
return loss
294+
return loss

0 commit comments

Comments
 (0)