Skip to content

Commit 7590573

Browse files
committed
fix issues
1 parent b49aa86 commit 7590573

2 files changed

Lines changed: 38 additions & 8 deletions

File tree

setup.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,21 +159,29 @@ def run(self):
159159
os.chdir('thirdparties')
160160

161161
# flash-attention
162+
# MI3xx has both CK backend and Triton backend.
162163
if (rocm_arch := os.environ['ROCM_ARCH']) in INSTINCT_ARCH:
163164
subprocess.check_call(['git', 'clone', '--recursive', 'https://github.com/ROCm/flash-attention.git'])
164165
os.chdir('flash-attention')
165166
num_jobs = os.cpu_count() - 1
166167
subprocess.check_call(['pip', 'install', '-v', '.', f'MAX_JOBS={num_jobs}'], shell=True)
167168
os.chdir('..')
169+
# Only Triton backend supports Radeon GPUs
170+
elif (rocm_arch := os.environ['ROCM_ARCH']) in RADEON_ARCH:
171+
subprocess.check_call(['git', 'clone', '--recursive', 'https://github.com/ROCm/flash-attention.git'])
172+
os.chdir('flash-attention')
173+
subprocess.check_call(['git', 'checkout', 'main_perf'])
174+
subprocess.check_call(['FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"', 'python', 'setup.py', 'install', ], shell=True)
175+
os.chdir('..')
168176

169-
# # xformers
170-
rocm_arch = os.environ['ROCM_ARCH']
171-
subprocess.check_call(['git', 'clone', 'https://github.com/ROCm/xformers.git'])
172-
os.chdir('xformers')
173-
subprocess.check_call(['git', 'submodule', 'update', '--init', '--recursive'])
174-
os.environ['PYTORCH_ROCM_ARCH'] = rocm_arch
175-
subprocess.check_call(['python', 'setup.py', 'install'])
176-
os.chdir('..')
177+
# only install xformers in Instinct GPUs
178+
if (rocm_arch := os.environ['ROCM_ARCH']) in INSTINCT_ARCH:
179+
subprocess.check_call(['git', 'clone', 'https://github.com/ROCm/xformers.git'])
180+
os.chdir('xformers')
181+
subprocess.check_call(['git', 'submodule', 'update', '--init', '--recursive'])
182+
os.environ['PYTORCH_ROCM_ARCH'] = rocm_arch
183+
subprocess.check_call(['python', 'setup.py', 'install'])
184+
os.chdir('..')
177185

178186
# bitsandbytes
179187
subprocess.check_call(['git', 'clone', '--recurse-submodules', 'https://github.com/ROCm/bitsandbytes'])

use_existing_torch.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project and Unsloth project
3+
# Copying from https://github.com/vllm-project/vllm/blob/main/use_existing_torch.py
4+
5+
import glob
6+
7+
requires_files = glob.glob('requirements/*.txt')
8+
requires_files += ["pyproject.toml"]
9+
for file in requires_files:
10+
print(f">>> cleaning {file}")
11+
with open(file) as f:
12+
lines = f.readlines()
13+
if "torch" in "".join(lines).lower():
14+
print("removed:")
15+
with open(file, 'w') as f:
16+
for line in lines:
17+
if 'torch' not in line.lower():
18+
f.write(line)
19+
else:
20+
print(line.strip())
21+
print(f"<<< done cleaning {file}")
22+
print()

0 commit comments

Comments
 (0)