Skip to content

Commit 3716b20

Browse files
author
Qirui Yang
committed
Add token sharding functions and tests for context parallelism
1 parent 7a51dd0 commit 3716b20

File tree

3 files changed

+3
-1
lines changed

3 files changed

+3
-1
lines changed

vllm/v1/attention/backends/cp_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def _cp_shard_positions_for_prefill(
4848
# Compute the token index ranges for the two shards handled by this rank
4949
chunk0_start = cp_rank * cp_shard_size
5050
chunk1_start = (2 * cp_size - cp_rank - 1) * cp_shard_size
51-
5251
chunk0_arange = arange_np[chunk0_start:chunk0_start + cp_shard_size]
5352
chunk1_arange = arange_np[chunk1_start:chunk1_start + cp_shard_size]
5453

vllm/v1/worker/block_table.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from vllm.logger import init_logger
1111
from vllm.utils import cdiv
1212
from vllm.v1.utils import CpuGpuBuffer
13+
from vllm.distributed.parallel_state import get_context_parallel_world_size
1314

1415
logger = init_logger(__name__)
1516

vllm/v1/worker/gpu_model_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@
7272
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
7373
create_fast_prefill_custom_backend,
7474
reorder_batch_to_split_decodes_and_prefills, split_attn_metadata)
75+
from vllm.v1.attention.backends.cp_utils import (
76+
cp_shard_positions_for_prefill, cp_get_computed_positions)
7577
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
7678
# yapf conflicts with isort for this block
7779
# yapf: disable

0 commit comments

Comments
 (0)