Skip to content

Commit a5856bf

Browse files
committed
[Bwd,Sm90] For dQ, move wait_group before TMA atomic add
1 parent 72c7ba4 commit a5856bf

2 files changed

Lines changed: 18 additions & 20 deletions

File tree

flash_attn/cute/block_sparse_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,6 +1352,12 @@ def _store_one_dQaccum_sm90(
13521352
tma_copy_bytes_dQ,
13531353
):
13541354
"""Store dQaccum for a single m_block."""
1355+
for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups):
1356+
cute.arch.cp_async_bulk_wait_group(num_mma_warp_groups - 1 - warp_group_idx, read=True)
1357+
cute.arch.barrier_arrive(
1358+
barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
1359+
number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE,
1360+
)
13551361
for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups):
13561362
cute.arch.barrier(
13571363
barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
@@ -1364,12 +1370,6 @@ def _store_one_dQaccum_sm90(
13641370
tma_copy_bytes_dQ,
13651371
)
13661372
cute.arch.cp_async_bulk_commit_group()
1367-
for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups):
1368-
cute.arch.cp_async_bulk_wait_group(num_mma_warp_groups - 1 - warp_group_idx, read=True)
1369-
cute.arch.barrier_arrive(
1370-
barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
1371-
number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE,
1372-
)
13731373

13741374

13751375
@cute.jit

flash_attn/cute/flash_bwd_sm90.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -668,11 +668,6 @@ def kernel(
668668
qhead_per_kvhead_divmod,
669669
)
670670
if warp_idx == 1:
671-
for warp_group_idx in cutlass.range(self.num_mma_warp_groups):
672-
cute.arch.barrier_arrive(
673-
barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
674-
number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE,
675-
)
676671
self.dQaccum_store(
677672
mdQaccum,
678673
sdQaccum,
@@ -1605,6 +1600,16 @@ def dQaccum_store(
16051600
m_block = m_block_min + iter_idx
16061601
m_block_safe = m_block
16071602

1603+
for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups):
1604+
cute.arch.cp_async_bulk_wait_group(
1605+
self.num_mma_warp_groups - 1 - warp_group_idx, read=True
1606+
)
1607+
cute.arch.barrier_arrive(
1608+
barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
1609+
number_of_threads=self.num_threads_per_warp_group
1610+
+ cute.arch.WARP_SIZE,
1611+
)
1612+
16081613
for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups):
16091614
cute.arch.barrier(
16101615
barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
@@ -1618,15 +1623,6 @@ def dQaccum_store(
16181623
self.tma_copy_bytes["dQ"],
16191624
)
16201625
cute.arch.cp_async_bulk_commit_group()
1621-
for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups):
1622-
cute.arch.cp_async_bulk_wait_group(
1623-
self.num_mma_warp_groups - 1 - warp_group_idx, read=True
1624-
)
1625-
cute.arch.barrier_arrive(
1626-
barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
1627-
number_of_threads=self.num_threads_per_warp_group
1628-
+ cute.arch.WARP_SIZE,
1629-
)
16301626
else:
16311627
dQaccum_store_block_sparse_bwd_sm90(
16321628
blocksparse_tensors,
@@ -1643,3 +1639,5 @@ def dQaccum_store(
16431639
)
16441640
tile_scheduler.advance_to_next_work()
16451641
work_tile = tile_scheduler.get_current_work()
1642+
1643+
cute.arch.cp_async_bulk_wait_group(0, read=True)

0 commit comments

Comments
 (0)