Skip to content

Commit 5785e2b

Browse files
authored
Allow dropout overwrites on EfficientNet (#7031)
1 parent 5a75fa9 commit 5785e2b

File tree

2 files changed

+36
-14
lines changed

2 files changed

+36
-14
lines changed

test/smoke_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def smoke_test_torchvision() -> None:
1717
all(x is not None for x in [torch.ops.image.decode_png, torch.ops.torchvision.roi_align]),
1818
)
1919

20+
2021
def smoke_test_torchvision_read_decode() -> None:
2122
img_jpg = read_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
2223
if img_jpg.ndim != 3 or img_jpg.numel() < 100:
@@ -25,6 +26,7 @@ def smoke_test_torchvision_read_decode() -> None:
2526
if img_png.ndim != 3 or img_png.numel() < 100:
2627
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
2728

29+
2830
def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
2931
img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device)
3032

@@ -47,9 +49,8 @@ def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
4749
expected_category = "German shepherd"
4850
print(f"{category_name} ({device}): {100 * score:.1f}%")
4951
if category_name != expected_category:
50-
raise RuntimeError(
51-
f"Failed ResNet50 classify {category_name} Expected: {expected_category}"
52-
)
52+
raise RuntimeError(f"Failed ResNet50 classify {category_name} Expected: {expected_category}")
53+
5354

5455
def main() -> None:
5556
print(f"torchvision: {torchvision.__version__}")
@@ -59,5 +60,6 @@ def main() -> None:
5960
if torch.cuda.is_available():
6061
smoke_test_torchvision_resnet50_classify("cuda")
6162

63+
6264
if __name__ == "__main__":
6365
main()

torchvision/models/efficientnet.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,9 @@ def efficientnet_b0(
779779
weights = EfficientNet_B0_Weights.verify(weights)
780780

781781
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0)
782-
return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs)
782+
return _efficientnet(
783+
inverted_residual_setting, kwargs.pop("dropout", 0.2), last_channel, weights, progress, **kwargs
784+
)
783785

784786

785787
@register_model()
@@ -808,7 +810,9 @@ def efficientnet_b1(
808810
weights = EfficientNet_B1_Weights.verify(weights)
809811

810812
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b1", width_mult=1.0, depth_mult=1.1)
811-
return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs)
813+
return _efficientnet(
814+
inverted_residual_setting, kwargs.pop("dropout", 0.2), last_channel, weights, progress, **kwargs
815+
)
812816

813817

814818
@register_model()
@@ -837,7 +841,9 @@ def efficientnet_b2(
837841
weights = EfficientNet_B2_Weights.verify(weights)
838842

839843
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b2", width_mult=1.1, depth_mult=1.2)
840-
return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs)
844+
return _efficientnet(
845+
inverted_residual_setting, kwargs.pop("dropout", 0.3), last_channel, weights, progress, **kwargs
846+
)
841847

842848

843849
@register_model()
@@ -866,7 +872,14 @@ def efficientnet_b3(
866872
weights = EfficientNet_B3_Weights.verify(weights)
867873

868874
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b3", width_mult=1.2, depth_mult=1.4)
869-
return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs)
875+
return _efficientnet(
876+
inverted_residual_setting,
877+
kwargs.pop("dropout", 0.3),
878+
last_channel,
879+
weights,
880+
progress,
881+
**kwargs,
882+
)
870883

871884

872885
@register_model()
@@ -895,7 +908,14 @@ def efficientnet_b4(
895908
weights = EfficientNet_B4_Weights.verify(weights)
896909

897910
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b4", width_mult=1.4, depth_mult=1.8)
898-
return _efficientnet(inverted_residual_setting, 0.4, last_channel, weights, progress, **kwargs)
911+
return _efficientnet(
912+
inverted_residual_setting,
913+
kwargs.pop("dropout", 0.4),
914+
last_channel,
915+
weights,
916+
progress,
917+
**kwargs,
918+
)
899919

900920

901921
@register_model()
@@ -926,7 +946,7 @@ def efficientnet_b5(
926946
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b5", width_mult=1.6, depth_mult=2.2)
927947
return _efficientnet(
928948
inverted_residual_setting,
929-
0.4,
949+
kwargs.pop("dropout", 0.4),
930950
last_channel,
931951
weights,
932952
progress,
@@ -963,7 +983,7 @@ def efficientnet_b6(
963983
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b6", width_mult=1.8, depth_mult=2.6)
964984
return _efficientnet(
965985
inverted_residual_setting,
966-
0.5,
986+
kwargs.pop("dropout", 0.5),
967987
last_channel,
968988
weights,
969989
progress,
@@ -1000,7 +1020,7 @@ def efficientnet_b7(
10001020
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b7", width_mult=2.0, depth_mult=3.1)
10011021
return _efficientnet(
10021022
inverted_residual_setting,
1003-
0.5,
1023+
kwargs.pop("dropout", 0.5),
10041024
last_channel,
10051025
weights,
10061026
progress,
@@ -1038,7 +1058,7 @@ def efficientnet_v2_s(
10381058
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_s")
10391059
return _efficientnet(
10401060
inverted_residual_setting,
1041-
0.2,
1061+
kwargs.pop("dropout", 0.2),
10421062
last_channel,
10431063
weights,
10441064
progress,
@@ -1076,7 +1096,7 @@ def efficientnet_v2_m(
10761096
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_m")
10771097
return _efficientnet(
10781098
inverted_residual_setting,
1079-
0.3,
1099+
kwargs.pop("dropout", 0.3),
10801100
last_channel,
10811101
weights,
10821102
progress,
@@ -1114,7 +1134,7 @@ def efficientnet_v2_l(
11141134
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_l")
11151135
return _efficientnet(
11161136
inverted_residual_setting,
1117-
0.4,
1137+
kwargs.pop("dropout", 0.4),
11181138
last_channel,
11191139
weights,
11201140
progress,

0 commit comments

Comments
 (0)