From 62f3acd291b249c761469f3fc00c5f22d6f5d76c Mon Sep 17 00:00:00 2001 From: Enigmatisms Date: Mon, 18 Aug 2025 06:54:49 +0000 Subject: [PATCH 1/3] [PHI] Aligned uint8 and int16 atomic funcs --- paddle/phi/backends/gpu/gpu_primitives.h | 89 ++++++++ .../kernels/cpu/put_along_axis_grad_kernel.cc | 1 + .../phi/kernels/cpu/put_along_axis_kernel.cc | 1 + .../cpu/take_along_axis_grad_kernel.cc | 1 + .../phi/kernels/cpu/take_along_axis_kernel.cc | 1 + .../kernels/funcs/gather_scatter_functor.cu | 36 +--- .../kernels/funcs/gather_scatter_functor.h | 27 +-- .../kernels/gpu/put_along_axis_grad_kernel.cu | 2 + .../phi/kernels/gpu/put_along_axis_kernel.cu | 2 + .../gpu/take_along_axis_grad_kernel.cu | 2 + .../phi/kernels/gpu/take_along_axis_kernel.cu | 2 + test/legacy_test/test_put_along_axis_op.py | 199 ++++++++++++++++-- test/legacy_test/test_take_along_axis_op.py | 36 ++++ 13 files changed, 340 insertions(+), 59 deletions(-) diff --git a/paddle/phi/backends/gpu/gpu_primitives.h b/paddle/phi/backends/gpu/gpu_primitives.h index cb2f45db4b7d4c..b028e5a0ee9e08 100644 --- a/paddle/phi/backends/gpu/gpu_primitives.h +++ b/paddle/phi/backends/gpu/gpu_primitives.h @@ -457,6 +457,60 @@ CUDA_ATOMIC_WRAPPER(Mul, float) { return __int_as_float(old); } +__device__ __forceinline__ uint32_t __loadAligned(const uintptr_t base_addr, + uint32_t mask, + uint32_t shift) { + // get 4B aligned address + uint32_t aligned_value = *reinterpret_cast(base_addr); + return (aligned_value & mask) >> shift; +} + +CUDA_ATOMIC_WRAPPER(Mul, uint8_t) { + // get 4D aligned base address + uintptr_t base_addr = reinterpret_cast(address) & (~3); + uint32_t offset = reinterpret_cast(address) - base_addr; + uint32_t shift = offset * 8; + uint32_t mask = 0xFFU << shift; + + uint32_t old32 = __loadAligned(base_addr, mask, shift), assumed32 = 0; + + do { + assumed32 = old32; + uint8_t current = static_cast((old32 & mask) >> shift); + uint8_t new_val = current * val; + uint32_t new32 = + (old32 & ~mask) | (static_cast(new_val) << shift); + + old32 = + atomicCAS(reinterpret_cast(base_addr), assumed32, new32); + } while (assumed32 != old32); + + return static_cast((old32 & mask) >> shift); +} + +CUDA_ATOMIC_WRAPPER(Mul, int16_t) { + // get 4D aligned base address + uintptr_t base_addr = reinterpret_cast(address) & (~3); + uint32_t offset = (reinterpret_cast(address) - base_addr) / 2; + uint32_t shift = offset * 16; + uint32_t mask = 0xFFFFU << shift; + + uint32_t old32 = __loadAligned(base_addr, mask, shift), assumed32 = 0; + + do { + assumed32 = old32; + int16_t current = static_cast((old32 & mask) >> shift); + int16_t new_val = current * val; + uint32_t new32 = + (old32 & ~mask) | (static_cast(new_val) << shift); + + old32 = + atomicCAS(reinterpret_cast(base_addr), assumed32, new32); + } while (assumed32 != old32); + + return static_cast((old32 & mask) >> shift); +} + CUDA_ATOMIC_WRAPPER(Mul, double) { unsigned long long int *const address_as_ull = // NOLINT reinterpret_cast(address); // NOLINT @@ -943,6 +997,41 @@ CUDA_ATOMIC_WRAPPER(Min, phi::dtype::bfloat16) { } } +#define DEFINE_ATOMIC_MINMAX(Dtype, OpType, operator) \ + __device__ __forceinline__ Dtype CudaAtomic##OpType(Dtype *address, \ + const Dtype val) { \ + uintptr_t base_addr = reinterpret_cast(address) & (~3); \ + uint32_t offset_bytes = reinterpret_cast(address) - base_addr; \ + uint32_t shift = 0, mask = 0; \ + if constexpr (sizeof(Dtype) == 1) { \ + shift = offset_bytes * 8; \ + mask = 0xFFU << shift; \ + } else { \ + shift = (offset_bytes / 2) * 16; \ + mask = 0xFFFFU << shift; \ + } \ + Dtype current = 0; \ + Dtype new_val = 0; \ + uint32_t assumed32 = 0, old32 = __loadAligned(base_addr, mask, shift); \ + do { \ + assumed32 = old32; \ + current = static_cast((old32 & mask) >> shift); \ + new_val = operator(current, val); \ + uint32_t new32 = \ + (old32 & ~mask) | (static_cast(new_val) << shift); \ + old32 = atomicCAS( \ + reinterpret_cast(base_addr), assumed32, new32); \ + } while (assumed32 != old32); \ + return current; \ + } + +DEFINE_ATOMIC_MINMAX(int16_t, Min, min) +DEFINE_ATOMIC_MINMAX(int16_t, Max, max) +DEFINE_ATOMIC_MINMAX(uint8_t, Min, min) +DEFINE_ATOMIC_MINMAX(uint8_t, Max, max) + +#undef DEFINE_ATOMIC_MINMAX + #ifdef PADDLE_WITH_CUDA /* * One thead block deals with elementwise atomicAdd for vector of len. diff --git a/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc b/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc index d1cb1c070ee7da..fd2cd8b0401728 100644 --- a/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc @@ -180,5 +180,6 @@ PD_REGISTER_KERNEL(put_along_axis_grad, float, double, int, + int16_t, uint8_t, int64_t) {} diff --git a/paddle/phi/kernels/cpu/put_along_axis_kernel.cc b/paddle/phi/kernels/cpu/put_along_axis_kernel.cc index c1bb2e3af280f5..ed096c6e1359d7 100644 --- a/paddle/phi/kernels/cpu/put_along_axis_kernel.cc +++ b/paddle/phi/kernels/cpu/put_along_axis_kernel.cc @@ -103,5 +103,6 @@ PD_REGISTER_KERNEL(put_along_axis, float, double, int, + int16_t, uint8_t, int64_t) {} diff --git a/paddle/phi/kernels/cpu/take_along_axis_grad_kernel.cc b/paddle/phi/kernels/cpu/take_along_axis_grad_kernel.cc index 5abc80811310f8..fe8881813dc9f5 100644 --- a/paddle/phi/kernels/cpu/take_along_axis_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/take_along_axis_grad_kernel.cc @@ -66,5 +66,6 @@ PD_REGISTER_KERNEL(take_along_axis_grad, float, double, int, + int16_t, uint8_t, int64_t) {} diff --git a/paddle/phi/kernels/cpu/take_along_axis_kernel.cc b/paddle/phi/kernels/cpu/take_along_axis_kernel.cc index 8adeec21ae6cd9..33b623df1fab10 100644 --- a/paddle/phi/kernels/cpu/take_along_axis_kernel.cc +++ b/paddle/phi/kernels/cpu/take_along_axis_kernel.cc @@ -65,5 +65,6 @@ PD_REGISTER_KERNEL(take_along_axis, float, double, int, + int16_t, uint8_t, int64_t) {} diff --git a/paddle/phi/kernels/funcs/gather_scatter_functor.cu b/paddle/phi/kernels/funcs/gather_scatter_functor.cu index 0814c5882dab84..5151132bf83d50 100644 --- a/paddle/phi/kernels/funcs/gather_scatter_functor.cu +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.cu @@ -31,65 +31,37 @@ static TensorAssign tensor_assign; class ReduceAdd { public: - template < - typename tensor_t, - std::enable_if_t::value>* = nullptr> + template __device__ void operator()(tensor_t* self_data, tensor_t* src_data) const { phi::CudaAtomicAdd(self_data, *src_data); } - template ::value>* = nullptr> - __device__ void operator()(tensor_t* self_data, tensor_t* src_data) const { - *self_data += *src_data; - } }; static ReduceAdd reduce_add; class ReduceMul { public: - template < - typename tensor_t, - std::enable_if_t::value>* = nullptr> + template __device__ void operator()(tensor_t* self_data, tensor_t* src_data) const { phi::CudaAtomicMul(self_data, *src_data); } - template ::value>* = nullptr> - __device__ void operator()(tensor_t* self_data, tensor_t* src_data) const { - *self_data *= *src_data; - } }; static ReduceMul reduce_mul; class ReduceMax { public: - template < - typename tensor_t, - std::enable_if_t::value>* = nullptr> + template __device__ void operator()(tensor_t* self_data, tensor_t* src_data) const { phi::CudaAtomicMax(self_data, *src_data); } - template ::value>* = nullptr> - __device__ void operator()(tensor_t* self_data, tensor_t* src_data) const { - *self_data = *src_data > *self_data ? *src_data : *self_data; - } }; static ReduceMax reduce_max; class ReduceMin { public: - template < - typename tensor_t, - std::enable_if_t::value>* = nullptr> + template __device__ void operator()(tensor_t* self_data, tensor_t* src_data) const { phi::CudaAtomicMin(self_data, *src_data); } - template ::value>* = nullptr> - __device__ void operator()(tensor_t* self_data, tensor_t* src_data) const { - *self_data = *src_data < *self_data ? *src_data : *self_data; - } }; static ReduceMin reduce_min; diff --git a/paddle/phi/kernels/funcs/gather_scatter_functor.h b/paddle/phi/kernels/funcs/gather_scatter_functor.h index d27b42d499f2f5..4f2a9dd26d7a82 100644 --- a/paddle/phi/kernels/funcs/gather_scatter_functor.h +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.h @@ -29,7 +29,8 @@ namespace funcs { Instantiate_Template_Function_index_t(func, phi::dtype::float16) \ Instantiate_Template_Function_index_t(func, \ phi::dtype::bfloat16) \ - Instantiate_Template_Function_index_t(func, unsigned char) + Instantiate_Template_Function_index_t(func, unsigned char) \ + Instantiate_Template_Function_index_t(func, int16_t) #define Instantiate_Template_Function_index_t(func, tensor_t) \ template void func(phi::DenseTensor input, \ @@ -45,17 +46,19 @@ namespace funcs { bool include_self, \ const phi::DeviceContext& dev_ctx); -#define Instantiate_Template_Function_With_Out(func) \ - Instantiate_Template_Function_index_t_With_Out(func, int) \ - Instantiate_Template_Function_index_t_With_Out(func, float) \ - Instantiate_Template_Function_index_t_With_Out(func, double) \ - Instantiate_Template_Function_index_t_With_Out(func, int64_t) \ - Instantiate_Template_Function_index_t_With_Out( \ - func, phi::dtype::float16) \ - Instantiate_Template_Function_index_t_With_Out( \ - func, phi::dtype::bfloat16) \ - Instantiate_Template_Function_index_t_With_Out( \ - func, unsigned char) +#define Instantiate_Template_Function_With_Out(func) \ + Instantiate_Template_Function_index_t_With_Out(func, int) \ + Instantiate_Template_Function_index_t_With_Out(func, float) \ + Instantiate_Template_Function_index_t_With_Out(func, double) \ + Instantiate_Template_Function_index_t_With_Out(func, int64_t) \ + Instantiate_Template_Function_index_t_With_Out( \ + func, phi::dtype::float16) \ + Instantiate_Template_Function_index_t_With_Out( \ + func, phi::dtype::bfloat16) \ + Instantiate_Template_Function_index_t_With_Out( \ + func, unsigned char) \ + Instantiate_Template_Function_index_t_With_Out( \ + func, int16_t) #define Instantiate_Template_Function_index_t_With_Out(func, tensor_t) \ template void func(phi::DenseTensor input, \ int dim, \ diff --git a/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu b/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu index 640001c4ffc385..db5d1c655e2904 100644 --- a/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu @@ -179,5 +179,7 @@ PD_REGISTER_KERNEL(put_along_axis_grad, double, int64_t, int, + int16_t, + uint8_t, phi::dtype::float16, phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/put_along_axis_kernel.cu b/paddle/phi/kernels/gpu/put_along_axis_kernel.cu index bb2d4ec542c70a..86e1387f0f029e 100644 --- a/paddle/phi/kernels/gpu/put_along_axis_kernel.cu +++ b/paddle/phi/kernels/gpu/put_along_axis_kernel.cu @@ -102,6 +102,8 @@ PD_REGISTER_KERNEL(put_along_axis, float, double, int64_t, + uint8_t, + int16_t, int, phi::dtype::float16, phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu b/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu index d23f0c0c6ee503..935ef6fcb7b4d3 100644 --- a/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu @@ -73,5 +73,7 @@ PD_REGISTER_KERNEL(take_along_axis_grad, double, int64_t, int, + int16_t, + uint8_t, phi::dtype::float16, phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/take_along_axis_kernel.cu b/paddle/phi/kernels/gpu/take_along_axis_kernel.cu index 10ff63488fbcc7..12f717591fb75f 100644 --- a/paddle/phi/kernels/gpu/take_along_axis_kernel.cu +++ b/paddle/phi/kernels/gpu/take_along_axis_kernel.cu @@ -71,5 +71,7 @@ PD_REGISTER_KERNEL(take_along_axis, double, int64_t, int, + int16_t, + uint8_t, phi::dtype::float16, phi::dtype::bfloat16) {} diff --git a/test/legacy_test/test_put_along_axis_op.py b/test/legacy_test/test_put_along_axis_op.py index 96e994f01e5301..c1a683021edd47 100644 --- a/test/legacy_test/test_put_along_axis_op.py +++ b/test/legacy_test/test_put_along_axis_op.py @@ -79,6 +79,89 @@ def init_data(self): self.axis_type = "int64" +class TestPutAlongAxisInt16OpBase(TestPutAlongAxisOp): + def init_data(self): + self.set_type() + self.x_shape = (10, 10, 10) + self.index_type = "int64" + self.axis = 1 + self.axis_type = "int64" + self.set_reduce_op() + self.set_value_and_index() + + def set_type(self): + self.dtype = np.int16 + self.x_type = "int16" + self.value_type = "int16" + + def set_value_and_index(self): + self.value = np.array([99]).astype(self.value_type) + self.index = np.array([[[0]]]).astype(self.index_type) + + def set_reduce_op(self): + self.reduce_op = "assign" + + def test_check_grad(self): + """int16 can not pass check_grad data type check for op multiply""" + pass + + +class TestPutAlongAxisUInt8OpBase(TestPutAlongAxisInt16OpBase): + def set_type(self): + self.dtype = np.uint8 + self.x_type = "uint8" + self.value_type = "uint8" + + def set_reduce_op(self): + self.reduce_op = "assign" + self.value = np.array([127]).astype(self.value_type) + self.index = np.array([[[0]]]).astype(self.index_type) + + def test_check_grad(self): + """uint8 can not pass check_grad data type check for op multiply""" + pass + + +class TestPutAlongAxisInt16OpAdd(TestPutAlongAxisInt16OpBase): + def set_reduce_op(self): + self.reduce_op = "add" + + +class TestPutAlongAxisInt16OpMul(TestPutAlongAxisInt16OpBase): + def set_reduce_op(self): + self.reduce_op = "mul" + + +class TestPutAlongAxisInt16OpAMin(TestPutAlongAxisInt16OpBase): + def set_reduce_op(self): + self.reduce_op = "amin" + + +class TestPutAlongAxisInt16OpAMax(TestPutAlongAxisInt16OpBase): + def set_reduce_op(self): + self.reduce_op = "amax" + + +class TestPutAlongAxisUInt8OpAdd(TestPutAlongAxisUInt8OpBase): + def set_reduce_op(self): + self.reduce_op = "add" + + +class TestPutAlongAxisUInt8OpMul(TestPutAlongAxisUInt8OpBase): + def set_reduce_op(self): + self.reduce_op = "mul" + + +class TestPutAlongAxisUInt8OpAMin(TestPutAlongAxisUInt8OpBase): + def set_reduce_op(self): + self.reduce_op = "amin" + + +class TestPutAlongAxisUInt8OpAMax(TestPutAlongAxisUInt8OpBase): + def set_reduce_op(self): + self.reduce_op = "amax" + + class TestPutAlongAxisFP16Op(TestPutAlongAxisOp): def init_data(self): self.dtype = np.float16 @@ -1259,31 +1342,63 @@ def run(place): not core.is_compiled_with_cuda(), "core is not compiled with CUDA", ) -class TestPutAlongAxisAPIMulUint8(unittest.TestCase): +class TestPutAlongAxisAPIReduceLowBits(unittest.TestCase): def setUp(self): np.random.seed(0) - self.dtype = 'uint8' - self.x_type = "uint8" - self.x_shape = (10, 10, 10) - self.value_type = "uint8" - self.value = np.random.randint(1, 5, (5, 5, 5)).astype(self.value_type) + self.setup_dtype() + self.set_range() + self.set_op_to_test() + self.x_shape = (8, 8) + self.value = np.random.randint(*self.ranges, (8, 8)).astype( + self.value_type + ) self.index_type = "int64" - self.index = np.zeros((5, 5, 5)).astype(self.index_type) + self.index = np.ones((8, 8), dtype=np.int64) self.axis = 1 self.axis_type = "int64" self.op_type = "put_along_axis" self.prim_op_type = "prim" self.public_python_api = paddle.tensor.put_along_axis self.python_api = paddle.tensor.put_along_axis - self.xnp = np.random.randint(1, 5, self.x_shape).astype(self.x_type) + self.xnp = np.random.randint(*self.ranges, self.x_shape).astype( + self.x_type + ) + self.input_filter() # numpy put_along_axis is an inplace operation. self.target = copy.deepcopy(self.xnp) - for i in range(5): - for j in range(5): - for k in range(5): - self.target[i, self.index[i, j, k], k] *= self.value[ - i, j, k - ] + if self.op == "mul": + host_op = lambda x, y: x * y + elif self.op == "amax": + host_op = lambda x, y: max(x, y) + elif self.op == "amin": + host_op = lambda x, y: min(x, y) + else: + raise ValueError( + f"Unsupported reduce op for put along axis: {self.op}" + ) + for i in range(8): + for j in range(8): + self.target[i, self.index[i, j]] = host_op( + self.target[i, self.index[i, j]], self.value[i, j] + ) + + def input_filter(self): + if self.ranges[0] <= 0 and self.op == "mul": + is_zero = self.values == 0 + self.values[is_zero] = 1 + is_zero = self.xnp == 0 + self.xnp[is_zero] = 1 + + def setup_dtype(self): + self.dtype = 'uint8' + self.x_type = "uint8" + self.value_type = "uint8" + + def set_range(self): + self.ranges = [1, 5] + + def set_op_to_test(self): + self.op = "mul" def test_api_dygraph(self): def run(place): @@ -1296,7 +1411,7 @@ def run(place): index_tensor, value_tensor, self.axis, - "mul", + self.op, True, False, ) @@ -1306,6 +1421,60 @@ def run(place): run(paddle.CUDAPlace(0)) +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not compiled with CUDA", +) +class TestPutAlongAxisAPIMulInt16(TestPutAlongAxisAPIReduceLowBits): + def setup_dtype(self): + self.dtype = 'int16' + self.x_type = "int16" + self.value_type = "int16" + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not compiled with CUDA", +) +class TestPutAlongAxisAPIMinInt16(TestPutAlongAxisAPIMulInt16): + def set_range(self): + self.ranges = [-32760, 32761] + + def set_op_to_test(self): + self.op = "amin" + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not compiled with CUDA", +) +class TestPutAlongAxisAPIMaxInt16(TestPutAlongAxisAPIMinInt16): + def set_op_to_test(self): + self.op = "amax" + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not compiled with CUDA", +) +class TestPutAlongAxisAPIMinUInt8(TestPutAlongAxisAPIReduceLowBits): + def set_range(self): + self.ranges = [0, 256] + + def set_op_to_test(self): + self.op = "amin" + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not compiled with CUDA", +) +class TestPutAlongAxisAPIMaxUInt8(TestPutAlongAxisAPIMinUInt8): + + def set_op_to_test(self): + self.op = "amax" + + class TestPutAlongAxisDynamicShape(unittest.TestCase): def setUp(self): np.random.seed(2024) diff --git a/test/legacy_test/test_take_along_axis_op.py b/test/legacy_test/test_take_along_axis_op.py index 72b266b4dccd78..15569180d2b856 100644 --- a/test/legacy_test/test_take_along_axis_op.py +++ b/test/legacy_test/test_take_along_axis_op.py @@ -462,6 +462,42 @@ def test_check_grad(self): ) +class TestTakeAlongAxisInt16(TestTakeAlongAxisOp): + def init_data(self): + self.dtype = np.int16 + self.x_type = "int16" + self.x_shape = (5, 5, 5) + self.index_type = "int32" + self.axis = 2 + dim_size = self.x_shape[self.axis] + self.index = np.random.randint( + -dim_size, dim_size, size=(5, 1, 1) + ).astype(self.index_type) + self.axis_type = "int64" + + def test_check_grad(self): + """int16 does not require and allow for grad check""" + pass + + +class TestTakeAlongAxisUInt8(TestTakeAlongAxisOp): + def init_data(self): + self.dtype = np.uint8 + self.x_type = "uint8" + self.x_shape = (5, 5, 5) + self.index_type = "int32" + self.axis = 2 + dim_size = self.x_shape[self.axis] + self.index = np.random.randint( + -dim_size, dim_size, size=(5, 1, 1) + ).astype(self.index_type) + self.axis_type = "int64" + + def test_check_grad(self): + """uint8 does not require and allow for grad check""" + pass + + if __name__ == "__main__": paddle.enable_static() unittest.main() From 8bf312838888f3bf49a4a72055e5a671a7528e8d Mon Sep 17 00:00:00 2001 From: Enigmatisms Date: Mon, 18 Aug 2025 07:29:51 +0000 Subject: [PATCH 2/3] [PHI] Removed some of the GPU only constraints. --- test/legacy_test/test_put_along_axis_op.py | 28 ++++------------------ 1 file changed, 4 insertions(+), 24 deletions(-) diff --git a/test/legacy_test/test_put_along_axis_op.py b/test/legacy_test/test_put_along_axis_op.py index c1a683021edd47..a8a934b532ec19 100644 --- a/test/legacy_test/test_put_along_axis_op.py +++ b/test/legacy_test/test_put_along_axis_op.py @@ -80,6 +80,8 @@ def init_data(self): class TestPutAlongAxisInt16OpBase(TestPutAlongAxisOp): + no_need_check_grad = True + def init_data(self): self.set_type() self.x_shape = (10, 10, 10) @@ -107,6 +109,8 @@ def test_check_grad(self): class TestPutAlongAxisUInt8OpBase(TestPutAlongAxisInt16OpBase): + no_need_check_grad = True + def set_type(self): self.dtype = np.uint8 self.x_type = "uint8" @@ -1338,10 +1342,6 @@ def run(place): run(paddle.CUDAPlace(0)) -@unittest.skipIf( - not core.is_compiled_with_cuda(), - "core is not compiled with CUDA", -) class TestPutAlongAxisAPIReduceLowBits(unittest.TestCase): def setUp(self): np.random.seed(0) @@ -1421,10 +1421,6 @@ def run(place): run(paddle.CUDAPlace(0)) -@unittest.skipIf( - not core.is_compiled_with_cuda(), - "core is not compiled with CUDA", -) class TestPutAlongAxisAPIMulInt16(TestPutAlongAxisAPIReduceLowBits): def setup_dtype(self): self.dtype = 'int16' @@ -1432,10 +1428,6 @@ def setup_dtype(self): self.value_type = "int16" -@unittest.skipIf( - not core.is_compiled_with_cuda(), - "core is not compiled with CUDA", -) class TestPutAlongAxisAPIMinInt16(TestPutAlongAxisAPIMulInt16): def set_range(self): self.ranges = [-32760, 32761] @@ -1444,19 +1436,11 @@ def set_op_to_test(self): self.op = "amin" -@unittest.skipIf( - not core.is_compiled_with_cuda(), - "core is not compiled with CUDA", -) class TestPutAlongAxisAPIMaxInt16(TestPutAlongAxisAPIMinInt16): def set_op_to_test(self): self.op = "amax" -@unittest.skipIf( - not core.is_compiled_with_cuda(), - "core is not compiled with CUDA", -) class TestPutAlongAxisAPIMinUInt8(TestPutAlongAxisAPIReduceLowBits): def set_range(self): self.ranges = [0, 256] @@ -1465,10 +1449,6 @@ def set_op_to_test(self): self.op = "amin" -@unittest.skipIf( - not core.is_compiled_with_cuda(), - "core is not compiled with CUDA", -) class TestPutAlongAxisAPIMaxUInt8(TestPutAlongAxisAPIMinUInt8): def set_op_to_test(self): From 23f9ee88a316f22eae4e7b8a1bd192687eea9b12 Mon Sep 17 00:00:00 2001 From: Enigmatisms Date: Mon, 18 Aug 2025 08:06:40 +0000 Subject: [PATCH 3/3] [PHI] Fixed put_along_axis CPU end test error --- test/legacy_test/test_put_along_axis_op.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/legacy_test/test_put_along_axis_op.py b/test/legacy_test/test_put_along_axis_op.py index a8a934b532ec19..4d310af2fca7df 100644 --- a/test/legacy_test/test_put_along_axis_op.py +++ b/test/legacy_test/test_put_along_axis_op.py @@ -1418,7 +1418,11 @@ def run(place): out_ref = self.target np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001) - run(paddle.CUDAPlace(0)) + run( + paddle.CUDAPlace(0) + if core.is_compiled_with_cuda() + else paddle.CPUPlace() + ) class TestPutAlongAxisAPIMulInt16(TestPutAlongAxisAPIReduceLowBits):