Skip to content

Commit 9de79f8

Browse files
authored
Merge branch 'main' into add-bagel-example-scripts
2 parents cfd81dc + bfbf3e5 commit 9de79f8

11 files changed

Lines changed: 315 additions & 8 deletions

File tree

vllm_omni/diffusion/offload.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from torch import nn
1919
from vllm.logger import init_logger
2020

21+
from vllm_omni.platforms import current_omni_platform
22+
2123
if TYPE_CHECKING:
2224
from vllm_omni.diffusion.data import OmniDiffusionConfig
2325

@@ -63,8 +65,8 @@ def _to_cpu(self, module: nn.Module) -> None:
6365
module.to("cpu", non_blocking=True)
6466

6567
# Release allocator blocks when tensors leave the GPU.
66-
if previous_device.type == "cuda" and torch.cuda.is_available():
67-
torch.cuda.empty_cache()
68+
if previous_device.type != "cpu":
69+
current_omni_platform.empty_cache()
6870

6971
if self.pin_memory:
7072
for p in module.parameters():
@@ -87,15 +89,19 @@ def _dit_pre_hook(self, module: nn.Module, args: tuple) -> None:
8789
for enc in self.encoders:
8890
self._to_cpu(enc)
8991
self._to_gpu(module)
90-
torch.cuda.synchronize()
92+
93+
current_omni_platform.synchronize()
94+
9195
logger.debug("Swapped: encoders -> CPU, DiT -> GPU")
9296

9397
def _encoder_pre_hook(self, module: nn.Module, args: tuple) -> None:
9498
"""Before encoder forward: offload DiT, load encoder."""
9599
for dit_mod in self.dits:
96100
self._to_cpu(dit_mod)
97101
self._to_gpu(module)
98-
torch.cuda.synchronize()
102+
103+
current_omni_platform.synchronize()
104+
99105
logger.debug("Swapped: DiT -> CPU, encoder -> GPU")
100106

