Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
auto k_concat = pattern::wrap_type<v0::Concat>(
{k_past, std::make_shared<pattern::op::Or>(OutputVector{k_current_reshaped, k_current})});

auto kv_shaping = [=](const std::shared_ptr<Node>& kv_concat) {
auto kv_shaping = [=](const std::shared_ptr<Node>& kv_concat, std::shared_ptr<Node>& unsqueeze) {
// Return unsqeeze (return param) to deduce number of kv heads in
// the place where they are being broadcases in case of GQA and MQ
auto interim = pattern::wrap_type<v1::StridedSlice>(
{kv_concat, pattern::any_input(), pattern::any_input(), pattern::any_input()});
interim = pattern::wrap_type<v1::StridedSlice>(
{interim, pattern::any_input(), pattern::any_input(), pattern::any_input()});
auto unsqueeze = pattern::wrap_type<v0::Unsqueeze>(
unsqueeze = pattern::wrap_type<v0::Unsqueeze>(
{std::make_shared<pattern::op::Or>(OutputVector{kv_concat, interim}), pattern::any_input()});
interim = pattern::wrap_type<v1::StridedSlice>(
{unsqueeze, pattern::any_input(), pattern::any_input(), pattern::any_input()});
Expand All @@ -90,8 +92,10 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
auto v_concat = pattern::wrap_type<v0::Concat>(
{v_past, std::make_shared<pattern::op::Or>(OutputVector{v_current_reshaped, v_current})});

auto k_shaped = kv_shaping(k_concat);
auto v_shaped = kv_shaping(v_concat);
std::shared_ptr<Node> k_heads_unsqueeze;
std::shared_ptr<Node> v_heads_unsqueeze;
auto k_shaped = kv_shaping(k_concat, k_heads_unsqueeze);
auto v_shaped = kv_shaping(v_concat, v_heads_unsqueeze);

auto k_simply_shaped = pattern::wrap_type<v1::Reshape>({k_concat, pattern::any_input()});
auto v_simply_shaped = pattern::wrap_type<v1::Reshape>({v_concat, pattern::any_input()});
Expand Down Expand Up @@ -149,16 +153,60 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par

auto real_k = take_4d(k_current, k_current_reshaped, k_current2);
auto real_v = take_4d(v_current, v_current_reshaped, v_current2);

auto sdpa_node = pattern_map.at(sdpa).get_node();
// E and Ev are from the SDPA specification at
// https://docs.openvino.ai/2024/documentation/openvino-ir-format/operation-sets/operation-specs/sequence/scaled-dot-product-attention.html
auto E = sdpa_node->get_input_tensor(1).get_partial_shape()[-1];
auto Ev = sdpa_node->get_input_tensor(2).get_partial_shape()[-1]; // in common case may not match E
auto num_q_heads = sdpa_node->get_input_tensor(0).get_partial_shape()[-3];

auto extract_num_kv_heads = [=, &pattern_map](std::shared_ptr<Node> unsqueeze) {
// Deduce number of k/v heads from Unsqueeze-Broadcast-Reshape (if present)
// pattern that appears in case of MQA/GQA
if (pattern_map.find(unsqueeze) != pattern_map.end()) {
// based on unsqueeze index determine the dimension that will be broadcased
// if there is no expected dimension for any reason, return dynamic dimension
unsqueeze = pattern_map.at(unsqueeze).get_node_shared_ptr();
auto shape = unsqueeze->get_output_partial_shape(0);
auto rank = shape.rank();
if (rank.is_dynamic()) {
return ov::Dimension();
}
rank = rank.get_length();
auto axis = unsqueeze->input_value(1).get_node_shared_ptr();
auto constant = std::dynamic_pointer_cast<ov::op::v0::Constant>(axis);
if (!constant) {
return ov::Dimension();
}
auto data = constant->cast_vector<int64_t>();
if (data.size() != 1) { // it should be only one axis
return ov::Dimension();
}
auto first_element = data[0];
if (first_element == 0 ||
first_element == -rank.get_length()) { // there should be at least one dimension to the left
return ov::Dimension();
}
return shape[first_element - 1];
} else {
return num_q_heads;
}
};

auto num_k_heads = extract_num_kv_heads(k_heads_unsqueeze);
auto num_v_heads = extract_num_kv_heads(v_heads_unsqueeze);
const ov::element::Type kv_cache_type = real_q.get_element_type();
std::string layer_index_str = std::to_string(layer_index);
auto k_parameter = setName(std::make_shared<v0::Parameter>(kv_cache_type, PartialShape{-1, -1, -1, -1, -1}),
auto k_parameter = setName(std::make_shared<v0::Parameter>(kv_cache_type, PartialShape{-1, num_k_heads, E}),
std::string("key_cache.") + std::to_string(layer_index));
auto v_parameter = setName(std::make_shared<v0::Parameter>(kv_cache_type, PartialShape{-1, -1, -1, -1}),
auto v_parameter = setName(std::make_shared<v0::Parameter>(kv_cache_type, PartialShape{-1, num_v_heads, Ev}),
std::string("value_cache.") + std::to_string(layer_index));
layer_index += 1;
kv_parameters.push_back(k_parameter);
kv_parameters.push_back(v_parameter);
auto kv_transpose_order = v0::Constant::create(element::i64, Shape{4}, {0, 2, 1, 3});

auto q_transpose = std::make_shared<v1::Transpose>(real_q, kv_transpose_order);
auto q_reshape =
std::make_shared<v1::Reshape>(q_transpose, v0::Constant::create(element::i64, Shape{3}, {0, 0, -1}), true);
Expand Down
12 changes: 1 addition & 11 deletions src/core/src/op/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,7 @@ void PagedAttentionExtension::validate_and_infer_types() {
// m_num_kv_heads = value_cache_shape[1];
// m_head_size = value_cache_shape[2];
// m_block_size = value_cache_shape[3];
NODE_VALIDATION_CHECK(this, value_cache_shape.size() == 4, "Value cache shape must be 4 dims");

// key_cache: shape [num_blocks, num_kv_heads, head_size/x, block_size, x]
auto key_cache_shape = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
value_cache_shape.size() == 4,
// value_cache_shape[0] == key_cache_shape[0] && // num_blocks
// key_cache_shape[1] == m_num_kv_heads &&
// key_cache_shape[2] * key_cache_shape[4] == m_head_size &&
// m_block_size == key_cache_shape[3], // block_size,
"Key cache shape must be 4 dims");
// Do not check shapes for cache K and cache V inputs, because they are hardware dependent

// query: shape [batch_size, seq_len, num_heads * head_size]
auto query_type = get_input_element_type(0);
Expand Down