Skip to content

Commit cd06115

Browse files
authored
Add possibility to call AttnType.Aiter (#562)
1 parent 52e74e8 commit cd06115

File tree

2 files changed

+175
-207
lines changed

2 files changed

+175
-207
lines changed

tests/core/test_xfuser_attn.py

Lines changed: 136 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
from itertools import product
23
import torch
34
import torch.distributed as dist
45
from xfuser.core.long_ctx_attention.ring.ring_flash_attn import (
@@ -18,7 +19,9 @@
1819
init_distributed_environment,
1920
initialize_model_parallel,
2021
)
22+
from xfuser.envs import PACKAGES_CHECKER
2123

24+
from yunchang.kernels import AttnType
2225

2326
def init_dist(backend="nccl"):
2427
local_rank = int(os.environ["LOCAL_RANK"])
@@ -52,35 +55,39 @@ def init_dist(backend="nccl"):
5255
class TestRingFlashAttn(unittest.TestCase):
5356
@classmethod
5457
def setUpClass(cls):
58+
env_info = PACKAGES_CHECKER.get_packages_info()
59+
5560
cls.batch_size = 1
5661
cls.num_heads = 4
5762
cls.head_dim = 32
5863
cls.seq_len = 128
59-
cls.dtype = torch.float16
6064

6165
cls.rank, cls.world_size, cls.ring_degree, cls.ulysses_degree = init_dist()
6266
cls.device = torch.device(f"cuda:{cls.rank}")
6367

68+
cls.HAS_AITER = env_info["has_aiter"]
69+
cls.HAS_FLASH_ATTENTION = env_info["has_flash_attn"]
70+
6471
def setUp(self):
6572
torch.manual_seed(42 + self.rank)
6673

6774
@classmethod
6875
def tearDownClass(cls):
6976
dist.destroy_process_group()
7077

71-
def _create_test_tensors(self):
78+
def _create_test_tensors(self, dtype=torch.float16):
7279
"""Helper to create test input tensors"""
7380
shape = (self.batch_size, self.seq_len, self.num_heads, self.head_dim)
7481

7582
# Prepare inputs
7683
q = torch.randn(
77-
shape, device=self.device, dtype=self.dtype, requires_grad=False
84+
shape, device=self.device, dtype=dtype, requires_grad=False
7885
)
7986
k = torch.randn(
80-
shape, device=self.device, dtype=self.dtype, requires_grad=False
87+
shape, device=self.device, dtype=dtype, requires_grad=False
8188
)
8289
v = torch.randn(
83-
shape, device=self.device, dtype=self.dtype, requires_grad=False
90+
shape, device=self.device, dtype=dtype, requires_grad=False
8491
)
8592

8693
dist.broadcast(q, src=0)
@@ -94,96 +101,134 @@ def _create_test_tensors(self):
94101

95102
def test_xfuser_attn_layer_joint_strategy_rear(self):
96103
"""Test xFuserLongContextAttention layer in distributed mode"""
97-
# Create test tensors
98-
q, k, v, local_q, local_k, local_v = self._create_test_tensors()
99-
joint_q, joint_k, joint_v, local_joint_q, local_joint_k, local_joint_v = (
100-
self._create_test_tensors()
101-
)
102-
joint_strategy = "rear"
103-
104-
attn = None
105-
106-
# Create attention layer
107-
attn_layer = xFuserLongContextAttention(
108-
scatter_idx=2,
109-
gather_idx=1,
110-
ring_impl_type="basic",
111-
use_kv_cache=False,
112-
).to(device=self.device, dtype=self.dtype)
113-
114-
assert attn_layer.ring_pg.size() == self.ring_degree
115-
assert attn_layer.ulysses_pg.size() == self.ulysses_degree
116-
117-
ref_output = flash_attn_func(
118-
torch.cat([q, joint_q], dim=1),
119-
torch.cat([k, joint_k], dim=1),
120-
torch.cat([v, joint_v], dim=1),
121-
dropout_p=0.0,
122-
window_size=(-1, -1),
123-
)
124104

125-
# Split ref_output into base and joint parts
126-
base_out = ref_output[:, : self.seq_len, ::] # First half for base attention
127-
joint_out = ref_output[:, self.seq_len :, ::] # Second half for joint attention
128-
129-
# Get local shard for base output
130-
base_out_shard = base_out.chunk(self.world_size, dim=1)[self.rank]
131-
# Duplicate joint output as specified
132-
ref_output = torch.cat([base_out_shard, joint_out], dim=1)
133-
134-
# Run distributed implementation
135-
output = attn_layer(
136-
attn=None,
137-
query=local_q,
138-
key=local_k,
139-
value=local_v,
140-
dropout_p=0.0,
141-
window_size=(-1, -1),
142-
joint_tensor_query=joint_q,
143-
joint_tensor_key=joint_k,
144-
joint_tensor_value=joint_v,
145-
joint_strategy=joint_strategy,
146-
)
147-
# assert torch.max(torch.abs(output - ref_output)) < 1e-3
148-
torch.testing.assert_close(ref_output, output, rtol=1e-3, atol=1e-3)
105+
backends = ["AITER", "FA"]
106+
dtypes = [torch.float16, torch.bfloat16]
107+
108+
for backend, dtype in product(backends, dtypes):
109+
with self.subTest(backend=backend, dtype=dtype):
110+
111+
if backend=="AITER":
112+
if self.HAS_AITER and 'AITER' in AttnType.__members__:
113+
attn_type=AttnType.AITER
114+
else:
115+
self.skipTest("AITER backend not applicable")
116+
if backend=="FA":
117+
if self.HAS_FLASH_ATTENTION:
118+
attn_type=AttnType.FA
119+
else:
120+
self.skipTest("FA backend not applicable")
121+
122+
# Create test tensors
123+
q, k, v, local_q, local_k, local_v = self._create_test_tensors(dtype=dtype)
124+
joint_q, joint_k, joint_v, local_joint_q, local_joint_k, local_joint_v = (
125+
self._create_test_tensors(dtype=dtype)
126+
)
127+
joint_strategy = "rear"
128+
129+
attn = None
130+
131+
# Create attention layer
132+
attn_layer = xFuserLongContextAttention(
133+
scatter_idx=2,
134+
gather_idx=1,
135+
ring_impl_type="basic",
136+
attn_type=attn_type,
137+
use_kv_cache=False,
138+
).to(device=self.device, dtype=torch.float16)
139+
140+
assert attn_layer.ring_pg.size() == self.ring_degree
141+
assert attn_layer.ulysses_pg.size() == self.ulysses_degree
142+
143+
ref_output = flash_attn_func(
144+
torch.cat([q, joint_q], dim=1),
145+
torch.cat([k, joint_k], dim=1),
146+
torch.cat([v, joint_v], dim=1),
147+
dropout_p=0.0,
148+
window_size=(-1, -1),
149+
)
150+
151+
# Split ref_output into base and joint parts
152+
base_out = ref_output[:, : self.seq_len, ::] # First half for base attention
153+
joint_out = ref_output[:, self.seq_len :, ::] # Second half for joint attention
154+
155+
# Get local shard for base output
156+
base_out_shard = base_out.chunk(self.world_size, dim=1)[self.rank]
157+
# Duplicate joint output as specified
158+
ref_output = torch.cat([base_out_shard, joint_out], dim=1)
159+
160+
# Run distributed implementation
161+
output = attn_layer(
162+
attn=None,
163+
query=local_q,
164+
key=local_k,
165+
value=local_v,
166+
dropout_p=0.0,
167+
window_size=(-1, -1),
168+
joint_tensor_query=joint_q,
169+
joint_tensor_key=joint_k,
170+
joint_tensor_value=joint_v,
171+
joint_strategy=joint_strategy,
172+
)
173+
# assert torch.max(torch.abs(output - ref_output)) < 1e-3
174+
torch.testing.assert_close(ref_output, output, rtol=1e-2, atol=1e-2)
149175

150176
def test_xfuser_attn_layer(self):
151177
"""Test xFuserLongContextAttention layer in distributed mode"""
152-
# Create test tensors
153-
q, k, v, local_q, local_k, local_v = self._create_test_tensors()
154-
attn = None
155-
156-
# Create attention layer
157-
attn_layer = xFuserLongContextAttention(
158-
scatter_idx=2,
159-
gather_idx=1,
160-
ring_impl_type="basic",
161-
use_kv_cache=False,
162-
).to(device=self.device, dtype=self.dtype)
163-
164-
assert attn_layer.ring_pg.size() == self.ring_degree
165-
assert attn_layer.ulysses_pg.size() == self.ulysses_degree
166-
167-
ref_output = flash_attn_func(
168-
q,
169-
k,
170-
v,
171-
dropout_p=0.0,
172-
window_size=(-1, -1),
173-
)
174-
ref_output = ref_output.chunk(self.world_size, dim=1)[self.rank]
175-
176-
# Run distributed implementation
177-
output = attn_layer(
178-
attn=None,
179-
query=local_q,
180-
key=local_k,
181-
value=local_v,
182-
dropout_p=0.0,
183-
window_size=(-1, -1),
184-
)
185-
assert torch.max(torch.abs(output - ref_output)) < 1e-3
186-
torch.testing.assert_close(ref_output, output, rtol=1e-3, atol=1e-3)
178+
179+
backends = ["AITER", "FA"]
180+
dtypes = [torch.float16, torch.bfloat16]
181+
182+
for backend, dtype in product(backends, dtypes):
183+
with self.subTest(backend=backend, dtype=dtype):
184+
185+
if backend=="AITER":
186+
if self.HAS_AITER and 'AITER' in AttnType.__members__:
187+
attn_type=AttnType.AITER
188+
else:
189+
self.skipTest("AITER backend not applicable")
190+
if backend=="FA":
191+
if self.HAS_FLASH_ATTENTION:
192+
attn_type=AttnType.FA
193+
else:
194+
self.skipTest("FA backend not applicable")
195+
196+
# Create test tensors
197+
q, k, v, local_q, local_k, local_v = self._create_test_tensors(dtype)
198+
199+
# Create attention layer
200+
attn_layer = xFuserLongContextAttention(
201+
scatter_idx=2,
202+
gather_idx=1,
203+
ring_impl_type="basic",
204+
use_kv_cache=False,
205+
attn_type=attn_type,
206+
).to(device=self.device, dtype=dtype)
207+
208+
assert attn_layer.ring_pg.size() == self.ring_degree
209+
assert attn_layer.ulysses_pg.size() == self.ulysses_degree
210+
211+
ref_output = flash_attn_func(
212+
q,
213+
k,
214+
v,
215+
dropout_p=0.0,
216+
window_size=(-1, -1),
217+
)
218+
ref_output = ref_output.chunk(self.world_size, dim=1)[self.rank]
219+
220+
# Run distributed implementation
221+
output = attn_layer(
222+
attn=None,
223+
query=local_q,
224+
key=local_k,
225+
value=local_v,
226+
dropout_p=0.0,
227+
window_size=(-1, -1),
228+
)
229+
230+
assert torch.max(torch.abs(output - ref_output)) < 1e-2
231+
torch.testing.assert_close(ref_output, output, rtol=1e-2, atol=1e-2)
187232

188233

189234
# torchrun --nproc_per_node=4 -m unittest tests/core/test_xfuser_attn.py

0 commit comments

Comments
 (0)