Skip to content

Commit 1c3eef4

Browse files
authored
Fix vgg error when num_classes is given (#28557)
* fix vgg num classes
1 parent 1de3cdd commit 1c3eef4

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

python/paddle/tests/test_vision_models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ def test_resnet101(self):
7171
def test_resnet152(self):
7272
self.models_infer('resnet152')
7373

74+
def test_vgg16_num_classes(self):
75+
vgg16 = models.__dict__['vgg16'](pretrained=False, num_classes=10)
76+
7477
def test_lenet(self):
7578
input = InputSpec([None, 1, 28, 28], 'float32', 'x')
7679
lenet = paddle.Model(models.__dict__['LeNet'](), input)

python/paddle/vision/models/vgg.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,7 @@ def make_layers(cfg, batch_norm=False):
107107

108108

109109
def _vgg(arch, cfg, batch_norm, pretrained, **kwargs):
110-
model = VGG(make_layers(
111-
cfgs[cfg], batch_norm=batch_norm),
112-
num_classes=1000,
113-
**kwargs)
110+
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
114111

115112
if pretrained:
116113
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(

0 commit comments

Comments
 (0)