Skip to content

Commit b8ebdf7

Browse files
authored
Add joint tensor and KV cache support to USP method (#586)
* Add initial USP_join calls * Use Ulysses rank rather than sequence rank * Add KV-cache support and proper joint tensor support to USP * Make Hunyuan use USP rather than Yunchang * Make Flux use USP rather than Yunchang * Enable the new USP function to support all SPs again * Add tests to compare new USP to yunchang * Refactor USP to be less repetitive * Remove inline importing in tests
1 parent 6e5af04 commit b8ebdf7

File tree

4 files changed

+211
-51
lines changed

4 files changed

+211
-51
lines changed

tests/layers/usp_test.py

Lines changed: 92 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,30 +11,29 @@
1111
get_runtime_state,
1212
)
1313
from xfuser.core.distributed.parallel_state import destroy_model_parallel, destroy_distributed_environment
14+
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
15+
from yunchang.kernels import AttnType
1416

1517

18+
def _init_environment():
19+
os.environ["RANK"] = "0"
20+
os.environ["WORLD_SIZE"] = "1"
21+
os.environ["MASTER_ADDR"] = "localhost"
22+
os.environ["MASTER_PORT"] = "12355"
23+
os.environ["LOCAL_RANK"] = "0"
24+
init_distributed_environment(rank=0, world_size=1)
25+
initialize_model_parallel(ring_degree=1, ulysses_degree=1)
26+
1627
class TestUSP(unittest.TestCase):
1728

1829
def setUp(self):
19-
self._init_environment()
30+
_init_environment()
2031
env_info = PACKAGES_CHECKER.get_packages_info()
2132
self.HAS_FLASH_ATTN = env_info["has_flash_attn"]
2233
self.HAS_AITER = env_info["has_aiter"]
23-
self.query = torch.randn(29760, 2, 5, 128, device="cuda", dtype=torch.bfloat16)
24-
self.key = torch.randn(29760, 2, 5, 128, device="cuda", dtype=torch.bfloat16)
25-
self.value = torch.randn(29760, 2, 5, 128, device="cuda", dtype=torch.bfloat16)
26-
27-
28-
def _init_environment(self):
29-
os.environ["RANK"] = "0"
30-
os.environ["WORLD_SIZE"] = "1"
31-
os.environ["MASTER_ADDR"] = "localhost"
32-
os.environ["MASTER_PORT"] = "12355"
33-
os.environ["LOCAL_RANK"] = "0"
34-
init_distributed_environment(rank=0, world_size=1)
35-
initialize_model_parallel(ring_degree=1, ulysses_degree=1)
36-
37-
34+
self.query = torch.randn(1, 24, 14867, 128, device="cuda", dtype=torch.bfloat16)
35+
self.key = torch.randn(1, 24, 14867, 128, device="cuda", dtype=torch.bfloat16)
36+
self.value = torch.randn(1, 24, 14867, 128, device="cuda", dtype=torch.bfloat16)
3837

3938
def tearDown(self):
4039
destroy_model_parallel()
@@ -123,4 +122,80 @@ def test_ring_attn_aiter(self):
123122

