From dbd6f0c08b2ddef09c4f9f82d959a3777a03a6d4 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Sat, 13 Apr 2024 08:44:15 +0000 Subject: [PATCH] add apply_per_channel_scale --- python/paddle/nn/quant/quantized_linear.py | 3 +-- .../test_apply_per_channel_scale.py | 23 +++++++++++-------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/python/paddle/nn/quant/quantized_linear.py b/python/paddle/nn/quant/quantized_linear.py index 7803c3bd387669..1c2d962f720cf0 100644 --- a/python/paddle/nn/quant/quantized_linear.py +++ b/python/paddle/nn/quant/quantized_linear.py @@ -18,7 +18,6 @@ from paddle.device.cuda import get_device_capability from paddle.framework import ( LayerHelper, - in_dynamic_mode, in_dynamic_or_pir_mode, ) @@ -326,7 +325,7 @@ def apply_per_channel_scale(x, scales): >>> out = apply_per_channel_scale(x, scales) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.apply_per_channel_scale(x, scales) else: type = "apply_per_channel_scale" diff --git a/test/quantization/test_apply_per_channel_scale.py b/test/quantization/test_apply_per_channel_scale.py index a28b69525c2213..c0fb5b254b10c4 100644 --- a/test/quantization/test_apply_per_channel_scale.py +++ b/test/quantization/test_apply_per_channel_scale.py @@ -21,8 +21,8 @@ import paddle import paddle.nn.quant as Q -from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api def get_cuda_version(): @@ -71,9 +71,9 @@ def setUp(self): def get_out_static(self): paddle.enable_static() - main = base.Program() - start = base.Program() - with base.program_guard(main, start): + main = paddle.static.Program() + start = paddle.static.Program() + with paddle.static.program_guard(main, start): x = paddle.static.data("x", self.x.shape, dtype=self.dtype) scales = paddle.static.data( "scales", self.scales.shape, dtype=self.dtype @@ -86,26 +86,31 @@ def get_out_static(self): 'scales': self.scales.numpy(), } - exe = base.Executor(paddle.CUDAPlace(0)) + exe = paddle.static.Executor(paddle.CUDAPlace(0)) exe.run(start) (out,) = exe.run(main, feed=feed_dict, fetch_list=[out]) paddle.disable_static() return out + @test_with_pir_api def test_apply_per_channel_scale(self): if self.static: self.out_real = self.get_out_static() else: + paddle.disable_static() self.out_real = Q.apply_per_channel_scale( x=self.x, scales=self.scales, ) - - if self.dtype == 'bfloat16': + out_expected = self.out_expected + if self.dtype == 'bfloat16' and isinstance( + self.out_real, paddle.Tensor + ): self.out_real = convert_uint16_to_float(self.out_real) - self.out_expected = convert_uint16_to_float(self.out_expected) + out_expected = convert_uint16_to_float(self.out_expected) + np.testing.assert_allclose( - self.out_expected, self.out_real, rtol=self.rtol, atol=self.atol + out_expected, self.out_real, rtol=self.rtol, atol=self.atol )