|
48 | 48 | def is_valid_kv_cache_layout(value: str) -> bool: |
49 | 49 | return value in get_args(KVCacheLayoutType) |
50 | 50 |
|
| 51 | +@dataclass |
| 52 | +class PrefillContextParallelMetadata: |
| 53 | + """ |
| 54 | + Attention metadata for prefill context parallel |
| 55 | + """ |
| 56 | + q_head_indices: torch.Tensor |
| 57 | + q_tail_indices: torch.Tensor |
| 58 | + q_head_start_loc: torch.Tensor |
| 59 | + kv_for_head_indices: torch.Tensor |
| 60 | + kv_for_tail_indices : torch.Tensor |
| 61 | + kv_for_head_indptr: torch.Tensor |
| 62 | + kv_for_tail_indptr: torch.Tensor |
| 63 | + q_full_indices: torch.Tensor |
51 | 64 |
|
52 | 65 | @dataclass |
53 | 66 | class CommonAttentionMetadata: |
@@ -94,6 +107,7 @@ class CommonAttentionMetadata: |
94 | 107 | dcp_local_seq_lens: torch.Tensor | None = None |
95 | 108 | """Sequence lengths of the local rank in decode context parallelism world""" |
96 | 109 |
|
| 110 | + pcp_metadata: PrefillContextParallelMetadata | None = None |
97 | 111 |
|
98 | 112 | def slice_query_start_locs( |
99 | 113 | query_start_loc: torch.Tensor, |
@@ -1077,35 +1091,35 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor): |
1077 | 1091 | return nums_dict, batch_ptr, token_chunk_offset_ptr |
1078 | 1092 |
|
1079 | 1093 |
|
1080 | | -def get_dcp_local_seq_lens( |
| 1094 | +def get_cp_local_seq_lens( |
1081 | 1095 | seq_lens: torch.Tensor, |
1082 | | - dcp_size: int = 1, |
1083 | | - dcp_rank: int | None = None, |
| 1096 | + cp_size: int = 1, |
| 1097 | + cp_rank: int | None = None, |
1084 | 1098 | cp_kv_cache_interleave_size: int = 1, |
1085 | 1099 | ) -> torch.Tensor: |
1086 | 1100 | """While using dcp, kv_cache size stored on each rank may be different, |
1087 | 1101 | use this function to calculate split decode seq_lens of each dcp rank. |
1088 | 1102 | Only consider dcp now, we can extend the case of cp based on this. |
1089 | 1103 | """ |
1090 | 1104 | num_requests = seq_lens.size(0) |
1091 | | - if dcp_rank is None: |
| 1105 | + if cp_rank is None: |
1092 | 1106 | rank_offsets = ( |
1093 | | - torch.arange(dcp_size, dtype=torch.int32) |
| 1107 | + torch.arange(cp_size, dtype=torch.int32) |
1094 | 1108 | .unsqueeze(0) |
1095 | 1109 | .repeat(num_requests, 1) |
1096 | 1110 | ) |
1097 | 1111 | else: |
1098 | | - rank_offsets = torch.Tensor([[dcp_rank]]).to(dtype=torch.int32) |
| 1112 | + rank_offsets = torch.Tensor([[cp_rank]]).to(dtype=torch.int32) |
1099 | 1113 | seq_lens_tiled = ( |
1100 | 1114 | seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1]) |
1101 | 1115 | ) |
1102 | 1116 | base = ( |
1103 | 1117 | seq_lens_tiled |
1104 | 1118 | // cp_kv_cache_interleave_size |
1105 | | - // dcp_size |
| 1119 | + // cp_size |
1106 | 1120 | * cp_kv_cache_interleave_size |
1107 | 1121 | ) |
1108 | | - remainder = seq_lens_tiled - base * dcp_size |
| 1122 | + remainder = seq_lens_tiled - base * cp_size |
1109 | 1123 | remainder = torch.clip( |
1110 | 1124 | remainder - rank_offsets * cp_kv_cache_interleave_size, |
1111 | 1125 | 0, |
|
0 commit comments