-
Notifications
You must be signed in to change notification settings - Fork 2.8k
[PyTorch] : implement support for replicated{1,2,3} pad #28271
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
dd2430a
4859af0
0824a3e
388964f
a759cc7
ac6579a
d8dca07
aa6ed1f
3113ed7
d7de4a9
47f9254
d560edc
d1b983e
50a9d41
1960c64
dc8b1ae
c246fe7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -219,9 +219,9 @@ def __init__(self, pads): | |
| if ndim == 1: | ||
| self.pad = torch.nn.ReflectionPad1d(pads) | ||
| elif ndim == 2: | ||
| self.pad = torch.nn.ReflectionPad1d(pads) | ||
| self.pad = torch.nn.ReflectionPad2d(pads) | ||
| elif ndim == 3: | ||
| self.pad = torch.nn.ReflectionPad1d(pads) | ||
| self.pad = torch.nn.ReflectionPad3d(pads) | ||
| else: | ||
| raise Exception("Unsupported pads") | ||
|
|
||
|
|
@@ -244,3 +244,44 @@ def test_reflection_padnd(self, pads, dtype, ie_device, precision, ir_version): | |
| print(ndim) | ||
| self._test(*self.create_model(pads), ie_device, precision, ir_version, | ||
| kwargs_to_prepare_input={"ndim": ndim, "dtype": dtype}) | ||
|
|
||
| class TestReplicatePad(PytorchLayerTest): | ||
rkazants marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def _prepare_input(self, ndim=4, dtype="float32"): | ||
| import numpy as np | ||
| input_5d_shape = [5,9,1,1,2,4] | ||
|
||
| return (np.random.randn(*input_5d_shape[:ndim]).astype(dtype),) | ||
|
|
||
| def create_model(self, pads): | ||
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
| class aten_pad(torch.nn.Module): | ||
| def __init__(self, pads): | ||
| super().__init__() | ||
| ndim = len(pads) / 2 | ||
| if ndim == 1: | ||
| self.pad = torch.nn.ReplicationPad1d(pads) | ||
| elif ndim == 2: | ||
| self.pad = torch.nn.ReplicationPad2d(pads) | ||
| elif ndim == 3: | ||
| self.pad = torch.nn.ReplicationPad3d(pads) | ||
| else: | ||
| raise Exception("Unsupported pads") | ||
|
|
||
| def forward(self, x): | ||
| return self.pad(x) | ||
|
|
||
| return aten_pad(pads), None, "aten::pad" | ||
mvafin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| @pytest.mark.parametrize("dtype", ["float32", "float64", "int32"]) | ||
| @pytest.mark.parametrize("pads", [ | ||
| (1, 2), | ||
rkazants marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| (1, 2, 3, 4), | ||
| (1, 2, 3, 4, 3, 2), | ||
| ]) | ||
| @pytest.mark.nightly | ||
| @pytest.mark.precommit_torch_export | ||
rkazants marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| def test_replicate_padnd(self, pads, dtype, ie_device, precision, ir_version): | ||
| ndim = len(pads) // 2 + 2 | ||
| self._test(*self.create_model(pads), ie_device, precision, ir_version, | ||
| kwargs_to_prepare_input={"ndim": ndim, "dtype": dtype}) | ||
Uh oh!
There was an error while loading. Please reload this page.