Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 92 additions & 17 deletions tests/layers/usp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,29 @@
get_runtime_state,
)
from xfuser.core.distributed.parallel_state import destroy_model_parallel, destroy_distributed_environment
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from yunchang.kernels import AttnType


def _init_environment():
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
os.environ["LOCAL_RANK"] = "0"
init_distributed_environment(rank=0, world_size=1)
initialize_model_parallel(ring_degree=1, ulysses_degree=1)

class TestUSP(unittest.TestCase):

def setUp(self):
self._init_environment()
_init_environment()
env_info = PACKAGES_CHECKER.get_packages_info()
self.HAS_FLASH_ATTN = env_info["has_flash_attn"]
self.HAS_AITER = env_info["has_aiter"]
self.query = torch.randn(29760, 2, 5, 128, device="cuda", dtype=torch.bfloat16)
self.key = torch.randn(29760, 2, 5, 128, device="cuda", dtype=torch.bfloat16)
self.value = torch.randn(29760, 2, 5, 128, device="cuda", dtype=torch.bfloat16)


def _init_environment(self):
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
os.environ["LOCAL_RANK"] = "0"
init_distributed_environment(rank=0, world_size=1)
initialize_model_parallel(ring_degree=1, ulysses_degree=1)


self.query = torch.randn(1, 24, 14867, 128, device="cuda", dtype=torch.bfloat16)
self.key = torch.randn(1, 24, 14867, 128, device="cuda", dtype=torch.bfloat16)
self.value = torch.randn(1, 24, 14867, 128, device="cuda", dtype=torch.bfloat16)

def tearDown(self):
destroy_model_parallel()
Expand Down Expand Up @@ -123,4 +122,80 @@ def test_ring_attn_aiter(self):

result_diff = (fsdpa_results - aiter_attn_results).abs().max()
self.assertNotEqual(result_diff, 0) # Different implementations won't produce same output
self.assertAlmostEqual(result_diff.item(), 0, places=1) # Difference can be 0.15ish
self.assertAlmostEqual(result_diff.item(), 0, places=1) # Difference can be 0.15ish


class TestUSPHybridParallel(unittest.TestCase):

def setUp(self):
_init_environment()
# Using SDPA here
self.HAS_FLASH_ATTN = False
self.HAS_AITER = False
self.query = torch.randn(1, 24, 14867, 128, device="cuda", dtype=torch.bfloat16)
self.key = torch.randn(1, 24, 14867, 128, device="cuda", dtype=torch.bfloat16)
self.value = torch.randn(1, 24, 14867, 128, device="cuda", dtype=torch.bfloat16)

self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
attn_type=AttnType.TORCH
)

def tearDown(self):
destroy_model_parallel()
destroy_distributed_environment()

def test_usp_hybrid_equivalence(self):
"""
Tests the output from USP is equivalent to hybrid seq parallel attention, i.e
yunchang path
"""
usp_results = usp.USP(self.query, self.key, self.value, dropout_p=0.0, is_causal=False)
hybrid_results = self.hybrid_seq_parallel_attn(
None,
self.query.transpose(1, 2),
self.key.transpose(1, 2),
self.value.transpose(1, 2),
dropout_p=0.0,
causal=False
).transpose(1, 2)

result_diff = (usp_results - hybrid_results).abs().max().float().cpu().numpy()
self.assertAlmostEqual(result_diff, 0, places=3)

def test_usp_hybrid_joint_equivalence(self):
"""
Tests the output from USP with joint tensors added is equivalent to hybrid seq
parallel attn.
"""
joint_shape = (1, 24, 64, 128)

joint_query = torch.randn(joint_shape, device="cuda", dtype=torch.bfloat16)
joint_key = torch.randn(joint_shape, device="cuda", dtype=torch.bfloat16)
joint_value = torch.randn(joint_shape, device="cuda", dtype=torch.bfloat16)

