Skip to content

Commit b7604ae

Browse files
committed
add fp8 quantization support of Dit for z-image
Signed-off-by: lishunyang <[email protected]>
1 parent 741f7e2 commit b7604ae

File tree

9 files changed

+363
-4
lines changed

9 files changed

+363
-4
lines changed

examples/offline_inference/text_to_image/text_to_image.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,15 @@ def parse_args() -> argparse.Namespace:
113113
default=1,
114114
help="Number of GPUs used for tensor parallelism (TP) inside the DiT.",
115115
)
116+
parser.add_argument(
117+
"--quantization",
118+
type=str,
119+
default=None,
120+
choices=["fp8"],
121+
help="Quantization method for the transformer. "
122+
"Options: 'fp8' (FP8 W8A8 on Ada/Hopper, weight-only on older GPUs). "
123+
"Default: None (no quantization, uses BF16).",
124+
)
116125
parser.add_argument(
117126
"--vae_use_slicing",
118127
action="store_true",
@@ -180,6 +189,7 @@ def main():
180189
parallel_config=parallel_config,
181190
enforce_eager=args.enforce_eager,
182191
enable_cpu_offload=args.enable_cpu_offload,
192+
quantization=args.quantization,
183193
)
184194

185195
if profiler_enabled:
@@ -192,6 +202,7 @@ def main():
192202
print(f" Model: {args.model}")
193203
print(f" Inference steps: {args.num_inference_steps}")
194204
print(f" Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}")
205+
print(f" Quantization: {args.quantization if args.quantization else 'None (BF16)'}")
195206
print(
196207
f" Parallel configuration: tensor_parallel_size={args.tensor_parallel_size}, "
197208
f"ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}"
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Unit tests for FP8 quantization config."""
4+
5+
import pytest
6+
7+
8+
def test_fp8_config_creation():
9+
"""Test that FP8 config can be created."""
10+
from vllm_omni.diffusion.quantization import get_diffusion_quant_config
11+
12+
config = get_diffusion_quant_config("fp8")
13+
assert config is not None
14+
assert config.get_name() == "fp8"
15+
16+
17+
def test_vllm_config_extraction():
18+
"""Test that vLLM config can be extracted from diffusion config."""
19+
from vllm_omni.diffusion.quantization import (
20+
get_diffusion_quant_config,
21+
get_vllm_quant_config_for_layers,
22+
)
23+
24+
diff_config = get_diffusion_quant_config("fp8")
25+
vllm_config = get_vllm_quant_config_for_layers(diff_config)
26+
assert vllm_config is not None
27+
assert vllm_config.activation_scheme == "dynamic"
28+
29+
30+
def test_none_quantization():
31+
"""Test that None quantization returns None config."""
32+
from vllm_omni.diffusion.quantization import (
33+
get_diffusion_quant_config,
34+
get_vllm_quant_config_for_layers,
35+
)
36+
37+
config = get_diffusion_quant_config(None)
38+
assert config is None
39+
vllm_config = get_vllm_quant_config_for_layers(config)
40+
assert vllm_config is None
41+
42+
43+
def test_invalid_quantization():
44+
"""Test that invalid quantization method raises error."""
45+
from vllm_omni.diffusion.quantization import get_diffusion_quant_config
46+
47+
with pytest.raises(ValueError, match="Unknown quantization method"):
48+
get_diffusion_quant_config("invalid_method")
49+
50+
51+
def test_fp8_config_with_custom_params():
52+
"""Test FP8 config with custom parameters."""
53+
from vllm_omni.diffusion.quantization import get_diffusion_quant_config
54+
55+
config = get_diffusion_quant_config(
56+
"fp8",
57+
activation_scheme="static",
58+
ignored_layers=["proj_out"],
59+
)
60+
assert config is not None
61+
assert config.activation_scheme == "static"
62+
assert "proj_out" in config.ignored_layers
63+
64+
65+
def test_supported_methods():
66+
"""Test that supported methods list is correct."""
67+
from vllm_omni.diffusion.quantization import SUPPORTED_QUANTIZATION_METHODS
68+
69+
assert "fp8" in SUPPORTED_QUANTIZATION_METHODS

vllm_omni/diffusion/data.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
from vllm_omni.diffusion.utils.network_utils import is_port_available
1818

19+
# Import after TYPE_CHECKING to avoid circular imports at runtime
20+
# The actual import is deferred to __post_init__ to avoid import order issues
21+
1922
logger = init_logger(__name__)
2023

2124

@@ -358,6 +361,11 @@ class OmniDiffusionConfig:
358361
# Omni configuration (injected from stage config)
359362
omni_kv_config: dict[str, Any] = field(default_factory=dict)
360363

364+
# Quantization settings
365+
# Supported methods: "fp8" (FP8 W8A8 on Ada/Hopper, weight-only on older GPUs)
366+
quantization: str | None = None
367+
quantization_config: Any | None = None # DiffusionQuantizationConfig or dict
368+
361369
def settle_port(self, port: int, port_inc: int = 42, max_attempts: int = 100) -> int:
362370
"""
363371
Find an available port with retry logic.
@@ -444,6 +452,16 @@ def __post_init__(self):
444452
# If it's neither dict nor DiffusionCacheConfig, convert to empty config
445453
self.cache_config = DiffusionCacheConfig()
446454

455+
# Convert quantization config (deferred import to avoid circular imports)
456+
if self.quantization is not None or self.quantization_config is not None:
457+
from vllm_omni.diffusion.quantization import get_diffusion_quant_config
458+
459+
if isinstance(self.quantization_config, dict):
460+
quant_method = self.quantization_config.pop("method", self.quantization)
461+
self.quantization_config = get_diffusion_quant_config(quant_method, **self.quantization_config)
462+
elif self.quantization_config is None and self.quantization is not None:
463+
self.quantization_config = get_diffusion_quant_config(self.quantization)
464+
447465
if self.max_cpu_loras is None:
448466
self.max_cpu_loras = 1
449467
elif self.max_cpu_loras < 1:

vllm_omni/diffusion/models/z_image/pipeline_z_image.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from vllm_omni.diffusion.models.z_image.z_image_transformer import (
3838
ZImageTransformer2DModel,
3939
)
40+
from vllm_omni.diffusion.quantization import get_vllm_quant_config_for_layers
4041
from vllm_omni.diffusion.request import OmniDiffusionRequest
4142
from vllm_omni.model_executor.model_loader.weight_utils import (
4243
download_weights_from_hf_specific,
@@ -173,7 +174,9 @@ def __init__(
173174
self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to(
174175
self._execution_device
175176
)
176-
self.transformer = ZImageTransformer2DModel()
177+
# Get vLLM quantization config for linear layers
178+
quant_config = get_vllm_quant_config_for_layers(od_config.quantization_config)
179+
self.transformer = ZImageTransformer2DModel(quant_config=quant_config)
177180
self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
178181

179182
# Note: Context parallelism is applied centrally in registry.initialize_model()

vllm_omni/diffusion/models/z_image/z_image_transformer.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import math
2020
from collections.abc import Iterable
21+
from typing import TYPE_CHECKING
2122

2223
import torch
2324
import torch.nn as nn
@@ -32,6 +33,11 @@
3233
)
3334
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
3435

36+
if TYPE_CHECKING:
37+
from vllm.model_executor.layers.quantization.base_config import (
38+
QuantizationConfig,
39+
)
40+
3541
from vllm_omni.diffusion.attention.layer import Attention
3642
from vllm_omni.diffusion.cache.base import CachedTransformer
3743
from vllm_omni.diffusion.distributed.sp_plan import (
@@ -250,6 +256,7 @@ def __init__(
250256
num_kv_heads: int,
251257
qk_norm: bool = True,
252258
eps: float = 1e-6,
259+
quant_config: "QuantizationConfig | None" = None,
253260
) -> None:
254261
super().__init__()
255262
self.dim = dim
@@ -264,6 +271,7 @@ def __init__(
264271
total_num_heads=num_heads,
265272
total_num_kv_heads=num_kv_heads,
266273
bias=False,
274+
quant_config=quant_config,
267275
)
268276

269277
assert qk_norm is True
@@ -281,6 +289,7 @@ def __init__(
281289
bias=False,
282290
input_is_parallel=True,
283291
return_bias=False,
292+
quant_config=quant_config,
284293
)
285294
]
286295
)
@@ -343,13 +352,19 @@ def forward(
343352

344353

345354
class FeedForward(nn.Module):
346-
def __init__(self, dim: int, hidden_dim: int):
355+
def __init__(
356+
self,
357+
dim: int,
358+
hidden_dim: int,
359+
quant_config: "QuantizationConfig | None" = None,
360+
):
347361
super().__init__()
348362
self.w13 = MergedColumnParallelLinear(
349363
dim,
350364
[hidden_dim] * 2,
351365
bias=False,
352366
return_bias=False,
367+
quant_config=quant_config,
353368
)
354369
self.act = SiluAndMul()
355370
self.w2 = RowParallelLinear(
@@ -358,6 +373,7 @@ def __init__(self, dim: int, hidden_dim: int):
358373
bias=False,
359374
input_is_parallel=True,
360375
return_bias=False,
376+
quant_config=quant_config,
361377
)
362378

363379
def forward(self, x):
@@ -374,6 +390,7 @@ def __init__(
374390
norm_eps: float,
375391
qk_norm: bool,
376392
modulation=True,
393+
quant_config: "QuantizationConfig | None" = None,
377394
):
378395
super().__init__()
379396
self.dim = dim
@@ -384,9 +401,14 @@ def __init__(
384401
num_kv_heads=n_kv_heads,
385402
qk_norm=qk_norm,
386403
eps=1e-5,
404+
quant_config=quant_config,
387405
)
388406

389-
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
407+
self.feed_forward = FeedForward(
408+
dim=dim,
409+
hidden_dim=int(dim / 3 * 8),
410+
quant_config=quant_config,
411+
)
390412
self.layer_id = layer_id
391413

392414
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
@@ -589,6 +611,7 @@ def __init__(
589611
t_scale=1000.0,
590612
axes_dims=[32, 48, 48],
591613
axes_lens=[1024, 512, 512],
614+
quant_config: "QuantizationConfig | None" = None,
592615
) -> None:
593616
super().__init__()
594617
self.dtype = torch.bfloat16
@@ -648,6 +671,7 @@ def __init__(
648671
norm_eps,
649672
qk_norm,
650673
modulation=True,
674+
quant_config=quant_config,
651675
)
652676
for layer_id in range(n_refiner_layers)
653677
]
@@ -662,6 +686,7 @@ def __init__(
662686
norm_eps,
663687
qk_norm,
664688
modulation=False,
689+
quant_config=quant_config,
665690
)
666691
for layer_id in range(n_refiner_layers)
667692
]
@@ -677,7 +702,15 @@ def __init__(
677702

678703
self.layers = nn.ModuleList(
679704
[
680-
ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm)
705+
ZImageTransformerBlock(
706+
layer_id,
707+
dim,
708+
n_heads,
709+
n_kv_heads,
710+
norm_eps,
711+
qk_norm,
712+
quant_config=quant_config,
713+
)
681714
for layer_id in range(n_layers)
682715
]
683716
)

0 commit comments

Comments
 (0)