-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Added support for aten::randperm and aten::polar #29585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
0656890
added aten::take function in pytorch, but tests are not running
vijaykr338 98c1c61
implemented aten::str, aten::delete but unable to write their tests
vijaykr338 725c474
Merge branch 'master' into master
vijaykr338 4176919
added support for aten::randperm with tests
vijaykr338 e32628f
Delete src/frontends/pytorch/src/op/delete.cpp
vijaykr338 1483255
cleaned
vijaykr338 8a16331
Delete tests/layer_tests/pytorch_tests/test_take.py
vijaykr338 e8b35e8
Delete src/frontends/pytorch/src/op/take.cpp
vijaykr338 16f11fa
Delete src/frontends/pytorch/src/op/str.cpp
vijaykr338 cd18e01
added support for aten::polar
vijaykr338 113d07e
fixed code style for aten::polar and aten::randperm
vijaykr338 c6ec86e
Merge branch 'master' into work2
vijaykr338 bf75e7b
Merge branch 'master' into work2
vijaykr338 c8dac15
Merge branch 'master' into work2
vijaykr338 dcc5d37
fixed coding style with clang-format-9 and added the suggested changes
vijaykr338 6bd7914
Update src/frontends/pytorch/src/op_table.cpp
vijaykr338 21f7245
Update test_polar.py
vijaykr338 8e42724
Merge branch 'master' into work2
vijaykr338 db261fc
Update polar.cpp
vijaykr338 bbb6952
Update test_randperm.py
vijaykr338 9cc31cd
Update tests/layer_tests/pytorch_tests/test_randperm.py
vijaykr338 7c9e745
Update polar.cpp
vijaykr338 7b4b3b7
Update randperm.cpp
vijaykr338 e8f42fa
Update test_randperm.py
vijaykr338 b27a3f2
Update test_randperm.py, added the mark
vijaykr338 f041050
Update test_randperm.py
vijaykr338 e7e3999
Update test_randperm.py
vijaykr338 e4d204d
Update test_polar.py
vijaykr338 0102220
Merge branch 'master' into work2
mvafin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| // Copyright (C) 2018-2025 Intel Corporation | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // | ||
|
|
||
| #include "openvino/frontend/complex_type_mark.hpp" | ||
| #include "openvino/op/convert.hpp" | ||
| #include "openvino/op/cos.hpp" | ||
| #include "openvino/op/multiply.hpp" | ||
| #include "openvino/op/sin.hpp" | ||
| #include "utils.hpp" | ||
|
|
||
| namespace ov { | ||
| namespace frontend { | ||
| namespace pytorch { | ||
| namespace op { | ||
|
|
||
| using namespace ov::op; | ||
|
|
||
| OutputVector translate_polar(const NodeContext& context) { | ||
| num_inputs_check(context, 2, 3); | ||
| auto abs = context.get_input(0); | ||
| auto angle = context.get_input(1); | ||
| auto cos_node = context.mark_node(std::make_shared<v0::Cos>(angle)); | ||
| auto real = context.mark_node(std::make_shared<v1::Multiply>(abs, cos_node)); | ||
| auto sin_node = context.mark_node(std::make_shared<v0::Sin>(angle)); | ||
| auto imag = context.mark_node(std::make_shared<v1::Multiply>(abs, sin_node)); | ||
| auto complex_tensor = context.mark_node(std::make_shared<ComplexTypeMark>(real, imag)); | ||
| return {complex_tensor}; | ||
| } | ||
|
|
||
| } // namespace op | ||
| } // namespace pytorch | ||
| } // namespace frontend | ||
| } // namespace ov | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| // Copyright (C) 2018-2025 Intel Corporation | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // | ||
|
|
||
| #include "openvino/frontend/pytorch/node_context.hpp" | ||
vijaykr338 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| #include "openvino/op/constant.hpp" | ||
| #include "openvino/op/random_uniform.hpp" | ||
| #include "openvino/op/topk.hpp" | ||
| #include "openvino/op/unsqueeze.hpp" | ||
| #include "utils.hpp" | ||
|
|
||
| namespace ov { | ||
| namespace frontend { | ||
| namespace pytorch { | ||
| namespace op { | ||
|
|
||
| using namespace ov::op; | ||
|
|
||
| OutputVector translate_randperm(const NodeContext& context) { | ||
| auto num_inputs = context.get_input_size(); | ||
| auto n_node = context.get_input(0); | ||
| int dtype_value = 4; | ||
| if (num_inputs == 1) { | ||
| } else if (num_inputs == 2) { | ||
| if (!context.input_is_none(1)) { | ||
| dtype_value = context.const_input<int>(1); | ||
| PYTORCH_OP_CONVERSION_CHECK(dtype_value == 4, | ||
| "Only dtype value 4 (int64) is supported for aten::randperm, got: ", | ||
| dtype_value); | ||
| } | ||
| } else if (num_inputs == 5) { | ||
| if (!context.input_is_none(1)) { | ||
| dtype_value = context.const_input<int>(1); | ||
| PYTORCH_OP_CONVERSION_CHECK(dtype_value == 4, | ||
| "Only dtype value 4 (int64) is supported for aten::randperm, got: ", | ||
| dtype_value); | ||
| } | ||
| } else { | ||
| PYTORCH_OP_CONVERSION_CHECK(false, "Unexpected number of inputs for aten::randperm: ", num_inputs); | ||
| } | ||
| auto axis_zero = v0::Constant::create(element::i64, Shape{1}, {0}); | ||
| auto shape = context.mark_node(std::make_shared<v0::Unsqueeze>(n_node, axis_zero)); | ||
| auto min_val = v0::Constant::create(element::f32, Shape{}, {0.0f}); | ||
| auto max_val = v0::Constant::create(element::f32, Shape{}, {1.0f}); | ||
| auto random_tensor = context.mark_node(std::make_shared<v8::RandomUniform>(shape, min_val, max_val, element::f32)); | ||
| const int64_t axis = 0; | ||
| auto topk = context.mark_node(std::make_shared<v11::TopK>(random_tensor, | ||
| n_node, | ||
| axis, | ||
| ov::op::TopKMode::MIN, | ||
| ov::op::TopKSortType::SORT_VALUES, | ||
| element::i64, | ||
| false)); | ||
| return {topk->output(1)}; | ||
| } | ||
|
|
||
| } // namespace op | ||
| } // namespace pytorch | ||
| } // namespace frontend | ||
| } // namespace ov | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| # Copyright (C) 2018-2025 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
| import torch | ||
| import openvino as ov | ||
| from pytorch_layer_test_class import PytorchLayerTest | ||
|
|
||
| class TestPolar(PytorchLayerTest): | ||
| def _prepare_input(self, input_shape=(1, 1000), dtype=np.float32): | ||
| return ( | ||
| np.random.uniform(0, 10, input_shape).astype(dtype), | ||
| np.random.uniform(-np.pi, np.pi, input_shape).astype(dtype) | ||
| ) | ||
|
|
||
| def create_model(self): | ||
| class PolarModel(torch.nn.Module): | ||
| def forward(self, abs, angle): | ||
| complex_tensor = torch.polar(abs, angle) | ||
| return torch.view_as_real(complex_tensor) | ||
| ref_net = None | ||
| return PolarModel(), None, "aten::polar" | ||
|
|
||
| @pytest.mark.parametrize("input_case", [ | ||
| (1, 1000), | ||
| (2, 500), | ||
| (5, 200), | ||
| (10, 100), | ||
| ]) | ||
| @pytest.mark.parametrize("dtype", [ | ||
| np.float32, | ||
| np.float64 | ||
| ]) | ||
| @pytest.mark.nightly | ||
| @pytest.mark.precommit | ||
| def test_polar(self, input_case, dtype, ie_device, precision, ir_version): | ||
| self._test(*self.create_model(), ie_device, precision, ir_version, | ||
| kwargs_to_prepare_input={"input_shape": input_case, "dtype": dtype}, | ||
| trace_model=True, use_convert_model=True, dynamic_shapes=False) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| # Copyright (C) 2018-2025 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import pytest | ||
vijaykr338 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| import torch | ||
| import numpy as np | ||
| from pytorch_layer_test_class import PytorchLayerTest | ||
|
|
||
| class TestSortedRandperm(PytorchLayerTest): | ||
| def _prepare_input(self): | ||
| return (np.arange(self.n, dtype=np.int64),) | ||
|
|
||
| def create_model(self, n, num_inputs, dtype_value=None): | ||
| class AtenSortedRandperm(torch.nn.Module): | ||
| def __init__(self, n, num_inputs, dtype_value): | ||
| super().__init__() | ||
| self.n = n | ||
| self.num_inputs = num_inputs | ||
| self.dtype = torch.int64 if dtype_value == 4 else None | ||
|
|
||
| def forward(self, x): | ||
| if self.num_inputs == 1: | ||
| p = torch.randperm(self.n) | ||
| elif self.num_inputs == 2: | ||
| p = torch.randperm(self.n, dtype=self.dtype) | ||
| elif self.num_inputs == 5: | ||
| p = torch.randperm(self.n, dtype=self.dtype, layout=torch.strided, | ||
| device=x.device, pin_memory=False) | ||
| else: | ||
| raise ValueError("Invalid num_inputs") | ||
| # sort to get a deterministic order for verifying the permutation. | ||
| x_permuted = x[p] | ||
| sorted_tensor, _ = torch.sort(x_permuted) | ||
vijaykr338 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return sorted_tensor | ||
|
|
||
| return AtenSortedRandperm(n, num_inputs, dtype_value), None, "aten::randperm" | ||
|
|
||
| @pytest.mark.parametrize(("n", "num_inputs", "dtype_value"), [ | ||
| (0, 1, None), | ||
| (1, 1, None), | ||
| (5, 1, None), | ||
| (5, 2, 4), | ||
| (5, 5, 4), | ||
| ]) | ||
| @pytest.mark.nightly | ||
| @pytest.mark.precommit | ||
| def test_sorted_randperm(self, n, num_inputs, dtype_value, ie_device, precision, ir_version): | ||
vijaykr338 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.n = n | ||
| self._test(*self.create_model(n, num_inputs, dtype_value), ie_device, precision, ir_version) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.