usp_results = usp.USP(
self.query,
self.key,
self.value,
dropout_p=0.0,
is_causal=False,
joint_query=joint_query,
joint_key=joint_key,
joint_value=joint_value,
joint_strategy="rear"
)
hybrid_results = self.hybrid_seq_parallel_attn(
None,
self.query.transpose(1, 2),
self.key.transpose(1, 2),
self.value.transpose(1, 2),
dropout_p=0.0,
causal=False,
joint_tensor_query=joint_query.transpose(1, 2),
joint_tensor_key=joint_key.transpose(1, 2),
joint_tensor_value=joint_value.transpose(1, 2),
joint_strategy="rear"
).transpose(1, 2)

result_diff = (usp_results - hybrid_results).abs().max().float().cpu().numpy()
self.assertAlmostEqual(result_diff, 0, places=3)
22 changes: 7 additions & 15 deletions xfuser/model_executor/layers/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,7 +1231,7 @@ def __call__(
):
hidden_states = USP(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
elif HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
elif get_sequence_parallel_world_size() > 1:
if get_runtime_state().split_text_embed_in_sp:
encoder_query = None
encoder_key = None
Expand All @@ -1247,27 +1247,19 @@ def __call__(
[num_query_tokens, num_encoder_hidden_states_tokens], dim=2
)

encoder_query = encoder_query.transpose(1, 2)
encoder_key = encoder_key.transpose(1, 2)
encoder_value = encoder_value.transpose(1, 2)

query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)

hidden_states = self.hybrid_seq_parallel_attn(
None,
hidden_states = USP(
query,
key,
value,
dropout_p=0.0,
causal=False,
joint_tensor_query=encoder_query,
joint_tensor_key=encoder_key,
joint_tensor_value=encoder_value,
is_causal=False,
joint_query=encoder_query,
joint_key=encoder_key,
joint_value=encoder_value,
joint_strategy="rear",
)

hidden_states = hidden_states.transpose(1, 2)
hidden_states = hidden_states.flatten(2, 3)
else:
if HAS_FLASH_ATTN:
Expand Down
105 changes: 96 additions & 9 deletions xfuser/model_executor/layers/usp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
get_sequence_parallel_world_size,
get_ulysses_parallel_world_size,
get_ring_parallel_world_size,
get_sequence_parallel_rank,
get_ulysses_parallel_rank,
)

