Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
286ae4e
add immutable_layout_trait
kangguangli Apr 9, 2024
0154be9
Squashed commit of the following:
kangguangli Apr 17, 2024
d26e874
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into transfe…
kangguangli Apr 17, 2024
eee7c2b
insert successfully
kangguangli Apr 25, 2024
a2c89f6
Squashed commit of the following:
kangguangli Apr 29, 2024
81c61d1
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into transfe…
kangguangli Apr 29, 2024
ddf14b8
fix
kangguangli Apr 29, 2024
79d1d18
fix
kangguangli Apr 29, 2024
629fb47
fix
kangguangli May 6, 2024
884b562
fix
kangguangli May 7, 2024
af6602b
fix
kangguangli May 7, 2024
c33dd51
fix
kangguangli May 7, 2024
d8cec6e
fix
kangguangli May 8, 2024
8958067
fix bug on windows
kangguangli May 8, 2024
c592425
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into transfe…
kangguangli May 8, 2024
5aa422f
Squashed commit of the following:
kangguangli May 10, 2024
d726fb0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
kangguangli May 10, 2024
d508681
fix
kangguangli May 10, 2024
5b37b7c
Merge commit 'refs/pull/63628/head' of github.com:PaddlePaddle/Paddle…
kangguangli May 13, 2024
d77907a
replace std::cout with glog
kangguangli May 13, 2024
ec97d9c
fix infermeta of conv
kangguangli May 15, 2024
2a7b281
revert comment
kangguangli May 15, 2024
ee94451
use whilte list instead of black list
kangguangli May 16, 2024
531f920
remove debug code
kangguangli May 16, 2024
202720f
remove debug code
kangguangli May 16, 2024
2ec4435
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into transfe…
kangguangli May 16, 2024
c74e7ac
remove debug code
kangguangli May 16, 2024
309eb48
fix bug
kangguangli May 17, 2024
a083b88
fix bug
kangguangli May 17, 2024
53f0fb7
fix bug
kangguangli May 17, 2024
1f2c164
fix bug
kangguangli May 17, 2024
382bb5e
fix windows ci
kangguangli May 17, 2024
3e948fa
fix windows ci
kangguangli May 20, 2024
ceed519
fix cinn ci
kangguangli May 20, 2024
1bad93b
fix windows ci
kangguangli May 20, 2024
5987d0c
fix ci
kangguangli May 21, 2024
eec927e
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into transfe…
kangguangli May 21, 2024
124f0b7
fix ci
kangguangli May 21, 2024
331e66e
fix windows ci
kangguangli May 21, 2024
ed9dae6
fix ci
kangguangli May 21, 2024
ff1634a
fix ci
kangguangli May 21, 2024
b00ba77
fix cinn ci
kangguangli May 22, 2024
2c06ede
to trigger CI
kangguangli May 22, 2024
589b3f3
to trigger CI
kangguangli May 22, 2024
5ff6dff
modify by reviews
kangguangli May 26, 2024
ef54728
fix windows ci
kangguangli May 27, 2024
e38738f
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into transfe…
kangguangli May 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.h"

#include "build/paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "build/paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h"
#include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h"
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h"

namespace cinn {
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/eager/auto_code_generator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ set(EAGER_GENERATOR_DEPS
imperative_profiler
imperative_flag)

if(WITH_CINN)
list(REMOVE_ITEM EAGER_GENERATOR_DEPS imperative_flag)
endif()

if(WITH_CUSTOM_DEVICE)
set(EAGER_GENERATOR_DEPS ${EAGER_GENERATOR_DEPS}
custom_device_common_op_registry)
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,8 @@ const std::vector<std::string> kPirGpuPasses{
"matmul_scale_fuse_pass",
"matmul_transpose_fuse_pass",
"transpose_flatten_concat_fuse_pass",
"remove_redundant_transpose_pass"};
"remove_redundant_transpose_pass",
"transfer_layout_pass"};

