-
Notifications
You must be signed in to change notification settings - Fork 512
Description
1 Overview
Refactor the Diffusion subsystem by separating the monolithic GPUWorker into distinct GPUDiffusionWorker (infrastructure) and GPUDiffusionModelRunner (execution) components, and introduce DiffusionKVManager to manage KV cache lifecycle for diffusion models.
1.1 Motivation
The current Diffusion subsystem in vLLM-Omni uses a monolithic GPUWorker class that combines infrastructure setup (device initialization, distributed environment) with model execution logic (forward pass, KV cache operations). This design violates the separation of concerns principle and creates several issues:
- Inconsistency with AR subsystem: The AR (Autoregressive) subsystem properly separates GPUARWorker (infrastructure only) from GPUARModelRunner (execution + KV cache operations). This inconsistency makes the codebase harder to maintain and extend.
- Limited KV cache support: The current Diffusion worker only supports KV cache receiving from previous stages but lacks the ability to extract and transfer KV cache to downstream stages. This prevents diffusion models from being used as intermediate stages in multi-stage pipelines.
- No memory lifecycle management: Unlike AR which uses vLLM's
KVCacheManagerfor block-based memory management, Diffusion has no equivalent mechanism to track and manage KV cache lifecycle.
This feature is primarily designed to support BAGEL model stage separation. BAGEL is a multimodal model that combines AR (BAGEL-7B-MoT) and DiT (Diffusion Transformer) components through a unified KV cache mechanism (NaiveCache). The model supports:
- Multimodal input: Text + Image encoded via ViT + VAE into KV cache
- Multimodal output: Image generation via diffusion denoising loop
By enabling KV cache transfer between AR and DiT stages, we can deploy BAGEL across multiple GPU nodes:
- Stage 1 (AR): BAGEL-7B-MoT processes text/image input, generates KV cache
- Stage 2 (DiT): BAGEL diffusion generates images using transferred KV cache
graph TB
S_AR[OmniARScheduler]
S_DiT[SyncScheduler]
DKVM[DiffusionKVManager]
W_AR[GPUARWorker]
W_DiT[GPUDiffusionWorker]
R_AR[GPUARModelRunner]
R_DiT[GPUDiffusionModelRunner]
KV[OmniConnector]
%% Relationships
S_AR -->|SchedulerOutput + BlockIDs| W_AR
W_AR -->|Delegates| R_AR
R_AR -->|Extracts Standard KV| KV
S_DiT -->|Request List| W_DiT
S_DiT -->|Owns / Manages| DKVM
W_DiT -->|Delegates| R_DiT
KV -->|Pushes/Pulled Standard KV| DKVM
DKVM -->|Provides KV| R_DiT
R_DiT -->|Extracts Output KV| KV
%% Styles
style W_DiT fill:#ffb347,stroke:#333,stroke-width:2px
style R_DiT fill:#99ff99,stroke:#333,stroke-width:2px
style DKVM fill:#99ff99,stroke:#333,stroke-width:2px
1.2 Target
Feature
In Scope:
- Separate GPUWorker into
GPUDiffusionWorker(infrastructure) andGPUDiffusionModelRunner(execution) GPUDiffusionWorkerresponsibilities:- CUDA device initialization
- Distributed environment setup (NCCL, process groups)
- Instantiate
GPUDiffusionModelRunner - Memory management (sleep/wake)
GPUDiffusionModelRunnerresponsibilities:- Model/pipeline loading in init
- Forward pass execution (execute_model)
- KV cache receiving from previous stages (existing)
- KV cache extraction using model utility method to organize format (NEW)
- KV cache transfer to downstream stages (NEW)
- Introduce
DiffusionKVManagerfor KV cache lifecycle tracking
Out of Scope:
- PagedAttention for diffusion models
- Performance optimization for KV transfer (phase 2)
Accuracy
-
Data integrity: All KV data transferred via connectors must be bit-exact (no corruption or loss)
-
Deterministic behavior: Same inputs must produce same outputs regardless of KV transfer path
-
Compatibility: Must work with existing BAGEL stage configurations without breaking changes
-
NaiveCache format conversion: Leverage the model's utility methods to convert standard KV cache to/from BAGEL's
NaiveCachestructure during execution. -
kv_lens/ropes propagation: Preserve
kv_lensandropesmetadata for correct position encoding
Performance
- Latency overhead: KV extraction should add < 5ms per request (GPU → CPU copy)
- Memory: KV cache extraction should use efficient tensor slicing without full tensor duplication
- BAGEL specific: Support up to 36 layers of KV cache with seq_len up to 4096 tokens
- Trade-offs:
- Synchronous KV transfer ensures data integrity but adds latency (async transfer as future optimization)
2 Design
2.1 Overview of Design
The refactoring follows the established AR subsystem pattern with three-layer architecture:
Design Principles:
- Worker = Infrastructure Only:
GPUDiffusionWorkerhandles device setup and instantiates the Runner. It does NOT load models or execute inference. - Runner = Execution + Data I/O:
GPUDiffusionModelRunnerowns the model lifecycle, executes forward passes, and manages all KV cache operations. - KVManager = Lifecycle Tracking:
DiffusionKVManagertracks active requests and their KV cache metadata.
Components Interaction:
┌─────────────────────────────────────────────────────────────────────┐
│ SyncScheduler │
│ - schedule() │
│ - triggers KVManager.receive_from_connector() │
│ - owns DiffusionKVManager │
└─────────────────────────────────────────────────────────────────────┘
│
│ requests
▼
┌─────────────────────────────────────────────────────────────────────┐
│ GPUDiffusionWorker │
│ - init_device() │
│ - owns GPUDiffusionModelRunner │
│ - execute_model() → pass-through to Runner │
│ - sleep() / wake_up() │
└─────────────────────────────────────────────────────────────────────┘
│
│ delegates
▼
┌─────────────────────────────────────────────────────────────────────┐
│ GPUDiffusionModelRunner │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ __init__: │ │
│ │ - _load_model() → load BagelPipeline │ │
│ │ - _init_connector() → create OmniConnector │ │
│ │ - kv_manager → DiffusionKVManager reference │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ execute_model(requests): │ │
│ │ 1. kv_manager.get_kv_cache(req) │ │
│ │ → returns Standard/Raw KV Cache │ │
│ │ 2. Model Utility Conversion │ │
│ │ → convert raw bytes to Model-Specific format │ │
│ │ → e.g. BAGEL utility -> NaiveCache │ │
│ │ 3. BagelPipeline.forward(req) │ │
│ │ → use converted KV cache │ │
│ │ 4. _handle_kv_transfer() (if needed) │ │
│ │ → extract output KV (Standard Format) │ │
│ │ → transfer via OmniConnector │ │
│ │ 5. kv_manager.free(req) │ │
│ └─────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────┘
│
┌───────────┴───────────┐
│ │
┌───────────────────┐ │
│ DiffusionKVManager│◄──────────────┤
│ - receive_kv() │ standard KV │ ┌───────────────────┐
│ - get_kv() │ └───┤ OmniConnector │
│ - free() │ │ - get() │
└───────────────────┘ └───────────────────┘
2.2 API Design
Current Component Changes
1. vllm_omni/diffusion/worker/gpu_worker.py (GPUWorker class)
- Change: Refactor into two separate classes:
GPUDiffusionWorkerandGPUDiffusionModelRunner - Why: Current class mixes infrastructure and execution logic
- Impact:
- init_device_and_model() split into Worker's init_device() and Runner's
_load_model() - execute_model() moves to Runner
- _receive_kv_cache_for_request() moves to Runner with NaiveCache support
- init_device_and_model() split into Worker's init_device() and Runner's
2. vllm_omni/diffusion/models/bagel/pipeline_bagel.py (BagelPipeline class)
- Change: Add method to expose
gen_context["past_key_values"]for extraction - Why: Runner needs access to internal NaiveCache after forward pass
- Impact: Add
get_past_key_values()andget_kv_metadata()methods - Code Location: BagelPipeline class
3. vllm_omni/diffusion/scheduler.py (Scheduler class)
- Change: Add get_finished_requests_needing_kv_transfer() method
- Why: Enable scheduler to signal which requests need KV transfer
- Impact: New method, no breaking changes
New APIs
1. vllm_omni/diffusion/worker/gpu_diffusion_worker.py
class GPUDiffusionWorker:
"""Infrastructure wrapper for diffusion model execution."""
def __init__(self, local_rank: int, rank: int, od_config: OmniDiffusionConfig):
...
def init_device(self) -> None:
"""Initialize CUDA device and distributed environment."""
...
def execute_model(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput:
"""Pass-through to model runner."""
...
def sleep(self, level: int = 1) -> bool:
...
def wake_up(self, tags: list[str] | None = None) -> bool:
...2. vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py
class GPUDiffusionModelRunner:
"""Execution layer for diffusion models with BAGEL KV cache support."""
def __init__(self, config: OmniDiffusionConfig, device: torch.device):
...
def execute_model(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput:
...
def _extract_kv_cache(self, request: OmniDiffusionRequest) -> KVCacheTransferData:
"""
Extract KV cache, organizing the format using the model's utility method.
Returns:
KVCacheTransferData with:
- layer_blocks["key_cache"]: list[torch.Tensor] per layer
- layer_blocks["value_cache"]: list[torch.Tensor] per layer
- block_ids: [] (empty for dense tensors)
- metadata: {"kv_lens": [...], "ropes": [...], "num_layers": N}
"""
...3. vllm_omni/diffusion/core/kv_manager.py
class DiffusionKVManager:
"""Manages KV cache lifecycle for diffusion requests."""
def receive_from_connector(self, connector: OmniConnector, request: OmniDiffusionRequest) -> None:
"""Fetch standard KV cache from connector and store it."""
...
def get_kv_cache(self, request_id: str) -> Any:
"""Retrieve stored KV cache for a request (Standard/Raw format)."""
...
def allocate(self, request: OmniDiffusionRequest) -> None:
...
def free(self, request: OmniDiffusionRequest) -> None:
...2.3 API Call Dependency
┌─────────────────────────────────────────────────────────────────────┐
│ Stage 1: AR (BAGEL-7B-MoT) │
│ 1. Process text/image input │
│ 2. Generate KV cache (PagedAttention format) │
│ 3. _extract_kv_cache() → KVCacheTransferData │
│ 4. connector.put() → transfer to Stage 2 │
└─────────────────────────────────────────────────────────────────────┘
│
KV Cache Transfer via OmniConnector
(SharedMemory or Mooncake)
│
▼
┌─────────────────────────────────────────────────────────────────────┐
│ Stage 2: DiT (BAGEL-7B-MoT) │
│ 1. Scheduler triggers KVManager.receive_from_connector(req) │
│ - Manager pulls Standard KV from Connector │
│ 2. Runner.execute_model(req) │
│ - kv_data = KVManager.get_kv_cache(req) │
│ - converted_kv = ModelUtility.convert(kv_data) │
│ - BagelPipeline.forward(req, past_key_values=converted_kv) │
│ 3. (Optional) _handle_kv_transfer() for downstream stages │
│ - extract output KV (Standard Format) │
│ - transfer via OmniConnector │
└─────────────────────────────────────────────────────────────────────┘
3 Test Cases
3.1 Unit Test (UT) Design
File: tests/diffusion/worker/test_gpu_diffusion_worker.py
1. test_worker_init_and_delegation()
- Purpose: Verify Worker initializes correctly and delegates execution to Runner.
- Assertions: Device set, Runner created,
execute_modeldelegation works.
File: tests/diffusion/worker/test_gpu_diffusion_model_runner.py
2. test_runner_kv_operations_via_manager()
- Purpose: Verify Runner retrieves KV from Manager and converts it using model utility.
- Assertions:
kv_manager.get_kv_cacheis called; Model utility conversion is invoked; Pipeline receives converted KV.
File: tests/diffusion/core/test_kv_manager.py
3. test_kv_manager_data_handling()
- Purpose: Verify Manager can receive data from connector and provide it to runner.
- Assertions: Data fetched from connector; Data stored in Manager; Data retrievable via
get_kv_cache; Data cleared on free.
3.2 Smoke Test (ST) Design
File: tests/e2e/diffusion/test_bagel_kv_transfer.py
1. test_kv_transfer_end_to_end()
- Purpose: Verify end-to-end KV transfer in a multi-stage pipeline (AR → DiT).
- Setup: Two stages (AR: BAGEL-7B-MoT, DiT: BAGEL-7B-MoT) connected via SharedMemory/Mooncake.
- Test Steps:
- Run AR stage to generate KV cache.
- Transfer KV to DiT stage.
- Verify DiT stage receives valid KV and generates correct output.
- Assertions: KV data transfer integrity, successful generation.
Last Updated: Jan 15, 2026
Author: Wang Zhipeng