File tree Expand file tree Collapse file tree 1 file changed +24
-4
lines changed Expand file tree Collapse file tree 1 file changed +24
-4
lines changed Original file line number Diff line number Diff line change @@ -186,20 +186,40 @@ def find_python():
186186 return python_lib_dirs
187187
188188
189+ def check_cuda (cuda_base_path ):
190+ return all (
191+ x .exists ()
192+ for x in [
193+ cuda_base_path / "bin" / "cudart64_12.dll" ,
194+ cuda_base_path / "bin" / "ptxas.exe" ,
195+ cuda_base_path / "include" / "cuda.h" ,
196+ cuda_base_path / "lib" / "x64" / "cuda.lib" ,
197+ ]
198+ )
199+
200+
189201@functools .cache
190202def find_cuda ():
191203 cuda_base_path = os .environ .get ("CUDA_PATH" )
192204 if cuda_base_path is not None :
193205 cuda_base_path = Path (cuda_base_path )
194- if not cuda_base_path . exists ( ):
206+ if not check_cuda ( cuda_base_path ):
195207 cuda_base_path = None
196208
197209 if cuda_base_path is None :
198210 paths = glob (r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12*" )
199- if not paths :
200- return None , [], []
201211 # Heuristic to find the highest version
202- cuda_base_path = Path (sorted (paths )[- 1 ])
212+ paths = sorted (paths )[::- 1 ]
213+ for path in paths :
214+ cuda_base_path = Path (path )
215+ if check_cuda (cuda_base_path ):
216+ break
217+ else :
218+ cuda_base_path = None
219+
220+ if cuda_base_path is None :
221+ print ("WARNING: Failed to find CUDA." )
222+ return None , [], []
203223
204224 return (
205225 str (cuda_base_path / "bin" ),
You can’t perform that action at this time.
0 commit comments