Skip to content

[Feature] Implement Decode Context Parallel in SGLang #12196

@staugust

Description

@staugust

Checklist

Motivation

Decoding very long sequences can overwhelm memory due to KV cache growth, and tensor parallelism (TP) alone often isn’t sufficient. I propose adding Decode Context Parallel (Decode CP) to SGLang, inspired by vLLM’s approach and supported by the paper "Helix Parallelism: Rethinking Sharding Strategies for Interactive Multi-Million-Token LLM Decoding", Decode CP partitions and processes the context during the decode phase across multiple devices (or compute units on a single device), enabling efficient long-context handling and avoid kv cache redundancy. Using MQA as a example,

  • With TP2: each rank keeps a copy of the sequence's kv cache, the maximum sequence length is still bounded by the memory of a single device.
  • With DP2: each device processes a different sequence; for any given sequence, its KV cache resides on one device, and the maximum length is again limited by that device’s memory.
  • With CP2: increasing the cp_size partitions the KV cache across devices, eliminating redundancy and enabling much longer sequences, with the effective memory budget scaling roughly linearly with the number of CP shards.
Image

It makes sense in MLA and GQA where tp size is greater than heads count as well.

Motivation

  • Support much longer context lengths (e.g., 64k/128k/256k) by reducing per-device KV cache footprint.
    High-Level Proposal
  • KV cache partitioning and management:
    • Split KV cache along the sequence dimension or block-wise; ensure compatibility with paged attention and efficient allocation/eviction.
  • Cross-shard attention during decode:
    • Implement communication patterns such as ring attention or reduce-scatter/all-gather so each step can access the required context across shards.
  • Scheduling and topology:
    • Introduce a context_parallel_size topology in the scheduler; route requests based on sequence length and resource availability; co-plan with TP/DP
  • Compute-communication overlap:
    • Overlap attention computation with inter-device communication to reduce synchronization overhead.

References

Related resources

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions