Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions src/frontends/pytorch/src/op/pop.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_pop(const NodeContext& context) {
num_inputs_check(context, 1, 2);

auto list_elems = get_list_as_outputs(context.get_input(0));
if (list_elems.empty())
throw std::runtime_error("pop from empty list");
size_t list_size = list_elems.size();

int64_t pop_index = -1;
if (!context.input_is_none(1)) {
auto node = context.get_input(1);
auto constant = std::dynamic_pointer_cast<ov::op::v0::Constant>(node.get_node_shared_ptr());
if (!constant)
throw std::runtime_error("pop index must be a constant integer");
auto values = constant->cast_vector<int64_t>();
if (values.empty())
throw std::runtime_error("pop index constant is empty");
pop_index = values[0];
}

if (pop_index == -1) {
pop_index = list_size - 1;
} else if (pop_index < 0) {
pop_index += list_size;
}
if (pop_index < 0 || pop_index >= static_cast<int64_t>(list_size))
throw std::runtime_error("pop index out of range");

auto result = list_elems[pop_index];
list_elems.erase(list_elems.begin() + pop_index);

if (!context.input_is_none(1)) {
context.mutate_input(1, result);
}

return {result};
}

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ OP_CONVERTER(translate_permute);
OP_CONVERTER(translate_pairwise_distance);
OP_CONVERTER(translate_pixel_shuffle);
OP_CONVERTER(translate_pixel_unshuffle);
OP_CONVERTER(translate_pop);
OP_CONVERTER(translate_polar);
OP_CONVERTER(translate_pow);
OP_CONVERTER(translate_prod);
Expand Down Expand Up @@ -650,6 +651,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::pixel_shuffle", op::translate_pixel_shuffle},
{"aten::pixel_unshuffle", op::translate_pixel_unshuffle},
{"aten::prelu", op::translate_1to1_match_2_inputs<opset10::PRelu>},
{"aten::pop", op::translate_pop},
{"aten::polar", op::translate_polar},
{"aten::pow", op::translate_pow},
{"aten::pow_", op::translate_pow},
Expand Down
56 changes: 56 additions & 0 deletions tests/layer_tests/pytorch_tests/test_pop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import numpy as np

from pytorch_layer_test_class import PytorchLayerTest


class aten_pop(torch.nn.Module):
def __init__(self, pop_index: int = -1):
super().__init__()
self.pop_index = pop_index

def forward(self, x):
a = torch.tensor(1, dtype=x.dtype)
b = torch.tensor(2, dtype=x.dtype)
lst = [a, b]
popped = lst.pop(self.pop_index)
return popped.reshape(())


class aten_pop_out(torch.nn.Module):
def __init__(self, pop_index: int = -1):
super().__init__()
self.pop_index = pop_index

def forward(self, x):
a = torch.tensor(1, dtype=x.dtype)
b = torch.tensor(2, dtype=x.dtype)
lst = [a, b]
popped = lst.pop(self.pop_index)
return popped.reshape(()), torch.tensor(self.pop_index, dtype=torch.int64)


class TestPop(PytorchLayerTest):
def _prepare_input(self):
data = np.random.randn(2, 3).astype(np.float32)
return (data,)

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("pop_index", [-1, 0])
def test_pop_no_out(self, pop_index, ie_device, precision, ir_version):
model = aten_pop(pop_index=pop_index)
self._test(model, None, "aten::pop", ie_device, precision, ir_version,
kwargs_to_prepare_input={})

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("pop_index", [-1, 0])
def test_pop_with_out(self, pop_index, ie_device, precision, ir_version):
model = aten_pop_out(pop_index=pop_index)
self._test(model, None, "aten::pop", ie_device, precision, ir_version,
kwargs_to_prepare_input={})
Loading