Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
#include "paddle/fluid/pir/dialect/operator/ir/ir_tensor.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/core/builtin_type.h"
#include "paddle/pir/include/core/op_base.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h"
#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/dialect/shape/utils/dim_expr_util.h"

namespace cinn {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "paddle/pir/include/core/ir_context.h"
#include "paddle/pir/include/core/program.h"
#include "paddle/pir/include/dialect/shape/ir/shape_dialect.h"
#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/pass/pass_manager.h"

#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
Expand All @@ -45,7 +46,6 @@
#include "paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.h"
#include "paddle/fluid/pir/transforms/build_cinn_pass.h"
#include "paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h"
#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"

COMMON_DECLARE_bool(print_ir);
PD_DECLARE_bool(group_schedule_tiling_first);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include "paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.h"
#include "paddle/common/flags.h"
#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h"

COMMON_DECLARE_bool(check_infer_symbolic);
PD_DECLARE_bool(prim_all);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
#include "paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.h"
#include "paddle/fluid/pir/transforms/passes.h"
#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/pass/pass_manager.h"
#include "paddle/pir/include/pass/pass_registry.h"

Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/pir/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,7 @@ set(op_dialect_srcs
${pir_op_source_file}
${pir_bwd_op_source_file}
${pir_update_op_source_file}
${api_source_file}
${PADDLE_SOURCE_DIR}/paddle/fluid/pir/transforms/shape_optimization_pass.cc)
${api_source_file})

if(WITH_ONEDNN)
set(op_dialect_srcs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,49 +21,11 @@
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h"
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.h"
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h"
#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h"

// Type inference is currently modelled executionally for operation creation
// using the `InferMetaInterface`. While `InferSymbolicShapeInterface` is used
// to implement the shape and element type inference. The return type can often
// be deduced from the deduced return shape and elemental type (queryable from
// `InferSymbolicShapeInterface`) and so type inference for tensor types can be
// implemented with `InferSymbolicShapeInterface`.
#include "paddle/pir/include/dialect/shape/interface/infer_symbolic_shape/infer_symbolic_shape.h"

namespace paddle::dialect {

class InferSymbolicShapeInterface
: public pir::OpInterfaceBase<InferSymbolicShapeInterface> {
public:
/// Defined these methods with the interface.
struct Concept {
explicit Concept(bool (*infer_symbolic_shapes)(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis))
: infer_symbolic_shapes(infer_symbolic_shapes) {}
bool (*infer_symbolic_shapes)(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);
};

template <class ConcreteOp>
struct Model : public Concept {
static inline bool InferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
return op->dyn_cast<ConcreteOp>().InferSymbolicShape(shape_analysis);
}

Model() : Concept(InferSymbolicShape) {}
};

/// Constructor
InferSymbolicShapeInterface(pir::Operation *op, Concept *impl)
: pir::OpInterfaceBase<InferSymbolicShapeInterface>(op), impl_(impl) {}

bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);

private:
Concept *impl_;
};
using InferSymbolicShapeInterface = pir::InferSymbolicShapeInterface;

} // namespace paddle::dialect

IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::InferSymbolicShapeInterface)
3 changes: 2 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ paddle::dialect::IfOp, paddle::dialect::WhileOp, paddle::dialect::HasElementsOp,
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h"

#include "paddle/phi/core/enforce.h"
#include "paddle/pir/include/core/builder.h"
#include "paddle/pir/include/core/builtin_attribute.h"
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/fluid/pir/transforms/passes.h"
#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"
#include "paddle/fluid/pybind/control_flow_api.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/pybind_variant_caster.h"
Expand All @@ -63,6 +62,7 @@
#include "paddle/pir/include/dialect/control_flow/ir/cf_dialect.h"
#include "paddle/pir/include/dialect/shape/ir/shape_attribute.h"
#include "paddle/pir/include/dialect/shape/ir/shape_dialect.h"
#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/pass/pass.h"
#include "paddle/pir/include/pass/pass_manager.h"
#include "paddle/pir/include/pass/pass_registry.h"
Expand Down
11 changes: 5 additions & 6 deletions paddle/pir/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
add_definitions(-DIR_LIBRARY)
set_property(GLOBAL PROPERTY IR_TARGETS "")

