Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,15 @@ inline bool horizontal_with_injective(
return true;
}

if (!is_same_size(first, second)) {
return false;
if (horizontal_relation(first, second, OpPatternKind::kInjective)) {
return true;
}
return horizontal_relation(first, second, OpPatternKind::kInjective);

if (is_same_size(first, second)) {
return true;
}

return true;
}

inline bool injective_horizontal_with_reduce(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -609,13 +609,17 @@ class OpFusionPassHelper {
hlir::framework::pir::CompatibleInfo::OpKind(*consumer))) {
auto& consumer_group = fusion_groups_[consumer];
// second step: check producer can be fused into consumer group
VLOG(3) << "Call ConditionFunction, Producer Op Pattern : "
VLOG(3) << "Call ConditionFunction, Producer Op: [" << producer->name()
<< "] Pattern : "
<< hlir::framework::pir::CompatibleInfo::OpKind(*producer)
<< " , Consumer Group Pattern : "
<< consumer_group->op_pattern_kind;
<< " , Consumer Group [" << consumer->name()
<< "] Pattern : " << consumer_group->op_pattern_kind;

return relation.fusion_op_kind[consumer_group->op_pattern_kind](
bool result = relation.fusion_op_kind[consumer_group->op_pattern_kind](
producer, fusion_groups_[consumer], shape_analysis);
VLOG(3) << " CanFuse: " << result;

return result;
}

return false;
Expand Down Expand Up @@ -655,7 +659,7 @@ GroupList OpFusionPassInternal(
OpFusionPassHelper(op_list, output_op_list, shape_analysis);
auto res = op_fusion_helper();

if (VLOG_IS_ON(6)) {
if (VLOG_IS_ON(3)) {
std::stringstream ss;
::pir::IrPrinter printer(ss);
for (size_t i = 0; i < res.size(); ++i) {
Expand All @@ -668,7 +672,7 @@ GroupList OpFusionPassInternal(
ss << "\n";
}
}
VLOG(6) << ss.str();
VLOG(1) << ss.str();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VLOG mismatch?

Copy link
Contributor Author

@zyfncg zyfncg Jan 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for debug, will change back before merging

}
VLOG(3) << "OpFusionPass Finish...!";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,11 @@ inline bool horizontal_or_can_inline(
return false;
}
}

// vertical relation: 1.can compute inline
if (producer->result(0).use_count() == 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是能融合slice/concat的tricky代码吗?

Copy link
Contributor Author

@zyfncg zyfncg Jan 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,这里的判断不算很tricky,如果融合规则明确的话,后面是可以直接使用的

return true;
}
// if (helper->GetNodeData(producer)->outlinks().size() == 1 &&
// helper->output_ops_set_.count(producer) == 0) {
// return true;
Expand Down
94 changes: 50 additions & 44 deletions paddle/cinn/hlir/framework/pir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,9 @@ const std::unordered_map<std::string, std::string> CompatibleInfo::OP_NAMES = {
{"pd_op.maximum", "max"},
{"pd_op.minimum", "min"},
{"pd_op.split_with_num", "split"},
{"pd_op.reshape", "reshape"},
{"pd_op.expand", "broadcast_to"},
{"cinn_op.generate_shape", "generate_shape"},
{"cinn_op.reshape", "reshape"},
{"cinn_op.scale", "scale"},
{"cinn_op.broadcast", "broadcast_to"},
// The following should implement OpPattern in pd_to_cinn_pass,
// otherwise, it will be block in BuildCinnPass.
{"cinn_op.squeeze", ""},
{"cinn_op.unsqueeze", ""}};
{"cinn_op.broadcast", "broadcast_to"}};

namespace {
using GroupOpsVec = std::vector<::pir::Operation*>;
Expand Down Expand Up @@ -119,8 +112,8 @@ bool IsSupportForCinn(const ::pir::Operation& op);
// implement OpPattern in pd_to_cinn_pass. Otherwise, we mark them
// as unimplement ops.
bool UnimplementOps(const ::pir::Operation& op) {
// cinn not support uniform, the FullOp of max and min support NOT generate by
// CINN
// cinn not support uniform, the FullOp of max and min support
// NOT generate by CINN
if (op.isa<paddle::dialect::FullOp>()) {
auto out = op.result(0);
if (out.use_count() > 0) {
Expand All @@ -131,59 +124,69 @@ bool UnimplementOps(const ::pir::Operation& op) {
}

bool HaveZeroDimInput(const ::pir::Operation& op) {
bool have_zero_dim = false;
auto HasZeroDim = [](const ::pir::Type& type) {
auto tensor_type = type.dyn_cast<::pir::DenseTensorType>();
return tensor_type && tensor_type.dims().size() == 0U;
};
// Judge for vector<Type>
auto HasZeroDimInVT = [&](const std::vector<::pir::Type>& types) {
for (auto& type : types) {
if (HasZeroDim(type)) return true;
}
return false;
};

for (size_t i = 0; i < op.num_operands(); ++i) {
auto in = op.operand_source(i);
if (in) {
if (auto tensor_type =
in.type().dyn_cast<paddle::dialect::DenseTensorType>()) {
if (tensor_type.dims().size() == 0) {
have_zero_dim = true;
break;
}
}
auto value = op.operand_source(i);
if (!value || !value.type()) continue;
if (auto vector_type = value.type().dyn_cast<::pir::VectorType>()) {
if (HasZeroDimInVT(vector_type.data())) return true;
} else if (HasZeroDim(value.type())) {
return true;
}
}

return have_zero_dim;
return false;
}

bool AllInputDenseTensor(const ::pir::Operation& op) {
bool all_denese_tensor = true;
auto IsDenseTensor = [](const ::pir::Type& type) {
return type.isa<::pir::DenseTensorType>();
};

// Judge for vector<Type>
auto IsAllDenseTensor = [&](const std::vector<::pir::Type>& types) {
for (auto& type : types) {
if (!IsDenseTensor(type)) return false;
}
return true;
};

for (size_t i = 0; i < op.num_operands(); ++i) {
auto in = op.operand_source(i);
if (in) {
if (!(in.type().isa<paddle::dialect::DenseTensorType>())) {
all_denese_tensor = false;
break;
}
auto value = op.operand_source(i);
if (!value || !value.type()) continue;
if (auto vector_type = value.type().dyn_cast<::pir::VectorType>()) {
if (!IsAllDenseTensor(vector_type.data())) return false;
} else if (!IsDenseTensor(value.type())) {
return false;
}
}

return all_denese_tensor;
return true;
}

bool IsRegisteredInCINN(const ::pir::Operation& op) {
if (CompatibleInfo::OP_NAMES.find(op.name()) !=
CompatibleInfo::OP_NAMES.end()) {
return true;
}
// After PdToCinnPass, if pd_op.reshape still exists, return false.
std::string black_op_name =
std::string(cinn::dialect::OperatorDialect::name()) + "." +
CompatibleInfo::OpName(op);
if (CompatibleInfo::OP_NAMES.find(black_op_name) !=
CompatibleInfo::OP_NAMES.end()) {
VLOG(4) << "Found black op after PdToCinnPass, because it has Attribute "
"Tensor: "
<< op.name();
return false;
}
return OpRegistry::Global()->Find(CompatibleInfo::OpName(op)) != nullptr;
}

bool IsSupportForCinn(const ::pir::Operation& op) {
if (!AllInputDenseTensor(op) || HaveZeroDimInput(op) || UnimplementOps(op)) {
VLOG(4) << "Found " << op.name()
<< " HaveZeroDimInput or UnimplementOps or NotAllInputDenseTensor. "
<< "So mark IsSupportForCinn: " << false;
return false;
}
auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim);
Expand All @@ -196,8 +199,8 @@ bool IsSupportForCinn(const ::pir::Operation& op) {
OpTransInfo trans_info;
bool is_support =
IsRegisteredInCINN(op) && !trans_info.default_deny_ops().count(op_name);
VLOG(4) << op_name << " is_support: " << is_support << " "
<< IsRegisteredInCINN(op);
VLOG(4) << op_name << " is_support: " << is_support
<< " IsRegisteredInCINN: " << IsRegisteredInCINN(op);
// if the op type is registered in CINN and allow_ops is not empty, return
// true only when it is in allow_ops
if (!allow_ops.empty()) {
Expand All @@ -221,7 +224,10 @@ bool IsSupportForCinn(const ::pir::Operation& op) {
// Such as cinn_op.reshape, except pd_op.reshape;
// 3. otherwise, it should be registered in OpRegistry;
bool CompatibleInfo::IsSupportCinn(const ::pir::Operation& op) {
return IsSupportForCinn(op);
bool flag = IsSupportForCinn(op);
VLOG(4) << "CompatibleInfo::IsSupportCinn of " << op.name()
<< " is: " << flag;
return flag;
}

std::string CompatibleInfo::OpName(const ::pir::Operation& op) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/hlir/op/contrib/gather_nd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ CINN_REGISTER_HELPER(gather_nd_ops) {
MakeOpFunction(cinn::hlir::op::InferShapeForGatherNd))
.set_attr("inferdtype",
MakeOpFunction(cinn::hlir::op::InferDtypeForGatherNd))
.set_attr<cinn::hlir::framework::OpPatternKind>(
"OpPattern", cinn::hlir::framework::OpPatternKind::kInjective)
.set_support_level(4);

return true;
Expand Down
95 changes: 95 additions & 0 deletions test/ir/pir/cinn/test_rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest

import paddle
from paddle import nn


def apply_to_static(net, use_cinn, input_spec=None):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(
net,
input_spec=input_spec,
build_strategy=build_strategy,
full_graph=True,
)


class RotaryPosEmb(nn.Layer):
def __init__(self):
super().__init__()

def forward(self, q, k, cos, sin, position_ids):
cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim]
sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim]

cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
return q_embed, k_embed

def rotate_half(self, x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return paddle.concat([-x2, x1], axis=-1) # shape is the same as x


class TestRotaryPosEmb(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
self.prepare_data()

def prepare_data(self):
self.q = paddle.randn([1, 2048, 8, 96], dtype="float32")
self.q.stop_gradient = False

self.k = paddle.randn([1, 2048, 8, 96], dtype="float32")
self.k.stop_gradient = False

self.cos = paddle.randn([1, 2048, 1, 96], dtype="float32")
self.cos.stop_gradient = False

self.sin = paddle.randn([1, 2048, 1, 96], dtype="float32")
self.sin.stop_gradient = False

self.position_ids = paddle.arange(end=2048, dtype="int64").unsqueeze(0)
self.position_ids.stop_gradient = False

def eval(self, use_cinn):
paddle.seed(2022)
net = RotaryPosEmb()
net.eval()
if use_cinn:
net = apply_to_static(net, use_cinn)

out = net(self.q, self.k, self.cos, self.sin, self.position_ids)
return out

def test_eval(self):
cinn_outs = self.eval(use_cinn=True)
# dy_outs = self.eval(use_cinn=False)

# TODO(phlrain): Need to check result
# for cinn_out, dy_out in zip(cinn_outs, dy_outs):
# np.testing.assert_allclose(
# cinn_out.numpy(), dy_out.numpy(), atol=1e-8
# )


if __name__ == '__main__':
unittest.main()