Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "ov_ops/rotary_positional_embeddings.hpp"
Expand Down Expand Up @@ -783,7 +784,14 @@ ov::pass::RoPEFusionChatGLMHF::RoPEFusionChatGLMHF() {
auto reshape = pattern::wrap_type<v1::Reshape>({qk_linear, pattern::any_input()},
pattern::shape_matches("[?, head_cnt, 1, head_size]"),
{{"special_zero", false}});
auto slice_1 = NewGenSlice(reshape, 0, "ndims", 1, 3);

auto qkv_proj = pattern::wrap_type<v1::VariadicSplit>({reshape, 3, {"ndims", "ndims"}});
qkv_proj->set_output_size(2);
auto vsplit_out0 =
pattern::wrap_type<op::v1::VariadicSplit>({reshape, 3, {"ndims", "ndims"}}, pattern::output_index_matches(0));
auto vsplit_out1 =
pattern::wrap_type<op::v1::VariadicSplit>({reshape, 3, {"ndims", "ndims"}}, pattern::output_index_matches(1));
auto slice_1 = NewGenSlice(reshape, 0, "ndims", 1, 3) | vsplit_out0;

auto const_idx =
pattern::wrap_type<ov::opset1::Constant>(pattern::type_matches(ov::element::i32) && const_idx_predicate);
Expand All @@ -807,7 +815,7 @@ ov::pass::RoPEFusionChatGLMHF::RoPEFusionChatGLMHF() {
auto multiply_1 = pattern::wrap_type<v1::Multiply>({flatten, repeat_interleave_sin}, {{"auto_broadcast", "numpy"}});
auto add = pattern::wrap_type<v1::Add>({multiply, multiply_1}, {{"auto_broadcast", "numpy"}});

auto slice_5 = NewGenSlice(reshape, "ndims", INT_MAX, 1, 3);
auto slice_5 = NewGenSlice(reshape, "ndims", INT_MAX, 1, 3) | vsplit_out1;
auto result = pattern::wrap_type<v0::Concat>({add, slice_5}, {{"axis", -1}});

matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
Expand Down
73 changes: 52 additions & 21 deletions src/plugins/intel_cpu/src/nodes/rope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,26 @@ struct RoPE::RoPEExecutorChatGLM : public RoPE::Executor {
jcp.dst_prc = precision_of<T>::value;
jcp.rotary_ndims = config.rotary_ndims;
jcp.interleave = true;
jcp.mix_cos_sin = true;
// if use precomputed rope cache then it's mixed
// otherwise rope will have separate cos/sin inputs
jcp.mix_cos_sin = config.use_rope_cache;
m_rotaryKernel = createJitKernel(jcp, true);
}

void execute([[maybe_unused]] const dnnl::stream& strm,
const std::vector<MemoryPtr>& inputs,
const std::vector<MemoryPtr>& outputs) override {
ov::intel_cpu::PlainTensor t_src(inputs[0]);
ov::intel_cpu::PlainTensor t_cos_sin(inputs[1]);
ov::intel_cpu::PlainTensor t_cos;
ov::intel_cpu::PlainTensor t_sin;
ov::intel_cpu::PlainTensor t_cos_sin;
if (!m_config.use_rope_cache) {
t_cos.reset(inputs[1]);
t_sin.reset(inputs[2]);
} else {
t_cos_sin.reset(inputs[1]);
}

ov::intel_cpu::PlainTensor t_dst(outputs[0]);

// [seq_len, batch_size, (hidden_states_q + hidden_states_k + hidden_states_v)]
Expand All @@ -277,27 +288,47 @@ struct RoPE::RoPEExecutorChatGLM : public RoPE::Executor {

auto rotary_dims = m_config.rotary_ndims;

parallel_for3d(batch_size, head_cnt, seq_len, [&](size_t b, size_t h, size_t p) {
// src [batch, length, H x S]
auto* src = t_src.ptr<T>(b, p, h * head_size);
// [batch_size, length, ndims//2, 2]
auto* cos_sin = &t_cos_sin.at<float>({b, p, 0, 0}, true);
auto* dst = t_dst.ptr<T>(b, h, p, 0);

if (m_rotaryKernel) {
execJitKernel(m_rotaryKernel, src, dst, cos_sin, nullptr);
} else {
size_t i = 0;
for (; i < rotary_dims; i += 2) {
auto cosv = cos_sin[i];
auto sinv = cos_sin[i + 1];
dst[i] = cosv * src[i] - sinv * src[i + 1];
dst[i + 1] = sinv * src[i] + cosv * src[i + 1];
if (m_config.use_rope_cache) {
parallel_for3d(batch_size, head_cnt, seq_len, [&](size_t b, size_t h, size_t p) {
// src [batch, length, H x S]
auto* src = t_src.ptr<T>(b, p, h * head_size);
// [batch_size, length, ndims//2, 2]
auto* cos_sin = &t_cos_sin.at<float>({b, p, 0, 0}, true);
auto* dst = t_dst.ptr<T>(b, h, p, 0);

if (m_rotaryKernel) {
execJitKernel(m_rotaryKernel, src, dst, cos_sin, nullptr);
} else {
size_t i = 0;
for (; i < rotary_dims; i += 2) {
auto cosv = cos_sin[i];
auto sinv = cos_sin[i + 1];
dst[i] = cosv * src[i] - sinv * src[i + 1];
dst[i + 1] = sinv * src[i] + cosv * src[i + 1];
}
}
}

memcpy(dst + rotary_dims, src + rotary_dims, (head_size - rotary_dims) * sizeof(T));
});
memcpy(dst + rotary_dims, src + rotary_dims, (head_size - rotary_dims) * sizeof(T));
});
} else {
parallel_for3d(batch_size, head_cnt, seq_len, [&](size_t b, size_t h, size_t p) {
auto* src = t_src.ptr<T>(b, p, h * head_size);
auto* dst = t_dst.ptr<T>(b, h, p);
// The pattern matching ensures that cos/sin table has shape [-1, 1, 1, -1] so that only b is
// vairable.
const auto* cos = t_cos.ptr<float>(b, 0, 0);
const auto* sin = t_sin.ptr<float>(b, 0, 0);
if (m_rotaryKernel) {
execJitKernel(m_rotaryKernel, src, dst, cos, sin);
} else {
for (size_t i = 0; i < rotary_dims; i += 2) {
dst[i] = cos[i / 2] * src[i] - sin[i / 2] * src[i + 1];
dst[i + 1] = sin[i / 2] * src[i] + cos[i / 2] * src[i + 1];
}
}
memcpy(dst + rotary_dims, src + rotary_dims, (head_size - rotary_dims) * sizeof(T));
});
}
} else {
auto seq_len = t_src.size(0);
auto batch_size = t_src.size(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1096,7 +1096,6 @@ void Transformations::PostLpt() {
CPU_REGISTER_PASS_X64(postLPTPassManager, ov::pass::RoPEFusion, true);
CPU_REGISTER_PASS_ARM64(postLPTPassManager, ov::pass::RoPEFusion, true);
CPU_DISABLE_PASS_COMMON(postLPTPassManager, ov::pass::RoPEFusionFlux);
CPU_DISABLE_PASS_COMMON(postLPTPassManager, ov::pass::RoPEFusionChatGLMHF);
CPU_REGISTER_PASS_X64(postLPTPassManager, CausalMaskPreprocessFusion);

#if defined(OPENVINO_ARCH_X86_64)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,12 @@ INSTANTIATE_TEST_SUITE_P(smoke_RoPETestQwenVL,
::testing::ValuesIn(vit_param)),
RoPETestQwenVL::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_RoPETestChatGLM,
RoPETestChatGLMHF,
::testing::Combine(::testing::Values(ov::element::f32),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(true, false)),
RoPETestChatGLMHF::getTestCaseName);

} // namespace test
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ INSTANTIATE_TEST_SUITE_P(smoke_RoPETestChatGLM,
RoPETestChatGLMHF,
::testing::Combine(
::testing::Values(ov::element::f16, ov::element::f32),
::testing::Values(ov::test::utils::DEVICE_GPU)),
::testing::Values(ov::test::utils::DEVICE_GPU),
::testing::Values(true, false)),
RoPETestChatGLMHF::getTestCaseName);

} // namespace test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace test {
using rope_params = std::tuple<ov::element::Type, std::string>;
using rope_params_2 = std::tuple<bool, ov::element::Type, std::string>;
using rope_params_qwenvit = std::tuple<ov::element::Type, std::string, std::string>;
using rope_params_chatglm = std::tuple<ov::element::Type, std::string, bool>;

class RoPETestFlux : public SubgraphBaseTest, public testing::WithParamInterface<rope_params> {
private:
Expand Down Expand Up @@ -183,19 +184,20 @@ class RoPETestChatGLM2DRoPEStridedSlice : public SubgraphBaseTest, public testin
static std::string getTestCaseName(const testing::TestParamInfo<rope_params>& obj);
};

class RoPETestChatGLMHF : public SubgraphBaseTest, public testing::WithParamInterface<rope_params> {
class RoPETestChatGLMHF : public SubgraphBaseTest, public testing::WithParamInterface<rope_params_chatglm> {
private:
std::shared_ptr<ov::Model> buildROPE_ChatGLMHF(int seq_length,
int num_heads,
int rotary_ndims,
ov::element::Type element_type);
ov::element::Type element_type,
bool use_split);

protected:
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override;
void SetUp() override;

public:
static std::string getTestCaseName(const testing::TestParamInfo<rope_params>& obj);
static std::string getTestCaseName(const testing::TestParamInfo<rope_params_chatglm>& obj);
};

} // namespace test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1345,20 +1345,27 @@ std::string RoPETestChatGLM2DRoPEStridedSlice::getTestCaseName(const testing::Te
std::shared_ptr<ov::Model> RoPETestChatGLMHF::buildROPE_ChatGLMHF(int seq_length,
int num_heads,
int rotary_ndims,
ov::element::Type element_type) {
ov::element::Type element_type,
bool use_split) {
auto input = std::make_shared<ov::opset1::Parameter>(element_type, PartialShape{seq_length, 1, 4096});
auto cos =
std::make_shared<ov::opset1::Parameter>(element_type, PartialShape{seq_length, 1, 1, (rotary_ndims / 2)});
auto sin =
std::make_shared<ov::opset1::Parameter>(element_type, PartialShape{seq_length, 1, 1, (rotary_ndims / 2)});

auto transpose = makeOP<ov::opset1::Reshape>({input, {-1, num_heads, 1, 128}}, {{"special_zero", false}});
auto slice_1 = makeOP<ov::opset1::StridedSlice>({transpose, {0, 0, 0, 0}, {0, 0, 0, rotary_ndims}, {1, 1, 1, 1}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
std::shared_ptr<ov::Node> slice_or_split_1 = nullptr;
if (use_split) {
slice_or_split_1 = makeOP<ov::opset1::VariadicSplit>({transpose, 3, {rotary_ndims, rotary_ndims}});
} else {
slice_or_split_1 =
makeOP<ov::opset1::StridedSlice>({transpose, {0, 0, 0, 0}, {0, 0, 0, rotary_ndims}, {1, 1, 1, 1}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
}

std::vector<int32_t> rpi_idx(rotary_ndims, 1);
int32_t v = 0;
Expand All @@ -1370,8 +1377,8 @@ std::shared_ptr<ov::Model> RoPETestChatGLMHF::buildROPE_ChatGLMHF(int seq_length
auto repeat_interleave_cos = makeOP<ov::opset8::Gather>({cos, repeat_interleave_index, -1}, {{"batch_dims", 0}});
auto repeat_interleave_sin = makeOP<ov::opset8::Gather>({cos, repeat_interleave_index, -1}, {{"batch_dims", 0}});

auto multiply = makeOP<ov::opset1::Multiply>({slice_1, repeat_interleave_cos}, {{"auto_broadcast", "numpy"}});
auto slice_2 = makeOP<ov::opset1::StridedSlice>({slice_1, {0, 0, 0, 1}, {0, 0, 0, INT_MAX}, {1, 1, 1, 2}},
auto multiply = makeOP<ov::opset1::Multiply>({slice_or_split_1->output(0), repeat_interleave_cos}, {{"auto_broadcast", "numpy"}});
auto slice_2 = makeOP<ov::opset1::StridedSlice>({slice_or_split_1->output(0), {0, 0, 0, 1}, {0, 0, 0, INT_MAX}, {1, 1, 1, 2}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
Expand All @@ -1381,7 +1388,7 @@ std::shared_ptr<ov::Model> RoPETestChatGLMHF::buildROPE_ChatGLMHF(int seq_length
auto neg = makeOP<ov::opset1::Multiply>({slice_2, minus_one}, {{"auto_broadcast", "numpy"}});
auto unsqueeze_1 =
makeOP<ov::opset1::Reshape>({neg, {-1, num_heads, 1, (rotary_ndims / 2), 1}}, {{"special_zero", false}});
auto slice_3 = makeOP<ov::opset1::StridedSlice>({slice_1, {0, 0, 0, 0}, {0, 0, 0, INT_MAX}, {1, 1, 1, 2}},
auto slice_3 = makeOP<ov::opset1::StridedSlice>({slice_or_split_1->output(0), {0, 0, 0, 0}, {0, 0, 0, INT_MAX}, {1, 1, 1, 2}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
Expand All @@ -1393,15 +1400,20 @@ std::shared_ptr<ov::Model> RoPETestChatGLMHF::buildROPE_ChatGLMHF(int seq_length
auto flatten = makeOP<ov::opset1::Reshape>({stack, {0, num_heads, 0, rotary_ndims}}, {{"special_zero", true}});
auto multiply_1 = makeOP<ov::opset1::Multiply>({flatten, repeat_interleave_sin}, {{"auto_broadcast", "numpy"}});
auto add = makeOP<ov::opset1::Add>({multiply, multiply_1}, {{"auto_broadcast", "numpy"}});

auto slice_5 =
makeOP<ov::opset1::StridedSlice>({transpose, {0, 0, 0, rotary_ndims}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto concat = makeOP<ov::opset1::Concat>({add, slice_5}, {{"axis", -1}});
std::shared_ptr<ov::Node> slice_or_split_5 = nullptr;
if (use_split) {
slice_or_split_5 = slice_or_split_1;
} else {
slice_or_split_5 =
makeOP<ov::opset1::StridedSlice>({transpose, {0, 0, 0, rotary_ndims}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
}
ov::Output<ov::Node> slice_or_split_output_5 = use_split ? slice_or_split_5->output(1) : slice_or_split_5->output(0);
auto concat = makeOP<ov::opset1::Concat>({add, slice_or_split_output_5}, {{"axis", -1}});
return std::make_shared<ov::Model>(ov::OutputVector{concat}, ov::ParameterVector{input, cos, sin});
}

Expand Down Expand Up @@ -1436,7 +1448,7 @@ void RoPETestChatGLMHF::generate_inputs(const std::vector<ov::Shape>& targetInpu
}

void RoPETestChatGLMHF::SetUp() {
const auto& [element_type, _targetDevice] = this->GetParam();
const auto& [element_type, _targetDevice, use_split] = this->GetParam();
targetDevice = _targetDevice;

const int seq_length = 7;
Expand All @@ -1447,13 +1459,13 @@ void RoPETestChatGLMHF::SetUp() {
{{-1, 1, 1, (rotary_ndims / 2)}, {{seq_length, 1, 1, (rotary_ndims / 2)}}},
{{-1, 1, 1, (rotary_ndims / 2)}, {{seq_length, 1, 1, (rotary_ndims / 2)}}}};
init_input_shapes(input_shapes);
function = buildROPE_ChatGLMHF(-1, num_heads, rotary_ndims, element_type);
function = buildROPE_ChatGLMHF(-1, num_heads, rotary_ndims, element_type, use_split);
}

std::string RoPETestChatGLMHF::getTestCaseName(const testing::TestParamInfo<rope_params>& obj) {
const auto& [element_type, targetDevice] = obj.param;
std::string RoPETestChatGLMHF::getTestCaseName(const testing::TestParamInfo<rope_params_chatglm>& obj) {
const auto& [element_type, targetDevice, use_split] = obj.param;
std::ostringstream result;
result << "targetDevice=" << targetDevice << "_element_type=" << element_type.to_string();
result << "targetDevice=" << targetDevice << "_element_type=" << element_type.to_string() << "_use_split_" << use_split;
return result.str();
}

Expand Down
Loading