124123
result_diff = (fsdpa_results - aiter_attn_results).abs().max()
125124
self.assertNotEqual(result_diff, 0) # Different implementations won't produce same output
126-
self.assertAlmostEqual(result_diff.item(), 0, places=1) # Difference can be 0.15ish
125+
self.assertAlmostEqual(result_diff.item(), 0, places=1) # Difference can be 0.15ish
126+
127+
128+
class TestUSPHybridParallel(unittest.TestCase):
129+
130+
def setUp(self):
131+
_init_environment()
132+
# Using SDPA here
133+
self.HAS_FLASH_ATTN = False
134+
self.HAS_AITER = False
135+
self.query = torch.randn(1, 24, 14867, 128, device="cuda", dtype=torch.bfloat16)
136+
self.key = torch.randn(1, 24, 14867, 128, device="cuda", dtype=torch.bfloat16)
137+
self.value = torch.randn(1, 24, 14867, 128, device="cuda", dtype=torch.bfloat16)
138+
139+
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
140+
attn_type=AttnType.TORCH
141+
)
142+
143+
def tearDown(self):
144+
destroy_model_parallel()
145+
destroy_distributed_environment()
146+
147+
def test_usp_hybrid_equivalence(self):
148+
"""
149+
Tests the output from USP is equivalent to hybrid seq parallel attention, i.e
150+
yunchang path
151+
"""
152+
usp_results = usp.USP(self.query, self.key, self.value, dropout_p=0.0, is_causal=False)
153+
hybrid_results = self.hybrid_seq_parallel_attn(
154+
None,
155+
self.query.transpose(1, 2),
156+
self.key.transpose(1, 2),
157+
self.value.transpose(1, 2),
158+
dropout_p=0.0,
159+
causal=False
160+
).transpose(1, 2)
161+
162+
result_diff = (usp_results - hybrid_results).abs().max().float().cpu().numpy()
163+
self.assertAlmostEqual(result_diff, 0, places=3)
164+
165+
def test_usp_hybrid_joint_equivalence(self):
166+
"""
167+
Tests the output from USP with joint tensors added is equivalent to hybrid seq
168+
parallel attn.
169+
"""
170+
joint_shape = (1, 24, 64, 128)
171+
172+
joint_query = torch.randn(joint_shape, device="cuda", dtype=torch.bfloat16)
173+
joint_key = torch.randn(joint_shape, device="cuda", dtype=torch.bfloat16)
174+
joint_value = torch.randn(joint_shape, device="cuda", dtype=torch.bfloat16)
175+
176+
usp_results = usp.USP(
177+
self.query,
178+
self.key,
179+
self.value,
180+
dropout_p=0.0,
181+
is_causal=False,
182+
joint_query=joint_query,
183+
joint_key=joint_key,
184+
joint_value=joint_value,
185+
joint_strategy="rear"
186+
)
187+
hybrid_results = self.hybrid_seq_parallel_attn(
188+
None,
189+
self.query.transpose(1, 2),
190+
self.key.transpose(1, 2),
191+
self.value.transpose(1, 2),
192+
dropout_p=0.0,
193+
causal=False,
194+
joint_tensor_query=joint_query.transpose(1, 2),
195+
joint_tensor_key=joint_key.transpose(1, 2),
196+
joint_tensor_value=joint_value.transpose(1, 2),
197+
joint_strategy="rear"
198+
).transpose(1, 2)
199+
200+
result_diff = (usp_results - hybrid_results).abs().max().float().cpu().numpy()
201+
self.assertAlmostEqual(result_diff, 0, places=3)

xfuser/model_executor/layers/attention_processor.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,7 +1231,7 @@ def __call__(
12311231
):
12321232
hidden_states = USP(query, key, value, dropout_p=0.0, is_causal=False)
12331233
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
1234-
elif HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
1234+
elif get_sequence_parallel_world_size() > 1:
12351235
if get_runtime_state().split_text_embed_in_sp:
12361236
encoder_query = None
12371237
encoder_key = None
@@ -1247,27 +1247,19 @@ def __call__(
12471247
[num_query_tokens, num_encoder_hidden_states_tokens], dim=2
12481248
)
12491249

1250-
encoder_query = encoder_query.transpose(1, 2)
1251-
encoder_key = encoder_key.transpose(1, 2)
1252-
encoder_value = encoder_value.transpose(1, 2)
1253-
1254-
query = query.transpose(1, 2)
1255-
key = key.transpose(1, 2)
1256-
value = value.transpose(1, 2)
1257-
1258-
hidden_states = self.hybrid_seq_parallel_attn(
1259-
None,
1250+
hidden_states = USP(
12601251
query,
12611252
key,
12621253
value,
12631254
dropout_p=0.0,
1264-
causal=False,
1265-
joint_tensor_query=encoder_query,
1266-
joint_tensor_key=encoder_key,
1267-
joint_tensor_value=encoder_value,
1255+
is_causal=False,
1256+
joint_query=encoder_query,
1257+
joint_key=encoder_key,
1258+
joint_value=encoder_value,
12681259
joint_strategy="rear",
12691260
)
12701261

1262+
hidden_states = hidden_states.transpose(1, 2)
12711263
hidden_states = hidden_states.flatten(2, 3)
12721264
else:
12731265
if HAS_FLASH_ATTN:

xfuser/model_executor/layers/usp.py

Lines changed: 96 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515
get_sequence_parallel_world_size,
1616
get_ulysses_parallel_world_size,
1717
get_ring_parallel_world_size,
18+
get_sequence_parallel_rank,
19+
get_ulysses_parallel_rank,
1820
)
1921

