@@ -14,6 +14,7 @@ limitations under the License. */
1414#include " paddle/phi/backends/dynload/dynamic_loader.h"
1515#include < dirent.h>
1616
17+ #include < codecvt>
1718#include < cstdlib>
1819#include < string>
1920#include < vector>
@@ -45,6 +46,7 @@ COMMON_DECLARE_string(cusparselt_dir);
4546COMMON_DECLARE_string (curand_dir);
4647COMMON_DECLARE_string (cusolver_dir);
4748COMMON_DECLARE_string (cusparse_dir);
49+ COMMON_DECLARE_string (win_cuda_bin_dir);
4850#ifdef PADDLE_WITH_HIP
4951
5052PHI_DEFINE_string (miopen_dir,
@@ -132,8 +134,12 @@ static constexpr char* win_cufft_lib =
132134
133135static inline std::string join (const std::string& part1,
134136 const std::string& part2) {
135- // directory separator
137+ // directory separator
138+ #if defined(_WIN32)
139+ const char sep = ' \\ ' ;
140+ #else
136141 const char sep = ' /' ;
142+ #endif
137143 if (!part2.empty () && part2.front () == sep) {
138144 return part2;
139145 }
@@ -263,6 +269,26 @@ static inline void* GetDsoHandleFromSearchPath(
263269#else
264270 int dynload_flags = 0 ;
265271#endif // !_WIN32
272+ #if defined(_WIN32)
273+ std::vector<std::wstring> cuda_bin_search_path = {
274+ L" cublas" ,
275+ L" cuda_nvrtc" ,
276+ L" cuda_runtime" ,
277+ L" cudnn" ,
278+ L" cufft" ,
279+ L" curand" ,
280+ L" cusolver" ,
281+ L" cusparse" ,
282+ L" nvjitlink" ,
283+ };
284+ for (auto search_path : cuda_bin_search_path) {
285+ std::wstring_convert<std::codecvt_utf8_utf16<wchar_t >> converter;
286+ std::wstring win_path_wstring =
287+ converter.from_bytes (FLAGS_win_cuda_bin_dir);
288+ search_path = win_path_wstring + L" \\ " + search_path + L" \\ bin" ;
289+ AddDllDirectory (search_path.c_str ());
290+ }
291+ #endif
266292 std::vector<std::string> dso_names = split (dso_name, " ;" );
267293 void * dso_handle = nullptr ;
268294 for (auto const & dso : dso_names) {
@@ -324,8 +350,26 @@ void* GetCublasDsoHandle() {
324350#if defined(__APPLE__) || defined(__OSX__)
325351 return GetDsoHandleFromSearchPath (FLAGS_cuda_dir, " libcublas.dylib" );
326352#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA)
327- return GetDsoHandleFromSearchPath (
328- FLAGS_cuda_dir, win_cublas_lib, true , {cuda_lib_path});
353+ if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000 ) {
354+ #ifdef WITH_PIP_CUDA_LIBRARIES
355+ return GetDsoHandleFromSearchPath (FLAGS_cuda_dir, " cublas64_11.dll" );
356+ #else
357+ return GetDsoHandleFromSearchPath (
358+ FLAGS_cuda_dir, win_cublas_lib, true , {cuda_lib_path});
359+ #endif
360+ } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 12030 ) {
361+ #ifdef WITH_PIP_CUDA_LIBRARIES
362+ return GetDsoHandleFromSearchPath (FLAGS_cuda_dir, " cublas64_12.dll" );
363+ #else
364+ return GetDsoHandleFromSearchPath (
365+ FLAGS_cuda_dir, win_cublas_lib, true , {cuda_lib_path});
366+ #endif
367+ } else {
368+ std::string warning_msg (
369+ " Your CUDA_VERSION is less than 11 or greater than 12, paddle "
370+ " temporarily no longer supports" );
371+ return nullptr ;
372+ }
329373#elif defined(__linux__) && defined(PADDLE_WITH_CUDA)
330374 if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000 ) {
331375#ifdef WITH_PIP_CUDA_LIBRARIES
@@ -403,8 +447,13 @@ void* GetCUDNNDsoHandle() {
403447 " Toolkit\\ CUDA\\ v10.0\n "
404448 " You should do this according to your CUDA installation directory and "
405449 " CUDNN version." );
450+ #ifdef WITH_PIP_CUDA_LIBRARIES
451+ return GetDsoHandleFromSearchPath (
452+ FLAGS_cuda_dir, " cudnn64_8.dll" , true , {cuda_lib_path}, win_warn_meg);
453+ #else
406454 return GetDsoHandleFromSearchPath (
407- FLAGS_cudnn_dir, win_cudnn_lib, true , {cuda_lib_path}, win_warn_meg);
455+ FLAGS_cuda_dir, win_cudnn_lib, true , {cuda_lib_path}, win_warn_meg);
456+ #endif
408457#elif defined(PADDLE_WITH_HIP)
409458 return GetDsoHandleFromSearchPath (FLAGS_miopen_dir, " libMIOpen.so" , false );
410459#else
@@ -461,8 +510,13 @@ void* GetCurandDsoHandle() {
461510#if defined(__APPLE__) || defined(__OSX__)
462511 return GetDsoHandleFromSearchPath (FLAGS_cuda_dir, " libcurand.dylib" );
463512#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA)
513+ #ifdef WITH_PIP_CUDA_LIBRARIES
514+ return GetDsoHandleFromSearchPath (
515+ FLAGS_cuda_dir, " curand64_10.dll" , true , {cuda_lib_path});
516+ #else
464517 return GetDsoHandleFromSearchPath (
465518 FLAGS_cuda_dir, win_curand_lib, true , {cuda_lib_path});
519+ #endif
466520#elif defined(PADDLE_WITH_HIP)
467521 return GetDsoHandleFromSearchPath (FLAGS_rocm_dir, " libhiprand.so" );
468522#else
@@ -500,8 +554,13 @@ void* GetCusolverDsoHandle() {
500554#if defined(__APPLE__) || defined(__OSX__)
501555 return GetDsoHandleFromSearchPath (FLAGS_cuda_dir, " libcusolver.dylib" );
502556#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA)
557+ #ifdef WITH_PIP_CUDA_LIBRARIES
558+ return GetDsoHandleFromSearchPath (
559+ FLAGS_cuda_dir, " cusolver64_11.dll" , true , {cuda_lib_path});
560+ #else
503561 return GetDsoHandleFromSearchPath (
504562 FLAGS_cuda_dir, win_cusolver_lib, true , {cuda_lib_path});
563+ #endif
505564#else
506565#ifdef WITH_PIP_CUDA_LIBRARIES
507566 return GetDsoHandleFromSearchPath (FLAGS_cuda_dir, " libcusolver.so.11" );
@@ -515,8 +574,26 @@ void* GetCusparseDsoHandle() {
515574#if defined(__APPLE__) || defined(__OSX__)
516575 return GetDsoHandleFromSearchPath (FLAGS_cuda_dir, " libcusparse.dylib" );
517576#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA)
518- return GetDsoHandleFromSearchPath (
519- FLAGS_cuda_dir, win_cusparse_lib, true , {cuda_lib_path});
577+ if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000 ) {
578+ #ifdef WITH_PIP_CUDA_LIBRARIES
579+ return GetDsoHandleFromSearchPath (FLAGS_cuda_dir, " cusparse64_11.dll" );
580+ #else
581+ return GetDsoHandleFromSearchPath (
582+ FLAGS_cuda_dir, win_cusparse_lib, true , {cuda_lib_path});
583+ #endif
584+ } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 12030 ) {
585+ #ifdef WITH_PIP_CUDA_LIBRARIES
586+ return GetDsoHandleFromSearchPath (FLAGS_cuda_dir, " cusparse64_12.dll" );
587+ #else
588+ return GetDsoHandleFromSearchPath (
589+ FLAGS_cuda_dir, win_cusparse_lib, true , {cuda_lib_path});
590+ #endif
591+ } else {
592+ std::string warning_msg (
593+ " Your CUDA_VERSION is less than 11 or greater than 12, paddle "
594+ " temporarily no longer supports" );
595+ return nullptr ;
596+ }
520597#elif defined(__linux__) && defined(PADDLE_WITH_CUDA)
521598 if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000 ) {
522599#ifdef WITH_PIP_CUDA_LIBRARIES
@@ -709,8 +786,26 @@ void* GetCUFFTDsoHandle() {
709786 return nullptr ;
710787 }
711788#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA)
712- return GetDsoHandleFromSearchPath (
713- FLAGS_cuda_dir, win_cufft_lib, true , {cuda_lib_path});
789+ if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000 ) {
790+ #ifdef WITH_PIP_CUDA_LIBRARIES
791+ return GetDsoHandleFromSearchPath (FLAGS_cuda_dir, " cufft64_10.dll" );
792+ #else
793+ return GetDsoHandleFromSearchPath (
794+ FLAGS_cuda_dir, win_cufft_lib, true , {cuda_lib_path});
795+ #endif
796+ } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 12030 ) {
797+ #ifdef WITH_PIP_CUDA_LIBRARIES
798+ return GetDsoHandleFromSearchPath (FLAGS_cuda_dir, " cufft64_11.dll" );
799+ #else
800+ return GetDsoHandleFromSearchPath (
801+ FLAGS_cuda_dir, win_cufft_lib, true , {cuda_lib_path});
802+ #endif
803+ } else {
804+ std::string warning_msg (
805+ " Your CUDA_VERSION is less than 11 or greater than 12, paddle "
806+ " temporarily no longer supports" );
807+ return nullptr ;
808+ }
714809#else
715810 return GetDsoHandleFromSearchPath (FLAGS_cuda_dir, " libcufft.so" );
716811#endif
0 commit comments