diff --git a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py index c426d3325a0811..657ff43683560e 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py @@ -94,6 +94,7 @@ 'LegacyInterpolateInferMeta', 'NceInferMeta', 'PyramidHashInferMeta', + 'RmsNormInferMeta', 'SigmoidCrossEntropyWithLogitsInferMeta', 'StackInferMeta', 'WeightOnlyLinearInferMeta', diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index bb10157cfc69da..51af7a9c2fe168 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -4944,15 +4944,22 @@ void RmsNormInferMeta(const MetaTensor& x, const float quant_min_bound, MetaTensor* out, MetaTensor* residual_out, - MetaTensor* inv_var) { + MetaTensor* inv_var, + MetaConfig config) { size_t x_dims_size = x.dims().size(); size_t normalized_dims = 1; + bool has_minus_one = false; for (size_t i = begin_norm_axis; i < x_dims_size; ++i) { normalized_dims *= x.dims().at(i); + has_minus_one |= (x.dims().at(i) == -1); } - if (normalized_dims != 0) { + bool skip_check = false; + if (normalized_dims == 0) skip_check = true; + if (has_minus_one && !config.is_runtime) skip_check = true; + + if (!skip_check) { PADDLE_ENFORCE_EQ(normalized_dims, norm_weight.dims()[0], common::errors::InvalidArgument( @@ -4963,7 +4970,6 @@ void RmsNormInferMeta(const MetaTensor& x, normalized_dims, norm_weight.dims()[0])); } - out->set_dims(x.dims()); if (quant_scale > 0) { diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 67027f75097f7e..224a1376902672 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -988,7 +988,8 @@ void RmsNormInferMeta(const MetaTensor& x, const float quant_min_bound, MetaTensor* out, MetaTensor* residual_out, - MetaTensor* inv_var); + MetaTensor* inv_var, + MetaConfig config = MetaConfig()); void RmspropInferMeta(const MetaTensor& param, const MetaTensor& mean_square,