Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions paddle/cinn/common/dim_expr_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/common/dim_expr_converter.h"
#include <unordered_map>
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/tensor.h"

namespace cinn::common {
using namespace symbol; // NOLINT
Expand All @@ -27,7 +28,7 @@ struct DimExprToIrExprVisitor {

ir::Expr operator()(const int64_t& dim) { return ir::Expr(dim); }

ir::Expr operator()(const std::string& dim_expr) {
virtual ir::Expr operator()(const std::string& dim_expr) {
// The dimension must be greater equal than 1, and due to the extensive use
// of int32 in CAS, the upper bound here is temporarily INT32_MAX, otherwise
// there may be a risk of overflow.
Expand Down Expand Up @@ -111,8 +112,61 @@ struct DimExprToIrExprVisitor {

} // namespace

struct DimExprConverterWithSymbolBindings::
DimExprToIrExprVisitorWithSymbolBinding : public DimExprToIrExprVisitor {
using SymbolBinding = cinn::dialect::SymbolBinding;
using ShapeSymbolBinding = cinn::dialect::ShapeSymbolBinding;
using DataSymbolBinding = cinn::dialect::DataSymbolBinding;

const std::vector<ir::Tensor>& inputs_;
std::unordered_map<std::string, cinn::dialect::SymbolBinding>
symbol_binding_map_;

ir::Expr operator()(const std::string& dim_expr) override {
CHECK(symbol_binding_map_.count(dim_expr));
auto symbol_binding = symbol_binding_map_[dim_expr];
auto [input_idx, input_dim_idx] = std::visit(
[](auto&& symbol_binding) -> std::pair<int64_t, int64_t> {
return {symbol_binding.input_tensor_idx,
symbol_binding.input_tensor_dim_idx};
},
symbol_binding);
if (std::holds_alternative<ShapeSymbolBinding>(symbol_binding)) {
return inputs_[input_idx]->sym_shape[input_dim_idx]->GetDimExpr();
}
// for data binding [S0, a, b], inputs[a] is Tensor A, return A(b)
return inputs_[input_idx](cinn::ir::Expr(input_dim_idx));
}

DimExprToIrExprVisitorWithSymbolBinding(
const std::vector<ir::Tensor>& inputs,
const std::vector<SymbolBinding>& symbol_bindings)
: inputs_(inputs) {
for (const auto& symbol_binding : symbol_bindings) {
const auto& symbol_name = std::visit(
[](auto&& symbol_binding) -> std::string {
return symbol_binding.symbol_name;
},
symbol_binding);
symbol_binding_map_[symbol_name] = symbol_binding;
}
}
};

ir::Expr DimExprConverter::ConvertToIrExpr(const DimExpr& dim_expr) const {
return DimExprToIrExprVisitor().ConvertToIrExpr(dim_expr);
}

ir::Expr DimExprConverterWithSymbolBindings::ConvertToIrExpr(
const DimExpr& dim_expr) const {
return visitor_->ConvertToIrExpr(dim_expr);
}

DimExprConverterWithSymbolBindings::DimExprConverterWithSymbolBindings(
const std::vector<ir::Tensor>& inputs,
const cinn::dialect::SymbolBindings& symbol_bindings) {
visitor_ = std::make_shared<DimExprToIrExprVisitorWithSymbolBinding>(
inputs, symbol_bindings);
}

} // namespace cinn::common
13 changes: 13 additions & 0 deletions paddle/cinn/common/dim_expr_converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#pragma once

#include <memory.h>
#include "paddle/cinn/hlir/dialect/operator/ir/symbol_bindings.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/pir/include/dialect/shape/utils/dim_expr.h"

Expand All @@ -23,4 +25,15 @@ struct DimExprConverter final {
ir::Expr ConvertToIrExpr(const symbol::DimExpr&) const;
};

struct DimExprConverterWithSymbolBindings final {
ir::Expr ConvertToIrExpr(const symbol::DimExpr&) const;
DimExprConverterWithSymbolBindings(
const std::vector<ir::Tensor>& inputs,
const cinn::dialect::SymbolBindings& symbol_bindings);

private:
struct DimExprToIrExprVisitorWithSymbolBinding;
std::shared_ptr<DimExprToIrExprVisitorWithSymbolBinding> visitor_;
};

} // namespace cinn::common
4 changes: 3 additions & 1 deletion paddle/cinn/frontend/paddle_model_convertor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

#include "paddle/cinn/frontend/paddle_model_convertor.h"

#include <glog/logging.h>

#include <algorithm>
Expand All @@ -25,6 +24,7 @@
#include "paddle/cinn/frontend/paddle/cpp/program_desc.h"
#include "paddle/cinn/frontend/paddle/model_parser.h"
#include "paddle/cinn/frontend/var_type_utils.h"
#include "paddle/cinn/hlir/dialect/operator/ir/symbol_bindings.h"
#include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/common/enforce.h"

Expand Down Expand Up @@ -202,6 +202,8 @@ void SetOpDescAttr(const std::string& attr_name,
VISITOR_EXPAND(std::vector<int64_t>)
VISITOR_EXPAND(std::vector<double>)
#undef VISITOR_EXPAND
void operator()(const std::vector<symbol::DimExpr>& v) {}
void operator()(const cinn::dialect::SymbolBindings& v) {}

private:
paddle::cpp::OpDesc* op_desc_;
Expand Down
7 changes: 7 additions & 0 deletions paddle/cinn/frontend/syntax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include "paddle/cinn/frontend/paddle/model_parser.h"
#include "paddle/cinn/frontend/paddle_model_to_program.h"
#include "paddle/cinn/hlir/dialect/operator/ir/symbol_bindings.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/utils/string.h"
Expand Down Expand Up @@ -559,6 +560,12 @@ std::string _Instruction_::debug_string() const {
void operator()(const std::vector<std::string>& x) {
s_ << "[" + utils::Join(x, ",") + "]";
}
void operator()(const std::vector<symbol::DimExpr>& x) {
s_ << "[" + utils::Join(x, ",") + "]";
}
void operator()(const cinn::dialect::SymbolBindings& x) {
s_ << "[" + utils::Join(x, ",") + "]";
}
};

std::stringstream ss;
Expand Down
34 changes: 23 additions & 11 deletions paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,19 @@ std::optional<DimExpr> ConvertAttributeToDimExpr(::pir::Attribute attribute) {
return std::nullopt;
}

std::optional<std::vector<symbol::DimExpr>> ConvertAttributeToDimExprs(
::pir::Attribute attribute) {
if (!attribute.isa<pir::ArrayAttribute>()) return std::nullopt;
auto array = attribute.dyn_cast<pir::ArrayAttribute>();
std::vector<symbol::DimExpr> dim_exprs;
for (int i = 0; i < array.size(); ++i) {
const auto& dim_expr = ConvertAttributeToDimExpr(array.at(i));
if (!dim_expr.has_value()) return std::nullopt;
dim_exprs.push_back(dim_expr.value());
}
return dim_exprs;
}

class SubstituteDimExprHelper final {
public:
using DimExpr4SymbolNameT =
Expand Down Expand Up @@ -327,7 +340,7 @@ DimExpr SubstituteDimExpr(
namespace {

std::optional<DimExpr> GetDimExprBySymbolBindingImpl(
const GenerateShapeOp::DataSymbolBinding& symbol_binding,
const DataSymbolBinding& symbol_binding,
const std::function<const symbol::ShapeOrDataDimExprs&(int in_tensor_idx)>&
DimExpr4InputDim) {
const symbol::ShapeOrDataDimExprs& shape_or_data_dim_expr =
Expand All @@ -340,7 +353,7 @@ std::optional<DimExpr> GetDimExprBySymbolBindingImpl(
}

std::optional<DimExpr> GetDimExprBySymbolBindingImpl(
const GenerateShapeOp::ShapeSymbolBinding& symbol_binding,
const ShapeSymbolBinding& symbol_binding,
const std::function<const symbol::ShapeOrDataDimExprs&(int in_tensor_idx)>&
DimExpr4InputDim) {
const symbol::ShapeOrDataDimExprs& shape_or_data_dim_expr =
Expand All @@ -350,8 +363,7 @@ std::optional<DimExpr> GetDimExprBySymbolBindingImpl(
return shape_or_data_dim_expr.shape().at(dim_idx);
}

std::string GetSymbolNameBySymbolBinding(
const GenerateShapeOp::SymbolBinding& symbol_binding) {
std::string GetSymbolNameBySymbolBinding(const SymbolBinding& symbol_binding) {
return std::visit([](const auto& impl) { return impl.symbol_name; },
symbol_binding);
}
Expand All @@ -360,10 +372,10 @@ std::string GetSymbolNameBySymbolBinding(

std::function<std::optional<DimExpr>(const std::string& symbol_name)>
MakeGetterDimExpr4SymbolName(
const GenerateShapeOp::SymbolBindings& symbol_bindings,
const SymbolBindings& symbol_bindings,
const std::function<const symbol::ShapeOrDataDimExprs&(int in_tensor_idx)>&
DimExpr4InputDim) {
std::unordered_map<std::string, std::vector<GenerateShapeOp::SymbolBinding>>
std::unordered_map<std::string, std::vector<SymbolBinding>>
symbol_name2symbol_bindins{};
for (const auto& symbol_binding : symbol_bindings) {
symbol_name2symbol_bindins[GetSymbolNameBySymbolBinding(symbol_binding)]
Expand Down Expand Up @@ -529,7 +541,7 @@ template <typename SymbolBindingsT>
void AppendSymbolBindings(const std::vector<symbol::DimExpr>& dim_exprs,
const std::set<std::string>& symbol_names,
int in_tensor_idx,
GenerateShapeOp::SymbolBindings* symbol_bindings) {
SymbolBindings* symbol_bindings) {
for (int in_tensor_dim_idx = 0; in_tensor_dim_idx < dim_exprs.size();
++in_tensor_dim_idx) {
const auto& dim_expr = dim_exprs.at(in_tensor_dim_idx);
Expand All @@ -549,14 +561,14 @@ void GenerateSymbolBindings(
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value,
const std::vector<pir::Value>& input_tensors,
const std::set<std::string>& symbol_names,
GenerateShapeOp::SymbolBindings* symbol_bindings) {
SymbolBindings* symbol_bindings) {
for (int i = 0; i < input_tensors.size(); ++i) {
const auto& input_tensor = input_tensors.at(i);
const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor);
AppendSymbolBindings<GenerateShapeOp::ShapeSymbolBinding>(
AppendSymbolBindings<ShapeSymbolBinding>(
dim_exprs.shape(), symbol_names, i, symbol_bindings);
if (dim_exprs.data().has_value()) {
AppendSymbolBindings<GenerateShapeOp::DataSymbolBinding>(
AppendSymbolBindings<DataSymbolBinding>(
dim_exprs.data().value(), symbol_names, i, symbol_bindings);
}
}
Expand Down Expand Up @@ -606,7 +618,7 @@ bool MakeGenerateShapeOpAttribute(
const std::vector<pir::Value>& origin_inputs,
std::vector<pir::Value>* minimal_inputs,
std::vector<pir::Attribute>* output_dim_expr_attrs,
GenerateShapeOp::SymbolBindings* symbol_bindings) {
SymbolBindings* symbol_bindings) {
*minimal_inputs = GetMinimalInputs(ShapeOrDataDimExprs4Value, origin_inputs);
if (!InputDimExprsAllSupported(ShapeOrDataDimExprs4Value, *minimal_inputs)) {
VLOG(4) << "input dim_exprs are not as simple as symbols, please make sure "
Expand Down
8 changes: 6 additions & 2 deletions paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <optional>
#include <vector>
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/symbol_bindings.h"
#include "paddle/pir/include/core/builder.h"
#include "paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h"

Expand All @@ -29,14 +30,17 @@ ::pir::Attribute ConvertDimExprToAttribute(pir::IrContext* ctx,
std::optional<symbol::DimExpr> ConvertAttributeToDimExpr(
::pir::Attribute attribute);

std::optional<std::vector<symbol::DimExpr>> ConvertAttributeToDimExprs(
::pir::Attribute attribute);

symbol::DimExpr SubstituteDimExpr(
const symbol::DimExpr& dim_expr,
const std::function<std::optional<symbol::DimExpr>(
const std::string& symbol_name)>& DimExpr4SymbolName);

std::function<std::optional<symbol::DimExpr>(const std::string& symbol_name)>
MakeGetterDimExpr4SymbolName(
const GenerateShapeOp::SymbolBindings& symbol_bindings,
const SymbolBindings& symbol_bindings,
const std::function<const symbol::ShapeOrDataDimExprs&(int in_tensor_idx)>&
DimExpr4InputDim);

Expand All @@ -51,6 +55,6 @@ bool MakeGenerateShapeOpAttribute(
const std::vector<pir::Value>& origin_inputs,
std::vector<pir::Value>* minimal_inputs,
std::vector<pir::Attribute>* output_dim_expr_attrs,
GenerateShapeOp::SymbolBindings* symbol_bindings);
SymbolBindings* symbol_bindings);

} // namespace cinn::dialect
23 changes: 9 additions & 14 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -478,24 +478,19 @@ GenerateShapeOp::ConvertAttributeToSymbolBindings(
}
return std::move(ret);
}

bool GenerateShapeOp::InferSymbolicShape(
pir::InferSymbolicShapeContext* infer_context) {
const auto attr_dim_exprs = [&] {
std::vector<symbol::DimExpr> dim_exprs{};
pir::Attribute dim_expr_attr = this->attributes().at("output_dim_exprs");
PADDLE_ENFORCE(dim_expr_attr.isa<pir::ArrayAttribute>(),
::common::errors::PreconditionNotMet(
"Required dim_expr_attr is ArrayAttribute."));
auto array = dim_expr_attr.dyn_cast<pir::ArrayAttribute>();
for (int i = 0; i < array.size(); ++i) {
const auto& dim_expr = ConvertAttributeToDimExpr(array.at(i));
PADDLE_ENFORCE(dim_expr.has_value(),
::common::errors::PreconditionNotMet(
"Required dim_expr.has_value()==true."));
dim_exprs.push_back(dim_expr.value());
}
return dim_exprs;
auto dim_exprs = ConvertAttributeToDimExprs(dim_expr_attr);

PADDLE_ENFORCE_EQ(
dim_exprs.has_value(),
true,
::common::errors::PreconditionNotMet(
"Required success to execute convert attribute to dim exprs."));

return dim_exprs.value();
}();
const auto symbol_bindings = [&] {
pir::Attribute symbol_bindings_attr =
Expand Down
18 changes: 6 additions & 12 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once
#include <variant>
#include "paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h"
#include "paddle/cinn/hlir/dialect/operator/ir/symbol_bindings.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h"
#include "paddle/phi/core/infermeta_utils.h"
Expand Down Expand Up @@ -154,18 +155,11 @@ class IR_API GenerateShapeOp
static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num];

struct SymbolBindingBase {
std::string symbol_name;
int64_t input_tensor_idx;
int64_t input_tensor_dim_idx;
};

struct DataSymbolBinding : public SymbolBindingBase {};
struct ShapeSymbolBinding : public SymbolBindingBase {};

using SymbolBinding = std::variant<DataSymbolBinding, ShapeSymbolBinding>;

using SymbolBindings = std::vector<SymbolBinding>;
using SymbolBindingBase = cinn::dialect::SymbolBindingBase;
using SymbolBinding = cinn::dialect::SymbolBinding;
using ShapeSymbolBinding = cinn::dialect::ShapeSymbolBinding;
using DataSymbolBinding = cinn::dialect::DataSymbolBinding;
using SymbolBindings = cinn::dialect::SymbolBindings;

static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
Expand Down
Loading