file(
GLOB_RECURSE
PIR_CPP_SOURCES
"*.cc"
${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.cc
)
file(GLOB_RECURSE PIR_CPP_SOURCES "*.cc")

if(WIN32)
if(WITH_SHARED_IR)
Expand Down Expand Up @@ -56,3 +51,7 @@ else()
set(ir_targets pir)
set_property(GLOBAL PROPERTY IR_TARGETS "${ir_targets}")
endif()

if((CMAKE_CXX_COMPILER_ID STREQUAL "GNU"))
set_target_properties(pir PROPERTIES COMPILE_FLAGS "-Wno-maybe-uninitialized")
endif()
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h"

// Type inference is currently modelled executionally for operation creation
// using the `InferMetaInterface`. While `InferSymbolicShapeInterface` is used
// to implement the shape and element type inference. The return type can often
// be deduced from the deduced return shape and elemental type (queryable from
// `InferSymbolicShapeInterface`) and so type inference for tensor types can be
// implemented with `InferSymbolicShapeInterface`.

namespace pir {

class InferSymbolicShapeInterface
: public pir::OpInterfaceBase<InferSymbolicShapeInterface> {
public:
/// Defined these methods with the interface.
struct Concept {
explicit Concept(bool (*infer_symbolic_shapes)(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis))
: infer_symbolic_shapes(infer_symbolic_shapes) {}
bool (*infer_symbolic_shapes)(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);
};

template <class ConcreteOp>
struct Model : public Concept {
static inline bool InferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
return op->dyn_cast<ConcreteOp>().InferSymbolicShape(shape_analysis);
}

Model() : Concept(InferSymbolicShape) {}
};

/// Constructor
InferSymbolicShapeInterface(pir::Operation *op, Concept *impl)
: pir::OpInterfaceBase<InferSymbolicShapeInterface>(op), impl_(impl) {}

bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);

private:
Concept *impl_;
};

} // namespace pir

IR_DECLARE_EXPLICIT_TYPE_ID(pir::InferSymbolicShapeInterface)
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h"
#include "paddle/pir/include/dialect/shape/interface/infer_symbolic_shape/infer_symbolic_shape.h"

// This file implements the infer_symbolic_shape interface for both paddle and
// cinn operators.

// Add `interfaces : paddle::dialect::InferSymbolicShapeInterface` in relative
// Add `interfaces : pir::InferSymbolicShapeInterface` in relative
// yaml file to conresponding op.

// Since necessary checks have been done in the Op's `InferMeta` and `VeriySig`,
// no more repetitive work here.

namespace paddle::dialect {
namespace pir {

bool InferSymbolicShapeInterface::InferSymbolicShape(
pir::ShapeConstraintIRAnalysis *shape_analysis) {
return impl_->infer_symbolic_shapes(operation(), shape_analysis);
}
} // namespace paddle::dialect

IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::InferSymbolicShapeInterface)
} // namespace pir

IR_DEFINE_EXPLICIT_TYPE_ID(pir::InferSymbolicShapeInterface)
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h"

