-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Added support for aten::randperm and aten::polar #29585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 20 commits
0656890
98c1c61
725c474
4176919
e32628f
1483255
8a16331
e8b35e8
16f11fa
cd18e01
113d07e
c6ec86e
bf75e7b
c8dac15
dcc5d37
6bd7914
21f7245
8e42724
db261fc
bbb6952
9cc31cd
7c9e745
7b4b3b7
e8f42fa
b27a3f2
f041050
e7e3999
e4d204d
0102220
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| #include "openvino/frontend/complex_type_mark.hpp" | ||
| #include "openvino/op/convert.hpp" | ||
| #include "openvino/op/cos.hpp" | ||
| #include "openvino/op/multiply.hpp" | ||
| #include "openvino/op/sin.hpp" | ||
| #include "utils.hpp" | ||
|
|
||
| namespace ov { | ||
| namespace frontend { | ||
| namespace pytorch { | ||
| namespace op { | ||
|
|
||
| using namespace ov::op; | ||
|
|
||
| OutputVector translate_polar(const NodeContext& context) { | ||
| num_inputs_check(context, 2, 3); | ||
| auto abs = context.get_input(0); | ||
| auto angle = context.get_input(1); | ||
| auto cos_node = context.mark_node(std::make_shared<v0::Cos>(angle)); | ||
| auto real = context.mark_node(std::make_shared<v1::Multiply>(abs, cos_node)); | ||
| auto sin_node = context.mark_node(std::make_shared<v0::Sin>(angle)); | ||
| auto imag = context.mark_node(std::make_shared<v1::Multiply>(abs, sin_node)); | ||
| auto complex_tensor = context.mark_node(std::make_shared<ComplexTypeMark>(real, imag)); | ||
| return {complex_tensor}; | ||
| } | ||
|
|
||
| } // namespace op | ||
| } // namespace pytorch | ||
| } // namespace frontend | ||
| } // namespace ov | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| #include "openvino/frontend/pytorch/node_context.hpp" | ||
vijaykr338 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| #include "openvino/op/constant.hpp" | ||
| #include "openvino/op/random_uniform.hpp" | ||
| #include "openvino/op/shape_of.hpp" | ||
| #include "openvino/op/topk.hpp" | ||
| #include "utils.hpp" | ||
|
|
||
| namespace ov { | ||
| namespace frontend { | ||
| namespace pytorch { | ||
| namespace op { | ||
|
|
||
| using namespace ov::op; | ||
|
|
||
| OutputVector translate_randperm(const NodeContext& context) { | ||
| auto num_inputs = context.get_input_size(); | ||
| int64_t n = context.const_input<int64_t>(0); | ||
mvafin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| int dtype_value = 4; | ||
| if (num_inputs == 1) { | ||
| } else if (num_inputs == 2) { | ||
| if (!context.input_is_none(1)) { | ||
| dtype_value = context.const_input<int>(1); | ||
| OPENVINO_ASSERT(dtype_value == 4, | ||
| "Only dtype value 4 (int64) is supported for aten::randperm, got: ", | ||
| dtype_value); | ||
| } | ||
| } else if (num_inputs == 5) { | ||
| if (!context.input_is_none(1)) { | ||
| dtype_value = context.const_input<int>(1); | ||
| OPENVINO_ASSERT(dtype_value == 4, | ||
vijaykr338 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| "Only dtype value 4 (int64) is supported for aten::randperm, got: ", | ||
| dtype_value); | ||
| } | ||
| } else { | ||
| PYTORCH_OP_CONVERSION_CHECK(false, "Unexpected number of inputs for aten::randperm: ", num_inputs); | ||
| } | ||
| if (n == 0) { | ||
| auto const_empty = std::make_shared<v0::Constant>(element::i64, Shape{0}, std::vector<int64_t>{}); | ||
| return {context.mark_node(const_empty)}; | ||
| } | ||
| auto shape = v0::Constant::create(element::i64, Shape{1}, {n}); | ||
| auto min_val = v0::Constant::create(element::f32, Shape{}, {0.0f}); | ||
| auto max_val = v0::Constant::create(element::f32, Shape{}, {1.0f}); | ||
| auto random_tensor = context.mark_node(std::make_shared<v8::RandomUniform>(shape, min_val, max_val, element::f32)); | ||
| const int64_t axis = 0; | ||
| auto k = v0::Constant::create(element::i64, Shape{}, {n}); | ||
| auto topk = context.mark_node(std::make_shared<v11::TopK>(random_tensor, | ||
| k, | ||
| axis, | ||
| ov::op::TopKMode::MIN, | ||
| ov::op::TopKSortType::SORT_VALUES, | ||
| element::i64, | ||
| false)); | ||
| return {topk->output(1)}; | ||
| } | ||
|
|
||
| } // namespace op | ||
| } // namespace pytorch | ||
| } // namespace frontend | ||
| } // namespace ov | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,42 @@ | ||||||
| # Copyright (C) 2018-2025 Intel Corporation | ||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||
|
|
||||||
| import numpy as np | ||||||
| import pytest | ||||||
| import torch | ||||||
| import openvino as ov | ||||||
| from pytorch_layer_test_class import PytorchLayerTest | ||||||
|
|
||||||
| class TestPolar(PytorchLayerTest): | ||||||
| def _prepare_input(self, input_shape=(1, 1000), dtype=np.float32): | ||||||
| return ( | ||||||
| np.random.uniform(0, 10, input_shape).astype(dtype), | ||||||
| np.random.uniform(-np.pi, np.pi, input_shape).astype(dtype) | ||||||
| ) | ||||||
|
|
||||||
| def create_model(self): | ||||||
| class PolarModel(torch.nn.Module): | ||||||
| def forward(self, abs, angle): | ||||||
| complex_tensor = torch.polar(abs, angle) | ||||||
| return torch.view_as_real(complex_tensor) | ||||||
| ref_net = None | ||||||
| return PolarModel(), None, "aten::polar" | ||||||
|
|
||||||
| @pytest.mark.parametrize("input_case", [ | ||||||
| (1, 1000), | ||||||
| (2, 500), | ||||||
| (5, 200), | ||||||
| (10, 100), | ||||||
| ]) | ||||||
| @pytest.mark.parametrize("dtype", [ | ||||||
| np.float32, | ||||||
| np.float64 | ||||||
| ]) | ||||||
| @pytest.mark.nightly | ||||||
| @pytest.mark.precommit | ||||||
| def test_polar(self, input_case, dtype, ie_device, precision, ir_version): | ||||||
| atol = 1e-4 if precision == "FP32" else 1e-3 | ||||||
| rtol = 1e-4 | ||||||
|
||||||
| atol = 1e-4 if precision == "FP32" else 1e-3 | |
| rtol = 1e-4 |
you are not using atol and rtol here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yes, I have removed them now, I was initially using atol and rtol as I was worried about the floating point differences that might be produced, but the default test layer handles it fine!
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| import pytest | ||
vijaykr338 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| import torch | ||
| import numpy as np | ||
| from pytorch_layer_test_class import PytorchLayerTest | ||
|
|
||
| class TestSortedRandperm(PytorchLayerTest): | ||
| def _prepare_input(self): | ||
| return (np.array([self.n], dtype=np.int64),) | ||
|
|
||
| def create_model(self, n, num_inputs, dtype_value=None): | ||
| class AtenSortedRandperm(torch.nn.Module): | ||
| def __init__(self, n, num_inputs, dtype_value): | ||
| super().__init__() | ||
| self.n = torch.tensor(n, dtype=torch.int64) | ||
| self.num_inputs = num_inputs | ||
| self.dtype = torch.int64 if dtype_value == 4 else None | ||
|
|
||
| def forward(self, x): | ||
| if self.num_inputs == 1: | ||
| p = torch.randperm(self.n) | ||
| elif self.num_inputs == 2: | ||
| p = torch.randperm(self.n, dtype=self.dtype) | ||
| elif self.num_inputs == 5: | ||
| p = torch.randperm(self.n, dtype=self.dtype, layout=torch.strided, | ||
| device=x.device, pin_memory=False) | ||
| else: | ||
| raise ValueError("Invalid num_inputs") | ||
| sorted_p, _ = torch.sort(p) | ||
|
||
| return sorted_p | ||
| return AtenSortedRandperm(n, num_inputs, dtype_value), None, "aten::randperm" | ||
|
|
||
| @pytest.mark.parametrize(("n", "num_inputs", "dtype_value"), [ | ||
| (0, 1, None), | ||
| (1, 1, None), | ||
| (5, 1, None), | ||
| (5, 2, 4), | ||
| (5, 5, 4), | ||
| ]) | ||
| def test_sorted_randperm(self, n, num_inputs, dtype_value, ie_device, precision, ir_version): | ||
vijaykr338 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.n = n | ||
| self._test(*self.create_model(n, num_inputs, dtype_value), ie_device, precision, ir_version) | ||
Uh oh!
There was an error while loading. Please reload this page.