Skip to content

Commit e804bd5

Browse files
author
Feiyu Chan
authored
fix dynload for cufft on windows (#51)
1. fix dynload for cufft on windows; 2. fix unittests.
1 parent 1e16889 commit e804bd5

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

paddle/fluid/platform/dynload/dynamic_loader.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ static constexpr char* win_cusolver_lib =
109109
static constexpr char* win_cusparse_lib =
110110
"cusparse64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR
111111
".dll;cusparse64_" CUDA_VERSION_MAJOR ".dll;cusparse64_10.dll";
112+
static constexpr char* win_cufft_lib =
113+
"cufft64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR
114+
".dll;cufft64_" CUDA_VERSION_MAJOR ".dll;cufft64_10.dll";
112115
#else
113116
static constexpr char* win_curand_lib =
114117
"curand64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR
@@ -122,7 +125,9 @@ static constexpr char* win_cusolver_lib =
122125
static constexpr char* win_cusparse_lib =
123126
"cusparse64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR
124127
".dll;cusparse64_" CUDA_VERSION_MAJOR ".dll";
125-
static constexpr char* win_cufft_lib = "cufft64_" CUDA_MAJOR_VERSION ".dll";
128+
static constexpr char* win_cufft_lib =
129+
"cufft64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR
130+
".dll;cufft64_" CUDA_VERSION_MAJOR ".dll";
126131
#endif // CUDA_VERSION
127132
#endif
128133

python/paddle/fluid/tests/unittests/fft/test_fft.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -534,8 +534,8 @@ def test_hfft2(self):
534534
[('test_n_nagative',
535535
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2),
536536
(-2, -1), 'backward', ValueError), \
537-
('test_n_equal_input_length',
538-
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (2, 1), (-2, -1),
537+
('test_zero_point',
538+
np.random.randn(4, 4, 1) + 1j * np.random.randn(4, 4, 1), None, (-2, -1),
539539
"backward", ValueError), \
540540
('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
541541
(0, 0), (-2, -1), 'backward', ValueError), \

0 commit comments

Comments
 (0)