Skip to content

Commit 7448911

Browse files
11happyrkazants
authored andcommitted
[JAX FE]: add support for jax.lax.logistic (openvinotoolkit#28240)
**Overview:** This PR fixes openvinotoolkit#26576. **Testing:** - Tested the Updated code - Verified that other functionalities remain unaffected ![Screenshot from 2025-01-01 13-11-04](https://github.com/user-attachments/assets/5acfabc2-dded-4c65-b408-d4174fa3c025) **Dependencies:** - No dependencies on other pull requests **CC:** - @rkazants --------- Signed-off-by: 11happy <[email protected]> Co-authored-by: Roman Kazantsev <[email protected]>
1 parent 2e24dfa commit 7448911

File tree

18 files changed

+428
-94
lines changed

18 files changed

+428
-94
lines changed

src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ static void reg_pattern_wrap_type(py::module m) {
8989

9090
wrap_type.def(py::init([](const std::string& type_name, const ov::Output<ov::Node>& input) {
9191
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name),
92-
nullptr,
92+
ov::pass::pattern::op::Predicate(),
9393
ov::OutputVector{input});
9494
}),
9595
py::arg("type_name"),
@@ -106,7 +106,7 @@ static void reg_pattern_wrap_type(py::module m) {
106106

107107
wrap_type.def(py::init([](const std::string& type_name, const std::shared_ptr<ov::Node>& input) {
108108
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name),
109-
nullptr,
109+
ov::pass::pattern::op::Predicate(),
110110
ov::OutputVector{input});
111111
}),
112112
py::arg("type_name"),
@@ -165,7 +165,9 @@ static void reg_pattern_wrap_type(py::module m) {
165165
)");
166166

167167
wrap_type.def(py::init([](const std::string& type_name, const ov::OutputVector& inputs) {
168-
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name), nullptr, inputs);
168+
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name),
169+
ov::pass::pattern::op::Predicate(),
170+
inputs);
169171
}),
170172
py::arg("type_name"),
171173
py::arg("inputs"),
@@ -181,7 +183,7 @@ static void reg_pattern_wrap_type(py::module m) {
181183

182184
wrap_type.def(py::init([](const std::string& type_name, const ov::NodeVector& inputs) {
183185
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name),
184-
nullptr,
186+
ov::pass::pattern::op::Predicate(),
185187
ov::as_output_vector(inputs));
186188
}),
187189
py::arg("type_name"),
@@ -264,7 +266,7 @@ static void reg_pattern_wrap_type(py::module m) {
264266

265267
wrap_type.def(py::init([](const std::vector<std::string>& type_names, const ov::Output<ov::Node>& input) {
266268
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names),
267-
nullptr,
269+
ov::pass::pattern::op::Predicate(),
268270
ov::OutputVector{input});
269271
}),
270272
py::arg("type_names"),
@@ -281,7 +283,7 @@ static void reg_pattern_wrap_type(py::module m) {
281283

282284
wrap_type.def(py::init([](const std::vector<std::string>& type_names, const std::shared_ptr<ov::Node>& input) {
283285
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names),
284-
nullptr,
286+
ov::pass::pattern::op::Predicate(),
285287
ov::OutputVector{input});
286288
}),
287289
py::arg("type_names"),
@@ -343,7 +345,9 @@ static void reg_pattern_wrap_type(py::module m) {
343345
)");
344346

