Skip to content
121 changes: 111 additions & 10 deletions docs/user_guide/acceleration/parallelism_acceleration.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,23 @@ The following parallelism methods are currently supported in vLLM-Omni:

2. [Ring-Attention](#ring-attention) - splits the input along the sequence dimension and uses ring-based P2P communication to accumulate attention results, keeping the sequence dimension sharded

3. Classifier-Free-Guidance Parallel (CFG-Parallel): CFG-Parallel runs the positive/negative prompts of classifier-free guidance (CFG) on different devices, then merges on a single device to perform the scheduler step.

The following table shows which models are currently supported by parallelism method:

### ImageGen

| Model | Model Identifier | Ulysses-SP | Ring-SP |
|-------|------------------|-----------|---------|
| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ❌ | ❌ |
| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ❌ | ❌ |
| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ❌ |
| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ |
| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ |
| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ |
| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ✅ | ✅ |
| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ❌ | ❌ |
| Model | Model Identifier | Ulysses-SP | Ring-SP |CFG-Parallel |
|-------|------------------|-----------|-------------|-------------|
| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ❌ | ❌ | ❌ |
| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ❌ | ❌ | ❌ |
| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ❌ | ❌ |
| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ |
| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ |
| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ |
| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ✅ | ✅ | ✅ |
| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ❌ | ❌ | ❌ |
| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ |

### VideoGen

Expand Down Expand Up @@ -181,6 +183,7 @@ To measure the parallelism methods, we run benchmarks with **Qwen/Qwen-Image** m

If a diffusion model has been deployed in vLLM-Omni and supports single-card inference, you can refer to the following instructions to parallelize it with [Ulysses-SP](https://arxiv.org/pdf/2309.14509).


This section uses **Qwen-Image** (`QwenImageTransformer2DModel`) as the reference implementation. Qwen-Image is a **dual-stream** transformer (text + image) that performs **joint attention** across the concatenated sequences. Because of that, when enabling sequence parallel you typically:

- Chunk **image tokens** (`hidden_states`) across SP ranks along the **sequence dimension**.
Expand Down Expand Up @@ -272,3 +275,101 @@ omni = Omni(

outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50)
```


### CFG-Parallel

##### Offline Inference

CFG-Parallel is enabled through `DiffusionParallelConfig(cfg_parallel_size=...)`. The recommended configuration is `cfg_parallel_size=2` (one rank for the positive branch and one rank for the negative branch).

An example of offline inference using CFG-Parallel (image-to-image) is shown below:

```python
from vllm_omni import Omni
from vllm_omni.diffusion.data import DiffusionParallelConfig

omni = Omni(
model="Qwen/Qwen-Image-Edit",
parallel_config=DiffusionParallelConfig(cfg_parallel_size=2),
)

outputs = omni.generate(
prompt="turn this cat to a dog",
negative_prompt="low quality, blurry",
true_cfg_scale=4.0,
pil_image=input_image,
num_inference_steps=50,
)
```

Notes:

- CFG-Parallel is only effective when **true CFG** is enabled (i.e., `true_cfg_scale > 1` and a `negative_prompt` is provided).

#### How to parallelize a pipeline

This section describes how to add CFG-Parallel to a diffusion **pipeline**. We use the Qwen-Image pipeline (`vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py`) as the reference implementation.

In `QwenImagePipeline`, each diffusion step runs two denoiser forward passes sequentially:

- positive (prompt-conditioned)
- negative (negative-prompt-conditioned)

CFG-Parallel assigns these two branches to different ranks in the **CFG group** and synchronizes the results.

Below is an example of CFG-Parallel implementation:

```python
def diffuse(
self,
...
):
# Enable CFG-parallel: rank0 computes positive, rank1 computes negative.
cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1

self.transformer.do_true_cfg = do_true_cfg

if cfg_parallel_ready:
cfg_group = get_cfg_group()
cfg_rank = get_classifier_free_guidance_rank()

if cfg_rank == 0:
local_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
else:
local_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]

gathered = cfg_group.all_gather(local_pred, separate_tensors=True)
if cfg_rank == 0:
noise_pred = gathered[0]
neg_noise_pred = gathered[1]
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
noise_pred = comb_pred * (cond_norm / noise_norm)
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
cfg_group.broadcast(latents, src=0)
else:
# fallback: run positive then negative sequentially on one rank
...
```
57 changes: 39 additions & 18 deletions docs/user_guide/diffusion_acceleration.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@ vLLM-Omni currently supports two main cache acceleration backends:

Both methods can provide significant speedups (typically **1.5x-2.0x**) while maintaining high output quality.

vLLM-Omni also supports the sequence parallelism (SP) for the diffusion model, that includes:
vLLM-Omni also supports parallelism methods for diffusion models, including:

1. [Ulysses-SP](acceleration/parallelism_acceleration.md#ulysses-sp) - splits the input along the sequence dimension and uses all-to-all communication to allow each device to compute only a subset of attention heads.

2. [Ring-Attention](acceleration/parallelism_acceleration.md#ring-attention) - splits the input along the sequence dimension and uses ring-based P2P communication to accumulate attention results, keeping the sequence dimension sharded.

3. [CFG-Parallel](acceleration/parallelism_acceleration.md#cfg-parallel) - runs the positive/negative prompts of classifier-free guidance (CFG) on different devices, then merges on a single device to perform the scheduler step.

## Quick Comparison

### Cache Methods
Expand All @@ -34,21 +37,18 @@ The following table shows which models are currently supported by each accelerat

### ImageGen

| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention |
|-------|------------------|:----------:|:-----------:|:-----------:|:----------------:|
| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ❌ | ✅ | ❌ | ❌ |
| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ❌ | ✅ | ❌ | ❌ |
| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ✅ | ❌ | ❌ |
| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ |
| **Qwen-Image-2512** | `Qwen/Qwen-Image-2512` | ✅ | ✅ | ✅ | ✅ |
| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ✅ |
| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ✅ |
| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ❌ | ✅ | ✅ | ✅ |
| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ❌ | ✅ |❌ | ❌ |
| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ❌ | ✅ | ❌ | ❌ |
| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ❌ | ✅ | ❌ | ❌ |
| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ✅ | ❌ | ❌ |
| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ |
| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention | CFG-Parallel |
|-------|------------------|:----------:|:-----------:|:-----------:|:----------------:|:----------------:|
| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ❌ | ✅ | ❌ | ❌ | ❌ |
| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ❌ | ✅ | ❌ | ❌ | ❌ |
| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ✅ | ❌ | ❌ | ❌ |
| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ | ✅ |
| **Qwen-Image-2512** | `Qwen/Qwen-Image-2512` | ✅ | ✅ | ✅ | ✅ | ✅ |
| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ✅ | ✅ |
| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ✅ | ✅ |
| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ❌ | ✅ | ✅ | ✅ | ✅ |
| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ❌ | ✅ |❌ | ❌ | ❌ |
| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ | ❌ |

### VideoGen

Expand Down Expand Up @@ -135,7 +135,7 @@ ulysses_degree = 2

omni = Omni(
model="Qwen/Qwen-Image",
parallel_config=DiffusionParallelConfig(ulysses_degree=2)
parallel_config=DiffusionParallelConfig(ulysses_degree=ulysses_degree)
)

outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50, width=2048, height=2048)
Expand All @@ -150,7 +150,7 @@ ulysses_degree = 2

omni = Omni(
model="Qwen/Qwen-Image-Edit",
parallel_config=DiffusionParallelConfig(ulysses_degree=2)
parallel_config=DiffusionParallelConfig(ulysses_degree=ulysses_degree)
)

outputs = omni.generate(prompt="turn this cat to a dog",
Expand All @@ -173,10 +173,31 @@ omni = Omni(
outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50, width=2048, height=2048)
```

### Using CFG-Parallel

Run image-to-image:

CFG-Parallel splits the CFG positive/negative branches across GPUs. Use it when you set a non-trivial `true_cfg_scale`.

```python
from vllm_omni import Omni
from vllm_omni.diffusion.data import DiffusionParallelConfig
cfg_parallel_size = 2

omni = Omni(
model="Qwen/Qwen-Image-Edit",
parallel_config=DiffusionParallelConfig(cfg_parallel_size=cfg_parallel_size)
)

outputs = omni.generate(prompt="turn this cat to a dog",
pil_image=input_image, num_inference_steps=50, true_cfg_scale=4.0)
```

## Documentation

For detailed information on each acceleration method:

- **[TeaCache Guide](acceleration/teacache.md)** - Complete TeaCache documentation, configuration options, and best practices
- **[Cache-DiT Acceleration Guide](acceleration/cache_dit_acceleration.md)** - Comprehensive Cache-DiT guide covering DBCache, TaylorSeer, SCM, and configuration parameters
- **[Sequence Parallelism](acceleration/parallelism_acceleration.md#sequence-parallelism)** - Guidance on how to set sequence parallelism with configuration.
- **[CFG-Parallel](acceleration/parallelism_acceleration.md#cfg-parallel)** - Guidance on how to set CFG-Parallel to run positive/negative branches across ranks.
25 changes: 22 additions & 3 deletions examples/offline_inference/image_to_image/image_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@
--layers 4 \
--color-format "RGBA"

Usage (with CFG Parallel):
python image_edit.py \
--image input.png \
--prompt "Edit description" \
--cfg_parallel_size 2 \
--num_inference_steps 50 \
--cfg_scale 4.0 \

For more options, run:
python image_edit.py --help
"""
Expand Down Expand Up @@ -245,7 +253,13 @@ def parse_args() -> argparse.Namespace:
default=0.2,
help="[tea_cache] Threshold for accumulated relative L1 distance.",
)

parser.add_argument(
"--cfg_parallel_size",
type=int,
default=1,
choices=[1, 2],
help="Number of GPUs used for classifier free guidance parallel size.",
)
return parser.parse_args()


Expand Down Expand Up @@ -273,7 +287,10 @@ def main():
# Enable VAE memory optimizations on NPU
vae_use_slicing = is_npu()
vae_use_tiling = is_npu()
parallel_config = DiffusionParallelConfig(ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree)
parallel_config = DiffusionParallelConfig(
ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree, cfg_parallel_size=args.cfg_parallel_size
)

# Configure cache based on backend type
cache_config = None
if args.cache_backend == "cache_dit":
Expand Down Expand Up @@ -319,7 +336,9 @@ def main():
print(f" Image {idx + 1} size: {img.size}")
else:
print(f" Input image size: {input_image.size}")
print(f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}")
print(
f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}"
)
print(f"{'=' * 60}\n")

generation_start = time.perf_counter()
Expand Down
16 changes: 14 additions & 2 deletions examples/offline_inference/text_to_image/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ def parse_args() -> argparse.Namespace:
default=1,
help="Number of GPUs used for ring sequence parallelism.",
)
parser.add_argument(
"--cfg_parallel_size",
type=int,
default=1,
choices=[1, 2],
help="Number of GPUs used for classifier free guidance parallel size.",
)
return parser.parse_args()


Expand Down Expand Up @@ -123,7 +130,10 @@ def main():
}

# assert args.ring_degree == 1, "Ring attention is not supported yet"
parallel_config = DiffusionParallelConfig(ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree)
parallel_config = DiffusionParallelConfig(
ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree, cfg_parallel_size=args.cfg_parallel_size
)

omni = Omni(
model=args.model,
vae_use_slicing=vae_use_slicing,
Expand All @@ -139,7 +149,9 @@ def main():
print(f" Model: {args.model}")
print(f" Inference steps: {args.num_inference_steps}")
print(f" Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}")
print(f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}")
print(
f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}"
)
print(f" Image size: {args.width}x{args.height}")
print(f"{'=' * 60}\n")

Expand Down
19 changes: 17 additions & 2 deletions vllm_omni/diffusion/cache/teacache/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
from vllm_omni.diffusion.cache.teacache.config import TeaCacheConfig
from vllm_omni.diffusion.cache.teacache.extractors import get_extractor
from vllm_omni.diffusion.cache.teacache.state import TeaCacheState
from vllm_omni.diffusion.distributed.parallel_state import (
get_classifier_free_guidance_rank,
get_classifier_free_guidance_world_size,
)
from vllm_omni.diffusion.hooks import HookRegistry, ModelHook, StateManager


Expand All @@ -34,6 +38,7 @@ class TeaCacheHook(ModelHook):
Key features:
- Zero changes to model code
- CFG-aware with separate states for positive/negative branches
- CFG-parallel compatible: properly detects branch identity across ranks
- Model-specific polynomial rescaling
- Auto-detection of model types

Expand Down Expand Up @@ -113,8 +118,18 @@ def new_forward(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> Any
# GENERIC CACHING LOGIC (works for all models)
# ============================================================================
# Set context based on CFG branch for separate state tracking
if module.do_true_cfg and self._forward_cnt % 2 == 1:
cache_branch = "negative"
# With CFG-parallel, each rank processes only one branch:
# - cfg_rank 0: positive branch
# - cfg_rank > 0: negative branch
# Without CFG-parallel, branches alternate within a single rank
if module.do_true_cfg:
cfg_parallel_size = get_classifier_free_guidance_world_size()
if cfg_parallel_size > 1:
cfg_rank = get_classifier_free_guidance_rank()
cache_branch = "negative" if cfg_rank > 0 else "positive"
else:
# No CFG-parallel: use forward counter to alternate branches
cache_branch = "negative" if self._forward_cnt % 2 == 1 else "positive"
else:
cache_branch = "positive"

Expand Down
Loading