2022
from packaging.version import parse
2123
from xfuser.envs import PACKAGES_CHECKER
24+
from xfuser.core.cache_manager.cache_manager import get_cache_manager
2225
env_info = PACKAGES_CHECKER.get_packages_info()
2326
HAS_FLASH_ATTN = env_info["has_flash_attn"]
2427
if HAS_FLASH_ATTN:
@@ -163,6 +166,7 @@ def _ft_c_output_all_to_all(x):
163166
x = x.reshape(world_size, s // world_size, b, -1, d).permute(2, 0, 3, 1, 4).reshape(b, -1, s // world_size, d)
164167
return x
165168

169+
166170
def _aiter_attn_call(query, key, value, dropout_p, is_causal):
167171
"""
168172
Performs the necessary tensor permutes and
@@ -217,21 +221,104 @@ def _attention(query, key, value, dropout_p, is_causal):
217221
query, key, value, dropout_p=dropout_p, is_causal=is_causal
218222
)
219223

220-
def USP(query, key, value, dropout_p=0.0, is_causal=False):
221-
if get_sequence_parallel_world_size() == 1:
222-
out = _attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal)
223-
elif get_ulysses_parallel_world_size() == 1:
224-
out = ring_attn(query, key, value, dropout_p=dropout_p, is_causal=is_causal)
225-
elif get_ulysses_parallel_world_size() > 1:
224+
225+
def _preprocess_joint_tensors(joint_key, joint_value):
226+
"""
227+
Preprocess the joint key and value tensors for Ulysses parallelism.
228+
"""
229+
ulysses_world_size = get_ulysses_parallel_world_size()
230+
ulysses_rank = get_ulysses_parallel_rank()
231+
attn_heads_per_ulysses_rank = (
232+
joint_key.shape[1] // ulysses_world_size
233+
)
234+
joint_key = joint_key.transpose(1,2)
235+
joint_value = joint_value.transpose(1,2)
236+
joint_key = joint_key[
237+
...,
238+
attn_heads_per_ulysses_rank
239+
* ulysses_rank : attn_heads_per_ulysses_rank
240+
* (ulysses_rank + 1),
241+
:, ].transpose(1,2)
242+
joint_value = joint_value[
243+
...,
244+
attn_heads_per_ulysses_rank
245+
* ulysses_rank : attn_heads_per_ulysses_rank
246+
* (ulysses_rank + 1),
247+
:,
248+
].transpose(1,2)
249+
return joint_key, joint_value
250+
251+
def _concat_joint_tensor(tensor, joint_tensor, joint_strategy, dim):
252+
"""
253+
Concatenate the joint tensor to the main tensor based on the joint strategy.
254+
"""
255+
if joint_strategy == "rear":
256+
tensor = torch.cat([tensor, joint_tensor], dim=dim)
257+
elif joint_strategy == "front":
258+
tensor = torch.cat([joint_tensor, tensor], dim=dim)
259+
else:
260+
raise ValueError(f"Invalid joint_strategy: {joint_strategy}")
261+
return tensor
262+
263+
def _update_and_get_kv_cache(key, value, attn_layer):
264+
"""
265+
Update and get the key and value cache for pipeline parallelism.
266+
"""
267+
key, value = get_cache_manager().update_and_get_kv_cache(
268+
new_kv=[key.transpose(1, 2), value.transpose(1, 2)],
269+
layer=attn_layer,
270+
slice_dim=1,
271+
layer_type="attn",
272+
)
273+
key = key.transpose(1, 2).contiguous()
274+
value = value.transpose(1, 2).contiguous()
275+
return key, value
276+
277+
def USP(
278+
query: torch.Tensor,
279+
key: torch.Tensor,
280+
value: torch.Tensor,
281+
dropout_p: float = 0.0,
282+
is_causal: bool = False,
283+
joint_query: torch.Tensor | None = None,
284+
joint_key: torch.Tensor | None = None,
285+
joint_value: torch.Tensor | None = None,
286+
joint_strategy: str | None = None,
287+
attn_layer=None,
288+
):
289+
"""
290+
Unified Sequence Parallelism (USP) attention call, supporting combinations of Ulysses and
291+
Ring attention. Also supports joint tensors and key-value caching for pipeline parallelism.
292+
"""
293+
294+
if joint_strategy:
295+
query = _concat_joint_tensor(query, joint_query, joint_strategy, dim=2)
296+
joint_key, joint_value = _preprocess_joint_tensors(joint_key, joint_value)
297+
298+
if get_ulysses_parallel_world_size() > 1:
226299
query = _ft_c_input_all_to_all(query)
227300
key = _ft_c_input_all_to_all(key)
228301
value = _ft_c_input_all_to_all(value)
229302

230-
if get_ring_parallel_world_size() == 1:
303+
if attn_layer:
304+
key, value = _update_and_get_kv_cache(key, value, attn_layer)
305+
if joint_strategy:
306+
key = _concat_joint_tensor(key, joint_key, joint_strategy, dim=2)
307+
value = _concat_joint_tensor(value, joint_value, joint_strategy, dim=2)
308+
309+
if get_sequence_parallel_world_size() == 1: # No SP
310+
out = _attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal)
311+
312+
elif get_ulysses_parallel_world_size() == 1: # Ring only
313+
out = ring_attn(query, key, value, dropout_p=dropout_p, is_causal=is_causal)
314+
315+
else:
316+
if get_ring_parallel_world_size() == 1: # Ulysses only
231317
out = _attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal)
232-
else:
318+
else: # USP
233319
out = ring_attn(query, key, value, dropout_p=dropout_p, is_causal=is_causal)
234-
235320
out = _ft_c_output_all_to_all(out)
236321

237322
return out
323+
324+

xfuser/model_executor/models/transformers/transformer_flux.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,12 @@ def __call__(
160160
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
161161
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
162162

163+
query = query.transpose(1, 2)
164+
key = key.transpose(1, 2)
165+
value = value.transpose(1, 2)
166+
163167
uses_pipeline_parallelism = get_runtime_state().num_pipeline_patch > 1
164168
if not uses_pipeline_parallelism:
165-
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
166169
hidden_states = USP(query, key, value)
167170
hidden_states = hidden_states.transpose(1, 2)
168171
else:
@@ -172,26 +175,29 @@ def __call__(
172175
encoder_hidden_states_value_proj = None
173176
else:
174177
encoder_hidden_states_query_proj, query = query.split(
175-
[num_encoder_hidden_states_tokens, num_query_tokens], dim=1
178+
[num_encoder_hidden_states_tokens, num_query_tokens], dim=2
176179
)
177180
encoder_hidden_states_key_proj, key = key.split(
178-
[num_encoder_hidden_states_tokens, num_query_tokens], dim=1
181+
[num_encoder_hidden_states_tokens, num_query_tokens], dim=2
179182
)
180183
encoder_hidden_states_value_proj, value = value.split(
181-
[num_encoder_hidden_states_tokens, num_query_tokens], dim=1
184+
[num_encoder_hidden_states_tokens, num_query_tokens], dim=2
182185
)
183-
hidden_states = self.hybrid_seq_parallel_attn(
184-
attn if get_runtime_state().num_pipeline_patch > 1 else None,
186+
hidden_states = USP(
185187
query,
186188
key,
187189
value,
188190
dropout_p=0.0,
189-
causal=False,
190-
joint_tensor_query=encoder_hidden_states_query_proj,
191-
joint_tensor_key=encoder_hidden_states_key_proj,
192-
joint_tensor_value=encoder_hidden_states_value_proj,
191+
is_causal=False,
192+
joint_query=encoder_hidden_states_query_proj,
193+
joint_key=encoder_hidden_states_key_proj,
194+
joint_value=encoder_hidden_states_value_proj,
193195
joint_strategy="front",
196+
attn_layer=attn,
194197
)
198+
hidden_states = hidden_states.transpose(1, 2)
199+
200+
195201
hidden_states = hidden_states.flatten(2, 3)
196202
hidden_states = hidden_states.to(query.dtype)
197203

0 commit comments

Comments
 (0)