diff --git a/tests/layers/usp_test.py b/tests/layers/usp_test.py index fa14092c..f43731a6 100644 --- a/tests/layers/usp_test.py +++ b/tests/layers/usp_test.py @@ -11,30 +11,29 @@ get_runtime_state, ) from xfuser.core.distributed.parallel_state import destroy_model_parallel, destroy_distributed_environment +from xfuser.core.long_ctx_attention import xFuserLongContextAttention +from yunchang.kernels import AttnType +def _init_environment(): + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + os.environ["LOCAL_RANK"] = "0" + init_distributed_environment(rank=0, world_size=1) + initialize_model_parallel(ring_degree=1, ulysses_degree=1) + class TestUSP(unittest.TestCase): def setUp(self): - self._init_environment() + _init_environment() env_info = PACKAGES_CHECKER.get_packages_info() self.HAS_FLASH_ATTN = env_info["has_flash_attn"] self.HAS_AITER = env_info["has_aiter"] - self.query = torch.randn(29760, 2, 5, 128, device="cuda", dtype=torch.bfloat16) - self.key = torch.randn(29760, 2, 5, 128, device="cuda", dtype=torch.bfloat16) - self.value = torch.randn(29760, 2, 5, 128, device="cuda", dtype=torch.bfloat16) - - - def _init_environment(self): - os.environ["RANK"] = "0" - os.environ["WORLD_SIZE"] = "1" - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12355" - os.environ["LOCAL_RANK"] = "0" - init_distributed_environment(rank=0, world_size=1) - initialize_model_parallel(ring_degree=1, ulysses_degree=1) - - + self.query = torch.randn(1, 24, 14867, 128, device="cuda", dtype=torch.bfloat16) + self.key = torch.randn(1, 24, 14867, 128, device="cuda", dtype=torch.bfloat16) + self.value = torch.randn(1, 24, 14867, 128, device="cuda", dtype=torch.bfloat16) def tearDown(self): destroy_model_parallel() @@ -123,4 +122,80 @@ def test_ring_attn_aiter(self): result_diff = (fsdpa_results - aiter_attn_results).abs().max() self.assertNotEqual(result_diff, 0) # Different implementations won't produce same output - self.assertAlmostEqual(result_diff.item(), 0, places=1) # Difference can be 0.15ish \ No newline at end of file + self.assertAlmostEqual(result_diff.item(), 0, places=1) # Difference can be 0.15ish + + +class TestUSPHybridParallel(unittest.TestCase): + + def setUp(self): + _init_environment() + # Using SDPA here + self.HAS_FLASH_ATTN = False + self.HAS_AITER = False + self.query = torch.randn(1, 24, 14867, 128, device="cuda", dtype=torch.bfloat16) + self.key = torch.randn(1, 24, 14867, 128, device="cuda", dtype=torch.bfloat16) + self.value = torch.randn(1, 24, 14867, 128, device="cuda", dtype=torch.bfloat16) + + self.hybrid_seq_parallel_attn = xFuserLongContextAttention( + attn_type=AttnType.TORCH + ) + + def tearDown(self): + destroy_model_parallel() + destroy_distributed_environment() + + def test_usp_hybrid_equivalence(self): + """ + Tests the output from USP is equivalent to hybrid seq parallel attention, i.e + yunchang path + """ + usp_results = usp.USP(self.query, self.key, self.value, dropout_p=0.0, is_causal=False) + hybrid_results = self.hybrid_seq_parallel_attn( + None, + self.query.transpose(1, 2), + self.key.transpose(1, 2), + self.value.transpose(1, 2), + dropout_p=0.0, + causal=False + ).transpose(1, 2) + + result_diff = (usp_results - hybrid_results).abs().max().float().cpu().numpy() + self.assertAlmostEqual(result_diff, 0, places=3) + + def test_usp_hybrid_joint_equivalence(self): + """ + Tests the output from USP with joint tensors added is equivalent to hybrid seq + parallel attn. + """ + joint_shape = (1, 24, 64, 128) + + joint_query = torch.randn(joint_shape, device="cuda", dtype=torch.bfloat16) + joint_key = torch.randn(joint_shape, device="cuda", dtype=torch.bfloat16) + joint_value = torch.randn(joint_shape, device="cuda", dtype=torch.bfloat16) + + usp_results = usp.USP( + self.query, + self.key, + self.value, + dropout_p=0.0, + is_causal=False, + joint_query=joint_query, + joint_key=joint_key, + joint_value=joint_value, + joint_strategy="rear" + ) + hybrid_results = self.hybrid_seq_parallel_attn( + None, + self.query.transpose(1, 2), + self.key.transpose(1, 2), + self.value.transpose(1, 2), + dropout_p=0.0, + causal=False, + joint_tensor_query=joint_query.transpose(1, 2), + joint_tensor_key=joint_key.transpose(1, 2), + joint_tensor_value=joint_value.transpose(1, 2), + joint_strategy="rear" + ).transpose(1, 2) + + result_diff = (usp_results - hybrid_results).abs().max().float().cpu().numpy() + self.assertAlmostEqual(result_diff, 0, places=3) diff --git a/xfuser/model_executor/layers/attention_processor.py b/xfuser/model_executor/layers/attention_processor.py index 2f8ff85c..ba4567ce 100644 --- a/xfuser/model_executor/layers/attention_processor.py +++ b/xfuser/model_executor/layers/attention_processor.py @@ -1231,7 +1231,7 @@ def __call__( ): hidden_states = USP(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) - elif HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: + elif get_sequence_parallel_world_size() > 1: if get_runtime_state().split_text_embed_in_sp: encoder_query = None encoder_key = None @@ -1247,27 +1247,19 @@ def __call__( [num_query_tokens, num_encoder_hidden_states_tokens], dim=2 ) - encoder_query = encoder_query.transpose(1, 2) - encoder_key = encoder_key.transpose(1, 2) - encoder_value = encoder_value.transpose(1, 2) - - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - hidden_states = self.hybrid_seq_parallel_attn( - None, + hidden_states = USP( query, key, value, dropout_p=0.0, - causal=False, - joint_tensor_query=encoder_query, - joint_tensor_key=encoder_key, - joint_tensor_value=encoder_value, + is_causal=False, + joint_query=encoder_query, + joint_key=encoder_key, + joint_value=encoder_value, joint_strategy="rear", ) + hidden_states = hidden_states.transpose(1, 2) hidden_states = hidden_states.flatten(2, 3) else: if HAS_FLASH_ATTN: diff --git a/xfuser/model_executor/layers/usp.py b/xfuser/model_executor/layers/usp.py index 661bdabb..731b160e 100644 --- a/xfuser/model_executor/layers/usp.py +++ b/xfuser/model_executor/layers/usp.py @@ -15,10 +15,13 @@ get_sequence_parallel_world_size, get_ulysses_parallel_world_size, get_ring_parallel_world_size, + get_sequence_parallel_rank, + get_ulysses_parallel_rank, ) from packaging.version import parse from xfuser.envs import PACKAGES_CHECKER +from xfuser.core.cache_manager.cache_manager import get_cache_manager env_info = PACKAGES_CHECKER.get_packages_info() HAS_FLASH_ATTN = env_info["has_flash_attn"] if HAS_FLASH_ATTN: @@ -163,6 +166,7 @@ def _ft_c_output_all_to_all(x): x = x.reshape(world_size, s // world_size, b, -1, d).permute(2, 0, 3, 1, 4).reshape(b, -1, s // world_size, d) return x + def _aiter_attn_call(query, key, value, dropout_p, is_causal): """ Performs the necessary tensor permutes and @@ -217,21 +221,104 @@ def _attention(query, key, value, dropout_p, is_causal): query, key, value, dropout_p=dropout_p, is_causal=is_causal ) -def USP(query, key, value, dropout_p=0.0, is_causal=False): - if get_sequence_parallel_world_size() == 1: - out = _attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal) - elif get_ulysses_parallel_world_size() == 1: - out = ring_attn(query, key, value, dropout_p=dropout_p, is_causal=is_causal) - elif get_ulysses_parallel_world_size() > 1: + +def _preprocess_joint_tensors(joint_key, joint_value): + """ + Preprocess the joint key and value tensors for Ulysses parallelism. + """ + ulysses_world_size = get_ulysses_parallel_world_size() + ulysses_rank = get_ulysses_parallel_rank() + attn_heads_per_ulysses_rank = ( + joint_key.shape[1] // ulysses_world_size + ) + joint_key = joint_key.transpose(1,2) + joint_value = joint_value.transpose(1,2) + joint_key = joint_key[ + ..., + attn_heads_per_ulysses_rank + * ulysses_rank : attn_heads_per_ulysses_rank + * (ulysses_rank + 1), + :, ].transpose(1,2) + joint_value = joint_value[ + ..., + attn_heads_per_ulysses_rank + * ulysses_rank : attn_heads_per_ulysses_rank + * (ulysses_rank + 1), + :, + ].transpose(1,2) + return joint_key, joint_value + +def _concat_joint_tensor(tensor, joint_tensor, joint_strategy, dim): + """ + Concatenate the joint tensor to the main tensor based on the joint strategy. + """ + if joint_strategy == "rear": + tensor = torch.cat([tensor, joint_tensor], dim=dim) + elif joint_strategy == "front": + tensor = torch.cat([joint_tensor, tensor], dim=dim) + else: + raise ValueError(f"Invalid joint_strategy: {joint_strategy}") + return tensor + +def _update_and_get_kv_cache(key, value, attn_layer): + """ + Update and get the key and value cache for pipeline parallelism. + """ + key, value = get_cache_manager().update_and_get_kv_cache( + new_kv=[key.transpose(1, 2), value.transpose(1, 2)], + layer=attn_layer, + slice_dim=1, + layer_type="attn", + ) + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + return key, value + +def USP( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + joint_query: torch.Tensor | None = None, + joint_key: torch.Tensor | None = None, + joint_value: torch.Tensor | None = None, + joint_strategy: str | None = None, + attn_layer=None, + ): + """ + Unified Sequence Parallelism (USP) attention call, supporting combinations of Ulysses and + Ring attention. Also supports joint tensors and key-value caching for pipeline parallelism. + """ + + if joint_strategy: + query = _concat_joint_tensor(query, joint_query, joint_strategy, dim=2) + joint_key, joint_value = _preprocess_joint_tensors(joint_key, joint_value) + + if get_ulysses_parallel_world_size() > 1: query = _ft_c_input_all_to_all(query) key = _ft_c_input_all_to_all(key) value = _ft_c_input_all_to_all(value) - if get_ring_parallel_world_size() == 1: + if attn_layer: + key, value = _update_and_get_kv_cache(key, value, attn_layer) + if joint_strategy: + key = _concat_joint_tensor(key, joint_key, joint_strategy, dim=2) + value = _concat_joint_tensor(value, joint_value, joint_strategy, dim=2) + + if get_sequence_parallel_world_size() == 1: # No SP + out = _attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal) + + elif get_ulysses_parallel_world_size() == 1: # Ring only + out = ring_attn(query, key, value, dropout_p=dropout_p, is_causal=is_causal) + + else: + if get_ring_parallel_world_size() == 1: # Ulysses only out = _attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal) - else: + else: # USP out = ring_attn(query, key, value, dropout_p=dropout_p, is_causal=is_causal) - out = _ft_c_output_all_to_all(out) return out + + diff --git a/xfuser/model_executor/models/transformers/transformer_flux.py b/xfuser/model_executor/models/transformers/transformer_flux.py index 00096e14..33a779b8 100644 --- a/xfuser/model_executor/models/transformers/transformer_flux.py +++ b/xfuser/model_executor/models/transformers/transformer_flux.py @@ -160,9 +160,12 @@ def __call__( key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + uses_pipeline_parallelism = get_runtime_state().num_pipeline_patch > 1 if not uses_pipeline_parallelism: - query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) hidden_states = USP(query, key, value) hidden_states = hidden_states.transpose(1, 2) else: @@ -172,26 +175,29 @@ def __call__( encoder_hidden_states_value_proj = None else: encoder_hidden_states_query_proj, query = query.split( - [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 + [num_encoder_hidden_states_tokens, num_query_tokens], dim=2 ) encoder_hidden_states_key_proj, key = key.split( - [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 + [num_encoder_hidden_states_tokens, num_query_tokens], dim=2 ) encoder_hidden_states_value_proj, value = value.split( - [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 + [num_encoder_hidden_states_tokens, num_query_tokens], dim=2 ) - hidden_states = self.hybrid_seq_parallel_attn( - attn if get_runtime_state().num_pipeline_patch > 1 else None, + hidden_states = USP( query, key, value, dropout_p=0.0, - causal=False, - joint_tensor_query=encoder_hidden_states_query_proj, - joint_tensor_key=encoder_hidden_states_key_proj, - joint_tensor_value=encoder_hidden_states_value_proj, + is_causal=False, + joint_query=encoder_hidden_states_query_proj, + joint_key=encoder_hidden_states_key_proj, + joint_value=encoder_hidden_states_value_proj, joint_strategy="front", + attn_layer=attn, ) + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype)