Skip to content

Commit 6020d24

Browse files
umiswingkuizhiqing
authored andcommitted
part-3 cherry from: [cherry-pick] Integration flash attention 2 (PaddlePaddle#56015)
* [FlashAttn] add flash randomness control (PaddlePaddle#52902) * add flash randomness control * fix VLOG undefied * [WIP] Integration flash attention 2 (PaddlePaddle#55758) * Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80. --------- Co-authored-by: Chitsing KUI <[email protected]>
1 parent 64f30b7 commit 6020d24

File tree

3 files changed

+106
-1
lines changed

3 files changed

+106
-1
lines changed

cmake/external/flashattn.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
2020
set(FLASHATTN_SOURCE_SUBDIR csrc)
2121
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
2222
set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/flashattn)
23-
set(FLASHATTN_TAG 18106c1ba0ccee81b97ca947397c08a141815a47)
23+
set(FLASHATTN_TAG b5bdb79d5e1f2f88b1ef62e86899a14f82fa079a)
2424

2525
set(FLASHATTN_INCLUDE_DIR
2626
"${FLASHATTN_INSTALL_DIR}/include"

paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323

2424
PD_DECLARE_bool(cudnn_deterministic);
2525

26+
#ifdef PADDLE_WITH_FLASHATTN
27+
#include "paddle/phi/backends/dynload/flashattn.h"
28+
#endif
29+
2630
namespace phi {
2731

2832
int get_num_split() {
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License
14+
15+
import logging
16+
17+
from ...utils.log_utils import get_logger
18+
19+
_logger = get_logger(logging.INFO)
20+
from ..random import determinate_rng, is_enable_auto_rand_ctrl
21+
from .common import (
22+
DistributedOperatorImplContainer,
23+
register_distributed_operator_impl,
24+
register_distributed_operator_impl_container,
25+
)
26+
from .dist_eltwise import DistributedDefaultImpl0, DistributedElementwiseImpl0
27+
28+
29+
class DistributedFlashAttn(DistributedOperatorImplContainer):
30+
def __init__(self, op_type):
31+
super().__init__(op_type)
32+
33+
34+
register_distributed_operator_impl_container(DistributedFlashAttn("flash_attn"))
35+
36+
37+
# Dist FlashAttn with Random Control
38+
class DistributedFlashAttnImpl0(DistributedElementwiseImpl0):
39+
def __init__(self, name):
40+
super().__init__(name)
41+
self._forward_implemented = True
42+
self._backward_implemented = True
43+
44+
def is_input_compatible(self, dist_op):
45+
return True
46+
47+
def is_output_compatible(self, dist_op):
48+
return True
49+
50+
def is_auto_compatible(self, dist_op):
51+
return True
52+
53+
@staticmethod
54+
def forward(ctx, *args, **kwargs):
55+
dist_op_context = ctx.dist_op_context
56+
main_block = dist_op_context.work_block
57+
startup_block = dist_op_context.startup_block
58+
src_op = dist_op_context.cur_src_op
59+
rank_id = dist_op_context.rank_id
60+
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
61+
62+
if (
63+
is_enable_auto_rand_ctrl()
64+
and not op_dist_attr.is_recompute
65+
and rank_id in op_dist_attr.process_mesh.process_ids
66+
):
67+
assert (
68+
op_dist_attr is not None
69+
), f"forward op [{str(src_op)}] don't have dist attribute !"
70+
71+
if (
72+
len(kwargs.get('fixed_seed_offset', [])) > 0
73+
or len(src_op.input("fixed_seed_offset")) > 0
74+
):
75+
# TODO(kuizhiqing) recompute should go here
76+
pass
77+
else:
78+
# determinate rng
79+
q_var = main_block._var_recursive(kwargs['q'][0])
80+
k_var = main_block._var_recursive(kwargs['k'][0])
81+
q_dims_mapping = op_dist_attr.get_input_dims_mapping(q_var.name)
82+
k_dims_mapping = op_dist_attr.get_input_dims_mapping(k_var.name)
83+
process_mesh = op_dist_attr.process_mesh
84+
dims_mapping = q_dims_mapping[:3] + [q_dims_mapping[2]]
85+
86+
rng_name = determinate_rng(rank_id, dims_mapping, process_mesh)
87+
assert rng_name is not None and rng_name != ""
88+
89+
src_op._set_attr('rng_name', rng_name)
90+
91+
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
92+
93+
@staticmethod
94+
def backward(ctx, *args, **kwargs):
95+
# dropout backward is deterministic by mask, and not need for random state control
96+
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
97+
98+
99+
register_distributed_operator_impl(
100+
"flash_attn", DistributedFlashAttnImpl0("random_control")
101+
)

0 commit comments

Comments
 (0)