Skip to content

Commit 6aacba7

Browse files
Hmm-1224chiruu12
authored andcommitted
Add OpenVINO backend support for argmin and argmax (keras-team#21060)
* Update numpy.py * Update excluded_concrete_tests.txt * all issues fixed * Update numpy.py * numpy.py reformatted * Update excluded_concrete_tests.txt * Update excluded_concrete_tests.txt * Update excluded_concrete_tests.txt * Update excluded_concrete_tests.txt * Update excluded_concrete_tests.txt * Update excluded_concrete_tests.txt
1 parent 3c276d8 commit 6aacba7

File tree

2 files changed

+59
-7
lines changed

2 files changed

+59
-7
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ NumpyDtypeTest::test_absolute_bool
55
NumpyDtypeTest::test_add_
66
NumpyDtypeTest::test_all
77
NumpyDtypeTest::test_any
8-
NumpyDtypeTest::test_argmax
9-
NumpyDtypeTest::test_argmin
108
NumpyDtypeTest::test_argpartition
119
NumpyDtypeTest::test_array
1210
NumpyDtypeTest::test_bitwise
@@ -77,8 +75,6 @@ NumpyDtypeTest::test_square_bool
7775
HistogramTest
7876
NumpyOneInputOpsCorrectnessTest::test_all
7977
NumpyOneInputOpsCorrectnessTest::test_any
80-
NumpyOneInputOpsCorrectnessTest::test_argmax
81-
NumpyOneInputOpsCorrectnessTest::test_argmin
8278
NumpyOneInputOpsCorrectnessTest::test_argpartition
8379
NumpyOneInputOpsCorrectnessTest::test_array
8480
NumpyOneInputOpsCorrectnessTest::test_bitwise_invert
@@ -161,4 +157,4 @@ NumpyTwoInputOpsCorrectnessTest::test_quantile
161157
NumpyTwoInputOpsCorrectnessTest::test_take_along_axis
162158
NumpyTwoInputOpsCorrectnessTest::test_tensordot
163159
NumpyTwoInputOpsCorrectnessTest::test_vdot
164-
NumpyTwoInputOpsCorrectnessTest::test_where
160+
NumpyTwoInputOpsCorrectnessTest::test_where

keras/src/backend/openvino/numpy.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,11 +328,67 @@ def arctanh(x):
328328

329329

330330
def argmax(x, axis=None, keepdims=False):
331-
raise NotImplementedError("`argmax` is not supported with openvino backend")
331+
x = get_ov_output(x)
332+
x_shape = x.get_partial_shape()
333+
rank = x_shape.rank.get_length()
334+
if rank == 0:
335+
return OpenVINOKerasTensor(ov_opset.constant([0], Type.i32).output(0))
336+
if axis is None:
337+
flatten_shape = ov_opset.constant(
338+
[-1] + [1] * (rank - 1), Type.i32
339+
).output(0)
340+
x = ov_opset.reshape(x, flatten_shape, False).output(0)
341+
axis = 0
342+
k = ov_opset.constant(1, Type.i32).output(0)
343+
else:
344+
if axis < 0:
345+
axis = rank + axis
346+
k = ov_opset.constant(1, Type.i32).output(0)
347+
topk_outputs = ov_opset.topk(
348+
x,
349+
k=k,
350+
axis=axis,
351+
mode="max",
352+
sort="value",
353+
stable=True,
354+
index_element_type=Type.i32,
355+
)
356+
topk_indices = topk_outputs.output(1)
357+
if not keepdims:
358+
topk_indices = ov_opset.squeeze(topk_indices, [axis]).output(0)
359+
return OpenVINOKerasTensor(topk_indices)
332360

333361

334362
def argmin(x, axis=None, keepdims=False):
335-
raise NotImplementedError("`argmin` is not supported with openvino backend")
363+
x = get_ov_output(x)
364+
x_shape = x.get_partial_shape()
365+
rank = x_shape.rank.get_length()
366+
if rank == 0:
367+
return OpenVINOKerasTensor(ov_opset.constant([0], Type.i32).output(0))
368+
if axis is None:
369+
flatten_shape = ov_opset.constant(
370+
[-1] + [1] * (rank - 1), Type.i32
371+
).output(0)
372+
x = ov_opset.reshape(x, flatten_shape, False).output(0)
373+
axis = 0
374+
k = ov_opset.constant(1, Type.i32).output(0)
375+
else:
376+
if axis < 0:
377+
axis = rank + axis
378+
k = ov_opset.constant(1, Type.i32).output(0)
379+
topk_outputs = ov_opset.topk(
380+
x,
381+
k=k,
382+
axis=axis,
383+
mode="min",
384+
sort="value",
385+
stable=True,
386+
index_element_type=Type.i32,
387+
)
388+
topk_indices = topk_outputs.output(1)
389+
if not keepdims:
390+
topk_indices = ov_opset.squeeze(topk_indices, [axis]).output(0)
391+
return OpenVINOKerasTensor(topk_indices)
336392

337393

338394
def argsort(x, axis=-1):

0 commit comments

Comments
 (0)