const std::vector<std::string> kPirXpuPasses{// Functional pass
"map_op_to_another_pass",
Expand Down
301 changes: 300 additions & 1 deletion paddle/fluid/pir/dialect/operator/interface/layout_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,319 @@

#include "paddle/fluid/pir/dialect/operator/interface/layout_transformation.h"

#include "paddle/common/ddim.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/pir/include/core/builtin_attribute.h"
#include "paddle/pir/include/core/ir_context.h"
#include "paddle/pir/include/pass/utils.h"

namespace paddle {
namespace dialect {

template <typename ConcreteOp>
void RewriteByInfermeta(pir::Operation* op, common::DataLayout new_layout) {
std::vector<pir::Type> new_outputs = ConcreteOp::InferMeta(
op->operands_source(), const_cast<pir::AttributeMap*>(&op->attributes()));
for (size_t i = 0; i < new_outputs.size(); ++i) {
op->result(i).set_type(new_outputs[i]);
}

for (auto value : RelevantOutputsImpl<ConcreteOp>(op)) {
pir::SetNewLayoutForValue(value, new_layout);
}
}

template <>
common::DataLayout PreferLayoutImpl<Conv2dOp>(pir::Operation* op) {
auto data_format_attr = op->attribute<pir::StrAttribute>("data_format");
if (!data_format_attr) {
PADDLE_THROW(phi::errors::InvalidArgument(
"op (%s) should have attribute `data_format`, but got %s",
op,
data_format_attr));
}

auto concrete_op = op->dyn_cast<Conv2dOp>();
if (auto in = concrete_op.input()) {
if (auto in_type = in.type()) {
if (in_type.isa<DenseTensorType>()) {
if (auto tensor_type = in_type.dyn_cast<DenseTensorType>()) {
if (tensor_type.dtype().isa<pir::Float16Type>()) {
return common::DataLayout::NHWC;
}
}
}
}
}
return common::StringToDataLayout(data_format_attr.AsString());
}

template <>
void RewriteByLayoutImpl<Conv2dOp>(pir::Operation* op,
common::DataLayout new_layout) {
op->set_attribute(
"data_format",
pir::StrAttribute::get(pir::IrContext::Instance(),
common::DataLayoutToString(new_layout)));
RewriteByInfermeta<Conv2dOp>(op, new_layout);
}

template <>
common::DataLayout PreferLayoutImpl<FusedConv2dAddActOp>(pir::Operation* op) {
return common::DataLayout::NHWC;
auto data_format_attr = op->attribute<pir::StrAttribute>("data_format");
if (!data_format_attr) {
PADDLE_THROW(phi::errors::InvalidArgument(
"op (%s) should have attribute `data_format`, but got %s",
op,
data_format_attr));
}

auto original_layout =
common::StringToDataLayout(data_format_attr.AsString());

auto concrete_op = op->dyn_cast<FusedConv2dAddActOp>();
if (auto in = concrete_op.input()) {
if (auto in_type = in.type()) {
if (in_type.isa<paddle::dialect::DenseTensorType>()) {
if (auto tensor_type =
in_type.dyn_cast<paddle::dialect::DenseTensorType>()) {
if (!tensor_type.dtype().isa<pir::Float16Type>()) {
return original_layout;
}
}
}
}
}

constexpr int CUDNN_ALIGNMENT = 8;

if (auto filter = concrete_op.filter()) {
if (auto filter_type = filter.type()) {
if (filter_type.isa<DenseTensorType>()) {
if (auto tensor_type = filter_type.dyn_cast<DenseTensorType>()) {
if (tensor_type.dtype().isa<pir::Float16Type>()) {
auto dims = tensor_type.dims();
if (dims.size() == 4 && (dims[0] % CUDNN_ALIGNMENT == 0) &&
(dims[1] % CUDNN_ALIGNMENT == 0)) {
return common::DataLayout::NHWC;
}
}
}
}
}
}

return original_layout;
}

template <>
void RewriteByLayoutImpl<FusedConv2dAddActOp>(pir::Operation* op,
common::DataLayout new_layout) {
op->set_attribute(
"data_format",
pir::StrAttribute::get(pir::IrContext::Instance(),
common::DataLayoutToString(new_layout)));

RewriteByInfermeta<FusedConv2dAddActOp>(op, new_layout);
}

template <>
void RewriteByLayoutImpl<GroupNormOp>(pir::Operation* op,
common::DataLayout new_layout) {
op->set_attribute(
"data_format",
pir::StrAttribute::get(pir::IrContext::Instance(),
common::DataLayoutToString(new_layout)));
RewriteByInfermeta<GroupNormOp>(op, new_layout);
}

template <>
std::vector<pir::Value> RelevantInputsImpl<GroupNormOp>(pir::Operation* op) {
auto concrete_op = op->dyn_cast<GroupNormOp>();
return {concrete_op.x()};
}

template <>
std::vector<pir::Value> RelevantOutputsImpl<GroupNormOp>(pir::Operation* op) {
auto concrete_op = op->dyn_cast<GroupNormOp>();
return {concrete_op.y()};
}

template <>
std::vector<pir::Value> RelevantInputsImpl<ReshapeOp>(pir::Operation* op) {
auto concrete_op = op->dyn_cast<ReshapeOp>();
return {concrete_op.x()};
}

template <>
std::vector<pir::Value> RelevantOutputsImpl<ReshapeOp>(pir::Operation* op) {
auto concrete_op = op->dyn_cast<ReshapeOp>();
return {concrete_op.out()};
}

template <>
bool CanBeModifiedImpl<ReshapeOp>(pir::Operation* op) {
return false;
}

template <>
void RewriteByLayoutImpl<SqueezeOp>(pir::Operation* op,
common::DataLayout new_layout) {
PADDLE_THROW(common::errors::Unimplemented(
"Op %s should have a specialized RewriteByLayout function", op->name()));
return;
}

template <>
std::vector<pir::Value> RelevantInputsImpl<SqueezeOp>(pir::Operation* op) {
auto concrete_op = op->dyn_cast<SqueezeOp>();
return {concrete_op.x()};
}

template <>
std::vector<pir::Value> RelevantOutputsImpl<SqueezeOp>(pir::Operation* op) {
auto concrete_op = op->dyn_cast<SqueezeOp>();
return {concrete_op.out()};
}

template <>
bool CanBeModifiedImpl<SqueezeOp>(pir::Operation* op) {
return false;
}

template <>
void RewriteByLayoutImpl<SiluOp>(pir::Operation* op,
common::DataLayout new_layout) {
RewriteByInfermeta<SiluOp>(op, new_layout);
}

template <>
void RewriteByLayoutImpl<AddOp>(pir::Operation* op,
common::DataLayout new_layout) {
RewriteByInfermeta<AddOp>(op, new_layout);
}

template <>
bool CanBeModifiedImpl<AddOp>(pir::Operation* op) {
auto concrete_op = op->dyn_cast<AddOp>();
if (auto x = concrete_op.x(), y = concrete_op.y(); x && y) {
if (auto xt = x.type(), yt = y.type(); xt && yt) {
if (auto xdt = xt.dyn_cast<pir::DenseTensorType>(),
ydt = yt.dyn_cast<pir::DenseTensorType>();
xdt && ydt) {
if (xdt.dims().size() != ydt.dims().size()) {
return false;
}
}
}
}
return true;
}

template <>
void RewriteByLayoutImpl<CastOp>(pir::Operation* op,
common::DataLayout new_layout) {
RewriteByInfermeta<CastOp>(op, new_layout);
}

template <>
std::vector<pir::Value> RelevantInputsImpl<ConcatOp>(pir::Operation* op) {
auto concrete_op = op->dyn_cast<ConcatOp>();
return {concrete_op.x()};
}

template <>
void RewriteByLayoutImpl<ConcatOp>(pir::Operation* op,
common::DataLayout new_layout) {
// we must the value of concat axis, but this is an input
// which is really hard to process.
// here we handle the simple case like pd_op.full and throw
// error in other cases.
auto concrete_op = op->dyn_cast<ConcatOp>();
auto axis = concrete_op.axis();
if (!axis || !(axis.defining_op()->isa<FullOp>())) {
PADDLE_THROW(common::errors::InvalidArgument(
"Concat's axis must be processed when rewirte by layout."));
}

// TODO(lyk): we must assert this full int array op has one user which is
// reshape
auto axis_op = axis.defining_op()->dyn_cast<FullOp>();
int axis_value =
axis_op.attribute("value").dyn_cast<ScalarAttribute>().data().to<int>();

// The layout of the tensor type is unreliable, since its always
// NCHW, which is a default value. So we cannot deduct the new
// axis by new layout, since we do not know if the layout changed.
// So we simply assume the old layout must be NCHW, new layout must
// be NHWC.
PADDLE_ENFORCE_EQ(
axis_value,
1,
common::errors::InvalidArgument(
"Concat's axis was expected as 1, but got %d", axis_value));
axis.defining_op()->set_attribute(
"value",
ScalarAttribute::get(pir::IrContext::Instance(), phi::Scalar(3)));

// infer new meta for concat
RewriteByInfermeta<ConcatOp>(op, new_layout);
}

template <>
void RewriteByLayoutImpl<pir::CombineOp>(pir::Operation* op,
common::DataLayout new_layout) {
auto concrete_op = op->dyn_cast<pir::CombineOp>();
auto out = concrete_op.out();
if (!out) return;
std::vector<pir::Type> new_out_type;
for (auto v : op->operands_source()) {
new_out_type.push_back(v.type());
}
auto new_out_type_v =
pir::VectorType::get(pir::IrContext::Instance(), new_out_type);
out.set_type(new_out_type_v);

return;
}

template <>
std::vector<pir::Value> RelevantInputsImpl<Pool2dOp>(pir::Operation* op) {
auto concrete_op = op->dyn_cast<Pool2dOp>();
return {concrete_op.x()};
}

template <>
void RewriteByLayoutImpl<Pool2dOp>(pir::Operation* op,
common::DataLayout new_layout) {
op->set_attribute(
"data_format",
pir::StrAttribute::get(pir::IrContext::Instance(),
common::DataLayoutToString(new_layout)));

RewriteByInfermeta<Pool2dOp>(op, new_layout);
}

template <>
void RewriteByLayoutImpl<MultiplyOp>(pir::Operation* op,
common::DataLayout new_layout) {
RewriteByInfermeta<MultiplyOp>(op, new_layout);
}

template <>
void RewriteByLayoutImpl<AssignOp>(pir::Operation* op,
common::DataLayout new_layout) {
RewriteByInfermeta<AssignOp>(op, new_layout);
}

template <>
void RewriteByLayoutImpl<SwishOp>(pir::Operation* op,
common::DataLayout new_layout) {
RewriteByInfermeta<SwishOp>(op, new_layout);
}

} // namespace dialect
} // namespace paddle
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::LayoutTransformationInterface)
Loading