Skip to content

Commit ea1ded3

Browse files
fix resnet_cifar10
1 parent ca01557 commit ea1ded3

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

image_classification/resnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ def resnet_cifar10(input, class_dim, depth=32):
8585
nStages = {16, 64, 128}
8686
conv1 = conv_bn_layer(
8787
input, ch_in=3, ch_out=16, filter_size=3, stride=1, padding=1)
88-
res1 = layer_warp(basicblock, conv1, 16, 16, n, 1)
89-
res2 = layer_warp(basicblock, res1, 16, 32, n, 2)
90-
res3 = layer_warp(basicblock, res2, 32, 64, n, 2)
88+
res1 = layer_warp(basicblock, conv1, 16, n, 1)
89+
res2 = layer_warp(basicblock, res1, 32, n, 2)
90+
res3 = layer_warp(basicblock, res2, 64, n, 2)
9191
pool = paddle.layer.img_pool(
9292
input=res3, pool_size=8, stride=1, pool_type=paddle.pooling.Avg())
9393
out = paddle.layer.fc(

0 commit comments

Comments
 (0)