diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp index 96427dfecc9371..cedaa3f27e7b2d 100644 --- a/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp @@ -60,12 +60,14 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par auto k_concat = pattern::wrap_type( {k_past, std::make_shared(OutputVector{k_current_reshaped, k_current})}); - auto kv_shaping = [=](const std::shared_ptr& kv_concat) { + auto kv_shaping = [=](const std::shared_ptr& kv_concat, std::shared_ptr& unsqueeze) { + // Return unsqeeze (return param) to deduce number of kv heads in + // the place where they are being broadcases in case of GQA and MQ auto interim = pattern::wrap_type( {kv_concat, pattern::any_input(), pattern::any_input(), pattern::any_input()}); interim = pattern::wrap_type( {interim, pattern::any_input(), pattern::any_input(), pattern::any_input()}); - auto unsqueeze = pattern::wrap_type( + unsqueeze = pattern::wrap_type( {std::make_shared(OutputVector{kv_concat, interim}), pattern::any_input()}); interim = pattern::wrap_type( {unsqueeze, pattern::any_input(), pattern::any_input(), pattern::any_input()}); @@ -90,8 +92,10 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par auto v_concat = pattern::wrap_type( {v_past, std::make_shared(OutputVector{v_current_reshaped, v_current})}); - auto k_shaped = kv_shaping(k_concat); - auto v_shaped = kv_shaping(v_concat); + std::shared_ptr k_heads_unsqueeze; + std::shared_ptr v_heads_unsqueeze; + auto k_shaped = kv_shaping(k_concat, k_heads_unsqueeze); + auto v_shaped = kv_shaping(v_concat, v_heads_unsqueeze); auto k_simply_shaped = pattern::wrap_type({k_concat, pattern::any_input()}); auto v_simply_shaped = pattern::wrap_type({v_concat, pattern::any_input()}); @@ -149,16 +153,60 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par auto real_k = take_4d(k_current, k_current_reshaped, k_current2); auto real_v = take_4d(v_current, v_current_reshaped, v_current2); + + auto sdpa_node = pattern_map.at(sdpa).get_node(); + // E and Ev are from the SDPA specification at + // https://docs.openvino.ai/2024/documentation/openvino-ir-format/operation-sets/operation-specs/sequence/scaled-dot-product-attention.html + auto E = sdpa_node->get_input_tensor(1).get_partial_shape()[-1]; + auto Ev = sdpa_node->get_input_tensor(2).get_partial_shape()[-1]; // in common case may not match E + auto num_q_heads = sdpa_node->get_input_tensor(0).get_partial_shape()[-3]; + + auto extract_num_kv_heads = [=, &pattern_map](std::shared_ptr unsqueeze) { + // Deduce number of k/v heads from Unsqueeze-Broadcast-Reshape (if present) + // pattern that appears in case of MQA/GQA + if (pattern_map.find(unsqueeze) != pattern_map.end()) { + // based on unsqueeze index determine the dimension that will be broadcased + // if there is no expected dimension for any reason, return dynamic dimension + unsqueeze = pattern_map.at(unsqueeze).get_node_shared_ptr(); + auto shape = unsqueeze->get_output_partial_shape(0); + auto rank = shape.rank(); + if (rank.is_dynamic()) { + return ov::Dimension(); + } + rank = rank.get_length(); + auto axis = unsqueeze->input_value(1).get_node_shared_ptr(); + auto constant = std::dynamic_pointer_cast(axis); + if (!constant) { + return ov::Dimension(); + } + auto data = constant->cast_vector(); + if (data.size() != 1) { // it should be only one axis + return ov::Dimension(); + } + auto first_element = data[0]; + if (first_element == 0 || + first_element == -rank.get_length()) { // there should be at least one dimension to the left + return ov::Dimension(); + } + return shape[first_element - 1]; + } else { + return num_q_heads; + } + }; + + auto num_k_heads = extract_num_kv_heads(k_heads_unsqueeze); + auto num_v_heads = extract_num_kv_heads(v_heads_unsqueeze); const ov::element::Type kv_cache_type = real_q.get_element_type(); std::string layer_index_str = std::to_string(layer_index); - auto k_parameter = setName(std::make_shared(kv_cache_type, PartialShape{-1, -1, -1, -1, -1}), + auto k_parameter = setName(std::make_shared(kv_cache_type, PartialShape{-1, num_k_heads, E}), std::string("key_cache.") + std::to_string(layer_index)); - auto v_parameter = setName(std::make_shared(kv_cache_type, PartialShape{-1, -1, -1, -1}), + auto v_parameter = setName(std::make_shared(kv_cache_type, PartialShape{-1, num_v_heads, Ev}), std::string("value_cache.") + std::to_string(layer_index)); layer_index += 1; kv_parameters.push_back(k_parameter); kv_parameters.push_back(v_parameter); auto kv_transpose_order = v0::Constant::create(element::i64, Shape{4}, {0, 2, 1, 3}); + auto q_transpose = std::make_shared(real_q, kv_transpose_order); auto q_reshape = std::make_shared(q_transpose, v0::Constant::create(element::i64, Shape{3}, {0, 0, -1}), true); diff --git a/src/core/src/op/paged_attention.cpp b/src/core/src/op/paged_attention.cpp index 376031d6ad9c33..909ed9b3ae7ea9 100644 --- a/src/core/src/op/paged_attention.cpp +++ b/src/core/src/op/paged_attention.cpp @@ -18,17 +18,7 @@ void PagedAttentionExtension::validate_and_infer_types() { // m_num_kv_heads = value_cache_shape[1]; // m_head_size = value_cache_shape[2]; // m_block_size = value_cache_shape[3]; - NODE_VALIDATION_CHECK(this, value_cache_shape.size() == 4, "Value cache shape must be 4 dims"); - - // key_cache: shape [num_blocks, num_kv_heads, head_size/x, block_size, x] - auto key_cache_shape = get_input_partial_shape(3); - NODE_VALIDATION_CHECK(this, - value_cache_shape.size() == 4, - // value_cache_shape[0] == key_cache_shape[0] && // num_blocks - // key_cache_shape[1] == m_num_kv_heads && - // key_cache_shape[2] * key_cache_shape[4] == m_head_size && - // m_block_size == key_cache_shape[3], // block_size, - "Key cache shape must be 4 dims"); + // Do not check shapes for cache K and cache V inputs, because they are hardware dependent // query: shape [batch_size, seq_len, num_heads * head_size] auto query_type = get_input_element_type(0);