Skip to content

Commit a05f195

Browse files
[Dynamic Shape] Add FullyInsertBroadcastPass and Broadcast Op (#60511)
* add ShapeBroadcastOp * add pass FullyInsertBroadcastPass * InferSymbolicShape of BroadcastShape Op * Delete unit test * Fix return error * Code format * Fix error message * Update paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc Co-authored-by: Bo Zhang <[email protected]> --------- Co-authored-by: Bo Zhang <[email protected]>
1 parent 6b2d74c commit a05f195

File tree

5 files changed

+344
-1
lines changed

5 files changed

+344
-1
lines changed

paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,13 @@ if(NOT CINN_ONLY)
2929
cinn_op_dialect
3030
op_dialect_vjp)
3131

32+
cinn_cc_library(
33+
fully_insert_broadcast_pass
34+
SRCS
35+
fully_insert_broadcast_pass.cc
36+
DEPS
37+
pir
38+
cinn_op_dialect
39+
op_dialect_vjp)
40+
3241
endif()
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h"
16+
17+
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
18+
#include "paddle/cinn/hlir/framework/pir/utils.h"
19+
#include "paddle/common/ddim.h"
20+
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
21+
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
22+
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
23+
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
24+
#include "paddle/fluid/pir/drr/api/match_context.h"
25+
#include "paddle/pir/core/builtin_dialect.h"
26+
#include "paddle/pir/pass/pass.h"
27+
#include "paddle/pir/pattern_rewrite/pattern_applicator.h"
28+
#include "paddle/pir/pattern_rewrite/pattern_match.h"
29+
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"
30+
31+
namespace cinn {
32+
namespace dialect {
33+
namespace ir {
34+
35+
pir::Value GetOutputDimTensor(pir::PatternRewriter* rewriter,
36+
pir::Value x,
37+
pir::Value y) {
38+
pir::Value x_shape = rewriter->Build<paddle::dialect::ShapeOp>(x).out();
39+
pir::Value y_shape = rewriter->Build<paddle::dialect::ShapeOp>(y).out();
40+
return rewriter->Build<paddle::dialect::ShapeBroadcastOp>(x_shape, y_shape)
41+
.out();
42+
}
43+
44+
bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
45+
pir::Value x = op->operand_source(0);
46+
pir::Value y = op->operand_source(1);
47+
pir::Value output_dim_tensor = GetOutputDimTensor(rewriter, x, y);
48+
{
49+
pir::Value broadcasted_x =
50+
rewriter->Build<paddle::dialect::ExpandOp>(x, output_dim_tensor).out();
51+
op->operand(0).set_source(broadcasted_x);
52+
}
53+
{
54+
pir::Value broadcasted_y =
55+
rewriter->Build<paddle::dialect::ExpandOp>(y, output_dim_tensor).out();
56+
op->operand(1).set_source(broadcasted_y);
57+
}
58+
return true;
59+
}
60+
61+
template <typename OPTYPE>
62+
class FullyInsertBroadcastPattern : public pir::OpRewritePattern<OPTYPE> {
63+
public:
64+
using pir::OpRewritePattern<OPTYPE>::OpRewritePattern;
65+
66+
bool MatchAndRewrite(OPTYPE op,
67+
pir::PatternRewriter& rewriter) const override {
68+
return ProcessOp(op, &rewriter);
69+
}
70+
};
71+
72+
FullyInsertBroadcastPass::FullyInsertBroadcastPass()
73+
: pir::PatternRewritePass("fully_insert_broadcast_pass", 1) {}
74+
75+
pir::RewritePatternSet FullyInsertBroadcastPass::InitializePatterns(
76+
pir::IrContext* context) {
77+
pir::RewritePatternSet ps(context);
78+
// elementwise ops
79+
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::AddOp>>(context);
80+
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::SubtractOp>>(context);
81+
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MultiplyOp>>(context);
82+
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::DivideOp>>(context);
83+
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::ElementwisePowOp>>(
84+
context);
85+
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::RemainderOp>>(context);
86+
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::FloorDivideOp>>(context);
87+
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MaximumOp>>(context);
88+
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MinimumOp>>(context);
89+
90+
// compare ops
91+
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::LessThanOp>>(context);
92+
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::LessEqualOp>>(context);
93+
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::EqualOp>>(context);
94+
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::NotEqualOp>>(context);
95+
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::GreaterThanOp>>(context);
96+
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::GreaterEqualOp>>(context);
97+
98+
// bitwise ops
99+
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseOrOp>>(context);
100+
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseXorOp>>(context);
101+
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseNotOp>>(context);
102+
103+
return ps;
104+
}
105+
106+
bool FullyInsertBroadcastPass::CanApplyOn(pir::Operation* op) const {
107+
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;
108+
}
109+
110+
} // namespace ir
111+
} // namespace dialect
112+
} // namespace cinn
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/pir/pass/pass.h"
18+
#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h"
19+
20+
namespace cinn {
21+
namespace dialect {
22+
namespace ir {
23+
24+
class FullyInsertBroadcastPass : public pir::PatternRewritePass {
25+
public:
26+
FullyInsertBroadcastPass();
27+
28+
pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override;
29+
30+
bool CanApplyOn(pir::Operation *op) const override;
31+
};
32+
33+
} // namespace ir
34+
} // namespace dialect
35+
} // namespace cinn

