diff --git a/cmake/external/flashattn.cmake b/cmake/external/flashattn.cmake index 39a7480997f9c3..c8461f57a575aa 100644 --- a/cmake/external/flashattn.cmake +++ b/cmake/external/flashattn.cmake @@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn) set(FLASHATTN_SOURCE_SUBDIR csrc) set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn) set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/flashattn) -set(FLASHATTN_TAG a96f8024714455fb86a326e20c3b7f700ec50772) +set(FLASHATTN_TAG 5fc132ac11e78d26471ca09e5ba0cd817c3424d8) set(FLASHATTN_INCLUDE_DIR "${FLASHATTN_INSTALL_DIR}/include" diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 6d33afeb94898a..bf6db9bd8d5c0f 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1323,15 +1323,15 @@ void FusedRopeGradInferMeta(const MetaTensor& sin, "[batch_size, seq_len, num_heads, head_dim]," "but got %u.", input_dims.size())); - if (dout_q) { + if (dout_q && dq) { dq->set_dims(dout_q.dims()); dq->set_dtype(dout_q.dtype()); } - if (dout_k) { + if (dout_k && dk) { dk->set_dims(dout_k.dims()); dk->set_dtype(dout_k.dtype()); } - if (dout_v) { + if (dout_v && dv) { dv->set_dims(dout_v.dims()); dv->set_dtype(dout_v.dtype()); } diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index b64f5435a94da5..d65a4e3955f350 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -20,6 +20,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/gpu/flash_attn_utils.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" PD_DECLARE_bool(cudnn_deterministic); @@ -51,42 +52,53 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, DenseTensor* dk, DenseTensor* dv) { #ifdef PADDLE_WITH_FLASHATTN + // q,k,v [total_*, num_heads, head_dim] + auto dims = q.dims(); + + const int64_t batch_size = cu_seqlens_q.numel() - 1; + const int64_t num_heads = dims[1]; + const int64_t head_size_og = dout.dims()[2]; + const int64_t head_size = dims[2]; + const int64_t total_k = k.dims()[0]; + const int64_t num_heads_k = k.dims()[1]; + + bool is_mha = (num_heads == num_heads_k); + void* dq_ptr = nullptr; void* dk_ptr = nullptr; void* dv_ptr = nullptr; - ctx.template Alloc(dq); - dq_ptr = dq->data(); + DenseTensor dq_tmp; + if (dq) { + dq_ptr = ctx.template Alloc(dq); + } else { + dq_tmp.Resize(dims); + dq_ptr = ctx.template Alloc(&dq_tmp); + } + + std::initializer_list dk_dv_shape = { + total_k, num_heads_k, num_heads / num_heads_k, head_size}; DenseTensor dk_tmp; - if (dk) { + if (dk && is_mha) { ctx.template Alloc(dk); dk_ptr = dk->data(); } else { - dk_tmp = EmptyLike(ctx, k); - dk_ptr = dk_tmp.data(); + dk_tmp.Resize(dk_dv_shape); + dk_ptr = ctx.template Alloc(&dk_tmp); } DenseTensor dv_tmp; - if (dv) { + if (dv && is_mha) { ctx.template Alloc(dv); dv_ptr = dv->data(); } else { - dv_tmp = EmptyLike(ctx, v); - dv_ptr = dv_tmp.data(); + dv_tmp.Resize(dk_dv_shape); + dv_ptr = ctx.template Alloc(&dv_tmp); } const cudaStream_t stream = ctx.stream(); - // q,k,v [total_*, num_heads, head_dim] - auto dims = q.dims(); - - const int64_t batch_size = cu_seqlens_q.numel() - 1; - const int64_t num_heads = dims[1]; - const int64_t head_size_og = dout.dims()[2]; - const int64_t head_size = dims[2]; - const int64_t num_heads_k = k.dims()[1]; - int num_splits = get_num_split(); // TODO(umiswing): add shape check @@ -150,6 +162,14 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr, params.attn_mask_tensor ? params.mask_dims.data() : nullptr); CheckFlashAttnStatus(succ); + if (!is_mha) { + if (dk) { + phi::SumKernel(ctx, dk_tmp, {2}, dk->type(), false, dk); + } + if (dv) { + phi::SumKernel(ctx, dv_tmp, {2}, dv->type(), false, dv); + } + } #else RaiseNotSupportedError(); #endif @@ -171,44 +191,53 @@ void FlashAttnGradKernel(const Context& ctx, DenseTensor* dk, DenseTensor* dv) { #ifdef PADDLE_WITH_FLASHATTN + // q, k, v [batch_size, seq_len, num_heads, head_dim] + const auto& dims = q.dims(); + + const int64_t batch_size = dims[0]; + const int64_t seqlen_q = dims[1]; + const int64_t num_heads = dims[2]; + const int64_t head_size_og = dout.dims()[3]; + const int64_t head_size = dims[3]; + const int64_t seqlen_k = k.dims()[1]; + const int64_t num_heads_k = k.dims()[2]; + + bool is_mha = (num_heads == num_heads_k); + void* dq_ptr = nullptr; void* dk_ptr = nullptr; void* dv_ptr = nullptr; - ctx.template Alloc(dq); - dq_ptr = dq->data(); + DenseTensor dq_tmp; + if (dq) { + dq_ptr = ctx.template Alloc(dq); + } else { + dq_tmp.Resize(dims); + dq_ptr = ctx.template Alloc(&dq_tmp); + } DenseTensor dk_tmp; - if (dk) { + std::initializer_list dk_dv_shape = { + batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}; + if (dk && is_mha) { ctx.template Alloc(dk); dk_ptr = dk->data(); } else { - dk_tmp = EmptyLike(ctx, k); - dk_ptr = dk_tmp.data(); + dk_tmp.Resize(dk_dv_shape); + dk_ptr = ctx.template Alloc(&dk_tmp); } DenseTensor dv_tmp; - if (dv) { + if (dv && is_mha) { ctx.template Alloc(dv); dv_ptr = dv->data(); } else { - dv_tmp = EmptyLike(ctx, v); - dv_ptr = dv_tmp.data(); + dv_tmp.Resize(dk_dv_shape); + dv_ptr = ctx.template Alloc(&dv_tmp); } const cudaStream_t stream = ctx.stream(); - // q, k, v [batch_size, seq_len, num_heads, head_dim] - const auto& dims = q.dims(); - - const int64_t batch_size = dims[0]; - const int64_t seqlen_q = dims[1]; - const int64_t num_heads = dims[2]; - const int64_t head_size_og = dout.dims()[3]; - const int64_t head_size = dims[3]; - const int64_t seqlen_k = k.dims()[1]; - const int64_t num_heads_k = k.dims()[2]; - // TODO(umiswing): add shape check PADDLE_ENFORCE_EQ( head_size_og, @@ -281,6 +310,14 @@ void FlashAttnGradKernel(const Context& ctx, params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr, params.attn_mask_tensor ? params.mask_dims.data() : nullptr); CheckFlashAttnStatus(succ); + if (!is_mha) { + if (dk) { + phi::SumKernel(ctx, dk_tmp, {3}, dk->type(), false, dk); + } + if (dv) { + phi::SumKernel(ctx, dv_tmp, {3}, dv->type(), false, dv); + } + } #else RaiseNotSupportedError(); #endif diff --git a/paddle/phi/kernels/gpu/flash_attn_utils.h b/paddle/phi/kernels/gpu/flash_attn_utils.h index a9caec4bb5202c..f68236f5391aa7 100644 --- a/paddle/phi/kernels/gpu/flash_attn_utils.h +++ b/paddle/phi/kernels/gpu/flash_attn_utils.h @@ -114,7 +114,7 @@ struct FlashAttnParamsBase { max_seqlen_q(_max_seqlen_q), max_seqlen_k(_max_seqlen_k), num_heads(_num_heads), - num_heads_k(_num_heads), + num_heads_k(_num_heads_k), head_size(_head_size), softmax_scale(_scale), causal(_causal), diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 867d4066c1eaca..5f589dd4c69022 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -643,7 +643,8 @@ def send_forward_backward_recv_forward_backward( if _timers is not None: _timers("send_forward_backward_recv_forward_backward").start() - self._send_meta(output_tensor) + if output_tensor is not None: + self._send_meta(output_tensor) if recv_prev: self._recv_meta() diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py index b74d4a87af20e2..986fbff73d62e0 100644 --- a/test/legacy_test/test_flash_attention.py +++ b/test/legacy_test/test_flash_attention.py @@ -202,6 +202,8 @@ def test_unpadded(self): fetches_result[0], out_, rtol=5e-03, atol=1e-03 ) + paddle.disable_static() + def test_all(self): print( f"Test case shape {self.shape} dtype {self.dtype} causal {self.causal}" @@ -318,6 +320,8 @@ def test_all(self): fetches_result[0], out_, rtol=5e-03, atol=1e-03 ) + paddle.disable_static() + @unittest.skipIf( not is_flashattn_supported(), @@ -536,5 +540,287 @@ def test_all(self): ) +@unittest.skipIf( + not is_flashattn_supported(), + "core is not compiled with CUDA and cuda version need larger than or equal to 11.4" + "and device's compute capability must be 7.5 or 8.x", +) +class TestFlashAttentionGQA(unittest.TestCase): + def setUp(self): + self.batch_size = 2 + self.num_head = 8 + self.seq_len = 8192 + self.head_dim = 128 + self.num_group = 2 + self.dtype = paddle.bfloat16 + + def gen_unpadded_data(self, dtype): + seq_len_q = np.random.randint( + low=1, high=self.seq_len, size=[self.batch_size] + ) + seq_len_k = np.random.randint( + low=1, high=self.seq_len, size=[self.batch_size] + ) + cu_seqlen_q = paddle.to_tensor( + [0] + np.cumsum(seq_len_q).tolist(), dtype=paddle.int32 + ) + cu_seqlen_k = paddle.to_tensor( + [0] + np.cumsum(seq_len_k).tolist(), dtype=paddle.int32 + ) + + qs, ks, vs = [], [], [] + for i in range(self.batch_size): + tmp_q = ( + paddle.randn( + [seq_len_q[i] * self.num_head * self.head_dim], dtype=dtype + ) + / 1e2 + ) + tmp_k = ( + paddle.randn( + [ + seq_len_k[i] + * self.num_head + * self.head_dim + // self.num_group + ], + dtype=dtype, + ) + / 1e2 + ) + tmp_v = ( + paddle.randn( + [ + seq_len_k[i] + * self.num_head + * self.head_dim + // self.num_group + ], + dtype=dtype, + ) + / 1e2 + ) + qs.append(tmp_q) + ks.append(tmp_k) + vs.append(tmp_v) + + q = paddle.concat(qs, axis=0).reshape( + [-1, self.num_head, self.head_dim] + ) + k = paddle.concat(ks, axis=0).reshape( + [-1, self.num_head // self.num_group, self.head_dim] + ) + v = paddle.concat(vs, axis=0).reshape( + [-1, self.num_head // self.num_group, self.head_dim] + ) + return q, k, v, cu_seqlen_q, cu_seqlen_k + + def gen_test_data(self, dtype, use_unpadded): + assert self.num_head % self.num_group == 0 + if use_unpadded: + q, k, v, cu_seqlen_q, cu_seqlen_k = self.gen_unpadded_data(dtype) + else: + q = ( + paddle.randn( + [ + self.batch_size, + self.seq_len, + self.num_head, + self.head_dim, + ], + dtype=dtype, + ) + / 1e2 + ) + k = ( + paddle.randn( + [ + self.batch_size, + self.seq_len, + self.num_head // self.num_group, + self.head_dim, + ], + dtype=dtype, + ) + / 1e2 + ) + v = ( + paddle.randn( + [ + self.batch_size, + self.seq_len, + self.num_head // self.num_group, + self.head_dim, + ], + dtype=dtype, + ) + / 1e2 + ) + cu_seqlen_q = None + cu_seqlen_k = None + out_grad = paddle.randn(q.shape, dtype=dtype) / 1e2 + return q, k, v, cu_seqlen_q, cu_seqlen_k, out_grad + + def clone_tensor(self, tensor): + if tensor is None: + return None + elif isinstance(tensor, (list, tuple)): + return [self.clone_tensor(t) for t in tensor] + else: + tensor = tensor.detach().clone() + tensor.stop_gradient = False + return tensor + + @paddle.no_grad() + def convert_dtype(self, tensors): + ret = [] + for t in tensors: + if t.dtype in [paddle.float16, paddle.bfloat16]: + t = t.astype(paddle.float32) + t = t.numpy() + ret.append(t) + return ret + + def calc_fa( + self, q, k, v, cu_seqlen_q, cu_seqlen_k, out_grad, causal, use_unpadded + ): + q, k, v = self.clone_tensor([q, k, v]) + if use_unpadded: + scale = self.head_dim ** (-0.5) + out = flash_attn_unpadded( + q, + k, + v, + cu_seqlens_q=cu_seqlen_q, + cu_seqlens_k=cu_seqlen_k, + max_seqlen_q=self.seq_len, + max_seqlen_k=self.seq_len, + scale=scale, + causal=causal, + ) + else: + out = flash_attention(q, k, v, causal=causal) + out = out[0] + out.backward(out_grad) + return self.convert_dtype([out, q.grad, k.grad, v.grad]) + + def calc_raw_attn( + self, q, k, v, cu_seqlen_q, cu_seqlen_k, out_grad, causal, use_unpadded + ): + q, k, v = self.clone_tensor([q, k, v]) + if use_unpadded: + qq, q_mask = self.pad(q, cu_seqlen_q, self.seq_len) + kk, k_mask = self.pad(k, cu_seqlen_k, self.seq_len) + vv, _ = self.pad(v, cu_seqlen_k, self.seq_len) + qk_mask = paddle.matmul(q_mask, k_mask, transpose_y=True) + qk_mask = qk_mask.reshape( + [self.batch_size, 1, self.seq_len, self.seq_len] + ) + qk_mask[qk_mask == 0] = -1e6 + qk_mask[qk_mask == 1] = 0 + else: + qq, kk, vv = q, k, v + + assert len(qq.shape) == 4, qq.shape + assert len(kk.shape) == 4, kk.shape + assert len(vv.shape) == 4, vv.shape + perm = [0, 2, 1, 3] + qq = paddle.transpose(qq, perm) + kk = paddle.transpose(kk, perm) + kk = paddle.stack([kk] * self.num_group, axis=2).reshape(qq.shape) + vv = paddle.transpose(vv, perm) + vv = paddle.stack([vv] * self.num_group, axis=2).reshape(qq.shape) + scale = self.head_dim ** (-0.5) + weight = paddle.matmul(qq * scale, kk, transpose_y=True) + if use_unpadded: + weight += qk_mask + if causal: + shape = weight.shape[-2:] + mask = paddle.full(shape, -np.inf, dtype=weight.dtype) + mask = paddle.triu(mask, diagonal=1) + weight += mask + + weight = weight.astype(paddle.float32) + weight = F.softmax(weight) + out = paddle.matmul(weight.astype(vv.dtype), vv) + out = paddle.transpose(out, perm) + if use_unpadded: + out = self.unpad(out, cu_seqlen_q) + out.backward(out_grad) + return self.convert_dtype([out, q.grad, k.grad, v.grad]) + + def pad(self, x, cu_seqlen, max_seqlen): + cu_seqlen_cpu = cu_seqlen.numpy() + split_sections = [] + for i in range(len(cu_seqlen_cpu) - 1): + split_sections.append(cu_seqlen_cpu[i + 1] - cu_seqlen_cpu[i]) + + tmp_xs = paddle.split(x, split_sections) + batch_size = len(tmp_xs) + tmp_masks = [] + tmp_x_pads = [] + for i in range(batch_size): + tmp_mask = paddle.ones([max_seqlen], dtype=x.dtype) + tmp_mask[split_sections[i] :] = 0 + tmp_mask = tmp_mask.reshape([1, -1, 1]) + tmp_masks.append(tmp_mask) + + tmp_shape = tmp_xs[i].shape + tmp_pad = paddle.zeros( + [max_seqlen - tmp_shape[0]] + list(tmp_shape[1:]), dtype=x.dtype + ) + tmp_x = paddle.concat([tmp_xs[i], tmp_pad]).unsqueeze(0) + tmp_x_pads.append(tmp_x) + + x_pad = paddle.concat(tmp_x_pads) + mask = paddle.concat(tmp_masks) + return x_pad, mask + + def unpad(self, x, cu_seqlen): + cu_seqlen_cpu = cu_seqlen.numpy() + xs = paddle.split(x, x.shape[0]) + tmp_xs = [] + for i in range(len(cu_seqlen_cpu) - 1): + tmp = xs[i].squeeze(0)[: cu_seqlen_cpu[i + 1] - cu_seqlen_cpu[i]] + tmp_xs.append(tmp) + unpad_x = paddle.concat(tmp_xs) + return unpad_x + + def test_main(self): + for causal in [False, True]: + for use_unpadded in [False, True]: + ( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + ) = self.gen_test_data(self.dtype, use_unpadded) + fa_out = self.calc_fa( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + causal, + use_unpadded, + ) + raw_out = self.calc_raw_attn( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + causal, + use_unpadded, + ) + assert len(fa_out) == len(raw_out) + for t1, t2 in zip(fa_out, raw_out): + np.testing.assert_allclose(t1, t2, atol=1e-2, rtol=1e-2) + + if __name__ == '__main__': unittest.main() diff --git a/third_party/flashattn b/third_party/flashattn index a96f8024714455..5fc132ac11e78d 160000 --- a/third_party/flashattn +++ b/third_party/flashattn @@ -1 +1 @@ -Subproject commit a96f8024714455fb86a326e20c3b7f700ec50772 +Subproject commit 5fc132ac11e78d26471ca09e5ba0cd817c3424d8