@@ -74,6 +74,7 @@ def __init__(
7474 pretrained : bool = True ,
7575 pretrained_path : str | None = None ,
7676 pretrained_state_dict_key : str | None = None ,
77+ channelwise : bool = False ,
7778 ):
7879 super ().__init__ ()
7980
@@ -102,15 +103,18 @@ def __init__(
102103 self .spatial_dims = spatial_dims
103104 self .perceptual_function : nn .Module
104105 if spatial_dims == 3 and is_fake_3d is False :
105- self .perceptual_function = MedicalNetPerceptualSimilarity (net = network_type , verbose = False )
106+ self .perceptual_function = MedicalNetPerceptualSimilarity (net = network_type , verbose = False ,
107+ channelwise = channelwise )
106108 elif "radimagenet_" in network_type :
107- self .perceptual_function = RadImageNetPerceptualSimilarity (net = network_type , verbose = False )
109+ self .perceptual_function = RadImageNetPerceptualSimilarity (net = network_type , verbose = False ,
110+ channelwise = channelwise )
108111 elif network_type == "resnet50" :
109112 self .perceptual_function = TorchvisionModelPerceptualSimilarity (
110113 net = network_type ,
111114 pretrained = pretrained ,
112115 pretrained_path = pretrained_path ,
113116 pretrained_state_dict_key = pretrained_state_dict_key ,
117+ channelwise = channelwise ,
114118 )
115119 else :
116120 self .perceptual_function = LPIPS (pretrained = pretrained , net = network_type , verbose = False )
@@ -185,14 +189,21 @@ class MedicalNetPerceptualSimilarity(nn.Module):
185189 net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``}
186190 Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``.
187191 verbose: if false, mute messages from torch Hub load function.
192+ channelwise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels.
193+ Defaults to ``False``.
188194 """
189195
190- def __init__ (self , net : str = "medicalnet_resnet10_23datasets" , verbose : bool = False ) -> None :
196+ def __init__ (self ,
197+ net : str = "medicalnet_resnet10_23datasets" ,
198+ verbose : bool = False ,
199+ channelwise : bool = False ) -> None :
191200 super ().__init__ ()
192201 torch .hub ._validate_not_a_forked_repo = lambda a , b , c : True
193202 self .model = torch .hub .load ("warvito/MedicalNet-models" , model = net , verbose = verbose )
194203 self .eval ()
195204
205+ self .channelwise = channelwise
206+
196207 for param in self .parameters ():
197208 param .requires_grad = False
198209
@@ -206,6 +217,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
206217 Args:
207218 input: 3D input tensor with shape BCDHW.
208219 target: 3D target tensor with shape BCDHW.
220+
209221 """
210222 input = medicalnet_intensity_normalisation (input )
211223 target = medicalnet_intensity_normalisation (target )
@@ -227,7 +239,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
227239 feats_target = normalize_tensor (outs_target )
228240
229241 results : torch .Tensor = (feats_input - feats_target ) ** 2
230- results = spatial_average_3d (results .sum (dim = 1 , keepdim = True ), keepdim = True )
242+
243+ if self .channelwise :
244+ results = results .sum (dim = 1 , keepdim = True )
245+ results = spatial_average_3d (results , keepdim = True )
231246
232247 return results
233248
@@ -260,11 +275,13 @@ class RadImageNetPerceptualSimilarity(nn.Module):
260275 verbose: if false, mute messages from torch Hub load function.
261276 """
262277
263- def __init__ (self , net : str = "radimagenet_resnet50" , verbose : bool = False ) -> None :
278+ def __init__ (self , net : str = "radimagenet_resnet50" , verbose : bool = False , channelwise : bool = False ) -> None :
264279 super ().__init__ ()
265280 self .model = torch .hub .load ("Warvito/radimagenet-models" , model = net , verbose = verbose )
266281 self .eval ()
267282
283+ self .channelwise = channelwise
284+
268285 for param in self .parameters ():
269286 param .requires_grad = False
270287
@@ -297,7 +314,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
297314 feats_target = normalize_tensor (outs_target )
298315
299316 results : torch .Tensor = (feats_input - feats_target ) ** 2
300- results = spatial_average (results .sum (dim = 1 , keepdim = True ), keepdim = True )
317+
318+ if self .channelwise :
319+ results = results .sum (dim = 1 , keepdim = True )
320+ results = spatial_average (results , keepdim = True )
301321
302322 return results
303323
@@ -324,6 +344,7 @@ def __init__(
324344 pretrained : bool = True ,
325345 pretrained_path : str | None = None ,
326346 pretrained_state_dict_key : str | None = None ,
347+ channelwise : bool = False ,
327348 ) -> None :
328349 super ().__init__ ()
329350 supported_networks = ["resnet50" ]
@@ -347,6 +368,8 @@ def __init__(
347368 self .model = torchvision .models .feature_extraction .create_feature_extractor (network , [self .final_layer ])
348369 self .eval ()
349370
371+ self .channelwise = channelwise
372+
350373 for param in self .parameters ():
351374 param .requires_grad = False
352375
@@ -376,7 +399,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
376399 feats_target = normalize_tensor (outs_target )
377400
378401 results : torch .Tensor = (feats_input - feats_target ) ** 2
379- results = spatial_average (results .sum (dim = 1 , keepdim = True ), keepdim = True )
402+
403+ if self .channelwise :
404+ results = results .sum (dim = 1 , keepdim = True )
405+ results = spatial_average (results , keepdim = True )
380406
381407 return results
382408
0 commit comments