diff --git a/demo/dygraph/unstructured_pruning/README.md b/demo/dygraph/unstructured_pruning/README.md index 3a861dffc..a10685425 100644 --- a/demo/dygraph/unstructured_pruning/README.md +++ b/demo/dygraph/unstructured_pruning/README.md @@ -89,9 +89,11 @@ python3.7 train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshol ## 推理: ```bash -python3.7 evalualte.py --pruned_model models/ --data imagenet +python3.7 evaluate.py --pruned_model models/model-pruned.pdparams --data imagenet ``` +**注意**,上述`pruned_model` 参数应该指向pdparams文件。 + 剪裁训练代码示例: ```python model = mobilenet_v1(num_classes=class_dim, pretrained=True) diff --git a/demo/dygraph/unstructured_pruning/evaluate.py b/demo/dygraph/unstructured_pruning/evaluate.py index d2c6aa56c..e8827d1f5 100644 --- a/demo/dygraph/unstructured_pruning/evaluate.py +++ b/demo/dygraph/unstructured_pruning/evaluate.py @@ -67,8 +67,6 @@ def test(epoch): start_time = time.time() x_data = data[0] y_data = paddle.to_tensor(data[1]) - if args.data == 'cifar10': - y_data = paddle.unsqueeze(y_data, 1) logits = model(x_data) loss = F.cross_entropy(logits, y_data) diff --git a/demo/dygraph/unstructured_pruning/train.py b/demo/dygraph/unstructured_pruning/train.py index 39b6804bf..5cf717892 100644 --- a/demo/dygraph/unstructured_pruning/train.py +++ b/demo/dygraph/unstructured_pruning/train.py @@ -145,8 +145,6 @@ def test(epoch): start_time = time.time() x_data = data[0] y_data = paddle.to_tensor(data[1]) - if args.data == 'cifar10': - y_data = paddle.unsqueeze(y_data, 1) logits = model(x_data) loss = F.cross_entropy(logits, y_data) @@ -180,8 +178,6 @@ def train(epoch): train_reader_cost += time.time() - reader_start x_data = data[0] y_data = paddle.to_tensor(data[1]) - if args.data == 'cifar10': - y_data = paddle.unsqueeze(y_data, 1) train_start = time.time() logits = model(x_data)