101107
def register(self) -> None:
@@ -166,7 +172,10 @@ def apply_offload_hooks(
166172
try:
167173
device = next(dit_modules[0].parameters()).device
168174
except StopIteration:
169-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
175+
try:
176+
device = current_omni_platform.get_torch_device()
177+
except (NotImplementedError, AttributeError):
178+
device = torch.device("cpu")
170179

171180
# Collect all encoders
172181
encoders: list[nn.Module] = []
@@ -184,9 +193,10 @@ def apply_offload_hooks(
184193
pin = getattr(od_config, "pin_cpu_memory", True)
185194
for dit_mod in dit_modules:
186195
dit_mod.to("cpu")
187-
if torch.cuda.is_available():
188-
torch.cuda.empty_cache()
189-
if pin and torch.cuda.is_available():
196+
197+
current_omni_platform.empty_cache()
198+
199+
if pin:
190200
for dit_mod in dit_modules:
191201
for p in dit_mod.parameters():
192202
if p.data.device.type == "cpu" and not p.data.is_pinned():

vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_token2wav.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,10 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) ->
736736
kaiser_window = torch.kaiser_window(
737737
kernel_size, beta=beta, periodic=False, dtype=torch.float32, device="cpu"
738738
).to("npu")
739+
elif current_omni_platform.is_xpu():
740+
kaiser_window = torch.kaiser_window(
741+
kernel_size, beta=beta, periodic=False, dtype=torch.float32, device="cpu"
742+
).to("xpu")
739743
else:
740744
kaiser_window = torch.kaiser_window(kernel_size, beta=beta, periodic=False, dtype=torch.float32)
741745

vllm_omni/platforms/interface.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@ class OmniPlatform(Platform):
3131
def is_npu(self) -> bool:
3232
return self._omni_enum == OmniPlatformEnum.NPU
3333

34+
def is_xpu(self) -> bool:
35+
return self._omni_enum == OmniPlatformEnum.XPU
36+
37+
def is_cuda(self) -> bool:
38+
return self._omni_enum == OmniPlatformEnum.CUDA
39+
40+
def is_rocm(self) -> bool:
41+
return self._omni_enum == OmniPlatformEnum.ROCM
42+
3443
@classmethod
3544
def get_omni_ar_worker_cls(cls) -> str:
3645
raise NotImplementedError
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# stage config for running qwen2.5-omni with architecture of OmniLLM.
2+
3+
# The following config has been verified on 2x 1550-64G XPUs.
4+
stage_args:
5+
- stage_id: 0
6+
stage_type: llm # Use llm stage type to launch OmniLLM
7+
runtime:
8+
process: true # Run this stage in a separate process
9+
devices: "0" # Visible devices for this stage
10+
max_batch_size: 1
11+
engine_args:
12+
model_stage: thinker
13+
model_arch: Qwen2_5OmniForConditionalGeneration
14+
worker_type: ar
15+
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
16+
gpu_memory_utilization: 0.8
17+
enforce_eager: false
18+
trust_remote_code: true
19+
engine_output_type: latent
20+
enable_prefix_caching: false
21+
is_comprehension: true
22+
final_output: true
23+
final_output_type: text
24+
default_sampling_params:
25+
temperature: 0.0
26+
top_p: 1.0
27+
top_k: -1
28+
max_tokens: 2048
29+
seed: 42
30+
detokenize: True
31+
repetition_penalty: 1.1
32+
- stage_id: 1
33+
stage_type: llm # Use llm stage type to launch OmniLLM
34+
runtime:
35+
process: true
36+
devices: "1"
37+
max_batch_size: 1
38+
engine_args:
39+
model_stage: talker
40+
model_arch: Qwen2_5OmniForConditionalGeneration
41+
worker_type: ar
42+
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
43+
gpu_memory_utilization: 0.8
44+
enforce_eager: false
45+
trust_remote_code: true
46+
enable_prefix_caching: false
47+
engine_output_type: latent
48+
engine_input_source: [0]
49+
custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
50+
default_sampling_params:
51+
temperature: 0.9
52+
top_p: 0.8
53+
top_k: 40
54+
max_tokens: 2048
55+
seed: 42
56+
detokenize: True
57+
repetition_penalty: 1.05
58+
stop_token_ids: [8294]
59+
60+
- stage_id: 2
61+
stage_type: llm # Use llm stage type to launch OmniLLM
62+
runtime:
63+
process: true
64+
devices: "0"
65+
max_batch_size: 1
66+
engine_args:
67+
model_stage: code2wav
68+
model_arch: Qwen2_5OmniForConditionalGeneration
69+
worker_type: generation
70+
scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
71+
gpu_memory_utilization: 0.15
72+
enforce_eager: true
73+
trust_remote_code: true
74+
enable_prefix_caching: false
75+
engine_output_type: audio
76+
engine_input_source: [1]
77+
final_output: true
78+
final_output_type: audio
79+
default_sampling_params:
80+
temperature: 0.0
81+
top_p: 1.0
82+
top_k: -1
83+
max_tokens: 2048
84+
seed: 42
85+
detokenize: True
86+
repetition_penalty: 1.1
87+
88+
# Top-level runtime config (concise): default windows and stage edges
89+
runtime:
90+
enabled: true
91+
defaults:
92+
window_size: -1 # Simplified: trigger downstream only after full upstream completion
93+
max_inflight: 1 # Simplified: process serially within each stage
94+
95+
edges:
96+
- from: 0 # thinker → talker: trigger only after receiving full input (-1)
97+
to: 1
98+
window_size: -1
99+
- from: 1 # talker → code2wav: trigger only after receiving full input (-1)
100+
to: 2
101+
window_size: -1
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
2+
# Stage 0: Thinker (multimodal understanding + text generation)
3+
# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
4+
# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
5+
6+
# The following config has been verified on 4x 1550-64G XPUs.
7+
stage_args:
8+
- stage_id: 0
9+
stage_type: llm # Use llm stage type to launch OmniLLM
10+
runtime:
11+
devices: "0,1"
12+
max_batch_size: 1
13+
engine_args:
14+
model_stage: thinker
15+
model_arch: Qwen3OmniMoeForConditionalGeneration
16+
worker_type: ar
17+
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
18+
gpu_memory_utilization: 0.8
19+
enforce_eager: true
20+
trust_remote_code: true
21+
engine_output_type: latent # Output hidden states for talker
22+
distributed_executor_backend: "mp"
23+
enable_prefix_caching: false
24+
max_num_batched_tokens: 32768
25+
hf_config_name: thinker_config
26+
tensor_parallel_size: 2
27+
final_output: true
28+
final_output_type: text
29+
is_comprehension: true
30+
default_sampling_params:
31+
temperature: 0.4
32+
top_p: 0.9
33+
top_k: 1
34+
max_tokens: 2048
35+
seed: 42
36+
detokenize: True
37+
repetition_penalty: 1.05
38+
39+
- stage_id: 1
40+
stage_type: llm # Use llm stage type to launch OmniLLM
41+
runtime:
42+
devices: "2"
43+
max_batch_size: 1
44+
engine_args:
45+
model_stage: talker
46+
model_arch: Qwen3OmniMoeForConditionalGeneration
47+
worker_type: ar
48+
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
49+
gpu_memory_utilization: 0.3
50+
enforce_eager: true
51+
trust_remote_code: true
52+
engine_output_type: latent # Output codec codes for code2wav
53+
enable_prefix_caching: false
54+
max_num_batched_tokens: 32768
55+
distributed_executor_backend: "mp"
56+
hf_config_name: talker_config
57+
engine_input_source: [0]
58+
custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
59+
# final_output: true
60+
# final_output_type: text
61+
default_sampling_params:
62+
temperature: 0.9
63+
top_k: 50
64+
max_tokens: 4096
65+
seed: 42
66+
detokenize: False
67+
repetition_penalty: 1.05
68+
stop_token_ids: [2150]
69+
70+
- stage_id: 2
71+
stage_type: llm # Use llm stage type to launch OmniLLM
72+
runtime:
73+
devices: "3"
74+
max_batch_size: 1
75+
engine_args:
76+
model_stage: code2wav
77+
model_arch: Qwen3OmniMoeForConditionalGeneration
78+
worker_type: generation
79+
scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
80+
enforce_eager: true
81+
trust_remote_code: true
82+
enable_prefix_caching: false
83+
engine_output_type: audio # Final output: audio waveform
84+
gpu_memory_utilization: 0.1
85+
distributed_executor_backend: "mp"
86+
max_num_batched_tokens: 1000000
87+
hf_config_name: thinker_config
88+
engine_input_source: [1]
89+
custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
90+
final_output: true
91+
final_output_type: audio
92+
default_sampling_params:
93+
temperature: 0.0
94+
top_p: 1.0
95+
top_k: -1
96+
max_tokens: 65536
97+
seed: 42
98+
detokenize: True
99+
repetition_penalty: 1.1

vllm_omni/platforms/xpu/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from contextlib import contextmanager
2+
3+
import torch
4+
5+
6+
@contextmanager
7+
def torch_cuda_wrapper():
8+
try:
9+
# replace cuda APIs with xpu APIs, this should work by default
10+
torch.cuda.Stream = torch.xpu.Stream
11+
torch.cuda.default_stream = torch.xpu.current_stream
12+
torch.cuda.current_stream = torch.xpu.current_stream
13+
torch.cuda.stream = torch.xpu.stream
14+
yield
15+
finally:
16+
pass

vllm_omni/platforms/xpu/worker/__init__.py

Whitespace-only changes.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import torch
5+
6+
from vllm_omni.platforms.xpu.utils import torch_cuda_wrapper
7+
from vllm_omni.worker.gpu_ar_model_runner import GPUARModelRunner
8+
9+
10+
class XPUARModelRunner(GPUARModelRunner):
11+
def __init__(self, *args, **kwargs):
12+
with torch_cuda_wrapper():
13+
super().__init__(*args, **kwargs)
14+
15+
def _init_device_properties(self):
16+
self.num_sms = None
17+
18+
def _sync_device(self) -> None:
19+
torch.xpu.synchronize()
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from vllm.v1.worker.xpu_worker import XPUWorker
5+
6+
from vllm_omni.platforms.xpu.worker.xpu_ar_model_runner import XPUARModelRunner
7+
from vllm_omni.worker.mixins import OmniWorkerMixin
8+
9+
10+
class XPUARWorker(OmniWorkerMixin, XPUWorker):
11+
"""XPU AR worker for thinker/talker stages in Omni model."""
12+
13+
def init_device(self):
14+
super().init_device()
15+
self.model_runner: XPUARModelRunner = XPUARModelRunner(self.vllm_config, self.device)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import torch
5+
6+
from vllm_omni.platforms.xpu.utils import torch_cuda_wrapper
7+
from vllm_omni.worker.gpu_generation_model_runner import GPUGenerationModelRunner
8+
9+
10+
class XPUGenerationModelRunner(GPUGenerationModelRunner):
11+
def __init__(self, *args, **kwargs):
12+
with torch_cuda_wrapper():
13+
super().__init__(*args, **kwargs)
14+
15+
def _init_device_properties(self):
16+
self.num_sms = None
17+
18+
def _sync_device(self) -> None:
19+
torch.xpu.synchronize()

0 commit comments

Comments
 (0)