Skip to content

Commit 6ecbc4c

Browse files
added omni mooncake connectors to distributed directory, integrate omni connector to omni-llm, rebased with main (one time).
ray + mooncake connector work on 2 nodes, but needes several updates\n1.start ray inside vllm-omni \n2.placement group optimize \n3.fix close
1 parent 99c46e1 commit 6ecbc4c

27 files changed

Lines changed: 2724 additions & 66 deletions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,4 @@ configs/development.yaml
235235
# Docker
236236
.dockerignore
237237
Dockerfile.dev
238+
discussion

examples/offline_inference/qwen2_5_omni/end2end.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,19 @@ def parse_args():
283283
default=None,
284284
help="Path to a .txt file with one prompt per line (preferred).",
285285
)
286+
parser.add_argument(
287+
"--worker-backend",
288+
type=str,
289+
default="process",
290+
choices=["process","ray"],
291+
help="backend"
292+
)
293+
parser.add_argument(
294+
"--ray-address",
295+
type=str,
296+
default=None,
297+
help="Path to a .txt file with one prompt per line (preferred).",
298+
286299

287300
return parser.parse_args()
288301

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,13 @@
1-
python end2end.py --output-wav output_audio \
2-
--query-type use_audio_in_video
1+
#python end2end.py --output-wav output_audio \
2+
# --query-type use_audio_in_video
3+
export PYTHONPATH=/workspace/omni/vllm-omni/:$PYTHONPATH
4+
CUDA_VISIBLE_DEVICES=0,1,2,3 python end2end.py --model /workspace/Qwen2.5-Omni-7B/ \
5+
--voice-type "m02" \
6+
--dit-ckpt none \
7+
--bigvgan-ckpt none \
8+
--output-wav output_audio \
9+
--prompt_type text \
10+
--init-sleep-seconds 0 \
11+
--worker-backend ray \
12+
--ray-address auto \
13+
--prompts "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words."
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Offline Example of vLLM-omni for Qwen2.5-omni
2+
3+
## 🛠️ Installation
4+
5+
Please refer to [README.md](../../../README.md)
6+
7+
## Run examples (Qwen2.5-omni)
8+
### Multiple Prompts
9+
Download dataset from [seed_tts](https://drive.google.com/file/d/1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP/edit). To get the prompt, you can:
10+
```bash
11+
tar -xf <Your Download Path>/seedtts_testset.tar
12+
cp seedtts_testset/en/meta.lst examples/offline_inference/qwen2_5_omni/meta.lst
13+
python3 examples/offline_inference/qwen2_5_omni/extract_prompts.py \
14+
--input examples/offline_inference/qwen2_5_omni/meta.lst \
15+
--output examples/offline_inference/qwen2_5_omni/top100.txt \
16+
--topk 100
17+
```
18+
Get into the example folder
19+
```bash
20+
cd examples/offline_inference/qwen2_5_omni
21+
```
22+
Then run the command below.
23+
```bash
24+
bash run_multiple_prompts.sh
25+
```
26+
### Single Prompts
27+
Get into the example folder
28+
```bash
29+
cd examples/offline_inference/qwen2_5_omni
30+
```
31+
Then run the command below.
32+
```bash
33+
bash run_single_prompt.sh
34+
```
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
import argparse
2+
import os
3+
import os as _os_env_toggle
4+
import random
5+
6+
import numpy as np
7+
import soundfile as sf
8+
import torch
9+
from utils import make_omni_prompt
10+
from vllm.sampling_params import SamplingParams
11+
12+
from vllm_omni.entrypoints.omni_llm import OmniLLM
13+
14+
_os_env_toggle.environ["VLLM_USE_V1"] = "1"
15+
16+
SEED = 42
17+
# Set all random seeds
18+
random.seed(SEED)
19+
np.random.seed(SEED)
20+
torch.manual_seed(SEED)
21+
torch.cuda.manual_seed(SEED)
22+
torch.cuda.manual_seed_all(SEED)
23+
24+
# Make PyTorch deterministic
25+
torch.backends.cudnn.deterministic = True
26+
torch.backends.cudnn.benchmark = False
27+
28+
# Set environment variables for deterministic behavior
29+
os.environ["PYTHONHASHSEED"] = str(SEED)
30+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
31+
32+
33+
def parse_args():
34+
parser = argparse.ArgumentParser()
35+
parser.add_argument(
36+
"--model",
37+
required=True,
38+
help="Path to merged model directory (will be created if downloading).",
39+
)
40+
parser.add_argument("--thinker-model", type=str, default=None)
41+
parser.add_argument("--talker-model", type=str, default=None)
42+
parser.add_argument("--code2wav-model", type=str, default=None)
43+
parser.add_argument(
44+
"--hf-hub-id",
45+
default="Qwen/Qwen2.5-Omni-7B",
46+
help="Hugging Face repo id to download if needed.",
47+
)
48+
parser.add_argument("--hf-revision", default=None, help="Optional HF revision (branch/tag/commit).")
49+
parser.add_argument("--prompts", nargs="+", default=None, help="Input text prompts.")
50+
parser.add_argument("--voice-type", default="default", help="Voice type, e.g., m02, f030, default.")
51+
parser.add_argument(
52+
"--code2wav-dir",
53+
default=None,
54+
help="Path to code2wav folder (contains spk_dict.pt).",
55+
)
56+
parser.add_argument("--dit-ckpt", default=None, help="Path to DiT checkpoint file (e.g., dit.pt).")
57+
parser.add_argument("--bigvgan-ckpt", default=None, help="Path to BigVGAN checkpoint file.")
58+
parser.add_argument("--dtype", default="bfloat16", choices=["float16", "bfloat16", "float32"])
59+
parser.add_argument("--max-model-len", type=int, default=32768)
60+
parser.add_argument(
61+
"--init-sleep-seconds",
62+
type=int,
63+
default=20,
64+
help="Sleep seconds after starting each stage process to allow initialization (default: 20)",
65+
)
66+
67+
parser.add_argument("--thinker-only", action="store_true")
68+
parser.add_argument("--text-only", action="store_true")
69+
parser.add_argument("--do-wave", action="store_true")
70+
parser.add_argument(
71+
"--prompt_type",
72+
choices=[
73+
"text",
74+
"audio",
75+
"audio-long",
76+
"audio-long-chunks",
77+
"audio-long-expand-chunks",
78+
"image",
79+
"video",
80+
"video-frames",
81+
"audio-in-video",
82+
"audio-in-video-v2",
83+
"audio-multi-round",
84+
"badcase-vl",
85+
"badcase-text",
86+
"badcase-image-early-stop",
87+
"badcase-two-audios",
88+
"badcase-two-videos",
89+
"badcase-multi-round",
90+
"badcase-voice-type",
91+
"badcase-voice-type-v2",
92+
"badcase-audio-tower-1",
93+
"badcase-audio-only",
94+
],
95+
default="text",
96+
)
97+
parser.add_argument("--use-torchvision", action="store_true")
98+
parser.add_argument("--tokenize", action="store_true")
99+
parser.add_argument(
100+
"--output-wav",
101+
default="output.wav",
102+
help="[Deprecated] Output wav directory (use --output-dir).",
103+
)
104+
parser.add_argument(
105+
"--output-dir",
106+
default="outputs",
107+
help="Output directory to save text and wav files together.",
108+
)
109+
parser.add_argument(
110+
"--thinker-hidden-states-dir",
111+
default="thinker_hidden_states",
112+
help="Path to thinker hidden states directory.",
113+
)
114+
parser.add_argument(
115+
"--batch-timeout",
116+
type=int,
117+
default=5,
118+
help="Timeout for batching in seconds (default: 5)",
119+
)
120+
parser.add_argument(
121+
"--init-timeout",
122+
type=int,
123+
default=300,
124+
help="Timeout for initializing stages in seconds (default: 300)",
125+
)
126+
parser.add_argument(
127+
"--shm-threshold-bytes",
128+
type=int,
129+
default=65536,
130+
help="Threshold for using shared memory in bytes (default: 65536)",
131+
)
132+
parser.add_argument(
133+
"--enable-stats",
134+
action="store_true",
135+
default=False,
136+
help="Enable writing detailed statistics (default: disabled)",
137+
)
138+
parser.add_argument(
139+
"--txt-prompts",
140+
type=str,
141+
default=None,
142+
help="Path to a .txt file with one prompt per line (preferred).",
143+
)
144+
parser.add_argument(
145+
"--worker-backend",
146+
type=str,
147+
default="process",
148+
choices=["process","ray"],
149+
help="backend"
150+
)
151+
parser.add_argument(
152+
"--ray-address",
153+
type=str,
154+
default=None,
155+
help="Path to a .txt file with one prompt per line (preferred)."
156+
)
157+
158+
args = parser.parse_args()
159+
return args
160+
161+
162+
def main():
163+
args = parse_args()
164+
model_name = args.model
165+
try:
166+
# Preferred: load from txt file (one prompt per line)
167+
if getattr(args, "txt_prompts", None) and args.prompt_type == "text":
168+
with open(args.txt_prompts, encoding="utf-8") as f:
169+
lines = [ln.strip() for ln in f.readlines()]
170+
args.prompts = [ln for ln in lines if ln != ""]
171+
print(f"[Info] Loaded {len(args.prompts)} prompts from {args.txt_prompts}")
172+
except Exception as e:
173+
print(f"[Error] Failed to load prompts: {e}")
174+
raise
175+
176+
if args.prompts is None:
177+
raise ValueError("No prompts provided. Use --prompts ... or --txt-prompts <file.txt> (with --prompt_type text)")
178+
omni_llm = OmniLLM(
179+
model=model_name,
180+
log_stats=args.enable_stats,
181+
log_file=("omni_llm_pipeline.log" if args.enable_stats else None),
182+
init_sleep_seconds=args.init_sleep_seconds,
183+
batch_timeout=args.batch_timeout,
184+
init_timeout=args.init_timeout,
185+
shm_threshold_bytes=args.shm_threshold_bytes,
186+
worker_backend=args.worker_backend,
187+
ray_address=args.ray_address
188+
)
189+
thinker_sampling_params = SamplingParams(
190+
temperature=0.0, # Deterministic - no randomness
191+
top_p=1.0, # Disable nucleus sampling
192+
top_k=-1, # Disable top-k sampling
193+
max_tokens=2048,
194+
seed=SEED, # Fixed seed for sampling
195+
detokenize=True,
196+
repetition_penalty=1.1,
197+
)
198+
talker_sampling_params = SamplingParams(
199+
temperature=0.9,
200+
top_p=0.8,
201+
top_k=40,
202+
max_tokens=2048,
203+
seed=SEED, # Fixed seed for sampling
204+
detokenize=True,
205+
repetition_penalty=1.05,
206+
stop_token_ids=[8294],
207+
)
208+
code2wav_sampling_params = SamplingParams(
209+
temperature=0.0, # Deterministic - no randomness
210+
top_p=1.0, # Disable nucleus sampling
211+
top_k=-1, # Disable top-k sampling
212+
max_tokens=2048,
213+
seed=SEED, # Fixed seed for sampling
214+
detokenize=True,
215+
repetition_penalty=1.1,
216+
)
217+
218+
sampling_params_list = [
219+
thinker_sampling_params,
220+
talker_sampling_params,
221+
code2wav_sampling_params,
222+
]
223+
import time
224+
for i in range(1):
225+
t1 = time.time()
226+
prompt = [make_omni_prompt(args, prompt) for prompt in args.prompts]
227+
print(f"prompt:{prompt}")
228+
omni_outputs = omni_llm.generate(prompt, sampling_params_list)
229+
t2 = time.time()
230+
print(f"==========> time:{t2-t1}")
231+
232+
# Determine output directory: prefer --output-dir; fallback to --output-wav
233+
output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav
234+
os.makedirs(output_dir, exist_ok=True)
235+
for stage_outputs in omni_outputs:
236+
if stage_outputs.final_output_type == "text":
237+
for output in stage_outputs.request_output:
238+
request_id = int(output.request_id)
239+
text_output = output.outputs[0].text
240+
# Save aligned text file per request
241+
prompt_text = args.prompts[request_id]
242+
out_txt = os.path.join(output_dir, f"{request_id:05d}.txt")
243+
lines = []
244+
lines.append("Prompt:\n")
245+
lines.append(str(prompt_text) + "\n")
246+
lines.append("vllm_text_output:\n")
247+
lines.append(str(text_output).strip() + "\n")
248+
try:
249+
with open(out_txt, "w", encoding="utf-8") as f:
250+
f.writelines(lines)
251+
except Exception as e:
252+
print(f"[Warn] Failed writing text file {out_txt}: {e}")
253+
print(f"Request ID: {request_id}, Text saved to {out_txt}")
254+
elif stage_outputs.final_output_type == "audio":
255+
for output in stage_outputs.request_output:
256+
request_id = int(output.request_id)
257+
audio_tensor = output.multimodal_output["audio"]
258+
output_wav = os.path.join(output_dir, f"output_{output.request_id}.wav")
259+
sf.write(output_wav, audio_tensor.detach().cpu().numpy(), samplerate=24000)
260+
print(f"Request ID: {request_id}, Saved audio to {output_wav}")
261+
262+
263+
if __name__ == "__main__":
264+
main()
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#!/usr/bin/env python3
2+
import argparse
3+
from typing import Optional
4+
5+
6+
def extract_prompt(line: str) -> Optional[str]:
7+
# 提取第一个 '|' 与第二个 '|' 之间的内容
8+
i = line.find("|")
9+
if i == -1:
10+
return None
11+
j = line.find("|", i + 1)
12+
if j == -1:
13+
return None
14+
return line[i + 1 : j].strip()
15+
16+
17+
def main() -> None:
18+
parser = argparse.ArgumentParser()
19+
parser.add_argument("--input", "-i", required=True, help="Input .lst file path")
20+
parser.add_argument("--output", "-o", required=True, help="Output file path")
21+
parser.add_argument(
22+
"--topk",
23+
"-k",
24+
type=int,
25+
default=100,
26+
help="Extract the top K prompts (default: 100)",
27+
)
28+
args = parser.parse_args()
29+
30+
prompts = []
31+
with open(args.input, encoding="utf-8", errors="ignore") as f:
32+
for line in f:
33+
if len(prompts) >= args.topk:
34+
break
35+
p = extract_prompt(line.rstrip("\n"))
36+
if p:
37+
prompts.append(p)
38+
39+
with open(args.output, "w", encoding="utf-8") as f:
40+
for p in prompts:
41+
f.write(p + "\n")
42+
43+
44+
if __name__ == "__main__":
45+
main()

0 commit comments

Comments
 (0)