Skip to content

Commit 54b95ae

Browse files
authored
[Dynamic Shape] Add helper function MakeGenerateShapeOpAttribute (#60512)
* add helper function MakeGenerateShapeOpAttribute * fix complier complaint * Code format
1 parent 698bb42 commit 54b95ae

File tree

3 files changed

+261
-198
lines changed

3 files changed

+261
-198
lines changed

paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h"
16+
#include <unordered_set>
1617
#include "paddle/pir/core/builder.h"
1718
#include "paddle/pir/core/builtin_attribute.h"
1819

@@ -422,4 +423,238 @@ MakeGetterDimExpr4SymbolName(
422423
};
423424
}
424425

426+
namespace {
427+
428+
bool IsAtomicImpl(int64_t) { return true; }
429+
430+
bool IsAtomicImpl(const std::string&) { return true; }
431+
432+
bool IsAtomicImpl(const symbol::Negative<symbol::DimExpr>&) { return false; }
433+
434+
bool IsAtomicImpl(const symbol::Reciprocal<symbol::DimExpr>&) { return false; }
435+
436+
bool IsAtomicImpl(const symbol::Add<symbol::DimExpr>&) { return false; }
437+
438+
bool IsAtomicImpl(const symbol::Mul<symbol::DimExpr>&) { return false; }
439+
440+
bool IsAtomicImpl(const symbol::Max<symbol::DimExpr>&) { return false; }
441+
442+
bool IsAtomicImpl(const symbol::Min<symbol::DimExpr>&) { return false; }
443+
444+
bool IsAtomicImpl(const symbol::Broadcast<symbol::DimExpr>&) { return false; }
445+
446+
bool IsAtomic(const symbol::DimExpr& dim_expr) {
447+
return std::visit([](const auto& impl) { return IsAtomicImpl(impl); },
448+
dim_expr.variant());
449+
}
450+
451+
bool InputDimExprsAllSupported(
452+
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value,
453+
const std::vector<pir::Value>& input_tensors) {
454+
const auto& AllSupported =
455+
[](const std::vector<symbol::DimExpr>& dim_exprs) -> bool {
456+
for (const auto& dim_expr : dim_exprs) {
457+
if (!IsAtomic(dim_expr)) return false;
458+
}
459+
return true;
460+
};
461+
for (const auto& input_tensor : input_tensors) {
462+
const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor);
463+
if (!AllSupported(dim_exprs.shape())) return false;
464+
if (dim_exprs.data().has_value()) {
465+
if (!AllSupported(dim_exprs.data().value())) return false;
466+
}
467+
}
468+
return true;
469+
}
470+
471+
void ConvertDimExprToAttributes(pir::IrContext* ir_context,
472+
const std::vector<symbol::DimExpr>& dim_exprs,
473+
std::vector<pir::Attribute>* attrs) {
474+
attrs->clear();
475+
attrs->reserve(dim_exprs.size());
476+
for (const auto& dim_expr : dim_exprs) {
477+
attrs->emplace_back(ConvertDimExprToAttribute(ir_context, dim_expr));
478+
}
479+
}
480+
481+
void CollectSymbolNames(const symbol::DimExpr& dim_expr,
482+
std::set<std::string>* ret);
483+
484+
void CollectSymbolNamesImpl(const int64_t& dim_expr,
485+
std::set<std::string>* ret) {
486+
// do nothing.
487+
}
488+
489+
void CollectSymbolNamesImpl(const std::string& dim_expr,
490+
std::set<std::string>* ret) {
491+
ret->insert(dim_expr);
492+
}
493+
494+
template <typename T>
495+
void CollectSymbolNamesImplForUnary(const T& dim_expr,
496+
std::set<std::string>* ret) {
497+
const auto& [operand] = *dim_expr;
498+
CollectSymbolNames(operand, ret);
499+
}
500+
501+
void CollectSymbolNamesImpl(const symbol::Negative<symbol::DimExpr>& dim_expr,
502+
std::set<std::string>* ret) {
503+
CollectSymbolNamesImplForUnary(dim_expr, ret);
504+
}
505+
506+
void CollectSymbolNamesImpl(const symbol::Reciprocal<symbol::DimExpr>& dim_expr,
507+
std::set<std::string>* ret) {
508+
CollectSymbolNamesImplForUnary(dim_expr, ret);
509+
}
510+
511+
template <typename T>
512+
void CollectSymbolNamesImplForVariadic(const T& dim_expr,
513+
std::set<std::string>* ret) {
514+
const auto& operands = *(dim_expr.operands);
515+
for (const auto& operand : operands) {
516+
CollectSymbolNames(operand, ret);
517+
}
518+
}
519+
520+
void CollectSymbolNamesImpl(const symbol::Add<symbol::DimExpr>& dim_expr,
521+
std::set<std::string>* ret) {
522+
CollectSymbolNamesImplForVariadic(dim_expr, ret);
523+
}
524+
525+
void CollectSymbolNamesImpl(const symbol::Mul<symbol::DimExpr>& dim_expr,
526+
std::set<std::string>* ret) {
527+
CollectSymbolNamesImplForVariadic(dim_expr, ret);
528+
}
529+
530+
void CollectSymbolNamesImpl(const symbol::Max<symbol::DimExpr>& dim_expr,
531+
std::set<std::string>* ret) {
532+
CollectSymbolNamesImplForVariadic(dim_expr, ret);
533+
}
534+
535+
void CollectSymbolNamesImpl(const symbol::Min<symbol::DimExpr>& dim_expr,
536+
std::set<std::string>* ret) {
537+
CollectSymbolNamesImplForVariadic(dim_expr, ret);
538+
}
539+
540+
void CollectSymbolNamesImpl(const symbol::Broadcast<symbol::DimExpr>& dim_expr,
541+
std::set<std::string>* ret) {
542+
CollectSymbolNamesImplForVariadic(dim_expr, ret);
543+
}
544+
545+
void CollectSymbolNames(const symbol::DimExpr& dim_expr,
546+
std::set<std::string>* ret) {
547+
return std::visit(
548+
[&](const auto& impl) { return CollectSymbolNamesImpl(impl, ret); },
549+
dim_expr.variant());
550+
}
551+
552+
void CollectSymbolNames(const std::vector<symbol::DimExpr>& dim_exprs,
553+
std::set<std::string>* ret) {
554+
for (const auto& dim_expr : dim_exprs) {
555+
CollectSymbolNames(dim_expr, ret);
556+
}
557+
}
558+
559+
template <typename SymbolBindingsT>
560+
void AppendSymbolBindings(const std::vector<symbol::DimExpr>& dim_exprs,
561+
const std::set<std::string>& symbol_names,
562+
int in_tensor_idx,
563+
GenerateShapeOp::SymbolBindings* symbol_bindings) {
564+
for (int in_tensor_dim_idx = 0; in_tensor_dim_idx < dim_exprs.size();
565+
++in_tensor_dim_idx) {
566+
const auto& dim_expr = dim_exprs.at(in_tensor_dim_idx);
567+
CHECK(IsAtomic(dim_expr));
568+
if (!dim_expr.isa<std::string>()) continue;
569+
const auto& sym_name = dim_expr.dyn_cast<std::string>();
570+
if (symbol_names.find(sym_name) == symbol_names.end()) continue;
571+
symbol_bindings->emplace_back(SymbolBindingsT{
572+
/*.symbol_name=*/sym_name,
573+
/*.input_tensor_idx=*/in_tensor_idx,
574+
/*.input_tensor_dim_idx=*/in_tensor_dim_idx,
575+
});
576+
}
577+
}
578+
579+
void GenerateSymbolBindings(
580+
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value,
581+
const std::vector<pir::Value>& input_tensors,
582+
const std::set<std::string>& symbol_names,
583+
GenerateShapeOp::SymbolBindings* symbol_bindings) {
584+
for (int i = 0; i < input_tensors.size(); ++i) {
585+
const auto& input_tensor = input_tensors.at(i);
586+
const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor);
587+
AppendSymbolBindings<GenerateShapeOp::ShapeSymbolBinding>(
588+
dim_exprs.shape(), symbol_names, i, symbol_bindings);
589+
if (dim_exprs.data().has_value()) {
590+
AppendSymbolBindings<GenerateShapeOp::DataSymbolBinding>(
591+
dim_exprs.shape(), symbol_names, i, symbol_bindings);
592+
}
593+
}
594+
}
595+
596+
std::vector<pir::Value> GetMinimalInputs(
597+
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value,
598+
const std::vector<pir::Value>& input_tensors) {
599+
std::unordered_set<symbol::DimExpr> handdled_dim_exprs;
600+
std::unordered_set<pir::Value> first_occurred_input_tensors;
601+
auto TryCollectFirstOcurredInput_tensor =
602+
[&](pir::Value input_tensor,
603+
const std::vector<symbol::DimExpr>& dim_exprs) {
604+
for (const auto& dim_expr : dim_exprs) {
605+
if (dim_expr.isa<int64_t>()) continue;
606+
if (!handdled_dim_exprs.insert(dim_expr).second) {
607+
first_occurred_input_tensors.insert(input_tensor);
608+
}
609+
}
610+
};
611+
for (pir::Value input_tensor : input_tensors) {
612+
const auto& shape_or_data_dim_exprs =
613+
ShapeOrDataDimExprs4Value(input_tensor);
614+
if (shape_or_data_dim_exprs.data().has_value()) {
615+
TryCollectFirstOcurredInput_tensor(
616+
input_tensor, shape_or_data_dim_exprs.data().value());
617+
}
618+
TryCollectFirstOcurredInput_tensor(input_tensor,
619+
shape_or_data_dim_exprs.shape());
620+
}
621+
std::vector<pir::Value> ret{};
622+
ret.reserve(input_tensors.size());
623+
for (pir::Value input_tensor : input_tensors) {
624+
if (first_occurred_input_tensors.count(input_tensor) > 0) {
625+
ret.emplace_back(input_tensor);
626+
}
627+
}
628+
return ret;
629+
}
630+
631+
} // namespace
632+
633+
bool MakeGenerateShapeOpAttribute(
634+
pir::IrContext* ir_context,
635+
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value,
636+
const std::vector<symbol::DimExpr>& out_dim_exprs,
637+
const std::vector<pir::Value>& origin_inputs,
638+
std::vector<pir::Value>* minial_inputs,
639+
std::vector<pir::Attribute>* output_dim_expr_attrs,
640+
GenerateShapeOp::SymbolBindings* symbol_bindings) {
641+
*minial_inputs = GetMinimalInputs(ShapeOrDataDimExprs4Value, origin_inputs);
642+
if (!InputDimExprsAllSupported(ShapeOrDataDimExprs4Value, *minial_inputs)) {
643+
VLOG(4) << "input dim_exprs are not as simple as symbols, please make sure "
644+
"they are handled by other passes";
645+
return false;
646+
}
647+
// generate output_dim_expr_attrs
648+
ConvertDimExprToAttributes(
649+
ir_context, out_dim_exprs, /*out*/ output_dim_expr_attrs);
650+
// generate symbol_bindings
651+
std::set<std::string> symbol_names_in_out_dim_exprs{};
652+
CollectSymbolNames(out_dim_exprs, &symbol_names_in_out_dim_exprs);
653+
GenerateSymbolBindings(ShapeOrDataDimExprs4Value,
654+
*minial_inputs,
655+
symbol_names_in_out_dim_exprs,
656+
/*out*/ symbol_bindings);
657+
return true;
658+
}
659+
425660
} // namespace cinn::dialect

paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
#pragma once
1616

17+
#include <functional>
1718
#include <optional>
19+
#include <vector>
1820
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
1921
#include "paddle/pir/core/builder.h"
2022
#include "paddle/pir/dialect/shape/utils/dim_expr.h"
@@ -46,4 +48,17 @@ MakeGetterDimExpr4SymbolName(
4648
const std::function<const symbol::ShapeOrDataDimExprs&(int in_tensor_idx)>&
4749
DimExpr4InputDim);
4850

51+
using ShapeOrDataDimExprs4ValueT =
52+
std::function<const symbol::ShapeOrDataDimExprs&(pir::Value)>;
53+
54+
// Returns true if success.
55+
bool MakeGenerateShapeOpAttribute(
56+
pir::IrContext* ir_context,
57+
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value,
58+
const std::vector<symbol::DimExpr>& out_dim_exprs,
59+
const std::vector<pir::Value>& origin_inputs,
60+
std::vector<pir::Value>* minial_inputs,
61+
std::vector<pir::Attribute>* output_dim_expr_attrs,
62+
GenerateShapeOp::SymbolBindings* symbol_bindings);
63+
4964
} // namespace cinn::dialect

0 commit comments

Comments
 (0)