#include "paddle/common/flags.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/pir/include/core/builtin_type.h"
#include "paddle/pir/include/core/dialect.h"
#include "paddle/pir/include/core/ir_printer.h"
#include "paddle/pir/include/dialect/shape/interface/infer_symbolic_shape/infer_symbolic_shape.h"
#include "paddle/pir/include/dialect/shape/ir/shape_attribute.h"
#include "paddle/pir/include/dialect/shape/ir/shape_dialect.h"
#include "paddle/pir/include/pass/pass_manager.h"
Expand Down Expand Up @@ -173,9 +175,9 @@ void CheckInferSymWithInferMeta(
// InferMeta funcs of some Ops are not corrrect now, we don't check them.
if (!NeedCheckInferSymbolicWithInferMeta(op->name(), i)) continue;

if (res.type().isa<paddle::dialect::DenseTensorType>()) {
const std::vector<int64_t>& infer_meta_shape = common::vectorize(
res.type().dyn_cast<paddle::dialect::DenseTensorType>().dims());
if (res.type().isa<pir::DenseTensorType>()) {
const std::vector<int64_t>& infer_meta_shape =
common::vectorize(res.type().dyn_cast<pir::DenseTensorType>().dims());
const std::vector<symbol::DimExpr>& infer_sym_shape =
shape_analysis->GetShapeOrDataForValue(res).shape();

Expand Down Expand Up @@ -272,12 +274,11 @@ class ShapeOptimizationPass : public pir::Pass {

static inline bool IsStaticShape(const Value& value) {
const auto& value_type = value.type();
if (!value || !value_type ||
!value_type.isa<paddle::dialect::DenseTensorType>()) {
if (!value || !value_type || !value_type.isa<pir::DenseTensorType>()) {
return false;
}
return !::common::contain_unknown_dim(
value_type.dyn_cast<paddle::dialect::DenseTensorType>().dims());
value_type.dyn_cast<pir::DenseTensorType>().dims());
}

symbol::ShapeOrDataDimExprs CreateShapeOrDataByDDim(const pir::DDim& dims) {
Expand All @@ -292,7 +293,7 @@ void InferSymExprForBlock(const Block& block,
ShapeConstraintIRAnalysis* shape_analysis) {
for (auto& op : block) {
auto infer_symbolic_shape_interface =
op.dyn_cast<paddle::dialect::InferSymbolicShapeInterface>();
op.dyn_cast<pir::InferSymbolicShapeInterface>();
if (infer_symbolic_shape_interface) {
PrintOpInfo(&op);
PADDLE_ENFORCE_EQ(
Expand Down Expand Up @@ -326,10 +327,7 @@ void InferSymExprForBlock(const Block& block,
shape_analysis->SetShapeOrDataForValue(
op.result(i),
CreateShapeOrDataByDDim(
op.result(i)
.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims()));
op.result(i).type().dyn_cast<pir::DenseTensorType>().dims()));
}
} else {
PADDLE_THROW(phi::errors::Unimplemented(
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/pir/cinn/adt/map_expr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
#include "paddle/cinn/runtime/flags.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/core/ir_context.h"
#include "paddle/pir/include/core/program.h"
#include "paddle/pir/include/dialect/shape/ir/shape_dialect.h"
#include "paddle/pir/include/dialect/shape/ir/shape_op.h"
#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/pass/pass_manager.h"
#include "test/cpp/pir/tools/test_pir_utils.h"

Expand Down
2 changes: 1 addition & 1 deletion test/cpp/pir/shape_dialect/infer_symbolic_shape_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
#include <gtest/gtest.h>
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/core/builtin_type_interfaces.h"
#include "paddle/pir/include/dialect/shape/ir/shape_dialect.h"
#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/pass/pass_manager.h"
#include "test/cpp/pir/tools/test_pir_utils.h"

Expand Down
2 changes: 1 addition & 1 deletion test/cpp/pir/shape_dialect/shape_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
#include <gtest/gtest.h>
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/core/builtin_type_interfaces.h"
#include "paddle/pir/include/dialect/shape/ir/shape_dialect.h"
#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/pass/pass_manager.h"
#include "test/cpp/pir/tools/test_pir_utils.h"

Expand Down
2 changes: 1 addition & 1 deletion test/cpp/pir/shape_dialect/shape_optimization_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
#include <gtest/gtest.h>
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/core/builtin_type_interfaces.h"
#include "paddle/pir/include/dialect/shape/ir/shape_dialect.h"
#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/pass/pass_manager.h"
#include "test/cpp/pir/tools/test_pir_utils.h"

Expand Down