diff --git a/demo/dygraph/unstructured_pruning/evaluate.py b/demo/dygraph/unstructured_pruning/evaluate.py index e8827d1f5..d2c6aa56c 100644 --- a/demo/dygraph/unstructured_pruning/evaluate.py +++ b/demo/dygraph/unstructured_pruning/evaluate.py @@ -67,6 +67,8 @@ def test(epoch): start_time = time.time() x_data = data[0] y_data = paddle.to_tensor(data[1]) + if args.data == 'cifar10': + y_data = paddle.unsqueeze(y_data, 1) logits = model(x_data) loss = F.cross_entropy(logits, y_data) diff --git a/demo/dygraph/unstructured_pruning/train.py b/demo/dygraph/unstructured_pruning/train.py index 7bd3f479b..1c859684c 100644 --- a/demo/dygraph/unstructured_pruning/train.py +++ b/demo/dygraph/unstructured_pruning/train.py @@ -145,6 +145,8 @@ def test(epoch): start_time = time.time() x_data = data[0] y_data = paddle.to_tensor(data[1]) + if args.data == 'cifar10': + y_data = paddle.unsqueeze(y_data, 1) logits = model(x_data) loss = F.cross_entropy(logits, y_data) @@ -178,6 +180,8 @@ def train(epoch): train_reader_cost += time.time() - reader_start x_data = data[0] y_data = paddle.to_tensor(data[1]) + if args.data == 'cifar10': + y_data = paddle.unsqueeze(y_data, 1) train_start = time.time() logits = model(x_data)