diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.cc index 5e431d3d7fb009..5bbc9b833cf0f0 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.cc @@ -14,12 +14,12 @@ #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.h" -#include "build/paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" -#include "build/paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" namespace cinn { diff --git a/paddle/fluid/eager/auto_code_generator/CMakeLists.txt b/paddle/fluid/eager/auto_code_generator/CMakeLists.txt index 3e0257aa7bffd5..78414410b8709d 100644 --- a/paddle/fluid/eager/auto_code_generator/CMakeLists.txt +++ b/paddle/fluid/eager/auto_code_generator/CMakeLists.txt @@ -12,6 +12,10 @@ set(EAGER_GENERATOR_DEPS imperative_profiler imperative_flag) +if(WITH_CINN) + list(REMOVE_ITEM EAGER_GENERATOR_DEPS imperative_flag) +endif() + if(WITH_CUSTOM_DEVICE) set(EAGER_GENERATOR_DEPS ${EAGER_GENERATOR_DEPS} custom_device_common_op_registry) diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index d6db33c30ab534..c6b83175d21d2a 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -615,7 +615,8 @@ const std::vector kPirGpuPasses{ "matmul_scale_fuse_pass", "matmul_transpose_fuse_pass", "transpose_flatten_concat_fuse_pass", - "remove_redundant_transpose_pass"}; + "remove_redundant_transpose_pass", + "transfer_layout_pass"}; const std::vector kPirXpuPasses{// Functional pass "map_op_to_another_pass", diff --git a/paddle/fluid/pir/dialect/operator/interface/layout_transformation.cc b/paddle/fluid/pir/dialect/operator/interface/layout_transformation.cc index c6c1401f32d5c5..41945cfa0c106d 100644 --- a/paddle/fluid/pir/dialect/operator/interface/layout_transformation.cc +++ b/paddle/fluid/pir/dialect/operator/interface/layout_transformation.cc @@ -14,20 +14,319 @@ #include "paddle/fluid/pir/dialect/operator/interface/layout_transformation.h" +#include "paddle/common/ddim.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/phi/common/scalar.h" +#include "paddle/pir/include/core/builtin_attribute.h" +#include "paddle/pir/include/core/ir_context.h" +#include "paddle/pir/include/pass/utils.h" + namespace paddle { namespace dialect { +template +void RewriteByInfermeta(pir::Operation* op, common::DataLayout new_layout) { + std::vector new_outputs = ConcreteOp::InferMeta( + op->operands_source(), const_cast(&op->attributes())); + for (size_t i = 0; i < new_outputs.size(); ++i) { + op->result(i).set_type(new_outputs[i]); + } + + for (auto value : RelevantOutputsImpl(op)) { + pir::SetNewLayoutForValue(value, new_layout); + } +} + +template <> +common::DataLayout PreferLayoutImpl(pir::Operation* op) { + auto data_format_attr = op->attribute("data_format"); + if (!data_format_attr) { + PADDLE_THROW(phi::errors::InvalidArgument( + "op (%s) should have attribute `data_format`, but got %s", + op, + data_format_attr)); + } + + auto concrete_op = op->dyn_cast(); + if (auto in = concrete_op.input()) { + if (auto in_type = in.type()) { + if (in_type.isa()) { + if (auto tensor_type = in_type.dyn_cast()) { + if (tensor_type.dtype().isa()) { + return common::DataLayout::NHWC; + } + } + } + } + } + return common::StringToDataLayout(data_format_attr.AsString()); +} + +template <> +void RewriteByLayoutImpl(pir::Operation* op, + common::DataLayout new_layout) { + op->set_attribute( + "data_format", + pir::StrAttribute::get(pir::IrContext::Instance(), + common::DataLayoutToString(new_layout))); + RewriteByInfermeta(op, new_layout); +} + template <> common::DataLayout PreferLayoutImpl(pir::Operation* op) { - return common::DataLayout::NHWC; + auto data_format_attr = op->attribute("data_format"); + if (!data_format_attr) { + PADDLE_THROW(phi::errors::InvalidArgument( + "op (%s) should have attribute `data_format`, but got %s", + op, + data_format_attr)); + } + + auto original_layout = + common::StringToDataLayout(data_format_attr.AsString()); + + auto concrete_op = op->dyn_cast(); + if (auto in = concrete_op.input()) { + if (auto in_type = in.type()) { + if (in_type.isa()) { + if (auto tensor_type = + in_type.dyn_cast()) { + if (!tensor_type.dtype().isa()) { + return original_layout; + } + } + } + } + } + + constexpr int CUDNN_ALIGNMENT = 8; + + if (auto filter = concrete_op.filter()) { + if (auto filter_type = filter.type()) { + if (filter_type.isa()) { + if (auto tensor_type = filter_type.dyn_cast()) { + if (tensor_type.dtype().isa()) { + auto dims = tensor_type.dims(); + if (dims.size() == 4 && (dims[0] % CUDNN_ALIGNMENT == 0) && + (dims[1] % CUDNN_ALIGNMENT == 0)) { + return common::DataLayout::NHWC; + } + } + } + } + } + } + + return original_layout; } template <> void RewriteByLayoutImpl(pir::Operation* op, common::DataLayout new_layout) { + op->set_attribute( + "data_format", + pir::StrAttribute::get(pir::IrContext::Instance(), + common::DataLayoutToString(new_layout))); + + RewriteByInfermeta(op, new_layout); +} + +template <> +void RewriteByLayoutImpl(pir::Operation* op, + common::DataLayout new_layout) { + op->set_attribute( + "data_format", + pir::StrAttribute::get(pir::IrContext::Instance(), + common::DataLayoutToString(new_layout))); + RewriteByInfermeta(op, new_layout); +} + +template <> +std::vector RelevantInputsImpl(pir::Operation* op) { + auto concrete_op = op->dyn_cast(); + return {concrete_op.x()}; +} + +template <> +std::vector RelevantOutputsImpl(pir::Operation* op) { + auto concrete_op = op->dyn_cast(); + return {concrete_op.y()}; +} + +template <> +std::vector RelevantInputsImpl(pir::Operation* op) { + auto concrete_op = op->dyn_cast(); + return {concrete_op.x()}; +} + +template <> +std::vector RelevantOutputsImpl(pir::Operation* op) { + auto concrete_op = op->dyn_cast(); + return {concrete_op.out()}; +} + +template <> +bool CanBeModifiedImpl(pir::Operation* op) { + return false; +} + +template <> +void RewriteByLayoutImpl(pir::Operation* op, + common::DataLayout new_layout) { + PADDLE_THROW(common::errors::Unimplemented( + "Op %s should have a specialized RewriteByLayout function", op->name())); return; } +template <> +std::vector RelevantInputsImpl(pir::Operation* op) { + auto concrete_op = op->dyn_cast(); + return {concrete_op.x()}; +} + +template <> +std::vector RelevantOutputsImpl(pir::Operation* op) { + auto concrete_op = op->dyn_cast(); + return {concrete_op.out()}; +} + +template <> +bool CanBeModifiedImpl(pir::Operation* op) { + return false; +} + +template <> +void RewriteByLayoutImpl(pir::Operation* op, + common::DataLayout new_layout) { + RewriteByInfermeta(op, new_layout); +} + +template <> +void RewriteByLayoutImpl(pir::Operation* op, + common::DataLayout new_layout) { + RewriteByInfermeta(op, new_layout); +} + +template <> +bool CanBeModifiedImpl(pir::Operation* op) { + auto concrete_op = op->dyn_cast(); + if (auto x = concrete_op.x(), y = concrete_op.y(); x && y) { + if (auto xt = x.type(), yt = y.type(); xt && yt) { + if (auto xdt = xt.dyn_cast(), + ydt = yt.dyn_cast(); + xdt && ydt) { + if (xdt.dims().size() != ydt.dims().size()) { + return false; + } + } + } + } + return true; +} + +template <> +void RewriteByLayoutImpl(pir::Operation* op, + common::DataLayout new_layout) { + RewriteByInfermeta(op, new_layout); +} + +template <> +std::vector RelevantInputsImpl(pir::Operation* op) { + auto concrete_op = op->dyn_cast(); + return {concrete_op.x()}; +} + +template <> +void RewriteByLayoutImpl(pir::Operation* op, + common::DataLayout new_layout) { + // we must the value of concat axis, but this is an input + // which is really hard to process. + // here we handle the simple case like pd_op.full and throw + // error in other cases. + auto concrete_op = op->dyn_cast(); + auto axis = concrete_op.axis(); + if (!axis || !(axis.defining_op()->isa())) { + PADDLE_THROW(common::errors::InvalidArgument( + "Concat's axis must be processed when rewirte by layout.")); + } + + // TODO(lyk): we must assert this full int array op has one user which is + // reshape + auto axis_op = axis.defining_op()->dyn_cast(); + int axis_value = + axis_op.attribute("value").dyn_cast().data().to(); + + // The layout of the tensor type is unreliable, since its always + // NCHW, which is a default value. So we cannot deduct the new + // axis by new layout, since we do not know if the layout changed. + // So we simply assume the old layout must be NCHW, new layout must + // be NHWC. + PADDLE_ENFORCE_EQ( + axis_value, + 1, + common::errors::InvalidArgument( + "Concat's axis was expected as 1, but got %d", axis_value)); + axis.defining_op()->set_attribute( + "value", + ScalarAttribute::get(pir::IrContext::Instance(), phi::Scalar(3))); + + // infer new meta for concat + RewriteByInfermeta(op, new_layout); +} + +template <> +void RewriteByLayoutImpl(pir::Operation* op, + common::DataLayout new_layout) { + auto concrete_op = op->dyn_cast(); + auto out = concrete_op.out(); + if (!out) return; + std::vector new_out_type; + for (auto v : op->operands_source()) { + new_out_type.push_back(v.type()); + } + auto new_out_type_v = + pir::VectorType::get(pir::IrContext::Instance(), new_out_type); + out.set_type(new_out_type_v); + + return; +} + +template <> +std::vector RelevantInputsImpl(pir::Operation* op) { + auto concrete_op = op->dyn_cast(); + return {concrete_op.x()}; +} + +template <> +void RewriteByLayoutImpl(pir::Operation* op, + common::DataLayout new_layout) { + op->set_attribute( + "data_format", + pir::StrAttribute::get(pir::IrContext::Instance(), + common::DataLayoutToString(new_layout))); + + RewriteByInfermeta(op, new_layout); +} + +template <> +void RewriteByLayoutImpl(pir::Operation* op, + common::DataLayout new_layout) { + RewriteByInfermeta(op, new_layout); +} + +template <> +void RewriteByLayoutImpl(pir::Operation* op, + common::DataLayout new_layout) { + RewriteByInfermeta(op, new_layout); +} + +template <> +void RewriteByLayoutImpl(pir::Operation* op, + common::DataLayout new_layout) { + RewriteByInfermeta(op, new_layout); +} + } // namespace dialect } // namespace paddle IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::LayoutTransformationInterface) diff --git a/paddle/fluid/pir/dialect/operator/interface/layout_transformation.h b/paddle/fluid/pir/dialect/operator/interface/layout_transformation.h index 71678029fb48cc..52ed9c6c289e7a 100644 --- a/paddle/fluid/pir/dialect/operator/interface/layout_transformation.h +++ b/paddle/fluid/pir/dialect/operator/interface/layout_transformation.h @@ -34,21 +34,25 @@ class LayoutTransformationInterface using RewriteByLayoutFn = void (*)(pir::Operation*, common::DataLayout); using RelevantInputsFn = std::vector (*)(pir::Operation*); using RelevantOutputsFn = std::vector (*)(pir::Operation*); + using CanBeModifiedFn = bool (*)(pir::Operation*); struct Concept { explicit Concept(PreferLayoutFn prefer_layout, RewriteByLayoutFn rewrite_by_layout, RelevantInputsFn relevant_inputs, - RelevantOutputsFn relevant_outputs) + RelevantOutputsFn relevant_outputs, + CanBeModifiedFn can_be_modified) : prefer_layout(prefer_layout), rewrite_by_layout(rewrite_by_layout), relevant_inputs(relevant_inputs), - relevant_outputs(relevant_outputs) {} + relevant_outputs(relevant_outputs), + can_be_modified(can_be_modified) {} PreferLayoutFn prefer_layout; RewriteByLayoutFn rewrite_by_layout; RelevantInputsFn relevant_inputs; RelevantOutputsFn relevant_outputs; + CanBeModifiedFn can_be_modified; }; template @@ -70,11 +74,16 @@ class LayoutTransformationInterface return RelevantOutputsImpl(op); } + static bool CanBeModifiedModel(pir::Operation* op) { + return CanBeModifiedImpl(op); + } + Model() : Concept(PreferLayoutModel, RewriteByLayoutModel, RelevantInputsModel, - RelevantOutputsModel) {} + RelevantOutputsModel, + CanBeModifiedModel) {} }; LayoutTransformationInterface(pir::Operation* op, Concept* impl) @@ -96,6 +105,8 @@ class LayoutTransformationInterface return impl_->relevant_outputs(op); } + bool CanBeModified(pir::Operation* op) { return impl_->can_be_modified(op); } + private: Concept* impl_; }; diff --git a/paddle/fluid/pir/dialect/operator/interface/layout_transformation.hpp b/paddle/fluid/pir/dialect/operator/interface/layout_transformation.hpp index c1860cbbac1080..05719bc1dfb2f1 100644 --- a/paddle/fluid/pir/dialect/operator/interface/layout_transformation.hpp +++ b/paddle/fluid/pir/dialect/operator/interface/layout_transformation.hpp @@ -14,12 +14,43 @@ #pragma once +#include + #include "paddle/common/enforce.h" #include "paddle/common/layout.h" #include "paddle/phi/common/data_type.h" +#include "paddle/pir/include/core/builtin_type.h" #include "paddle/pir/include/core/operation.h" #include "paddle/pir/include/core/type_name.h" +#define OVERLOAD_PREFER_LAYOUT(op) \ + template <> \ + common::DataLayout PreferLayoutImpl(pir::Operation*); \ + extern template common::DataLayout PreferLayoutImpl(pir::Operation*); + +#define OVERLOAD_REWRITE_BY_LAYOUT(op) \ + template <> \ + void RewriteByLayoutImpl(pir::Operation*, common::DataLayout); \ + extern template void RewriteByLayoutImpl(pir::Operation*, \ + common::DataLayout); + +#define OVERLOAD_RELEVANT_INPUTS(op) \ + template <> \ + std::vector RelevantInputsImpl(pir::Operation * op); \ + extern template std::vector RelevantInputsImpl( \ + pir::Operation * op); + +#define OVERLOAD_RELEVANT_OUTPUTS(op) \ + template <> \ + std::vector RelevantOutputsImpl(pir::Operation * op); \ + extern template std::vector RelevantOutputsImpl( \ + pir::Operation * op); + +#define OVERLOAD_CAN_BE_MODIFIED(op) \ + template <> \ + bool CanBeModifiedImpl(pir::Operation * op); \ + extern template bool CanBeModifiedImpl(pir::Operation * op); + namespace paddle { namespace dialect { @@ -28,6 +59,11 @@ common::DataLayout PreferLayoutImpl(pir::Operation* op) { return common::DataLayout::ALL_LAYOUT; } +template +common::DataLayout CurrentLayoutImpl(pir::Operation* op) { + return common::DataLayout::UNDEFINED; +} + template void RewriteByLayoutImpl(pir::Operation* op, common::DataLayout new_layout) { PADDLE_THROW(common::errors::Unimplemented( @@ -37,24 +73,96 @@ void RewriteByLayoutImpl(pir::Operation* op, common::DataLayout new_layout) { template std::vector RelevantInputsImpl(pir::Operation* op) { - return op->operands_source(); + std::vector relevant_inputs; + for (auto& operand : op->operands_source()) { + if (!operand || !operand.type()) continue; + if (auto operand_type = operand.type().dyn_cast()) { + if (operand_type.size() == 0) continue; + } + relevant_inputs.push_back(operand); + } + return relevant_inputs; } template std::vector RelevantOutputsImpl(pir::Operation* op) { - return op->results(); + std::vector relevant_outputs; + for (auto& result : op->results()) { + if (!result || !result.type()) continue; + if (auto result_type = result.type().dyn_cast()) { + if (result_type.size() == 0) continue; + } + relevant_outputs.push_back(result); + } + return relevant_outputs; +} + +template +bool CanBeModifiedImpl(pir::Operation* op) { + return true; } class FusedConv2dAddActOp; -template <> -common::DataLayout PreferLayoutImpl(pir::Operation*); -extern template common::DataLayout PreferLayoutImpl( - pir::Operation*); -template <> -void RewriteByLayoutImpl(pir::Operation*, - common::DataLayout); -extern template void RewriteByLayoutImpl( - pir::Operation*, common::DataLayout); +OVERLOAD_PREFER_LAYOUT(FusedConv2dAddActOp); +OVERLOAD_REWRITE_BY_LAYOUT(FusedConv2dAddActOp); + +class Conv2dOp; +OVERLOAD_PREFER_LAYOUT(Conv2dOp); +OVERLOAD_REWRITE_BY_LAYOUT(Conv2dOp); + +class GroupNormOp; +OVERLOAD_REWRITE_BY_LAYOUT(GroupNormOp); +OVERLOAD_RELEVANT_INPUTS(GroupNormOp); +OVERLOAD_RELEVANT_OUTPUTS(GroupNormOp); + +class ReshapeOp; +OVERLOAD_RELEVANT_INPUTS(ReshapeOp); +OVERLOAD_RELEVANT_OUTPUTS(ReshapeOp); +OVERLOAD_CAN_BE_MODIFIED(ReshapeOp); + +class SqueezeOp; +OVERLOAD_REWRITE_BY_LAYOUT(SqueezeOp); +OVERLOAD_RELEVANT_INPUTS(SqueezeOp); +OVERLOAD_RELEVANT_OUTPUTS(SqueezeOp); +OVERLOAD_CAN_BE_MODIFIED(SqueezeOp); + +class SiluOp; +OVERLOAD_REWRITE_BY_LAYOUT(SiluOp); + +class AddOp; +OVERLOAD_REWRITE_BY_LAYOUT(AddOp); +OVERLOAD_CAN_BE_MODIFIED(AddOp); + +class CastOp; +OVERLOAD_REWRITE_BY_LAYOUT(CastOp); + +class ConcatOp; +OVERLOAD_REWRITE_BY_LAYOUT(ConcatOp); +OVERLOAD_RELEVANT_INPUTS(ConcatOp); + +class Pool2dOp; +OVERLOAD_RELEVANT_INPUTS(Pool2dOp); +OVERLOAD_REWRITE_BY_LAYOUT(Pool2dOp); + +class MultiplyOp; +OVERLOAD_REWRITE_BY_LAYOUT(MultiplyOp); + +class AssignOp; +OVERLOAD_REWRITE_BY_LAYOUT(AssignOp); + +class SwishOp; +OVERLOAD_REWRITE_BY_LAYOUT(SwishOp); + +} // namespace dialect +} // namespace paddle + +namespace pir { +class CombineOp; +} + +namespace paddle { +namespace dialect { +OVERLOAD_REWRITE_BY_LAYOUT(::pir::CombineOp); } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 8fd0215293a0d7..fda9cb14d77128 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/framework/custom_operator_utils.h" +#include "paddle/fluid/pir/dialect/operator/interface/layout_transformation.h" #include "paddle/fluid/pir/dialect/operator/interface/vjp.h" #include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" #include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" @@ -238,6 +239,9 @@ OperatorDialect::OperatorDialect(pir::IrContext* ctx) info.AttachInterface( pir::InterfaceValue::Get()); + info.AttachInterface(pir::InterfaceValue::Get< + LayoutTransformationInterface, + LayoutTransformationInterface::Model>()); info = ctx->GetRegisteredOpInfo(pir::ParameterOp::name()); info.AttachInterface( diff --git a/paddle/fluid/pir/transforms/general/transfer_layout_pass.cc b/paddle/fluid/pir/transforms/general/transfer_layout_pass.cc new file mode 100644 index 00000000000000..e87690b6021ba7 --- /dev/null +++ b/paddle/fluid/pir/transforms/general/transfer_layout_pass.cc @@ -0,0 +1,752 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/general/transfer_layout_pass.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/common/layout.h" +#include "paddle/fluid/inference/api/paddle_pass_builder.h" +#include "paddle/fluid/pir/dialect/operator/interface/layout_transformation.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/pir/include/core/builtin_dialect.h" +#include "paddle/pir/include/core/ir_context.h" +#include "paddle/pir/include/core/program.h" +#include "paddle/pir/include/pass/pass.h" +#include "paddle/pir/include/pass/pass_manager.h" +#include "paddle/pir/include/pass/pass_registry.h" +#include "paddle/pir/include/pass/utils.h" + +struct Node; + +struct SrcNode { + bool operator==(const SrcNode& rhs) const { return true; } + operator Node() const; + friend std::ostream& operator<<(std::ostream& os, const SrcNode& n) { + os << "Src"; + return os; + } +}; +struct DstNode { + bool operator==(const DstNode& rhs) const { return true; } + operator Node() const; + friend std::ostream& operator<<(std::ostream& os, const DstNode& n) { + os << "Dst"; + return os; + } +}; + +SrcNode src_node() { return SrcNode(); } +DstNode dst_node() { return DstNode(); } + +const float INF = std::numeric_limits::max(); +template +struct overloaded : Ts... { + using Ts::operator()...; +}; +template +overloaded(Ts...) -> overloaded; + +struct Node { + using DataType = + std::variant; + DataType data; + + explicit Node(const pir::Operation* op) : data(op) {} + explicit Node(pir::Value value) : data(value) {} + explicit Node(SrcNode n) : data(n) {} + explicit Node(DstNode n) : data(n) {} + + Node() : data(pir::Value(nullptr)) {} + + bool operator==(const Node& rhs) const { + bool ret = std::visit( + overloaded{ + [](const pir::Operation* left, const pir::Operation* right) { + return (left == right); + }, + [](const pir::Value& left, const pir::Value& right) { + return (left == right); + }, + [](const SrcNode& left, const SrcNode& right) { return true; }, + [](const DstNode& left, const DstNode& right) { return true; }, + [](auto& left, auto& right) { return false; }}, + data, + rhs.data); + return ret; + } + friend std::ostream& operator<<(std::ostream& os, const Node& n) { + std::visit(overloaded{[&](const pir::Operation* op) { + os << "Op(" << op->name() << " " << op << ")"; + }, + [&](const pir::Value& value) { + if (!value) + os << "Var(null)"; + else + os << "Var(" << value.defining_op()->name() << " " + << value.defining_op() << ")"; + }, + [&](SrcNode arg) { os << "Src"; }, + [&](DstNode arg) { os << "Dst"; }}, + n.data); + return os; + } +}; + +SrcNode::operator Node() const { return Node(*this); } + +DstNode::operator Node() const { return Node(*this); } + +namespace std { + +template <> +struct hash { + size_t operator()(const Node& s) const noexcept { return 0x111; } +}; + +template <> +struct hash { + size_t operator()(const Node& s) const noexcept { return 0x222; } +}; + +template <> +struct hash { + size_t operator()(const Node& s) const noexcept { + return hash{}(s.data); + } +}; + +} // namespace std + +struct FlowGraph { + using EdgeIndex = size_t; + + struct Edge { + Node src; + Node dst; + float capacity; + float flow; + bool real; + + Edge(Node src, + Node dst, + float capacity = 0.0f, + float flow = 0.0f, + bool real = false) + : src(src), dst(dst), capacity(capacity), flow(flow), real(real) {} + friend std::ostream& operator<<(std::ostream& os, const Edge& n) { + os << "(" << n.src << "->" << n.dst << ")"; + return os; + } + }; + + std::vector edges; + // std::vector nodes; + std::unordered_map> adjs; + std::unordered_map cur_arcs; + std::unordered_map heights; + const pir::Program& program; + + void AddEdge(Node src, + Node dst, + float capacity = 0.0f, + float flow = 0.0f, + bool real = false) { + if (src == dst) { + return; + } + + edges.emplace_back(src, dst, capacity, flow, real); + adjs[src].push_back(edges.size() - 1); + + // add reverse edge + edges.emplace_back(dst, src, 0, flow); + adjs[dst].push_back(edges.size() - 1); + } + + explicit FlowGraph(const pir::Program& program) : program(program) { + // We assume by default that the program is topologically sorted; + // otherwise, it will fail during destruction. + + for (auto& op : *(program.block())) { + Node op_node(&op); + auto layout_transform_iface = + op.dyn_cast(); + const auto& relevate_inputs = + layout_transform_iface ? layout_transform_iface.RelevantInputs(&op) + : op.operands_source(); + const auto& relevate_outputs = + layout_transform_iface ? layout_transform_iface.RelevantOutputs(&op) + : op.results(); + VLOG(10) << "[BuildGraph]" << op_node << " isz:" << relevate_inputs.size() + << " osz:" << relevate_outputs.size(); + + // add in edge + for (auto& operand : relevate_inputs) { + Node operand_node(operand); + // the capacity should be set as the out_degree of operand node + float weight = 1.0f; + if (operand && operand.type()) { + weight = 1.0f / (operand.use_count()); + if (auto t = operand.type().dyn_cast()) { + weight = INF; + } + } + AddEdge(operand_node, op_node, weight, 0.0f, true); + } + + for (const auto& op_result : relevate_outputs) { + // we have ssa, so the output must not be processed + Node op_result_node(op_result); + + float weight = 1.0f; + if (op_result && op_result.type()) { + if (auto t = op_result.type().dyn_cast()) { + weight = INF; + } + } + AddEdge(op_node, op_result_node, weight, 0.0f, true); + } + } + + PreProcess(); + } + + void PreProcess() { + // the algorithm only accepts two kinds of layout, we assign src node + // and dst node each a kind. in the begin, each var node have a + // layout, but the layout of op node is uncertain. + + // TODO(lyk): we need a getLayout interface to get the layout of op / + // value and determine how many kinds of layout in program currently. + // Then we call prefer_layout get the total count while running the + // algorithm. To simplify the experiment, we skip the first step here + // and just assume they're all NCHW + + for (auto& op : *(program.block())) { + // we need to ensure the edge from src node to real src node in + // calculation graph + + if (!op.HasTrait() && op.num_operands() > 0) { + continue; + } + Node op_node(&op); + AddEdge(src_node(), op_node, INF); + + auto layout_transform_iface = + op.dyn_cast(); + const auto& relevate_inputs = + layout_transform_iface ? layout_transform_iface.RelevantInputs(&op) + : op.operands_source(); + const auto& relevate_outputs = + layout_transform_iface ? layout_transform_iface.RelevantOutputs(&op) + : op.results(); + + for (const auto& op_operand : relevate_inputs) { + Node operand_node(op_operand); + AddEdge(src_node(), operand_node, INF); + } + + for (const auto& op_result : relevate_outputs) { + Node op_result_node(op_result); + AddEdge(src_node(), op_result_node, INF); + } + } + + std::unordered_set nhwc_nodes; + for (auto& op : *(program.block())) { + auto layout_transform_iface = + op.dyn_cast(); + if (!layout_transform_iface) { + continue; + } + + auto prefer_layout = layout_transform_iface.PreferLayout(&op); + if (prefer_layout == common::DataLayout::NHWC) { + Node op_node(&op); + nhwc_nodes.insert(op_node); + AddEdge(op_node, dst_node(), INF); + VLOG(10) << "[PreProcess] node: " << op_node + << " should be set to NHWC"; + } + } + + // Since VarDesc doesn't store layout, in pir we set all layout to + // NCHW after translation. However, we need the real layout to decide + // if we need to alter the operation and value. Here we start from the + // operation who have a dertermined layout and spread its layout to + // its output and inputs recursively. + std::queue q; + for (auto& n : nhwc_nodes) { + q.push(n); + } + std::unordered_set is_node_layout_visited; + int i = 0; + while (!q.empty()) { + VLOG(10) << "before : " << q.size() << " " << i; + i++; + Node node = q.front(); + VLOG(10) << "visiting node: " << node; + q.pop(); + if (is_node_layout_visited.find(node) != is_node_layout_visited.end()) { + continue; + } + is_node_layout_visited.insert(node); + + VLOG(10) << "judging node: " << node; + + auto judge_dense_tensor_type = [](paddle::dialect::DenseTensorType t) { + if (t.dims().size() == 4) { + return false; + } + return true; + }; + + bool should_interrupt = std::visit( + overloaded{ + [&](const pir::Operation* op) { + // TODO(lyk): These conditions may be too loose, + // we should make a white list here. + + pir::Operation* fop = const_cast(op); + + auto layout_transform_iface = fop->dyn_cast< + paddle::dialect::LayoutTransformationInterface>(); + if (layout_transform_iface) { + return !layout_transform_iface.CanBeModified(fop); + } + return true; + }, + [&](const pir::Value& v) { + if (!v) return true; + auto vt = v.type(); + if (!vt) return true; + // maybe not DenseTensor, but we can handle other types later + if (auto vdt = + vt.dyn_cast()) { + VLOG(10) << "judging var: " << v.defining_op() << " " + << v.type() << " " << vdt.dims() << " " + << (vdt.dims().size() == 4); + return judge_dense_tensor_type(vdt); + } else if (auto vdt = vt.dyn_cast()) { + if (vdt.size() == 0) return false; + auto vt_elem = vdt[0]; + if (auto vdt_elem = + vt_elem.dyn_cast()) + return judge_dense_tensor_type(vdt_elem); + } + return true; + }, + [](const auto&) { return true; }, + }, + node.data); + if (should_interrupt) { + continue; + } + + VLOG(10) << "add node to nhwc set: " << node; + nhwc_nodes.insert(node); + + VLOG(10) << "processing node successor: " << node; + + int j = 0; + for (const auto& e : adjs[node]) { + auto& edge = edges[e]; + q.push(edge.dst); + VLOG(10) << "add node to queue: " << node << " -> " << edge.dst; + j++; + } + } + + q.push(src_node()); + is_node_layout_visited.clear(); + while (!q.empty()) { + auto node = q.front(); + q.pop(); + if (is_node_layout_visited.find(node) != is_node_layout_visited.end()) { + continue; + } + is_node_layout_visited.insert(node); + if (nhwc_nodes.count(node) == 0) { + VLOG(10) << "add node to nchw set: " << node; + AddEdge(src_node(), node, INF); + } + for (const auto& e : adjs[node]) { + auto& edge = edges[e]; + q.push(edge.dst); + } + } + } + + bool ConstructLevelGraph() { + heights.clear(); + std::queue> q; + q.push({src_node(), 0}); + while (!q.empty()) { + auto [node, height] = q.front(); + q.pop(); + if (heights.find(node) != heights.end()) { + continue; + } + heights[node] = height; + for (auto e_ind : adjs[node]) { + auto& e = edges[e_ind]; + if (e.capacity - e.flow > 0 && heights.find(e.dst) == heights.end()) { + q.push({e.dst, height + 1}); + } + } + } + return (heights[dst_node()] > 0); + } + + // cf is the admissable flow in current path + float FindBlockingFlow(Node src, float cf) { + if (src == dst_node() || abs(cf) < 1e-9) { + return cf; + } + + auto& next_arc = cur_arcs[src]; // notice this is a reference + float ret = 0.0f; + while (next_arc < adjs[src].size()) { + auto e_ind = adjs[src][next_arc]; + auto& e = edges[e_ind]; + next_arc++; + auto next_node = e.dst; + if (heights[next_node] == heights[src] + 1) { + auto left_capacity = e.capacity - e.flow; + auto update_flow = std::min(cf - ret, left_capacity); + auto f = FindBlockingFlow(next_node, update_flow); + if (f > 0) { + e.flow += f; + auto reverse_e_ind = e_ind ^ 1; + auto& reverse_e = edges[reverse_e_ind]; + reverse_e.flow -= f; + ret += f; + + if (abs(ret - cf) < 1e-9) { + return ret; + } + } + } + } + + if (ret == 0) { + heights[src] = 0; + } + + return ret; + } + + float MaxFlow() { + VLOG(10) + << "--------------------[max flow start]---------------------------"; + float total_flow = 0.0f; + while (ConstructLevelGraph()) { + for (auto& [node, nexts] : adjs) { + cur_arcs[node] = 0; + } + while (auto f = FindBlockingFlow(src_node(), INF)) { + total_flow += f; + } + } + VLOG(10) << "--------------------[max flow end]---------------------------"; + return total_flow; + } + + std::tuple, std::vector> MinCut() { // NOLINT + MaxFlow(); + // from src_node get its reachable nodes and call them S + // other nodes are in T + // collect edges between S and T + std::unordered_set src_set; + std::queue q; + q.push(src_node()); + while (!q.empty()) { + auto n = q.front(); + q.pop(); + VLOG(10) << "bfs access: " << n; + if (src_set.count(n) > 0) continue; + src_set.insert(n); + VLOG(10) << "bfs insert " << n << " " << src_set.size(); + for (auto& ind : adjs[n]) { + VLOG(10) << "bfs edge: " << edges[ind] << " c:" << edges[ind].capacity + << " f:" << edges[ind].flow; + if (edges[ind].capacity > edges[ind].flow) { + VLOG(10) << "bfs add: " << edges[ind].dst; + q.push(edges[ind].dst); + } + } + } + + VLOG(10) << "src_set.size()=" << src_set.size(); + + std::vector cut; + for (const auto& e : edges) { + if (!e.real) continue; + auto& src = e.src; + auto& dst = e.dst; + bool src_cond = (src_set.count(src) > 0); + bool dst_cond = (src_set.count(dst) > 0); + if (src_cond == dst_cond) { + continue; + } + VLOG(10) << "cut " << src << "(" << src_cond << ")" + << " " << dst << "(" << dst_cond << ")"; + cut.push_back(e); + } + + VLOG(10) << "cut set.size()=" << cut.size(); + VLOG(10) << "-----------------------------------------------"; + + return {src_set, cut}; + } +}; + +using Edge = FlowGraph::Edge; + +class TransferLayoutPass : public pir::Pass { + public: + TransferLayoutPass() : pir::Pass("transfer_layout_pass", 4) {} + + bool CanApplyOn(pir::Operation* op) const override { + if (!op->isa()) { + return false; + } + return op->num_regions() > 0; + } + + void Run(pir::Operation* op) override { + pir::IrContext* ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + auto module_op = op->dyn_cast(); + auto* program = module_op.program(); + + // MinCut + VLOG(10) << "---------------------MinCut---------------------"; + FlowGraph graph(*program); + auto&& [src_set, cut] = graph.MinCut(); + for (auto& e : cut) { + VLOG(10) << e; + } + + // collect all edges from variable to operation + // for these, we only need to add 1 for every variable + // instead of every edges + std::unordered_map> var_set; + std::unordered_map op_src_set; + std::vector op_set; + for (const auto& e : cut) { + if (std::get_if(&(e.src.data))) { + op_src_set[e.src] = e.dst; + op_set.push_back(e); + } else { + var_set[e.src].push_back(e.dst); + } + } + + VLOG(10) << "-----------------------[var set]------------------------"; + + // cout var_set + for (auto& [var, ops] : var_set) { + VLOG(10) << var << ":"; + for (auto op : ops) { + VLOG(10) << op << ","; + } + VLOG(10); + } + + VLOG(10) << "-----------------------[op src set]------------------------"; + + // cout op_set + for (auto& [k, v] : op_src_set) { + VLOG(10) << k << "," << v; + } + + VLOG(10) << "-----------------------[min cut end]------------------------"; + + pir::Builder builder(ctx, program->block()); + auto layout_to_perm = [](std::string src, std::string dst) { + std::vector perm(src.size(), 0); + std::unordered_map d; + for (size_t i = 0; i < src.size(); ++i) { + d[src[i]] = i; + } + for (size_t i = 0; i < dst.size(); ++i) { + perm[i] = d[dst[i]]; + } + return perm; + }; + + std::deque q; + std::unordered_set is_node_layout_visited; + std::function topological_visit = [&](Node node) -> void { + if (is_node_layout_visited.count(node)) return; + is_node_layout_visited.insert(node); + + // add successors to queue + for (auto& ind : graph.adjs[node]) { + auto& e = graph.edges[ind]; + if (!e.real) continue; + topological_visit(e.dst); + } + + q.push_front(node); + }; + + for (auto& op : *(program->block())) { + if (op.num_operands() > 0) continue; + Node op_node(&op); + topological_visit(op_node); + q.push_front(op_node); + } + + VLOG(10) + << "-----------------------[topological sort]------------------------"; + + for (auto n : q) { + VLOG(10) << n; + } + + VLOG(10) + << "-----------------------[rewrite begin]------------------------"; + + while (!q.empty()) { + auto node = q.front(); + q.pop_front(); + + // not in cut set and its layout should not be changed + if (src_set.find(node) == src_set.end()) { + // process layout transformation + if (std::get_if(&(node.data)) != nullptr) { + auto* op = const_cast( + std::get(node.data)); + VLOG(10) << "[Rewrite][RewriteByLayout] " << node; + auto layout_transformation_iface = + op->dyn_cast(); + if (layout_transformation_iface) { + layout_transformation_iface.RewriteByLayout( + op, common::DataLayout::NHWC); + } else { + PADDLE_THROW(common::errors::Unimplemented( + "Op %s should have a specialized RewriteByLayout function", + op->name())); + } + } + } + + VLOG(10) << "[Rewrite] for " << node; + // if node is the src node of a cut edge + // and it's an operation + if (op_src_set.find(node) != op_src_set.end()) { + VLOG(10) << "[Rewrite][Op] for " << node; + + // just insert a transpose op + auto src = node; + auto dst = op_src_set[src]; + auto dst_value = std::get(dst.data); + + VLOG(10) << "[Rewrite][Op] for var:" + << (dst_value ? (dst_value.defining_op()) : nullptr) + << " t:" << (dst_value ? (dst_value.type()) : pir::Type()); + + // enforce dst value.defining_op = src + const auto& perm = + ((src_set.count(node) > 0) ? layout_to_perm("NCHW", "NHWC") + : layout_to_perm("NHWC", "NCHW")); + const auto& new_layout = + ((src_set.count(node) > 0) ? common::DataLayout::NHWC + : common::DataLayout::NCHW); + builder.SetInsertionPointAfter(dst_value.defining_op()); + auto transpose_op = + builder.Build(dst_value, perm); + transpose_op->set_attribute( + "source", + pir::StrAttribute::get(transpose_op->ir_context(), + "transfer_layout_pass")); + auto replace_uses_without_self = [&](pir::OpOperand arg) { + return arg.owner() != transpose_op.operation(); + }; + pir::SetNewLayoutForValue(transpose_op.out(), new_layout); + dst_value.ReplaceUsesWithIf(transpose_op.out(), + replace_uses_without_self); + } + + // if node is the src node of a cut edge + // and it's a value + // this node must not be in the nhwc set + if (var_set.find(node) != var_set.end()) { + VLOG(10) << "[Rewrite][Var] for " << node; + const auto& ops = var_set[node]; + // operand should be replaced + std::unordered_set operation_set; + for (auto op : ops) { + operation_set.insert(std::get(op.data)); + } + + auto value = std::get(node.data); + VLOG(10) << "[Rewrite][Var] for var:" + << (value ? value.defining_op() : nullptr); + for (const auto& op : operation_set) { + VLOG(10) << " op: " << op << ","; + } + VLOG(10); + const auto& perm = + ((src_set.count(node) > 0) ? layout_to_perm("NCHW", "NHWC") + : layout_to_perm("NHWC", "NCHW")); + const auto& new_layout = + ((src_set.count(node) > 0) ? common::DataLayout::NHWC + : common::DataLayout::NCHW); + builder.SetInsertionPointAfter(value.defining_op()); + auto transpose_op = + builder.Build(value, perm); + transpose_op->set_attribute( + "source", + pir::StrAttribute::get(transpose_op->ir_context(), + "transfer_layout_pass")); + auto replace_uses_in_cut_set = [&](pir::OpOperand arg) { + return (operation_set.find(arg.owner()) != operation_set.end()) && + (arg.owner() != transpose_op.operation()); + }; + pir::SetNewLayoutForValue(transpose_op.out(), new_layout); + value.ReplaceUsesWithIf(transpose_op.out(), replace_uses_in_cut_set); + } + } + } +}; + +namespace pir { + +std::unique_ptr CreateTransferLayoutPass() { + return std::make_unique(); +} + +} // namespace pir + +REGISTER_IR_PASS(transfer_layout_pass, TransferLayoutPass); diff --git a/paddle/fluid/pir/transforms/general/transfer_layout_pass.h b/paddle/fluid/pir/transforms/general/transfer_layout_pass.h new file mode 100644 index 00000000000000..98e90292c2c7f3 --- /dev/null +++ b/paddle/fluid/pir/transforms/general/transfer_layout_pass.h @@ -0,0 +1,27 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/pir/include/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateTransferLayoutPass(); + +} // namespace pir diff --git a/paddle/fluid/pir/transforms/passes.h b/paddle/fluid/pir/transforms/passes.h index 57e8958470c925..db6a50a8ec3ade 100644 --- a/paddle/fluid/pir/transforms/passes.h +++ b/paddle/fluid/pir/transforms/passes.h @@ -41,6 +41,7 @@ USE_PIR_PASS(add_norm_fuse_pass); USE_PIR_PASS(fused_dot_product_attention_pass); USE_PIR_PASS(fused_flash_attn_pass); USE_PIR_PASS(remove_redundant_transpose_pass); +USE_PIR_PASS(transfer_layout_pass); #ifdef PADDLE_WITH_DNNL USE_PIR_PASS(depthwise_conv_onednn_pass); diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index bf957554a3d755..91845ea04fc998 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -36,6 +36,7 @@ extern bool HasCUDNN(); * different cudnn version has different interfaces **/ #define CUDNN_DNN_ROUTINE_EACH(__macro) \ + __macro(cudnnSetCallback); \ __macro(cudnnSetTensor4dDescriptor); \ __macro(cudnnSetTensor4dDescriptorEx); \ __macro(cudnnSetTensorNdDescriptor); \ diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index a020cb4c8122d3..6deffc89271f9e 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -54,6 +54,7 @@ set(PYBIND_DEPS if(WITH_CINN) set(PYBIND_DEPS ${PYBIND_DEPS} pir_transforms cinn_transforms sub_graph_checker add_cinn_pass) + list(REMOVE_ITEM PYBIND_DEPS imperative_flag) endif() if(WITH_PSCORE) diff --git a/paddle/phi/backends/dynload/cudnn.h b/paddle/phi/backends/dynload/cudnn.h index 7a7dce241ff0ac..2cd91fa8ab517e 100644 --- a/paddle/phi/backends/dynload/cudnn.h +++ b/paddle/phi/backends/dynload/cudnn.h @@ -49,6 +49,7 @@ TEST_API extern void EnforceCUDNNLoaded(const char* fn_name); * different cudnn version has different interfaces **/ #define CUDNN_DNN_ROUTINE_EACH(__macro) \ + __macro(cudnnSetCallback); \ __macro(cudnnSetTensor4dDescriptor); \ __macro(cudnnSetTensor4dDescriptorEx); \ __macro(cudnnSetTensorNdDescriptor); \ diff --git a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml index 14c474ca1b21b5..7a469b67da8b71 100644 --- a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml +++ b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml @@ -14,7 +14,7 @@ data_type: x inplace : (x -> out) backward : add_grad - interfaces : paddle::dialect::InferSymbolicShapeInterface + interfaces : paddle::dialect::InferSymbolicShapeInterface, paddle::dialect::LayoutTransformationInterface # this add_n is only for ops_api_gen.py and onednn - op : add_n @@ -47,7 +47,7 @@ func : assign backward : assign_grad inplace : (x -> out) - interfaces : paddle::dialect::InferSymbolicShapeInterface + interfaces : paddle::dialect::InferSymbolicShapeInterface, paddle::dialect::LayoutTransformationInterface - op : assign_out_ args : (Tensor x, Tensor output) @@ -709,7 +709,7 @@ data_transform : support_trans_dtype : x, y backward : multiply_grad - interfaces : paddle::dialect::InferSymbolicShapeInterface + interfaces : paddle::dialect::InferSymbolicShapeInterface, paddle::dialect::LayoutTransformationInterface - op : nop args : (Tensor x) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index bca412a8cce37b..cb36b6528d1648 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -772,7 +772,7 @@ func : concat data_type : x backward : concat_grad - interfaces : paddle::dialect::InferSymbolicShapeInterface + interfaces : paddle::dialect::InferSymbolicShapeInterface, paddle::dialect::LayoutTransformationInterface - op : conj args : (Tensor x) @@ -793,7 +793,7 @@ func : conv2d data_type : input backward : conv2d_grad - interfaces : paddle::dialect::InferSymbolicShapeInterface + interfaces : paddle::dialect::InferSymbolicShapeInterface, paddle::dialect::LayoutTransformationInterface - op : conv2d_transpose args : (Tensor x, Tensor filter, int[] strides={1, 1}, int[] paddings={0, 0}, int[] output_padding={}, IntArray output_size={}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1}, str data_format="NCHW") @@ -1936,7 +1936,7 @@ optional : scale, bias intermediate : mean, variance backward : group_norm_grad - interfaces : paddle::dialect::InferSymbolicShapeInterface + interfaces : paddle::dialect::InferSymbolicShapeInterface, paddle::dialect::LayoutTransformationInterface - op : gumbel_softmax args : (Tensor x, float temperature = 1.0, bool hard = false, int axis = -1) @@ -2995,6 +2995,7 @@ func : pool2d param : [x, kernel_size, strides, paddings, ceil_mode, exclusive, data_format, pooling_type, global_pooling, adaptive, padding_algorithm] backward : pool2d_grad + interfaces : paddle::dialect::LayoutTransformationInterface - op : pool3d args : (Tensor x, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) @@ -3550,6 +3551,7 @@ kernel : func : silu backward : silu_grad + interfaces : paddle::dialect::LayoutTransformationInterface - op : sin args : (Tensor x) @@ -3690,7 +3692,7 @@ view: (x -> out) intermediate : xshape backward : squeeze_grad - interfaces : paddle::dialect::InferSymbolicShapeInterface + interfaces : paddle::dialect::InferSymbolicShapeInterface, paddle::dialect::LayoutTransformationInterface - op : stack args : (Tensor[] x, int axis = 0) @@ -3770,6 +3772,7 @@ kernel : func : swish backward : swish_grad + interfaces : paddle::dialect::LayoutTransformationInterface - op : sync_batch_norm_ args : (Tensor x, Tensor mean, Tensor variance, Tensor scale, Tensor bias, bool is_test, float momentum, float epsilon, str data_format, bool use_global_stats, bool trainable_statistics) diff --git a/paddle/pir/include/pass/pass_registry.h b/paddle/pir/include/pass/pass_registry.h index 9fba4e09c54339..003292a576b0d1 100644 --- a/paddle/pir/include/pass/pass_registry.h +++ b/paddle/pir/include/pass/pass_registry.h @@ -25,7 +25,7 @@ namespace pir { using PassCreator = std::function()>; -class PassRegistry { +class IR_API PassRegistry { public: static PassRegistry &Instance(); @@ -57,7 +57,7 @@ class PassRegistry { }; template -class PassRegistrar { +class IR_API PassRegistrar { public: // In our design, various kinds of passes, // have their corresponding registry and registrar. The action of @@ -87,7 +87,7 @@ class PassRegistrar { "REGISTER_IR_PASS must be called in global namespace"); \ static ::pir::PassRegistrar \ __pir_pass_registrar_##pass_type##__(#pass_type); \ - int TouchPirPassRegistrar_##pass_type() { \ + IR_API int TouchPirPassRegistrar_##pass_type() { \ __pir_pass_registrar_##pass_type##__.Touch(); \ return 0; \ } \ diff --git a/paddle/pir/include/pass/utils.h b/paddle/pir/include/pass/utils.h new file mode 100644 index 00000000000000..9a2cbc0274793f --- /dev/null +++ b/paddle/pir/include/pass/utils.h @@ -0,0 +1,24 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/common/layout.h" +#include "paddle/pir/include/core/value.h" + +namespace pir { + +void SetNewLayoutForValue(pir::Value value, common::DataLayout new_layout); + +} // namespace pir diff --git a/paddle/pir/src/pass/utils.cc b/paddle/pir/src/pass/utils.cc new file mode 100644 index 00000000000000..f866d7beaf8a2b --- /dev/null +++ b/paddle/pir/src/pass/utils.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pir/include/pass/utils.h" + +#include "paddle/pir/include/core/builtin_type.h" +#include "paddle/pir/include/core/ir_context.h" + +namespace pir { + +void SetNewLayoutForValue(pir::Value value, common::DataLayout new_layout) { + if (!value || !value.type()) { + return; + } + auto tensor_type = value.type().dyn_cast(); + if (!tensor_type) { + return; + } + auto new_tensor_type = pir::DenseTensorType::get(pir::IrContext::Instance(), + tensor_type.dtype(), + tensor_type.dims(), + new_layout, + tensor_type.lod(), + tensor_type.offset()); + value.set_type(new_tensor_type); +} + +} // namespace pir diff --git a/test/cpp/pir/operator/layout_transformation_interface_test.cc b/test/cpp/pir/operator/layout_transformation_interface_test.cc index 7b9ed0e8c171bd..2067df3f8d70a7 100644 --- a/test/cpp/pir/operator/layout_transformation_interface_test.cc +++ b/test/cpp/pir/operator/layout_transformation_interface_test.cc @@ -33,7 +33,7 @@ TEST(layout_transformation_interface_test, operator) { auto build_input_value = [&](std::vector shape = {2, 2}) { auto uniform = builder.Build( - shape, phi::DataType::FLOAT32, 0.0, 1.0, 2, phi::CPUPlace()); + shape, phi::DataType::FLOAT16, 0.0, 1.0, 2, phi::CPUPlace()); return uniform; }; @@ -48,13 +48,12 @@ TEST(layout_transformation_interface_test, operator) { EXPECT_TRUE(layout_transformation_iface); EXPECT_EQ(layout_transformation_iface.PreferLayout(fused_conv), - common::DataLayout::NHWC); + common::DataLayout::NCHW); EXPECT_NO_THROW(layout_transformation_iface.RewriteByLayout( fused_conv, common::DataLayout::NHWC)); EXPECT_EQ(layout_transformation_iface.RelevantInputs(fused_conv).size(), fused_conv->operands().size()); - EXPECT_EQ(layout_transformation_iface.RelevantOutputs(fused_conv).size(), - fused_conv->results().size()); + EXPECT_EQ(layout_transformation_iface.RelevantOutputs(fused_conv).size(), 1u); } TEST(immutable_layout_trait_test, operator) { diff --git a/test/cpp/pir/pass/CMakeLists.txt b/test/cpp/pir/pass/CMakeLists.txt index 3507cfd0708dd7..904d6e5e4e7ab0 100644 --- a/test/cpp/pir/pass/CMakeLists.txt +++ b/test/cpp/pir/pass/CMakeLists.txt @@ -5,3 +5,16 @@ if(WITH_ONNXRUNTIME AND WIN32) # be build only in CI, so suppose the generator in Windows is Ninja. copy_onnx(pass_manager_test) endif() + +if(WITH_GPU) + file(DOWNLOAD https://paddle-ci.gz.bcebos.com/test/sd15_unet.pdmodel + ${CMAKE_CURRENT_BINARY_DIR}/sd15_unet.pdmodel + EXPECTED_MD5 4b5a3b8ea5b49bfd12172847cfe5a92a) + + paddle_test(transfer_layout_pass_test SRCS transfer_layout_pass_test.cc) + if(WITH_ONNXRUNTIME AND WIN32) + # Copy onnxruntime for some c++ test in Windows, since the test will + # be build only in CI, so suppose the generator in Windows is Ninja. + copy_onnx(transfer_layout_pass_test) + endif() +endif() diff --git a/test/cpp/pir/pass/transfer_layout_pass_test.cc b/test/cpp/pir/pass/transfer_layout_pass_test.cc new file mode 100644 index 00000000000000..26f7649d161b14 --- /dev/null +++ b/test/cpp/pir/pass/transfer_layout_pass_test.cc @@ -0,0 +1,106 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/common/layout.h" +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/inference/api/paddle_pass_builder.h" +#include "paddle/fluid/ir_adaptor/translator/translate.h" +#include "paddle/fluid/pir/dialect/operator/interface/layout_transformation.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/transforms/general/transfer_layout_pass.h" +#include "paddle/fluid/pir/transforms/passes.h" +#include "paddle/pir/include/core/builtin_dialect.h" +#include "paddle/pir/include/core/ir_context.h" +#include "paddle/pir/include/core/program.h" +#include "paddle/pir/include/pass/pass_manager.h" + +using ProgramDesc = paddle::framework::ProgramDesc; +ProgramDesc load_from_file(const std::string& file_name) { + std::ifstream fin(file_name, std::ios::in | std::ios::binary); + fin.seekg(0, std::ios::end); + + std::string buffer(fin.tellg(), ' '); + fin.seekg(0, std::ios::beg); + fin.read(&buffer[0], buffer.size()); // NOLINT + fin.close(); + return ProgramDesc(buffer); +} + +TEST(transfer_layout_pass, pass_test) { + // Load Unet Program + const std::string model_name = "sd15_unet.pdmodel"; + auto p = load_from_file(model_name); + EXPECT_EQ(p.Size(), 1u); + EXPECT_GT(p.Block(0).OpSize(), 0u); + + // Translate to PIR Program + pir::IrContext* ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + auto program = paddle::TranslateLegacyProgramToProgram(p); + + pir::PassManager pass_pm(::pir::IrContext::Instance(), 3); + + // Note(lyk) To avoid windows compiler error: + // paddle::kPirGpuPasses has already been declared as + // PD_INFER_DECL, but we still got LINK ERROR. The + // reason is still unclear, skip it now. + const std::vector kCopiedPirGpuPasses{ + // Functional pass + "map_op_to_another_pass", + "identity_op_clean_pass", + // Operator fusion pass + "silu_fuse_pass", + "conv2d_bn_fuse_pass", + "conv2d_add_act_fuse_pass", + "conv2d_add_fuse_pass", + "embedding_eltwise_layernorm_fuse_pass", + "fused_flash_attn_pass", + "multihead_matmul_fuse_pass", + "matmul_add_act_fuse_pass", + "fc_elementwise_layernorm_fuse_pass", + "matmul_scale_fuse_pass", + "matmul_transpose_fuse_pass", + "transpose_flatten_concat_fuse_pass", + "remove_redundant_transpose_pass"}; + + for (const auto& gpu_pass : kCopiedPirGpuPasses) { + pass_pm.AddPass(pir::PassRegistry::Instance().Get(gpu_pass)); + } + pass_pm.Run(program.get()); + + pir::PassManager transfer_layout_manager(::pir::IrContext::Instance(), 4); + transfer_layout_manager.AddPass(pir::CreateTransferLayoutPass()); + transfer_layout_manager.Run(program.get()); +}