Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3113,6 +3113,25 @@ struct CEmbeddingOpTranscriber : public OpTranscriber {
}
};

struct QuantizeLinearOpTranscriber : public OpTranscriber {
void HandleNonexistentAttribute(pir::IrContext* ctx,
pir::AttributeMap* attribute_map,
const OpAttributeInfo& info) override {
if (info.name == "round_type") {
(*attribute_map)[info.name] = pir::Int32Attribute::get(ctx, 0);
}
if (info.name == "is_test") {
(*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, true);
}
if (info.name == "only_observer") {
(*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, false);
}
if (info.name == "moving_rate") {
(*attribute_map)[info.name] = pir::FloatAttribute::get(ctx, 0.9);
}
}
};

OpTranslator::OpTranslator() {
pir::IrContext* ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
Expand Down Expand Up @@ -3185,6 +3204,8 @@ OpTranslator::OpTranslator() {
special_handlers["elementwise_mod_grad"] = ElementwiseGradTranscriber();
special_handlers["elementwise_floordiv_grad"] = ElementwiseGradTranscriber();
special_handlers["c_embedding"] = CEmbeddingOpTranscriber();
special_handlers["quantize_linear"] = QuantizeLinearOpTranscriber();
special_handlers["dequantize_linear"] = QuantizeLinearOpTranscriber();
}

} // namespace translator
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@
'lrn',
'multi_gru',
'matmul_with_flatten',
'moving_average_abs_max_scale',
'moving_average_abs_max_scale_',
'quantize_linear',
'quantize_linear_',
'dequantize_linear',
'dequantize_linear_',
]

NO_NEED_GEN_STATIC_ONLY_APIS = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ KernelKeyTuple UniqueOpParseKernelKey(pir::Operation* op) {
return {dtype, backend};
}

KernelKeyTuple SaveCombineOpParseKernelKey(pir::Operation* op) {
return {phi::DataType::FLOAT32, phi::Backend::UNDEFINED};
}

} // namespace paddle::dialect

IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ParseKernelKeyInterface)
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class ParseKernelKeyInterface
// Register the ParseKernelKeyInterface for unique op.
KernelKeyTuple UniqueOpParseKernelKey(pir::Operation *op);

KernelKeyTuple SaveCombineOpParseKernelKey(pir::Operation *op);

} // namespace dialect
} // namespace paddle

Expand Down
39 changes: 39 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,19 @@
data_type : x
backward : depthwise_conv2d_transpose_grad

- op : dequantize_linear
args : (Tensor x, Tensor scale, Tensor zero_point, Tensor in_accum, Tensor in_state, int quant_axis = 0, int bit_length = 8, int round_type = 0, bool is_test = true, bool only_observer = false, float moving_rate=0.9f)
output : Tensor(y), Tensor(out_scale), Tensor(out_accum), Tensor(out_state)
infer_meta :
func : QuantizeLinearInferMeta
param : [x, scale, in_accum, in_state, quant_axis]
kernel :
func : quantize_linear
param : [x, scale, zero_point, in_accum, in_state, quant_axis, bit_length, round_type, is_test, only_observer, moving_rate]
data_type : x
optional : in_accum, in_state, out_scale, out_accum, out_state
inplace : (scale -> out_scale, in_accum -> out_accum, in_state -> out_state)

- op : disable_check_model_nan_inf
args: (Tensor x, int flag = 0)
output: Tensor(out)
Expand Down Expand Up @@ -1083,6 +1096,19 @@
data_type : out_grad_in
inplace: (out_grad_in -> out_grad_out)

- op : quantize_linear
args : (Tensor x, Tensor scale, Tensor zero_point, Tensor in_accum, Tensor in_state, int quant_axis = 0, int bit_length = 8, int round_type = 0, bool is_test = true, bool only_observer = false, float moving_rate=0.9f)
output : Tensor(y), Tensor(out_scale), Tensor(out_accum), Tensor(out_state)
infer_meta :
func : QuantizeLinearInferMeta
param : [x, scale, in_accum, in_state, quant_axis]
kernel :
func : quantize_linear
param : [x, scale, zero_point, in_accum, in_state, quant_axis, bit_length, round_type, is_test, only_observer, moving_rate]
data_type : x
optional : in_accum, in_state, out_scale, out_accum, out_state
inplace : (scale -> out_scale, in_accum -> out_accum, in_state -> out_state)

- op : randint
args : (int low, int high, IntArray shape, DataType dtype=DataType::INT64, Place place={})
output : Tensor(out)
Expand Down Expand Up @@ -1215,6 +1241,7 @@
func: save_combine_tensor
param: [x, file_path, overwrite, save_as_fp16, save_to_memory]
optional : out
interfaces : paddle::dialect::ParseKernelKeyInterface

- op : seed
args : (int seed, bool deterministic, str rng_name, bool force_cpu)
Expand Down Expand Up @@ -1635,6 +1662,18 @@
func: match_matrix_tensor
backward: match_matrix_tensor_grad

- op: moving_average_abs_max_scale
args: (Tensor x, Tensor in_accum, Tensor in_state, float moving_rate=0.9f, bool is_test=false)
output: Tensor(out), Tensor(out_scale), Tensor(out_state), Tensor(out_accum)
infer_meta:
func: MovingAverageAbsMaxScaleInferMeta
param: [x, in_accum, in_state]
kernel:
func: moving_average_abs_max_scale
param: [x, in_accum, in_state, moving_rate, is_test]
optional : in_accum, in_state, out, out_state, out_accum
inplace : (in_accum -> out_accum), (in_state -> out_state)

