Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,8 @@ const std::vector<std::string> kPirGpuPasses{
"matmul_transpose_fuse_pass",
"transpose_flatten_concat_fuse_pass",
"remove_redundant_transpose_pass",
"transfer_layout_pass"};
"transfer_layout_pass",
};

const std::vector<std::string> kPirXpuPasses{// Functional pass
"map_op_to_another_pass",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,30 @@ common::DataLayout PreferLayoutImpl<Conv2dOp>(pir::Operation* op) {
data_format_attr));
}

// Note(lyk): We exhibit the layout transformation for conv2d
auto concrete_op = op->dyn_cast<Conv2dOp>();
if (auto in = concrete_op.input()) {
if (auto in_type = in.type()) {
if (in_type.isa<DenseTensorType>()) {
if (auto tensor_type = in_type.dyn_cast<DenseTensorType>()) {
if (tensor_type.dtype().isa<pir::Float16Type>()) {
return common::DataLayout::NHWC;
}
}
}
}
}

return common::StringToDataLayout(data_format_attr.AsString());
}

template <>
std::vector<pir::Value> RelevantInputsImpl<Conv2dOp>(pir::Operation* op) {
// Note(lyk): We exhibit the layout transformation for filter of conv2d
// due to issues with its infermeta and kernel not functioning
// properly in NHWC layout. However, if the FLAGS_manually_trans_conv_filter
// is enabled, the transfer_layout_pass can also operate correctly.
return common::StringToDataLayout(data_format_attr.AsString());
auto concrete_op = op->dyn_cast<Conv2dOp>();
return {concrete_op.input()};
}

template <>
Expand Down Expand Up @@ -124,6 +143,31 @@ void RewriteByLayoutImpl<FusedConv2dAddActOp>(pir::Operation* op,
RewriteByInfermeta<FusedConv2dAddActOp>(op, new_layout);
}

template <>
bool CanBeModifiedImpl<FusedConv2dAddActOp>(pir::Operation* op) {
auto data_format_attr = op->attribute<pir::StrAttribute>("data_format");
if (!data_format_attr) {
PADDLE_THROW(phi::errors::InvalidArgument(
"op (%s) should have attribute `data_format`, but got %s",
op,
data_format_attr));
}
auto cur_layout = common::StringToDataLayout(data_format_attr.AsString());
auto prefer_layout = PreferLayoutImpl<FusedConv2dAddActOp>(op);
auto can_be_modified = cur_layout != prefer_layout;

for (auto value : RelevantOutputsImpl<FusedConv2dAddActOp>(op)) {
// TODO(lyk) if value was used in another block, we cannot rewrite this op
for (auto it = value.use_begin(); it != value.use_end(); ++it) {
if (it->owner()->GetParent() != op->GetParent()) {
return false;
}
}
}

return can_be_modified;
}

template <>
void RewriteByLayoutImpl<GroupNormOp>(pir::Operation* op,
common::DataLayout new_layout) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@ bool CanBeModifiedImpl(pir::Operation* op) {
class FusedConv2dAddActOp;
OVERLOAD_PREFER_LAYOUT(FusedConv2dAddActOp);
OVERLOAD_REWRITE_BY_LAYOUT(FusedConv2dAddActOp);
OVERLOAD_CAN_BE_MODIFIED(FusedConv2dAddActOp);

class Conv2dOp;
OVERLOAD_PREFER_LAYOUT(Conv2dOp);
OVERLOAD_RELEVANT_INPUTS(Conv2dOp);
OVERLOAD_REWRITE_BY_LAYOUT(Conv2dOp);

class GroupNormOp;
Expand Down
27 changes: 16 additions & 11 deletions paddle/fluid/pir/transforms/general/transfer_layout_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,18 +278,22 @@ struct FlowGraph {
}
}

std::unordered_set<Node> nhwc_nodes;
std::unordered_set<Node> mutable_nodes;
for (auto& op : *(program.block())) {
auto layout_transform_iface =
op.dyn_cast<paddle::dialect::LayoutTransformationInterface>();
if (!layout_transform_iface) {
continue;
}

if (!layout_transform_iface.CanBeModified(&op)) {
continue;
}

auto prefer_layout = layout_transform_iface.PreferLayout(&op);
if (prefer_layout == common::DataLayout::NHWC) {
Node op_node(&op);
nhwc_nodes.insert(op_node);
mutable_nodes.insert(op_node);
AddEdge(op_node, dst_node(), INF);
VLOG(10) << "[PreProcess] node: " << op_node
<< " should be set to NHWC";
Expand All @@ -302,7 +306,7 @@ struct FlowGraph {
// operation who have a dertermined layout and spread its layout to
// its output and inputs recursively.
std::queue<Node> q;
for (auto& n : nhwc_nodes) {
for (auto& n : mutable_nodes) {
q.push(n);
}
std::unordered_set<Node> is_node_layout_visited;
Expand Down Expand Up @@ -362,13 +366,14 @@ struct FlowGraph {
// a point of cut edge. So we set its outputs and inputs to
// immutable.
Node in_node = Node(v.defining_op());
nhwc_nodes.erase(in_node);
VLOG(10) << "erase node: " << in_node << " from nhwc set";
mutable_nodes.erase(in_node);
VLOG(10) << "erase node: " << in_node << " from mutable set";

for (auto it = v.use_begin(); it != v.use_end(); ++it) {
Node out_node(it->owner());
nhwc_nodes.erase(out_node);
VLOG(10) << "erase node: " << out_node << " from nhwc set";
mutable_nodes.erase(out_node);
VLOG(10)
<< "erase node: " << out_node << " from mutable set";
}
}
return !can_be_transformed;
Expand All @@ -380,8 +385,8 @@ struct FlowGraph {
continue;
}

VLOG(10) << "add node to nhwc set: " << node;
nhwc_nodes.insert(node);
VLOG(10) << "add node to mutable set: " << node;
mutable_nodes.insert(node);

VLOG(10) << "processing node successor: " << node;

Expand All @@ -403,7 +408,7 @@ struct FlowGraph {
continue;
}
is_node_layout_visited.insert(node);
if (nhwc_nodes.count(node) == 0) {
if (mutable_nodes.count(node) == 0) {
VLOG(10) << "add node to nchw set: " << node;
AddEdge(src_node(), node, INF);
}
Expand Down Expand Up @@ -542,7 +547,7 @@ using Edge = FlowGraph::Edge;

class TransferLayoutPass : public pir::Pass {
public:
TransferLayoutPass() : pir::Pass("transfer_layout_pass", 3) {}
TransferLayoutPass() : pir::Pass("transfer_layout_pass", 2) {}

bool CanApplyOn(pir::Operation* op) const override {
if (!op->isa<pir::ModuleOp>()) {
Expand Down