Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 51 additions & 5 deletions examples/online_serving/qwen2_5_omni/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,58 @@ def parse_args():
default=None,
help="Path to custom stage configs YAML file (optional).",
)
parser.add_argument(
"--log-stats",
action="store_true",
help="Enable statistics logging for AsyncOmni.",
)
parser.add_argument(
"--log-file",
type=str,
default=None,
help="Path prefix for AsyncOmni log files.",
)
parser.add_argument(
"--init-sleep-seconds",
type=int,
default=30,
help="Seconds to sleep between starting stage processes.",
)
parser.add_argument(
"--shm-threshold-bytes",
type=int,
default=65536,
help="Threshold in bytes for using shared memory IPC.",
)
parser.add_argument(
"--batch-timeout",
type=int,
default=10,
help="Batching timeout (seconds) inside each stage.",
)
parser.add_argument(
"--init-timeout",
type=int,
default=ASYNC_INIT_TIMEOUT,
help="Timeout (seconds) for initializing all stages.",
)
return parser.parse_args()


def build_async_omni_cli_args(base_args: argparse.Namespace) -> argparse.Namespace:
"""Construct the minimal CLI args Namespace expected by AsyncOmni."""
return argparse.Namespace(
model=base_args.model,
stage_configs_path=getattr(base_args, "stage_configs_path", None),
log_stats=bool(getattr(base_args, "log_stats", False)),
log_file=getattr(base_args, "log_file", None),
init_sleep_seconds=int(getattr(base_args, "init_sleep_seconds", 30)),
shm_threshold_bytes=int(getattr(base_args, "shm_threshold_bytes", 65536)),
batch_timeout=int(getattr(base_args, "batch_timeout", 10)),
init_timeout=int(getattr(base_args, "init_timeout", ASYNC_INIT_TIMEOUT)),
)


def build_sampling_params(seed: int, model_key: str) -> list[SamplingParams]:
"""Build SamplingParams objects by reusing the dict definitions."""
return [SamplingParams(**params_dict) for params_dict in build_sampling_params_dict(seed, model_key)]
Expand Down Expand Up @@ -500,11 +549,8 @@ def signal_handler(sig, frame):
print(f"Using custom stage configs: {args.stage_configs_path}")

sampling_params = build_sampling_params(SEED, model_name)
omni = AsyncOmni(
model=args.model,
stage_configs_path=args.stage_configs_path,
init_timeout=ASYNC_INIT_TIMEOUT,
)
cli_args = build_async_omni_cli_args(args)
omni = AsyncOmni(model=args.model, cli_args=cli_args)
print("✓ AsyncOmni initialized successfully")
prompt_args_template = create_prompt_args(args)

Expand Down
56 changes: 51 additions & 5 deletions examples/online_serving/qwen3_omni/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,58 @@ def parse_args():
default=None,
help="Path to custom stage configs YAML file (optional).",
)
parser.add_argument(
"--log-stats",
action="store_true",
help="Enable statistics logging for AsyncOmni.",
)
parser.add_argument(
"--log-file",
type=str,
default=None,
help="Path prefix for AsyncOmni log files.",
)
parser.add_argument(
"--init-sleep-seconds",
type=int,
default=30,
help="Seconds to sleep between starting stage processes.",
)
parser.add_argument(
"--shm-threshold-bytes",
type=int,
default=65536,
help="Threshold in bytes for using shared memory IPC.",
)
parser.add_argument(
"--batch-timeout",
type=int,
default=10,
help="Batching timeout (seconds) inside each stage.",
)
parser.add_argument(
"--init-timeout",
type=int,
default=ASYNC_INIT_TIMEOUT,
help="Timeout (seconds) for initializing all stages.",
)
return parser.parse_args()


def build_async_omni_cli_args(base_args: argparse.Namespace) -> argparse.Namespace:
"""Construct the minimal CLI args Namespace expected by AsyncOmni."""
return argparse.Namespace(
model=base_args.model,
stage_configs_path=getattr(base_args, "stage_configs_path", None),
log_stats=bool(getattr(base_args, "log_stats", False)),
log_file=getattr(base_args, "log_file", None),
init_sleep_seconds=int(getattr(base_args, "init_sleep_seconds", 30)),
shm_threshold_bytes=int(getattr(base_args, "shm_threshold_bytes", 65536)),
batch_timeout=int(getattr(base_args, "batch_timeout", 10)),
init_timeout=int(getattr(base_args, "init_timeout", ASYNC_INIT_TIMEOUT)),
)


def build_sampling_params(seed: int, model_key: str) -> list[SamplingParams]:
"""Build SamplingParams objects by reusing the dict definitions."""
return [SamplingParams(**params_dict) for params_dict in build_sampling_params_dict(seed, model_key)]
Expand Down Expand Up @@ -506,11 +555,8 @@ def signal_handler(sig, frame):
print(f"Using custom stage configs: {args.stage_configs_path}")

sampling_params = build_sampling_params(SEED, model_name)
omni = AsyncOmni(
model=args.model,
stage_configs_path=args.stage_configs_path,
init_timeout=ASYNC_INIT_TIMEOUT,
)
cli_args = build_async_omni_cli_args(args)
omni = AsyncOmni(model=args.model, cli_args=cli_args)
print("✓ AsyncOmni initialized successfully")
prompt_args_template = create_prompt_args(args)

Expand Down