- op: nce
args: (Tensor input, Tensor label, Tensor weight, Tensor bias, Tensor sample_weight, Tensor custom_dist_probs, Tensor custom_dist_alias, Tensor custom_dist_alias_probs, int num_total_classes, int[] custom_neg_classes={}, int num_neg_samples=10, int sampler=0, int seed=0, bool is_sparse=false, bool remote_prefetch=false, bool is_test=false)
output: Tensor(cost), Tensor(sample_logits), Tensor(sample_labels)
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ const std::unordered_set<std::string> LegacyOpList = {
NceGradOp::name(),
LrnOp::name(),
LrnGradOp::name(),
MovingAverageAbsMaxScaleOp::name(),
MovingAverageAbsMaxScale_Op::name(),
QuantizeLinearOp::name(),
QuantizeLinear_Op::name(),
DequantizeLinearOp::name(),
DequantizeLinear_Op::name(),
#ifdef PADDLE_WITH_DNNL
paddle::onednn::dialect::LrnOp::name(),
paddle::onednn::dialect::LrnGradOp::name(),
Expand Down
14 changes: 14 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,10 @@
{scale : Scale, shift : Shift}

- op : dequantize_linear
inputs :
{x : X, scale : Scale, zero_point : ZeroPoint, in_accum : InAccum, in_state : InState}
outputs :
{y : Y, out_scale : OutScale, out_accum : OutAccum, out_state : OutState}
extra :
attrs : [float moving_rate = 0.9]

Expand Down Expand Up @@ -2197,6 +2201,12 @@
outputs :
{param_out : ParamOut, velocity_out : VelocityOut, master_param_out : MasterParamOut}

- op : moving_average_abs_max_scale
inputs :
{x : X, in_accum : InAccum, in_state : InState}
outputs :
{out : Out, out_scale : OutScale, out_state : OutState, out_accum : OutAccum}

- op : multi_dot
backward : multi_dot_grad
inputs :
Expand Down Expand Up @@ -2546,6 +2556,10 @@
{scale : Scale, shift : Shift, include_self: Include_self}

- op : quantize_linear
inputs :
{x : X, scale : Scale, zero_point : ZeroPoint, in_accum : InAccum, in_state : InState}
outputs :
{y : Y, out_scale : OutScale, out_accum : OutAccum, out_state : OutState}
extra :
attrs : [float moving_rate = 0.9]

Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/core/kernel_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ const AttrType& KernelContext::AttrAt(size_t idx) const {
return paddle::get<AttrType>(attrs_.at(idx));
} catch (paddle::bad_variant_access const& ex) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Attribute cast error in Op Kernel Context."));
"Attribute %d cast error in Op Kernel Context.", idx));
}
}

Expand Down
26 changes: 26 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3497,6 +3497,32 @@ void PsroiPoolInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}

void QuantizeLinearInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& in_accum,
const MetaTensor& in_state,
int quant_axis,
MetaTensor* y,
MetaTensor* out_scale,
MetaTensor* out_accum,
MetaTensor* out_state) {
y->set_dims(x.dims());
y->share_lod(x);
if (out_scale) {
if (quant_axis < 0) {
out_scale->set_dims(scale.dims());
} else {
out_scale->set_dims({x.dims()[quant_axis]});
}
}
if (out_accum) {
out_accum->set_dims(in_accum.dims());
}
if (out_state) {
out_state->set_dims(in_state.dims());
}
}

void RmsNormInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& residual,
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,16 @@ void PsroiPoolInferMeta(const MetaTensor& x,
float spatial_scale,
MetaTensor* out);

void QuantizeLinearInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& in_accum,
const MetaTensor& in_state,
int quant_axis,
MetaTensor* y,
MetaTensor* out_scale,
MetaTensor* out_accum,
MetaTensor* out_state);

void RmsNormInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& residual,
Expand Down
20 changes: 20 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,26 @@ void MultiClassNMSInferMeta(const MetaTensor& bboxes,
nms_rois_num->set_dtype(DataType::INT32);
}

void MovingAverageAbsMaxScaleInferMeta(const MetaTensor& x,
const MetaTensor& in_accum,
const MetaTensor& in_state,
MetaTensor* out,
MetaTensor* out_scale,
MetaTensor* out_state,
MetaTensor* out_accum) {
if (out) {
out->set_dims(x.dims());
out->share_lod(x);
out_scale->set_dims({1});
}
if (out_state) {
out_state->set_dims(in_state.dims());
}
if (out_accum) {
out_accum->set_dims(in_accum.dims());
}
}

void NllLossRawInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& weight,
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,14 @@ void MatchMatrixTensorInferMeta(const MetaTensor& x,
MetaTensor* tmp,
MetaConfig config = MetaConfig());

void MovingAverageAbsMaxScaleInferMeta(const MetaTensor& x,
const MetaTensor& in_accum,
const MetaTensor& in_state,
MetaTensor* out,
MetaTensor* out_scale,
MetaTensor* out_state,
MetaTensor* out_accum);

void MultiClassNMSInferMeta(const MetaTensor& bboxes,
const MetaTensor& scores,
const MetaTensor& rois_num,
Expand Down
5 changes: 5 additions & 0 deletions python/paddle/distributed/auto_parallel/static/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,11 @@ def __init__(
paddle.framework.set_flags({'FLAGS_new_executor_sequential_run': 1})
paddle.framework.set_flags({'FLAGS_new_executor_static_build': 1})

if auto_utils.use_new_executor():
is_pir_mode = os.environ.get("FLAGS_enable_pir_in_executor", None)
if is_pir_mode is None:
paddle.framework.set_flags({'FLAGS_enable_pir_in_executor': 1})

self.enable_job_schedule_profiler = False

# get dist input spec from shard dataloader
Expand Down
Loading