diff --git a/paddle/phi/kernels/cpu/weight_quantize_kernel.cc b/paddle/phi/kernels/cpu/weight_quantize_kernel.cc index 313c59e2e66765..61304e43d4e85a 100644 --- a/paddle/phi/kernels/cpu/weight_quantize_kernel.cc +++ b/paddle/phi/kernels/cpu/weight_quantize_kernel.cc @@ -22,7 +22,11 @@ limitations under the License. */ namespace phi { -template +template void quant_compute(const DeviceContext& dev_ctx, const DenseTensor& x, DenseTensor* out, @@ -48,7 +52,7 @@ void quant_compute(const DeviceContext& dev_ctx, DDim dims = {num}; const T* x_data = x.data(); D* out_data = out->data(); - T* scale_data = scale->data(); + ScaleT* scale_data = scale->data(); DenseTensor x_int(out->type()); @@ -121,11 +125,16 @@ void WeightQuantizeKernel(const Context& dev_ctx, DenseTensor* out, DenseTensor* scale) { dev_ctx.template Alloc(out); - dev_ctx.template Alloc(scale); - if (algo == "weight_only_int8" || algo == "llm.int8") { + if (algo == "weight_only_int8") { + dev_ctx.template Alloc(scale); quant_compute( dev_ctx, x, out, scale, algo, arch, group_size); + } else if (algo == "llm.int8") { + dev_ctx.template Alloc(scale); + quant_compute( + dev_ctx, x, out, scale, algo, arch, group_size); } else if (algo == "weight_only_int4") { + dev_ctx.template Alloc(scale); quant_compute( dev_ctx, x, out, scale, algo, arch, group_size); } else { diff --git a/paddle/phi/kernels/gpu/weight_quantize_kernel.cu b/paddle/phi/kernels/gpu/weight_quantize_kernel.cu index 8cd5598e2e92a3..103691f9cd8a47 100644 --- a/paddle/phi/kernels/gpu/weight_quantize_kernel.cu +++ b/paddle/phi/kernels/gpu/weight_quantize_kernel.cu @@ -37,7 +37,6 @@ void WeightQuantizeKernel(const Context& dev_ctx, DenseTensor quanted_x; dev_ctx.template Alloc(out); - dev_ctx.template Alloc(scale); size_t m = x.dims()[0]; size_t n = x.dims()[1]; quanted_x.Resize({static_cast(m), static_cast(n)}); @@ -51,15 +50,17 @@ void WeightQuantizeKernel(const Context& dev_ctx, "Currently, arch only support 70, 75, 80, 86.")); if (algo == "llm.int8") { + dev_ctx.template Alloc(scale); std::vector axis = {1, 0}; funcs::Transpose trans; weight_quant_gpu(dev_ctx, x.data(), quanted_x.data(), - scale->data(), + scale->data(), weight_shape); trans(dev_ctx, quanted_x, out, axis); } else if (algo == "weight_only_int8") { + dev_ctx.template Alloc(scale); weight_quant_gpu(dev_ctx, x.data(), quanted_x.data(), diff --git a/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h b/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h index 201dd403270f36..05d0e47b314555 100644 --- a/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h +++ b/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h @@ -106,10 +106,10 @@ void weight_permute_gpu(const GPUContext& dev_ctx, } } -template +template __global__ void per_channel_quant_gpu(const T* weight_data, int8_t* quanted_weight_data, - T* scale_data, + ScaleT* scale_data, int total_k, int total_vec_n) { int n = blockIdx.x * blockDim.x + threadIdx.x; @@ -133,10 +133,10 @@ __global__ void per_channel_quant_gpu(const T* weight_data, abs_max[i] = fmaxf((abs_max[i]), fabsf((weight[i]))); } } - phi::AlignedVector scale; + phi::AlignedVector scale; #pragma unroll for (int i = 0; i < VectorSize; ++i) { - scale[i] = static_cast(abs_max[i] / static_cast(127.0f)); + scale[i] = static_cast(abs_max[i] / static_cast(127.0f)); } *reinterpret_cast(scale_data + VectorSize * n) = *reinterpret_cast(&scale); @@ -161,11 +161,11 @@ __global__ void per_channel_quant_gpu(const T* weight_data, } } } -template +template void weight_quant_gpu(const GPUContext& dev_ctx, const T* weight_data, int8_t* quanted_weight_data, - T* scale_data, + ScaleT* scale_data, const std::vector& shape) { int total_k = shape[0]; int total_n = shape[1]; diff --git a/paddle/phi/kernels/impl/weight_quantize_kernel_impl.h b/paddle/phi/kernels/impl/weight_quantize_kernel_impl.h index 2905fd14e6b335..6f7fc1e9c06806 100644 --- a/paddle/phi/kernels/impl/weight_quantize_kernel_impl.h +++ b/paddle/phi/kernels/impl/weight_quantize_kernel_impl.h @@ -42,9 +42,9 @@ inline T xabs(const T x) { return x < static_cast(0.0) ? -x : x; } -template +template void per_channel_scale( - T* scale, const T* input, size_t m, size_t n, float bound) { + ScaleT* scale, const T* input, size_t m, size_t n, float bound) { for (size_t i = 0; i < n; ++i) { float max = static_cast(input[i]); for (size_t j = 0; j < m; ++j) { @@ -52,12 +52,12 @@ void per_channel_scale( ? static_cast(xabs(input[j * n + i])) : max; } - scale[i] = static_cast(max / bound); + scale[i] = static_cast(max / bound); } } -template -void group_wise_scale(T* scale, +template +void group_wise_scale(ScaleT* scale, const T* input, size_t m, size_t n, @@ -72,15 +72,15 @@ void group_wise_scale(T* scale, : max; } scale[static_cast(j / group_size) * n + i] = - static_cast(max / bound); + static_cast(max / bound); } } } -template +template void per_channel_quant(int8_t* output, const T* input, - const T* scale, + const ScaleT* scale, size_t num_rows, size_t num_cols) { size_t bytes_per_out_col = num_cols * quant_bit / 8; @@ -123,10 +123,10 @@ void per_channel_quant(int8_t* output, } } -template +template void group_wise_quant(int8_t* output, const T* input, - const T* scale, + const ScaleT* scale, size_t num_rows, size_t num_cols, const int group_size) { diff --git a/test/quantization/test_llm_int8_linear.py b/test/quantization/test_llm_int8_linear.py index 972c41bd31f52c..909f44c0ca4041 100644 --- a/test/quantization/test_llm_int8_linear.py +++ b/test/quantization/test_llm_int8_linear.py @@ -24,9 +24,6 @@ from paddle.framework import set_default_dtype from paddle.pir_utils import test_with_pir_api -np.random.seed(123) -paddle.seed(42) - @unittest.skipIf( not core.is_compiled_with_cuda() @@ -43,11 +40,13 @@ def config(self): self.batch = 1 self.token = 32 self.in_features = 64 - self.out_features = 256 + self.out_features = 128 self.threshold = 6.0 self.static = False def setUp(self): + np.random.seed(123) + paddle.seed(42) self.config() x = np.random.random((self.batch, self.token, self.in_features)) self.x = paddle.to_tensor(x, dtype=self.dtype) @@ -64,49 +63,89 @@ def setUp(self): self.in_features, self.out_features, bias_attr=bias_attr ) - self.bias = self.linear.bias self.weight = self.linear.weight self.weight_scale = None self.weight, self.weight_scale = Q.weight_quantize( self.weight, algo="llm.int8" ) + def dynamic_quant(self, x): + row_ranges = paddle.max(x, axis=[-1]).astype('float32') + row_ranges = row_ranges.unsqueeze(-1) + quant_x = paddle.round( + paddle.clip( + x.astype('float32') * 127.0 * (1 / row_ranges), + min=-127.0, + max=127.0, + ) + ).astype('int8') + return quant_x, row_ranges + def get_linear_out(self): - out = self.linear(self.x) + outlier_cols = ( + paddle.nonzero(paddle.max(self.x, axis=[0, 1]) > self.threshold) + .reshape([-1]) + .numpy() + .tolist() + ) + + x_int8 = self.x + if len(outlier_cols) > 0: + x_fp = self.x[:, :, outlier_cols] + w_fp = self.linear.weight[outlier_cols] + res_fp = paddle.matmul(x_fp, w_fp) + + x_int8[:, :, outlier_cols] = 0 + x_int8, row_ranges = self.dynamic_quant(x_int8) + + res_int8 = paddle.matmul(x_int8, self.weight.transpose((1, 0))) + dequant_scale = row_ranges * self.weight_scale / 127.0 + res_dequant = (res_int8.astype('float32') * dequant_scale).astype( + self.dtype + ) + + if len(outlier_cols) > 0: + out = res_dequant + res_fp + else: + out = res_dequant + + if self.bias: + out += self.bias + return out.numpy() def get_llm_int8_linear_out(self): out = Q.llm_int8_linear( self.x, self.weight, - bias=self.bias, + bias=self.linear.bias, weight_scale=self.weight_scale, threshold=self.threshold, ) return out.numpy() @test_with_pir_api - def get_llm_int8_linear_out_static(self): + def llm_int8_linear_out_static(self, out_expect): paddle.enable_static() - main = base.static.Program() - start = base.static.Program() - with base.static.program_guard(main, start): - x = paddle.static.data("x", self.x.shape, dtype=self.x.dtype) + main = paddle.static.Program() + start = paddle.static.Program() + with paddle.static.program_guard(main, start): + x = paddle.static.data("x", self.x.shape, dtype=self.dtype) weight = paddle.static.data( - "weight", self.weight.shape, dtype=self.weight.dtype + "weight", self.weight.shape, dtype='int8' ) bias = paddle.static.data( - "bias", self.bias.shape, dtype=self.bias.dtype + "bias", self.linear.bias.shape, dtype=self.dtype ) x_np = self.x.numpy() weight_np = self.weight.numpy() - bias_np = self.bias.numpy() + bias_np = self.linear.bias.numpy() if self.weight_scale is not None: weight_scale = paddle.static.data( "weight_scale", self.weight_scale.shape, - dtype=self.weight_scale.dtype, + dtype='float32', ) weight_scale_np = self.weight_scale.numpy() else: @@ -128,20 +167,30 @@ def get_llm_int8_linear_out_static(self): } exe = base.Executor(paddle.CUDAPlace(0)) exe.run(start) - (out,) = exe.run(main, feed=feed_dict, fetch_list=[out]) + (out_real,) = exe.run(main, feed=feed_dict, fetch_list=[out]) + paddle.disable_static() - return out + + if self.dtype == "bfloat16": + out_real = convert_uint16_to_float(out_real) + out_expect = convert_uint16_to_float(out_expect) + + np.testing.assert_allclose( + out_real, out_expect, rtol=self.rtol, atol=self.atol + ) def test_llm_int8_linear(self): out_expect = self.get_linear_out() if self.static: - out_real = self.get_llm_int8_linear_out_static() + self.llm_int8_linear_out_static(out_expect) + return else: out_real = self.get_llm_int8_linear_out() if self.dtype == "bfloat16": out_real = convert_uint16_to_float(out_real) out_expect = convert_uint16_to_float(out_expect) + np.testing.assert_allclose( out_real, out_expect, rtol=self.rtol, atol=self.atol ) @@ -174,19 +223,6 @@ def config(self): self.weight_dtype = "int8" -@unittest.skipIf( - not core.is_compiled_with_cuda() - or get_cuda_version() < 11020 - or paddle.device.cuda.get_device_capability()[0] < 8, - "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", -) -class LLMInt8LinearTestCase3(LLMInt8LinearTestCase): - def config(self): - super().config() - self.dtype = 'bfloat16' - self.weight_dtype = "int8" - - @unittest.skipIf( not core.is_compiled_with_cuda() or get_cuda_version() < 11020 @@ -215,20 +251,6 @@ def config(self): self.weight_dtype = "int4" -@unittest.skipIf( - not core.is_compiled_with_cuda() - or get_cuda_version() < 11020 - or paddle.device.cuda.get_device_capability()[0] < 8 - or not core.is_bfloat16_supported(core.CUDAPlace(0)), - "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8 or core is not support bfloat16", -) -class LLMInt8LinearTestCase6(LLMInt8LinearTestCase): - def config(self): - super().config() - self.dtype = 'bfloat16' - self.weight_dtype = "int4" - - @unittest.skipIf( not core.is_compiled_with_cuda() or get_cuda_version() < 11020 @@ -260,21 +282,6 @@ def config(self): self.token = 1 -@unittest.skipIf( - not core.is_compiled_with_cuda() - or get_cuda_version() < 11020 - or paddle.device.cuda.get_device_capability()[0] < 8, - "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", -) -class LLMInt8LinearTestCase9(LLMInt8LinearTestCase): - def config(self): - super().config() - self.dtype = 'bfloat16' - self.weight_dtype = "int8" - self.batch = 1 - self.token = 1 - - @unittest.skipIf( not core.is_compiled_with_cuda() or get_cuda_version() < 11020