Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
from openvino._pyopenvino._offline_transformations import compress_model_transformation
from openvino._pyopenvino._offline_transformations import compress_quantize_weights_transformation
from openvino._pyopenvino._offline_transformations import convert_sequence_to_tensor_iterator_transformation
from openvino._pyopenvino._offline_transformations import paged_attention_transformation
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <compress_quantize_weights.hpp>
#include <openvino/pass/make_stateful.hpp>
#include <openvino/pass/sdpa_to_paged_attention.hpp>
#include <openvino/pass/serialize.hpp>
#include <pruning.hpp>
#include <transformations/common_optimizations/compress_float_constants.hpp>
Expand Down Expand Up @@ -127,4 +128,15 @@ void regmodule_offline_transformations(py::module m) {
manager.run_passes(model);
},
py::arg("model"));

m_offline_transformations.def(
"paged_attention_transformation",
[](std::shared_ptr<ov::Model> model) {
std::cout << "___CALLING SDPAToPagedAttention___" << std::endl;
ov::pass::Manager manager;
manager.register_pass<ov::pass::SDPAToPagedAttention>();
manager.run_passes(model);
std::cout << "___AFTER CALLING SDPAToPagedAttention___" << std::endl;
},
py::arg("model"));
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,162 +5,13 @@
#include "pyopenvino/graph/ops/paged_attention_extension.hpp"

#include "openvino/op/op.hpp"
#include "openvino/op/paged_attention.hpp"
#include "pyopenvino/core/common.hpp"

namespace py = pybind11;

namespace {

// This is an experimental operation that is implemented in the plugins.
// Do not use in user applications, backward compatibility is not guaranteed in future releases.
class PagedAttentionExtension : public ov::op::Op {
public:
OPENVINO_OP("PagedAttentionExtension");

PagedAttentionExtension(const ov::OutputVector& args) : ov::op::Op(args) {
constructor_validate_and_infer_types();
}

void validate_and_infer_types() override {
auto value_cache_shape = get_input_partial_shape(4);
// m_num_kv_heads = value_cache_shape[1];
// m_head_size = value_cache_shape[2];
// m_block_size = value_cache_shape[3];
NODE_VALIDATION_CHECK(this, value_cache_shape.size() == 4, "Value cache shape must be 4 dims");

// key_cache: shape [num_blocks, num_kv_heads, head_size/x, block_size, x]
auto key_cache_shape = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
value_cache_shape.size() == 4,
// value_cache_shape[0] == key_cache_shape[0] && // num_blocks
// key_cache_shape[1] == m_num_kv_heads &&
// key_cache_shape[2] * key_cache_shape[4] == m_head_size &&
// m_block_size == key_cache_shape[3], // block_size,
"Key cache shape must be 4 dims");

// query: shape [batch_size, seq_len, num_heads * head_size]
auto query_type = get_input_element_type(0);
auto query_shape = get_input_partial_shape(0);
NODE_VALIDATION_CHECK(
this,
// query_type.is_real() &&
query_shape.size() == 3,
// query_shape[2] == m_num_heads * m_head_size,
"Query type must be real, shape must be like [batch_size, seq_len, num_heads * head_size]. ",
"Got element type ",
query_type,
", shape ",
query_shape);

// key: shape [batch_size, seq_len, num_kv_heads * head_size]
auto key_type = get_input_element_type(1);
auto key_shape = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(this,
// query_type == key_type &&
key_shape.size() == 3,
"Key type must be the same as query, shape must be the same as query. "
"Got element type ",
key_type,
", shape ",
key_shape);

// value: shape [batch_size, seq_len, num_kv_heads * head_size]
// auto value_type = get_input_element_type(2);
auto value_shape = get_input_partial_shape(2);

// is_prompt: boolean scalar
NODE_VALIDATION_CHECK(this,
// get_input_element_type(5) == ov::element::boolean &&
get_input_shape(5) == ov::Shape({}),
"is_prompt validation failed. ",
"Got element type ",
get_input_element_type(5),
", shape ",
get_input_shape(5));

// slot_mapping: shape [batch_size, max_context_len]
auto slot_mapping_shape = get_input_partial_shape(6);
NODE_VALIDATION_CHECK(this,
// get_input_element_type(6) == ov::element::i64 &&
slot_mapping_shape.size() == 2,
"slot_mapping validation failed. ",
"Got element type ",
get_input_element_type(6),
", shape ",
slot_mapping_shape);

// max_context_len: integer scalar
NODE_VALIDATION_CHECK(this,
// get_input_element_type(7) == ov::element::i32 &&
get_input_shape(7) == ov::Shape({}),
"max_context_len validation failed. ",
"Got element type ",
get_input_element_type(7),
", shape ",
get_input_shape(7));

// context_lens: shape [batch_size]
auto context_lens_shape = get_input_partial_shape(8);
NODE_VALIDATION_CHECK(this,
// get_input_element_type(8) == ov::element::i32 &&
context_lens_shape.size() == 1,
"context_lens validation failed. ",
"Got element type ",
get_input_element_type(8),
", shape ",
context_lens_shape);

// block_tables: shape [batch_size, max_block_per_request]
NODE_VALIDATION_CHECK(this,
// get_input_element_type(9) == ov::element::i32 &&
get_input_partial_shape(9).size() == 2,
"block_tables validation failed. ",
"Got element type ",
get_input_element_type(9),
", shape ",
get_input_partial_shape(9));

// scale: float scalar
NODE_VALIDATION_CHECK(this,
// get_input_element_type(10) == ov::element::f32 &&
get_input_shape(10) == ov::Shape({}),
"block_tables validation failed. ",
"Got element type ",
get_input_element_type(10),
", shape ",
get_input_shape(10));

// alibi_slopes: 1D float tensor
NODE_VALIDATION_CHECK(this,
// get_input_element_type(11) == ov::element::f32 &&
get_input_partial_shape(11).rank().get_length() == 1,
"alibi_slopes should be a 1D float tensor. ",
"Got element type ",
get_input_element_type(11),
", shape ",
get_input_partial_shape(11));

// sliding_window: int scalar
NODE_VALIDATION_CHECK(this,
// get_input_element_type(12) == ov::element::i32 &&
get_input_partial_shape(12).rank().get_length() == 0,
"sliding_window argument should be an i32 scalar. ",
"Got element type ",
get_input_element_type(12),
", shape ",
get_input_partial_shape(12));

set_output_type(0, query_type, query_shape);
}

std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override {
return std::make_shared<PagedAttentionExtension>(new_args);
}
};

} // namespace

