-
Notifications
You must be signed in to change notification settings - Fork 2.9k
[JAX]: Support jax.lax.select_n operation for JAX #28025
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
a2f58f3
Implemented jax.lax.select_n
11happy 16f206d
add test for bool type & added copyrights
11happy 7f39f13
remove np.bool and correct implementation for dynamic shapes
11happy 6ffea93
reverted to previous implementation
11happy 4074780
added case_num seperate parameter for more concrete testing
11happy 372c988
implement without get_shape also check for scalar cases as well in tests
11happy 9831b7f
minor nits
11happy 37a46a6
Update tests/layer_tests/jax_tests/test_select_n.py
rkazants dba3347
gha build fix
11happy e7b13ad
Merge remote-tracking branch 'refs/remotes/origin/jax.select' into ja…
11happy 74a42e0
revert case num tests
11happy 07aa342
correct case num tests
11happy b596253
Update test_select_n.py
rkazants 09b1a97
Merge branch 'master' into jax.select
rkazants File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| // Copyright (C) 2018-2024 Intel Corporation | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // | ||
|
|
||
| #include "openvino/frontend/jax/node_context.hpp" | ||
| #include "openvino/op/concat.hpp" | ||
| #include "openvino/op/constant.hpp" | ||
| #include "openvino/op/convert.hpp" | ||
| #include "openvino/op/gather_elements.hpp" | ||
| #include "openvino/op/unsqueeze.hpp" | ||
| #include "utils.hpp" | ||
|
|
||
| using namespace ov::op; | ||
|
|
||
| namespace ov { | ||
| namespace frontend { | ||
| namespace jax { | ||
| namespace op { | ||
|
|
||
| OutputVector translate_select_n(const NodeContext& context) { | ||
| num_inputs_check(context, 2); | ||
| auto num_inputs = static_cast<int>(context.get_input_size()); | ||
| Output<Node> which = context.get_input(0); | ||
| if (which.get_element_type() == element::boolean) { | ||
| which = std::make_shared<v0::Convert>(which, element::i32); | ||
| } | ||
| auto const_axis = ov::op::v0::Constant::create(element::i64, Shape{1}, std::vector<int64_t>{0}); | ||
| OutputVector unsqueezed_cases(num_inputs - 1); | ||
| unsqueezed_cases.reserve(num_inputs - 1); | ||
| for (int ind = 1; ind < num_inputs; ++ind) { | ||
| auto case_input = context.get_input(ind); | ||
| auto unsqueeze = std::make_shared<v0::Unsqueeze>(case_input, const_axis); | ||
| unsqueezed_cases[ind - 1] = unsqueeze; | ||
| } | ||
| Output<Node> cases = std::make_shared<v0::Concat>(unsqueezed_cases, 0); | ||
| which = | ||
| std::make_shared<v0::Unsqueeze>(which, | ||
| ov::op::v0::Constant::create(element::i64, Shape{1}, std::vector<int64_t>{0})); | ||
| Output<Node> result = std::make_shared<v6::GatherElements>(cases, which, 0); | ||
| return {result}; | ||
| }; | ||
|
|
||
| } // namespace op | ||
| } // namespace jax | ||
| } // namespace frontend | ||
| } // namespace ov | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| # Copyright (C) 2018-2024 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import jax | ||
11happy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| import numpy as np | ||
| import pytest | ||
| from jax import numpy as jnp | ||
|
|
||
| from jax_layer_test_class import JaxLayerTest | ||
|
|
||
| rng = np.random.default_rng(5402) | ||
|
|
||
|
|
||
| class TestSelectN(JaxLayerTest): | ||
| def _prepare_input(self): | ||
| cases = [] | ||
| if (self.case_num == 2): | ||
| which = rng.choice([True, False], self.input_shape) | ||
| else: | ||
| which = rng.uniform(0, self.case_num, self.input_shape).astype(self.input_type) | ||
| which = np.array(which) | ||
| for i in range(self.case_num): | ||
| cases.append(jnp.array(np.random.uniform(-1000, 1000, self.input_shape).astype(self.input_type))) | ||
| cases = np.array(cases) | ||
| return (which, cases) | ||
|
|
||
| def create_model(self, input_shape, input_type, case_num): | ||
| self.input_shape = input_shape | ||
| self.input_type = input_type | ||
| self.case_num = case_num | ||
|
|
||
| def jax_select_n(which, cases): | ||
| return jax.lax.select_n(which, *cases) | ||
|
|
||
| return jax_select_n, None, 'select_n' | ||
|
|
||
| @pytest.mark.parametrize("input_shape", [[], [1], [2, 3], [4, 5, 6], [7, 8, 9, 10]]) | ||
| @pytest.mark.parametrize("input_type", [np.int32, np.int64]) | ||
11happy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| @pytest.mark.parametrize("case_num", [2, 3, 4]) | ||
| @pytest.mark.nightly | ||
| @pytest.mark.precommit_jax_fe | ||
| def test_select_n(self, ie_device, precision, ir_version, input_shape, input_type, case_num): | ||
| self._test(*self.create_model(input_shape, input_type, case_num), | ||
| ie_device, precision, | ||
| ir_version) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.