paddle/fluid/pir/dialect/operator/ir/manual_op.cc

Lines changed: 155 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op,
2323
paddle::dialect::SliceArrayOp, paddle::dialect::SliceArrayDenseOp,
2424
paddle::dialect::AssignArray_Op, paddle::dialect::ArrayToTensorOp,
2525
paddle::dialect::SelectInputOp, paddle::dialect::IncrementOp,
26-
paddle::dialect::Increment_Op
26+
paddle::dialect::Increment_Op, paddle::dialect::ShapeBroadcastOp
2727
#else
2828

2929
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
@@ -35,6 +35,7 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op,
3535
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
3636
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
3737
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
38+
#include "paddle/phi/api/lib/data_type_set.h"
3839
#include "paddle/phi/api/lib/utils/allocator.h"
3940
#include "paddle/phi/core/dense_tensor.h"
4041
#include "paddle/phi/core/enforce.h"
@@ -2925,6 +2926,158 @@ phi::DataType Increment_Op::GetKernelTypeForVar(
29252926
return expected_kernel_dtype;
29262927
}
29272928

2929+
void ShapeBroadcastOp::Build(pir::Builder &builder,
2930+
pir::OperationArgument &argument,
2931+
pir::Value x_,
2932+
pir::Value y_) {
2933+
VLOG(4) << "Start build ShapeBroadcastOp";
2934+
2935+
VLOG(4) << "Builder construction inputs";
2936+
std::vector<pir::Value> argument_inputs = {x_, y_};
2937+
argument.AddInputs(argument_inputs);
2938+
2939+
VLOG(4) << "Builder construction attributes";
2940+
2941+
VLOG(4) << "Builder construction outputs";
2942+
paddle::dialect::DenseTensorType x =
2943+
x_.type().dyn_cast<paddle::dialect::DenseTensorType>();
2944+
paddle::dialect::DenseTensorType y =
2945+
y_.type().dyn_cast<paddle::dialect::DenseTensorType>();
2946+
2947+
VLOG(4) << "Builder construction dense_x";
2948+
paddle::dialect::IrTensor ir_tensor_x(
2949+
paddle::dialect::TransToPhiDataType(x.dtype()),
2950+
x.dims(),
2951+
x.data_layout(),
2952+
x.lod(),
2953+
x.offset());
2954+
VLOG(4) << "Builder construction meta_x";
2955+
paddle::dialect::IrMetaTensor meta_x(&ir_tensor_x);
2956+
2957+
VLOG(4) << "Builder construction dense_y";
2958+
paddle::dialect::IrTensor ir_tensor_y(
2959+
paddle::dialect::TransToPhiDataType(y.dtype()),
2960+
y.dims(),
2961+
y.data_layout(),
2962+
y.lod(),
2963+
y.offset());
2964+
VLOG(4) << "Builder construction meta_y";
2965+
paddle::dialect::IrMetaTensor meta_y(&ir_tensor_y);
2966+
paddle::dialect::IrTensor dense_out;
2967+
paddle::dialect::IrMetaTensor meta_out(&dense_out);
2968+
2969+
phi::ElementwiseInferMeta(meta_x, meta_y, &meta_out);
2970+
2971+
std::vector<pir::Type> argument_outputs;
2972+
pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get(
2973+
pir::IrContext::Instance(),
2974+
paddle::dialect::TransToIrDataType(dense_out.dtype()),
2975+
dense_out.dims(),
2976+
dense_out.layout(),
2977+
dense_out.lod(),
2978+
dense_out.offset());
2979+
argument_outputs.push_back(out_dense_tensor_type);
2980+
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
2981+
::pir::PassStopGradientsDefaultly(argument);
2982+
}
2983+
2984+
namespace {
2985+
2986+
void ShapeBroadcastOpInferMeta(const phi::MetaTensor &x,
2987+
const phi::MetaTensor &y,
2988+
phi::MetaTensor *out) {
2989+
PADDLE_ENFORCE_EQ(
2990+
x.dims().size(),
2991+
1,
2992+
phi::errors::PreconditionNotMet(
2993+
"The size %d of x.dims() must be equal to 1.", x.dims().size()));
2994+
PADDLE_ENFORCE_EQ(
2995+
y.dims().size(),
2996+
1,
2997+
phi::errors::PreconditionNotMet(
2998+
"The size %d of y.dims() must be equal to 1.", y.dims().size()));
2999+
out->set_dims({std::max<int64_t>(x.dims().at(0), y.dims().at(0))});
3000+
// dtype need promote when meet input dtype with more precision
3001+
paddle::experimental::DataTypeSet dtype_set{x.dtype()};
3002+
dtype_set = dtype_set | paddle::experimental::DataTypeSet(y.dtype());
3003+
DataType promote_result = PromoteTypes(dtype_set);
3004+
if (promote_result == DataType::UNDEFINED) {
3005+
promote_result = x.dtype();
3006+
}
3007+
out->set_dtype(promote_result);
3008+
out->set_layout(x.layout());
3009+
out->share_lod(x);
3010+
}
3011+
3012+
} // namespace
3013+
3014+
void ShapeBroadcastOp::InferMeta(phi::InferMetaContext *infer_meta) {
3015+
auto fn = PD_INFER_META(ShapeBroadcastOpInferMeta);
3016+
fn(infer_meta);
3017+
}
3018+
3019+
phi::DataType ShapeBroadcastOp::GetKernelTypeForVar(
3020+
const std::string &var_name,
3021+
const phi::DataType &tensor_dtype,
3022+
const phi::DataType &expected_kernel_dtype) {
3023+
VLOG(4) << "Get KernelType for Var of op: ShapeBroadcastOp";
3024+
3025+
return expected_kernel_dtype;
3026+
}
3027+
3028+
namespace {
3029+
3030+
symbol::DimExpr GetBroadcastDimExpr(const symbol::DimExpr &lhs,
3031+
const symbol::DimExpr &rhs) {
3032+
if (lhs.isa<std::int64_t>() && rhs.isa<std::int64_t>()) {
3033+
return std::max(lhs.dyn_cast<std::int64_t>(), rhs.dyn_cast<std::int64_t>());
3034+
} else if (lhs.isa<std::int64_t>()) {
3035+
return lhs.dyn_cast<std::int64_t>() == 1 ? rhs : lhs;
3036+
} else if (rhs.isa<std::int64_t>()) {
3037+
return rhs.dyn_cast<std::int64_t>() == 1 ? lhs : rhs;
3038+
} else {
3039+
return symbol::Broadcast<symbol::DimExpr>{
3040+
symbol::List<symbol::DimExpr>{lhs, rhs}};
3041+
}
3042+
LOG(FATAL) << "Dead code";
3043+
}
3044+
3045+
} // namespace
3046+
3047+
bool ShapeBroadcastOp::InferSymbolicShape(
3048+
pir::ShapeConstraintIRAnalysis *shape_analysis) {
3049+
pir::Value x = operand_source(0);
3050+
pir::Value y = operand_source(1);
3051+
std::string x_id = pir::GetValueId(&x);
3052+
std::string y_id = pir::GetValueId(&y);
3053+
3054+
IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(x_id) > 0,
3055+
"x_id does not exist.");
3056+
IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(y_id) > 0,
3057+
"y_id does not exist.");
3058+
const auto &x_data_shape = shape_analysis->value_id_to_shapeordata_.at(x_id);
3059+
const auto &y_data_shape = shape_analysis->value_id_to_shapeordata_.at(y_id);
3060+
IR_ENFORCE(x_data_shape.data().has_value(),
3061+
"Value x comes from ShapeOp, it must have data");
3062+
IR_ENFORCE(y_data_shape.data().has_value(),
3063+
"Value y comes from ShapeOp, it must have data");
3064+
const auto &x_data = x_data_shape.data().value();
3065+
const auto &y_data = y_data_shape.data().value();
3066+
IR_ENFORCE(x_data.size() == y_data.size(), "Support same rank temporarily");
3067+
3068+
std::vector<symbol::DimExpr> output_data;
3069+
for (std::size_t i = 0; i < x_data.size(); ++i) {
3070+
output_data.emplace_back(GetBroadcastDimExpr(x_data.at(i), y_data.at(i)));
3071+
}
3072+
3073+
pir::OpResult res = result(0);
3074+
std::string res_id = pir::GetValueId(&res);
3075+
symbol::ShapeOrDataDimExprs output_data_shape =
3076+
symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(output_data);
3077+
shape_analysis->value_id_to_shapeordata_[res_id] = output_data_shape;
3078+
return true;
3079+
}
3080+
29283081
} // namespace dialect
29293082
} // namespace paddle
29303083

