Skip to content

Commit ee7ddbf

Browse files
committed
Add paths of cl and CUDA before loading anything
1 parent c79a578 commit ee7ddbf

File tree

4 files changed

+53
-13
lines changed

4 files changed

+53
-13
lines changed

python/triton/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
"""isort:skip_file"""
22
__version__ = '3.0.0'
33

4+
# Users may not know how to add cl and CUDA to PATH. Let's do it before loading anything
5+
import os
6+
if os.name == "nt":
7+
from .windows_utils import find_cuda, find_msvc_winsdk
8+
msvc_winsdk_inc_dirs, _ = find_msvc_winsdk()
9+
if msvc_winsdk_inc_dirs:
10+
cl_path = msvc_winsdk_inc_dirs[0].replace(r"\include", r"\bin\Hostx64\x64")
11+
os.environ["PATH"] = cl_path + os.pathsep + os.environ["PATH"]
12+
cuda_bin_path, _, _ = find_cuda()
13+
if cuda_bin_path:
14+
os.environ["PATH"] = cuda_bin_path + os.pathsep + os.environ["PATH"]
15+
416
# ---------------------------------------
517
# Note: import order is significant here.
618

python/triton/runtime/build.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import setuptools
99

1010
if os.name == "nt":
11-
from .windows import find_msvc_winsdk, find_python
11+
from triton.windows_utils import find_msvc_winsdk, find_python
1212

1313

1414
@contextlib.contextmanager
@@ -44,12 +44,6 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
4444
# try to avoid setuptools if possible
4545
cc = os.environ.get("CC")
4646
if cc is None:
47-
# Users may not know how to add cl to PATH. Let's do it for them
48-
if os.name == "nt":
49-
msvc_winsdk_inc_dirs, _ = find_msvc_winsdk()
50-
if msvc_winsdk_inc_dirs:
51-
cl_path = msvc_winsdk_inc_dirs[0].replace(r"\include", r"\bin\Hostx64\x64")
52-
os.environ["PATH"] = cl_path + os.pathsep + os.environ["PATH"]
5347
# TODO: support more things here.
5448
cl = shutil.which("cl")
5549
gcc = shutil.which("gcc")

python/triton/runtime/windows.py renamed to python/triton/windows_utils.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@ def unparse_version(t, prefix=""):
2121
return prefix + ".".join([str(x) for x in t])
2222

2323

24-
def max_version(versions):
25-
versions = [parse_version(x) for x in versions]
24+
def max_version(versions, prefix=""):
25+
versions = [parse_version(x, prefix) for x in versions]
2626
versions = [x for x in versions if x is not None]
27-
version = unparse_version(max(versions))
27+
if not versions:
28+
return None
29+
version = unparse_version(max(versions), prefix)
2830
return version
2931

3032

@@ -105,6 +107,10 @@ def find_msvc():
105107
return [], []
106108

107109
version = max_version(os.listdir(msvc_base_path))
110+
if version is None:
111+
print("WARNING: Failed to find MSVC.")
112+
return [], []
113+
108114
return (
109115
[str(msvc_base_path / version / "include")],
110116
[str(msvc_base_path / version / "lib" / "x64")],
@@ -144,6 +150,10 @@ def find_winsdk():
144150
return [], []
145151

146152
version = max_version(os.listdir(winsdk_base_path / "Include"))
153+
if version is None:
154+
print("WARNING: Failed to find Windows SDK.")
155+
return [], []
156+
147157
return (
148158
[
149159
str(winsdk_base_path / "Include" / version / "shared"),
@@ -174,3 +184,25 @@ def find_python():
174184
rf"C:\Python{version}\libs",
175185
]
176186
return python_lib_dirs
187+
188+
189+
@functools.cache
190+
def find_cuda():
191+
cuda_base_path = os.environ.get("CUDA_PATH")
192+
if cuda_base_path is not None:
193+
cuda_base_path = Path(cuda_base_path)
194+
if not cuda_base_path.exists():
195+
cuda_base_path = None
196+
197+
if cuda_base_path is None:
198+
paths = glob(r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12*")
199+
if not paths:
200+
return None, [], []
201+
# Heuristic to find the highest version
202+
cuda_base_path = Path(sorted(paths)[-1])
203+
204+
return (
205+
str(cuda_base_path / "bin"),
206+
[str(cuda_base_path / "include")],
207+
[str(cuda_base_path / "lib" / "x64")],
208+
)

third_party/nvidia/backend/driver.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,18 @@
1212
dirname = os.path.dirname(os.path.realpath(__file__))
1313
include_dir = [os.path.join(dirname, "include")]
1414
if os.name == "nt":
15-
cuda_path = os.environ.get("CUDA_PATH")
16-
include_dir += [f"{cuda_path}\\include"]
15+
from triton.windows_utils import find_cuda
16+
_, cuda_inc_dirs, _ = find_cuda()
17+
include_dir += cuda_inc_dirs
1718
libdevice_dir = os.path.join(dirname, "lib")
1819
libraries = ['cuda']
1920

2021

2122
@functools.lru_cache()
2223
def libcuda_dirs():
2324
if os.name == "nt":
24-
return [os.path.join(cuda_path, "lib", "x64")]
25+
_, _, cuda_lib_dirs = find_cuda()
26+
return cuda_lib_dirs
2527

2628
env_libcuda_path = os.getenv("TRITON_LIBCUDA_PATH")
2729
if env_libcuda_path:

0 commit comments

Comments
 (0)