Skip to content

Commit 7a1d6a7

Browse files
committed
Check CUDA installation
1 parent 4ad0461 commit 7a1d6a7

File tree

1 file changed

+24
-4
lines changed

1 file changed

+24
-4
lines changed

python/triton/windows_utils.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,20 +186,40 @@ def find_python():
186186
return python_lib_dirs
187187

188188

189+
def check_cuda(cuda_base_path):
190+
return all(
191+
x.exists()
192+
for x in [
193+
cuda_base_path / "bin" / "cudart64_12.dll",
194+
cuda_base_path / "bin" / "ptxas.exe",
195+
cuda_base_path / "include" / "cuda.h",
196+
cuda_base_path / "lib" / "x64" / "cuda.lib",
197+
]
198+
)
199+
200+
189201
@functools.cache
190202
def find_cuda():
191203
cuda_base_path = os.environ.get("CUDA_PATH")
192204
if cuda_base_path is not None:
193205
cuda_base_path = Path(cuda_base_path)
194-
if not cuda_base_path.exists():
206+
if not check_cuda(cuda_base_path):
195207
cuda_base_path = None
196208

197209
if cuda_base_path is None:
198210
paths = glob(r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12*")
199-
if not paths:
200-
return None, [], []
201211
# Heuristic to find the highest version
202-
cuda_base_path = Path(sorted(paths)[-1])
212+
paths = sorted(paths)[::-1]
213+
for path in paths:
214+
cuda_base_path = Path(path)
215+
if check_cuda(cuda_base_path):
216+
break
217+
else:
218+
cuda_base_path = None
219+
220+
if cuda_base_path is None:
221+
print("WARNING: Failed to find CUDA.")
222+
return None, [], []
203223

204224
return (
205225
str(cuda_base_path / "bin"),

0 commit comments

Comments
 (0)