diff --git a/paddle/fluid/operators/dequantize_abs_max_op.cc b/paddle/fluid/operators/dequantize_abs_max_op.cc deleted file mode 100644 index b9d5f0bb29200e..00000000000000 --- a/paddle/fluid/operators/dequantize_abs_max_op.cc +++ /dev/null @@ -1,117 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/dequantize_abs_max_op.h" - -#include - -namespace paddle { -namespace framework { -class InferShapeContext; -class OpDesc; -template -class EmptyGradOpMaker; -} // namespace framework -namespace imperative { -class OpBase; -} // namespace imperative -namespace platform {} // namespace platform -} // namespace paddle - -namespace paddle { -namespace operators { - -template -struct DequantizeFunctor { - void operator()(const phi::CPUContext& dev_ctx, - const phi::DenseTensor* in, - const phi::DenseTensor* scale, - float max_range, - phi::DenseTensor* out) { - const float* scale_factor = scale->data(); - const T* input_data = in->data(); - float* output_data = out->mutable_data(dev_ctx.GetPlace()); - int ind = static_cast(in->numel()); - for (size_t i = 0; i < (unsigned)ind; i++) { - output_data[i] = scale_factor[0] * input_data[i] / max_range; - } - } -}; - -template struct DequantizeFunctor; -template struct DequantizeFunctor; - -class DequantizeMaxAbsOp : public framework::OperatorWithKernel { - public: - DequantizeMaxAbsOp(const std::string& type, - const framework::VariableNameMap& inputs, - const framework::VariableNameMap& outputs, - const framework::AttributeMap& attrs) - : OperatorWithKernel(type, inputs, outputs, attrs) {} - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "DequantizeMaxAbs"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "DequantizeMaxAbs"); - - ctx->ShareDim("X", /*->*/ "Out"); - ctx->ShareLoD("X", /*->*/ "Out"); - } - - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return phi::KernelKey(data_type, ctx.device_context().GetPlace()); - } -}; - -class DequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "(Int Tensor) The input with int8/16 type is the " - "low precision tensor."); - AddInput("Scale", "(float) The scale in quantization stage."); - AddOutput("Out", - "(float32 Tensor) The output is the dequantized high " - "precision tensor."); - AddAttr("max_range", "(float) The max range in quantization stage."); - AddComment(R"DOC( -DequantizeMaxAbsOp operator. - -This calculation is an opposite operation of QuantizeMaxAbsOp: - -$$Out = \frac{scale*X}{ max\_range }$$ - -)DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OPERATOR( - dequantize_abs_max, - ops::DequantizeMaxAbsOp, - ops::DequantizeMaxAbsOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); - -PD_REGISTER_STRUCT_KERNEL(dequantize_abs_max, - CPU, - ALL_LAYOUT, - ops::DequantizeMaxAbsKernel, - int8_t, - int16_t) {} diff --git a/paddle/fluid/operators/dequantize_abs_max_op.cu b/paddle/fluid/operators/dequantize_abs_max_op.cu deleted file mode 100644 index 93fc6009d556cb..00000000000000 --- a/paddle/fluid/operators/dequantize_abs_max_op.cu +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/dequantize_abs_max_op.h" - -namespace paddle { -namespace operators { - -template -__global__ void KeDequantize( - const T* in, const float* scale, float max_range, int num, float* out) { - const int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < num) { - out[idx] = in[idx] * scale[0] / max_range; - } -} - -template -struct DequantizeFunctor { - void operator()(const phi::GPUContext& dev_ctx, - const phi::DenseTensor* in, - const phi::DenseTensor* scale, - float max_range, - phi::DenseTensor* out) { - const T* in_data = in->data(); - const float* scale_factor = scale->data(); - float* out_data = out->mutable_data(dev_ctx.GetPlace()); - - int num = in->numel(); - int block = 512; - int grid = (num + block - 1) / block; - - KeDequantize<<>>( - in_data, scale_factor, max_range, num, out_data); - } -}; - -template struct DequantizeFunctor; -template struct DequantizeFunctor; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -PD_REGISTER_STRUCT_KERNEL(dequantize_abs_max, - GPU, - ALL_LAYOUT, - ops::DequantizeMaxAbsKernel, - int8_t, - int16_t) {} diff --git a/paddle/fluid/operators/dequantize_abs_max_op.h b/paddle/fluid/operators/dequantize_abs_max_op.h deleted file mode 100644 index 3796c1fe3f9e3d..00000000000000 --- a/paddle/fluid/operators/dequantize_abs_max_op.h +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include - -#include "paddle/common/ddim.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" - -namespace phi { -class DenseTensor; -} // namespace phi - -namespace paddle { -namespace operators { - -template -struct DequantizeFunctor { - void operator()(const DeviceContext& dev_ctx, - const phi::DenseTensor* in, - const phi::DenseTensor* scale, - float max_range, - phi::DenseTensor* out); -}; - -template -class DequantizeMaxAbsKernel : public framework::OpKernel { - public: - virtual void Compute(const framework::ExecutionContext& ctx) const { - auto* in = ctx.Input("X"); - auto* scale = ctx.Input("Scale"); - - auto* out = ctx.Output("Out"); - - float max_range = ctx.Attr("max_range"); - - auto& dev_ctx = ctx.template device_context(); - out->mutable_data(dev_ctx.GetPlace()); - - DequantizeFunctor()(dev_ctx, in, scale, max_range, out); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/unity_build_rule.cmake b/paddle/fluid/operators/unity_build_rule.cmake index b19533b005a94e..5fa44e2a566dd0 100644 --- a/paddle/fluid/operators/unity_build_rule.cmake +++ b/paddle/fluid/operators/unity_build_rule.cmake @@ -67,7 +67,6 @@ register_unity_group( deformable_conv_v1_op.cc deformable_psroi_pooling_op.cc delete_var_op.cc - dequantize_abs_max_op.cc dequantize_op.cc onednn/dequantize_onednn_op.cc) register_unity_group( @@ -362,8 +361,7 @@ register_unity_group( unzip_op.cu data_norm_op.cu deformable_conv_op.cu - deformable_conv_v1_op.cu - dequantize_abs_max_op.cu) + deformable_conv_v1_op.cu) register_unity_group( cu dgc_clip_by_norm_op.cu diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index cca6a7245b88b3..ab8c81110b5718 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -809,6 +809,12 @@ attrs : {scale : Scale, shift : Shift} +- op : dequantize_abs_max + inputs : + {x : X, scale : Scale} + outputs : + out : Out + - op : dequantize_linear inputs : {x : X, scale : Scale, zero_point : ZeroPoint, in_accum : InAccum, in_state : InState} diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index d9a9ba4beaefc5..64d7e22ac28b51 100755 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -776,6 +776,15 @@ data_type : input backward : depthwise_conv2d_grad +- op : dequantize_abs_max + args : (Tensor x, Tensor scale, float max_range) + output : Tensor(out) + infer_meta : + func : DequantizeAbsMaxInferMeta + kernel : + func : dequantize_abs_max + data_type : x + - op : det args : (Tensor x) output : Tensor diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index fe14db9ccb415e..5212a6fe872224 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1085,6 +1085,15 @@ void DepthwiseConvInferMeta(const MetaTensor& input, config); } +void DequantizeAbsMaxInferMeta(const MetaTensor& x, + const MetaTensor& scale, + float max_range, + MetaTensor* out) { + out->set_dtype(x.dtype()); + out->share_dims(x); + out->share_lod(x); +} + void DequantizeLogInferMeta(const MetaTensor& x, const MetaTensor& dict, MetaTensor* out) { diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index c43b4e852a19ea..bd8517f73898e0 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -182,6 +182,11 @@ void DepthwiseConvInferMeta(const MetaTensor& input, MetaTensor* out, MetaConfig config = MetaConfig()); +void DequantizeAbsMaxInferMeta(const MetaTensor& x, + const MetaTensor& scale, + float max_range, + MetaTensor* out); + void DequantizeLogInferMeta(const MetaTensor& x, const MetaTensor& dict, MetaTensor* out); diff --git a/paddle/phi/kernels/cpu/dequantize_abs_max_kernel.cc b/paddle/phi/kernels/cpu/dequantize_abs_max_kernel.cc new file mode 100644 index 00000000000000..d1c2006b9afd16 --- /dev/null +++ b/paddle/phi/kernels/cpu/dequantize_abs_max_kernel.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/dequantize_abs_max_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void DequantizeAbsMaxKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& scale, + float max_range, + DenseTensor* out) { + const float* scale_factor = scale.data(); + const T* input_data = x.data(); + float* output_data = dev_ctx.template Alloc(out); + int ind = static_cast(x.numel()); + for (size_t i = 0; i < (unsigned)ind; i++) { + output_data[i] = scale_factor[0] * input_data[i] / max_range; + } +} +} // namespace phi + +PD_REGISTER_KERNEL(dequantize_abs_max, + CPU, + ALL_LAYOUT, + phi::DequantizeAbsMaxKernel, + int8_t, + int16_t) {} diff --git a/paddle/phi/kernels/dequantize_abs_max_kernel.h b/paddle/phi/kernels/dequantize_abs_max_kernel.h new file mode 100644 index 00000000000000..f0084e37c20250 --- /dev/null +++ b/paddle/phi/kernels/dequantize_abs_max_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void DequantizeAbsMaxKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& scale, + float max_range, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/dequantize_abs_max_kernel.cu b/paddle/phi/kernels/gpu/dequantize_abs_max_kernel.cu new file mode 100644 index 00000000000000..cb8fe971084978 --- /dev/null +++ b/paddle/phi/kernels/gpu/dequantize_abs_max_kernel.cu @@ -0,0 +1,59 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/dequantize_abs_max_kernel.h" + +#include "paddle/common/hostdevice.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_primitives.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math.h" + +namespace phi { + +template +__global__ void KeDequantize( + const T* in, const float* scale, float max_range, int num, float* out) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < num) { + out[idx] = in[idx] * scale[0] / max_range; + } +} + +template +void DequantizeAbsMaxKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& scale, + float max_range, + DenseTensor* out) { + const T* in_data = x.data(); + const float* scale_factor = scale.data(); + float* out_data = dev_ctx.template Alloc(out); + + int num = x.numel(); + int block = 512; + int grid = (num + block - 1) / block; + + KeDequantize<<>>( + in_data, scale_factor, max_range, num, out_data); +} + +} // namespace phi + +PD_REGISTER_KERNEL(dequantize_abs_max, + GPU, + ALL_LAYOUT, + phi::DequantizeAbsMaxKernel, + int8_t, + int16_t) {} diff --git a/test/legacy_test/test_dequantize_abs_max_op.py b/test/legacy_test/test_dequantize_abs_max_op.py index 0df5a3fda11c2f..2917a05a00a14f 100644 --- a/test/legacy_test/test_dequantize_abs_max_op.py +++ b/test/legacy_test/test_dequantize_abs_max_op.py @@ -51,7 +51,7 @@ def setUp(self): self.outputs = {'Out': ydq} def test_check_output(self): - self.check_output() + self.check_output(check_dygraph=False) class TestDequantizeMaxAbsOp5Bits(TestDequantizeMaxAbsOp):