Skip to content

Commit 4b554d0

Browse files
build flash-attn whl (PaddlePaddle#33)
* simplify code * add files * fix python version * windows fixed * del time
1 parent d98d8a3 commit 4b554d0

File tree

6 files changed

+215
-10
lines changed

6 files changed

+215
-10
lines changed

csrc/CMakeLists.txt

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,24 @@
11
cmake_minimum_required(VERSION 3.9 FATAL_ERROR)
22
project(flash-attention LANGUAGES CXX CUDA)
3+
set(CMAKE_CXX_STANDARD 17)
4+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
5+
36

47
find_package(Git QUIET REQUIRED)
58

69
execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive
710
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
811
RESULT_VARIABLE GIT_SUBMOD_RESULT)
912

13+
#cmake -DWITH_ADVANCED=ON
14+
if (WITH_ADVANCED)
15+
add_compile_definitions(PADDLE_WITH_ADVANCED)
16+
endif()
17+
1018
add_definitions("-DFLASH_ATTN_WITH_TORCH=0")
1119

1220
set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass)
21+
set(BINARY_DIR ${CMAKE_BINARY_DIR})
1322

1423
set(FA2_SOURCES_CU
1524
flash_attn/src/cuda_utils.cu
@@ -55,6 +64,7 @@ target_include_directories(flashattn PRIVATE
5564
flash_attn
5665
${CUTLASS_3_DIR}/include)
5766

67+
if (WITH_ADVANCED)
5868
set(FA1_SOURCES_CU
5969
flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu
6070
flash_attn_with_bias_and_mask/src/cuda_utils.cu
@@ -65,6 +75,12 @@ set(FA1_SOURCES_CU
6575
flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu
6676
flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu
6777
flash_attn_with_bias_and_mask/src/utils.cu)
78+
else()
79+
set(FA1_SOURCES_CU
80+
flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu
81+
flash_attn_with_bias_and_mask/src/cuda_utils.cu
82+
flash_attn_with_bias_and_mask/src/utils.cu)
83+
endif()
6884

6985
add_library(flashattn_with_bias_mask STATIC
7086
flash_attn_with_bias_and_mask/
@@ -83,18 +99,14 @@ target_link_libraries(flashattn flashattn_with_bias_mask)
8399

84100
add_dependencies(flashattn flashattn_with_bias_mask)
85101

102+
set(NVCC_ARCH_BIN 80 CACHE STRING "CUDA architectures")
86103

87-
if (NOT DEFINED NVCC_ARCH_BIN)
88-
message(FATAL_ERROR "NVCC_ARCH_BIN is not defined.")
89-
endif()
90-
91-
if (NVCC_ARCH_BIN STREQUAL "")
92-
message(FATAL_ERROR "NVCC_ARCH_BIN is not set.")
93-
endif()
104+
message("NVCC_ARCH_BIN is set to: ${NVCC_ARCH_BIN}")
94105

95106
STRING(REPLACE "-" ";" FA_NVCC_ARCH_BIN ${NVCC_ARCH_BIN})
96107

97108
set(FA_GENCODE_OPTION "SHELL:")
109+
98110
foreach(arch ${FA_NVCC_ARCH_BIN})
99111
if(${arch} GREATER_EQUAL 80)
100112
set(FA_GENCODE_OPTION "${FA_GENCODE_OPTION} -gencode arch=compute_${arch},code=sm_${arch}")
@@ -131,7 +143,35 @@ target_compile_options(flashattn_with_bias_mask PRIVATE $<$<COMPILE_LANGUAGE:CUD
131143
"${FA_GENCODE_OPTION}"
132144
>)
133145

146+
134147
INSTALL(TARGETS flashattn
135148
LIBRARY DESTINATION "lib")
136149

137150
INSTALL(FILES capi/flash_attn.h DESTINATION "include")
151+
152+
if (WITH_ADVANCED)
153+
if(WIN32)
154+
set(target_output_name "flashattn")
155+
else()
156+
set(target_output_name "libflashattn")
157+
endif()
158+
set_target_properties(flashattn PROPERTIES
159+
OUTPUT_NAME ${target_output_name}_advanced
160+
PREFIX ""
161+
)
162+
163+
configure_file(${CMAKE_SOURCE_DIR}/env_dict.py.in ${CMAKE_SOURCE_DIR}/env_dict.py @ONLY)
164+
set_target_properties(flashattn PROPERTIES
165+
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/paddle_flash_attn/
166+
)
167+
add_custom_target(build_whl
168+
COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist_wheel
169+
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}
170+
DEPENDS flashattn
171+
COMMENT "Running build wheel"
172+
)
173+
174+
add_custom_target(default_target DEPENDS build_whl)
175+
176+
set_property(DIRECTORY PROPERTY DEFAULT_TARGET default_target)
177+
endif()

csrc/env_dict.py.in

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
env_dict = {
2+
'CMAKE_BINARY_DIR': '@CMAKE_BINARY_DIR@'
3+
}

csrc/flash_attn/src/flash_bwd_launch_template.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
6868
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
6969
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
7070
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
71-
BOOL_SWITCH(is_attn_mask, Is_attn_mask, [&] {
72-
BOOL_SWITCH(is_deterministic, Is_deterministic, [&] {
71+
BOOL_SWITCH_ADVANCED(is_attn_mask, Is_attn_mask, [&] {
72+
BOOL_SWITCH_ADVANCED(is_deterministic, Is_deterministic, [&] {
7373
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, IsEvenKConst, Is_attn_mask && !IsCausalConst, Is_deterministic>;
7474
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
7575
if (smem_size_dq_dk_dv >= 48 * 1024) {

csrc/flash_attn/src/flash_fwd_launch_template.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
4040
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
4141
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
4242
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
43-
BOOL_SWITCH(is_attn_mask, Is_attn_mask, [&] {
43+
BOOL_SWITCH_ADVANCED(is_attn_mask, Is_attn_mask, [&] {
4444
BOOL_SWITCH(is_equal_qk, Is_equal_seq_qk, [&] {
4545
// Will only return softmax if dropout, to reduce compilation time.
4646
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout, Is_attn_mask && !Is_causal, Is_equal_seq_qk>;

csrc/flash_attn/src/static_switch.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,25 @@
2525
} \
2626
}()
2727

28+
#ifdef PADDLE_WITH_ADVANCED
29+
#define BOOL_SWITCH_ADVANCED(COND, CONST_NAME, ...) \
30+
[&] { \
31+
if (COND) { \
32+
constexpr static bool CONST_NAME = true; \
33+
return __VA_ARGS__(); \
34+
} else { \
35+
constexpr static bool CONST_NAME = false; \
36+
return __VA_ARGS__(); \
37+
} \
38+
}()
39+
#else
40+
#define BOOL_SWITCH_ADVANCED(COND, CONST_NAME, ...) \
41+
[&] { \
42+
constexpr static bool CONST_NAME = false; \
43+
return __VA_ARGS__(); \
44+
}()
45+
#endif
46+
2847
#define FP16_SWITCH(COND, ...) \
2948
[&] { \
3049
if (COND) { \

csrc/setup.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
16+
import ast
17+
import os
18+
import re
19+
import subprocess
20+
import sys
21+
from pathlib import Path
22+
23+
from env_dict import env_dict
24+
from setuptools import setup
25+
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
26+
27+
28+
with open("../../README.md", "r", encoding="utf-8") as fh:
29+
long_description = fh.read()
30+
31+
32+
cur_dir = os.path.dirname(os.path.abspath(__file__))
33+
PACKAGE_NAME = "paddle-flash-attn"
34+
35+
36+
def get_platform():
37+
"""
38+
Returns the platform name as used in wheel filenames.
39+
"""
40+
if sys.platform.startswith('linux'):
41+
return 'linux_x86_64'
42+
elif sys.platform == 'win32':
43+
return 'win_amd64'
44+
else:
45+
raise ValueError(f'Unsupported platform: {sys.platform}')
46+
47+
48+
def get_cuda_version():
49+
try:
50+
result = subprocess.run(
51+
['nvcc', '--version'],
52+
stdout=subprocess.PIPE,
53+
stderr=subprocess.PIPE,
54+
text=True,
55+
)
56+
if result.returncode == 0:
57+
output_lines = result.stdout.split('\n')
58+
for line in output_lines:
59+
if line.startswith('Cuda compilation tools'):
60+
cuda_version = (
61+
line.split('release')[1].strip().split(',')[0]
62+
)
63+
return cuda_version
64+
else:
65+
print("Error:", result.stderr)
66+
67+
except Exception as e:
68+
print("Error:", str(e))
69+
70+
return None
71+
72+
73+
def get_package_version():
74+
with open(Path(cur_dir) / "../flash_attn" / "__init__.py", "r") as f:
75+
version_match = re.search(
76+
r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE
77+
)
78+
public_version = ast.literal_eval(version_match.group(1))
79+
local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION")
80+
if local_version:
81+
return f"{public_version}+{local_version}"
82+
else:
83+
return str(public_version)
84+
85+
86+
def get_package_data():
87+
binary_dir = env_dict.get("CMAKE_BINARY_DIR")
88+
lib = os.path.join(
89+
os.path.abspath(os.path.dirname(__file__)),
90+
binary_dir + '/paddle_flash_attn/*',
91+
)
92+
package_data = {'paddle_flash_attn': [lib]}
93+
return package_data
94+
95+
96+
class CustomWheelsCommand(_bdist_wheel):
97+
"""
98+
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
99+
find an existing wheel (which is currently the case for all flash attention installs). We use
100+
the environment parameters to detect whether there is already a pre-built version of a compatible
101+
wheel available and short-circuits the standard full build pipeline.
102+
"""
103+
104+
def run(self):
105+
self.run_command('build_ext')
106+
super().run()
107+
cuda_version = get_cuda_version()
108+
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
109+
platform_name = get_platform()
110+
flash_version = get_package_version()
111+
wheel_name = 'paddle_flash_attn'
112+
113+
# Determine wheel URL based on CUDA version, python version and OS
114+
impl_tag, abi_tag, plat_tag = self.get_tag()
115+
archive_basename = (
116+
f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
117+
)
118+
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
119+
print("Raw wheel path", wheel_path)
120+
wheel_filename = f'{wheel_name}-{flash_version}+cu{cuda_version}-{impl_tag}-{abi_tag}-{platform_name}.whl'
121+
os.rename(wheel_path, os.path.join(self.dist_dir, wheel_filename))
122+
123+
124+
setup(
125+
name=PACKAGE_NAME,
126+
version=get_package_version(),
127+
packages=['paddle_flash_attn'],
128+
package_data=get_package_data(),
129+
author_email="[email protected]",
130+
description="Flash Attention: Fast and Memory-Efficient Exact Attention",
131+
long_description=long_description,
132+
long_description_content_type="text/markdown",
133+
url="https://github.com/PaddlePaddle/flash-attention",
134+
classifiers=[
135+
"Programming Language :: Python :: 3",
136+
"License :: OSI Approved :: BSD License",
137+
"Operating System :: Unix",
138+
],
139+
cmdclass={
140+
'bdist_wheel': CustomWheelsCommand,
141+
},
142+
python_requires=">=3.7",
143+
)

0 commit comments

Comments
 (0)