Skip to content

Commit 9e886c7

Browse files
[API] isclose support bigtensor (#72516)
* isclose support bigtensor * refine
1 parent ddfc630 commit 9e886c7

File tree

1 file changed

+81
-21
lines changed

1 file changed

+81
-21
lines changed

paddle/phi/kernels/impl/isclose_kernel_impl.h

Lines changed: 81 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,12 @@ struct IscloseFunctor<phi::CPUContext, T> {
6565
auto* in_a = in.data<T>();
6666
auto* in_b = other.data<T>();
6767
auto* out_data = ctx.template Alloc<bool>(output);
68-
auto num = in.numel();
68+
int64_t num = in.numel();
6969
// *out_data = true;
70-
for (int i = 0; i < num; i++) {
70+
for (int64_t i = 0; i < num; i++) {
7171
out_data[i] = true;
7272
}
73-
for (int i = 0; i < num; i++) {
73+
for (int64_t i = 0; i < num; i++) {
7474
const T a = in_a[i], b = in_b[i];
7575
bool val;
7676
if (std::isnan(a) || std::isnan(b)) {
@@ -99,12 +99,12 @@ struct IscloseFunctor<phi::CPUContext, phi::dtype::complex<T>> {
9999
auto* in_a = in.data<phi::dtype::complex<T>>();
100100
auto* in_b = other.data<phi::dtype::complex<T>>();
101101
auto* out_data = ctx.template Alloc<bool>(output);
102-
auto num = in.numel();
102+
int64_t num = in.numel();
103103
// *out_data = true;
104-
for (int i = 0; i < num; i++) {
104+
for (int64_t i = 0; i < num; i++) {
105105
out_data[i] = true;
106106
}
107-
for (int i = 0; i < num; i++) {
107+
for (int64_t i = 0; i < num; i++) {
108108
const phi::dtype::complex<T> a = in_a[i], b = in_b[i];
109109
bool val;
110110
if (std::isnan(a) || std::isnan(b)) {
@@ -122,18 +122,18 @@ struct IscloseFunctor<phi::CPUContext, phi::dtype::complex<T>> {
122122
};
123123

124124
#if defined(__NVCC__) || defined(__HIPCC__)
125-
template <typename T>
125+
template <typename T, typename IndexType>
126126
__global__ void IscloseCUDAKernel(const T* in_data,
127127
const T* other_data,
128128
const double rtol,
129129
const double atol,
130130
bool equal_nan,
131-
int num,
131+
IndexType num,
132132
bool* out_data) {
133-
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
133+
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
134134
bool val;
135135
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
136-
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
136+
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
137137
const MPType a = static_cast<MPType>(in_data[i]);
138138
const MPType b = static_cast<MPType>(other_data[i]);
139139
if (isnan(a) || isnan(b)) {
@@ -149,17 +149,44 @@ __global__ void IscloseCUDAKernel(const T* in_data,
149149
}
150150
}
151151
template <>
152-
__global__ void IscloseCUDAKernel<phi::dtype::complex<float>>(
152+
__global__ void IscloseCUDAKernel<phi::dtype::complex<float>, unsigned int>(
153153
const phi::dtype::complex<float>* in_data,
154154
const phi::dtype::complex<float>* other_data,
155155
const double rtol,
156156
const double atol,
157157
bool equal_nan,
158-
int num,
158+
unsigned int num,
159159
bool* out_data) {
160160
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
161161
bool val;
162-
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
162+
for (unsigned int i = idx; i < num; i += blockDim.x * gridDim.x) {
163+
const phi::dtype::complex<float> a = in_data[i];
164+
const phi::dtype::complex<float> b = other_data[i];
165+
if (isnan(a) || isnan(b)) {
166+
val = equal_nan && isnan(a) == isnan(b);
167+
} else {
168+
float left = abs(a - b);
169+
float right = atol + rtol * abs(b);
170+
float diff = abs(left - right);
171+
val = a == b || left <= right || diff <= 1e-15;
172+
}
173+
out_data[i] = val;
174+
// if (!val) *out_data = false;
175+
}
176+
}
177+
178+
template <>
179+
__global__ void IscloseCUDAKernel<phi::dtype::complex<float>, int64_t>(
180+
const phi::dtype::complex<float>* in_data,
181+
const phi::dtype::complex<float>* other_data,
182+
const double rtol,
183+
const double atol,
184+
bool equal_nan,
185+
int64_t num,
186+
bool* out_data) {
187+
int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
188+
bool val;
189+
for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
163190
const phi::dtype::complex<float> a = in_data[i];
164191
const phi::dtype::complex<float> b = other_data[i];
165192
if (isnan(a) || isnan(b)) {
@@ -176,17 +203,17 @@ __global__ void IscloseCUDAKernel<phi::dtype::complex<float>>(
176203
}
177204

178205
template <>
179-
__global__ void IscloseCUDAKernel<phi::dtype::complex<double>>(
206+
__global__ void IscloseCUDAKernel<phi::dtype::complex<double>, unsigned int>(
180207
const phi::dtype::complex<double>* in_data,
181208
const phi::dtype::complex<double>* other_data,
182209
const double rtol,
183210
const double atol,
184211
bool equal_nan,
185-
int num,
212+
unsigned int num,
186213
bool* out_data) {
187214
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
188215
bool val;
189-
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
216+
for (unsigned int i = idx; i < num; i += blockDim.x * gridDim.x) {
190217
const phi::dtype::complex<double> a = in_data[i];
191218
const phi::dtype::complex<double> b = other_data[i];
192219
if (isnan(a) || isnan(b)) {
@@ -201,6 +228,34 @@ __global__ void IscloseCUDAKernel<phi::dtype::complex<double>>(
201228
// if (!val) *out_data = false;
202229
}
203230
}
231+
232+
template <>
233+
__global__ void IscloseCUDAKernel<phi::dtype::complex<double>, int64_t>(
234+
const phi::dtype::complex<double>* in_data,
235+
const phi::dtype::complex<double>* other_data,
236+
const double rtol,
237+
const double atol,
238+
bool equal_nan,
239+
int64_t num,
240+
bool* out_data) {
241+
int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
242+
bool val;
243+
for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
244+
const phi::dtype::complex<double> a = in_data[i];
245+
const phi::dtype::complex<double> b = other_data[i];
246+
if (isnan(a) || isnan(b)) {
247+
val = equal_nan && isnan(a) == isnan(b);
248+
} else {
249+
double left = abs(a - b);
250+
double right = atol + rtol * abs(b);
251+
double diff = abs(left - right);
252+
val = a == b || left <= right || diff <= 1e-15;
253+
}
254+
out_data[i] = val;
255+
// if (!val) *out_data = false;
256+
}
257+
}
258+
204259
template <typename T>
205260
struct GetTensorValue<phi::GPUContext, T> {
206261
T operator()(const phi::GPUContext& dev_ctx,
@@ -223,20 +278,25 @@ struct IscloseFunctor<phi::GPUContext, T> {
223278
const double atol,
224279
bool equal_nan,
225280
DenseTensor* output) {
226-
int num = in.numel();
281+
int64_t num = in.numel();
227282
const T* in_data = in.data<T>();
228283
const T* other_data = other.data<T>();
229284
bool* out_data = dev_ctx.template Alloc<bool>(output);
230-
int block = 1024;
231-
int grid = (block - 1 + num) / block;
285+
int64_t block = 1024;
286+
int64_t grid = (block - 1 + num) / block;
232287
grid = (grid > block) ? block : grid;
233288
#ifdef PADDLE_WITH_HIP
234289
hipMemset(out_data, true, num * sizeof(bool));
235290
#else
236291
cudaMemset(out_data, true, num * sizeof(bool));
237292
#endif
238-
IscloseCUDAKernel<T><<<grid, block, 0, dev_ctx.stream()>>>(
239-
in_data, other_data, rtol, atol, equal_nan, num, out_data);
293+
if (num + grid * block + 1 > std::numeric_limits<unsigned int>::max()) {
294+
IscloseCUDAKernel<T, int64_t><<<grid, block, 0, dev_ctx.stream()>>>(
295+
in_data, other_data, rtol, atol, equal_nan, num, out_data);
296+
} else {
297+
IscloseCUDAKernel<T, unsigned int><<<grid, block, 0, dev_ctx.stream()>>>(
298+
in_data, other_data, rtol, atol, equal_nan, num, out_data);
299+
}
240300
}
241301
};
242302
#endif

0 commit comments

Comments
 (0)