|
| 1 | +import argparse |
| 2 | +import os |
| 3 | +import os as _os_env_toggle |
| 4 | +import random |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +import soundfile as sf |
| 8 | +import torch |
| 9 | +from utils import make_omni_prompt |
| 10 | +from vllm.sampling_params import SamplingParams |
| 11 | + |
| 12 | +from vllm_omni.entrypoints.omni_llm import OmniLLM |
| 13 | + |
| 14 | +_os_env_toggle.environ["VLLM_USE_V1"] = "1" |
| 15 | + |
| 16 | +SEED = 42 |
| 17 | +# Set all random seeds |
| 18 | +random.seed(SEED) |
| 19 | +np.random.seed(SEED) |
| 20 | +torch.manual_seed(SEED) |
| 21 | +torch.cuda.manual_seed(SEED) |
| 22 | +torch.cuda.manual_seed_all(SEED) |
| 23 | + |
| 24 | +# Make PyTorch deterministic |
| 25 | +torch.backends.cudnn.deterministic = True |
| 26 | +torch.backends.cudnn.benchmark = False |
| 27 | + |
| 28 | +# Set environment variables for deterministic behavior |
| 29 | +os.environ["PYTHONHASHSEED"] = str(SEED) |
| 30 | +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" |
| 31 | + |
| 32 | + |
| 33 | +def parse_args(): |
| 34 | + parser = argparse.ArgumentParser() |
| 35 | + parser.add_argument( |
| 36 | + "--model", |
| 37 | + required=True, |
| 38 | + help="Path to merged model directory (will be created if downloading).", |
| 39 | + ) |
| 40 | + parser.add_argument("--thinker-model", type=str, default=None) |
| 41 | + parser.add_argument("--talker-model", type=str, default=None) |
| 42 | + parser.add_argument("--code2wav-model", type=str, default=None) |
| 43 | + parser.add_argument( |
| 44 | + "--hf-hub-id", |
| 45 | + default="Qwen/Qwen2.5-Omni-7B", |
| 46 | + help="Hugging Face repo id to download if needed.", |
| 47 | + ) |
| 48 | + parser.add_argument("--hf-revision", default=None, help="Optional HF revision (branch/tag/commit).") |
| 49 | + parser.add_argument("--prompts", nargs="+", default=None, help="Input text prompts.") |
| 50 | + parser.add_argument("--voice-type", default="default", help="Voice type, e.g., m02, f030, default.") |
| 51 | + parser.add_argument( |
| 52 | + "--code2wav-dir", |
| 53 | + default=None, |
| 54 | + help="Path to code2wav folder (contains spk_dict.pt).", |
| 55 | + ) |
| 56 | + parser.add_argument("--dit-ckpt", default=None, help="Path to DiT checkpoint file (e.g., dit.pt).") |
| 57 | + parser.add_argument("--bigvgan-ckpt", default=None, help="Path to BigVGAN checkpoint file.") |
| 58 | + parser.add_argument("--dtype", default="bfloat16", choices=["float16", "bfloat16", "float32"]) |
| 59 | + parser.add_argument("--max-model-len", type=int, default=32768) |
| 60 | + parser.add_argument( |
| 61 | + "--init-sleep-seconds", |
| 62 | + type=int, |
| 63 | + default=20, |
| 64 | + help="Sleep seconds after starting each stage process to allow initialization (default: 20)", |
| 65 | + ) |
| 66 | + |
| 67 | + parser.add_argument("--thinker-only", action="store_true") |
| 68 | + parser.add_argument("--text-only", action="store_true") |
| 69 | + parser.add_argument("--do-wave", action="store_true") |
| 70 | + parser.add_argument( |
| 71 | + "--prompt_type", |
| 72 | + choices=[ |
| 73 | + "text", |
| 74 | + "audio", |
| 75 | + "audio-long", |
| 76 | + "audio-long-chunks", |
| 77 | + "audio-long-expand-chunks", |
| 78 | + "image", |
| 79 | + "video", |
| 80 | + "video-frames", |
| 81 | + "audio-in-video", |
| 82 | + "audio-in-video-v2", |
| 83 | + "audio-multi-round", |
| 84 | + "badcase-vl", |
| 85 | + "badcase-text", |
| 86 | + "badcase-image-early-stop", |
| 87 | + "badcase-two-audios", |
| 88 | + "badcase-two-videos", |
| 89 | + "badcase-multi-round", |
| 90 | + "badcase-voice-type", |
| 91 | + "badcase-voice-type-v2", |
| 92 | + "badcase-audio-tower-1", |
| 93 | + "badcase-audio-only", |
| 94 | + ], |
| 95 | + default="text", |
| 96 | + ) |
| 97 | + parser.add_argument("--use-torchvision", action="store_true") |
| 98 | + parser.add_argument("--tokenize", action="store_true") |
| 99 | + parser.add_argument( |
| 100 | + "--output-wav", |
| 101 | + default="output.wav", |
| 102 | + help="[Deprecated] Output wav directory (use --output-dir).", |
| 103 | + ) |
| 104 | + parser.add_argument( |
| 105 | + "--output-dir", |
| 106 | + default="outputs", |
| 107 | + help="Output directory to save text and wav files together.", |
| 108 | + ) |
| 109 | + parser.add_argument( |
| 110 | + "--thinker-hidden-states-dir", |
| 111 | + default="thinker_hidden_states", |
| 112 | + help="Path to thinker hidden states directory.", |
| 113 | + ) |
| 114 | + parser.add_argument( |
| 115 | + "--batch-timeout", |
| 116 | + type=int, |
| 117 | + default=5, |
| 118 | + help="Timeout for batching in seconds (default: 5)", |
| 119 | + ) |
| 120 | + parser.add_argument( |
| 121 | + "--init-timeout", |
| 122 | + type=int, |
| 123 | + default=300, |
| 124 | + help="Timeout for initializing stages in seconds (default: 300)", |
| 125 | + ) |
| 126 | + parser.add_argument( |
| 127 | + "--shm-threshold-bytes", |
| 128 | + type=int, |
| 129 | + default=65536, |
| 130 | + help="Threshold for using shared memory in bytes (default: 65536)", |
| 131 | + ) |
| 132 | + parser.add_argument( |
| 133 | + "--enable-stats", |
| 134 | + action="store_true", |
| 135 | + default=False, |
| 136 | + help="Enable writing detailed statistics (default: disabled)", |
| 137 | + ) |
| 138 | + parser.add_argument( |
| 139 | + "--txt-prompts", |
| 140 | + type=str, |
| 141 | + default=None, |
| 142 | + help="Path to a .txt file with one prompt per line (preferred).", |
| 143 | + ) |
| 144 | + parser.add_argument( |
| 145 | + "--worker-backend", |
| 146 | + type=str, |
| 147 | + default="process", |
| 148 | + choices=["process","ray"], |
| 149 | + help="backend" |
| 150 | + ) |
| 151 | + parser.add_argument( |
| 152 | + "--ray-address", |
| 153 | + type=str, |
| 154 | + default=None, |
| 155 | + help="Path to a .txt file with one prompt per line (preferred)." |
| 156 | + ) |
| 157 | + |
| 158 | + args = parser.parse_args() |
| 159 | + return args |
| 160 | + |
| 161 | + |
| 162 | +def main(): |
| 163 | + args = parse_args() |
| 164 | + model_name = args.model |
| 165 | + try: |
| 166 | + # Preferred: load from txt file (one prompt per line) |
| 167 | + if getattr(args, "txt_prompts", None) and args.prompt_type == "text": |
| 168 | + with open(args.txt_prompts, encoding="utf-8") as f: |
| 169 | + lines = [ln.strip() for ln in f.readlines()] |
| 170 | + args.prompts = [ln for ln in lines if ln != ""] |
| 171 | + print(f"[Info] Loaded {len(args.prompts)} prompts from {args.txt_prompts}") |
| 172 | + except Exception as e: |
| 173 | + print(f"[Error] Failed to load prompts: {e}") |
| 174 | + raise |
| 175 | + |
| 176 | + if args.prompts is None: |
| 177 | + raise ValueError("No prompts provided. Use --prompts ... or --txt-prompts <file.txt> (with --prompt_type text)") |
| 178 | + omni_llm = OmniLLM( |
| 179 | + model=model_name, |
| 180 | + log_stats=args.enable_stats, |
| 181 | + log_file=("omni_llm_pipeline.log" if args.enable_stats else None), |
| 182 | + init_sleep_seconds=args.init_sleep_seconds, |
| 183 | + batch_timeout=args.batch_timeout, |
| 184 | + init_timeout=args.init_timeout, |
| 185 | + shm_threshold_bytes=args.shm_threshold_bytes, |
| 186 | + worker_backend=args.worker_backend, |
| 187 | + ray_address=args.ray_address |
| 188 | + ) |
| 189 | + thinker_sampling_params = SamplingParams( |
| 190 | + temperature=0.0, # Deterministic - no randomness |
| 191 | + top_p=1.0, # Disable nucleus sampling |
| 192 | + top_k=-1, # Disable top-k sampling |
| 193 | + max_tokens=2048, |
| 194 | + seed=SEED, # Fixed seed for sampling |
| 195 | + detokenize=True, |
| 196 | + repetition_penalty=1.1, |
| 197 | + ) |
| 198 | + talker_sampling_params = SamplingParams( |
| 199 | + temperature=0.9, |
| 200 | + top_p=0.8, |
| 201 | + top_k=40, |
| 202 | + max_tokens=2048, |
| 203 | + seed=SEED, # Fixed seed for sampling |
| 204 | + detokenize=True, |
| 205 | + repetition_penalty=1.05, |
| 206 | + stop_token_ids=[8294], |
| 207 | + ) |
| 208 | + code2wav_sampling_params = SamplingParams( |
| 209 | + temperature=0.0, # Deterministic - no randomness |
| 210 | + top_p=1.0, # Disable nucleus sampling |
| 211 | + top_k=-1, # Disable top-k sampling |
| 212 | + max_tokens=2048, |
| 213 | + seed=SEED, # Fixed seed for sampling |
| 214 | + detokenize=True, |
| 215 | + repetition_penalty=1.1, |
| 216 | + ) |
| 217 | + |
| 218 | + sampling_params_list = [ |
| 219 | + thinker_sampling_params, |
| 220 | + talker_sampling_params, |
| 221 | + code2wav_sampling_params, |
| 222 | + ] |
| 223 | + import time |
| 224 | + for i in range(1): |
| 225 | + t1 = time.time() |
| 226 | + prompt = [make_omni_prompt(args, prompt) for prompt in args.prompts] |
| 227 | + print(f"prompt:{prompt}") |
| 228 | + omni_outputs = omni_llm.generate(prompt, sampling_params_list) |
| 229 | + t2 = time.time() |
| 230 | + print(f"==========> time:{t2-t1}") |
| 231 | + |
| 232 | + # Determine output directory: prefer --output-dir; fallback to --output-wav |
| 233 | + output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav |
| 234 | + os.makedirs(output_dir, exist_ok=True) |
| 235 | + for stage_outputs in omni_outputs: |
| 236 | + if stage_outputs.final_output_type == "text": |
| 237 | + for output in stage_outputs.request_output: |
| 238 | + request_id = int(output.request_id) |
| 239 | + text_output = output.outputs[0].text |
| 240 | + # Save aligned text file per request |
| 241 | + prompt_text = args.prompts[request_id] |
| 242 | + out_txt = os.path.join(output_dir, f"{request_id:05d}.txt") |
| 243 | + lines = [] |
| 244 | + lines.append("Prompt:\n") |
| 245 | + lines.append(str(prompt_text) + "\n") |
| 246 | + lines.append("vllm_text_output:\n") |
| 247 | + lines.append(str(text_output).strip() + "\n") |
| 248 | + try: |
| 249 | + with open(out_txt, "w", encoding="utf-8") as f: |
| 250 | + f.writelines(lines) |
| 251 | + except Exception as e: |
| 252 | + print(f"[Warn] Failed writing text file {out_txt}: {e}") |
| 253 | + print(f"Request ID: {request_id}, Text saved to {out_txt}") |
| 254 | + elif stage_outputs.final_output_type == "audio": |
| 255 | + for output in stage_outputs.request_output: |
| 256 | + request_id = int(output.request_id) |
| 257 | + audio_tensor = output.multimodal_output["audio"] |
| 258 | + output_wav = os.path.join(output_dir, f"output_{output.request_id}.wav") |
| 259 | + sf.write(output_wav, audio_tensor.detach().cpu().numpy(), samplerate=24000) |
| 260 | + print(f"Request ID: {request_id}, Saved audio to {output_wav}") |
| 261 | + |
| 262 | + |
| 263 | +if __name__ == "__main__": |
| 264 | + main() |
0 commit comments