Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions paddle/phi/kernels/cpu/weight_quantize_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ limitations under the License. */

namespace phi {

template <typename DeviceContext, typename T, typename D, int bits>
template <typename DeviceContext,
typename T,
typename D,
int bits,
typename ScaleT = T>
void quant_compute(const DeviceContext& dev_ctx,
const DenseTensor& x,
DenseTensor* out,
Expand All @@ -48,7 +52,7 @@ void quant_compute(const DeviceContext& dev_ctx,
DDim dims = {num};
const T* x_data = x.data<T>();
D* out_data = out->data<D>();
T* scale_data = scale->data<T>();
ScaleT* scale_data = scale->data<ScaleT>();

DenseTensor x_int(out->type());

Expand Down Expand Up @@ -121,11 +125,16 @@ void WeightQuantizeKernel(const Context& dev_ctx,
DenseTensor* out,
DenseTensor* scale) {
dev_ctx.template Alloc<int8_t>(out);
dev_ctx.template Alloc<T>(scale);
if (algo == "weight_only_int8" || algo == "llm.int8") {
if (algo == "weight_only_int8") {
dev_ctx.template Alloc<T>(scale);
quant_compute<Context, T, int8_t, 8>(
dev_ctx, x, out, scale, algo, arch, group_size);
} else if (algo == "llm.int8") {
dev_ctx.template Alloc<float>(scale);
quant_compute<Context, T, int8_t, 8, float>(
dev_ctx, x, out, scale, algo, arch, group_size);
} else if (algo == "weight_only_int4") {
dev_ctx.template Alloc<T>(scale);
quant_compute<Context, T, int8_t, 4>(
dev_ctx, x, out, scale, algo, arch, group_size);
} else {
Expand Down
5 changes: 3 additions & 2 deletions paddle/phi/kernels/gpu/weight_quantize_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ void WeightQuantizeKernel(const Context& dev_ctx,

DenseTensor quanted_x;
dev_ctx.template Alloc<int8_t>(out);
dev_ctx.template Alloc<T>(scale);
size_t m = x.dims()[0];
size_t n = x.dims()[1];
quanted_x.Resize({static_cast<int64_t>(m), static_cast<int64_t>(n)});
Expand All @@ -51,15 +50,17 @@ void WeightQuantizeKernel(const Context& dev_ctx,
"Currently, arch only support 70, 75, 80, 86."));

if (algo == "llm.int8") {
dev_ctx.template Alloc<float>(scale);
std::vector<int> axis = {1, 0};
funcs::Transpose<Context, int8_t, 2> trans;
weight_quant_gpu<T, Context>(dev_ctx,
x.data<T>(),
quanted_x.data<int8_t>(),
scale->data<T>(),
scale->data<float>(),
weight_shape);
trans(dev_ctx, quanted_x, out, axis);
} else if (algo == "weight_only_int8") {
dev_ctx.template Alloc<T>(scale);
weight_quant_gpu<T, Context>(dev_ctx,
x.data<T>(),
quanted_x.data<int8_t>(),
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ void weight_permute_gpu(const GPUContext& dev_ctx,
}
}

template <typename T, int VectorSize = 8>
template <typename T, int VectorSize = 8, typename ScaleT>
__global__ void per_channel_quant_gpu(const T* weight_data,
int8_t* quanted_weight_data,
T* scale_data,
ScaleT* scale_data,
int total_k,
int total_vec_n) {
int n = blockIdx.x * blockDim.x + threadIdx.x;
Expand All @@ -133,10 +133,10 @@ __global__ void per_channel_quant_gpu(const T* weight_data,
abs_max[i] = fmaxf((abs_max[i]), fabsf((weight[i])));
}
}
phi::AlignedVector<T, VectorSize> scale;
phi::AlignedVector<ScaleT, VectorSize> scale;
#pragma unroll
for (int i = 0; i < VectorSize; ++i) {
scale[i] = static_cast<T>(abs_max[i] / static_cast<float>(127.0f));
scale[i] = static_cast<ScaleT>(abs_max[i] / static_cast<float>(127.0f));
}
*reinterpret_cast<float4*>(scale_data + VectorSize * n) =
*reinterpret_cast<float4*>(&scale);
Expand All @@ -161,11 +161,11 @@ __global__ void per_channel_quant_gpu(const T* weight_data,
}
}
}
template <typename T, typename GPUContext>
template <typename T, typename GPUContext, typename ScaleT>
void weight_quant_gpu(const GPUContext& dev_ctx,
const T* weight_data,
int8_t* quanted_weight_data,
T* scale_data,
ScaleT* scale_data,
const std::vector<int>& shape) {
int total_k = shape[0];
int total_n = shape[1];
Expand Down
20 changes: 10 additions & 10 deletions paddle/phi/kernels/impl/weight_quantize_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,22 @@ inline T xabs(const T x) {
return x < static_cast<T>(0.0) ? -x : x;
}

template <typename T>
template <typename T, typename ScaleT>
void per_channel_scale(
T* scale, const T* input, size_t m, size_t n, float bound) {
ScaleT* scale, const T* input, size_t m, size_t n, float bound) {
for (size_t i = 0; i < n; ++i) {
float max = static_cast<float>(input[i]);
for (size_t j = 0; j < m; ++j) {
max = static_cast<float>(xabs(input[j * n + i])) > max
? static_cast<float>(xabs(input[j * n + i]))
: max;
}
scale[i] = static_cast<T>(max / bound);
scale[i] = static_cast<ScaleT>(max / bound);
}
}

template <typename T>
void group_wise_scale(T* scale,
template <typename T, typename ScaleT>
void group_wise_scale(ScaleT* scale,
const T* input,
size_t m,
size_t n,
Expand All @@ -72,15 +72,15 @@ void group_wise_scale(T* scale,
: max;
}
scale[static_cast<int>(j / group_size) * n + i] =
static_cast<T>(max / bound);
static_cast<ScaleT>(max / bound);
}
}
}

template <typename T, int quant_bit = 8>
template <typename T, int quant_bit = 8, typename ScaleT>
void per_channel_quant(int8_t* output,
const T* input,
const T* scale,
const ScaleT* scale,
size_t num_rows,
size_t num_cols) {
size_t bytes_per_out_col = num_cols * quant_bit / 8;
Expand Down Expand Up @@ -123,10 +123,10 @@ void per_channel_quant(int8_t* output,
}
}

template <typename T, int quant_bit = 8>
template <typename T, int quant_bit = 8, typename ScaleT>
void group_wise_quant(int8_t* output,
const T* input,
const T* scale,
const ScaleT* scale,
size_t num_rows,
size_t num_cols,
const int group_size) {
Expand Down
129 changes: 68 additions & 61 deletions test/quantization/test_llm_int8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@
from paddle.framework import set_default_dtype
from paddle.pir_utils import test_with_pir_api

np.random.seed(123)
paddle.seed(42)


@unittest.skipIf(
not core.is_compiled_with_cuda()
Expand All @@ -43,11 +40,13 @@ def config(self):
self.batch = 1
self.token = 32
self.in_features = 64
self.out_features = 256
self.out_features = 128
self.threshold = 6.0
self.static = False

def setUp(self):
np.random.seed(123)
paddle.seed(42)
self.config()
x = np.random.random((self.batch, self.token, self.in_features))
self.x = paddle.to_tensor(x, dtype=self.dtype)
Expand All @@ -64,49 +63,89 @@ def setUp(self):
self.in_features, self.out_features, bias_attr=bias_attr
)

self.bias = self.linear.bias
self.weight = self.linear.weight
self.weight_scale = None
self.weight, self.weight_scale = Q.weight_quantize(
self.weight, algo="llm.int8"
)

def dynamic_quant(self, x):
row_ranges = paddle.max(x, axis=[-1]).astype('float32')
row_ranges = row_ranges.unsqueeze(-1)
quant_x = paddle.round(
paddle.clip(
x.astype('float32') * 127.0 * (1 / row_ranges),
min=-127.0,
max=127.0,
)
).astype('int8')
return quant_x, row_ranges

def get_linear_out(self):
out = self.linear(self.x)
outlier_cols = (
paddle.nonzero(paddle.max(self.x, axis=[0, 1]) > self.threshold)
.reshape([-1])
.numpy()
.tolist()
)

x_int8 = self.x
if len(outlier_cols) > 0:
x_fp = self.x[:, :, outlier_cols]
w_fp = self.linear.weight[outlier_cols]
res_fp = paddle.matmul(x_fp, w_fp)

x_int8[:, :, outlier_cols] = 0
x_int8, row_ranges = self.dynamic_quant(x_int8)

res_int8 = paddle.matmul(x_int8, self.weight.transpose((1, 0)))
dequant_scale = row_ranges * self.weight_scale / 127.0
res_dequant = (res_int8.astype('float32') * dequant_scale).astype(
self.dtype
)

if len(outlier_cols) > 0:
out = res_dequant + res_fp
else:
out = res_dequant

if self.bias:
out += self.bias

return out.numpy()

def get_llm_int8_linear_out(self):
out = Q.llm_int8_linear(
self.x,
self.weight,
bias=self.bias,
bias=self.linear.bias,
weight_scale=self.weight_scale,
threshold=self.threshold,
)
return out.numpy()

@test_with_pir_api
def get_llm_int8_linear_out_static(self):
def llm_int8_linear_out_static(self, out_expect):
paddle.enable_static()
main = base.static.Program()
start = base.static.Program()
with base.static.program_guard(main, start):
x = paddle.static.data("x", self.x.shape, dtype=self.x.dtype)
main = paddle.static.Program()
start = paddle.static.Program()
with paddle.static.program_guard(main, start):
x = paddle.static.data("x", self.x.shape, dtype=self.dtype)

weight = paddle.static.data(
"weight", self.weight.shape, dtype=self.weight.dtype
"weight", self.weight.shape, dtype='int8'
)
bias = paddle.static.data(
"bias", self.bias.shape, dtype=self.bias.dtype
"bias", self.linear.bias.shape, dtype=self.dtype
)
x_np = self.x.numpy()
weight_np = self.weight.numpy()
bias_np = self.bias.numpy()
bias_np = self.linear.bias.numpy()
if self.weight_scale is not None:
weight_scale = paddle.static.data(
"weight_scale",
self.weight_scale.shape,
dtype=self.weight_scale.dtype,
dtype='float32',
)
weight_scale_np = self.weight_scale.numpy()
else:
Expand All @@ -128,20 +167,30 @@ def get_llm_int8_linear_out_static(self):
}
exe = base.Executor(paddle.CUDAPlace(0))
exe.run(start)
(out,) = exe.run(main, feed=feed_dict, fetch_list=[out])
(out_real,) = exe.run(main, feed=feed_dict, fetch_list=[out])

paddle.disable_static()
return out

if self.dtype == "bfloat16":
out_real = convert_uint16_to_float(out_real)
out_expect = convert_uint16_to_float(out_expect)

np.testing.assert_allclose(
out_real, out_expect, rtol=self.rtol, atol=self.atol
)

def test_llm_int8_linear(self):
out_expect = self.get_linear_out()
if self.static:
out_real = self.get_llm_int8_linear_out_static()
self.llm_int8_linear_out_static(out_expect)
return
else:
out_real = self.get_llm_int8_linear_out()

if self.dtype == "bfloat16":
out_real = convert_uint16_to_float(out_real)
out_expect = convert_uint16_to_float(out_expect)

np.testing.assert_allclose(
out_real, out_expect, rtol=self.rtol, atol=self.atol
)
Expand Down Expand Up @@ -174,19 +223,6 @@ def config(self):
self.weight_dtype = "int8"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class LLMInt8LinearTestCase3(LLMInt8LinearTestCase):
def config(self):
super().config()
self.dtype = 'bfloat16'
self.weight_dtype = "int8"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
Expand Down Expand Up @@ -215,20 +251,6 @@ def config(self):
self.weight_dtype = "int4"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8 or core is not support bfloat16",
)
class LLMInt8LinearTestCase6(LLMInt8LinearTestCase):
def config(self):
super().config()
self.dtype = 'bfloat16'
self.weight_dtype = "int4"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
Expand Down Expand Up @@ -260,21 +282,6 @@ def config(self):
self.token = 1


@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class LLMInt8LinearTestCase9(LLMInt8LinearTestCase):
def config(self):
super().config()
self.dtype = 'bfloat16'
self.weight_dtype = "int8"
self.batch = 1
self.token = 1


@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
Expand Down