Skip to content

Commit c79a578

Browse files
committed
Detect MSVC using hard-coded path
1 parent 54af0cf commit c79a578

File tree

1 file changed

+26
-8
lines changed

1 file changed

+26
-8
lines changed

python/triton/runtime/windows.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import sys
66
import sysconfig
77
import winreg
8+
from glob import glob
89
from pathlib import Path
910

1011

@@ -20,6 +21,13 @@ def unparse_version(t, prefix=""):
2021
return prefix + ".".join([str(x) for x in t])
2122

2223

24+
def max_version(versions):
25+
versions = [parse_version(x) for x in versions]
26+
versions = [x for x in versions if x is not None]
27+
version = unparse_version(max(versions))
28+
return version
29+
30+
2331
def find_msvc_base_vswhere():
2432
program_files = os.environ.get("ProgramFiles(x86)", r"C:\Program Files (x86)")
2533
vswhere_path = (
@@ -72,18 +80,31 @@ def find_msvc_base_envpath():
7280
return None
7381

7482

83+
def find_msvc_base_hardcoded():
84+
msvc_base_path = Path(r"C:\Program Files (x86)\Microsoft Visual Studio")
85+
if not msvc_base_path.exists():
86+
msvc_base_path = Path(r"C:\Program Files\Microsoft Visual Studio")
87+
if not msvc_base_path.exists():
88+
return None
89+
90+
paths = glob(str(msvc_base_path / "*" / "*" / "VC" / "Tools" / "MSVC"))
91+
if not paths:
92+
return None
93+
# Heuristic to find the highest version
94+
return Path(sorted(paths)[-1])
95+
96+
7597
def find_msvc():
7698
msvc_base_path = find_msvc_base_vswhere()
7799
if msvc_base_path is None:
78100
msvc_base_path = find_msvc_base_envpath()
101+
if msvc_base_path is None:
102+
msvc_base_path = find_msvc_base_hardcoded()
79103
if msvc_base_path is None:
80104
print("WARNING: Failed to find MSVC.")
81105
return [], []
82106

83-
versions = [parse_version(x) for x in os.listdir(msvc_base_path)]
84-
versions = [x for x in versions if x is not None]
85-
version = unparse_version(max(versions))
86-
107+
version = max_version(os.listdir(msvc_base_path))
87108
return (
88109
[str(msvc_base_path / version / "include")],
89110
[str(msvc_base_path / version / "lib" / "x64")],
@@ -122,10 +143,7 @@ def find_winsdk():
122143
print("WARNING: Failed to find Windows SDK.")
123144
return [], []
124145

125-
versions = [parse_version(x) for x in os.listdir(winsdk_base_path / "Include")]
126-
versions = [x for x in versions if x is not None]
127-
version = unparse_version(max(versions))
128-
146+
version = max_version(os.listdir(winsdk_base_path / "Include"))
129147
return (
130148
[
131149
str(winsdk_base_path / "Include" / version / "shared"),

0 commit comments

Comments
 (0)