Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/frontends/jax/src/op/select_n.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include "openvino/frontend/jax/node_context.hpp";
#include "openvino/op/ops.hpp";
#include "utils.hpp";

using namespace ov::op;

namespace ov {
namespace frontend {
namespace jax {
namespace op {

OutputVector translate_select_n(const NodeContext& context) {
num_inputs_check(context, 2);
auto num_inputs = static_cast<int>(context.get_input_size());
Output<Node> which = context.get_input(0);
OutputVector cases_vector(num_inputs - 1);
for(int ind = 1; ind < num_inputs; ++ind) {
cases_vector[ind - 1] = context.get_input(ind);
}

Output<Node> cases = std::make_shared<v0::Concat>(cases_vector, 0);
auto which_shape = which.get_shape();
std::vector<int64_t> cases_reshape_shape = {num_inputs-1,which_shape[0]};
std::vector<int64_t> which_reshape_shape = {1,which_shape[0]};

cases = std::make_shared<v1::Reshape>(cases, ov::op::v0::Constant::create(element::i64, Shape{2}, cases_reshape_shape), false);
which = std::make_shared<v1::Reshape>(which, ov::op::v0::Constant::create(element::i64, Shape{2}, which_reshape_shape), false);
Output<Node> result = std::make_shared<v6::GatherElements>(cases, which, 0);
return {result};

};

} // namespace op
} // namespace jax
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/jax/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ OP_CONVERTER(translate_reduce_window_max);
OP_CONVERTER(translate_reduce_window_sum);
OP_CONVERTER(translate_reshape);
OP_CONVERTER(translate_rsqrt);
OP_CONVERTER(translate_select_n);
OP_CONVERTER(translate_slice);
OP_CONVERTER(translate_squeeze);
OP_CONVERTER(translate_transpose);
Expand Down Expand Up @@ -91,6 +92,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_jaxpr() {
{"transpose", op::translate_transpose},
{"rsqrt", op::translate_rsqrt},
{"reshape", op::translate_reshape},
{"select_n", op::translate_select_n},
{"slice", op::translate_slice},
{"sqrt", op::translate_1to1_match_1_input<v0::Sqrt>},
{"squeeze", op::translate_squeeze},
Expand Down
41 changes: 41 additions & 0 deletions tests/layer_tests/jax_tests/test_select_n.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import jax
import numpy as np
import pytest
from jax import numpy as jnp

from jax_layer_test_class import JaxLayerTest

rng = np.random.default_rng(5402)


class TestSelectN(JaxLayerTest):
def _prepare_input(self):
cases = []
which = rng.uniform(0,self.input_shape, self.input_shape).astype(self.input_type)
which = np.array(which)

cases = []
for i in range(self.input_shape):
cases.append(jnp.array(rng.uniform(i*10, (i+1)*10, self.input_shape).astype(self.input_type)))
cases = np.array(cases)
return (which, cases)

def create_model(self, input_shape, input_type):
self.input_shape = input_shape
self.input_type = input_type

def jax_select_n(which, cases):
return jax.lax.select_n(which, *cases)

return jax_select_n, None, 'select_n'


@pytest.mark.parametrize("input_shape", [1,2,3,4,5,6,7,8,9,10])
@pytest.mark.parametrize("input_type", [np.int32, np.int64])
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_jax_fe
def test_select_n(self, ie_device, precision, ir_version, input_shape, input_type):
self._test(*self.create_model(input_shape, input_type),
ie_device, precision,
ir_version)