-
Notifications
You must be signed in to change notification settings - Fork 74
Expand file tree
/
Copy pathsetup.py
More file actions
149 lines (124 loc) · 5.58 KB
/
setup.py
File metadata and controls
149 lines (124 loc) · 5.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, CppExtension, BuildExtension
import torch
import sys
import os
# Force unbuffered output
os.environ['PYTHONUNBUFFERED'] = '1'
sys.stderr.reconfigure(line_buffering=True)
def log(msg):
"""Print to both stdout and stderr."""
print(msg)
print(msg, file=sys.stderr, flush=True)
def configure_cuda():
"""Configure CUDA/ROCm backend."""
fallback_archs = [
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
]
log("Compiling for CUDA.")
compiler_args = {"cxx": ["-O3", "-DFUSED_SSIM_CUDA"], "nvcc": ["-O3", "-DFUSED_SSIM_CUDA"]}
if torch.version.hip:
log("Detected AMD GPU with ROCm/HIP")
compiler_args["nvcc"].append("-ffast-math")
detected_arch = "AMD GPU (ROCm/HIP)"
else:
compiler_args["nvcc"].extend(("--maxrregcount=32", "--use_fast_math"))
# Check for CUDA_ARCHITECTURES environment variable first
cuda_archs_env = os.environ.get('CUDA_ARCHITECTURES')
arch_configured = False
if cuda_archs_env:
try:
archs = [arch.strip() for arch in cuda_archs_env.split(';')]
log(f"Using CUDA architectures from environment: {archs}")
for arch in archs:
compiler_args["nvcc"].append(f"-gencode=arch=compute_{arch},code=sm_{arch}")
detected_arch = f"env:{','.join(archs)}"
arch_configured = True
except Exception as e:
log(f"Failed to parse CUDA_ARCHITECTURES environment variable: {e}. Trying device detection.")
# Try device detection if environment variable not set or failed
if not arch_configured:
try:
device = torch.cuda.current_device()
compute_capability = torch.cuda.get_device_capability(device)
arch = f"sm_{compute_capability[0]}{compute_capability[1]}"
log(f"Detected GPU architecture: {arch}")
compiler_args["nvcc"].append(f"-arch={arch}")
detected_arch = arch
arch_configured = True
except Exception as e:
log(f"Failed to detect GPU architecture: {e}. Falling back to multiple architectures.")
# Fallback to multiple architectures if both methods failed
if not arch_configured:
compiler_args["nvcc"].extend(fallback_archs)
detected_arch = "multiple architectures"
return CUDAExtension, ["ssim.cu", "ssim3d.cu", "ext.cpp"], "fused_ssim_cuda", compiler_args, [], detected_arch
def configure_mps():
"""Configure Apple MPS backend."""
log("Compiling for MPS.")
compiler_args = {"cxx": ["-O3", "-std=c++17", "-ObjC++", "-Wno-unused-parameter"]}
link_args = ["-framework", "Metal", "-framework", "Foundation"]
return CppExtension, ["ssim.mm","ext.cpp"], "fused_ssim_mps", compiler_args, link_args, "Apple Silicon (MPS)"
def configure_xpu():
"""Configure Intel XPU (SYCL) backend."""
log("Compiling for XPU.")
os.environ['CXX'] = 'icpx'
compiler_args = {"cxx": ["-O3", "-std=c++17", "-fsycl"]}
link_args = ["-fsycl"]
try:
device_name = torch.xpu.get_device_name(0)
log(f"Detected Intel XPU: {device_name}")
detected_arch = f"Intel XPU (SYCL) - {device_name}"
except Exception:
log("Detected Intel XPU (SYCL)")
detected_arch = "Intel XPU (SYCL)"
return CppExtension, ["ssim_sycl.cpp","ext.cpp"], "fused_ssim_xpu", compiler_args, link_args, detected_arch
# Detect backend
if torch.cuda.is_available():
extension_type, extension_files, build_name, compiler_args, link_args, detected_arch = configure_cuda()
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
extension_type, extension_files, build_name, compiler_args, link_args, detected_arch = configure_mps()
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
extension_type, extension_files, build_name, compiler_args, link_args, detected_arch = configure_xpu()
else:
extension_type, extension_files, build_name, compiler_args, link_args, detected_arch = configure_cuda()
# Create a custom class that prints the architecture information
class CustomBuildExtension(BuildExtension):
def build_extensions(self):
# For SYCL, override compiler to use icpx
if 'xpu' in build_name:
self.compiler.compiler_so = ['icpx'] + self.compiler.compiler_so[1:]
self.compiler.compiler_cxx = ['icpx'] + self.compiler.compiler_cxx[1:]
self.compiler.linker_so = ['icpx'] + self.compiler.linker_so[1:]
arch_info = f"Building with GPU architecture: {detected_arch if detected_arch else 'multiple architectures'}"
print("\n" + "="*50)
print(arch_info)
print("="*50 + "\n")
super().build_extensions()
setup(
name="fused_ssim",
packages=['fused_ssim'],
ext_modules=[
extension_type(
name=build_name,
sources=extension_files,
extra_compile_args=compiler_args,
extra_link_args=link_args
)
],
cmdclass={
'build_ext': CustomBuildExtension
}
)
# Print again at the end of setup.py execution
if "nvcc" in compiler_args:
final_msg = "Setup completed. NVCC args: {}. CXX args: {}. Link args: {}.".format(
compiler_args["nvcc"], compiler_args["cxx"], link_args
)
else:
final_msg = "Setup completed. CXX args: {}. Link args: {}.".format(
compiler_args["cxx"], link_args
)
print(final_msg)