-
Notifications
You must be signed in to change notification settings - Fork 512
Description
Motivation.
This RFC proposes the implementation of RingAttention-based Context Parallelism (CP) in vllm-omni. While the team is currently implementing Ulysses-based Sequence Parallelism (SP), we propose adding RingAttention #273 to overcome specific limitations inherent to Ulysses #189 specifically regarding attention head constraints.
This proposal aims to:
- Clarify the terminology between Sequence Parallelism (Ulysses-style) and Context-Parallelism (Ring-style) to align with broader community usage (LLM vs. DiT).
- Justify the necessity of RingAttention as a complementary strategy to Ulysses.
- Discuss the roadmap for supporting CP/RingAttention alongside the existing Ulysses implementation.
Concepts
-
Megatron Sequence Parallelism (Megatron-SP):
Essence: Reduces activation memory per GPU to minimize re-computation (activation checkpointing) and speed up training.
Usage: Typically coupled tightly with Tensor Parallelism (TP). -
Megatron Context Parallelism (Megatron-CP):
Essence: An enhanced version of SP.
Mechanism: It incorporates Ring-Attention-like techniques (performing Ring-Attention across ranks with identical TP-PP-DP positions) to support training in conjunction with Megatron's various hybrid parallelism strategies. -
DeepSpeed Ulysses Sequence Parallel (Ulysses-SP):
Context: DeepSpeed ZeRO is essentially data parallelism with model parallel characteristics. Long sequences pressure single-GPU memory during the full MHA calculation.
Mechanism: Ulysses splits the input along the sequence dimension and uses all-to-all communication to allow each GPU to compute only a subset of attention heads. -
Ring-Attention:
Essence: Can be viewed as a distributed implementation of Flash Attention V2.
Mechanism: Each GPU computes the Multi-Head Attention (MHA) only for the specific seq_chunk it maintains, passing KV blocks in a ring.
Terminology Alignment
- In LLM community (e.g., Megatron-LM, vLLM): SP usually refers to splitting LayerNorm and Dropout operations along the sequence dimension to save activation memory in the MoE & MLP layer, while relying on Tensor Parallelism for Attention. In this case, SP technology is merely a memory optimization dependent on TP (television bandwidth). Thus, SP often needs to have the same dimensions as TP. (vllm issues#22693)
- In Diffusion Models (e.g., DiT): SP often refers to splitting the actual attention computation across devices along the sequence dimension.
To avoid confusion, we need a clear taxonomy that distinguishes between memory optimization techniques and distributed attention computation.
Two Main Approaches to Context Parallel:
-
Ulysses-SP (Head-based CP) (Implemented [Diffusion]: Diffusion Ulysses-Sequence-Parallelism support #189 [Feature]: Support Ulysses Sequence Parallelism for Diffusion Models #192):
Mechanism: Uses all-to-all to gather the full sequence for a subset of heads.
Pros: Communication is efficient (low volume).
Cons: Hard limit: Parallelism size$\le$ Number of Heads. -
RingAttention (Sequence-based CP, used in Megatron-CP) ([Diffusion]: Diffusion Ring Attention support #273):
Mechanism: Keeps heads local but splits the sequence. Passes KV blocks between GPUs in a ring during the attention calculation.
Pros: No limit on parallelism size relative to heads. Can scale to arbitrary context lengths.
Cons: Higher implementation complexity (custom kernels often required), potentially higher latency if communication isn't perfectly overlapped.
Proposed Change.
1. vLLM Terminology Alignment
In the context of vLLM, the distinction between the prefill phase and the decode phase is critical. We observe that the "SP" used in DiT is architecturally similar to PCP (Prefill-Context-Parallel).
We propose the following specific definitions for vLLM internals:
- PCP (Prefill-Context-Parallel): Parallelism applied during the prefill phase, where the prompt is processed.
- DCP (Decode-Context-Parallel): Parallelism applied during the token generation phase.
2. Standardization of "Context Parallelism"
To eliminate ambiguity, we propose that Context Parallelism (CP) be adopted as the standard term for any parallelism strategy that splits the input sequence and is orthogonal to Tensor Parallelism (TP).
- Old Terminology: Sequence Parallelism (when referring to splitting attention computation).
- New Standard: Context Parallelism (CP).
This change should be reflected in:
- CLI arguments (e.g.,
-enable-context-parallelinstead of-enable-sp). - Internal variable naming conventions.
- Documentation and docstrings.
3. Adding a RingAttention backend to the Context Parallelism interface.
We invite the community to discuss the following:
- Do we agree on deprecating "Sequence Parallelism" in favor of "Context Parallelism" for user-facing APIs regarding sequence splitting?
- Configuration API: How should we expose the choice between Ulysses and Ring to the user under the context-parallel Terminology setting? (e.g., --cp-method=auto|ulysses|ring)
- Should we adapt the RingAttention kernel from flash_attn directly?
Please share your thoughts and suggestions below.
Test Documentation
Overview
End-to-end system tests for Sequence Parallelism in diffusion models, covering both Ulysses-SP (DeepSpeed Ulysses Sequence Parallel) and Ring Attention strategies.
Test File
tests/e2e/offline_inference/test_sequence_parallel.py
Tested Features
| Feature | Description |
|---|---|
| Ulysses Sequence Parallelism | Shards the sequence dimension across GPUs using All-to-All communication for Q/K/V redistribution |
| Ring Attention | Shards the sequence dimension across GPUs using ring-based P2P communication to accumulate attention results |
| Hybrid Ulysses + Ring | Combines both strategies for maximum parallelism (ulysses_degree × ring_degree GPUs) |
Test Matrix
| Parameter | Values | Description |
|---|---|---|
ulysses_degree |
[1, 2] |
Number of GPUs for Ulysses parallelism |
ring_degree |
[1, 2] |
Number of GPUs for Ring Attention |
dtype |
[torch.bfloat16] |
Data type for inference |
attn_backend |
["sdpa"] |
Attention backend (PyTorch Scaled Dot Product Attention) |
Effective Test Cases (excluding ulysses=1, ring=1):
| Test Case | ulysses_degree | ring_degree | Total GPUs | Strategy |
|---|---|---|---|---|
| Ring-only | 1 | 2 | 2 | Ring Attention |
| Ulysses-only | 2 | 1 | 2 | Ulysses-SP |
| Hybrid | 2 | 2 | 4 | Ulysses + Ring |
Test Methodology
- Baseline Generation: Run inference with
ulysses_degree=1, ring_degree=1(single GPU, no parallelism) - SP Generation: Run inference with the target parallel configuration
- Output Comparison: Compare generated images pixel-by-pixel
Key Functions/Classes Tested
| Component | Location | Description |
|---|---|---|
Omni |
vllm_omni.Omni |
Main entry point for diffusion inference |
DiffusionParallelConfig |
vllm_omni.diffusion.data |
Configuration for sequence parallelism |
RingParallelAttention |
vllm_omni.diffusion.attention.parallel.ring |
Ring Attention implementation |
UlyssesParallelAttention |
vllm_omni.diffusion.attention.parallel.ulysses |
Ulysses-SP implementation |
RingComm |
vllm_omni.diffusion.distributed.comm |
Ring P2P communication primitives |
SeqAllToAll4D/5D |
vllm_omni.diffusion.distributed.comm |
All-to-All communication for Ulysses |
ring_pytorch_attn_func |
vllm_omni.diffusion.attention.backends.ring_pytorch_attn |
Ring Attention kernel (SDPA backend) |
Validation Criteria
# Thresholds for BF16/FP16
mean_threshold = 2e-2 # Mean absolute difference
max_threshold = 2e-1 # Max absolute differenceThe test passes if:
- Both baseline and SP runs produce valid images (non-None, correct dimensions)
- Pixel-wise difference between baseline and SP outputs is within thresholds
Hardware Requirements
- Minimum: 2 GPUs (for Ring-only or Ulysses-only tests)
- Full suite: 4 GPUs (for Hybrid test)
Run Command
pytest -s -v tests/e2e/offline_inference/test_sequence_parallel.pyCI Configuration
- Timeout: 15 minutes
- Docker shm-size: 8GB (required for NCCL P2P communication)
- GPU Queue:
gpu_4_queue(4× L4 GPUs)
Feedback Period.
No response
CC List.
No response
Any Other Things.
No response
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.