From f8a65d56302a85e5283f9f999e68b9a2270b8082 Mon Sep 17 00:00:00 2001 From: Andrii Staikov Date: Tue, 7 May 2024 13:14:34 +0330 Subject: [PATCH] Deduce the number of KV heads and head_size from the model Deduce the number of KV heads and head_size from the model without relying on HF config, and set the deduced values as KV cache input dimension. Applied HW specific layout rearagement based on the current expectations from CPU and GPU preserving those deduced dimensions. --- .../state_management_pattern.cpp | 60 +++++++++++++++++-- src/core/src/op/paged_attention.cpp | 12 +--- 2 files changed, 55 insertions(+), 17 deletions(-) diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp index 96427dfecc9371..cedaa3f27e7b2d 100644 --- a/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp @@ -60,12 +60,14 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par auto k_concat = pattern::wrap_type( {k_past, std::make_shared(OutputVector{k_current_reshaped, k_current})}); - auto kv_shaping = [=](const std::shared_ptr& kv_concat) { + auto kv_shaping = [=](const std::shared_ptr& kv_concat, std::shared_ptr& unsqueeze) { + // Return unsqeeze (return param) to deduce number of kv heads in + // the place where they are being broadcases in case of GQA and MQ auto interim = pattern::wrap_type( {kv_concat, pattern::any_input(), pattern::any_input(), pattern::any_input()}); interim = pattern::wrap_type( {interim, pattern::any_input(), pattern::any_input(), pattern::any_input()}); - auto unsqueeze = pattern::wrap_type( + unsqueeze = pattern::wrap_type( {std::make_shared(OutputVector{kv_concat, interim}), pattern::any_input()}); interim = pattern::wrap_type( {unsqueeze, pattern::any_input(), pattern::any_input(), pattern::any_input()}); @@ -90,8 +92,10 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par auto v_concat = pattern::wrap_type( {v_past, std::make_shared(OutputVector{v_current_reshaped, v_current})}); - auto k_shaped = kv_shaping(k_concat); - auto v_shaped = kv_shaping(v_concat); + std::shared_ptr k_heads_unsqueeze; + std::shared_ptr v_heads_unsqueeze; + auto k_shaped = kv_shaping(k_concat, k_heads_unsqueeze); + auto v_shaped = kv_shaping(v_concat, v_heads_unsqueeze); auto k_simply_shaped = pattern::wrap_type({k_concat, pattern::any_input()}); auto v_simply_shaped = pattern::wrap_type({v_concat, pattern::any_input()}); @@ -149,16 +153,60 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par auto real_k = take_4d(k_current, k_current_reshaped, k_current2); auto real_v = take_4d(v_current, v_current_reshaped, v_current2); + + auto sdpa_node = pattern_map.at(sdpa).get_node(); + // E and Ev are from the SDPA specification at + // https://docs.openvino.ai/2024/documentation/openvino-ir-format/operation-sets/operation-specs/sequence/scaled-dot-product-attention.html + auto E = sdpa_node->get_input_tensor(1).get_partial_shape()[-1]; + auto Ev = sdpa_node->get_input_tensor(2).get_partial_shape()[-1]; // in common case may not match E + auto num_q_heads = sdpa_node->get_input_tensor(0).get_partial_shape()[-3]; + + auto extract_num_kv_heads = [=, &pattern_map](std::shared_ptr unsqueeze) { + // Deduce number of k/v heads from Unsqueeze-Broadcast-Reshape (if present) + // pattern that appears in case of MQA/GQA + if (pattern_map.find(unsqueeze) != pattern_map.end()) { + // based on unsqueeze index determine the dimension that will be broadcased + // if there is no expected dimension for any reason, return dynamic dimension + unsqueeze = pattern_map.at(unsqueeze).get_node_shared_ptr(); + auto shape = unsqueeze->get_output_partial_shape(0); + auto rank = shape.rank(); + if (rank.is_dynamic()) { + return ov::Dimension(); + } + rank = rank.get_length(); + auto axis = unsqueeze->input_value(1).get_node_shared_ptr(); + auto constant = std::dynamic_pointer_cast(axis); + if (!constant) { + return ov::Dimension(); + } + auto data = constant->cast_vector(); + if (data.size() != 1) { // it should be only one axis + return ov::Dimension(); + } + auto first_element = data[0]; + if (first_element == 0 || + first_element == -rank.get_length()) { // there should be at least one dimension to the left + return ov::Dimension(); + } + return shape[first_element - 1]; + } else { + return num_q_heads; + } + }; + + auto num_k_heads = extract_num_kv_heads(k_heads_unsqueeze); + auto num_v_heads = extract_num_kv_heads(v_heads_unsqueeze); const ov::element::Type kv_cache_type = real_q.get_element_type(); std::string layer_index_str = std::to_string(layer_index); - auto k_parameter = setName(std::make_shared(kv_cache_type, PartialShape{-1, -1, -1, -1, -1}), + auto k_parameter = setName(std::make_shared(kv_cache_type, PartialShape{-1, num_k_heads, E}), std::string("key_cache.") + std::to_string(layer_index)); - auto v_parameter = setName(std::make_shared(kv_cache_type, PartialShape{-1, -1, -1, -1}), + auto v_parameter = setName(std::make_shared(kv_cache_type, PartialShape{-1, num_v_heads, Ev}), std::string("value_cache.") + std::to_string(layer_index)); layer_index += 1; kv_parameters.push_back(k_parameter); kv_parameters.push_back(v_parameter); auto kv_transpose_order = v0::Constant::create(element::i64, Shape{4}, {0, 2, 1, 3}); + auto q_transpose = std::make_shared(real_q, kv_transpose_order); auto q_reshape = std::make_shared(q_transpose, v0::Constant::create(element::i64, Shape{3}, {0, 0, -1}), true); diff --git a/src/core/src/op/paged_attention.cpp b/src/core/src/op/paged_attention.cpp index 376031d6ad9c33..909ed9b3ae7ea9 100644 --- a/src/core/src/op/paged_attention.cpp +++ b/src/core/src/op/paged_attention.cpp @@ -18,17 +18,7 @@ void PagedAttentionExtension::validate_and_infer_types() { // m_num_kv_heads = value_cache_shape[1]; // m_head_size = value_cache_shape[2]; // m_block_size = value_cache_shape[3]; - NODE_VALIDATION_CHECK(this, value_cache_shape.size() == 4, "Value cache shape must be 4 dims"); - - // key_cache: shape [num_blocks, num_kv_heads, head_size/x, block_size, x] - auto key_cache_shape = get_input_partial_shape(3); - NODE_VALIDATION_CHECK(this, - value_cache_shape.size() == 4, - // value_cache_shape[0] == key_cache_shape[0] && // num_blocks - // key_cache_shape[1] == m_num_kv_heads && - // key_cache_shape[2] * key_cache_shape[4] == m_head_size && - // m_block_size == key_cache_shape[3], // block_size, - "Key cache shape must be 4 dims"); + // Do not check shapes for cache K and cache V inputs, because they are hardware dependent // query: shape [batch_size, seq_len, num_heads * head_size] auto query_type = get_input_element_type(0);