Skip to content

Commit 194ef8b

Browse files
Fix llm.int8 unit test (#61591)
* fix llm.int8 unit test * fix llm.int8 unnittest when cpu * fix numerical mismatch * code clean
1 parent 082f954 commit 194ef8b

5 files changed

Lines changed: 100 additions & 83 deletions

File tree

paddle/phi/kernels/cpu/weight_quantize_kernel.cc

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ limitations under the License. */
2222

2323
namespace phi {
2424

25-
template <typename DeviceContext, typename T, typename D, int bits>
25+
template <typename DeviceContext,
26+
typename T,
27+
typename D,
28+
int bits,
29+
typename ScaleT = T>
2630
void quant_compute(const DeviceContext& dev_ctx,
2731
const DenseTensor& x,
2832
DenseTensor* out,
@@ -48,7 +52,7 @@ void quant_compute(const DeviceContext& dev_ctx,
4852
DDim dims = {num};
4953
const T* x_data = x.data<T>();
5054
D* out_data = out->data<D>();
51-
T* scale_data = scale->data<T>();
55+
ScaleT* scale_data = scale->data<ScaleT>();
5256

5357
DenseTensor x_int(out->type());
5458

@@ -121,11 +125,16 @@ void WeightQuantizeKernel(const Context& dev_ctx,
121125
DenseTensor* out,
122126
DenseTensor* scale) {
123127
dev_ctx.template Alloc<int8_t>(out);
124-
dev_ctx.template Alloc<T>(scale);
125-
if (algo == "weight_only_int8" || algo == "llm.int8") {
128+
if (algo == "weight_only_int8") {
129+
dev_ctx.template Alloc<T>(scale);
126130
quant_compute<Context, T, int8_t, 8>(
127131
dev_ctx, x, out, scale, algo, arch, group_size);
132+
} else if (algo == "llm.int8") {
133+
dev_ctx.template Alloc<float>(scale);
134+
quant_compute<Context, T, int8_t, 8, float>(
135+
dev_ctx, x, out, scale, algo, arch, group_size);
128136
} else if (algo == "weight_only_int4") {
137+
dev_ctx.template Alloc<T>(scale);
129138
quant_compute<Context, T, int8_t, 4>(
130139
dev_ctx, x, out, scale, algo, arch, group_size);
131140
} else {

paddle/phi/kernels/gpu/weight_quantize_kernel.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ void WeightQuantizeKernel(const Context& dev_ctx,
3737

3838
DenseTensor quanted_x;
3939
dev_ctx.template Alloc<int8_t>(out);
40-
dev_ctx.template Alloc<T>(scale);
4140
size_t m = x.dims()[0];
4241
size_t n = x.dims()[1];
4342
quanted_x.Resize({static_cast<int64_t>(m), static_cast<int64_t>(n)});
@@ -51,15 +50,17 @@ void WeightQuantizeKernel(const Context& dev_ctx,
5150
"Currently, arch only support 70, 75, 80, 86."));
5251

5352
if (algo == "llm.int8") {
53+
dev_ctx.template Alloc<float>(scale);
5454
std::vector<int> axis = {1, 0};
5555
funcs::Transpose<Context, int8_t, 2> trans;
5656
weight_quant_gpu<T, Context>(dev_ctx,
5757
x.data<T>(),
5858
quanted_x.data<int8_t>(),
59-
scale->data<T>(),
59+
scale->data<float>(),
6060
weight_shape);
6161
trans(dev_ctx, quanted_x, out, axis);
6262
} else if (algo == "weight_only_int8") {
63+
dev_ctx.template Alloc<T>(scale);
6364
weight_quant_gpu<T, Context>(dev_ctx,
6465
x.data<T>(),
6566
quanted_x.data<int8_t>(),

paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,10 @@ void weight_permute_gpu(const GPUContext& dev_ctx,
106106
}
107107
}
108108

109-
template <typename T, int VectorSize = 8>
109+
template <typename T, int VectorSize = 8, typename ScaleT>
110110
__global__ void per_channel_quant_gpu(const T* weight_data,
111111
int8_t* quanted_weight_data,
112-
T* scale_data,
112+
ScaleT* scale_data,
113113
int total_k,
114114
int total_vec_n) {
115115
int n = blockIdx.x * blockDim.x + threadIdx.x;
@@ -133,10 +133,10 @@ __global__ void per_channel_quant_gpu(const T* weight_data,
133133
abs_max[i] = fmaxf((abs_max[i]), fabsf((weight[i])));
134134
}
135135
}
136-
phi::AlignedVector<T, VectorSize> scale;
136+
phi::AlignedVector<ScaleT, VectorSize> scale;
137137
#pragma unroll
138138
for (int i = 0; i < VectorSize; ++i) {
139-
scale[i] = static_cast<T>(abs_max[i] / static_cast<float>(127.0f));
139+
scale[i] = static_cast<ScaleT>(abs_max[i] / static_cast<float>(127.0f));
140140
}
141141
*reinterpret_cast<float4*>(scale_data + VectorSize * n) =
142142
*reinterpret_cast<float4*>(&scale);
@@ -161,11 +161,11 @@ __global__ void per_channel_quant_gpu(const T* weight_data,
161161
}
162162
}
163163
}
164-
template <typename T, typename GPUContext>
164+
template <typename T, typename GPUContext, typename ScaleT>
165165
void weight_quant_gpu(const GPUContext& dev_ctx,
166166
const T* weight_data,
167167
int8_t* quanted_weight_data,
168-
T* scale_data,
168+
ScaleT* scale_data,
169169
const std::vector<int>& shape) {
170170
int total_k = shape[0];
171171
int total_n = shape[1];

paddle/phi/kernels/impl/weight_quantize_kernel_impl.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,22 @@ inline T xabs(const T x) {
4242
return x < static_cast<T>(0.0) ? -x : x;
4343
}
4444

45-
template <typename T>
45+
template <typename T, typename ScaleT>
4646
void per_channel_scale(
47-
T* scale, const T* input, size_t m, size_t n, float bound) {
47+
ScaleT* scale, const T* input, size_t m, size_t n, float bound) {
4848
for (size_t i = 0; i < n; ++i) {
4949
float max = static_cast<float>(input[i]);
5050
for (size_t j = 0; j < m; ++j) {
5151
max = static_cast<float>(xabs(input[j * n + i])) > max
5252
? static_cast<float>(xabs(input[j * n + i]))
5353
: max;
5454
}
55-
scale[i] = static_cast<T>(max / bound);
55+
scale[i] = static_cast<ScaleT>(max / bound);
5656
}
5757
}
5858

59-
template <typename T>
60-
void group_wise_scale(T* scale,
59+
template <typename T, typename ScaleT>
60+
void group_wise_scale(ScaleT* scale,
6161
const T* input,
6262
size_t m,
6363
size_t n,
@@ -72,15 +72,15 @@ void group_wise_scale(T* scale,
7272
: max;
7373
}
7474
scale[static_cast<int>(j / group_size) * n + i] =
75-
static_cast<T>(max / bound);
75+
static_cast<ScaleT>(max / bound);
7676
}
7777
}
7878
}
7979

80-
template <typename T, int quant_bit = 8>
80+
template <typename T, int quant_bit = 8, typename ScaleT>
8181
void per_channel_quant(int8_t* output,
8282
const T* input,
83-
const T* scale,
83+
const ScaleT* scale,
8484
size_t num_rows,
8585
size_t num_cols) {
8686
size_t bytes_per_out_col = num_cols * quant_bit / 8;
@@ -123,10 +123,10 @@ void per_channel_quant(int8_t* output,
123123
}
124124
}
125125

126-
template <typename T, int quant_bit = 8>
126+
template <typename T, int quant_bit = 8, typename ScaleT>
127127
void group_wise_quant(int8_t* output,
128128
const T* input,
129-
const T* scale,
129+
const ScaleT* scale,
130130
size_t num_rows,
131131
size_t num_cols,
132132
const int group_size) {

test/quantization/test_llm_int8_linear.py

Lines changed: 68 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424
from paddle.framework import set_default_dtype
2525
from paddle.pir_utils import test_with_pir_api
2626

27-
np.random.seed(123)
28-
paddle.seed(42)
29-
3027

3128
@unittest.skipIf(
3229
not core.is_compiled_with_cuda()
@@ -43,11 +40,13 @@ def config(self):
4340
self.batch = 1
4441
self.token = 32
4542
self.in_features = 64
46-
self.out_features = 256
43+
self.out_features = 128
4744
self.threshold = 6.0
4845
self.static = False
4946

5047
def setUp(self):
48+
np.random.seed(123)
49+
paddle.seed(42)
5150
self.config()
5251
x = np.random.random((self.batch, self.token, self.in_features))
5352
self.x = paddle.to_tensor(x, dtype=self.dtype)
@@ -64,49 +63,89 @@ def setUp(self):
6463
self.in_features, self.out_features, bias_attr=bias_attr
6564
)
6665

67-
self.bias = self.linear.bias
6866
self.weight = self.linear.weight
6967
self.weight_scale = None
7068
self.weight, self.weight_scale = Q.weight_quantize(
7169
self.weight, algo="llm.int8"
7270
)
7371

72+
def dynamic_quant(self, x):
73+
row_ranges = paddle.max(x, axis=[-1]).astype('float32')
74+
row_ranges = row_ranges.unsqueeze(-1)
75+
quant_x = paddle.round(
76+
paddle.clip(
77+
x.astype('float32') * 127.0 * (1 / row_ranges),
78+
min=-127.0,
79+
max=127.0,
80+
)
81+
).astype('int8')
82+
return quant_x, row_ranges
83+
7484
def get_linear_out(self):
75-
out = self.linear(self.x)
85+
outlier_cols = (
86+
paddle.nonzero(paddle.max(self.x, axis=[0, 1]) > self.threshold)
87+
.reshape([-1])
88+
.numpy()
89+
.tolist()
90+
)
91+
92+
x_int8 = self.x
93+
if len(outlier_cols) > 0:
94+
x_fp = self.x[:, :, outlier_cols]
95+
w_fp = self.linear.weight[outlier_cols]
96+
res_fp = paddle.matmul(x_fp, w_fp)
97+
98+
x_int8[:, :, outlier_cols] = 0
99+
x_int8, row_ranges = self.dynamic_quant(x_int8)
100+
101+
res_int8 = paddle.matmul(x_int8, self.weight.transpose((1, 0)))
102+
dequant_scale = row_ranges * self.weight_scale / 127.0
103+
res_dequant = (res_int8.astype('float32') * dequant_scale).astype(
104+
self.dtype
105+
)
106+
107+
if len(outlier_cols) > 0:
108+
out = res_dequant + res_fp
109+
else:
110+
out = res_dequant
111+
112+
if self.bias:
113+
out += self.bias
114+
76115
return out.numpy()
77116

78117
def get_llm_int8_linear_out(self):
79118
out = Q.llm_int8_linear(
80119
self.x,
81120
self.weight,
82-
bias=self.bias,
121+
bias=self.linear.bias,
83122
weight_scale=self.weight_scale,
84123
threshold=self.threshold,
85124
)
86125
return out.numpy()
87126

88127
@test_with_pir_api
89-
def get_llm_int8_linear_out_static(self):
128+
def llm_int8_linear_out_static(self, out_expect):
90129
paddle.enable_static()
91-
main = base.static.Program()
92-
start = base.static.Program()
93-
with base.static.program_guard(main, start):
94-
x = paddle.static.data("x", self.x.shape, dtype=self.x.dtype)
130+
main = paddle.static.Program()
131+
start = paddle.static.Program()
132+
with paddle.static.program_guard(main, start):
133+
x = paddle.static.data("x", self.x.shape, dtype=self.dtype)
95134

96135
weight = paddle.static.data(
97-
"weight", self.weight.shape, dtype=self.weight.dtype
136+
"weight", self.weight.shape, dtype='int8'
98137
)
99138
bias = paddle.static.data(
100-
"bias", self.bias.shape, dtype=self.bias.dtype
139+
"bias", self.linear.bias.shape, dtype=self.dtype
101140
)
102141
x_np = self.x.numpy()
103142
weight_np = self.weight.numpy()
104-
bias_np = self.bias.numpy()
143+
bias_np = self.linear.bias.numpy()
105144
if self.weight_scale is not None:
106145
weight_scale = paddle.static.data(
107146
"weight_scale",
108147
self.weight_scale.shape,
109-
dtype=self.weight_scale.dtype,
148+
dtype='float32',
110149
)
111150
weight_scale_np = self.weight_scale.numpy()
112151
else:
@@ -128,20 +167,30 @@ def get_llm_int8_linear_out_static(self):
128167
}
129168
exe = base.Executor(paddle.CUDAPlace(0))
130169
exe.run(start)
131-
(out,) = exe.run(main, feed=feed_dict, fetch_list=[out])
170+
(out_real,) = exe.run(main, feed=feed_dict, fetch_list=[out])
171+
132172
paddle.disable_static()
133-
return out
173+
174+
if self.dtype == "bfloat16":
175+
out_real = convert_uint16_to_float(out_real)
176+
out_expect = convert_uint16_to_float(out_expect)
177+
178+
np.testing.assert_allclose(
179+
out_real, out_expect, rtol=self.rtol, atol=self.atol
180+
)
134181

135182
def test_llm_int8_linear(self):
136183
out_expect = self.get_linear_out()
137184
if self.static:
138-
out_real = self.get_llm_int8_linear_out_static()
185+
self.llm_int8_linear_out_static(out_expect)
186+
return
139187
else:
140188
out_real = self.get_llm_int8_linear_out()
141189

142190
if self.dtype == "bfloat16":
143191
out_real = convert_uint16_to_float(out_real)
144192
out_expect = convert_uint16_to_float(out_expect)
193+
145194
np.testing.assert_allclose(
146195
out_real, out_expect, rtol=self.rtol, atol=self.atol
147196
)
@@ -174,19 +223,6 @@ def config(self):
174223
self.weight_dtype = "int8"
175224

176225

177-
@unittest.skipIf(
178-
not core.is_compiled_with_cuda()
179-
or get_cuda_version() < 11020
180-
or paddle.device.cuda.get_device_capability()[0] < 8,
181-
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
182-
)
183-
class LLMInt8LinearTestCase3(LLMInt8LinearTestCase):
184-
def config(self):
185-
super().config()
186-
self.dtype = 'bfloat16'
187-
self.weight_dtype = "int8"
188-
189-
190226
@unittest.skipIf(
191227
not core.is_compiled_with_cuda()
192228
or get_cuda_version() < 11020
@@ -215,20 +251,6 @@ def config(self):
215251
self.weight_dtype = "int4"
216252

217253

218-
@unittest.skipIf(
219-
not core.is_compiled_with_cuda()
220-
or get_cuda_version() < 11020
221-
or paddle.device.cuda.get_device_capability()[0] < 8
222-
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
223-
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8 or core is not support bfloat16",
224-
)
225-
class LLMInt8LinearTestCase6(LLMInt8LinearTestCase):
226-
def config(self):
227-
super().config()
228-
self.dtype = 'bfloat16'
229-
self.weight_dtype = "int4"
230-
231-
232254
@unittest.skipIf(
233255
not core.is_compiled_with_cuda()
234256
or get_cuda_version() < 11020
@@ -260,21 +282,6 @@ def config(self):
260282
self.token = 1
261283

262284

263-
@unittest.skipIf(
264-
not core.is_compiled_with_cuda()
265-
or get_cuda_version() < 11020
266-
or paddle.device.cuda.get_device_capability()[0] < 8,
267-
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
268-
)
269-
class LLMInt8LinearTestCase9(LLMInt8LinearTestCase):
270-
def config(self):
271-
super().config()
272-
self.dtype = 'bfloat16'
273-
self.weight_dtype = "int8"
274-
self.batch = 1
275-
self.token = 1
276-
277-
278285
@unittest.skipIf(
279286
not core.is_compiled_with_cuda()
280287
or get_cuda_version() < 11020

0 commit comments

Comments
 (0)