@@ -2948,4 +3101,5 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp)
29483101
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp)
29493102
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp)
29503103
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::Increment_Op)
3104+
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ShapeBroadcastOp)
29513105
#endif

paddle/fluid/pir/dialect/operator/ir/manual_op.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "paddle/fluid/framework/infershape_utils.h"
1919
#include "paddle/fluid/pir/dialect/operator/interface/decomp.h"
2020
#include "paddle/fluid/pir/dialect/operator/interface/get_kernel_type_for_var.h"
21+
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h"
2122
#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h"
2223
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
2324
#include "paddle/fluid/pir/dialect/operator/interface/vjp.h"
@@ -554,6 +555,37 @@ class Increment_Op
554555
const std::vector<std::vector<bool>> &stop_gradients);
555556
};
556557

558+
class IR_API ShapeBroadcastOp
559+
: public pir::Op<ShapeBroadcastOp,
560+
paddle::dialect::InferSymbolicShapeInterface,
561+
paddle::dialect::InferMetaInterface,
562+
paddle::dialect::GetKernelTypeForVarInterface> {
563+
public:
564+
using Op::Op;
565+
static const char *name() { return "pd_op.shape_broadcast"; }
566+
static constexpr const char **attributes_name = nullptr;
567+
static constexpr uint32_t attributes_num = 0;
568+
static void Build(pir::Builder &builder, // NOLINT
569+
pir::OperationArgument &argument, // NOLINT
570+
pir::Value x_,
571+
pir::Value y_);
572+
573+
void VerifySig() {}
574+
575+
pir::Value x() { return operand_source(0); }
576+
pir::Value y() { return operand_source(1); }
577+
pir::OpResult out() { return result(0); }
578+
579+
static void InferMeta(phi::InferMetaContext *infer_meta);
580+
581+
static phi::DataType GetKernelTypeForVar(
582+
const std::string &var_name,
583+
const phi::DataType &tensor_dtype,
584+
const phi::DataType &expected_kernel_dtype);
585+
586+
bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);
587+
};
588+
557589
} // namespace dialect
558590
} // namespace paddle
559591

@@ -577,3 +609,4 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp)
577609
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp)
578610
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp)
579611
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::Increment_Op)
612+
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ShapeBroadcastOp)

0 commit comments

Comments
 (0)