|
13 | 13 | // limitations under the License. |
14 | 14 |
|
15 | 15 | #include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h" |
| 16 | +#include <unordered_set> |
16 | 17 | #include "paddle/pir/core/builder.h" |
17 | 18 | #include "paddle/pir/core/builtin_attribute.h" |
18 | 19 |
|
@@ -422,4 +423,238 @@ MakeGetterDimExpr4SymbolName( |
422 | 423 | }; |
423 | 424 | } |
424 | 425 |
|
| 426 | +namespace { |
| 427 | + |
| 428 | +bool IsAtomicImpl(int64_t) { return true; } |
| 429 | + |
| 430 | +bool IsAtomicImpl(const std::string&) { return true; } |
| 431 | + |
| 432 | +bool IsAtomicImpl(const symbol::Negative<symbol::DimExpr>&) { return false; } |
| 433 | + |
| 434 | +bool IsAtomicImpl(const symbol::Reciprocal<symbol::DimExpr>&) { return false; } |
| 435 | + |
| 436 | +bool IsAtomicImpl(const symbol::Add<symbol::DimExpr>&) { return false; } |
| 437 | + |
| 438 | +bool IsAtomicImpl(const symbol::Mul<symbol::DimExpr>&) { return false; } |
| 439 | + |
| 440 | +bool IsAtomicImpl(const symbol::Max<symbol::DimExpr>&) { return false; } |
| 441 | + |
| 442 | +bool IsAtomicImpl(const symbol::Min<symbol::DimExpr>&) { return false; } |
| 443 | + |
| 444 | +bool IsAtomicImpl(const symbol::Broadcast<symbol::DimExpr>&) { return false; } |
| 445 | + |
| 446 | +bool IsAtomic(const symbol::DimExpr& dim_expr) { |
| 447 | + return std::visit([](const auto& impl) { return IsAtomicImpl(impl); }, |
| 448 | + dim_expr.variant()); |
| 449 | +} |
| 450 | + |
| 451 | +bool InputDimExprsAllSupported( |
| 452 | + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, |
| 453 | + const std::vector<pir::Value>& input_tensors) { |
| 454 | + const auto& AllSupported = |
| 455 | + [](const std::vector<symbol::DimExpr>& dim_exprs) -> bool { |
| 456 | + for (const auto& dim_expr : dim_exprs) { |
| 457 | + if (!IsAtomic(dim_expr)) return false; |
| 458 | + } |
| 459 | + return true; |
| 460 | + }; |
| 461 | + for (const auto& input_tensor : input_tensors) { |
| 462 | + const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor); |
| 463 | + if (!AllSupported(dim_exprs.shape())) return false; |
| 464 | + if (dim_exprs.data().has_value()) { |
| 465 | + if (!AllSupported(dim_exprs.data().value())) return false; |
| 466 | + } |
| 467 | + } |
| 468 | + return true; |
| 469 | +} |
| 470 | + |
| 471 | +void ConvertDimExprToAttributes(pir::IrContext* ir_context, |
| 472 | + const std::vector<symbol::DimExpr>& dim_exprs, |
| 473 | + std::vector<pir::Attribute>* attrs) { |
| 474 | + attrs->clear(); |
| 475 | + attrs->reserve(dim_exprs.size()); |
| 476 | + for (const auto& dim_expr : dim_exprs) { |
| 477 | + attrs->emplace_back(ConvertDimExprToAttribute(ir_context, dim_expr)); |
| 478 | + } |
| 479 | +} |
| 480 | + |
| 481 | +void CollectSymbolNames(const symbol::DimExpr& dim_expr, |
| 482 | + std::set<std::string>* ret); |
| 483 | + |
| 484 | +void CollectSymbolNamesImpl(const int64_t& dim_expr, |
| 485 | + std::set<std::string>* ret) { |
| 486 | + // do nothing. |
| 487 | +} |
| 488 | + |
| 489 | +void CollectSymbolNamesImpl(const std::string& dim_expr, |
| 490 | + std::set<std::string>* ret) { |
| 491 | + ret->insert(dim_expr); |
| 492 | +} |
| 493 | + |
| 494 | +template <typename T> |
| 495 | +void CollectSymbolNamesImplForUnary(const T& dim_expr, |
| 496 | + std::set<std::string>* ret) { |
| 497 | + const auto& [operand] = *dim_expr; |
| 498 | + CollectSymbolNames(operand, ret); |
| 499 | +} |
| 500 | + |
| 501 | +void CollectSymbolNamesImpl(const symbol::Negative<symbol::DimExpr>& dim_expr, |
| 502 | + std::set<std::string>* ret) { |
| 503 | + CollectSymbolNamesImplForUnary(dim_expr, ret); |
| 504 | +} |
| 505 | + |
| 506 | +void CollectSymbolNamesImpl(const symbol::Reciprocal<symbol::DimExpr>& dim_expr, |
| 507 | + std::set<std::string>* ret) { |
| 508 | + CollectSymbolNamesImplForUnary(dim_expr, ret); |
| 509 | +} |
| 510 | + |
| 511 | +template <typename T> |
| 512 | +void CollectSymbolNamesImplForVariadic(const T& dim_expr, |
| 513 | + std::set<std::string>* ret) { |
| 514 | + const auto& operands = *(dim_expr.operands); |
| 515 | + for (const auto& operand : operands) { |
| 516 | + CollectSymbolNames(operand, ret); |
| 517 | + } |
| 518 | +} |
| 519 | + |
| 520 | +void CollectSymbolNamesImpl(const symbol::Add<symbol::DimExpr>& dim_expr, |
| 521 | + std::set<std::string>* ret) { |
| 522 | + CollectSymbolNamesImplForVariadic(dim_expr, ret); |
| 523 | +} |
| 524 | + |
| 525 | +void CollectSymbolNamesImpl(const symbol::Mul<symbol::DimExpr>& dim_expr, |
| 526 | + std::set<std::string>* ret) { |
| 527 | + CollectSymbolNamesImplForVariadic(dim_expr, ret); |
| 528 | +} |
| 529 | + |
| 530 | +void CollectSymbolNamesImpl(const symbol::Max<symbol::DimExpr>& dim_expr, |
| 531 | + std::set<std::string>* ret) { |
| 532 | + CollectSymbolNamesImplForVariadic(dim_expr, ret); |
| 533 | +} |
| 534 | + |
| 535 | +void CollectSymbolNamesImpl(const symbol::Min<symbol::DimExpr>& dim_expr, |
| 536 | + std::set<std::string>* ret) { |
| 537 | + CollectSymbolNamesImplForVariadic(dim_expr, ret); |
| 538 | +} |
| 539 | + |
| 540 | +void CollectSymbolNamesImpl(const symbol::Broadcast<symbol::DimExpr>& dim_expr, |
| 541 | + std::set<std::string>* ret) { |
| 542 | + CollectSymbolNamesImplForVariadic(dim_expr, ret); |
| 543 | +} |
| 544 | + |
| 545 | +void CollectSymbolNames(const symbol::DimExpr& dim_expr, |
| 546 | + std::set<std::string>* ret) { |
| 547 | + return std::visit( |
| 548 | + [&](const auto& impl) { return CollectSymbolNamesImpl(impl, ret); }, |
| 549 | + dim_expr.variant()); |
| 550 | +} |
| 551 | + |
| 552 | +void CollectSymbolNames(const std::vector<symbol::DimExpr>& dim_exprs, |
| 553 | + std::set<std::string>* ret) { |
| 554 | + for (const auto& dim_expr : dim_exprs) { |
| 555 | + CollectSymbolNames(dim_expr, ret); |
| 556 | + } |
| 557 | +} |
| 558 | + |
| 559 | +template <typename SymbolBindingsT> |
| 560 | +void AppendSymbolBindings(const std::vector<symbol::DimExpr>& dim_exprs, |
| 561 | + const std::set<std::string>& symbol_names, |
| 562 | + int in_tensor_idx, |
| 563 | + GenerateShapeOp::SymbolBindings* symbol_bindings) { |
| 564 | + for (int in_tensor_dim_idx = 0; in_tensor_dim_idx < dim_exprs.size(); |
| 565 | + ++in_tensor_dim_idx) { |
| 566 | + const auto& dim_expr = dim_exprs.at(in_tensor_dim_idx); |
| 567 | + CHECK(IsAtomic(dim_expr)); |
| 568 | + if (!dim_expr.isa<std::string>()) continue; |
| 569 | + const auto& sym_name = dim_expr.dyn_cast<std::string>(); |
| 570 | + if (symbol_names.find(sym_name) == symbol_names.end()) continue; |
| 571 | + symbol_bindings->emplace_back(SymbolBindingsT{ |
| 572 | + /*.symbol_name=*/sym_name, |
| 573 | + /*.input_tensor_idx=*/in_tensor_idx, |
| 574 | + /*.input_tensor_dim_idx=*/in_tensor_dim_idx, |
| 575 | + }); |
| 576 | + } |
| 577 | +} |
| 578 | + |
| 579 | +void GenerateSymbolBindings( |
| 580 | + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, |
| 581 | + const std::vector<pir::Value>& input_tensors, |
| 582 | + const std::set<std::string>& symbol_names, |
| 583 | + GenerateShapeOp::SymbolBindings* symbol_bindings) { |
| 584 | + for (int i = 0; i < input_tensors.size(); ++i) { |
| 585 | + const auto& input_tensor = input_tensors.at(i); |
| 586 | + const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor); |
| 587 | + AppendSymbolBindings<GenerateShapeOp::ShapeSymbolBinding>( |
| 588 | + dim_exprs.shape(), symbol_names, i, symbol_bindings); |
| 589 | + if (dim_exprs.data().has_value()) { |
| 590 | + AppendSymbolBindings<GenerateShapeOp::DataSymbolBinding>( |
| 591 | + dim_exprs.shape(), symbol_names, i, symbol_bindings); |
| 592 | + } |
| 593 | + } |
| 594 | +} |
| 595 | + |
| 596 | +std::vector<pir::Value> GetMinimalInputs( |
| 597 | + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, |
| 598 | + const std::vector<pir::Value>& input_tensors) { |
| 599 | + std::unordered_set<symbol::DimExpr> handdled_dim_exprs; |
| 600 | + std::unordered_set<pir::Value> first_occurred_input_tensors; |
| 601 | + auto TryCollectFirstOcurredInput_tensor = |
| 602 | + [&](pir::Value input_tensor, |
| 603 | + const std::vector<symbol::DimExpr>& dim_exprs) { |
| 604 | + for (const auto& dim_expr : dim_exprs) { |
| 605 | + if (dim_expr.isa<int64_t>()) continue; |
| 606 | + if (!handdled_dim_exprs.insert(dim_expr).second) { |
| 607 | + first_occurred_input_tensors.insert(input_tensor); |
| 608 | + } |
| 609 | + } |
| 610 | + }; |
| 611 | + for (pir::Value input_tensor : input_tensors) { |
| 612 | + const auto& shape_or_data_dim_exprs = |
| 613 | + ShapeOrDataDimExprs4Value(input_tensor); |
| 614 | + if (shape_or_data_dim_exprs.data().has_value()) { |
| 615 | + TryCollectFirstOcurredInput_tensor( |
| 616 | + input_tensor, shape_or_data_dim_exprs.data().value()); |
| 617 | + } |
| 618 | + TryCollectFirstOcurredInput_tensor(input_tensor, |
| 619 | + shape_or_data_dim_exprs.shape()); |
| 620 | + } |
| 621 | + std::vector<pir::Value> ret{}; |
| 622 | + ret.reserve(input_tensors.size()); |
| 623 | + for (pir::Value input_tensor : input_tensors) { |
| 624 | + if (first_occurred_input_tensors.count(input_tensor) > 0) { |
| 625 | + ret.emplace_back(input_tensor); |
| 626 | + } |
| 627 | + } |
| 628 | + return ret; |
| 629 | +} |
| 630 | + |
| 631 | +} // namespace |
| 632 | + |
| 633 | +bool MakeGenerateShapeOpAttribute( |
| 634 | + pir::IrContext* ir_context, |
| 635 | + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, |
| 636 | + const std::vector<symbol::DimExpr>& out_dim_exprs, |
| 637 | + const std::vector<pir::Value>& origin_inputs, |
| 638 | + std::vector<pir::Value>* minial_inputs, |
| 639 | + std::vector<pir::Attribute>* output_dim_expr_attrs, |
| 640 | + GenerateShapeOp::SymbolBindings* symbol_bindings) { |
| 641 | + *minial_inputs = GetMinimalInputs(ShapeOrDataDimExprs4Value, origin_inputs); |
| 642 | + if (!InputDimExprsAllSupported(ShapeOrDataDimExprs4Value, *minial_inputs)) { |
| 643 | + VLOG(4) << "input dim_exprs are not as simple as symbols, please make sure " |
| 644 | + "they are handled by other passes"; |
| 645 | + return false; |
| 646 | + } |
| 647 | + // generate output_dim_expr_attrs |
| 648 | + ConvertDimExprToAttributes( |
| 649 | + ir_context, out_dim_exprs, /*out*/ output_dim_expr_attrs); |
| 650 | + // generate symbol_bindings |
| 651 | + std::set<std::string> symbol_names_in_out_dim_exprs{}; |
| 652 | + CollectSymbolNames(out_dim_exprs, &symbol_names_in_out_dim_exprs); |
| 653 | + GenerateSymbolBindings(ShapeOrDataDimExprs4Value, |
| 654 | + *minial_inputs, |
| 655 | + symbol_names_in_out_dim_exprs, |
| 656 | + /*out*/ symbol_bindings); |
| 657 | + return true; |
| 658 | +} |
| 659 | + |
425 | 660 | } // namespace cinn::dialect |
0 commit comments