diff --git a/src/bindings/python/src/openvino/_offline_transformations/__init__.py b/src/bindings/python/src/openvino/_offline_transformations/__init__.py index 2cfe8cec521524..81c288657afd0d 100644 --- a/src/bindings/python/src/openvino/_offline_transformations/__init__.py +++ b/src/bindings/python/src/openvino/_offline_transformations/__init__.py @@ -17,3 +17,4 @@ from openvino._pyopenvino._offline_transformations import compress_model_transformation from openvino._pyopenvino._offline_transformations import compress_quantize_weights_transformation from openvino._pyopenvino._offline_transformations import convert_sequence_to_tensor_iterator_transformation +from openvino._pyopenvino._offline_transformations import paged_attention_transformation diff --git a/src/bindings/python/src/pyopenvino/core/offline_transformations.cpp b/src/bindings/python/src/pyopenvino/core/offline_transformations.cpp index b91837adfa230d..fa18ba0c84d4dd 100644 --- a/src/bindings/python/src/pyopenvino/core/offline_transformations.cpp +++ b/src/bindings/python/src/pyopenvino/core/offline_transformations.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -127,4 +128,13 @@ void regmodule_offline_transformations(py::module m) { manager.run_passes(model); }, py::arg("model")); + + m_offline_transformations.def( + "paged_attention_transformation", + [](std::shared_ptr model) { + ov::pass::Manager manager; + manager.register_pass(); + manager.run_passes(model); + }, + py::arg("model")); } diff --git a/src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.cpp b/src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.cpp index 608f4fe2b61a09..3b6e4f5c897ad5 100644 --- a/src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.cpp +++ b/src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.cpp @@ -5,162 +5,13 @@ #include "pyopenvino/graph/ops/paged_attention_extension.hpp" #include "openvino/op/op.hpp" +#include "openvino/op/paged_attention.hpp" #include "pyopenvino/core/common.hpp" namespace py = pybind11; -namespace { - -// This is an experimental operation that is implemented in the plugins. -// Do not use in user applications, backward compatibility is not guaranteed in future releases. -class PagedAttentionExtension : public ov::op::Op { -public: - OPENVINO_OP("PagedAttentionExtension"); - - PagedAttentionExtension(const ov::OutputVector& args) : ov::op::Op(args) { - constructor_validate_and_infer_types(); - } - - void validate_and_infer_types() override { - auto value_cache_shape = get_input_partial_shape(4); - // m_num_kv_heads = value_cache_shape[1]; - // m_head_size = value_cache_shape[2]; - // m_block_size = value_cache_shape[3]; - NODE_VALIDATION_CHECK(this, value_cache_shape.size() == 4, "Value cache shape must be 4 dims"); - - // key_cache: shape [num_blocks, num_kv_heads, head_size/x, block_size, x] - auto key_cache_shape = get_input_partial_shape(3); - NODE_VALIDATION_CHECK(this, - value_cache_shape.size() == 4, - // value_cache_shape[0] == key_cache_shape[0] && // num_blocks - // key_cache_shape[1] == m_num_kv_heads && - // key_cache_shape[2] * key_cache_shape[4] == m_head_size && - // m_block_size == key_cache_shape[3], // block_size, - "Key cache shape must be 4 dims"); - - // query: shape [batch_size, seq_len, num_heads * head_size] - auto query_type = get_input_element_type(0); - auto query_shape = get_input_partial_shape(0); - NODE_VALIDATION_CHECK( - this, - // query_type.is_real() && - query_shape.size() == 3, - // query_shape[2] == m_num_heads * m_head_size, - "Query type must be real, shape must be like [batch_size, seq_len, num_heads * head_size]. ", - "Got element type ", - query_type, - ", shape ", - query_shape); - - // key: shape [batch_size, seq_len, num_kv_heads * head_size] - auto key_type = get_input_element_type(1); - auto key_shape = get_input_partial_shape(1); - NODE_VALIDATION_CHECK(this, - // query_type == key_type && - key_shape.size() == 3, - "Key type must be the same as query, shape must be the same as query. " - "Got element type ", - key_type, - ", shape ", - key_shape); - - // value: shape [batch_size, seq_len, num_kv_heads * head_size] - // auto value_type = get_input_element_type(2); - auto value_shape = get_input_partial_shape(2); - - // is_prompt: boolean scalar - NODE_VALIDATION_CHECK(this, - // get_input_element_type(5) == ov::element::boolean && - get_input_shape(5) == ov::Shape({}), - "is_prompt validation failed. ", - "Got element type ", - get_input_element_type(5), - ", shape ", - get_input_shape(5)); - - // slot_mapping: shape [batch_size, max_context_len] - auto slot_mapping_shape = get_input_partial_shape(6); - NODE_VALIDATION_CHECK(this, - // get_input_element_type(6) == ov::element::i64 && - slot_mapping_shape.size() == 2, - "slot_mapping validation failed. ", - "Got element type ", - get_input_element_type(6), - ", shape ", - slot_mapping_shape); - - // max_context_len: integer scalar - NODE_VALIDATION_CHECK(this, - // get_input_element_type(7) == ov::element::i32 && - get_input_shape(7) == ov::Shape({}), - "max_context_len validation failed. ", - "Got element type ", - get_input_element_type(7), - ", shape ", - get_input_shape(7)); - - // context_lens: shape [batch_size] - auto context_lens_shape = get_input_partial_shape(8); - NODE_VALIDATION_CHECK(this, - // get_input_element_type(8) == ov::element::i32 && - context_lens_shape.size() == 1, - "context_lens validation failed. ", - "Got element type ", - get_input_element_type(8), - ", shape ", - context_lens_shape); - - // block_tables: shape [batch_size, max_block_per_request] - NODE_VALIDATION_CHECK(this, - // get_input_element_type(9) == ov::element::i32 && - get_input_partial_shape(9).size() == 2, - "block_tables validation failed. ", - "Got element type ", - get_input_element_type(9), - ", shape ", - get_input_partial_shape(9)); - - // scale: float scalar - NODE_VALIDATION_CHECK(this, - // get_input_element_type(10) == ov::element::f32 && - get_input_shape(10) == ov::Shape({}), - "block_tables validation failed. ", - "Got element type ", - get_input_element_type(10), - ", shape ", - get_input_shape(10)); - - // alibi_slopes: 1D float tensor - NODE_VALIDATION_CHECK(this, - // get_input_element_type(11) == ov::element::f32 && - get_input_partial_shape(11).rank().get_length() == 1, - "alibi_slopes should be a 1D float tensor. ", - "Got element type ", - get_input_element_type(11), - ", shape ", - get_input_partial_shape(11)); - - // sliding_window: int scalar - NODE_VALIDATION_CHECK(this, - // get_input_element_type(12) == ov::element::i32 && - get_input_partial_shape(12).rank().get_length() == 0, - "sliding_window argument should be an i32 scalar. ", - "Got element type ", - get_input_element_type(12), - ", shape ", - get_input_partial_shape(12)); - - set_output_type(0, query_type, query_shape); - } - - std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override { - return std::make_shared(new_args); - } -}; - -} // namespace - void regclass_graph_op_PagedAttentionExtension(py::module m) { + using ov::op::PagedAttentionExtension; py::class_, ov::Node> cls( m, "_PagedAttentionExtension"); diff --git a/src/common/transformations/include/transformations/sdpa_to_paged_attention/position_ids_replacer.hpp b/src/common/transformations/include/transformations/sdpa_to_paged_attention/position_ids_replacer.hpp new file mode 100644 index 00000000000000..748ed1c0887617 --- /dev/null +++ b/src/common/transformations/include/transformations/sdpa_to_paged_attention/position_ids_replacer.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/cc/pass/itt.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/parameter.hpp" +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/utils/utils.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +class PositionIDsReplacer; + +} // namespace pass +} // namespace ov + +class ov::pass::PositionIDsReplacer : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("PositionIDsReplacer", "0"); + explicit PositionIDsReplacer(const std::shared_ptr>& position_ids); +}; \ No newline at end of file diff --git a/src/common/transformations/include/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp b/src/common/transformations/include/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp new file mode 100644 index 00000000000000..8f7cc2e0a2e11f --- /dev/null +++ b/src/common/transformations/include/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/cc/pass/itt.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/utils/utils.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +class PrevSequenceLengthPattern; + +} // namespace pass +} // namespace ov + +class ov::pass::PrevSequenceLengthPattern : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("PrevSequenceLengthPattern", "0"); + explicit PrevSequenceLengthPattern(const std::shared_ptr& prev_max_seq_len); +}; \ No newline at end of file diff --git a/src/common/transformations/include/transformations/sdpa_to_paged_attention/state_management_pattern.hpp b/src/common/transformations/include/transformations/sdpa_to_paged_attention/state_management_pattern.hpp new file mode 100644 index 00000000000000..31d78236ffd83d --- /dev/null +++ b/src/common/transformations/include/transformations/sdpa_to_paged_attention/state_management_pattern.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +class StateManagementPattern; + +} // namespace pass +} // namespace ov + +class ov::pass::StateManagementPattern : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("StateManagementPattern", "0"); + StateManagementPattern(ParameterVector& kv_parameters, + const ParameterVector& model_remaining_params, + const std::shared_ptr& sliding_window, + ParameterVector& parameters_to_remove, + NodeVector& assignes_to_remove, + int& layer_index); +}; \ No newline at end of file diff --git a/src/common/transformations/include/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp b/src/common/transformations/include/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp new file mode 100644 index 00000000000000..3fef00bb7e05f6 --- /dev/null +++ b/src/common/transformations/include/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/cc/pass/itt.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/parameter.hpp" +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/utils/utils.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +class TotalSequenceLengthPattern; + +} // namespace pass +} // namespace ov + +class ov::pass::TotalSequenceLengthPattern : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("TotalSequenceLengthPattern", "0"); + explicit TotalSequenceLengthPattern(const std::shared_ptr& max_context_len); +}; diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp new file mode 100644 index 00000000000000..33ca7e4cad804d --- /dev/null +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp @@ -0,0 +1,39 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/sdpa_to_paged_attention/position_ids_replacer.hpp" + +#include "openvino/cc/pass/itt.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/utils/utils.hpp" + +using namespace ov::op; + +// TODO: Instead of using the following transformation that matches quite a specific place in a model graph in case when +// position_ids parameter is missing, consider replacing always existing attention_mask parameter with a sub-graph using +// a new slot_mapping parameter. +ov::pass::PositionIDsReplacer::PositionIDsReplacer(const std::shared_ptr>& position_ids) { + MATCHER_SCOPE(PositionIDsReplacer); + + auto input_ids = pattern::any_input(); + auto input_embed = pattern::wrap_type({pattern::any_input(), input_ids, pattern::any_input()}); + + auto position_ids_pattern = pattern::any_input(); + auto offset = pattern::wrap_type(); + auto add_offset = pattern::wrap_type({position_ids_pattern, offset}); + auto convert = pattern::wrap_type({add_offset}); + auto position_embed = pattern::wrap_type({pattern::any_input(), convert, pattern::any_input()}); + + auto add = pattern::wrap_type({input_embed, position_embed}); + + ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + replace_node(pattern_map.at(position_ids_pattern).get_node_shared_ptr(), position_ids->get_node_shared_ptr()); + return true; + }; + + auto m = std::make_shared(add, matcher_name); + register_matcher(m, callback); +} \ No newline at end of file diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.cpp new file mode 100644 index 00000000000000..8b7caddd61e491 --- /dev/null +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.cpp @@ -0,0 +1,41 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp" + +#include "openvino/cc/pass/itt.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/utils/utils.hpp" + +using namespace ov::op; + +ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern( + const std::shared_ptr& prev_max_seq_len) { + MATCHER_SCOPE(PrevSequenceLengthPattern); + + auto kv_past = pattern::wrap_type({pattern::any_input()}); + auto kv_gather = pattern::wrap_type({kv_past, pattern::any_input(), pattern::any_input()}); + auto kv_shape = pattern::wrap_type({kv_gather}); + auto seq = pattern::wrap_type({kv_past, pattern::any_input(), pattern::any_input()}); + + ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + // TODO: Check that seq has axis that really takes sequence len but not any other dimension -- use symbolics or + // look at the constant input + auto gather = m.get_match_root(); + auto target_type = gather->get_output_element_type(0); + std::shared_ptr replacement; + if (prev_max_seq_len->get_output_element_type(0) != target_type) { + replacement = std::make_shared(prev_max_seq_len, target_type); + } else { + replacement = prev_max_seq_len; + } + replace_node(gather, replacement); + return true; + }; + + auto m = std::make_shared(seq, matcher_name); + register_matcher(m, callback); +} diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp new file mode 100644 index 00000000000000..96427dfecc9371 --- /dev/null +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp @@ -0,0 +1,278 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/sdpa_to_paged_attention/state_management_pattern.hpp" + +#include "openvino/cc/pass/itt.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/paged_attention.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scaled_dot_product_attention.hpp" +#include "openvino/op/select.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/sqrt.hpp" +#include "openvino/op/strided_slice.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "openvino/pass/pattern/op/or.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/utils/utils.hpp" + +using namespace ov::op; + +// Exactly copied the function from another file. Maybe should be moved to some general file +static std::shared_ptr setName(std::shared_ptr node, const std::string& name) { + // Set name for both node and output tensor (should be only one tensor, and any other names will be overriden by a + // given single name) + node->set_friendly_name(name); + OPENVINO_ASSERT(node->get_output_size() == + 1); // Should I use assert here? I heard using ASSERTS is not the best thing + node->get_output_tensor(0).set_names({name}); + return node; +} + +ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_parameters, + const ParameterVector& model_remaining_params, + const std::shared_ptr& sliding_window, + ParameterVector& parameters_to_remove, + NodeVector& assignes_to_remove, + int& layer_index) { + MATCHER_SCOPE(StateManagementPattern); + + auto k_past_var = pattern::wrap_type({pattern::any_input()}); + auto k_past_par = pattern::wrap_type(); + auto k_past = std::make_shared( + OutputVector{pattern::wrap_type({k_past_var, pattern::any_input(), pattern::any_input()}), + k_past_par}); + k_past = std::make_shared( + OutputVector{k_past, + pattern::wrap_type( + {k_past, pattern::any_input()})}); // Transpose is used when kv-cache is stored in a not usual + // layout, example: bloom + auto k_current = pattern::any_input(); + auto k_current2 = pattern::any_input(); + auto k_current_reshaped = pattern::wrap_type({k_current2, pattern::any_input()}); + auto k_concat = pattern::wrap_type( + {k_past, std::make_shared(OutputVector{k_current_reshaped, k_current})}); + + auto kv_shaping = [=](const std::shared_ptr& kv_concat) { + auto interim = pattern::wrap_type( + {kv_concat, pattern::any_input(), pattern::any_input(), pattern::any_input()}); + interim = pattern::wrap_type( + {interim, pattern::any_input(), pattern::any_input(), pattern::any_input()}); + auto unsqueeze = pattern::wrap_type( + {std::make_shared(OutputVector{kv_concat, interim}), pattern::any_input()}); + interim = pattern::wrap_type( + {unsqueeze, pattern::any_input(), pattern::any_input(), pattern::any_input()}); + interim = pattern::wrap_type( + {interim, pattern::any_input(), pattern::any_input(), pattern::any_input()}); + interim = pattern::wrap_type( + {std::make_shared(OutputVector{unsqueeze, interim}), pattern::any_input()}); + interim = pattern::wrap_type({interim, pattern::any_input()}); + return interim; + }; + + auto v_past_var = pattern::wrap_type({pattern::any_input()}); + auto v_past_par = pattern::wrap_type(); + auto v_past = std::make_shared( + OutputVector{pattern::wrap_type({v_past_var, pattern::any_input(), pattern::any_input()}), + v_past_par}); + v_past = std::make_shared( + OutputVector{v_past, pattern::wrap_type({v_past, pattern::any_input()})}); + auto v_current = pattern::any_input(); + auto v_current2 = pattern::any_input(); + auto v_current_reshaped = pattern::wrap_type({v_current2, pattern::any_input()}); + auto v_concat = pattern::wrap_type( + {v_past, std::make_shared(OutputVector{v_current_reshaped, v_current})}); + + auto k_shaped = kv_shaping(k_concat); + auto v_shaped = kv_shaping(v_concat); + + auto k_simply_shaped = pattern::wrap_type({k_concat, pattern::any_input()}); + auto v_simply_shaped = pattern::wrap_type({v_concat, pattern::any_input()}); + + auto k_order = pattern::any_input(); + auto v_order = pattern::any_input(); + + // KV-path may already have Transposes that will be rewritten based on PA KV inputs required layout + auto k_shaped_transposed = pattern::wrap_type( + {std::make_shared(OutputVector{k_concat, k_shaped}), k_order}); + auto v_shaped_transposed = pattern::wrap_type( + {std::make_shared(OutputVector{v_concat, v_shaped}), v_order}); + + // Optional pattern to capture alibi slopes (based on pattern from bloom) + auto alibi = pattern::any_input(); + auto sdpa_mask = pattern::wrap_type({pattern::any_input(), alibi}); // apply input position_ids + sdpa_mask = pattern::wrap_type({sdpa_mask, pattern::any_input()}); + sdpa_mask = pattern::wrap_type({sdpa_mask, pattern::any_input()}); + sdpa_mask = pattern::wrap_type({pattern::any_input(), pattern::any_input(), sdpa_mask}); + + auto q = pattern::any_input(); + auto sdpa = pattern::wrap_type( + {q, + std::make_shared(OutputVector{k_concat, k_shaped, k_shaped_transposed, k_simply_shaped}), + std::make_shared(OutputVector{v_concat, v_shaped, v_shaped_transposed, v_simply_shaped}), + std::make_shared(OutputVector{sdpa_mask, pattern::any_input()})}); + + ov::matcher_pass_callback callback = [=, + &kv_parameters, + &model_remaining_params, + &sliding_window, + ¶meters_to_remove, + &assignes_to_remove, + &layer_index](ov::pass::pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + if (pattern_map.find(sdpa) == pattern_map.end()) { + return false; + } + auto real_q = pattern_map.at(q); + + // takes option that has 4D instead of fine-grained Reshape analysis + // it avoids complication in the pattern, but we don't really have many options + auto take_4d = [=](const std::shared_ptr& option1, + const std::shared_ptr& option2, + const std::shared_ptr& option3) { + if (pattern_map.find(option1) != pattern_map.end() && + pattern_map.at(option1).get_partial_shape().rank().get_length() == 4) { + return pattern_map.at(option1); + } else if (pattern_map.at(option2).get_partial_shape().rank().get_length() == 4) { + return pattern_map.at(option2); + } else { + return pattern_map.at(option3); + } + }; + + auto real_k = take_4d(k_current, k_current_reshaped, k_current2); + auto real_v = take_4d(v_current, v_current_reshaped, v_current2); + const ov::element::Type kv_cache_type = real_q.get_element_type(); + std::string layer_index_str = std::to_string(layer_index); + auto k_parameter = setName(std::make_shared(kv_cache_type, PartialShape{-1, -1, -1, -1, -1}), + std::string("key_cache.") + std::to_string(layer_index)); + auto v_parameter = setName(std::make_shared(kv_cache_type, PartialShape{-1, -1, -1, -1}), + std::string("value_cache.") + std::to_string(layer_index)); + layer_index += 1; + kv_parameters.push_back(k_parameter); + kv_parameters.push_back(v_parameter); + auto kv_transpose_order = v0::Constant::create(element::i64, Shape{4}, {0, 2, 1, 3}); + auto q_transpose = std::make_shared(real_q, kv_transpose_order); + auto q_reshape = + std::make_shared(q_transpose, v0::Constant::create(element::i64, Shape{3}, {0, 0, -1}), true); + + std::shared_ptr k_transpose_order = + kv_transpose_order; // eeeh, is it a right way to assign Constants? Maybe I should clone somehow? + if (pattern_map.find(k_order) != + pattern_map.end()) { // reapply transpose found in the graph by manipulating of indices of our Transpose + k_transpose_order = std::make_shared(pattern_map.at(k_order), + kv_transpose_order, + v0::Constant::create(element::i64, Shape{}, {0})); + } + auto k_transpose = std::make_shared(real_k, k_transpose_order); + auto k_reshape = + std::make_shared(k_transpose, v0::Constant::create(element::i64, Shape{3}, {0, 0, -1}), true); + + std::shared_ptr v_transpose_order = + kv_transpose_order; // eeeh, is it a right way to assign Constants? Maybe I should clone somehow? + if (pattern_map.find(v_order) != + pattern_map.end()) { // reapply transpose found in the graph by manipulating of indices of our Transpose + v_transpose_order = std::make_shared(pattern_map.at(v_order), + kv_transpose_order, + v0::Constant::create(element::i64, Shape{}, {0})); + } + auto v_transpose = std::make_shared(real_v, v_transpose_order); + auto v_reshape = + std::make_shared(v_transpose, v0::Constant::create(element::i64, Shape{3}, {0, 0, -1}), true); + + // TODO: Detect whether SDPA in the model graph has `scale` argument set and use it instead of the computed + // scale below Most likely `scale` will always be a constant in real inference, but dynamic dimension + // propagation may not always derive it as a constant That's why a sub-graph computing `scale` is built instead + // of just a constant node. + auto hidden_shape = std::make_shared(real_q); + auto hidden_dim = std::make_shared(hidden_shape, + v0::Constant::create(element::i64, Shape{}, {-1}), + v0::Constant::create(element::i64, Shape{}, {0})); + auto scale = std::make_shared( + v0::Constant::create(element::f32, Shape{}, {1}), + std::make_shared(std::make_shared(hidden_dim, element::f32))); + + std::shared_ptr alibi_slopes; + if (pattern_map.find(alibi) != pattern_map.end()) { + alibi_slopes = std::make_shared(pattern_map.at(alibi), + v0::Constant::create(element::i64, Shape{1}, {-1}), + false); // here {-1} is interesting in Python TODO: discuss + if (alibi_slopes->get_element_type() == element::f32) { + alibi_slopes = std::make_shared(alibi_slopes, element::f32); + } + } else { + alibi_slopes = v0::Constant::create(element::f32, Shape{0}, {}); // correctly created? + } + + OutputVector params = {q_reshape, k_reshape, v_reshape, k_parameter, v_parameter}; + params.insert(params.end(), model_remaining_params.begin(), model_remaining_params.end()); + std::initializer_list> additional_params = {scale, alibi_slopes, sliding_window}; + params.insert(params.end(), additional_params.begin(), additional_params.end()); + + // Really not sure if I construct correctly because the Python code uses an additional function + auto paged_attention = std::make_shared(params); + + auto pa_shape = std::make_shared( + OutputVector{ + v0::Constant::create(element::i64, Shape{1}, {0}), + v0::Constant::create(element::i64, Shape{1}, {0}), + v0::Constant::create(element::i64, Shape{1}, {-1}), + std::make_shared(hidden_dim, v0::Constant::create(element::i64, Shape{}, {0})), + }, + 0); + auto pa_reshape = std::make_shared(paged_attention, pa_shape, true); + auto pa_transpose = std::make_shared(pa_reshape, kv_transpose_order); + + // TODO: Complete this part to work with stateless models as well as will stateful + // def add_kv_parameter(past_node): + // if past_node.get_type_info().name == 'Parameter': + // parameters_to_remove.append(past_node) + + // add_kv_parameter(mapping[k_gather]) + // add_kv_parameter(mapping[v_gather]) + + if (pattern_map.find(v_past_par) != pattern_map.end()) { + auto param = std::dynamic_pointer_cast(pattern_map.at(v_past_par).get_node_shared_ptr()); + if (param) { + return false; + } + parameters_to_remove.push_back(param); + } + + if (pattern_map.find(k_past_par) != pattern_map.end()) { + auto param = std::dynamic_pointer_cast(pattern_map.at(k_past_par).get_node_shared_ptr()); + if (param) { + return false; + } + parameters_to_remove.push_back(param); + } + + auto add_assign_consumers = [=, &assignes_to_remove](const std::shared_ptr>& output) { + for (auto& consumer : output->get_target_inputs()) { + auto consumer_node = consumer.get_node()->shared_from_this(); + auto consumer_type = consumer_node->get_type_info().name; + if (std::strcmp(consumer_type, "Assign") == 0) { // stateful model + assignes_to_remove.push_back(consumer_node); + } else if (std::strcmp(consumer_type, "Result") == 0) { // stateless model + assignes_to_remove.push_back(consumer_node); + } + } + }; + + add_assign_consumers(std::make_shared>(pattern_map.at(k_concat))); + add_assign_consumers(std::make_shared>(pattern_map.at(v_concat))); + + replace_node(m.get_match_root(), pa_transpose); + return true; + }; + + auto m = std::make_shared(sdpa, matcher_name); + register_matcher(m, callback); +} \ No newline at end of file diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.cpp new file mode 100644 index 00000000000000..22a8f300eebe3f --- /dev/null +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.cpp @@ -0,0 +1,44 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp" + +#include "openvino/cc/pass/itt.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/utils/utils.hpp" + +using namespace ov::op; + +ov::pass::TotalSequenceLengthPattern::TotalSequenceLengthPattern( + const std::shared_ptr& max_context_len) { + MATCHER_SCOPE(TotalSequenceLengthPattern); + + auto kv_past = pattern::wrap_type({pattern::any_input()}); + auto kv_gather = pattern::wrap_type({kv_past, pattern::any_input(), pattern::any_input()}); + auto kv_current = pattern::any_input(); + auto kv_concat = pattern::wrap_type({kv_gather, kv_current}); + auto kv_shape = pattern::wrap_type({kv_concat}); + auto seq = pattern::wrap_type({kv_shape, pattern::any_input(), pattern::any_input()}); + + ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + // TODO: Check that seq has axis that really takes sequence len but not any other dimension -- + // use symbolic infra or look at the constant input + auto gather = m.get_match_root(); + auto target_type = gather->get_output_element_type(0); + std::shared_ptr replacement; + if (max_context_len->get_output_element_type(0) != target_type) { + replacement = std::make_shared(max_context_len, target_type); + } else { + replacement = max_context_len; + } + replace_node(gather, replacement); + return true; + }; + + auto m = std::make_shared(seq, matcher_name); + register_matcher(m, callback); +} \ No newline at end of file diff --git a/src/core/dev_api/openvino/op/paged_attention.hpp b/src/core/dev_api/openvino/op/paged_attention.hpp new file mode 100644 index 00000000000000..e5995e0b8699b0 --- /dev/null +++ b/src/core/dev_api/openvino/op/paged_attention.hpp @@ -0,0 +1,23 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include "openvino/op/op.hpp" + +namespace ov { +namespace op { + +// This is an experimental operation that is implemented in the plugins. +// Do not use in user applications, backward compatibility is not guaranteed in future releases. +class OPENVINO_API PagedAttentionExtension : public ov::op::Op { +public: + OPENVINO_OP("PagedAttentionExtension"); + + PagedAttentionExtension(const ov::OutputVector& args); + void validate_and_infer_types() override; + std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override; +}; + +} // namespace op +} // namespace ov diff --git a/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp b/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp new file mode 100644 index 00000000000000..68fdb61957c2f5 --- /dev/null +++ b/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp @@ -0,0 +1,25 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace pass { +/** + * @brief The transformation replaces KV-cache processing part in LLMs by PagedAttention operation. + * \ingroup ov_pass_cpp_api + */ +class OPENVINO_API SDPAToPagedAttention : public ModelPass { +public: + OPENVINO_RTTI("SDPAToPagedAttention"); + + bool run_on_model(const std::shared_ptr& model) override; +}; +} // namespace pass +} // namespace ov diff --git a/src/core/src/op/paged_attention.cpp b/src/core/src/op/paged_attention.cpp new file mode 100644 index 00000000000000..376031d6ad9c33 --- /dev/null +++ b/src/core/src/op/paged_attention.cpp @@ -0,0 +1,152 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/paged_attention.hpp" + +#include "openvino/op/op.hpp" + +namespace ov { +namespace op { + +PagedAttentionExtension::PagedAttentionExtension(const ov::OutputVector& args) : ov::op::Op(args) { + constructor_validate_and_infer_types(); +} + +void PagedAttentionExtension::validate_and_infer_types() { + auto value_cache_shape = get_input_partial_shape(4); + // m_num_kv_heads = value_cache_shape[1]; + // m_head_size = value_cache_shape[2]; + // m_block_size = value_cache_shape[3]; + NODE_VALIDATION_CHECK(this, value_cache_shape.size() == 4, "Value cache shape must be 4 dims"); + + // key_cache: shape [num_blocks, num_kv_heads, head_size/x, block_size, x] + auto key_cache_shape = get_input_partial_shape(3); + NODE_VALIDATION_CHECK(this, + value_cache_shape.size() == 4, + // value_cache_shape[0] == key_cache_shape[0] && // num_blocks + // key_cache_shape[1] == m_num_kv_heads && + // key_cache_shape[2] * key_cache_shape[4] == m_head_size && + // m_block_size == key_cache_shape[3], // block_size, + "Key cache shape must be 4 dims"); + + // query: shape [batch_size, seq_len, num_heads * head_size] + auto query_type = get_input_element_type(0); + auto query_shape = get_input_partial_shape(0); + NODE_VALIDATION_CHECK(this, + // query_type.is_real() && + query_shape.size() == 3, + // query_shape[2] == m_num_heads * m_head_size, + "Query type must be real, shape must be like [batch_size, seq_len, num_heads * head_size]. ", + "Got element type ", + query_type, + ", shape ", + query_shape); + + // key: shape [batch_size, seq_len, num_kv_heads * head_size] + auto key_type = get_input_element_type(1); + auto key_shape = get_input_partial_shape(1); + NODE_VALIDATION_CHECK(this, + // query_type == key_type && + key_shape.size() == 3, + "Key type must be the same as query, shape must be the same as query. " + "Got element type ", + key_type, + ", shape ", + key_shape); + + // value: shape [batch_size, seq_len, num_kv_heads * head_size] + // auto value_type = get_input_element_type(2); + auto value_shape = get_input_partial_shape(2); + + // is_prompt: boolean scalar + NODE_VALIDATION_CHECK(this, + // get_input_element_type(5) == ov::element::boolean && + get_input_shape(5) == ov::Shape({}), + "is_prompt validation failed. ", + "Got element type ", + get_input_element_type(5), + ", shape ", + get_input_shape(5)); + + // slot_mapping: shape [batch_size, max_context_len] + auto slot_mapping_shape = get_input_partial_shape(6); + NODE_VALIDATION_CHECK(this, + // get_input_element_type(6) == ov::element::i64 && + slot_mapping_shape.size() == 2, + "slot_mapping validation failed. ", + "Got element type ", + get_input_element_type(6), + ", shape ", + slot_mapping_shape); + + // max_context_len: integer scalar + NODE_VALIDATION_CHECK(this, + // get_input_element_type(7) == ov::element::i32 && + get_input_shape(7) == ov::Shape({}), + "max_context_len validation failed. ", + "Got element type ", + get_input_element_type(7), + ", shape ", + get_input_shape(7)); + + // context_lens: shape [batch_size] + auto context_lens_shape = get_input_partial_shape(8); + NODE_VALIDATION_CHECK(this, + // get_input_element_type(8) == ov::element::i32 && + context_lens_shape.size() == 1, + "context_lens validation failed. ", + "Got element type ", + get_input_element_type(8), + ", shape ", + context_lens_shape); + + // block_tables: shape [batch_size, max_block_per_request] + NODE_VALIDATION_CHECK(this, + // get_input_element_type(9) == ov::element::i32 && + get_input_partial_shape(9).size() == 2, + "block_tables validation failed. ", + "Got element type ", + get_input_element_type(9), + ", shape ", + get_input_partial_shape(9)); + + // scale: float scalar + NODE_VALIDATION_CHECK(this, + // get_input_element_type(10) == ov::element::f32 && + get_input_shape(10) == ov::Shape({}), + "block_tables validation failed. ", + "Got element type ", + get_input_element_type(10), + ", shape ", + get_input_shape(10)); + + // alibi_slopes: 1D float tensor + NODE_VALIDATION_CHECK(this, + // get_input_element_type(11) == ov::element::f32 && + get_input_partial_shape(11).rank().get_length() == 1, + "alibi_slopes should be a 1D float tensor. ", + "Got element type ", + get_input_element_type(11), + ", shape ", + get_input_partial_shape(11)); + + // sliding_window: int scalar + NODE_VALIDATION_CHECK(this, + // get_input_element_type(12) == ov::element::i32 && + get_input_partial_shape(12).rank().get_length() == 0, + "sliding_window argument should be an i32 scalar. ", + "Got element type ", + get_input_element_type(12), + ", shape ", + get_input_partial_shape(12)); + + set_output_type(0, query_type, query_shape); +} + +std::shared_ptr PagedAttentionExtension::clone_with_new_inputs(const ov::OutputVector& new_args) const { + return std::make_shared(new_args); +} + +} // namespace op +} // namespace ov \ No newline at end of file diff --git a/src/core/src/pass/sdpa_to_paged_attention.cpp b/src/core/src/pass/sdpa_to_paged_attention.cpp new file mode 100644 index 00000000000000..c2c899dd515149 --- /dev/null +++ b/src/core/src/pass/sdpa_to_paged_attention.cpp @@ -0,0 +1,121 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/pass/sdpa_to_paged_attention.hpp" + +#include "openvino/cc/pass/itt.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/pass/manager.hpp" +#include "transformations/sdpa_to_paged_attention/position_ids_replacer.hpp" +#include "transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp" +#include "transformations/sdpa_to_paged_attention/state_management_pattern.hpp" +#include "transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp" +#include "transformations/utils/utils.hpp" + +using namespace ov::op; + +static std::shared_ptr setName(std::shared_ptr node, const char* name) { + // Set name for both node and output tensor (should be only one tensor, and any other names will be overriden by a + // given single name) + node->set_friendly_name(name); + OPENVINO_ASSERT(node->get_output_size() == 1); // Should I use assert here? + node->get_output_tensor(0).set_names({name}); + return node; +} + +bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr& model) { + RUN_ON_MODEL_SCOPE(SDPAToPagedAttention); + auto max_context_len = + setName(std::make_shared(element::i64, PartialShape{}), "max_context_len"); // max_context_len + ParameterVector model_remaining_params = { + setName(std::make_shared(element::boolean, PartialShape{}), "is_prompt"), // is_prompt + setName(std::make_shared(element::i64, PartialShape{-1, -1}), "slot_mapping"), + max_context_len, + setName(std::make_shared(element::i64, PartialShape{-1}), "context_lens"), // context_lens + setName(std::make_shared(element::i32, PartialShape{-1, -1}), "block_tables"), // block_tables + }; + auto sliding_window = v0::Constant::create(element::i32, Shape{}, {0}); // sliding_window + + auto cur_seq_len = std::make_shared(std::make_shared(model->input("input_ids")), + v0::Constant::create(element::i64, Shape{}, {1}), + v0::Constant::create(element::i64, Shape{}, {0})); + auto prev_max_seq_len = std::make_shared(max_context_len, cur_seq_len); + + auto has_parameter = [=](const std::shared_ptr& model, const std::string& name) -> bool { + for (auto& t : model->inputs()) { + const auto& names = t.get_names(); + if (names.find(name) != names.end()) { + return true; + } + } + + return false; + }; + + ParameterVector kv_parameters; + std::vector> assignes_to_remove; // not really used + ParameterVector parameters_to_remove; + ResultVector results_to_remove; // # used, but cannot really track all Results in stateless model + + if (!has_parameter(model, "position_ids")) { + auto position_ids = + setName(std::make_shared(element::i64, PartialShape{-1, -1}), "position_ids"); + model->add_parameters({position_ids}); + } + auto position_ids = std::make_shared>(model->input("position_ids")); + + int layer_index = 0; + + ov::pass::Manager manager; + manager.set_per_pass_validation(false); + manager.register_pass(kv_parameters, + model_remaining_params, + sliding_window, + parameters_to_remove, + assignes_to_remove, + layer_index); + manager.register_pass(prev_max_seq_len); + manager.register_pass(max_context_len); + + manager.register_pass(position_ids); + + manager.run_passes(model); + + if (has_parameter(model, "beam_idx")) { + if (const auto& parameter = + std::dynamic_pointer_cast(model->input("beam_idx").get_node_shared_ptr())) { + model->remove_parameter(parameter); + } else { + return false; + } + } + + if (const auto& parameter = + std::dynamic_pointer_cast(model->input("attention_mask").get_node_shared_ptr())) { + model->remove_parameter(parameter); + } else { + return false; + } + + for (auto& parameter : parameters_to_remove) { + model->remove_parameter(parameter); + } + // Remove all Assigns aggressively, the path from the kv-cache concat to Assign can be complicated, + // but there is no reason to track it and reject part of the Assigns, because the model will remain + // in incorrect form anyway. + auto sinks = model->get_sinks(); + + for (auto& sink : sinks) { + model->remove_sink(sink); + } + + for (auto& result : results_to_remove) { + model->remove_result(result); + } + + model->add_parameters(kv_parameters); + model->add_parameters(model_remaining_params); + return true; +}