Skip to content

Commit ed4e384

Browse files
cszdrgzhengshengning
authored andcommitted
[深度对齐]Divide (PaddlePaddle#75379)
* fix * fix * fix * fix * fix
1 parent 30b7970 commit ed4e384

File tree

2 files changed

+225
-55
lines changed

2 files changed

+225
-55
lines changed

paddle/phi/common/complex.h

Lines changed: 113 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -230,16 +230,62 @@ HOSTDEVICE inline complex<T> operator*(const complex<T>& a,
230230
}
231231

232232
template <typename T>
233-
HOSTDEVICE inline complex<T> operator/(const complex<T>& a,
234-
const complex<T>& b) {
235-
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
236-
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
237-
return complex<T>(thrust::complex<T>(a) / thrust::complex<T>(b));
238-
#else
239-
T denominator = b.real * b.real + b.imag * b.imag;
240-
return complex<T>((a.real * b.real + a.imag * b.imag) / denominator,
241-
(a.imag * b.real - a.real * b.imag) / denominator);
242-
#endif
233+
HOSTDEVICE inline complex<T> operator/(const complex<T>& x,
234+
const complex<T>& y) {
235+
T a = x.real;
236+
T b = x.imag;
237+
T c = y.real;
238+
T d = y.imag;
239+
240+
// (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i
241+
// the calculation below follows numpy's complex division
242+
#if defined(__GNUC__) && !defined(__clang__)
243+
// std::abs is already constexpr by gcc
244+
auto abs_c = std::abs(c);
245+
auto abs_d = std::abs(d);
246+
#else
247+
auto abs_c = c < 0 ? -c : c;
248+
auto abs_d = d < 0 ? -d : d;
249+
#endif
250+
T real_, imag_;
251+
252+
auto rat = (abs_c >= abs_d) ? (d / c) : (c / d);
253+
auto scl =
254+
(abs_c >= abs_d) ? (T(1.0) / (c + d * rat)) : (T(1.0) / (d + c * rat));
255+
if (abs_c >= abs_d) {
256+
#if __cplusplus >= 201703L
257+
if constexpr (std::is_same_v<T, float>) {
258+
real_ = std::fmaf(b, rat, a) * scl;
259+
imag_ = std::fmaf(-a, rat, b) * scl;
260+
} else if constexpr (std::is_same_v<T, double>) {
261+
real_ = std::fma(b, rat, a) * scl;
262+
imag_ = std::fma(-a, rat, b) * scl;
263+
} else {
264+
real_ = (a + b * rat) * scl;
265+
imag_ = (b - a * rat) * scl;
266+
}
267+
#else
268+
real_ = (a + b * rat) * scl;
269+
imag_ = (b - a * rat) * scl;
270+
#endif
271+
} else {
272+
#if __cplusplus >= 201703L
273+
if constexpr (std::is_same_v<T, float>) {
274+
real_ = std::fmaf(a, rat, b) * scl;
275+
imag_ = std::fmaf(b, rat, -a) * scl;
276+
} else if constexpr (std::is_same_v<T, double>) {
277+
real_ = std::fma(a, rat, b) * scl;
278+
imag_ = std::fma(b, rat, -a) * scl;
279+
} else {
280+
real_ = (a * rat + b) * scl;
281+
imag_ = (b * rat - a) * scl;
282+
}
283+
#else
284+
real_ = (a * rat + b) * scl;
285+
imag_ = (b * rat - a) * scl;
286+
#endif
287+
}
288+
return complex<T>(real_, imag_);
243289
}
244290

245291
template <typename T>
@@ -303,19 +349,63 @@ HOSTDEVICE inline complex<T>& operator*=(complex<T>& a, // NOLINT
303349
}
304350

305351
template <typename T>
306-
HOSTDEVICE inline complex<T>& operator/=(complex<T>& a, // NOLINT
307-
const complex<T>& b) {
308-
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
309-
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
310-
a = complex<T>(thrust::complex<T>(a.real, a.imag) /=
311-
thrust::complex<T>(b.real, b.imag));
312-
return a;
313-
#else
314-
T denominator = b.real * b.real + b.imag * b.imag;
315-
a.real = (a.real * b.real + a.imag * b.imag) / denominator;
316-
a.imag = (a.imag * b.real - a.real * b.imag) / denominator;
317-
return a;
318-
#endif
352+
HOSTDEVICE inline complex<T>& operator/=(complex<T>& x, // NOLINT
353+
const complex<T>& y) {
354+
T a = x.real;
355+
T b = x.imag;
356+
T c = y.real;
357+
T d = y.imag;
358+
359+
// (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i
360+
// the calculation below follows numpy's complex division
361+
#if defined(__GNUC__) && !defined(__clang__)
362+
// std::abs is already constexpr by gcc
363+
auto abs_c = std::abs(c);
364+
auto abs_d = std::abs(d);
365+
#else
366+
auto abs_c = c < 0 ? -c : c;
367+
auto abs_d = d < 0 ? -d : d;
368+
#endif
369+
T real_, imag_;
370+
371+
auto rat = (abs_c >= abs_d) ? (d / c) : (c / d);
372+
auto scl =
373+
(abs_c >= abs_d) ? (T(1.0) / (c + d * rat)) : (T(1.0) / (d + c * rat));
374+
if (abs_c >= abs_d) {
375+
#if __cplusplus >= 201703L
376+
if constexpr (std::is_same_v<T, float>) {
377+
real_ = std::fmaf(b, rat, a) * scl;
378+
imag_ = std::fmaf(-a, rat, b) * scl;
379+
} else if constexpr (std::is_same_v<T, double>) {
380+
real_ = std::fma(b, rat, a) * scl;
381+
imag_ = std::fma(-a, rat, b) * scl;
382+
} else {
383+
real_ = (a + b * rat) * scl;
384+
imag_ = (b - a * rat) * scl;
385+
}
386+
#else
387+
real_ = (a + b * rat) * scl;
388+
imag_ = (b - a * rat) * scl;
389+
#endif
390+
} else {
391+
#if __cplusplus >= 201703L
392+
if constexpr (std::is_same_v<T, float>) {
393+
real_ = std::fmaf(a, rat, b) * scl;
394+
imag_ = std::fmaf(b, rat, -a) * scl;
395+
} else if constexpr (std::is_same_v<T, double>) {
396+
real_ = std::fma(a, rat, b) * scl;
397+
imag_ = std::fma(b, rat, -a) * scl;
398+
} else {
399+
real_ = (a * rat + b) * scl;
400+
imag_ = (b * rat - a) * scl;
401+
}
402+
#else
403+
real_ = (a * rat + b) * scl;
404+
imag_ = (b * rat - a) * scl;
405+
#endif
406+
}
407+
x = complex<T>(real_, imag_);
408+
return x;
319409
}
320410

321411
template <typename T>

paddle/phi/kernels/funcs/elementwise_functor.h

Lines changed: 112 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -145,23 +145,44 @@ struct DivideFunctor<ComplexType<T>> {
145145
#endif
146146

147147
T real_, imag_;
148+
149+
auto rat = (abs_c >= abs_d) ? (d / c) : (c / d);
150+
auto scl =
151+
(abs_c >= abs_d) ? (T(1.0) / (c + d * rat)) : (T(1.0) / (d + c * rat));
148152
if (abs_c >= abs_d) {
149-
if (abs_c == T(0) && abs_d == T(0)) {
150-
/* divide by zeros should yield a complex inf or nan */
151-
real_ = a / abs_c;
152-
imag_ = b / abs_d;
153+
#if __cplusplus >= 201703L
154+
if constexpr (std::is_same_v<T, float>) {
155+
real_ = std::fmaf(b, rat, a) * scl;
156+
imag_ = std::fmaf(-a, rat, b) * scl;
157+
} else if constexpr (std::is_same_v<T, double>) {
158+
real_ = std::fma(b, rat, a) * scl;
159+
imag_ = std::fma(-a, rat, b) * scl;
153160
} else {
154-
auto rat = d / c;
155-
auto scl = T(1.0) / (c + d * rat);
156161
real_ = (a + b * rat) * scl;
157162
imag_ = (b - a * rat) * scl;
158163
}
164+
#else
165+
real_ = (a + b * rat) * scl;
166+
imag_ = (b - a * rat) * scl;
167+
#endif
159168
} else {
160-
auto rat = c / d;
161-
auto scl = T(1.0) / (d + c * rat);
169+
#if __cplusplus >= 201703L
170+
if constexpr (std::is_same_v<T, float>) {
171+
real_ = std::fmaf(a, rat, b) * scl;
172+
imag_ = std::fmaf(b, rat, -a) * scl;
173+
} else if constexpr (std::is_same_v<T, double>) {
174+
real_ = std::fma(a, rat, b) * scl;
175+
imag_ = std::fma(b, rat, -a) * scl;
176+
} else {
177+
real_ = (a * rat + b) * scl;
178+
imag_ = (b * rat - a) * scl;
179+
}
180+
#else
162181
real_ = (a * rat + b) * scl;
163182
imag_ = (b * rat - a) * scl;
183+
#endif
164184
}
185+
165186
return ComplexType<T>(real_, imag_);
166187
}
167188
};
@@ -187,23 +208,44 @@ struct InverseDivideFunctor<ComplexType<T>> {
187208
#endif
188209

189210
T real_, imag_;
211+
212+
auto rat = (abs_c >= abs_d) ? (d / c) : (c / d);
213+
auto scl =
214+
(abs_c >= abs_d) ? (T(1.0) / (c + d * rat)) : (T(1.0) / (d + c * rat));
190215
if (abs_c >= abs_d) {
191-
if (abs_c == T(0) && abs_d == T(0)) {
192-
/* divide by zeros should yield a complex inf or nan */
193-
real_ = a / abs_c;
194-
imag_ = b / abs_d;
216+
#if __cplusplus >= 201703L
217+
if constexpr (std::is_same_v<T, float>) {
218+
real_ = std::fmaf(b, rat, a) * scl;
219+
imag_ = std::fmaf(-a, rat, b) * scl;
220+
} else if constexpr (std::is_same_v<T, double>) {
221+
real_ = std::fma(b, rat, a) * scl;
222+
imag_ = std::fma(-a, rat, b) * scl;
195223
} else {
196-
auto rat = d / c;
197-
auto scl = T(1.0) / (c + d * rat);
198224
real_ = (a + b * rat) * scl;
199225
imag_ = (b - a * rat) * scl;
200226
}
227+
#else
228+
real_ = (a + b * rat) * scl;
229+
imag_ = (b - a * rat) * scl;
230+
#endif
201231
} else {
202-
auto rat = c / d;
203-
auto scl = T(1.0) / (d + c * rat);
232+
#if __cplusplus >= 201703L
233+
if constexpr (std::is_same_v<T, float>) {
234+
real_ = std::fmaf(a, rat, b) * scl;
235+
imag_ = std::fmaf(b, rat, -a) * scl;
236+
} else if constexpr (std::is_same_v<T, double>) {
237+
real_ = std::fma(a, rat, b) * scl;
238+
imag_ = std::fma(b, rat, -a) * scl;
239+
} else {
240+
real_ = (a * rat + b) * scl;
241+
imag_ = (b * rat - a) * scl;
242+
}
243+
#else
204244
real_ = (a * rat + b) * scl;
205245
imag_ = (b * rat - a) * scl;
246+
#endif
206247
}
248+
207249
return ComplexType<T>(real_, imag_);
208250
}
209251
};
@@ -779,22 +821,41 @@ struct RemainderFunctor<ComplexType<T>> {
779821
#endif
780822

781823
T real_, imag_;
824+
auto rat = (abs_c >= abs_d) ? (d__ / c__) : (c__ / d__);
825+
auto scl = (abs_c >= abs_d) ? (T(1.0) / (c__ + d__ * rat))
826+
: (T(1.0) / (d__ + c__ * rat));
782827
if (abs_c >= abs_d) {
783-
if (abs_c == T(0) && abs_d == T(0)) {
784-
/* divide by zeros should yield a complex inf or nan */
785-
real_ = a__ / abs_c;
786-
imag_ = b__ / abs_d;
828+
#if __cplusplus >= 201703L
829+
if constexpr (std::is_same_v<T, float>) {
830+
real_ = std::fmaf(b__, rat, a__) * scl;
831+
imag_ = std::fmaf(-a__, rat, b__) * scl;
832+
} else if constexpr (std::is_same_v<T, double>) {
833+
real_ = std::fma(b__, rat, a__) * scl;
834+
imag_ = std::fma(-a__, rat, b__) * scl;
787835
} else {
788-
auto rat = d__ / c__;
789-
auto scl = T(1.0) / (c__ + d__ * rat);
790836
real_ = (a__ + b__ * rat) * scl;
791837
imag_ = (b__ - a__ * rat) * scl;
792838
}
839+
#else
840+
real_ = (a__ + b__ * rat) * scl;
841+
imag_ = (b__ - a__ * rat) * scl;
842+
#endif
793843
} else {
794-
auto rat = c__ / d__;
795-
auto scl = T(1.0) / (d__ + c__ * rat);
844+
#if __cplusplus >= 201703L
845+
if constexpr (std::is_same_v<T, float>) {
846+
real_ = std::fmaf(a__, rat, b__) * scl;
847+
imag_ = std::fmaf(b__, rat, -a__) * scl;
848+
} else if constexpr (std::is_same_v<T, double>) {
849+
real_ = std::fma(a__, rat, b__) * scl;
850+
imag_ = std::fma(b__, rat, -a__) * scl;
851+
} else {
852+
real_ = (a__ * rat + b__) * scl;
853+
imag_ = (b__ * rat - a__) * scl;
854+
}
855+
#else
796856
real_ = (a__ * rat + b__) * scl;
797857
imag_ = (b__ * rat - a__) * scl;
858+
#endif
798859
}
799860
auto q = ComplexType<T>(real_, imag_);
800861

@@ -973,22 +1034,41 @@ struct InverseRemainderFunctor<
9731034
#endif
9741035

9751036
T real_, imag_;
1037+
auto rat = (abs_c >= abs_d) ? (d__ / c__) : (c__ / d__);
1038+
auto scl = (abs_c >= abs_d) ? (T(1.0) / (c__ + d__ * rat))
1039+
: (T(1.0) / (d__ + c__ * rat));
9761040
if (abs_c >= abs_d) {
977-
if (abs_c == T(0) && abs_d == T(0)) {
978-
/* divide by zeros should yield a complex inf or nan */
979-
real_ = a__ / abs_c;
980-
imag_ = b__ / abs_d;
1041+
#if __cplusplus >= 201703L
1042+
if constexpr (std::is_same_v<T, float>) {
1043+
real_ = std::fmaf(b__, rat, a__) * scl;
1044+
imag_ = std::fmaf(-a__, rat, b__) * scl;
1045+
} else if constexpr (std::is_same_v<T, double>) {
1046+
real_ = std::fma(b__, rat, a__) * scl;
1047+
imag_ = std::fma(-a__, rat, b__) * scl;
9811048
} else {
982-
auto rat = d__ / c__;
983-
auto scl = T(1.0) / (c__ + d__ * rat);
9841049
real_ = (a__ + b__ * rat) * scl;
9851050
imag_ = (b__ - a__ * rat) * scl;
9861051
}
1052+
#else
1053+
real_ = (a__ + b__ * rat) * scl;
1054+
imag_ = (b__ - a__ * rat) * scl;
1055+
#endif
9871056
} else {
988-
auto rat = c__ / d__;
989-
auto scl = T(1.0) / (d__ + c__ * rat);
1057+
#if __cplusplus >= 201703L
1058+
if constexpr (std::is_same_v<T, float>) {
1059+
real_ = std::fmaf(a__, rat, b__) * scl;
1060+
imag_ = std::fmaf(b__, rat, -a__) * scl;
1061+
} else if constexpr (std::is_same_v<T, double>) {
1062+
real_ = std::fma(a__, rat, b__) * scl;
1063+
imag_ = std::fma(b__, rat, -a__) * scl;
1064+
} else {
1065+
real_ = (a__ * rat + b__) * scl;
1066+
imag_ = (b__ * rat - a__) * scl;
1067+
}
1068+
#else
9901069
real_ = (a__ * rat + b__) * scl;
9911070
imag_ = (b__ * rat - a__) * scl;
1071+
#endif
9921072
}
9931073
auto q = ComplexType<T>(real_, imag_);
9941074

0 commit comments

Comments
 (0)