From c612cfd9ecae5e3fb8079eb81b703cacd63c0db3 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Fri, 24 May 2024 14:39:37 +0800 Subject: [PATCH 1/3] move fake channel wise dequantize max abs --- paddle/fluid/operators/fake_dequantize_op.cc | 239 ------------------ paddle/fluid/operators/fake_dequantize_op.cu | 27 -- .../fluid/operators/fake_dequantize_op.cu.h | 150 ----------- paddle/fluid/operators/fake_dequantize_op.h | 96 ------- paddle/phi/infermeta/multiary.cc | 24 ++ paddle/phi/infermeta/multiary.h | 8 + .../phi/kernels/cpu/fake_dequantize_kernel.cc | 7 + paddle/phi/kernels/fake_dequantize_kernel.h | 9 + .../kernels/funcs/fake_dequantize_functor.cc | 102 ++++++++ .../kernels/funcs/fake_dequantize_functor.cu | 124 +++++++++ .../kernels/funcs/fake_dequantize_functor.h | 13 + .../phi/kernels/gpu/fake_dequantize_kernel.cu | 8 + .../impl/fake_dequantize_kernel_impl.h | 52 ++++ paddle/phi/ops/yaml/op_compat.yaml | 6 + paddle/phi/ops/yaml/ops.yaml | 9 + test/legacy_test/test_fake_dequantize_op.py | 10 +- 16 files changed, 367 insertions(+), 517 deletions(-) delete mode 100644 paddle/fluid/operators/fake_dequantize_op.cc delete mode 100644 paddle/fluid/operators/fake_dequantize_op.cu delete mode 100644 paddle/fluid/operators/fake_dequantize_op.cu.h delete mode 100644 paddle/fluid/operators/fake_dequantize_op.h diff --git a/paddle/fluid/operators/fake_dequantize_op.cc b/paddle/fluid/operators/fake_dequantize_op.cc deleted file mode 100644 index 9fbd4909164b7f..00000000000000 --- a/paddle/fluid/operators/fake_dequantize_op.cc +++ /dev/null @@ -1,239 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/fake_dequantize_op.h" - -#include -#include - -#include "paddle/fluid/framework/op_version_registry.h" - -namespace paddle { -namespace operators { - -template -struct ChannelDequantizeFunctor { - void operator()(const phi::CPUContext& dev_ctx, - const phi::DenseTensor* in, - const phi::DenseTensor** scales, - const int scale_num, - T max_range, - const int quant_axis, - const int x_num_col_dims, - phi::DenseTensor* out) { - if (scale_num == 1) { - // Dequant op is before quantized op - // Dequantize the weight of quantized op - auto in_dims = in->dims(); - const int64_t channel = in_dims[quant_axis]; - const T* scale_factor = scales[0]->data(); - if (quant_axis == 0) { - for (int64_t i = 0; i < channel; i++) { - T s = scale_factor[i]; - phi::DenseTensor one_channel_in = in->Slice(i, i + 1); - phi::DenseTensor one_channel_out = out->Slice(i, i + 1); - auto in_e = phi::EigenVector::Flatten(one_channel_in); - auto out_e = phi::EigenVector::Flatten(one_channel_out); - auto& dev = *dev_ctx.eigen_device(); - out_e.device(dev) = in_e * s / max_range; - } - } else if (quant_axis == 1) { - int64_t out_iter = 1; - for (int i = 0; i < quant_axis; i++) { - out_iter *= in_dims[i]; - } - int64_t step_i = in->numel() / out_iter; - int64_t step_j = in->numel() / (out_iter * channel); - auto* in_data = in->data(); - auto* out_data = out->mutable_data(dev_ctx.GetPlace()); - for (int64_t i = 0; i < out_iter; i++) { - for (int64_t j = 0; j < channel; j++) { - auto* cur_in = in_data + i * step_i + j * step_j; - auto* cur_out = out_data + i * step_i + j * step_j; - T s = scale_factor[j]; - for (int64_t k = 0; k < step_j; k++) { - *cur_out = (*cur_in) * s / max_range; - ++cur_in; - ++cur_out; - } - } - } - } - } else if (scale_num == 2) { - // Dequant op is after quantized op - // Dequantize the output tensor of quantized op - if (x_num_col_dims > 1) { - auto in_dims = in->dims(); - const int64_t channel = in_dims[x_num_col_dims]; - const T* scale_one = scales[0]->data(); - const T* scale_two = scales[1]->data(); - int64_t out_iter = 1; - for (int i = 0; i < x_num_col_dims; i++) { - out_iter *= in_dims[i]; - } - int64_t step_i = in->numel() / out_iter; - int64_t step_j = in->numel() / (out_iter * channel); - auto* in_data = in->data(); - auto* out_data = out->mutable_data(dev_ctx.GetPlace()); - for (int64_t i = 0; i < out_iter; i++) { - for (int64_t j = 0; j < channel; j++) { - auto* cur_in = in_data + i * step_i + j * step_j; - auto* cur_out = out_data + i * step_i + j * step_j; - T s = scale_one[j]; - for (int64_t k = 0; k < step_j; k++) { - *cur_out = (*cur_in) * s * scale_two[0] / max_range; - ++cur_in; - ++cur_out; - } - } - } - } else { - int batch_size = static_cast(in->dims()[0]); - int channel = static_cast(in->dims()[1]); - const T* scale_one = scales[0]->data(); - const T* scale_two = scales[1]->data(); - for (int i = 0; i < batch_size; i++) { - phi::DenseTensor one_batch_in = in->Slice(i, i + 1).Resize( - common::slice_ddim(in->dims(), 1, in->dims().size())); - phi::DenseTensor one_batch_out = out->Slice(i, i + 1).Resize( - common::slice_ddim(out->dims(), 1, out->dims().size())); - for (int j = 0; j < channel; j++) { - T s = scale_one[j]; - phi::DenseTensor one_channel_in = one_batch_in.Slice(j, j + 1); - phi::DenseTensor one_channel_out = one_batch_out.Slice(j, j + 1); - auto in_e = phi::EigenVector::Flatten(one_channel_in); - auto out_e = phi::EigenVector::Flatten(one_channel_out); - auto& dev = *dev_ctx.eigen_device(); - out_e.device(dev) = in_e * s * scale_two[0] / max_range; - } - } - } - } - } -}; - -template struct ChannelDequantizeFunctor; -template struct ChannelDequantizeFunctor; - -class FakeChannelWiseDequantizeMaxAbsOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK( - ctx->HasInput("X"), "Input", "X", "FakeChannelWiseDequantizeMaxAbs"); - OP_INOUT_CHECK(ctx->HasInputs("Scales"), - "Input", - "Scales", - "FakeChannelWiseDequantizeMaxAbs"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), - "Output", - "Out", - "FakeChannelWiseDequantizeMaxAbs"); - - ctx->ShareDim("X", /*->*/ "Out"); - ctx->ShareLoD("X", /*->*/ "Out"); - } -}; - -class FakeChannelWiseDequantizeMaxAbsOpMaker - : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "(Tensor) The input with float-32/64 type is the " - "low precision tensor."); - AddInput("Scales", - "(Tensors) The scales in quantization stage. " - "Now, `Scales` is a vector with at most two tensors. " - "If Scales has two elements, the second tensor should only have " - "one value.") - .AsDuplicable(); - AddOutput("Out", - "(Tensor) The output is the dequantized high " - "precision tensor."); - AddAttr>( - "quant_bits", - "Quantization bit numbers in quantization stage. " - "The size of `quant_bits` should be equal to the size of `Scales`.") - .SetDefault({8}); - AddAttr("quant_axis", - "(int, default 0) The axis for quantization. " - "For conv2d, depthwise_conv2d, conv2d_transpose " - "and mul, the quant_axis is equal to the cout axis.") - .SetDefault(0) - .AddCustomChecker([](const int& quant_axis) { - PADDLE_ENFORCE_EQ( - quant_axis == 0 || quant_axis == 1, - true, - phi::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " - "the received is %d", - quant_axis)); - }); - AddAttr("x_num_col_dims", - "The x_num_col_dims of mul. Only used for mul or matmul.") - .SetDefault(1) - .AddCustomChecker([](const int& x_num_col_dims) { - PADDLE_ENFORCE_EQ(x_num_col_dims == 0, - false, - phi::errors::InvalidArgument( - "'x_num_col_dims' should be larger than 0, but " - "the received is %d", - x_num_col_dims)); - }); - AddComment(R"DOC( -FakeChannelWiseDequantizeMaxAbsOp operator. - -This calculation is an opposite operation of FakeChannelWiseQuantizeMaxAbsOp: - -$$Out_c = \frac{X_c\prod_{i=1}^{n}Scales_{ic}}{\prod_{i=1}^{n}(2^{quant\_bits_i-1}-1)}$$ - -In the above formula, the range value of $c$ can be represented as $0 \leq c \lt \ the\ channel\ number\ of\ X$. -Besides, the size of $quant\_bits$ should be equal to the size of $Scales$, and it is called $n$ in the formula. - -Notes: In general, the per-channel quantization is only applied to weights and the activations use per-layer quantization. -)DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -using CPU = phi::CPUContext; - -REGISTER_OPERATOR( - fake_channel_wise_dequantize_max_abs, - ops::FakeChannelWiseDequantizeMaxAbsOp, - ops::FakeChannelWiseDequantizeMaxAbsOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); -PD_REGISTER_STRUCT_KERNEL(fake_channel_wise_dequantize_max_abs, - CPU, - ALL_LAYOUT, - ops::FakeChannelWiseDequantizeMaxAbsKernel, - float, - double) {} - -REGISTER_OP_VERSION(fake_channel_wise_dequantize_max_abs) - .AddCheckpoint( - R"ROC(add new attributes [quant_axis] for applying per-channel " - "dequantization to conv2d_transpose and mul ops.)ROC", - paddle::framework::compatible::OpVersionDesc().NewAttr( - "quant_axis", "The axis for dequantization.", 0)) - .AddCheckpoint( - R"ROC(add new attributes [x_num_col_dims] for applying per-channel " - "dequantization to mul ops.)ROC", - paddle::framework::compatible::OpVersionDesc().NewAttr( - "x_num_col_dims", "The x_num_col_dims for dequantization.", 1)); diff --git a/paddle/fluid/operators/fake_dequantize_op.cu b/paddle/fluid/operators/fake_dequantize_op.cu deleted file mode 100644 index db8b97a70e3193..00000000000000 --- a/paddle/fluid/operators/fake_dequantize_op.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/fake_dequantize_op.h" -#include "paddle/fluid/operators/fake_dequantize_op.cu.h" - -namespace ops = paddle::operators; -using float16 = phi::dtype::float16; - -PD_REGISTER_STRUCT_KERNEL(fake_channel_wise_dequantize_max_abs, - GPU, - ALL_LAYOUT, - ops::FakeChannelWiseDequantizeMaxAbsKernel, - float, - double, - float16) {} diff --git a/paddle/fluid/operators/fake_dequantize_op.cu.h b/paddle/fluid/operators/fake_dequantize_op.cu.h deleted file mode 100644 index b8914ce09c9083..00000000000000 --- a/paddle/fluid/operators/fake_dequantize_op.cu.h +++ /dev/null @@ -1,150 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifndef PADDLE_FLUID_OPERATORS_FAKE_DEQUANTIZE_OP_CU_H_ -#define PADDLE_FLUID_OPERATORS_FAKE_DEQUANTIZE_OP_CU_H_ -#endif // PADDLE_FLUID_OPERATORS_FAKE_DEQUANTIZE_OP_CU_H_ - -#include "paddle/fluid/operators/fake_dequantize_op.h" - -namespace paddle { -namespace operators { - -template -__global__ void DequantizeOneScaleQuantAxis0( - const T* in, const T* scale, T max_range, int num, int channel, T* out) { - int tid = threadIdx.x; - int channel_size = num / channel; - const T* in_c = in + blockIdx.x * channel_size; - T* out_c = out + blockIdx.x * channel_size; - for (int i = tid; i < channel_size; i += blockDim.x) { - out_c[i] = in_c[i] * scale[blockIdx.x] / max_range; - } -} - -template -__global__ void DequantizeOneScaleQuantAxisN(const T* in, - const T* scale, - const T max_range, - const int64_t num, - const int n_scales, - const int quant_stride, - T* out) { - int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) { - T s = scale[(i / quant_stride) % n_scales]; - out[i] = in[i] * s / max_range; - } -} - -template -__global__ void DequantizeTwoScale(const T* in, - const T* scale_one, - const T* scale_two, - T max_range, - int num, - int n_scales, - int quant_stride, - T* out) { - int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) { - int scale_index = (i / quant_stride) % n_scales; - T s = scale_one[scale_index] * scale_two[0]; - out[i] = in[i] * s / max_range; - } -} - -template -struct ChannelDequantizeFunctor { - void operator()(const phi::GPUContext& dev_ctx, - const phi::DenseTensor* in, - const phi::DenseTensor** scales, - const int scale_num, - T max_range, - const int quant_axis, - const int x_num_col_dims, - phi::DenseTensor* out) { - auto in_dims = in->dims(); - const T* in_data = in->data(); - T* out_data = out->mutable_data(dev_ctx.GetPlace()); - if (scale_num == 1) { - // Dequantize inputs or weights before quantizable operators and after - // quantization operators. inputs --> quant -- > deqaunt --> conv2d --> - int64_t num = in->numel(); - const T* scale_factor = scales[0]->data(); - int64_t block_size = std::min( - num, static_cast(dev_ctx.GetMaxThreadsPerBlock() / 4)); - int64_t max_threads = - dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM - const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1), - static_cast(1)); - const int64_t grid_size = - std::min(max_blocks, (num + block_size - 1) / block_size); - - int quant_stride = 1; - for (int i = quant_axis + 1; i < in_dims.size(); i++) { - quant_stride *= in_dims[i]; - } - - DequantizeOneScaleQuantAxisN - <<>>(in_data, - scale_factor, - max_range, - num, - in_dims[quant_axis], - quant_stride, - out_data); - } else if (scale_num == 2) { - // Dequantize activations after quantizable operators. - // inputs --> quant --> conv2d --> deqaunt --> - // Note 1: Not need to consider 'quant_axis'. Because 'quant_axis' is the - // axis of weights to be quantized on while dequantization is applied on - // activations. Note 2: 'x_num_col_dims' is the axis of activations to be - // quantized on. `x_num_col_dims` is -1 for operator in ['matmul', - // 'matmul_v2', 'mul'] and is 1 for other operators. - int64_t num = in->numel(); - int n_scales = in->dims()[x_num_col_dims]; - const T* scale_one = scales[0]->data(); - const T* scale_two = scales[1]->data(); - - int64_t block_size = std::min( - num, static_cast(dev_ctx.GetMaxThreadsPerBlock() / 4)); - int64_t max_threads = - dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM - const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1), - static_cast(1)); - const int64_t grid_size = - std::min(max_blocks, (num + block_size - 1) / block_size); - int quant_stride = 1; - for (int i = x_num_col_dims + 1; i < in_dims.size(); i++) { - quant_stride *= in_dims[i]; - } - DequantizeTwoScale - <<>>(in_data, - scale_one, - scale_two, - max_range, - num, - n_scales, - quant_stride, - out_data); - } - } -}; - -template struct ChannelDequantizeFunctor; -template struct ChannelDequantizeFunctor; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/fake_dequantize_op.h b/paddle/fluid/operators/fake_dequantize_op.h deleted file mode 100644 index 6cfa4f06838de3..00000000000000 --- a/paddle/fluid/operators/fake_dequantize_op.h +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include - -#include "paddle/common/ddim.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" - -namespace paddle { -namespace operators { - -template -struct ChannelDequantizeFunctor { - void operator()(const DeviceContext& dev_ctx, - const phi::DenseTensor* in, - const phi::DenseTensor** scales, - const int scale_num, - T max_range, - const int quant_axis, - const int x_num_col_dims, - phi::DenseTensor* out); -}; - -template -class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel { - public: - virtual void Compute(const framework::ExecutionContext& ctx) const { - auto* in = ctx.Input("X"); - auto scales = ctx.MultiInput("Scales"); - auto* out = ctx.Output("Out"); - - auto quant_bits = ctx.Attr>("quant_bits"); - auto quant_axis = ctx.Attr("quant_axis"); - auto x_num_col_dims = ctx.Attr("x_num_col_dims"); - int max_range = 1; - - auto& dev_ctx = ctx.template device_context(); - out->mutable_data(dev_ctx.GetPlace()); - int scale_num = scales.size(); - if (scale_num == 1) { - PADDLE_ENFORCE_EQ( - scales[0]->numel(), - in->dims()[quant_axis], - phi::errors::PreconditionNotMet( - "The number of first scale values must be the same with " - "quant_axis dimension value of Input(X) when the `Scales` has " - "only one element, but %ld != %ld here.", - scales[0]->numel(), - in->dims()[quant_axis])); - max_range *= (std::pow(2, quant_bits[0] - 1) - 1); - } else if (scale_num == 2) { - PADDLE_ENFORCE_EQ( - scales[0]->numel(), - in->dims()[x_num_col_dims], - phi::errors::PreconditionNotMet( - "The number of first scale values must be the same with " - "corresponding dimension value of Input(X) when the `Scales` " - "has two elements, but %ld != %ld here.", - scales[0]->numel(), - in->dims()[1])); - PADDLE_ENFORCE_EQ(scales[1]->numel(), - 1, - phi::errors::PreconditionNotMet( - "The second scale tensor should only have one " - "value at now, but it has %ld values here.", - scales[1]->numel())); - max_range *= (std::pow(2, quant_bits[0] - 1) - 1) * - (std::pow(2, quant_bits[1] - 1) - 1); - } - ChannelDequantizeFunctor()(dev_ctx, - in, - scales.data(), - scale_num, - static_cast(max_range), - quant_axis, - x_num_col_dims, - out); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 65de2d4e3ce210..7ac5da8701dbd7 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1822,6 +1822,30 @@ void EditDistanceInferMeta(const MetaTensor& hyps, sequencenum->set_dtype(DataType::FLOAT32); } +void FakeChannelWiseDequantizeMaxAbsInferMeta( + const MetaTensor& x, + const std::vector& scales, + const std::vector& quant_bits, + int quant_axis, + int x_num_col_dims, + MetaTensor* out) { + PADDLE_ENFORCE_EQ( + quant_axis == 0 || quant_axis == 1, + true, + phi::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); + PADDLE_ENFORCE_EQ(x_num_col_dims == 0, + false, + phi::errors::InvalidArgument( + "'x_num_col_dims' should be larger than 0, but " + "the received is %d", + x_num_col_dims)); + out->set_dtype(x.dtype()); + out->share_dims(x); + out->share_lod(x); +} + void FakeQuantOrWithDequantMovingAverageAbsMaxInferMeta( const MetaTensor& x, const MetaTensor& in_scale, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index d60f8b0f3c4436..3147601f7ea177 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -371,6 +371,14 @@ void EditDistanceInferMeta(const MetaTensor& hyps, MetaTensor* sequencenum, MetaTensor* out); +void FakeChannelWiseDequantizeMaxAbsInferMeta( + const MetaTensor& x, + const std::vector& scales, + const std::vector& quant_bits, + int quant_axis, + int x_num_col_dims, + MetaTensor* out); + void FakeQuantOrWithDequantMovingAverageAbsMaxInferMeta( const MetaTensor& x, const MetaTensor& in_scale, diff --git a/paddle/phi/kernels/cpu/fake_dequantize_kernel.cc b/paddle/phi/kernels/cpu/fake_dequantize_kernel.cc index 85490e7be85172..e68b3c6b0c78aa 100644 --- a/paddle/phi/kernels/cpu/fake_dequantize_kernel.cc +++ b/paddle/phi/kernels/cpu/fake_dequantize_kernel.cc @@ -21,3 +21,10 @@ PD_REGISTER_KERNEL(fake_dequantize_max_abs, phi::FakeDequantizeMaxAbsKernel, float, double) {} + +PD_REGISTER_KERNEL(fake_channel_wise_dequantize_max_abs, + CPU, + ALL_LAYOUT, + phi::FakeChannelWiseDequantizeMaxAbsKernel, + float, + double) {} diff --git a/paddle/phi/kernels/fake_dequantize_kernel.h b/paddle/phi/kernels/fake_dequantize_kernel.h index a01940e6345aee..1bc8725fa07797 100644 --- a/paddle/phi/kernels/fake_dequantize_kernel.h +++ b/paddle/phi/kernels/fake_dequantize_kernel.h @@ -25,4 +25,13 @@ void FakeDequantizeMaxAbsKernel(const Context& dev_ctx, float max_range, DenseTensor* out); +template +void FakeChannelWiseDequantizeMaxAbsKernel( + const Context& dev_ctx, + const DenseTensor& x, + const std::vector& scales, + const std::vector& quant_bits, + int quant_axis, + int x_num_col_dims, + DenseTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/funcs/fake_dequantize_functor.cc b/paddle/phi/kernels/funcs/fake_dequantize_functor.cc index 805b46c4f0c3f4..c56a7e56b5a2b9 100644 --- a/paddle/phi/kernels/funcs/fake_dequantize_functor.cc +++ b/paddle/phi/kernels/funcs/fake_dequantize_functor.cc @@ -31,6 +31,108 @@ void DequantizeFunctor::operator()(const Context& dev_ctx, out_e.device(dev) = in_e * scale_factor[0] / max_range; } +template +void ChannelDequantizeFunctor::operator()( + const Context& dev_ctx, + const DenseTensor* in, + const DenseTensor** scales, + const int scale_num, + T max_range, + const int quant_axis, + const int x_num_col_dims, + DenseTensor* out) { + if (scale_num == 1) { + // Dequant op is before quantized op + // Dequantize the weight of quantized op + auto in_dims = in->dims(); + const int64_t channel = in_dims[quant_axis]; + const T* scale_factor = scales[0]->data(); + if (quant_axis == 0) { + for (int64_t i = 0; i < channel; i++) { + T s = scale_factor[i]; + phi::DenseTensor one_channel_in = in->Slice(i, i + 1); + phi::DenseTensor one_channel_out = out->Slice(i, i + 1); + auto in_e = phi::EigenVector::Flatten(one_channel_in); + auto out_e = phi::EigenVector::Flatten(one_channel_out); + auto& dev = *dev_ctx.eigen_device(); + out_e.device(dev) = in_e * s / max_range; + } + } else if (quant_axis == 1) { + int64_t out_iter = 1; + for (int i = 0; i < quant_axis; i++) { + out_iter *= in_dims[i]; + } + int64_t step_i = in->numel() / out_iter; + int64_t step_j = in->numel() / (out_iter * channel); + auto* in_data = in->data(); + auto* out_data = dev_ctx.template Alloc(out); + for (int64_t i = 0; i < out_iter; i++) { + for (int64_t j = 0; j < channel; j++) { + auto* cur_in = in_data + i * step_i + j * step_j; + auto* cur_out = out_data + i * step_i + j * step_j; + T s = scale_factor[j]; + for (int64_t k = 0; k < step_j; k++) { + *cur_out = (*cur_in) * s / max_range; + ++cur_in; + ++cur_out; + } + } + } + } + } else if (scale_num == 2) { + // Dequant op is after quantized op + // Dequantize the output tensor of quantized op + if (x_num_col_dims > 1) { + auto in_dims = in->dims(); + const int64_t channel = in_dims[x_num_col_dims]; + const T* scale_one = scales[0]->data(); + const T* scale_two = scales[1]->data(); + int64_t out_iter = 1; + for (int i = 0; i < x_num_col_dims; i++) { + out_iter *= in_dims[i]; + } + int64_t step_i = in->numel() / out_iter; + int64_t step_j = in->numel() / (out_iter * channel); + auto* in_data = in->data(); + auto* out_data = dev_ctx.template Alloc(out); + for (int64_t i = 0; i < out_iter; i++) { + for (int64_t j = 0; j < channel; j++) { + auto* cur_in = in_data + i * step_i + j * step_j; + auto* cur_out = out_data + i * step_i + j * step_j; + T s = scale_one[j]; + for (int64_t k = 0; k < step_j; k++) { + *cur_out = (*cur_in) * s * scale_two[0] / max_range; + ++cur_in; + ++cur_out; + } + } + } + } else { + int batch_size = static_cast(in->dims()[0]); + int channel = static_cast(in->dims()[1]); + const T* scale_one = scales[0]->data(); + const T* scale_two = scales[1]->data(); + for (int i = 0; i < batch_size; i++) { + phi::DenseTensor one_batch_in = in->Slice(i, i + 1).Resize( + common::slice_ddim(in->dims(), 1, in->dims().size())); + phi::DenseTensor one_batch_out = out->Slice(i, i + 1).Resize( + common::slice_ddim(out->dims(), 1, out->dims().size())); + for (int j = 0; j < channel; j++) { + T s = scale_one[j]; + phi::DenseTensor one_channel_in = one_batch_in.Slice(j, j + 1); + phi::DenseTensor one_channel_out = one_batch_out.Slice(j, j + 1); + auto in_e = phi::EigenVector::Flatten(one_channel_in); + auto out_e = phi::EigenVector::Flatten(one_channel_out); + auto& dev = *dev_ctx.eigen_device(); + out_e.device(dev) = in_e * s * scale_two[0] / max_range; + } + } + } + } +} + +template class ChannelDequantizeFunctor; +template class ChannelDequantizeFunctor; template class DequantizeFunctor; template class DequantizeFunctor; diff --git a/paddle/phi/kernels/funcs/fake_dequantize_functor.cu b/paddle/phi/kernels/funcs/fake_dequantize_functor.cu index b60ef42f2fb714..6a484d2511c316 100644 --- a/paddle/phi/kernels/funcs/fake_dequantize_functor.cu +++ b/paddle/phi/kernels/funcs/fake_dequantize_functor.cu @@ -49,6 +49,130 @@ void DequantizeFunctor::operator()(const Context& dev_ctx, in_data, scale_factor, max_range, num, out_data); } +template +__global__ void DequantizeOneScaleQuantAxis0( + const T* in, const T* scale, T max_range, int num, int channel, T* out) { + int tid = threadIdx.x; + int channel_size = num / channel; + const T* in_c = in + blockIdx.x * channel_size; + T* out_c = out + blockIdx.x * channel_size; + for (int i = tid; i < channel_size; i += blockDim.x) { + out_c[i] = in_c[i] * scale[blockIdx.x] / max_range; + } +} + +template +__global__ void DequantizeOneScaleQuantAxisN(const T* in, + const T* scale, + const T max_range, + const int64_t num, + const int n_scales, + const int quant_stride, + T* out) { + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) { + T s = scale[(i / quant_stride) % n_scales]; + out[i] = in[i] * s / max_range; + } +} + +template +__global__ void DequantizeTwoScale(const T* in, + const T* scale_one, + const T* scale_two, + T max_range, + int num, + int n_scales, + int quant_stride, + T* out) { + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) { + int scale_index = (i / quant_stride) % n_scales; + T s = scale_one[scale_index] * scale_two[0]; + out[i] = in[i] * s / max_range; + } +} + +template +void ChannelDequantizeFunctor::operator()( + const Context& dev_ctx, + const DenseTensor* in, + const DenseTensor** scales, + const int scale_num, + T max_range, + const int quant_axis, + const int x_num_col_dims, + DenseTensor* out) { + auto in_dims = in->dims(); + const T* in_data = in->data(); + T* out_data = dev_ctx.template Alloc(out); + if (scale_num == 1) { + // Dequantize inputs or weights before quantizable operators and after + // quantization operators. inputs --> quant -- > deqaunt --> conv2d --> + int64_t num = in->numel(); + const T* scale_factor = scales[0]->data(); + int64_t block_size = std::min( + num, static_cast(dev_ctx.GetMaxThreadsPerBlock() / 4)); + int64_t max_threads = + dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM + const int64_t max_blocks = + std::max(((max_threads - 1) / block_size + 1), static_cast(1)); + const int64_t grid_size = + std::min(max_blocks, (num + block_size - 1) / block_size); + + int quant_stride = 1; + for (int i = quant_axis + 1; i < in_dims.size(); i++) { + quant_stride *= in_dims[i]; + } + + DequantizeOneScaleQuantAxisN + <<>>(in_data, + scale_factor, + max_range, + num, + in_dims[quant_axis], + quant_stride, + out_data); + } else if (scale_num == 2) { + // Dequantize activations after quantizable operators. + // inputs --> quant --> conv2d --> deqaunt --> + // Note 1: Not need to consider 'quant_axis'. Because 'quant_axis' is the + // axis of weights to be quantized on while dequantization is applied on + // activations. Note 2: 'x_num_col_dims' is the axis of activations to be + // quantized on. `x_num_col_dims` is -1 for operator in ['matmul', + // 'matmul_v2', 'mul'] and is 1 for other operators. + int64_t num = in->numel(); + int n_scales = in->dims()[x_num_col_dims]; + const T* scale_one = scales[0]->data(); + const T* scale_two = scales[1]->data(); + + int64_t block_size = std::min( + num, static_cast(dev_ctx.GetMaxThreadsPerBlock() / 4)); + int64_t max_threads = + dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM + const int64_t max_blocks = + std::max(((max_threads - 1) / block_size + 1), static_cast(1)); + const int64_t grid_size = + std::min(max_blocks, (num + block_size - 1) / block_size); + int quant_stride = 1; + for (int i = x_num_col_dims + 1; i < in_dims.size(); i++) { + quant_stride *= in_dims[i]; + } + DequantizeTwoScale + <<>>(in_data, + scale_one, + scale_two, + max_range, + num, + n_scales, + quant_stride, + out_data); + } +} + +template class ChannelDequantizeFunctor; +template class ChannelDequantizeFunctor; +template class ChannelDequantizeFunctor; template class DequantizeFunctor; template class DequantizeFunctor; template class DequantizeFunctor; diff --git a/paddle/phi/kernels/funcs/fake_dequantize_functor.h b/paddle/phi/kernels/funcs/fake_dequantize_functor.h index eb69f0a5d863d8..bcf6568a80488d 100644 --- a/paddle/phi/kernels/funcs/fake_dequantize_functor.h +++ b/paddle/phi/kernels/funcs/fake_dequantize_functor.h @@ -34,5 +34,18 @@ class DequantizeFunctor { DenseTensor* out); }; +template +class ChannelDequantizeFunctor { + public: + void operator()(const Context& dev_ctx, + const DenseTensor* in, + const DenseTensor** scales, + const int scale_num, + T max_range, + const int quant_axis, + const int x_num_col_dims, + DenseTensor* out); +}; + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/fake_dequantize_kernel.cu b/paddle/phi/kernels/gpu/fake_dequantize_kernel.cu index 97e8fd594c8d2e..3b2ac8dec44f3f 100644 --- a/paddle/phi/kernels/gpu/fake_dequantize_kernel.cu +++ b/paddle/phi/kernels/gpu/fake_dequantize_kernel.cu @@ -22,3 +22,11 @@ PD_REGISTER_KERNEL(fake_dequantize_max_abs, float, double, phi::dtype::float16) {} + +PD_REGISTER_KERNEL(fake_channel_wise_dequantize_max_abs, + GPU, + ALL_LAYOUT, + phi::FakeChannelWiseDequantizeMaxAbsKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/impl/fake_dequantize_kernel_impl.h b/paddle/phi/kernels/impl/fake_dequantize_kernel_impl.h index 8eb5a159d70a84..d7e3cee6c70aa9 100644 --- a/paddle/phi/kernels/impl/fake_dequantize_kernel_impl.h +++ b/paddle/phi/kernels/impl/fake_dequantize_kernel_impl.h @@ -30,4 +30,56 @@ void FakeDequantizeMaxAbsKernel(const Context& dev_ctx, dev_ctx, &x, &scale, static_cast(max_range), out); } +template +void FakeChannelWiseDequantizeMaxAbsKernel( + const Context& dev_ctx, + const DenseTensor& x, + const std::vector& scales, + const std::vector& quant_bits, + int quant_axis, + int x_num_col_dims, + DenseTensor* out) { + int max_range = 1; + dev_ctx.template Alloc(out); + int scale_num = scales.size(); + if (scale_num == 1) { + PADDLE_ENFORCE_EQ( + scales[0]->numel(), + x.dims()[quant_axis], + phi::errors::PreconditionNotMet( + "The number of first scale values must be the same with " + "quant_axis dimension value of Input(X) when the `Scales` has " + "only one element, but %ld != %ld here.", + scales[0]->numel(), + x.dims()[quant_axis])); + max_range *= (std::pow(2, quant_bits[0] - 1) - 1); + } else if (scale_num == 2) { + PADDLE_ENFORCE_EQ( + scales[0]->numel(), + x.dims()[x_num_col_dims], + phi::errors::PreconditionNotMet( + "The number of first scale values must be the same with " + "corresponding dimension value of Input(X) when the `Scales` " + "has two elements, but %ld != %ld here.", + scales[0]->numel(), + x.dims()[1])); + PADDLE_ENFORCE_EQ(scales[1]->numel(), + 1, + phi::errors::PreconditionNotMet( + "The second scale tensor should only have one " + "value at now, but it has %ld values here.", + scales[1]->numel())); + max_range *= (std::pow(2, quant_bits[0] - 1) - 1) * + (std::pow(2, quant_bits[1] - 1) - 1); + } + phi::funcs::ChannelDequantizeFunctor()(dev_ctx, + &x, + scales.data(), + scale_num, + static_cast(max_range), + quant_axis, + x_num_col_dims, + out); +} + } // namespace phi diff --git a/paddle/phi/ops/yaml/op_compat.yaml b/paddle/phi/ops/yaml/op_compat.yaml index 3435ea6c467890..5d45654708684a 100755 --- a/paddle/phi/ops/yaml/op_compat.yaml +++ b/paddle/phi/ops/yaml/op_compat.yaml @@ -1112,6 +1112,12 @@ num_columns : support_tensor : true +- op : fake_channel_wise_dequantize_max_abs + inputs : + {x : X, scales : Scales} + outputs : + out : Out + - op : fake_channel_wise_quantize_abs_max inputs : x : X diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index bca412a8cce37b..9636ef54b9df56 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -1389,6 +1389,15 @@ data_type : dtype backend : place +- op : fake_channel_wise_dequantize_max_abs + args : (Tensor x, Tensor[] scales, int[] quant_bits = {8}, int quant_axis = 0, int x_num_col_dims = 1) + output : Tensor(out) + infer_meta : + func : FakeChannelWiseDequantizeMaxAbsInferMeta + kernel : + func : fake_channel_wise_dequantize_max_abs + data_type : x + - op : fake_channel_wise_quantize_abs_max args : (Tensor x, int bit_length = 8, int round_type = 1, int quant_axis = 0, bool is_test = false) output : Tensor(out), Tensor(out_scale) diff --git a/test/legacy_test/test_fake_dequantize_op.py b/test/legacy_test/test_fake_dequantize_op.py index 044aaedb6759ba..2da5ca4f65d51d 100644 --- a/test/legacy_test/test_fake_dequantize_op.py +++ b/test/legacy_test/test_fake_dequantize_op.py @@ -102,7 +102,7 @@ def setUp(self): self.outputs = {'Out': ydq} def test_check_output(self): - self.check_output() + self.check_output(check_dygraph=False) class TestFakeChannelWiseDequantizeMaxAbsOpTwoScalesFloat16( @@ -112,7 +112,7 @@ def set_dtype(self): self.dtype = np.float16 def test_check_output(self): - self.check_output(atol=1e-2) + self.check_output(check_dygraph=False, atol=1e-2) class TestFakeChannelWiseDequantizeMaxAbsOpOneScale(OpTest): @@ -146,7 +146,7 @@ def setUp(self): self.outputs = {'Out': ydq} def test_check_output(self): - self.check_output() + self.check_output(check_dygraph=False) class TestFakeChannelWiseDequantizeMaxAbsOpOneScale1( @@ -164,7 +164,7 @@ def set_dtype(self): self.dtype = np.float16 def test_check_output(self): - self.check_output(atol=1e-2) + self.check_output(check_dygraph=False, atol=1e-2) class TestFakeChannelWiseDequantizeMaxAbsOpOneScale1Float16( @@ -174,7 +174,7 @@ def set_dtype(self): self.dtype = np.float16 def test_check_output(self): - self.check_output(atol=1e-2) + self.check_output(check_dygraph=False, atol=1e-2) class TestFakeDequantizeMaxAbsOp(OpTest): From e3d69d61d8890fcda9e0c47569cad2826436008d Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Fri, 24 May 2024 15:21:28 +0800 Subject: [PATCH 2/3] fix impl --- .../kernels/impl/fake_dequantize_kernel_impl.h | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/paddle/phi/kernels/impl/fake_dequantize_kernel_impl.h b/paddle/phi/kernels/impl/fake_dequantize_kernel_impl.h index d7e3cee6c70aa9..a952dc84c9675d 100644 --- a/paddle/phi/kernels/impl/fake_dequantize_kernel_impl.h +++ b/paddle/phi/kernels/impl/fake_dequantize_kernel_impl.h @@ -72,14 +72,15 @@ void FakeChannelWiseDequantizeMaxAbsKernel( max_range *= (std::pow(2, quant_bits[0] - 1) - 1) * (std::pow(2, quant_bits[1] - 1) - 1); } - phi::funcs::ChannelDequantizeFunctor()(dev_ctx, - &x, - scales.data(), - scale_num, - static_cast(max_range), - quant_axis, - x_num_col_dims, - out); + phi::funcs::ChannelDequantizeFunctor()( + dev_ctx, + &x, + (const_cast>(&scales)).data(), + scale_num, + static_cast(max_range), + quant_axis, + x_num_col_dims, + out); } } // namespace phi From 4aa558997425995e0e57720e7e00d1a23615ab01 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Fri, 24 May 2024 15:45:34 +0800 Subject: [PATCH 3/3] fix impl --- paddle/phi/kernels/impl/fake_dequantize_kernel_impl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/impl/fake_dequantize_kernel_impl.h b/paddle/phi/kernels/impl/fake_dequantize_kernel_impl.h index a952dc84c9675d..2d95f26cf4cea9 100644 --- a/paddle/phi/kernels/impl/fake_dequantize_kernel_impl.h +++ b/paddle/phi/kernels/impl/fake_dequantize_kernel_impl.h @@ -75,7 +75,7 @@ void FakeChannelWiseDequantizeMaxAbsKernel( phi::funcs::ChannelDequantizeFunctor()( dev_ctx, &x, - (const_cast>(&scales)).data(), + (const_cast*>(&scales))->data(), scale_num, static_cast(max_range), quant_axis,