@@ -159,21 +159,29 @@ def run(self):
159159 os .chdir ('thirdparties' )
160160
161161 # flash-attention
162+ # MI3xx has both CK backend and Triton backend.
162163 if (rocm_arch := os .environ ['ROCM_ARCH' ]) in INSTINCT_ARCH :
163164 subprocess .check_call (['git' , 'clone' , '--recursive' , 'https://github.com/ROCm/flash-attention.git' ])
164165 os .chdir ('flash-attention' )
165166 num_jobs = os .cpu_count () - 1
166167 subprocess .check_call (['pip' , 'install' , '-v' , '.' , f'MAX_JOBS={ num_jobs } ' ], shell = True )
167168 os .chdir ('..' )
169+ # Only Triton backend supports Radeon GPUs
170+ elif (rocm_arch := os .environ ['ROCM_ARCH' ]) in RADEON_ARCH :
171+ subprocess .check_call (['git' , 'clone' , '--recursive' , 'https://github.com/ROCm/flash-attention.git' ])
172+ os .chdir ('flash-attention' )
173+ subprocess .check_call (['git' , 'checkout' , 'main_perf' ])
174+ subprocess .check_call (['FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"' , 'python' , 'setup.py' , 'install' , ], shell = True )
175+ os .chdir ('..' )
168176
169- # # xformers
170- rocm_arch = os .environ ['ROCM_ARCH' ]
171- subprocess .check_call (['git' , 'clone' , 'https://github.com/ROCm/xformers.git' ])
172- os .chdir ('xformers' )
173- subprocess .check_call (['git' , 'submodule' , 'update' , '--init' , '--recursive' ])
174- os .environ ['PYTORCH_ROCM_ARCH' ] = rocm_arch
175- subprocess .check_call (['python' , 'setup.py' , 'install' ])
176- os .chdir ('..' )
177+ # only install xformers in Instinct GPUs
178+ if ( rocm_arch : = os .environ ['ROCM_ARCH' ]) in INSTINCT_ARCH :
179+ subprocess .check_call (['git' , 'clone' , 'https://github.com/ROCm/xformers.git' ])
180+ os .chdir ('xformers' )
181+ subprocess .check_call (['git' , 'submodule' , 'update' , '--init' , '--recursive' ])
182+ os .environ ['PYTORCH_ROCM_ARCH' ] = rocm_arch
183+ subprocess .check_call (['python' , 'setup.py' , 'install' ])
184+ os .chdir ('..' )
177185
178186 # bitsandbytes
179187 subprocess .check_call (['git' , 'clone' , '--recurse-submodules' , 'https://github.com/ROCm/bitsandbytes' ])
0 commit comments