from packaging.version import parse
from xfuser.envs import PACKAGES_CHECKER
from xfuser.core.cache_manager.cache_manager import get_cache_manager
env_info = PACKAGES_CHECKER.get_packages_info()
HAS_FLASH_ATTN = env_info["has_flash_attn"]
if HAS_FLASH_ATTN:
Expand Down Expand Up @@ -163,6 +166,7 @@ def _ft_c_output_all_to_all(x):
x = x.reshape(world_size, s // world_size, b, -1, d).permute(2, 0, 3, 1, 4).reshape(b, -1, s // world_size, d)
return x


def _aiter_attn_call(query, key, value, dropout_p, is_causal):
"""
Performs the necessary tensor permutes and
Expand Down Expand Up @@ -217,21 +221,104 @@ def _attention(query, key, value, dropout_p, is_causal):
query, key, value, dropout_p=dropout_p, is_causal=is_causal
)

def USP(query, key, value, dropout_p=0.0, is_causal=False):
if get_sequence_parallel_world_size() == 1:
out = _attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal)
elif get_ulysses_parallel_world_size() == 1:
out = ring_attn(query, key, value, dropout_p=dropout_p, is_causal=is_causal)
elif get_ulysses_parallel_world_size() > 1:

def _preprocess_joint_tensors(joint_key, joint_value):
"""
Preprocess the joint key and value tensors for Ulysses parallelism.
"""
ulysses_world_size = get_ulysses_parallel_world_size()
ulysses_rank = get_ulysses_parallel_rank()
attn_heads_per_ulysses_rank = (
joint_key.shape[1] // ulysses_world_size
)
joint_key = joint_key.transpose(1,2)
joint_value = joint_value.transpose(1,2)
joint_key = joint_key[
...,
attn_heads_per_ulysses_rank
* ulysses_rank : attn_heads_per_ulysses_rank
* (ulysses_rank + 1),
:, ].transpose(1,2)
joint_value = joint_value[
...,
attn_heads_per_ulysses_rank
* ulysses_rank : attn_heads_per_ulysses_rank
* (ulysses_rank + 1),
:,
].transpose(1,2)
return joint_key, joint_value

def _concat_joint_tensor(tensor, joint_tensor, joint_strategy, dim):
"""
Concatenate the joint tensor to the main tensor based on the joint strategy.
"""
if joint_strategy == "rear":
tensor = torch.cat([tensor, joint_tensor], dim=dim)
elif joint_strategy == "front":
tensor = torch.cat([joint_tensor, tensor], dim=dim)
else:
raise ValueError(f"Invalid joint_strategy: {joint_strategy}")
return tensor

def _update_and_get_kv_cache(key, value, attn_layer):
"""
Update and get the key and value cache for pipeline parallelism.
"""
key, value = get_cache_manager().update_and_get_kv_cache(
new_kv=[key.transpose(1, 2), value.transpose(1, 2)],
layer=attn_layer,
slice_dim=1,
layer_type="attn",
)
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
return key, value

def USP(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
joint_query: torch.Tensor | None = None,
joint_key: torch.Tensor | None = None,
joint_value: torch.Tensor | None = None,
joint_strategy: str | None = None,
attn_layer=None,
):
"""
Unified Sequence Parallelism (USP) attention call, supporting combinations of Ulysses and
Ring attention. Also supports joint tensors and key-value caching for pipeline parallelism.
"""

if joint_strategy:
query = _concat_joint_tensor(query, joint_query, joint_strategy, dim=2)
joint_key, joint_value = _preprocess_joint_tensors(joint_key, joint_value)

if get_ulysses_parallel_world_size() > 1:
query = _ft_c_input_all_to_all(query)
key = _ft_c_input_all_to_all(key)
value = _ft_c_input_all_to_all(value)

if get_ring_parallel_world_size() == 1:
if attn_layer:
key, value = _update_and_get_kv_cache(key, value, attn_layer)
if joint_strategy:
key = _concat_joint_tensor(key, joint_key, joint_strategy, dim=2)
value = _concat_joint_tensor(value, joint_value, joint_strategy, dim=2)

if get_sequence_parallel_world_size() == 1: # No SP
out = _attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal)

elif get_ulysses_parallel_world_size() == 1: # Ring only
out = ring_attn(query, key, value, dropout_p=dropout_p, is_causal=is_causal)

else:
if get_ring_parallel_world_size() == 1: # Ulysses only
out = _attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal)
else:
else: # USP
out = ring_attn(query, key, value, dropout_p=dropout_p, is_causal=is_causal)

out = _ft_c_output_all_to_all(out)

return out


26 changes: 16 additions & 10 deletions xfuser/model_executor/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,12 @@ def __call__(
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)

query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)

uses_pipeline_parallelism = get_runtime_state().num_pipeline_patch > 1
if not uses_pipeline_parallelism:
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
hidden_states = USP(query, key, value)
hidden_states = hidden_states.transpose(1, 2)
else:
Expand All @@ -172,26 +175,29 @@ def __call__(
encoder_hidden_states_value_proj = None
else:
encoder_hidden_states_query_proj, query = query.split(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=1
[num_encoder_hidden_states_tokens, num_query_tokens], dim=2
)
encoder_hidden_states_key_proj, key = key.split(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=1
[num_encoder_hidden_states_tokens, num_query_tokens], dim=2
)
encoder_hidden_states_value_proj, value = value.split(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=1
[num_encoder_hidden_states_tokens, num_query_tokens], dim=2
)
hidden_states = self.hybrid_seq_parallel_attn(
attn if get_runtime_state().num_pipeline_patch > 1 else None,
hidden_states = USP(
query,
key,
value,
dropout_p=0.0,
causal=False,
joint_tensor_query=encoder_hidden_states_query_proj,
joint_tensor_key=encoder_hidden_states_key_proj,
joint_tensor_value=encoder_hidden_states_value_proj,
is_causal=False,
joint_query=encoder_hidden_states_query_proj,
joint_key=encoder_hidden_states_key_proj,
joint_value=encoder_hidden_states_value_proj,
joint_strategy="front",
attn_layer=attn,
)
hidden_states = hidden_states.transpose(1, 2)


hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)

Expand Down