Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions monai/losses/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
pretrained: bool = True,
pretrained_path: str | None = None,
pretrained_state_dict_key: str | None = None,
channelwise: bool = False,
Comment thread
SomeUserName1 marked this conversation as resolved.
Outdated
):
super().__init__()

Expand All @@ -86,6 +87,9 @@ def __init__(
"Argument is_fake_3d must be set to False."
)

if channelwise and "medicalnet_" not in network_type:
raise ValueError("Channelwise loss is only compatible with MedicalNet networks.")
Comment thread
SomeUserName1 marked this conversation as resolved.
Outdated

if network_type.lower() not in list(PercetualNetworkType):
raise ValueError(
"Unrecognised criterion entered for Adversarial Loss. Must be one in: %s"
Expand All @@ -102,7 +106,9 @@ def __init__(
self.spatial_dims = spatial_dims
self.perceptual_function: nn.Module
if spatial_dims == 3 and is_fake_3d is False:
self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False)
self.perceptual_function = MedicalNetPerceptualSimilarity(
net=network_type, verbose=False, channelwise=channelwise
Comment thread
SomeUserName1 marked this conversation as resolved.
Outdated
)
elif "radimagenet_" in network_type:
self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False)
elif network_type == "resnet50":
Expand Down Expand Up @@ -170,9 +176,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
loss = loss_sagittal + loss_axial + loss_coronal
else:
# 2D and real 3D cases
loss = self.perceptual_function(input, target)
loss = self.perceptual_function(input, target).squeeze()

return torch.mean(loss)
return torch.mean(loss, dim=0)
Comment thread
SomeUserName1 marked this conversation as resolved.
Outdated


class MedicalNetPerceptualSimilarity(nn.Module):
Expand All @@ -185,14 +191,20 @@ class MedicalNetPerceptualSimilarity(nn.Module):
net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``}
Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``.
verbose: if false, mute messages from torch Hub load function.
channelwise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels.
Comment thread
SomeUserName1 marked this conversation as resolved.
Outdated
Defaults to ``False``.
"""

def __init__(self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False) -> None:
def __init__(
self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False, channelwise: bool = False
) -> None:
super().__init__()
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose)
self.eval()

self.channelwise = channelwise
Comment thread
SomeUserName1 marked this conversation as resolved.
Outdated

for param in self.parameters():
param.requires_grad = False

Expand All @@ -206,20 +218,42 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Args:
input: 3D input tensor with shape BCDHW.
target: 3D target tensor with shape BCDHW.

"""
input = medicalnet_intensity_normalisation(input)
target = medicalnet_intensity_normalisation(target)

# Get model outputs
outs_input = self.model.forward(input)
outs_target = self.model.forward(target)
feats_per_ch = 0
for ch_idx in range(input.shape[1]):
input_channel = input[:, ch_idx, ...].unsqueeze(1)
target_channel = target[:, ch_idx, ...].unsqueeze(1)

if ch_idx == 0:
outs_input = self.model.forward(input_channel)
outs_target = self.model.forward(target_channel)
feats_per_ch = outs_input.shape[1]
else:
outs_input = torch.cat([outs_input, self.model.forward(input_channel)], dim=1)
outs_target = torch.cat([outs_target, self.model.forward(target_channel)], dim=1)

# Normalise through the channels
feats_input = normalize_tensor(outs_input)
feats_target = normalize_tensor(outs_target)

results: torch.Tensor = (feats_input - feats_target) ** 2
results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True)
feats_diff: torch.Tensor = (feats_input - feats_target) ** 2
if self.channelwise:
Comment thread
SomeUserName1 marked this conversation as resolved.
Outdated
results = torch.zeros(
feats_diff.shape[0], input.shape[1], feats_diff.shape[2], feats_diff.shape[3], feats_diff.shape[4]
)
for i in range(input.shape[1]):
l_idx = i * feats_per_ch
r_idx = (i + 1) * feats_per_ch
results[:, i, ...] = feats_diff[:, l_idx : i + r_idx, ...].sum(dim=1)
else:
results = feats_diff.sum(dim=1, keepdim=True)

results = spatial_average_3d(results, keepdim=True)

return results

Expand Down
29 changes: 26 additions & 3 deletions tests/test_perceptual_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from monai.losses import PerceptualLoss
from monai.utils import optional_import
from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_quick
from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose, skip_if_downloading_fails, skip_if_quick

_, has_torchvision = optional_import("torchvision")
TEST_CASES = [
Expand All @@ -40,11 +40,26 @@
(2, 1, 64, 64, 64),
(2, 1, 64, 64, 64),
],
[
{"spatial_dims": 3, "network_type": "medicalnet_resnet10_23datasets", "is_fake_3d": False},
(2, 6, 64, 64, 64),
(2, 6, 64, 64, 64),
],
[
{"spatial_dims": 3, "network_type": "medicalnet_resnet10_23datasets", "is_fake_3d": False, "channelwise": True},
Comment thread
SomeUserName1 marked this conversation as resolved.
Outdated
(2, 6, 64, 64, 64),
(2, 6, 64, 64, 64),
],
[
{"spatial_dims": 3, "network_type": "medicalnet_resnet50_23datasets", "is_fake_3d": False},
(2, 1, 64, 64, 64),
(2, 1, 64, 64, 64),
],
[
{"spatial_dims": 3, "network_type": "medicalnet_resnet50_23datasets", "is_fake_3d": False},
(2, 6, 64, 64, 64),
(2, 6, 64, 64, 64),
],
[
{"spatial_dims": 3, "network_type": "resnet50", "is_fake_3d": True, "pretrained": True, "fake_3d_ratio": 0.2},
(2, 1, 64, 64, 64),
Expand All @@ -63,15 +78,23 @@ def test_shape(self, input_param, input_shape, target_shape):
with skip_if_downloading_fails():
loss = PerceptualLoss(**input_param)
result = loss(torch.randn(input_shape), torch.randn(target_shape))
self.assertEqual(result.shape, torch.Size([]))

if "channelwise" in input_param.keys() and input_param["channelwise"]:
self.assertEqual(result.shape, torch.Size([input_shape[1]]))
else:
self.assertEqual(result.shape, torch.Size([]))

@parameterized.expand(TEST_CASES)
def test_identical_input(self, input_param, input_shape, target_shape):
with skip_if_downloading_fails():
loss = PerceptualLoss(**input_param)
tensor = torch.randn(input_shape)
result = loss(tensor, tensor)
self.assertEqual(result, torch.Tensor([0.0]))

if "channelwise" in input_param.keys() and input_param["channelwise"]:
assert_allclose(result, torch.Tensor([0.0] * input_shape[1]))
else:
self.assertEqual(result, torch.Tensor([0.0]))

def test_different_shape(self):
with skip_if_downloading_fails():
Expand Down