Skip to content

Commit 4f8bc59

Browse files
authored
Enble swinunetr-v2 (#6203)
Fixes #6183 . ### Description Added a "use_v2" option in swinunetr initialization. Default is false will not affect the original swinunetr. Once changed to true, will become swinunetr-v2 with 4 additional convolution block. Tested running from auto3dseg bundles, no change needed for original swinunetr, and works for swinunetr-v2 Tested running from monai research contribution repo for swinuntr, no change needed for original swinunetr, and works for swinunetr-v2 Tested TensorRT, compiled .ts file successfully. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: heyufan1995 <[email protected]>
1 parent b8f158b commit 4f8bc59

1 file changed

Lines changed: 38 additions & 0 deletions

File tree

monai/networks/nets/swin_unetr.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565
use_checkpoint: bool = False,
6666
spatial_dims: int = 3,
6767
downsample="merging",
68+
use_v2=False,
6869
) -> None:
6970
"""
7071
Args:
@@ -84,6 +85,7 @@ def __init__(
8485
downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
8586
user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
8687
The default is currently `"merging"` (the original version defined in v0.9.0).
88+
use_v2: using swinunetr_v2, which adds a residual convolution block at the beggining of each swin stage.
8789
8890
Examples::
8991
@@ -142,6 +144,7 @@ def __init__(
142144
use_checkpoint=use_checkpoint,
143145
spatial_dims=spatial_dims,
144146
downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample,
147+
use_v2=use_v2,
145148
)
146149

147150
self.encoder1 = UnetrBasicBlock(
@@ -921,6 +924,7 @@ def __init__(
921924
use_checkpoint: bool = False,
922925
spatial_dims: int = 3,
923926
downsample="merging",
927+
use_v2=False,
924928
) -> None:
925929
"""
926930
Args:
@@ -942,6 +946,7 @@ def __init__(
942946
downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
943947
user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
944948
The default is currently `"merging"` (the original version defined in v0.9.0).
949+
use_v2: using swinunetr_v2, which adds a residual convolution block at the beggining of each swin stage.
945950
"""
946951

947952
super().__init__()
@@ -959,10 +964,16 @@ def __init__(
959964
)
960965
self.pos_drop = nn.Dropout(p=drop_rate)
961966
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
967+
self.use_v2 = use_v2
962968
self.layers1 = nn.ModuleList()
963969
self.layers2 = nn.ModuleList()
964970
self.layers3 = nn.ModuleList()
965971
self.layers4 = nn.ModuleList()
972+
if self.use_v2:
973+
self.layers1c = nn.ModuleList()
974+
self.layers2c = nn.ModuleList()
975+
self.layers3c = nn.ModuleList()
976+
self.layers4c = nn.ModuleList()
966977
down_sample_mod = look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample
967978
for i_layer in range(self.num_layers):
968979
layer = BasicLayer(
@@ -987,6 +998,25 @@ def __init__(
987998
self.layers3.append(layer)
988999
elif i_layer == 3:
9891000
self.layers4.append(layer)
1001+
if self.use_v2:
1002+
layerc = UnetrBasicBlock(
1003+
spatial_dims=3,
1004+
in_channels=embed_dim * 2**i_layer,
1005+
out_channels=embed_dim * 2**i_layer,
1006+
kernel_size=3,
1007+
stride=1,
1008+
norm_name="instance",
1009+
res_block=True,
1010+
)
1011+
if i_layer == 0:
1012+
self.layers1c.append(layerc)
1013+
elif i_layer == 1:
1014+
self.layers2c.append(layerc)
1015+
elif i_layer == 2:
1016+
self.layers3c.append(layerc)
1017+
elif i_layer == 3:
1018+
self.layers4c.append(layerc)
1019+
9901020
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
9911021

9921022
def proj_out(self, x, normalize=False):
@@ -1008,12 +1038,20 @@ def forward(self, x, normalize=True):
10081038
x0 = self.patch_embed(x)
10091039
x0 = self.pos_drop(x0)
10101040
x0_out = self.proj_out(x0, normalize)
1041+
if self.use_v2:
1042+
x0 = self.layers1c[0](x0.contiguous())
10111043
x1 = self.layers1[0](x0.contiguous())
10121044
x1_out = self.proj_out(x1, normalize)
1045+
if self.use_v2:
1046+
x1 = self.layers2c[0](x1.contiguous())
10131047
x2 = self.layers2[0](x1.contiguous())
10141048
x2_out = self.proj_out(x2, normalize)
1049+
if self.use_v2:
1050+
x2 = self.layers3c[0](x2.contiguous())
10151051
x3 = self.layers3[0](x2.contiguous())
10161052
x3_out = self.proj_out(x3, normalize)
1053+
if self.use_v2:
1054+
x3 = self.layers4c[0](x3.contiguous())
10171055
x4 = self.layers4[0](x3.contiguous())
10181056
x4_out = self.proj_out(x4, normalize)
10191057
return [x0_out, x1_out, x2_out, x3_out, x4_out]

0 commit comments

Comments
 (0)