Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions paddle/fluid/operators/dequantize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ framework::OpKernelType DeQuantOp::GetExpectedKernelType(
}

void DeQuantOpMaker::Make() {
AddInput("Input", "input data");
AddOutput("Output", "output data");
AddAttr<float>("Scale", "scale data").SetDefault({1.0f});
AddInput("Input", "Input data");
AddOutput("Output", "Output data");
AddAttr<float>("Scale", "Scale data").SetDefault({1.0f});
AddAttr<float>("Shift", "Shift data").SetDefault({0.0f});
AddComment(R"DOC(This op will dequantize data from INT8 to FP32)DOC");
}

Expand Down
30 changes: 28 additions & 2 deletions paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/dequantize_op.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"

Expand All @@ -37,14 +38,29 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input");
auto scale_data = ctx.Attr<float>("Scale");
auto scale_shift = ctx.Attr<float>("Shift");
bool with_shift = scale_shift != 0.0f;
auto* output = ctx.Output<Tensor>("Output");

PADDLE_ENFORCE_NE(scale_data, 0.0f,
platform::errors::InvalidArgument(
"Dequantization scale cannot be 0.0"));
PADDLE_ENFORCE_GE(scale_shift, 0,
platform::errors::Unimplemented(
"Dequantization shift must be nonnegative."));
PADDLE_ENFORCE_LE(
scale_shift, 255,
platform::errors::Unimplemented(
"Dequantization shift must be less than or equal to 255."));

auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& engine = dev_ctx.GetEngine();

const T* input_data = input->data<T>();
float* output_data = output->mutable_data<float>(ctx.GetPlace());
std::vector<float> reorder_scale = {1.0f / scale_data};

float reorder_shift = -scale_shift / scale_data;

auto src_tz = paddle::framework::vectorize<int64_t>(input->dims());
auto dst_tz = paddle::framework::vectorize<int64_t>(output->dims());
Expand All @@ -65,7 +81,15 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
if (reorder_p == nullptr) {
mkldnn::primitive_attr attri;
int mask = 0;
attri.set_output_scales(mask, reorder_scale);
float reorder_scale = 1. / scale_data;
attri.set_output_scales(mask, {reorder_scale});

if (with_shift) {
mkldnn::post_ops post_operations;
post_operations.append_sum();
attri.set_post_ops(post_operations);
std::fill(output_data, output_data + output->numel(), reorder_shift);
}

auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt);
src_memory = std::make_shared<mkldnn::memory>(
Expand All @@ -92,6 +116,8 @@ class DeQuantOpKernel : public framework::OpKernel<T> {

dst_memory = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(key_dst_mem));
if (with_shift)
std::fill(output_data, output_data + output->numel(), reorder_shift);
dst_memory->set_data_handle(output->mutable_data<float>(ctx.GetPlace()));
}

Expand Down
43 changes: 35 additions & 8 deletions paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,21 @@ class QuantOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input");
auto scale_data = ctx.Attr<float>("Scale");
auto scale_shift = ctx.Attr<float>("Shift");
bool with_shift = scale_shift != 0.0f;
auto* output = ctx.Output<Tensor>("Output");

PADDLE_ENFORCE_NE(
scale_data, 0.0f,
platform::errors::InvalidArgument("Quantization scale cannot be 0.0"));
PADDLE_ENFORCE_GE(scale_shift, 0,
platform::errors::Unimplemented(
"Quantization shift must be nonnegative."));
PADDLE_ENFORCE_LE(
scale_shift, 255,
platform::errors::Unimplemented(
"Quantization shift must be less than or equal to 255."));

auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& engine = dev_ctx.GetEngine();
Expand All @@ -47,11 +61,12 @@ class QuantOpKernel : public framework::OpKernel<T> {

const T* input_data = input->data<T>();

bool is_negative = ctx.Attr<bool>("is_negative_input");
bool is_negative_input = ctx.Attr<bool>("is_negative_input");
bool bfloat16 = ctx.Attr<bool>("bfloat16");
std::string key =
platform::CreateKey(platform::ThreadIDasStr(), src_tz, scale_data,
is_negative, ctx.OutputName("Output"));

std::string key = platform::CreateKey(
platform::ThreadIDasStr(), src_tz, scale_data, scale_shift,
is_negative_input, ctx.OutputName("Output"));
const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d";
Expand All @@ -69,6 +84,15 @@ class QuantOpKernel : public framework::OpKernel<T> {
int mask = 0;
attri.set_output_scales(mask, {scale_data});

if (with_shift) {
mkldnn::post_ops post_operations;
post_operations.append_sum();
attri.set_post_ops(post_operations);
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
// memset casts scale_shift to unsigned char (uint8_t) internally
std::memset(output_data, scale_shift, output->numel());
}

auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32,
input->format());
src_memory = std::make_shared<mkldnn::memory>(
Expand All @@ -78,7 +102,7 @@ class QuantOpKernel : public framework::OpKernel<T> {
if (bfloat16) {
platform::SetDstMemoryQuantized<paddle::platform::bfloat16>(
ctx, output, dst_tz, engine, dst_md, dst_memory, out_format);
} else if (is_negative) {
} else if (is_negative_input && !with_shift) {
platform::SetDstMemoryQuantized<int8_t>(ctx, output, dst_tz, engine,
dst_md, dst_memory, out_format);
} else {
Expand All @@ -104,10 +128,13 @@ class QuantOpKernel : public framework::OpKernel<T> {
if (bfloat16) {
dst_memory->set_data_handle(
output->mutable_data<paddle::platform::bfloat16>(place));
} else if (is_negative) {
dst_memory->set_data_handle(output->mutable_data<int8_t>(place));
} else if (with_shift || !is_negative_input) {
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
if (with_shift) std::memset(output_data, scale_shift, output->numel());
dst_memory->set_data_handle(output_data);
} else {
dst_memory->set_data_handle(output->mutable_data<uint8_t>(place));
dst_memory->set_data_handle(
output->mutable_data<int8_t>(ctx.GetPlace()));
}
}

Expand Down
72 changes: 58 additions & 14 deletions paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,45 @@ using dnnl::reorder;
using platform::to_void_cast;
using Tensor = framework::Tensor;

namespace {

inline uint8_t clip_to_uint8(float x) {
return std::max(0L, std::min(255L, std::lround(x)));
}

} // namespace

template <typename T>
class ReQuantOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input");
auto scale_in = ctx.Attr<float>("Scale_in");
auto shift_in = ctx.Attr<float>("Shift_in");
auto scale_out = ctx.Attr<float>("Scale_out");
auto shift_out = ctx.Attr<float>("Shift_out");
bool with_shift = shift_in != 0.0f || shift_out != 0.0f;
auto* output = ctx.Output<Tensor>("Output");

PADDLE_ENFORCE_NE(scale_in, 0.0f, platform::errors::InvalidArgument(
"Scale of input cannot be 0.0"));
PADDLE_ENFORCE_NE(scale_out, 0.0f, platform::errors::InvalidArgument(
"Scale of output cannot be 0.0"));
if (shift_in != 0.0f) {
PADDLE_ENFORCE_EQ(
input->type(), framework::proto::VarType::UINT8,
platform::errors::Unimplemented("Requantize does not support nonzero "
"shift for signed input."));
}

auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& engine = dev_ctx.GetEngine();

auto src_tz = paddle::framework::vectorize(input->dims());

float reorder_scale = scale_out / scale_in;

std::string key =
platform::CreateKey(platform::ThreadIDasStr(), src_tz, scale_in,
scale_out, ctx.OutputName("Output"));
Expand All @@ -53,28 +78,37 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
reorder_p = std::static_pointer_cast<reorder>(dev_ctx.GetBlob(key_prim));

const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());

if (reorder_p == nullptr) {
dnnl::primitive_attr attri;
int mask = 0;
float scale_shift = scale_out / scale_in;
attri.set_output_scales(mask, {scale_shift});

auto dst_tz = paddle::framework::vectorize(output->dims());
dnnl::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type());
dnnl::memory::data_type dst_dt = src_dt;
auto dst_tz = framework::vectorize(output->dims());
auto src_dt = framework::ToMKLDNNDataType(input->type());
auto dst_dt = with_shift ? framework::MKLDNNDataType::u8 : src_dt;

auto src_md =
platform::MKLDNNMemDesc({src_tz}, src_dt, MKLDNNMemoryFormat::nhwc);
src_memory = std::make_shared<dnnl::memory>(src_md, engine,
to_void_cast<T>(input_data));

auto dst_md =
platform::MKLDNNMemDesc({dst_tz}, dst_dt, MKLDNNMemoryFormat::nhwc);
dst_memory = std::make_shared<dnnl::memory>(dst_md, engine,
to_void_cast<T>(output_data));

dnnl::primitive_attr attri;
int mask = 0;
attri.set_output_scales(mask, {reorder_scale});
if (with_shift) {
mkldnn::post_ops post_operations;
post_operations.append_sum();
attri.set_post_ops(post_operations);
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
uint8_t reorder_shift =
clip_to_uint8(shift_out - reorder_scale * shift_in);
std::memset(output_data, reorder_shift, output->numel());
dst_memory = std::make_shared<dnnl::memory>(
dst_md, engine, to_void_cast<uint8_t>(output_data));
} else {
T* output_data = output->mutable_data<T>(ctx.GetPlace());
dst_memory = std::make_shared<dnnl::memory>(
dst_md, engine, to_void_cast<T>(output_data));
}

auto reorder_pd =
reorder::primitive_desc(*src_memory, *dst_memory, attri);
Expand All @@ -90,7 +124,17 @@ class ReQuantOpKernel : public framework::OpKernel<T> {

dst_memory =
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(key_dst_mem));
dst_memory->set_data_handle(output_data);
if (with_shift) {
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
uint8_t reorder_shift =
clip_to_uint8(shift_out - reorder_scale * shift_in);
std::memset(output_data, reorder_shift, output->numel());
dst_memory->set_data_handle(output_data);

} else {
T* output_data = output->mutable_data<T>(ctx.GetPlace());
dst_memory->set_data_handle(output_data);
}
}

dnnl::stream astream(engine);
Expand Down
10 changes: 7 additions & 3 deletions paddle/fluid/operators/quantize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@ framework::OpKernelType QuantOp::GetExpectedKernelType(
}

void QuantOpMaker::Make() {
AddInput("Input", "input data");
AddOutput("Output", "output data");
AddInput("Input", "Input data");
AddOutput("Output", "Output data");
AddAttr<bool>("is_negative_input",
"(bool, default false) Only used in mkldnn INT8 kernel")
.SetDefault(false);
AddAttr<float>("Scale", "scale data").SetDefault({1.0f});
AddAttr<float>("Scale", "Scale data").SetDefault({1.0f});
AddAttr<float>(
"Shift",
"Shift data. When Shift is non-zero, data is quantized to unsigned int8.")
.SetDefault({0.0f});
AddAttr<std::string>("output_format",
"Convert format to NHWC or NCHW during quantization.")
.SetDefault("NHWC");
Expand Down
10 changes: 6 additions & 4 deletions paddle/fluid/operators/requantize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ framework::OpKernelType ReQuantOp::GetExpectedKernelType(
}

void ReQuantOpMaker::Make() {
AddInput("Input", "input data");
AddOutput("Output", "output data");
AddAttr<float>("Scale_in", "scale in data").SetDefault({1.0f});
AddAttr<float>("Scale_out", "scale out data").SetDefault({1.0f});
AddInput("Input", "Input data");
AddOutput("Output", "Output data");
AddAttr<float>("Scale_in", "Scale in data").SetDefault({1.0f});
AddAttr<float>("Scale_out", "Scale out data").SetDefault({1.0f});
AddAttr<float>("Shift_in", "Shift in data").SetDefault({1.0f});
AddAttr<float>("Shift_out", "Shift out data").SetDefault({1.0f});
AddComment(
R"DOC(This op will re-quantize data from INT8 with scale_in to INT8 with scale_out)DOC");
}
Expand Down
Loading