Skip to content

Commit 548786a

Browse files
11happyrkazants
andauthored
[JAX FE]: add support for jax.lax.logistic (#28240)
**Overview:** This PR fixes #26576. **Testing:** - Tested the Updated code - Verified that other functionalities remain unaffected ![Screenshot from 2025-01-01 13-11-04](https://github.com/user-attachments/assets/5acfabc2-dded-4c65-b408-d4174fa3c025) **Dependencies:** - No dependencies on other pull requests **CC:** - @rkazants --------- Signed-off-by: 11happy <[email protected]> Co-authored-by: Roman Kazantsev <[email protected]>
1 parent 2e24dfa commit 548786a

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

src/frontends/jax/src/op_table.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "openvino/op/not_equal.hpp"
2020
#include "openvino/op/reduce_max.hpp"
2121
#include "openvino/op/reduce_sum.hpp"
22+
#include "openvino/op/sigmoid.hpp"
2223
#include "openvino/op/sqrt.hpp"
2324
#include "openvino/op/subtract.hpp"
2425
#include "openvino/op/tanh.hpp"
@@ -94,6 +95,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_jaxpr() {
9495
{"rsqrt", op::translate_rsqrt},
9596
{"reshape", op::translate_reshape},
9697
{"select_n", op::translate_select_n},
98+
{"logistic", op::translate_1to1_match_1_input<v0::Sigmoid>},
9799
{"slice", op::translate_slice},
98100
{"square", op::translate_square},
99101
{"sqrt", op::translate_1to1_match_1_input<v0::Sqrt>},
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (C) 2018-2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import jax
5+
import numpy as np
6+
import pytest
7+
from jax import numpy as jnp
8+
9+
from jax_layer_test_class import JaxLayerTest
10+
11+
rng = np.random.default_rng(5402)
12+
13+
14+
class TestLogistic(JaxLayerTest):
15+
def _prepare_input(self):
16+
17+
input = jnp.array(np.random.uniform(-1000, 1000, self.input_shape).astype(self.input_type))
18+
return [input]
19+
20+
def create_model(self, input_shape, input_type):
21+
self.input_shape = input_shape
22+
self.input_type = input_type
23+
24+
def jax_logistic(input):
25+
return jax.lax.logistic(input)
26+
27+
return jax_logistic, None, 'logistic'
28+
29+
@pytest.mark.parametrize("input_shape", [[2], [3, 4], [5,6,7]])
30+
@pytest.mark.parametrize("input_type", [np.float32, np.float64])
31+
@pytest.mark.nightly
32+
@pytest.mark.precommit
33+
@pytest.mark.precommit_jax_fe
34+
def test_logistic(self, ie_device, precision, ir_version, input_shape, input_type):
35+
self._test(*self.create_model(input_shape, input_type),
36+
ie_device, precision,
37+
ir_version)

0 commit comments

Comments
 (0)