Skip to content

Commit ce96a90

Browse files
authored
[深度对齐]Divide (#75379)
* fix * fix * fix * fix * fix
1 parent 973074a commit ce96a90

File tree

2 files changed

+225
-55
lines changed

2 files changed

+225
-55
lines changed

paddle/phi/common/complex.h

Lines changed: 113 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -230,16 +230,62 @@ HOSTDEVICE inline complex<T> operator*(const complex<T>& a,
230230
}
231231

232232
template <typename T>
233-
HOSTDEVICE inline complex<T> operator/(const complex<T>& a,
234-
const complex<T>& b) {
235-
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
236-
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
237-
return complex<T>(thrust::complex<T>(a) / thrust::complex<T>(b));
238-
#else
239-
T denominator = b.real * b.real + b.imag * b.imag;
240-
return complex<T>((a.real * b.real + a.imag * b.imag) / denominator,
241-
(a.imag * b.real - a.real * b.imag) / denominator);
242-
#endif
233+
HOSTDEVICE inline complex<T> operator/(const complex<T>& x,
234+
const complex<T>& y) {
235+
T a = x.real;
236+
T b = x.imag;
237+
T c = y.real;
238+
T d = y.imag;
239+
240+
// (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i
241+
// the calculation below follows numpy's complex division
242+
#if defined(__GNUC__) && !defined(__clang__)
243+
// std::abs is already constexpr by gcc
244+
auto abs_c = std::abs(c);
245+
auto abs_d = std::abs(d);
246+
#else
247+
auto abs_c = c < 0 ? -c : c;
248+
auto abs_d = d < 0 ? -d : d;
249+
#endif
250+
T real_, imag_;
251+
252+
auto rat = (abs_c >= abs_d) ? (d / c) : (c / d);
253+
auto scl =
254+
(abs_c >= abs_d) ? (T(1.0) / (c + d * rat)) : (T(1.0) / (d + c * rat));
255+
if (abs_c >= abs_d) {
256+
#if __cplusplus >= 201703L
257+
if constexpr (std::is_same_v<T, float>) {
258+
real_ = std::fmaf(b, rat, a) * scl;
259+
imag_ = std::fmaf(-a, rat, b) * scl;
260+
} else if constexpr (std::is_same_v<T, double>) {
261+
real_ = std::fma(b, rat, a) * scl;
262+
imag_ = std::fma(-a, rat, b) * scl;
263+
} else {
264+
real_ = (a + b * rat) * scl;
265+
imag_ = (b - a * rat) * scl;
266+
}
267+
#else
268+
real_ = (a + b * rat) * scl;
269+
imag_ = (b - a * rat) * scl;
270+
#endif
271+
} else {
272+
#if __cplusplus >= 201703L
273+
if constexpr (std::is_same_v<T, float>) {
274+
real_ = std::fmaf(a, rat, b) * scl;
275+
imag_ = std::fmaf(b, rat, -a) * scl;
276+
} else if constexpr (std::is_same_v<T, double>) {
277+
real_ = std::fma(a, rat, b) * scl;
278+
imag_ = std::fma(b, rat, -a) * scl;
279+
} else {
280+
real_ = (a * rat + b) * scl;
281+
imag_ = (b * rat - a) * scl;
282+
}
283+
#else
284+
real_ = (a * rat + b) * scl;
285+
imag_ = (b * rat - a) * scl;
286+
#endif
287+
}
288+
return complex<T>(real_, imag_);
243289
}
244290

245291
template <typename T>
@@ -303,19 +349,63 @@ HOSTDEVICE inline complex<T>& operator*=(complex<T>& a, // NOLINT
303349
}
304350

305351
template <typename T>
306-
HOSTDEVICE inline complex<T>& operator/=(complex<T>& a, // NOLINT
307-
const complex<T>& b) {
308-
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
309-
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
310-
a = complex<T>(thrust::complex<T>(a.real, a.imag) /=
311-
thrust::complex<T>(b.real, b.imag));
312-
return a;
313-
#else
314-
T denominator = b.real * b.real + b.imag * b.imag;
315-
a.real = (a.real * b.real + a.imag * b.imag) / denominator;
316-
a.imag = (a.imag * b.real - a.real * b.imag) / denominator;
317-
return a;
318-
#endif
352+
HOSTDEVICE inline complex<T>& operator/=(complex<T>& x, // NOLINT
353+
const complex<T>& y) {
354+
T a = x.real;
355+
T b = x.imag;
356+
T c = y.real;
357+
T d = y.imag;
358+
359+
// (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i
360+
// the calculation below follows numpy's complex division
361+
#if defined(__GNUC__) && !defined(__clang__)
362+
// std::abs is already constexpr by gcc
363+
auto abs_c = std::abs(c);
364+
auto abs_d = std::abs(d);
365+
#else
366+
auto abs_c = c < 0 ? -c : c;
367+
auto abs_d = d < 0 ? -d : d;
368+
#endif
369+
T real_, imag_;
370+
371+
auto rat = (abs_c >= abs_d) ? (d / c) : (c / d);
372+
auto scl =
373+
(abs_c >= abs_d) ? (T(1.0) / (c + d * rat)) : (T(1.0) / (d + c * rat));
374+
if (abs_c >= abs_d) {
375+
#if __cplusplus >= 201703L
376+
if constexpr (std::is_same_v<T, float>) {
377+
real_ = std::fmaf(b, rat, a) * scl;
378+
imag_ = std::fmaf(-a, rat, b) * scl;
379+
} else if constexpr (std::is_same_v<T, double>) {
380+
real_ = std::fma(b, rat, a) * scl;
381+
imag_ = std::fma(-a, rat, b) * scl;
382+
} else {
383+
real_ = (a + b * rat) * scl;
384+
imag_ = (b - a * rat) * scl;
385+
}
386+
#else
387+
real_ = (a + b * rat) * scl;
388+
imag_ = (b - a * rat) * scl;
389+
#endif
390+
} else {
391+
#if __cplusplus >= 201703L
392+
if constexpr (std::is_same_v<T, float>) {
393+
real_ = std::fmaf(a, rat, b) * scl;
394+
imag_ = std::fmaf(b, rat, -a) * scl;
395+
} else if constexpr (std::is_same_v<T, double>) {
396+
real_ = std::fma(a, rat, b) * scl;
397+
imag_ = std::fma(b, rat, -a) * scl;
398+
} else {
399+
real_ = (a * rat + b) * scl;
400+
imag_ = (b * rat - a) * scl;
401+
}
402+
#else
403+
real_ = (a * rat + b) * scl;
404+
imag_ = (b * rat - a) * scl;
405+
#endif
406+
}
407+
x = complex<T>(real_, imag_);
408+
return x;
319409
}
320410

321411
template <typename T>

paddle/phi/kernels/funcs/elementwise_functor.h

Lines changed: 112 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -142,23 +142,44 @@ struct DivideFunctor<ComplexType<T>> {
142142
#endif
143143

144144
T real_, imag_;
145+
146+
auto rat = (abs_c >= abs_d) ? (d / c) : (c / d);
147+
auto scl =
148+
(abs_c >= abs_d) ? (T(1.0) / (c + d * rat)) : (T(1.0) / (d + c * rat));
145149
if (abs_c >= abs_d) {
146-
if (abs_c == T(0) && abs_d == T(0)) {
147-
/* divide by zeros should yield a complex inf or nan */
148-
real_ = a / abs_c;
149-
imag_ = b / abs_d;
150+
#if __cplusplus >= 201703L
151+
if constexpr (std::is_same_v<T, float>) {
152+
real_ = std::fmaf(b, rat, a) * scl;
153+
imag_ = std::fmaf(-a, rat, b) * scl;
154+
} else if constexpr (std::is_same_v<T, double>) {
155+
real_ = std::fma(b, rat, a) * scl;
156+
imag_ = std::fma(-a, rat, b) * scl;
150157
} else {
151-
auto rat = d / c;
152-
auto scl = T(1.0) / (c + d * rat);
153158
real_ = (a + b * rat) * scl;
154159
imag_ = (b - a * rat) * scl;
155160
}
161+
#else
162+
real_ = (a + b * rat) * scl;
163+
imag_ = (b - a * rat) * scl;
164+
#endif
156165
} else {
157-
auto rat = c / d;
158-
auto scl = T(1.0) / (d + c * rat);
166+
#if __cplusplus >= 201703L
167+
if constexpr (std::is_same_v<T, float>) {
168+
real_ = std::fmaf(a, rat, b) * scl;
169+
imag_ = std::fmaf(b, rat, -a) * scl;
170+
} else if constexpr (std::is_same_v<T, double>) {
171+
real_ = std::fma(a, rat, b) * scl;
172+
imag_ = std::fma(b, rat, -a) * scl;
173+
} else {
174+
real_ = (a * rat + b) * scl;
175+
imag_ = (b * rat - a) * scl;
176+
}
177+
#else
159178
real_ = (a * rat + b) * scl;
160179
imag_ = (b * rat - a) * scl;
180+
#endif
161181
}
182+
162183
return ComplexType<T>(real_, imag_);
163184
}
164185
};
@@ -184,23 +205,44 @@ struct InverseDivideFunctor<ComplexType<T>> {
184205
#endif
185206

186207
T real_, imag_;
208+
209+
auto rat = (abs_c >= abs_d) ? (d / c) : (c / d);
210+
auto scl =
211+
(abs_c >= abs_d) ? (T(1.0) / (c + d * rat)) : (T(1.0) / (d + c * rat));
187212
if (abs_c >= abs_d) {
188-
if (abs_c == T(0) && abs_d == T(0)) {
189-
/* divide by zeros should yield a complex inf or nan */
190-
real_ = a / abs_c;
191-
imag_ = b / abs_d;
213+
#if __cplusplus >= 201703L
214+
if constexpr (std::is_same_v<T, float>) {
215+
real_ = std::fmaf(b, rat, a) * scl;
216+
imag_ = std::fmaf(-a, rat, b) * scl;
217+
} else if constexpr (std::is_same_v<T, double>) {
218+
real_ = std::fma(b, rat, a) * scl;
219+
imag_ = std::fma(-a, rat, b) * scl;
192220
} else {
193-
auto rat = d / c;
194-
auto scl = T(1.0) / (c + d * rat);
195221
real_ = (a + b * rat) * scl;
196222
imag_ = (b - a * rat) * scl;
197223
}
224+
#else
225+
real_ = (a + b * rat) * scl;
226+
imag_ = (b - a * rat) * scl;
227+
#endif
198228
} else {
199-
auto rat = c / d;
200-
auto scl = T(1.0) / (d + c * rat);
229+
#if __cplusplus >= 201703L
230+
if constexpr (std::is_same_v<T, float>) {
231+
real_ = std::fmaf(a, rat, b) * scl;
232+
imag_ = std::fmaf(b, rat, -a) * scl;
233+
} else if constexpr (std::is_same_v<T, double>) {
234+
real_ = std::fma(a, rat, b) * scl;
235+
imag_ = std::fma(b, rat, -a) * scl;
236+
} else {
237+
real_ = (a * rat + b) * scl;
238+
imag_ = (b * rat - a) * scl;
239+
}
240+
#else
201241
real_ = (a * rat + b) * scl;
202242
imag_ = (b * rat - a) * scl;
243+
#endif
203244
}
245+
204246
return ComplexType<T>(real_, imag_);
205247
}
206248
};
@@ -776,22 +818,41 @@ struct RemainderFunctor<ComplexType<T>> {
776818
#endif
777819

778820
T real_, imag_;
821+
auto rat = (abs_c >= abs_d) ? (d__ / c__) : (c__ / d__);
822+
auto scl = (abs_c >= abs_d) ? (T(1.0) / (c__ + d__ * rat))
823+
: (T(1.0) / (d__ + c__ * rat));
779824
if (abs_c >= abs_d) {
780-
if (abs_c == T(0) && abs_d == T(0)) {
781-
/* divide by zeros should yield a complex inf or nan */
782-
real_ = a__ / abs_c;
783-
imag_ = b__ / abs_d;
825+
#if __cplusplus >= 201703L
826+
if constexpr (std::is_same_v<T, float>) {
827+
real_ = std::fmaf(b__, rat, a__) * scl;
828+
imag_ = std::fmaf(-a__, rat, b__) * scl;
829+
} else if constexpr (std::is_same_v<T, double>) {
830+
real_ = std::fma(b__, rat, a__) * scl;
831+
imag_ = std::fma(-a__, rat, b__) * scl;
784832
} else {
785-
auto rat = d__ / c__;
786-
auto scl = T(1.0) / (c__ + d__ * rat);
787833
real_ = (a__ + b__ * rat) * scl;
788834
imag_ = (b__ - a__ * rat) * scl;
789835
}
836+
#else
837+
real_ = (a__ + b__ * rat) * scl;
838+
imag_ = (b__ - a__ * rat) * scl;
839+
#endif
790840
} else {
791-
auto rat = c__ / d__;
792-
auto scl = T(1.0) / (d__ + c__ * rat);
841+
#if __cplusplus >= 201703L
842+
if constexpr (std::is_same_v<T, float>) {
843+
real_ = std::fmaf(a__, rat, b__) * scl;
844+
imag_ = std::fmaf(b__, rat, -a__) * scl;
845+
} else if constexpr (std::is_same_v<T, double>) {
846+
real_ = std::fma(a__, rat, b__) * scl;
847+
imag_ = std::fma(b__, rat, -a__) * scl;
848+
} else {
849+
real_ = (a__ * rat + b__) * scl;
850+
imag_ = (b__ * rat - a__) * scl;
851+
}
852+
#else
793853
real_ = (a__ * rat + b__) * scl;
794854
imag_ = (b__ * rat - a__) * scl;
855+
#endif
795856
}
796857
auto q = ComplexType<T>(real_, imag_);
797858

@@ -970,22 +1031,41 @@ struct InverseRemainderFunctor<
9701031
#endif
9711032

9721033
T real_, imag_;
1034+
auto rat = (abs_c >= abs_d) ? (d__ / c__) : (c__ / d__);
1035+
auto scl = (abs_c >= abs_d) ? (T(1.0) / (c__ + d__ * rat))
1036+
: (T(1.0) / (d__ + c__ * rat));
9731037
if (abs_c >= abs_d) {
974-
if (abs_c == T(0) && abs_d == T(0)) {
975-
/* divide by zeros should yield a complex inf or nan */
976-
real_ = a__ / abs_c;
977-
imag_ = b__ / abs_d;
1038+
#if __cplusplus >= 201703L
1039+
if constexpr (std::is_same_v<T, float>) {
1040+
real_ = std::fmaf(b__, rat, a__) * scl;
1041+
imag_ = std::fmaf(-a__, rat, b__) * scl;
1042+
} else if constexpr (std::is_same_v<T, double>) {
1043+
real_ = std::fma(b__, rat, a__) * scl;
1044+
imag_ = std::fma(-a__, rat, b__) * scl;
9781045
} else {
979-
auto rat = d__ / c__;
980-
auto scl = T(1.0) / (c__ + d__ * rat);
9811046
real_ = (a__ + b__ * rat) * scl;
9821047
imag_ = (b__ - a__ * rat) * scl;
9831048
}
1049+
#else
1050+
real_ = (a__ + b__ * rat) * scl;
1051+
imag_ = (b__ - a__ * rat) * scl;
1052+
#endif
9841053
} else {
985-
auto rat = c__ / d__;
986-
auto scl = T(1.0) / (d__ + c__ * rat);
1054+
#if __cplusplus >= 201703L
1055+
if constexpr (std::is_same_v<T, float>) {
1056+
real_ = std::fmaf(a__, rat, b__) * scl;
1057+
imag_ = std::fmaf(b__, rat, -a__) * scl;
1058+
} else if constexpr (std::is_same_v<T, double>) {
1059+
real_ = std::fma(a__, rat, b__) * scl;
1060+
imag_ = std::fma(b__, rat, -a__) * scl;
1061+
} else {
1062+
real_ = (a__ * rat + b__) * scl;
1063+
imag_ = (b__ * rat - a__) * scl;
1064+
}
1065+
#else
9871066
real_ = (a__ * rat + b__) * scl;
9881067
imag_ = (b__ * rat - a__) * scl;
1068+
#endif
9891069
}
9901070
auto q = ComplexType<T>(real_, imag_);
9911071

0 commit comments

Comments
 (0)