Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions paddle/cinn/backends/codegen_device_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ struct PredicatePrinter : public ir::IrPrinter {
void Visit(const ir::Or *x) { PrintBinaryOp("OR", x); }
void Visit(const ir::Max *x) { PrintBinaryOp("MAX", x); }
void Visit(const ir::Min *x) { PrintBinaryOp("MIN", x); }
void Visit(const ir::Call *x) { PrintCallOp(x); }

template <typename IRN>
void PrintBinaryOp(const std::string &op, const ir::BinaryOpNode<IRN> *x) {
Expand All @@ -143,6 +144,27 @@ struct PredicatePrinter : public ir::IrPrinter {
ir::IrPrinter::Visit(x->b());
str_ += "_BPA_";
}

void PrintCallOp(const ir::Call *x) {
str_ += "_BCALL_";
str_ += [&]() {
std::string temp = x->name;
std::transform(
temp.begin(), temp.end(), temp.begin(), [](unsigned char c) {
return std::toupper(c);
});
return temp;
}();
if (!x->read_args.empty()) {
str_ += "_R_";
for (const auto &v : x->read_args) ir::IrPrinter::Visit(v);
}
if (!x->write_args.empty()) {
str_ += "_W_";
for (const auto &v : x->write_args) ir::IrPrinter::Visit(v);
}
str_ += "_ECALL_";
}
};

std::string Predicate2String(ir::Expr predicate) {
Expand Down
11 changes: 11 additions & 0 deletions paddle/cinn/backends/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,14 @@ llvm::Value *CodeGenLLVM::Visit(const ir::_Dim_ *) {
CINN_NOT_IMPLEMENTED return nullptr;
}

llvm::Function *CallHostFallBack(const llvm::Module *m, const ir::Call *op) {
std::string fallback_func_name =
"cinn_host_" + op->name + "_" + common::Type2Str(op->type());
VLOG(6) << "Warn: host side has no func named '" << op->name
<< "', trying a fallback version '" << fallback_func_name << "'";
return m->getFunction(fallback_func_name);
}

llvm::Value *CodeGenLLVM::Visit(const ir::Call *op) {
if (op->name == runtime::intrinsic::debug_log_repr) {
return EmitCall_debug_info(op);
Expand All @@ -854,6 +862,9 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Call *op) {
}

llvm::Function *callee = m_->getFunction(op->name);
if (!callee) {
callee = CallHostFallBack(m_, op);
}
CHECK(callee) << "Unknown function referenced. [" << op->name << "]";

std::vector<llvm::Value *> args;
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ void ApplyCinnPreprocessPass(
if (has_dynamic_shape) {
pass_manager->AddPass(
cinn::dialect::ir::CreateFuseShapeOpsIntoGenerateShapeOpPass());
pass_manager->AddPass(
cinn::dialect::ir::CreatePdOpToDynamicShapeCinnOpPass());
pass_manager->AddPass(pir::CreateDeadCodeEliminationPass());
}
pass_manager->Run(program);
Expand Down
137 changes: 107 additions & 30 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -514,18 +514,68 @@ class SliceOpPattern : public pir::OpRewritePattern<paddle::dialect::SliceOp> {
}
};

/**
* CINN ArangeOp supports two kinds of input:
* input from pd_op.full (static) and input from cinn_op.generate_shape
* An example for the latter:
* ```c++
* x = paddle.zeros([3, 10])
* batch_size = paddle.shape(x)[1]
* stop = batch_size * 2
* paddle.arange(
* 0, // static start (from pd_op.full)
* stop, // symbolic stop (from cinn_op.generate_shape)
* 2 // static end (from pd_op.full)
* )
* ``` Note that step is not allowed to be symbolic, and when
* the inputs are symbolic, the start and end must be of integer type
*/
class ArangeOpPattern
: public pir::OpRewritePattern<paddle::dialect::ArangeOp> {
public:
using pir::OpRewritePattern<paddle::dialect::ArangeOp>::OpRewritePattern;

bool Match(paddle::dialect::ArangeOp op) const override {
// ArangeOp for CINN must have static start, end, step to calculate
// the shape of output tensor. Otherwise, it will be denied
// due to CauseNewSymbolicShape returning false
bool is_denied = CompatibleInfo::IsDeniedForCinn(*op.operation());
return !is_denied && IsDefinedBy<FullOp>(op, 0) &&
IsDefinedBy<FullOp>(op, 1) && IsDefinedBy<FullOp>(op, 2);
if (is_denied) return false;
// step is not allowed to be symbolic
if (IsDefinedBy<FullOp>(op, 2)) {
const FullOp full_op = CastDefinedTo<FullOp>(op, 2);
phi::Scalar step = full_op.attribute("value")
.dyn_cast<paddle::dialect::ScalarAttribute>()
.data();
bool positive_step = true;
#define MATCH_TYPE_TEST(TypeEnum, Dtype) \
case phi::DataType::TypeEnum: \
positive_step = step.to<Dtype>() > 0; \
break;

switch (step.dtype()) {
MATCH_TYPE_TEST(FLOAT32, float)
MATCH_TYPE_TEST(FLOAT64, double)
MATCH_TYPE_TEST(INT32, int)
MATCH_TYPE_TEST(INT64, int64_t)
MATCH_TYPE_TEST(FLOAT16, float)
MATCH_TYPE_TEST(BFLOAT16, float)
#undef MATCH_TYPE_TEST
default:
positive_step = false;
}
if (positive_step) {
const auto &dtype = op.attributes()
.at("dtype")
.dyn_cast<paddle::dialect::DataTypeAttribute>()
.data();
return (IsDefinedBy<FullOp>(op, 0) ||
IsDefinedBy<GenerateShapeOp>(op, 0)) &&
(IsDefinedBy<FullOp>(op, 1) ||
IsDefinedBy<GenerateShapeOp>(op, 1)) &&
(dtype == phi::DataType::INT32 || dtype == phi::DataType::INT64);
} else {
return IsDefinedBy<FullOp>(op, 0) && IsDefinedBy<FullOp>(op, 1);
}
}
return false;
}

void Rewrite(paddle::dialect::ArangeOp op,
Expand All @@ -537,31 +587,39 @@ class ArangeOpPattern

std::array<phi::Scalar, 3> input_list;
for (int i = 0; i < 3; i++) {
const FullOp full_op = CastDefinedTo<FullOp>(op, i);
phi::Scalar input = full_op.attribute("value")
.dyn_cast<paddle::dialect::ScalarAttribute>()
.data();
if (input.dtype() != dtype) {
// FullOp creates a tensor (scalar) with fp64 type by default
// therefore, we might need to perform type casting
switch (dtype) {
case phi::DataType::FLOAT32:
input = phi::Scalar(input.to<float>());
break;
case phi::DataType::FLOAT64:
input = phi::Scalar(input.to<double>());
break;
case phi::DataType::INT32:
input = phi::Scalar(input.to<int>());
break;
case phi::DataType::FLOAT16:
input = phi::Scalar(input.to<float>());
break;
case phi::DataType::BFLOAT16:
input = phi::Scalar(input.to<float>());
break;
default:
input = phi::Scalar(input.to<int64_t>());
phi::Scalar input;
if (IsDefinedBy<GenerateShapeOp>(op, i)) {
// arange does not support bool, so if the input is boolean, this would
// mean that there is dynamic shape
input = phi::Scalar(false);
input.SetFromTensor(true);
} else {
const FullOp full_op = CastDefinedTo<FullOp>(op, i);
input = full_op.attribute("value")
.dyn_cast<paddle::dialect::ScalarAttribute>()
.data();
if (input.dtype() != dtype) {
// FullOp creates a tensor (scalar) with fp64 type by default
// therefore, we might need to perform type casting
switch (dtype) {
case phi::DataType::FLOAT32:
input = phi::Scalar(input.to<float>());
break;
case phi::DataType::FLOAT64:
input = phi::Scalar(input.to<double>());
break;
case phi::DataType::INT32:
input = phi::Scalar(input.to<int>());
break;
case phi::DataType::FLOAT16:
input = phi::Scalar(input.to<float>());
break;
case phi::DataType::BFLOAT16:
input = phi::Scalar(input.to<float>());
break;
default:
input = phi::Scalar(input.to<int64_t>());
}
}
}
input_list[i] = input;
Expand Down Expand Up @@ -1436,6 +1494,7 @@ pir::RewritePatternSet PdOpToCinnOpPass::InitializePatterns(
ps.Add<
ArgMinMaxOpPattern<paddle::dialect::ArgmaxOp, cinn::dialect::ArgmaxOp>>(
context);
// Arange in this pass only handles static inputs
ps.Add<ArangeOpPattern>(context);
ps.Add<ProdOpPattern>(context);
ps.Add<ReshapeOpPattern>(context);
Expand Down Expand Up @@ -1469,6 +1528,24 @@ std::unique_ptr<pir::Pass> CreatePdOpToCinnOpPass() {
return std::make_unique<PdOpToCinnOpPass>();
}

PdOpToDynamicShapeCinnOpPass::PdOpToDynamicShapeCinnOpPass()
: pir::PatternRewritePass("pd_to_dyn_shape_cinn_pass", 1) {}

pir::RewritePatternSet PdOpToDynamicShapeCinnOpPass::InitializePatterns(
pir::IrContext *context) {
pir::RewritePatternSet ps(context);
ps.Add<ArangeOpPattern>(context);
return ps;
}

bool PdOpToDynamicShapeCinnOpPass::CanApplyOn(pir::Operation *op) const {
return op->num_regions() > 0;
}

std::unique_ptr<pir::Pass> CreatePdOpToDynamicShapeCinnOpPass() {
return std::make_unique<PdOpToDynamicShapeCinnOpPass>();
}

} // namespace ir
} // namespace dialect
} // namespace cinn
10 changes: 10 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,17 @@ class PdOpToCinnOpPass : public pir::PatternRewritePass {
bool CanApplyOn(pir::Operation *op) const override;
};

class PdOpToDynamicShapeCinnOpPass : public pir::PatternRewritePass {
public:
PdOpToDynamicShapeCinnOpPass();

pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override;

bool CanApplyOn(pir::Operation *op) const override;
};

IR_API std::unique_ptr<pir::Pass> CreatePdOpToCinnOpPass();
IR_API std::unique_ptr<pir::Pass> CreatePdOpToDynamicShapeCinnOpPass();

} // namespace ir
} // namespace dialect
Expand Down
77 changes: 77 additions & 0 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "paddle/cinn/common/dim_expr_converter.h"
#include "paddle/cinn/common/shape_constraint.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h"
#include "paddle/cinn/hlir/framework/compile_error.h"
Expand Down Expand Up @@ -666,6 +667,79 @@ std::vector<ir::LoweredFunc> OpLowererImpl::DoOpLower(
return funcs;
}

/**
* This function converts pir::Value::defining_op for ir::Tensor::operation
* Normally, ir::Tensor::operation will only be used to record the name
* of the compiler-generated var name, which is useless. However, operation
* has Attributes field, so can be used to record the op info.
*/
ir::PlaceholderOp* TensorOperationRecording(const ::pir::Value& value) {
// TODO(heqianyue): I think this is kinda ugly, since we should manually
// specify the rules to convert all the op (and their attribute), yet current
// implementation works and can be quickly written.
const ::pir::Operation* define_op = value.defining_op();
ir::PlaceholderOp* res = nullptr;
if (!define_op) return res;
res = cinn::common::make_shared<ir::PlaceholderOp>();
res->name = define_op->name();
// we filter some of the ops, and only record the **needed** attributes
if (define_op->name() == "pd_op.full") {
auto dtype = define_op->attribute("dtype")
.dyn_cast<paddle::dialect::DataTypeAttribute>()
.data();
phi::Scalar data = define_op->attribute("value")
.dyn_cast<paddle::dialect::ScalarAttribute>()
.data();
ir::Expr value;
#define DEFINE_CASE(TypeFlag, Type) \
case phi::DataType::TypeFlag: \
value = ir::Expr(data.to<Type>()); \
break;
switch (dtype) {
DEFINE_CASE(FLOAT32, float)
DEFINE_CASE(FLOAT64, double)
DEFINE_CASE(INT32, int)
DEFINE_CASE(BFLOAT16, float)
value->set_type(cinn::common::BFloat16());
break;
DEFINE_CASE(FLOAT16, float)
value->set_type(cinn::common::Float16());
break;
default:
value = ir::Expr(data.to<int64_t>());
}
#undef DEFINE_CASE
res->attrs.emplace("value", value);
} else if (define_op->name() == "cinn_op.generate_shape") {
// pir::Attribute --> symbol::DimExpr --> ir::Expr

auto ir_dim_expr = [&]() {
auto dim_expr_attr = define_op->attribute("output_dim_exprs");
auto dim_exprs = dialect::ConvertAttributeToDimExprs(dim_expr_attr);

PADDLE_ENFORCE_EQ(
dim_exprs.has_value(),
true,
::common::errors::PreconditionNotMet(
"Required success to execute convert attribute to dim exprs."));

auto expr_vec = dim_exprs.value();
PADDLE_ENFORCE_EQ(
expr_vec.empty(),
false,
::common::errors::PreconditionNotMet(
"Generate shape op can not yield empty symbolic shape."));
// only the first dim_expr matters for ArangeOp
return common::DimExprConverter().ConvertToIrExpr(expr_vec[0]);
}();
res->attrs.emplace("value", ir_dim_expr);
} else {
VLOG(6) << "Tensor defining op recording: not currently supported op.";
return nullptr;
}
return res;
}

ir::Tensor OpLowererImpl::GetTensor(const OpLoweringGroupPtr& group,
const ::pir::Value& value) {
auto type_info = value.type().dyn_cast<paddle::dialect::DenseTensorType>();
Expand Down Expand Up @@ -704,6 +778,9 @@ ir::Tensor OpLowererImpl::GetTensor(const OpLoweringGroupPtr& group,
tensor->set_value(*tensor_value);
}
}
if (auto op_ptr = TensorOperationRecording(value)) {
tensor->operation = ir::FunctionRef(op_ptr);
}
return tensor;
}

Expand Down
Loading
Loading