@@ -60,12 +60,12 @@ def __init__(
6060 include_background: if False, channel index 0 (background category) is excluded from the calculation.
6161 if the non-background segmentations are small compared to the total image size they can get overwhelmed
6262 by the signal from the background so excluding it in such cases helps convergence.
63- to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
63+ to_onehot_y: whether to convert the ``target`` into the one-hot format,
64+ using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
6465 sigmoid: if True, apply a sigmoid function to the prediction.
6566 softmax: if True, apply a softmax function to the prediction.
66- other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
67- other activation layers, Defaults to ``None``. for example:
68- `other_act = torch.tanh`.
67+ other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
68+ ``other_act = torch.tanh``.
6969 squared_pred: use squared versions of targets and predictions in the denominator or not.
7070 jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
7171 reduction: {``"none"``, ``"mean"``, ``"sum"``}
@@ -247,12 +247,12 @@ def __init__(
247247 """
248248 Args:
249249 include_background: If False channel index 0 (background category) is excluded from the calculation.
250- to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
250+ to_onehot_y: whether to convert the ``target`` into the one-hot format,
251+ using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
251252 sigmoid: If True, apply a sigmoid function to the prediction.
252253 softmax: If True, apply a softmax function to the prediction.
253- other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
254- other activation layers, Defaults to ``None``. for example:
255- `other_act = torch.tanh`.
254+ other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
255+ ``other_act = torch.tanh``.
256256 w_type: {``"square"``, ``"simple"``, ``"uniform"``}
257257 Type of function to transform ground truth volume to a weight factor. Defaults to ``"square"``.
258258 reduction: {``"none"``, ``"mean"``, ``"sum"``}
@@ -639,14 +639,14 @@ def __init__(
639639 ``reduction`` is used for both losses and other parameters are only used for dice loss.
640640
641641 include_background: if False channel index 0 (background category) is excluded from the calculation.
642- to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
642+ to_onehot_y: whether to convert the ``target`` into the one-hot format,
643+ using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
643644 sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,
644645 don't need to specify activation function for `CrossEntropyLoss`.
645646 softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,
646647 don't need to specify activation function for `CrossEntropyLoss`.
647- other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
648- other activation layers, Defaults to ``None``. for example: `other_act = torch.tanh`.
649- only used by the `DiceLoss`, don't need to specify activation function for `CrossEntropyLoss`.
648+ other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
649+ ``other_act = torch.tanh``. only used by the `DiceLoss`, not for the `CrossEntropyLoss`.
650650 squared_pred: use squared versions of targets and predictions in the denominator or not.
651651 jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
652652 reduction: {``"mean"``, ``"sum"``}
@@ -728,7 +728,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
728728
729729 """
730730 if len (input .shape ) != len (target .shape ):
731- raise ValueError ("the number of dimensions for input and target should be the same." )
731+ raise ValueError (
732+ "the number of dimensions for input and target should be the same, "
733+ f"got shape { input .shape } and { target .shape } ."
734+ )
732735
733736 dice_loss = self .dice (input , target )
734737 ce_loss = self .ce (input , target )
@@ -743,6 +746,10 @@ class DiceFocalLoss(_Loss):
743746 The details of Dice loss is shown in ``monai.losses.DiceLoss``.
744747 The details of Focal Loss is shown in ``monai.losses.FocalLoss``.
745748
749+ ``gamma``, ``focal_weight`` and ``lambda_focal`` are only used for the focal loss.
750+ ``include_background`` and ``reduction`` are used for both losses
751+ and other parameters are only used for dice loss.
752+
746753 """
747754
748755 def __init__ (
@@ -765,18 +772,15 @@ def __init__(
765772 ) -> None :
766773 """
767774 Args:
768- ``gamma``, ``focal_weight`` and ``lambda_focal`` are only used for focal loss.
769- ``include_background``, ``to_onehot_y``and ``reduction`` are used for both losses
770- and other parameters are only used for dice loss.
771775 include_background: if False channel index 0 (background category) is excluded from the calculation.
772- to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
776+ to_onehot_y: whether to convert the ``target`` into the one-hot format,
777+ using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
773778 sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,
774779 don't need to specify activation function for `FocalLoss`.
775780 softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,
776781 don't need to specify activation function for `FocalLoss`.
777- other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
778- other activation layers, Defaults to ``None``. for example: `other_act = torch.tanh`.
779- only used by the `DiceLoss`, don't need to specify activation function for `FocalLoss`.
782+ other_act: callable function to execute other activation layers, Defaults to ``None``.
783+ for example: `other_act = torch.tanh`. only used by the `DiceLoss`, not for `FocalLoss`.
780784 squared_pred: use squared versions of targets and predictions in the denominator or not.
781785 jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
782786 reduction: {``"none"``, ``"mean"``, ``"sum"``}
@@ -803,6 +807,8 @@ def __init__(
803807 """
804808 super ().__init__ ()
805809 self .dice = DiceLoss (
810+ include_background = include_background ,
811+ to_onehot_y = False ,
806812 sigmoid = sigmoid ,
807813 softmax = softmax ,
808814 other_act = other_act ,
@@ -813,15 +819,20 @@ def __init__(
813819 smooth_dr = smooth_dr ,
814820 batch = batch ,
815821 )
816- self .focal = FocalLoss (gamma = gamma , weight = focal_weight , reduction = reduction )
822+ self .focal = FocalLoss (
823+ include_background = include_background ,
824+ to_onehot_y = False ,
825+ gamma = gamma ,
826+ weight = focal_weight ,
827+ reduction = reduction ,
828+ )
817829 if lambda_dice < 0.0 :
818830 raise ValueError ("lambda_dice should be no less than 0.0." )
819831 if lambda_focal < 0.0 :
820832 raise ValueError ("lambda_focal should be no less than 0.0." )
821833 self .lambda_dice = lambda_dice
822834 self .lambda_focal = lambda_focal
823835 self .to_onehot_y = to_onehot_y
824- self .include_background = include_background
825836
826837 def forward (self , input : torch .Tensor , target : torch .Tensor ) -> torch .Tensor :
827838 """
@@ -836,24 +847,16 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
836847
837848 """
838849 if len (input .shape ) != len (target .shape ):
839- raise ValueError ("the number of dimensions for input and target should be the same." )
840-
841- n_pred_ch = input .shape [ 1 ]
842-
850+ raise ValueError (
851+ "the number of dimensions for input and target should be the same, "
852+ f"got shape { input .shape } and { target . shape } ."
853+ )
843854 if self .to_onehot_y :
855+ n_pred_ch = input .shape [1 ]
844856 if n_pred_ch == 1 :
845857 warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
846858 else :
847859 target = one_hot (target , num_classes = n_pred_ch )
848-
849- if not self .include_background :
850- if n_pred_ch == 1 :
851- warnings .warn ("single channel prediction, `include_background=False` ignored." )
852- else :
853- # if skipping background, removing first channel
854- target = target [:, 1 :]
855- input = input [:, 1 :]
856-
857860 dice_loss = self .dice (input , target )
858861 focal_loss = self .focal (input , target )
859862 total_loss : torch .Tensor = self .lambda_dice * dice_loss + self .lambda_focal * focal_loss
@@ -867,11 +870,13 @@ class GeneralizedDiceFocalLoss(torch.nn.modules.loss._Loss):
867870 Args:
868871 include_background (bool, optional): if False channel index 0 (background category) is excluded from the calculation.
869872 Defaults to True.
870- to_onehot_y (bool, optional): whether to convert `y` into the one-hot format. Defaults to False.
873+ to_onehot_y: whether to convert the ``target`` into the one-hot format,
874+ using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
871875 sigmoid (bool, optional): if True, apply a sigmoid function to the prediction. Defaults to False.
872876 softmax (bool, optional): if True, apply a softmax function to the prediction. Defaults to False.
873- other_act (Optional[Callable], optional): if don't want to use sigmoid or softmax, use other callable
874- function to execute other activation layers. Defaults to None.
877+ other_act (Optional[Callable], optional): callable function to execute other activation layers,
878+ Defaults to ``None``. for example: `other_act = torch.tanh`.
879+ only used by the `GeneralizedDiceLoss`, not for the `FocalLoss`.
875880 w_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
876881 ground-truth volume to a weight factor. Defaults to ``"square"``.
877882 reduction (Union[LossReduction, str], optional): {``"none"``, ``"mean"``, ``"sum"``}. Specified the reduction to
0 commit comments