void regclass_graph_op_PagedAttentionExtension(py::module m) {
using ov::op::PagedAttentionExtension;
py::class_<PagedAttentionExtension, std::shared_ptr<PagedAttentionExtension>, ov::Node> cls(
m,
"_PagedAttentionExtension");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/cc/pass/itt.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class PositionIDsReplacer;

} // namespace pass
} // namespace ov

class ov::pass::PositionIDsReplacer : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("PositionIDsReplacer", "0");
explicit PositionIDsReplacer(const std::shared_ptr<Output<Node>>& position_ids);
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/cc/pass/itt.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class PrevSequenceLengthPattern;

} // namespace pass
} // namespace ov

class ov::pass::PrevSequenceLengthPattern : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("PrevSequenceLengthPattern", "0");
explicit PrevSequenceLengthPattern(const std::shared_ptr<ov::op::v1::Subtract>& prev_max_seq_len);
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class StateManagementPattern;

} // namespace pass
} // namespace ov

class ov::pass::StateManagementPattern : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("StateManagementPattern", "0");
StateManagementPattern(ParameterVector& kv_parameters,
const ParameterVector& model_remaining_params,
const std::shared_ptr<ov::op::v0::Constant>& sliding_window,
ParameterVector& parameters_to_remove,
NodeVector& assignes_to_remove,
int& layer_index);
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/cc/pass/itt.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class TotalSequenceLengthPattern;

} // namespace pass
} // namespace ov

class ov::pass::TotalSequenceLengthPattern : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TotalSequenceLengthPattern", "0");
explicit TotalSequenceLengthPattern(const std::shared_ptr<ov::op::v0::Parameter>& max_context_len);
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/sdpa_to_paged_attention/position_ids_replacer.hpp"

#include "openvino/cc/pass/itt.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"

using namespace ov::op;

// TODO: Instead of using the following transformation that matches quite a specific place in a model graph in case when
// position_ids parameter is missing, consider replacing always existing attention_mask parameter with a sub-graph using
// a new slot_mapping parameter.
ov::pass::PositionIDsReplacer::PositionIDsReplacer(const std::shared_ptr<Output<Node>>& position_ids) {
MATCHER_SCOPE(PositionIDsReplacer);

auto input_ids = pattern::any_input();
auto input_embed = pattern::wrap_type<v8::Gather>({pattern::any_input(), input_ids, pattern::any_input()});

auto position_ids_pattern = pattern::any_input();
auto offset = pattern::wrap_type<v0::Constant>();
auto add_offset = pattern::wrap_type<v1::Add>({position_ids_pattern, offset});
auto convert = pattern::wrap_type<v0::Convert>({add_offset});
auto position_embed = pattern::wrap_type<v8::Gather>({pattern::any_input(), convert, pattern::any_input()});

auto add = pattern::wrap_type<v1::Add>({input_embed, position_embed});

ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
replace_node(pattern_map.at(position_ids_pattern).get_node_shared_ptr(), position_ids->get_node_shared_ptr());
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(add, matcher_name);
register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp"

#include "openvino/cc/pass/itt.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"

using namespace ov::op;

ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(
const std::shared_ptr<ov::op::v1::Subtract>& prev_max_seq_len) {
MATCHER_SCOPE(PrevSequenceLengthPattern);

auto kv_past = pattern::wrap_type<v6::ReadValue>({pattern::any_input()});
auto kv_gather = pattern::wrap_type<v8::Gather>({kv_past, pattern::any_input(), pattern::any_input()});
auto kv_shape = pattern::wrap_type<v3::ShapeOf>({kv_gather});
auto seq = pattern::wrap_type<v8::Gather>({kv_past, pattern::any_input(), pattern::any_input()});

ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
// TODO: Check that seq has axis that really takes sequence len but not any other dimension -- use symbolics or
// look at the constant input
auto gather = m.get_match_root();
auto target_type = gather->get_output_element_type(0);
std::shared_ptr<Node> replacement;
if (prev_max_seq_len->get_output_element_type(0) != target_type) {
replacement = std::make_shared<v0::Convert>(prev_max_seq_len, target_type);
} else {
replacement = prev_max_seq_len;
}
replace_node(gather, replacement);
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(seq, matcher_name);
register_matcher(m, callback);
}
Loading