From aa875cb5861f5126a713c5ae45f8d82173ef0f1b Mon Sep 17 00:00:00 2001 From: 11happy Date: Wed, 1 Jan 2025 13:08:05 +0530 Subject: [PATCH 1/4] implement jax.lax.logistic Signed-off-by: 11happy --- src/frontends/jax/src/op/logistic.cpp | 26 ++++++++++++++ src/frontends/jax/src/op_table.cpp | 2 ++ tests/layer_tests/jax_tests/test_logistic.py | 36 ++++++++++++++++++++ 3 files changed, 64 insertions(+) create mode 100644 src/frontends/jax/src/op/logistic.cpp create mode 100644 tests/layer_tests/jax_tests/test_logistic.py diff --git a/src/frontends/jax/src/op/logistic.cpp b/src/frontends/jax/src/op/logistic.cpp new file mode 100644 index 00000000000000..38b7ed49e62452 --- /dev/null +++ b/src/frontends/jax/src/op/logistic.cpp @@ -0,0 +1,26 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/jax/node_context.hpp" +#include "openvino/op/sigmoid.hpp" +#include "utils.hpp" + +using namespace ov::op; + +namespace ov { +namespace frontend { +namespace jax { +namespace op { + +OutputVector translate_logistic(const NodeContext& context) { + num_inputs_check(context, 1, 1); + auto input = context.get_input(0); + auto logistic = std::make_shared(input); + return {logistic}; +}; + +} // namespace op +} // namespace jax +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/jax/src/op_table.cpp b/src/frontends/jax/src/op_table.cpp index 9c492dfa3e119d..8ee31ff69c7f08 100644 --- a/src/frontends/jax/src/op_table.cpp +++ b/src/frontends/jax/src/op_table.cpp @@ -53,6 +53,7 @@ OP_CONVERTER(translate_reduce_window_sum); OP_CONVERTER(translate_reshape); OP_CONVERTER(translate_rsqrt); OP_CONVERTER(translate_select_n); +OP_CONVERTER(translate_logistic); OP_CONVERTER(translate_slice); OP_CONVERTER(translate_square); OP_CONVERTER(translate_squeeze); @@ -94,6 +95,7 @@ const std::map get_supported_ops_jaxpr() { {"rsqrt", op::translate_rsqrt}, {"reshape", op::translate_reshape}, {"select_n", op::translate_select_n}, + {"logistic", op::translate_logistic}, {"slice", op::translate_slice}, {"square", op::translate_square}, {"sqrt", op::translate_1to1_match_1_input}, diff --git a/tests/layer_tests/jax_tests/test_logistic.py b/tests/layer_tests/jax_tests/test_logistic.py new file mode 100644 index 00000000000000..a08a7dfa2e0eb5 --- /dev/null +++ b/tests/layer_tests/jax_tests/test_logistic.py @@ -0,0 +1,36 @@ +# Copyright (C) 2018-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import jax +import numpy as np +import pytest +from jax import numpy as jnp + +from jax_layer_test_class import JaxLayerTest + +rng = np.random.default_rng(5402) + + +class TestLogistic(JaxLayerTest): + def _prepare_input(self): + + input = jnp.array(np.random.uniform(-1000, 1000, self.input_shape).astype(self.input_type)) + return [input] + + def create_model(self, input_shape, input_type): + self.input_shape = input_shape + self.input_type = input_type + + def jax_logistic(input): + return jax.lax.logistic(input) + + return jax_logistic, None, 'logistic' + + @pytest.mark.parametrize("input_shape", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + @pytest.mark.parametrize("input_type", [np.float32, np.float64]) + @pytest.mark.nightly + @pytest.mark.precommit_jax_fe + def test_logistic(self, ie_device, precision, ir_version, input_shape, input_type): + self._test(*self.create_model(input_shape, input_type), + ie_device, precision, + ir_version) From 9fce3ee0e9be0bac8a5da2ab02888fd79de2c84f Mon Sep 17 00:00:00 2001 From: 11happy Date: Wed, 1 Jan 2025 18:30:38 +0530 Subject: [PATCH 2/4] simplify implementation Signed-off-by: 11happy --- src/frontends/jax/src/op/logistic.cpp | 26 -------------------- src/frontends/jax/src/op_table.cpp | 3 ++- tests/layer_tests/jax_tests/test_logistic.py | 3 ++- 3 files changed, 4 insertions(+), 28 deletions(-) delete mode 100644 src/frontends/jax/src/op/logistic.cpp diff --git a/src/frontends/jax/src/op/logistic.cpp b/src/frontends/jax/src/op/logistic.cpp deleted file mode 100644 index 38b7ed49e62452..00000000000000 --- a/src/frontends/jax/src/op/logistic.cpp +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (C) 2018-2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "openvino/frontend/jax/node_context.hpp" -#include "openvino/op/sigmoid.hpp" -#include "utils.hpp" - -using namespace ov::op; - -namespace ov { -namespace frontend { -namespace jax { -namespace op { - -OutputVector translate_logistic(const NodeContext& context) { - num_inputs_check(context, 1, 1); - auto input = context.get_input(0); - auto logistic = std::make_shared(input); - return {logistic}; -}; - -} // namespace op -} // namespace jax -} // namespace frontend -} // namespace ov \ No newline at end of file diff --git a/src/frontends/jax/src/op_table.cpp b/src/frontends/jax/src/op_table.cpp index 8ee31ff69c7f08..ef528f446700fd 100644 --- a/src/frontends/jax/src/op_table.cpp +++ b/src/frontends/jax/src/op_table.cpp @@ -20,6 +20,7 @@ #include "openvino/op/reduce_max.hpp" #include "openvino/op/reduce_sum.hpp" #include "openvino/op/sqrt.hpp" +#include "openvino/op/sigmoid.hpp" #include "openvino/op/subtract.hpp" #include "openvino/op/tanh.hpp" #include "utils.hpp" @@ -95,7 +96,7 @@ const std::map get_supported_ops_jaxpr() { {"rsqrt", op::translate_rsqrt}, {"reshape", op::translate_reshape}, {"select_n", op::translate_select_n}, - {"logistic", op::translate_logistic}, + {"logistic", op::translate_1to1_match_1_input}, {"slice", op::translate_slice}, {"square", op::translate_square}, {"sqrt", op::translate_1to1_match_1_input}, diff --git a/tests/layer_tests/jax_tests/test_logistic.py b/tests/layer_tests/jax_tests/test_logistic.py index a08a7dfa2e0eb5..19e5f0a81eba30 100644 --- a/tests/layer_tests/jax_tests/test_logistic.py +++ b/tests/layer_tests/jax_tests/test_logistic.py @@ -26,9 +26,10 @@ def jax_logistic(input): return jax_logistic, None, 'logistic' - @pytest.mark.parametrize("input_shape", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + @pytest.mark.parametrize("input_shape", [[2], [3, 4], [5,6,7]]) @pytest.mark.parametrize("input_type", [np.float32, np.float64]) @pytest.mark.nightly + @pytest.mark.precommit @pytest.mark.precommit_jax_fe def test_logistic(self, ie_device, precision, ir_version, input_shape, input_type): self._test(*self.create_model(input_shape, input_type), From 6c19bb326faeca85e9c527547c1a7401dc30a251 Mon Sep 17 00:00:00 2001 From: Bhuminjay Soni Date: Thu, 2 Jan 2025 10:20:49 +0530 Subject: [PATCH 3/4] Update src/frontends/jax/src/op_table.cpp Co-authored-by: Roman Kazantsev --- src/frontends/jax/src/op_table.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/frontends/jax/src/op_table.cpp b/src/frontends/jax/src/op_table.cpp index ef528f446700fd..78f7108a7f9294 100644 --- a/src/frontends/jax/src/op_table.cpp +++ b/src/frontends/jax/src/op_table.cpp @@ -54,7 +54,6 @@ OP_CONVERTER(translate_reduce_window_sum); OP_CONVERTER(translate_reshape); OP_CONVERTER(translate_rsqrt); OP_CONVERTER(translate_select_n); -OP_CONVERTER(translate_logistic); OP_CONVERTER(translate_slice); OP_CONVERTER(translate_square); OP_CONVERTER(translate_squeeze); From ff96484477ad764b9ec02ec22a5adb5fe8c7c0e6 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Thu, 2 Jan 2025 13:00:46 +0400 Subject: [PATCH 4/4] Apply suggestions from code review --- src/frontends/jax/src/op_table.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/frontends/jax/src/op_table.cpp b/src/frontends/jax/src/op_table.cpp index 78f7108a7f9294..6ae0e6adc7c469 100644 --- a/src/frontends/jax/src/op_table.cpp +++ b/src/frontends/jax/src/op_table.cpp @@ -19,8 +19,8 @@ #include "openvino/op/not_equal.hpp" #include "openvino/op/reduce_max.hpp" #include "openvino/op/reduce_sum.hpp" -#include "openvino/op/sqrt.hpp" #include "openvino/op/sigmoid.hpp" +#include "openvino/op/sqrt.hpp" #include "openvino/op/subtract.hpp" #include "openvino/op/tanh.hpp" #include "utils.hpp"