@@ -149,13 +149,6 @@ def validate_and_update_archs(archs):
149149 TORCH_MAJOR = int (torch .__version__ .split ("." )[0 ])
150150 TORCH_MINOR = int (torch .__version__ .split ("." )[1 ])
151151
152- # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
153- # See https://github.com/pytorch/pytorch/pull/70650
154- generator_flag = []
155- torch_dir = torch .__path__ [0 ]
156- if os .path .exists (os .path .join (torch_dir , "include" , "ATen" , "CUDAGeneratorImpl.h" )):
157- generator_flag = ["-DOLD_GENERATOR_PATH" ]
158-
159152 check_if_cuda_home_none ("flash_attn" )
160153 # Check, if CUDA11 is installed for compute capability 8.0
161154 cc_flag = []
@@ -271,7 +264,7 @@ def validate_and_update_archs(archs):
271264 "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu" ,
272265 ],
273266 extra_compile_args = {
274- "cxx" : ["-O3" , "-std=c++17" ] + generator_flag ,
267+ "cxx" : ["-O3" , "-std=c++17" ],
275268 "nvcc" : append_nvcc_threads (
276269 [
277270 "-O3" ,
@@ -293,7 +286,6 @@ def validate_and_update_archs(archs):
293286 # "-DFLASHATTENTION_DISABLE_UNEVEN_K",
294287 # "-DFLASHATTENTION_DISABLE_LOCAL",
295288 ]
296- + generator_flag
297289 + cc_flag
298290 ),
299291 },
0 commit comments