@@ -65,12 +65,12 @@ struct IscloseFunctor<phi::CPUContext, T> {
6565 auto * in_a = in.data <T>();
6666 auto * in_b = other.data <T>();
6767 auto * out_data = ctx.template Alloc <bool >(output);
68- auto num = in.numel ();
68+ int64_t num = in.numel ();
6969 // *out_data = true;
70- for (int i = 0 ; i < num; i++) {
70+ for (int64_t i = 0 ; i < num; i++) {
7171 out_data[i] = true ;
7272 }
73- for (int i = 0 ; i < num; i++) {
73+ for (int64_t i = 0 ; i < num; i++) {
7474 const T a = in_a[i], b = in_b[i];
7575 bool val;
7676 if (std::isnan (a) || std::isnan (b)) {
@@ -99,12 +99,12 @@ struct IscloseFunctor<phi::CPUContext, phi::dtype::complex<T>> {
9999 auto * in_a = in.data <phi::dtype::complex <T>>();
100100 auto * in_b = other.data <phi::dtype::complex <T>>();
101101 auto * out_data = ctx.template Alloc <bool >(output);
102- auto num = in.numel ();
102+ int64_t num = in.numel ();
103103 // *out_data = true;
104- for (int i = 0 ; i < num; i++) {
104+ for (int64_t i = 0 ; i < num; i++) {
105105 out_data[i] = true ;
106106 }
107- for (int i = 0 ; i < num; i++) {
107+ for (int64_t i = 0 ; i < num; i++) {
108108 const phi::dtype::complex <T> a = in_a[i], b = in_b[i];
109109 bool val;
110110 if (std::isnan (a) || std::isnan (b)) {
@@ -122,18 +122,18 @@ struct IscloseFunctor<phi::CPUContext, phi::dtype::complex<T>> {
122122};
123123
124124#if defined(__NVCC__) || defined(__HIPCC__)
125- template <typename T>
125+ template <typename T, typename IndexType >
126126__global__ void IscloseCUDAKernel (const T* in_data,
127127 const T* other_data,
128128 const double rtol,
129129 const double atol,
130130 bool equal_nan,
131- int num,
131+ IndexType num,
132132 bool * out_data) {
133- unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x ;
133+ IndexType idx = threadIdx.x + blockIdx.x * blockDim.x ;
134134 bool val;
135135 using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
136- for (int i = idx; i < num; i += blockDim.x * gridDim.x ) {
136+ for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x ) {
137137 const MPType a = static_cast <MPType>(in_data[i]);
138138 const MPType b = static_cast <MPType>(other_data[i]);
139139 if (isnan (a) || isnan (b)) {
@@ -149,17 +149,44 @@ __global__ void IscloseCUDAKernel(const T* in_data,
149149 }
150150}
151151template <>
152- __global__ void IscloseCUDAKernel<phi::dtype::complex <float >>(
152+ __global__ void IscloseCUDAKernel<phi::dtype::complex <float >, unsigned int >(
153153 const phi::dtype::complex <float >* in_data,
154154 const phi::dtype::complex <float >* other_data,
155155 const double rtol,
156156 const double atol,
157157 bool equal_nan,
158- int num,
158+ unsigned int num,
159159 bool * out_data) {
160160 unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x ;
161161 bool val;
162- for (int i = idx; i < num; i += blockDim.x * gridDim.x ) {
162+ for (unsigned int i = idx; i < num; i += blockDim.x * gridDim.x ) {
163+ const phi::dtype::complex <float > a = in_data[i];
164+ const phi::dtype::complex <float > b = other_data[i];
165+ if (isnan (a) || isnan (b)) {
166+ val = equal_nan && isnan (a) == isnan (b);
167+ } else {
168+ float left = abs (a - b);
169+ float right = atol + rtol * abs (b);
170+ float diff = abs (left - right);
171+ val = a == b || left <= right || diff <= 1e-15 ;
172+ }
173+ out_data[i] = val;
174+ // if (!val) *out_data = false;
175+ }
176+ }
177+
178+ template <>
179+ __global__ void IscloseCUDAKernel<phi::dtype::complex <float >, int64_t >(
180+ const phi::dtype::complex <float >* in_data,
181+ const phi::dtype::complex <float >* other_data,
182+ const double rtol,
183+ const double atol,
184+ bool equal_nan,
185+ int64_t num,
186+ bool * out_data) {
187+ int64_t idx = threadIdx.x + blockIdx.x * blockDim.x ;
188+ bool val;
189+ for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x ) {
163190 const phi::dtype::complex <float > a = in_data[i];
164191 const phi::dtype::complex <float > b = other_data[i];
165192 if (isnan (a) || isnan (b)) {
@@ -176,17 +203,17 @@ __global__ void IscloseCUDAKernel<phi::dtype::complex<float>>(
176203}
177204
178205template <>
179- __global__ void IscloseCUDAKernel<phi::dtype::complex <double >>(
206+ __global__ void IscloseCUDAKernel<phi::dtype::complex <double >, unsigned int >(
180207 const phi::dtype::complex <double >* in_data,
181208 const phi::dtype::complex <double >* other_data,
182209 const double rtol,
183210 const double atol,
184211 bool equal_nan,
185- int num,
212+ unsigned int num,
186213 bool * out_data) {
187214 unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x ;
188215 bool val;
189- for (int i = idx; i < num; i += blockDim.x * gridDim.x ) {
216+ for (unsigned int i = idx; i < num; i += blockDim.x * gridDim.x ) {
190217 const phi::dtype::complex <double > a = in_data[i];
191218 const phi::dtype::complex <double > b = other_data[i];
192219 if (isnan (a) || isnan (b)) {
@@ -201,6 +228,34 @@ __global__ void IscloseCUDAKernel<phi::dtype::complex<double>>(
201228 // if (!val) *out_data = false;
202229 }
203230}
231+
232+ template <>
233+ __global__ void IscloseCUDAKernel<phi::dtype::complex <double >, int64_t >(
234+ const phi::dtype::complex <double >* in_data,
235+ const phi::dtype::complex <double >* other_data,
236+ const double rtol,
237+ const double atol,
238+ bool equal_nan,
239+ int64_t num,
240+ bool * out_data) {
241+ int64_t idx = threadIdx.x + blockIdx.x * blockDim.x ;
242+ bool val;
243+ for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x ) {
244+ const phi::dtype::complex <double > a = in_data[i];
245+ const phi::dtype::complex <double > b = other_data[i];
246+ if (isnan (a) || isnan (b)) {
247+ val = equal_nan && isnan (a) == isnan (b);
248+ } else {
249+ double left = abs (a - b);
250+ double right = atol + rtol * abs (b);
251+ double diff = abs (left - right);
252+ val = a == b || left <= right || diff <= 1e-15 ;
253+ }
254+ out_data[i] = val;
255+ // if (!val) *out_data = false;
256+ }
257+ }
258+
204259template <typename T>
205260struct GetTensorValue <phi::GPUContext, T> {
206261 T operator ()(const phi::GPUContext& dev_ctx,
@@ -223,20 +278,25 @@ struct IscloseFunctor<phi::GPUContext, T> {
223278 const double atol,
224279 bool equal_nan,
225280 DenseTensor* output) {
226- int num = in.numel ();
281+ int64_t num = in.numel ();
227282 const T* in_data = in.data <T>();
228283 const T* other_data = other.data <T>();
229284 bool * out_data = dev_ctx.template Alloc <bool >(output);
230- int block = 1024 ;
231- int grid = (block - 1 + num) / block;
285+ int64_t block = 1024 ;
286+ int64_t grid = (block - 1 + num) / block;
232287 grid = (grid > block) ? block : grid;
233288#ifdef PADDLE_WITH_HIP
234289 hipMemset (out_data, true , num * sizeof (bool ));
235290#else
236291 cudaMemset (out_data, true , num * sizeof (bool ));
237292#endif
238- IscloseCUDAKernel<T><<<grid, block, 0 , dev_ctx.stream ()>>>(
239- in_data, other_data, rtol, atol, equal_nan, num, out_data);
293+ if (num + grid * block + 1 > std::numeric_limits<unsigned int >::max ()) {
294+ IscloseCUDAKernel<T, int64_t ><<<grid, block, 0 , dev_ctx.stream ()>>>(
295+ in_data, other_data, rtol, atol, equal_nan, num, out_data);
296+ } else {
297+ IscloseCUDAKernel<T, unsigned int ><<<grid, block, 0 , dev_ctx.stream ()>>>(
298+ in_data, other_data, rtol, atol, equal_nan, num, out_data);
299+ }
240300 }
241301};
242302#endif
0 commit comments