11import unittest
2+ from itertools import product
23import torch
34import torch .distributed as dist
45from xfuser .core .long_ctx_attention .ring .ring_flash_attn import (
1819 init_distributed_environment ,
1920 initialize_model_parallel ,
2021)
22+ from xfuser .envs import PACKAGES_CHECKER
2123
24+ from yunchang .kernels import AttnType
2225
2326def init_dist (backend = "nccl" ):
2427 local_rank = int (os .environ ["LOCAL_RANK" ])
@@ -52,35 +55,39 @@ def init_dist(backend="nccl"):
5255class TestRingFlashAttn (unittest .TestCase ):
5356 @classmethod
5457 def setUpClass (cls ):
58+ env_info = PACKAGES_CHECKER .get_packages_info ()
59+
5560 cls .batch_size = 1
5661 cls .num_heads = 4
5762 cls .head_dim = 32
5863 cls .seq_len = 128
59- cls .dtype = torch .float16
6064
6165 cls .rank , cls .world_size , cls .ring_degree , cls .ulysses_degree = init_dist ()
6266 cls .device = torch .device (f"cuda:{ cls .rank } " )
6367
68+ cls .HAS_AITER = env_info ["has_aiter" ]
69+ cls .HAS_FLASH_ATTENTION = env_info ["has_flash_attn" ]
70+
6471 def setUp (self ):
6572 torch .manual_seed (42 + self .rank )
6673
6774 @classmethod
6875 def tearDownClass (cls ):
6976 dist .destroy_process_group ()
7077
71- def _create_test_tensors (self ):
78+ def _create_test_tensors (self , dtype = torch . float16 ):
7279 """Helper to create test input tensors"""
7380 shape = (self .batch_size , self .seq_len , self .num_heads , self .head_dim )
7481
7582 # Prepare inputs
7683 q = torch .randn (
77- shape , device = self .device , dtype = self . dtype , requires_grad = False
84+ shape , device = self .device , dtype = dtype , requires_grad = False
7885 )
7986 k = torch .randn (
80- shape , device = self .device , dtype = self . dtype , requires_grad = False
87+ shape , device = self .device , dtype = dtype , requires_grad = False
8188 )
8289 v = torch .randn (
83- shape , device = self .device , dtype = self . dtype , requires_grad = False
90+ shape , device = self .device , dtype = dtype , requires_grad = False
8491 )
8592
8693 dist .broadcast (q , src = 0 )
@@ -94,96 +101,134 @@ def _create_test_tensors(self):
94101
95102 def test_xfuser_attn_layer_joint_strategy_rear (self ):
96103 """Test xFuserLongContextAttention layer in distributed mode"""
97- # Create test tensors
98- q , k , v , local_q , local_k , local_v = self ._create_test_tensors ()
99- joint_q , joint_k , joint_v , local_joint_q , local_joint_k , local_joint_v = (
100- self ._create_test_tensors ()
101- )
102- joint_strategy = "rear"
103-
104- attn = None
105-
106- # Create attention layer
107- attn_layer = xFuserLongContextAttention (
108- scatter_idx = 2 ,
109- gather_idx = 1 ,
110- ring_impl_type = "basic" ,
111- use_kv_cache = False ,
112- ).to (device = self .device , dtype = self .dtype )
113-
114- assert attn_layer .ring_pg .size () == self .ring_degree
115- assert attn_layer .ulysses_pg .size () == self .ulysses_degree
116-
117- ref_output = flash_attn_func (
118- torch .cat ([q , joint_q ], dim = 1 ),
119- torch .cat ([k , joint_k ], dim = 1 ),
120- torch .cat ([v , joint_v ], dim = 1 ),
121- dropout_p = 0.0 ,
122- window_size = (- 1 , - 1 ),
123- )
124104
125- # Split ref_output into base and joint parts
126- base_out = ref_output [:, : self .seq_len , ::] # First half for base attention
127- joint_out = ref_output [:, self .seq_len :, ::] # Second half for joint attention
128-
129- # Get local shard for base output
130- base_out_shard = base_out .chunk (self .world_size , dim = 1 )[self .rank ]
131- # Duplicate joint output as specified
132- ref_output = torch .cat ([base_out_shard , joint_out ], dim = 1 )
133-
134- # Run distributed implementation
135- output = attn_layer (
136- attn = None ,
137- query = local_q ,
138- key = local_k ,
139- value = local_v ,
140- dropout_p = 0.0 ,
141- window_size = (- 1 , - 1 ),
142- joint_tensor_query = joint_q ,
143- joint_tensor_key = joint_k ,
144- joint_tensor_value = joint_v ,
145- joint_strategy = joint_strategy ,
146- )
147- # assert torch.max(torch.abs(output - ref_output)) < 1e-3
148- torch .testing .assert_close (ref_output , output , rtol = 1e-3 , atol = 1e-3 )
105+ backends = ["AITER" , "FA" ]
106+ dtypes = [torch .float16 , torch .bfloat16 ]
107+
108+ for backend , dtype in product (backends , dtypes ):
109+ with self .subTest (backend = backend , dtype = dtype ):
110+
111+ if backend == "AITER" :
112+ if self .HAS_AITER and 'AITER' in AttnType .__members__ :
113+ attn_type = AttnType .AITER
114+ else :
115+ self .skipTest ("AITER backend not applicable" )
116+ if backend == "FA" :
117+ if self .HAS_FLASH_ATTENTION :
118+ attn_type = AttnType .FA
119+ else :
120+ self .skipTest ("FA backend not applicable" )
121+
122+ # Create test tensors
123+ q , k , v , local_q , local_k , local_v = self ._create_test_tensors (dtype = dtype )
124+ joint_q , joint_k , joint_v , local_joint_q , local_joint_k , local_joint_v = (
125+ self ._create_test_tensors (dtype = dtype )
126+ )
127+ joint_strategy = "rear"
128+
129+ attn = None
130+
131+ # Create attention layer
132+ attn_layer = xFuserLongContextAttention (
133+ scatter_idx = 2 ,
134+ gather_idx = 1 ,
135+ ring_impl_type = "basic" ,
136+ attn_type = attn_type ,
137+ use_kv_cache = False ,
138+ ).to (device = self .device , dtype = torch .float16 )
139+
140+ assert attn_layer .ring_pg .size () == self .ring_degree
141+ assert attn_layer .ulysses_pg .size () == self .ulysses_degree
142+
143+ ref_output = flash_attn_func (
144+ torch .cat ([q , joint_q ], dim = 1 ),
145+ torch .cat ([k , joint_k ], dim = 1 ),
146+ torch .cat ([v , joint_v ], dim = 1 ),
147+ dropout_p = 0.0 ,
148+ window_size = (- 1 , - 1 ),
149+ )
150+
151+ # Split ref_output into base and joint parts
152+ base_out = ref_output [:, : self .seq_len , ::] # First half for base attention
153+ joint_out = ref_output [:, self .seq_len :, ::] # Second half for joint attention
154+
155+ # Get local shard for base output
156+ base_out_shard = base_out .chunk (self .world_size , dim = 1 )[self .rank ]
157+ # Duplicate joint output as specified
158+ ref_output = torch .cat ([base_out_shard , joint_out ], dim = 1 )
159+
160+ # Run distributed implementation
161+ output = attn_layer (
162+ attn = None ,
163+ query = local_q ,
164+ key = local_k ,
165+ value = local_v ,
166+ dropout_p = 0.0 ,
167+ window_size = (- 1 , - 1 ),
168+ joint_tensor_query = joint_q ,
169+ joint_tensor_key = joint_k ,
170+ joint_tensor_value = joint_v ,
171+ joint_strategy = joint_strategy ,
172+ )
173+ # assert torch.max(torch.abs(output - ref_output)) < 1e-3
174+ torch .testing .assert_close (ref_output , output , rtol = 1e-2 , atol = 1e-2 )
149175
150176 def test_xfuser_attn_layer (self ):
151177 """Test xFuserLongContextAttention layer in distributed mode"""
152- # Create test tensors
153- q , k , v , local_q , local_k , local_v = self ._create_test_tensors ()
154- attn = None
155-
156- # Create attention layer
157- attn_layer = xFuserLongContextAttention (
158- scatter_idx = 2 ,
159- gather_idx = 1 ,
160- ring_impl_type = "basic" ,
161- use_kv_cache = False ,
162- ).to (device = self .device , dtype = self .dtype )
163-
164- assert attn_layer .ring_pg .size () == self .ring_degree
165- assert attn_layer .ulysses_pg .size () == self .ulysses_degree
166-
167- ref_output = flash_attn_func (
168- q ,
169- k ,
170- v ,
171- dropout_p = 0.0 ,
172- window_size = (- 1 , - 1 ),
173- )
174- ref_output = ref_output .chunk (self .world_size , dim = 1 )[self .rank ]
175-
176- # Run distributed implementation
177- output = attn_layer (
178- attn = None ,
179- query = local_q ,
180- key = local_k ,
181- value = local_v ,
182- dropout_p = 0.0 ,
183- window_size = (- 1 , - 1 ),
184- )
185- assert torch .max (torch .abs (output - ref_output )) < 1e-3
186- torch .testing .assert_close (ref_output , output , rtol = 1e-3 , atol = 1e-3 )
178+
179+ backends = ["AITER" , "FA" ]
180+ dtypes = [torch .float16 , torch .bfloat16 ]
181+
182+ for backend , dtype in product (backends , dtypes ):
183+ with self .subTest (backend = backend , dtype = dtype ):
184+
185+ if backend == "AITER" :
186+ if self .HAS_AITER and 'AITER' in AttnType .__members__ :
187+ attn_type = AttnType .AITER
188+ else :
189+ self .skipTest ("AITER backend not applicable" )
190+ if backend == "FA" :
191+ if self .HAS_FLASH_ATTENTION :
192+ attn_type = AttnType .FA
193+ else :
194+ self .skipTest ("FA backend not applicable" )
195+
196+ # Create test tensors
197+ q , k , v , local_q , local_k , local_v = self ._create_test_tensors (dtype )
198+
199+ # Create attention layer
200+ attn_layer = xFuserLongContextAttention (
201+ scatter_idx = 2 ,
202+ gather_idx = 1 ,
203+ ring_impl_type = "basic" ,
204+ use_kv_cache = False ,
205+ attn_type = attn_type ,
206+ ).to (device = self .device , dtype = dtype )
207+
208+ assert attn_layer .ring_pg .size () == self .ring_degree
209+ assert attn_layer .ulysses_pg .size () == self .ulysses_degree
210+
211+ ref_output = flash_attn_func (
212+ q ,
213+ k ,
214+ v ,
215+ dropout_p = 0.0 ,
216+ window_size = (- 1 , - 1 ),
217+ )
218+ ref_output = ref_output .chunk (self .world_size , dim = 1 )[self .rank ]
219+
220+ # Run distributed implementation
221+ output = attn_layer (
222+ attn = None ,
223+ query = local_q ,
224+ key = local_k ,
225+ value = local_v ,
226+ dropout_p = 0.0 ,
227+ window_size = (- 1 , - 1 ),
228+ )
229+
230+ assert torch .max (torch .abs (output - ref_output )) < 1e-2
231+ torch .testing .assert_close (ref_output , output , rtol = 1e-2 , atol = 1e-2 )
187232
188233
189234# torchrun --nproc_per_node=4 -m unittest tests/core/test_xfuser_attn.py
0 commit comments