Skip to content

Commit 6445c77

Browse files
DrRyanHuangSecretXV
authored andcommitted
【PIR api adaptor No.242、228】 Migrate unique_consecutive/moveaxis into pir (PaddlePaddle#58688)
1 parent 0ec982e commit 6445c77

9 files changed

Lines changed: 19 additions & 19 deletions

File tree

paddle/phi/api/yaml/ops.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2683,7 +2683,7 @@
26832683
backward: uniform_inplace_grad
26842684

26852685
- op : unique_consecutive
2686-
args : (Tensor x, bool return_inverse = false, bool return_counts = false, int[] axis = {}, int dtype = 5)
2686+
args : (Tensor x, bool return_inverse = false, bool return_counts = false, int[] axis = {}, DataType dtype = DataType::FLOAT32)
26872687
output : Tensor(out), Tensor(index), Tensor(counts)
26882688
infer_meta :
26892689
func : UniqueConsecutiveInferMeta

paddle/phi/infermeta/unary.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4852,7 +4852,7 @@ void UniqueConsecutiveInferMeta(const MetaTensor& x,
48524852
bool return_inverse,
48534853
bool return_counts,
48544854
const std::vector<int>& axis,
4855-
int dtype,
4855+
DataType dtype,
48564856
MetaTensor* out,
48574857
MetaTensor* index,
48584858
MetaTensor* counts) {

paddle/phi/infermeta/unary.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,7 @@ void UniqueConsecutiveInferMeta(const MetaTensor& x,
716716
bool return_inverse,
717717
bool return_counts,
718718
const std::vector<int>& axis,
719-
int dtype,
719+
DataType dtype,
720720
MetaTensor* out,
721721
MetaTensor* index,
722722
MetaTensor* counts);

paddle/phi/kernels/cpu/unique_consecutive_kernel.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,11 @@ void UniqueConsecutiveKernel(const Context& dev_ctx,
3030
bool return_inverse,
3131
bool return_counts,
3232
const std::vector<int>& axis,
33-
int dtype,
33+
DataType dtype,
3434
DenseTensor* out,
3535
DenseTensor* index,
3636
DenseTensor* counts) {
37-
auto data_type = phi::TransToPhiDataType(dtype);
38-
if (data_type == phi::DataType::INT32) {
37+
if (dtype == phi::DataType::INT32) {
3938
PADDLE_ENFORCE_LE(
4039
x.numel(),
4140
INT_MAX,
@@ -48,14 +47,14 @@ void UniqueConsecutiveKernel(const Context& dev_ctx,
4847

4948
if (axis.empty()) {
5049
phi::VisitDataTypeTiny(
51-
data_type,
50+
dtype,
5251
UniqueConsecutiveFlattenedTensorFunctor<Context, T>(
5352
dev_ctx, x, out, return_inverse, return_counts, index, counts));
5453
} else {
5554
int valid_axis = axis[0];
5655
if (valid_axis < 0) valid_axis += x.dims().size();
5756
phi::VisitDataTypeTiny(
58-
data_type,
57+
dtype,
5958
UniqueConsecutiveDimFunctor<Context, T>(dev_ctx,
6059
x,
6160
out,

paddle/phi/kernels/gpu/unique_consecutive_kernel.cu

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,11 @@ void UniqueConsecutiveKernel(const Context& dev_ctx,
2929
bool return_inverse,
3030
bool return_counts,
3131
const std::vector<int>& axis,
32-
int dtype,
32+
DataType dtype,
3333
DenseTensor* out,
3434
DenseTensor* index,
3535
DenseTensor* counts) {
36-
auto data_type = phi::TransToPhiDataType(dtype);
37-
if (data_type == phi::DataType::INT32) {
36+
if (dtype == phi::DataType::INT32) {
3837
PADDLE_ENFORCE_LE(
3938
x.numel() + 1,
4039
INT_MAX,
@@ -48,15 +47,15 @@ void UniqueConsecutiveKernel(const Context& dev_ctx,
4847
// if 'axis' is not required, flatten the Tensor.
4948
if (axis.empty()) {
5049
phi::VisitDataTypeTiny(
51-
data_type,
50+
dtype,
5251
UniqueConsecutiveFlattenedCUDAFunctor<Context, T>(
5352
dev_ctx, x, out, return_inverse, return_counts, index, counts));
5453
} else {
5554
// 'axis' is required.
5655
int valid_axis = axis[0];
5756
if (valid_axis < 0) valid_axis += x.dims().size();
5857
phi::VisitDataTypeTiny(
59-
data_type,
58+
dtype,
6059
UniqueConsecutiveDimsCUDAFunctor<Context, T>(dev_ctx,
6160
x,
6261
out,

paddle/phi/kernels/unique_consecutive_kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ void UniqueConsecutiveKernel(const Context& dev_ctx,
2626
bool return_inverse,
2727
bool return_counts,
2828
const std::vector<int>& axis,
29-
int dtype,
29+
DataType dtype,
3030
DenseTensor* out,
3131
DenseTensor* index,
3232
DenseTensor* counts);

python/paddle/tensor/manipulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2482,7 +2482,7 @@ def unique_consecutive(
24822482
else:
24832483
axis = [axis]
24842484
attr_dtype = convert_np_dtype_to_dtype_(dtype)
2485-
if in_dynamic_mode():
2485+
if in_dynamic_or_pir_mode():
24862486
out, inverse, counts = _C_ops.unique_consecutive(
24872487
x, return_inverse, return_counts, axis, attr_dtype
24882488
)

test/legacy_test/test_transpose_op.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,7 @@ def test_moveaxis3(self):
710710
self.assertEqual(out.shape, [2, 3])
711711
paddle.enable_static()
712712

713+
@test_with_pir_api
713714
def test_error(self):
714715
x = paddle.randn([2, 3, 4, 5])
715716
# src must have the same number with dst

test/legacy_test/test_unique_consecutive_op.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import paddle
2121
from paddle import base
2222
from paddle.base import core
23+
from paddle.pir_utils import test_with_pir_api
2324

2425

2526
def reference_unique_consecutive(
@@ -203,6 +204,7 @@ def setUp(self):
203204
if core.is_compiled_with_cuda():
204205
self.places.append(base.CUDAPlace(0))
205206

207+
@test_with_pir_api
206208
def check_static_result(self, place):
207209
with base.program_guard(base.Program(), base.Program()):
208210
paddle.enable_static()
@@ -217,7 +219,6 @@ def check_static_result(self, place):
217219
x_np = np.random.randint(20, size=100).astype("float32")
218220
exe = base.Executor(place)
219221
fetches = exe.run(
220-
base.default_main_program(),
221222
feed={"input_x": x_np},
222223
fetch_list=[result],
223224
)
@@ -240,6 +241,7 @@ def setUp(self):
240241
if core.is_compiled_with_cuda():
241242
self.places.append(base.CUDAPlace(0))
242243

244+
@test_with_pir_api
243245
def check_static_result(self, place):
244246
with base.program_guard(base.Program(), base.Program()):
245247
paddle.enable_static()
@@ -256,7 +258,6 @@ def check_static_result(self, place):
256258
x_np = np.random.randint(20, size=100).astype("float32")
257259
exe = base.Executor(place)
258260
fetches = exe.run(
259-
base.default_main_program(),
260261
feed={"input_x": x_np},
261262
fetch_list=[result],
262263
)
@@ -281,6 +282,7 @@ def setUp(self):
281282
if core.is_compiled_with_cuda():
282283
self.places.append(base.CUDAPlace(0))
283284

285+
@test_with_pir_api
284286
def check_static_result(self, place):
285287
with base.program_guard(base.Program(), base.Program()):
286288
paddle.enable_static()
@@ -297,7 +299,6 @@ def check_static_result(self, place):
297299
x_np = np.random.randint(20, size=100).astype("float32")
298300
exe = base.Executor(place)
299301
fetches = exe.run(
300-
base.default_main_program(),
301302
feed={"input_x": x_np},
302303
fetch_list=[result],
303304
)
@@ -347,7 +348,7 @@ def setUp(self):
347348
}
348349

349350
def test_check_output(self):
350-
self.check_output()
351+
self.check_output(check_pir=True)
351352

352353

353354
if __name__ == "__main__":

0 commit comments

Comments
 (0)