Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/phi/infermeta/nullary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ void PartialRecvInferMeta(int ring_id,
}

void RandpermInferMeta(int n, DataType dtype, MetaTensor* out) {
PADDLE_ENFORCE_GT(
n,
0,
errors::InvalidArgument("The upper bound %d isn't greater than 0.", n));
out->set_dims(common::make_ddim({n}));
out->set_dtype(dtype);
}
Expand Down
3 changes: 3 additions & 0 deletions test/legacy_test/test_randperm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import paddle
from paddle.base import core
from paddle.pir_utils import test_with_pir_api
from paddle.static import Program, program_guard


Expand Down Expand Up @@ -156,13 +157,15 @@ def verify_output(self, outs):


class TestRandpermOpError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
with program_guard(Program(), Program()):
self.assertRaises(ValueError, paddle.randperm, -3)
self.assertRaises(TypeError, paddle.randperm, 10, 'int8')


class TestRandpermAPI(unittest.TestCase):
@test_with_pir_api
def test_out(self):
n = 10
place = (
Expand Down