@@ -23,7 +23,7 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op,
2323 paddle::dialect::SliceArrayOp, paddle::dialect::SliceArrayDenseOp,
2424 paddle::dialect::AssignArray_Op, paddle::dialect::ArrayToTensorOp,
2525 paddle::dialect::SelectInputOp, paddle::dialect::IncrementOp,
26- paddle::dialect::Increment_Op
26+ paddle::dialect::Increment_Op, paddle::dialect::ShapeBroadcastOp
2727#else
2828
2929#include " paddle/fluid/pir/dialect/operator/ir/manual_op.h"
@@ -35,6 +35,7 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op,
3535#include " paddle/fluid/pir/dialect/operator/ir/op_type.h"
3636#include " paddle/fluid/pir/dialect/operator/ir/pd_op.h"
3737#include " paddle/fluid/primitive/rule/vjp/vjp.h"
38+ #include " paddle/phi/api/lib/data_type_set.h"
3839#include " paddle/phi/api/lib/utils/allocator.h"
3940#include " paddle/phi/core/dense_tensor.h"
4041#include " paddle/phi/core/enforce.h"
@@ -2925,6 +2926,158 @@ phi::DataType Increment_Op::GetKernelTypeForVar(
29252926 return expected_kernel_dtype;
29262927}
29272928
2929+ void ShapeBroadcastOp::Build (pir::Builder &builder,
2930+ pir::OperationArgument &argument,
2931+ pir::Value x_,
2932+ pir::Value y_) {
2933+ VLOG (4 ) << " Start build ShapeBroadcastOp" ;
2934+
2935+ VLOG (4 ) << " Builder construction inputs" ;
2936+ std::vector<pir::Value> argument_inputs = {x_, y_};
2937+ argument.AddInputs (argument_inputs);
2938+
2939+ VLOG (4 ) << " Builder construction attributes" ;
2940+
2941+ VLOG (4 ) << " Builder construction outputs" ;
2942+ paddle::dialect::DenseTensorType x =
2943+ x_.type ().dyn_cast <paddle::dialect::DenseTensorType>();
2944+ paddle::dialect::DenseTensorType y =
2945+ y_.type ().dyn_cast <paddle::dialect::DenseTensorType>();
2946+
2947+ VLOG (4 ) << " Builder construction dense_x" ;
2948+ paddle::dialect::IrTensor ir_tensor_x (
2949+ paddle::dialect::TransToPhiDataType (x.dtype ()),
2950+ x.dims (),
2951+ x.data_layout (),
2952+ x.lod (),
2953+ x.offset ());
2954+ VLOG (4 ) << " Builder construction meta_x" ;
2955+ paddle::dialect::IrMetaTensor meta_x (&ir_tensor_x);
2956+
2957+ VLOG (4 ) << " Builder construction dense_y" ;
2958+ paddle::dialect::IrTensor ir_tensor_y (
2959+ paddle::dialect::TransToPhiDataType (y.dtype ()),
2960+ y.dims (),
2961+ y.data_layout (),
2962+ y.lod (),
2963+ y.offset ());
2964+ VLOG (4 ) << " Builder construction meta_y" ;
2965+ paddle::dialect::IrMetaTensor meta_y (&ir_tensor_y);
2966+ paddle::dialect::IrTensor dense_out;
2967+ paddle::dialect::IrMetaTensor meta_out (&dense_out);
2968+
2969+ phi::ElementwiseInferMeta (meta_x, meta_y, &meta_out);
2970+
2971+ std::vector<pir::Type> argument_outputs;
2972+ pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get (
2973+ pir::IrContext::Instance (),
2974+ paddle::dialect::TransToIrDataType (dense_out.dtype ()),
2975+ dense_out.dims (),
2976+ dense_out.layout (),
2977+ dense_out.lod (),
2978+ dense_out.offset ());
2979+ argument_outputs.push_back (out_dense_tensor_type);
2980+ argument.AddOutputs (argument_outputs.begin (), argument_outputs.end ());
2981+ ::pir::PassStopGradientsDefaultly (argument);
2982+ }
2983+
2984+ namespace {
2985+
2986+ void ShapeBroadcastOpInferMeta (const phi::MetaTensor &x,
2987+ const phi::MetaTensor &y,
2988+ phi::MetaTensor *out) {
2989+ PADDLE_ENFORCE_EQ (
2990+ x.dims ().size (),
2991+ 1 ,
2992+ phi::errors::PreconditionNotMet (
2993+ " The size %d of x.dims() must be equal to 1." , x.dims ().size ()));
2994+ PADDLE_ENFORCE_EQ (
2995+ y.dims ().size (),
2996+ 1 ,
2997+ phi::errors::PreconditionNotMet (
2998+ " The size %d of y.dims() must be equal to 1." , y.dims ().size ()));
2999+ out->set_dims ({std::max<int64_t >(x.dims ().at (0 ), y.dims ().at (0 ))});
3000+ // dtype need promote when meet input dtype with more precision
3001+ paddle::experimental::DataTypeSet dtype_set{x.dtype ()};
3002+ dtype_set = dtype_set | paddle::experimental::DataTypeSet (y.dtype ());
3003+ DataType promote_result = PromoteTypes (dtype_set);
3004+ if (promote_result == DataType::UNDEFINED) {
3005+ promote_result = x.dtype ();
3006+ }
3007+ out->set_dtype (promote_result);
3008+ out->set_layout (x.layout ());
3009+ out->share_lod (x);
3010+ }
3011+
3012+ } // namespace
3013+
3014+ void ShapeBroadcastOp::InferMeta (phi::InferMetaContext *infer_meta) {
3015+ auto fn = PD_INFER_META (ShapeBroadcastOpInferMeta);
3016+ fn (infer_meta);
3017+ }
3018+
3019+ phi::DataType ShapeBroadcastOp::GetKernelTypeForVar (
3020+ const std::string &var_name,
3021+ const phi::DataType &tensor_dtype,
3022+ const phi::DataType &expected_kernel_dtype) {
3023+ VLOG (4 ) << " Get KernelType for Var of op: ShapeBroadcastOp" ;
3024+
3025+ return expected_kernel_dtype;
3026+ }
3027+
3028+ namespace {
3029+
3030+ symbol::DimExpr GetBroadcastDimExpr (const symbol::DimExpr &lhs,
3031+ const symbol::DimExpr &rhs) {
3032+ if (lhs.isa <std::int64_t >() && rhs.isa <std::int64_t >()) {
3033+ return std::max (lhs.dyn_cast <std::int64_t >(), rhs.dyn_cast <std::int64_t >());
3034+ } else if (lhs.isa <std::int64_t >()) {
3035+ return lhs.dyn_cast <std::int64_t >() == 1 ? rhs : lhs;
3036+ } else if (rhs.isa <std::int64_t >()) {
3037+ return rhs.dyn_cast <std::int64_t >() == 1 ? lhs : rhs;
3038+ } else {
3039+ return symbol::Broadcast<symbol::DimExpr>{
3040+ symbol::List<symbol::DimExpr>{lhs, rhs}};
3041+ }
3042+ LOG (FATAL) << " Dead code" ;
3043+ }
3044+
3045+ } // namespace
3046+
3047+ bool ShapeBroadcastOp::InferSymbolicShape (
3048+ pir::ShapeConstraintIRAnalysis *shape_analysis) {
3049+ pir::Value x = operand_source (0 );
3050+ pir::Value y = operand_source (1 );
3051+ std::string x_id = pir::GetValueId (&x);
3052+ std::string y_id = pir::GetValueId (&y);
3053+
3054+ IR_ENFORCE (shape_analysis->value_id_to_shapeordata_ .count (x_id) > 0 ,
3055+ " x_id does not exist." );
3056+ IR_ENFORCE (shape_analysis->value_id_to_shapeordata_ .count (y_id) > 0 ,
3057+ " y_id does not exist." );
3058+ const auto &x_data_shape = shape_analysis->value_id_to_shapeordata_ .at (x_id);
3059+ const auto &y_data_shape = shape_analysis->value_id_to_shapeordata_ .at (y_id);
3060+ IR_ENFORCE (x_data_shape.data ().has_value (),
3061+ " Value x comes from ShapeOp, it must have data" );
3062+ IR_ENFORCE (y_data_shape.data ().has_value (),
3063+ " Value y comes from ShapeOp, it must have data" );
3064+ const auto &x_data = x_data_shape.data ().value ();
3065+ const auto &y_data = y_data_shape.data ().value ();
3066+ IR_ENFORCE (x_data.size () == y_data.size (), " Support same rank temporarily" );
3067+
3068+ std::vector<symbol::DimExpr> output_data;
3069+ for (std::size_t i = 0 ; i < x_data.size (); ++i) {
3070+ output_data.emplace_back (GetBroadcastDimExpr (x_data.at (i), y_data.at (i)));
3071+ }
3072+
3073+ pir::OpResult res = result (0 );
3074+ std::string res_id = pir::GetValueId (&res);
3075+ symbol::ShapeOrDataDimExprs output_data_shape =
3076+ symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData (output_data);
3077+ shape_analysis->value_id_to_shapeordata_ [res_id] = output_data_shape;
3078+ return true ;
3079+ }
3080+
29283081} // namespace dialect
29293082} // namespace paddle
29303083
@@ -2948,4 +3101,5 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp)
29483101IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp)
29493102IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp)
29503103IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::Increment_Op)
3104+ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ShapeBroadcastOp)
29513105#endif
0 commit comments