345347
wrap_type.def(py::init([](const std::vector<std::string>& type_names, const ov::OutputVector& inputs) {
346-
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names), nullptr, inputs);
348+
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names),
349+
ov::pass::pattern::op::Predicate(),
350+
inputs);
347351
}),
348352
py::arg("type_names"),
349353
py::arg("inputs"),
@@ -359,7 +363,7 @@ static void reg_pattern_wrap_type(py::module m) {
359363

360364
wrap_type.def(py::init([](const std::vector<std::string>& type_names, const ov::NodeVector& inputs) {
361365
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names),
362-
nullptr,
366+
ov::pass::pattern::op::Predicate(),
363367
ov::as_output_vector(inputs));
364368
}),
365369
py::arg("type_names"),
@@ -501,8 +505,7 @@ static void reg_pattern_optional(py::module m) {
501505

502506
optional_type.def(py::init([](const std::vector<std::string>& type_names, const ov::Output<ov::Node>& input) {
503507
return std::make_shared<ov::pass::pattern::op::Optional>(get_types(type_names),
504-
ov::OutputVector{input},
505-
nullptr);
508+
ov::OutputVector{input});
506509
}),
507510
py::arg("type_names"),
508511
py::arg("input"),
@@ -518,8 +521,7 @@ static void reg_pattern_optional(py::module m) {
518521

519522
optional_type.def(py::init([](const std::vector<std::string>& type_names, const std::shared_ptr<ov::Node>& input) {
520523
return std::make_shared<ov::pass::pattern::op::Optional>(get_types(type_names),
521-
ov::OutputVector{input},
522-
nullptr);
524+
ov::OutputVector{input});
523525
}),
524526
py::arg("type_names"),
525527
py::arg("input"),
@@ -533,13 +535,12 @@ static void reg_pattern_optional(py::module m) {
533535
:type input: openvino.runtime.Node
534536
)");
535537

536-
optional_type.def(
537-
py::init([](const std::vector<std::string>& type_names, const ov::OutputVector& inputs) {
538-
return std::make_shared<ov::pass::pattern::op::Optional>(get_types(type_names), inputs, nullptr);
539-
}),
540-
py::arg("type_names"),
541-
py::arg("inputs"),
542-
R"(
538+
optional_type.def(py::init([](const std::vector<std::string>& type_names, const ov::OutputVector& inputs) {
539+
return std::make_shared<ov::pass::pattern::op::Optional>(get_types(type_names), inputs);
540+
}),
541+
py::arg("type_names"),
542+
py::arg("inputs"),
543+
R"(
543544
Create Optional with the given node type and input node.
544545
545546
:param type_names: node type. For example: ["opset8.Abs", "opset8.Relu"]
@@ -551,8 +552,7 @@ static void reg_pattern_optional(py::module m) {
551552

552553
optional_type.def(py::init([](const std::vector<std::string>& type_names, const ov::NodeVector& inputs) {
553554
return std::make_shared<ov::pass::pattern::op::Optional>(get_types(type_names),
554-
ov::as_output_vector(inputs),
555-
nullptr);
555+
ov::as_output_vector(inputs));
556556
}),
557557
py::arg("type_names"),
558558
py::arg("inputs"),

src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#pragma once
66

77
#include "openvino/pass/graph_rewrite.hpp"
8+
#include "transformations/symbolic_transformations/symbolic_optimizations.hpp"
89
#include "transformations_visibility.hpp"
910

1011
namespace ov {
@@ -88,29 +89,42 @@ class ov::pass::RoPEShareCosSin : public ov::pass::MatcherPass {
8889
* @ingroup ov_transformation_common_api
8990
* @brief Fuses special sub-graph into an internal Rotary Positional Embedding operation
9091
*/
91-
class ov::pass::RoPEFusion : public ov::pass::GraphRewrite {
92+
93+
class ov::pass::RoPEFusion : public ov::pass::ModelPass {
9294
public:
93-
OPENVINO_GRAPH_REWRITE_RTTI("RoPEFusion");
94-
RoPEFusion(bool support_2d_rope = false) {
95-
add_matcher<ov::pass::RoPEFusionFlux>();
96-
add_matcher<ov::pass::RoPEFusionGPTNEOX>();
97-
add_matcher<ov::pass::RoPEFusionGPTJ>();
95+
OPENVINO_MODEL_PASS_RTTI("RoPEFusion");
96+
97+
explicit RoPEFusion(bool support_2d_rope = false) : support_2d_rope(support_2d_rope){};
98+
99+
bool run_on_model(const std::shared_ptr<ov::Model>& m) override {
100+
auto symbolic_pipeline = ov::pass::SymbolicOptimizations(false);
101+
auto rope_fusions = symbolic_pipeline.get_manager()->register_pass<ov::pass::GraphRewrite>();
102+
rope_fusions->set_name("RoPEFusions");
103+
104+
rope_fusions->add_matcher<ov::pass::RoPEFusionFlux>();
105+
rope_fusions->add_matcher<ov::pass::RoPEFusionGPTNEOX>();
106+
rope_fusions->add_matcher<ov::pass::RoPEFusionGPTJ>();
98107
// optional heads & tails are fused in separate matcher pass,
99108
// after RoPENode has been created.
100-
add_matcher<ov::pass::RoPEFusionCosSinPreprocess>();
101-
add_matcher<ov::pass::RoPEFusionIOSlicing>();
102-
add_matcher<ov::pass::RoPEFusionPreprocess>();
109+
rope_fusions->add_matcher<ov::pass::RoPEFusionCosSinPreprocess>();
110+
rope_fusions->add_matcher<ov::pass::RoPEFusionIOSlicing>();
111+
rope_fusions->add_matcher<ov::pass::RoPEFusionPreprocess>();
103112

104-
add_matcher<ov::pass::RoPEFusionChatGLM>(0);
105-
add_matcher<ov::pass::RoPEFusionChatGLM>(1);
113+
rope_fusions->add_matcher<ov::pass::RoPEFusionChatGLM>(0);
114+
rope_fusions->add_matcher<ov::pass::RoPEFusionChatGLM>(1);
106115
if (support_2d_rope) {
107-
add_matcher<ov::pass::RoPEFusionChatGLM>(0, true);
108-
add_matcher<ov::pass::RoPEFusionChatGLM>(1, true);
116+
rope_fusions->add_matcher<ov::pass::RoPEFusionChatGLM>(0, true);
117+
rope_fusions->add_matcher<ov::pass::RoPEFusionChatGLM>(1, true);
109118
}
110119

111-
add_matcher<ov::pass::RoPEFusionQwen>(0);
112-
add_matcher<ov::pass::RoPEFusionQwen>(1);
120+
rope_fusions->add_matcher<ov::pass::RoPEFusionQwen>(0);
121+
rope_fusions->add_matcher<ov::pass::RoPEFusionQwen>(1);
122+
123+
rope_fusions->add_matcher<ov::pass::RoPEShareCosSin>();
113124

114-
add_matcher<ov::pass::RoPEShareCosSin>();
125+
return symbolic_pipeline.run_on_model(m);
115126
}
127+
128+
protected:
129+
bool support_2d_rope = false;
116130
};

src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ ov::pass::LabelResolvingThroughSelect::LabelResolvingThroughSelect() {
172172
}
173173

174174
ov::pass::SymbolicOptimizations::SymbolicOptimizations(bool full_run) {
175-
m_manager = std::make_shared<pass::Manager>("Symbolic");
175+
m_manager = std::make_shared<pass::Manager>(get_pass_config(), "Symbolic");
176176
m_manager->set_per_pass_validation(false);
177177

178178
#define REGISTER_SYMBOLIC(region, ...) m_manager->register_pass<region>(__VA_ARGS__);
@@ -208,7 +208,7 @@ bool ov::pass::SymbolicOptimizations::run_on_model(const std::shared_ptr<ov::Mod
208208
pass_config->disable<EliminateSqueeze>();
209209
pass_config->disable<EliminateUnsqueeze>();
210210

211-
m_manager->run_passes(m);
211+
bool status = m_manager->run_passes(m);
212212
ov::remove_skip_invalidation_rti(m);
213-
return true;
213+
return status;
214214
}

src/core/include/openvino/pass/pattern/matcher.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ class OPENVINO_API Matcher {
164164
PatternValueMap& get_pattern_value_map() {
165165
return m_pattern_map;
166166
}
167+
PatternSymbolMap& get_symbols() {
168+
return m_pattern_symbols;
169+
}
167170
PatternValueMaps& get_pattern_value_maps() {
168171
return m_pattern_value_maps;
169172
}
@@ -198,6 +201,7 @@ class OPENVINO_API Matcher {
198201
Output<Node> m_match_root;
199202
Output<Node> m_pattern_node;
200203
PatternValueMap m_pattern_map;
204+
PatternSymbolMap m_pattern_symbols;
201205
PatternValueMaps m_pattern_value_maps;
202206
OutputVector m_matched_list;
203207

src/core/include/openvino/pass/pattern/op/any.hpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,27 @@ class OPENVINO_API Any : public Pattern {
1818
OPENVINO_RTTI("patternAny");
1919
/// \brief creates a Any node containing a sub-pattern described by \sa type and \sa
2020
/// shape.
21-
Any(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values)
21+
Any(const element::Type& type, const PartialShape& s, Predicate pred, const OutputVector& wrapped_values)
2222
: Pattern(wrapped_values, pred) {
2323
set_output_type(0, type, s);
2424
}
25+
Any(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values)
26+
: Any(type, s, Predicate(pred), wrapped_values) {}
27+
Any(const element::Type& type, const PartialShape& s, SymbolPredicate pred, const OutputVector& wrapped_values)
28+
: Any(type, s, Predicate(pred), wrapped_values) {}
2529
Any(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values)
26-
: Any(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
30+
: Any(type, s, pred, as_output_vector(wrapped_values)) {}
31+
2732
/// \brief creates a Any node containing a sub-pattern described by the type and
2833
/// shape of \sa node.
29-
Any(const Output<Node>& node, ValuePredicate pred, const OutputVector& wrapped_values)
34+
Any(const Output<Node>& node, Predicate pred, const OutputVector& wrapped_values)
3035
: Any(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
36+
Any(const Output<Node>& node, ValuePredicate pred, const OutputVector& wrapped_values)
37+
: Any(node, Predicate(pred), wrapped_values) {}
38+
Any(const Output<Node>& node, SymbolPredicate pred, const OutputVector& wrapped_values)
39+
: Any(node, Predicate(pred), wrapped_values) {}
3140
Any(const Output<Node>& node, NodePredicate pred, const NodeVector& wrapped_values)
32-
: Any(node.get_element_type(),
33-
node.get_partial_shape(),
34-
as_value_predicate(pred),
35-
as_output_vector(wrapped_values)) {}
41+
: Any(node, pred, as_output_vector(wrapped_values)) {}
3642

3743
bool match_value(pattern::Matcher* matcher,
3844
const Output<Node>& pattern_value,

src/core/include/openvino/pass/pattern/op/any_of.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ class OPENVINO_API AnyOf : public Pattern {
3030
}
3131
set_output_type(0, type, s);
3232
}
33+
AnyOf(const element::Type& type, const PartialShape& s, SymbolPredicate pred, const OutputVector& wrapped_values)
34+
: Pattern(wrapped_values, pred) {
35+
if (wrapped_values.size() != 1) {
36+
OPENVINO_THROW("AnyOf expects exactly one argument");
37+
}
38+
set_output_type(0, type, s);
39+
}
3340
AnyOf(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values)
3441
: AnyOf(
3542
type,
@@ -43,6 +50,8 @@ class OPENVINO_API AnyOf : public Pattern {
4350
/// shape of \sa node.
4451
AnyOf(const Output<Node>& node, ValuePredicate pred, const OutputVector& wrapped_values)
4552
: AnyOf(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
53+
AnyOf(const Output<Node>& node, SymbolPredicate pred, const OutputVector& wrapped_values)
54+
: AnyOf(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
4655
AnyOf(const std::shared_ptr<Node>& node, NodePredicate pred, const NodeVector& wrapped_values)
4756
: AnyOf(node, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
4857
bool match_value(Matcher* matcher, const Output<Node>& pattern_value, const Output<Node>& graph_value) override;

src/core/include/openvino/pass/pattern/op/label.hpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,13 @@ class OPENVINO_API Label : public Pattern {
4343
: Pattern(OutputVector{wrap_values(wrapped_values)}, pred) {
4444
set_output_type(0, type, s);
4545
}
46-
46+
Label(const element::Type& type,
47+
const PartialShape& s,
48+
const SymbolPredicate pred,
49+
const OutputVector& wrapped_values)
50+
: Pattern(OutputVector{wrap_values(wrapped_values)}, pred) {
51+
set_output_type(0, type, s);
52+
}
4753
explicit Label(const element::Type& type = element::dynamic, const PartialShape& s = PartialShape::dynamic())
4854
: Label(
4955
type,
@@ -56,6 +62,9 @@ class OPENVINO_API Label : public Pattern {
5662
Label(const element::Type& type, const PartialShape& s, ValuePredicate pred)
5763
: Label(type, s, std::move(pred), OutputVector{}) {}
5864

65+
Label(const element::Type& type, const PartialShape& s, SymbolPredicate pred)
66+
: Label(type, s, std::move(pred), OutputVector{}) {}
67+
5968
Label(const element::Type& type, const PartialShape& s, NodePredicate pred)
6069
: Label(type, s, as_value_predicate(std::move(pred)), OutputVector{}) {}
6170

@@ -78,6 +87,10 @@ class OPENVINO_API Label : public Pattern {
7887
: Label(value.get_element_type(), value.get_partial_shape(), pred, wrapped_values) {}
7988
Label(const Output<Node>& value, const ValuePredicate pred)
8089
: Label(value.get_element_type(), value.get_partial_shape(), pred, OutputVector{}) {}
90+
Label(const Output<Node>& value, const SymbolPredicate pred, const OutputVector& wrapped_values)
91+
: Label(value.get_element_type(), value.get_partial_shape(), pred, wrapped_values) {}
92+
Label(const Output<Node>& value, const SymbolPredicate pred)
93+
: Label(value.get_element_type(), value.get_partial_shape(), pred, OutputVector{}) {}
8194

8295
Label(const Output<Node>& value, const NodePredicate pred)
8396
: Label(value.get_element_type(), value.get_partial_shape(), as_value_predicate(pred), OutputVector{}) {}
@@ -107,6 +120,8 @@ std::shared_ptr<Node> any_input();
107120

108121
OPENVINO_API
109122
std::shared_ptr<Node> any_input(const pattern::op::ValuePredicate& pred);
123+
OPENVINO_API
124+
std::shared_ptr<Node> any_input(const pattern::op::SymbolPredicate& pred);
110125
} // namespace pattern
111126
} // namespace pass
112127
} // namespace ov

0 commit comments

Comments
 (0)