Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9e7323f
Support Bagel Model
princepride Jan 10, 2026
5d5db43
remove reqest_ids and request_id assert exist
princepride Jan 10, 2026
9085814
move bagel end2end to examples&delete useless parameters in yaml
princepride Jan 10, 2026
5f4a472
remove useless comment
princepride Jan 10, 2026
bb094c7
fix some problem
princepride Jan 14, 2026
339faca
xxx
princepride Jan 14, 2026
3d07a2c
xxx
princepride Jan 14, 2026
d533646
adjust test path logits
princepride Jan 16, 2026
095e7b8
adjust kv transer unit-test and add online-inference
princepride Jan 18, 2026
169a744
fix pre-commit error
princepride Jan 18, 2026
055894f
fix pre-commit error
princepride Jan 18, 2026
7948da1
move inject_omni_kv_config to utils
princepride Jan 18, 2026
72110c0
simplify ar model runner extract and transfer kv cache code
princepride Jan 18, 2026
01a90b3
simplify ar model runner extract and transfer kv cache code
princepride Jan 18, 2026
695b1de
xxx
princepride Jan 18, 2026
0bea3bc
move customer config under /transformers_utils/configs
princepride Jan 19, 2026
866442c
move customer config under /transformers_utils/configs
princepride Jan 19, 2026
f8fafdd
move customer config under /transformers_utils/configs
princepride Jan 19, 2026
e87d8f9
move customer processor under /transformers_utils/processors
princepride Jan 19, 2026
1c0622e
remove useless comment
princepride Jan 19, 2026
2655414
remove useless comment
princepride Jan 19, 2026
e01d690
fix pre-commit error
princepride Jan 21, 2026
33b20e1
remove ar bagel because vllm already have
princepride Jan 21, 2026
f2f620f
remove ar bagel because vllm already have
princepride Jan 21, 2026
8894708
remove useless code
princepride Jan 21, 2026
ecad2ea
fix some bug
princepride Jan 22, 2026
4a76c7f
add test_kv_flow to buildkite
princepride Jan 22, 2026
ca78f31
Merge branch 'main' into new-bagel-model-stage
princepride Jan 22, 2026
1a4ecab
remove duplicate code
princepride Jan 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions examples/offline_inference/bagel/end2end.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import argparse
import os


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
default="ByteDance-Seed/BAGEL-7B-MoT",
help="Path to merged model directory.",
)
parser.add_argument("--prompts", nargs="+", default=None, help="Input text prompts.")
parser.add_argument(
"--txt-prompts",
type=str,
default=None,
help="Path to a .txt file with one prompt per line (preferred).",
)
parser.add_argument("--prompt_type", default="text", choices=["text"])

parser.add_argument(
"--modality",
default="text2img",
choices=["text2img", "img2img", "img2text", "text2text"],
help="Modality mode to control stage execution.",
)

parser.add_argument(
"--image-path",
type=str,
default=None,
help="Path to input image for img2img.",
)

# OmniLLM init args
parser.add_argument("--enable-stats", action="store_true", default=False)
parser.add_argument("--init-sleep-seconds", type=int, default=20)
parser.add_argument("--batch-timeout", type=int, default=5)
parser.add_argument("--init-timeout", type=int, default=300)
parser.add_argument("--shm-threshold-bytes", type=int, default=65536)
parser.add_argument("--worker-backend", type=str, default="process", choices=["process", "ray"])
parser.add_argument("--ray-address", type=str, default=None)
parser.add_argument("--stage-configs-path", type=str, default=None)
parser.add_argument("--steps", type=int, default=50, help="Number of inference steps.")

args = parser.parse_args()
return args


def main():
args = parse_args()
model_name = args.model
try:
# Preferred: load from txt file (one prompt per line)
if getattr(args, "txt_prompts", None) and args.prompt_type == "text":
with open(args.txt_prompts, encoding="utf-8") as f:
lines = [ln.strip() for ln in f.readlines()]
args.prompts = [ln for ln in lines if ln != ""]
print(f"[Info] Loaded {len(args.prompts)} prompts from {args.txt_prompts}")
except Exception as e:
print(f"[Error] Failed to load prompts: {e}")
raise

if args.prompts is None:
# Default prompt for text2img test if none provided
args.prompts = ["<|im_start|>A cute cat<|im_end|>"]
print(f"[Info] No prompts provided, using default: {args.prompts}")
omni_outputs = []

from PIL import Image

if args.modality == "img2img":
from PIL import Image

from vllm_omni.entrypoints.omni_diffusion import OmniDiffusion

print("[Info] Running in img2img mode (Stage 1 only)")
client = OmniDiffusion(model=model_name)

generate_kwargs = {
"prompt": args.prompts,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Later we should unify these kwargs into SamplingParams like classes

"seed": 52,
"need_kv_receive": False,
"num_inference_steps": args.steps,
}

if args.image_path:
if os.path.exists(args.image_path):
loaded_image = Image.open(args.image_path).convert("RGB")
generate_kwargs["pil_image"] = loaded_image
else:
print(f"[Warning] Image path {args.image_path} does not exist.")

result = client.generate(**generate_kwargs)

# Ensure result is a list for iteration
if not isinstance(result, list):
omni_outputs = [result]
else:
omni_outputs = result

else:
import copy

from vllm_omni.entrypoints.omni import Omni

omni_kwargs = {}
if args.stage_configs_path:
omni_kwargs["stage_configs_path"] = args.stage_configs_path

omni_kwargs.update(
{
"log_stats": args.enable_stats,
"init_sleep_seconds": args.init_sleep_seconds,
"batch_timeout": args.batch_timeout,
"init_timeout": args.init_timeout,
"shm_threshold_bytes": args.shm_threshold_bytes,
"worker_backend": args.worker_backend,
"ray_address": args.ray_address,
}
)

omni = Omni(model=model_name, **omni_kwargs)

formatted_prompts = []
for p in args.prompts:
if args.modality == "img2text":
if args.image_path:
loaded_image = Image.open(args.image_path).convert("RGB")
final_prompt_text = f"<|im_start|>user\n<|image_pad|>\n{p}<|im_end|>\n<|im_start|>assistant\n"
prompt_dict = {
"prompt": final_prompt_text,
"multi_modal_data": {"image": loaded_image},
"modalities": ["text"],
}
formatted_prompts.append(prompt_dict)
elif args.modality == "text2text":
final_prompt_text = f"<|im_start|>user\n{p}<|im_end|>\n<|im_start|>assistant\n"
prompt_dict = {"prompt": final_prompt_text, "modalities": ["text"]}
formatted_prompts.append(prompt_dict)
else:
# text2img
final_prompt_text = f"<|im_start|>{p}<|im_end|>"
prompt_dict = {"prompt": final_prompt_text, "modalities": ["image"]}
formatted_prompts.append(prompt_dict)

params_list = copy.deepcopy(omni.default_sampling_params_list)
if args.modality == "text2img":
params_list[0]["max_tokens"] = 1
if len(params_list) > 1:
params_list[1]["num_inference_steps"] = args.steps

omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list))

for i, req_output in enumerate(omni_outputs):
images = getattr(req_output, "images", None)
if not images and hasattr(req_output, "output"):
if isinstance(req_output.output, list):
images = req_output.output
else:
images = [req_output.output]

if images:
for j, img in enumerate(images):
img.save(f"output_{i}_{j}.png")

if hasattr(req_output, "request_output") and req_output.request_output:
for stage_out in req_output.request_output:
if hasattr(stage_out, "images") and stage_out.images:
for k, img in enumerate(stage_out.images):
save_path = f"output_{i}_stage_{getattr(stage_out, 'stage_id', '?')}_{k}.png"
img.save(save_path)
print(f"[Info] Saved stage output image to {save_path}")

print(omni_outputs)


if __name__ == "__main__":
main()
6 changes: 6 additions & 0 deletions examples/offline_inference/bagel/run_single_prompt.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
prompt="<|im_start|>user\nWhat is the capital of France?<|im_end|>\n<|im_start|>assistant\n"

python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
--prompt_type text \
--init-sleep-seconds 0 \
--prompts ${prompt}
184 changes: 184 additions & 0 deletions examples/online_serving/bagel/openai_chat_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
#!/usr/bin/env python3
"""
Bagel OpenAI-compatible chat client for image generation and multimodal tasks.

Usage:
python openai_chat_client.py --prompt "A cute cat" --output output.png
python openai_chat_client.py --prompt "Describe this image" --image-url https://example.com/image.png
"""

import argparse
import base64
from pathlib import Path

import requests


def generate_image(
prompt: str,
server_url: str = "http://localhost:8091",
image_url: str | None = None,
height: int | None = None,
width: int | None = None,
steps: int | None = None,
seed: int | None = None,
negative_prompt: str | None = None,
modality: str = "text2img", # "text2img" (default), "img2img", "img2text", "text2text"
) -> bytes | str | None:
"""Generate an image or text using the chat completions API.

Args:
prompt: Text description or prompt
server_url: Server URL
image_url: URL or path to input image (for img2img/img2text)
height: Image height in pixels
width: Image width in pixels
steps: Number of inference steps
seed: Random seed
negative_prompt: Negative prompt
modality: Task modality hint

Returns:
Image bytes (for image outputs) or Text string (for text outputs) or None if failed
"""

# Construct Message Content
content = [{"type": "text", "text": f"<|im_start|>{prompt}<|im_end|>"}]

if image_url:
# Check if local file
if Path(image_url).exists():
with open(image_url, "rb") as f:
b64_data = base64.b64encode(f.read()).decode("utf-8")
final_image_url = f"data:image/jpeg;base64,{b64_data}"
else:
final_image_url = image_url

content.append({"type": "image_url", "image_url": {"url": final_image_url}})

messages = [{"role": "user", "content": content}]

# Build request payload with all parameters at top level
# Note: vLLM ignores "extra_body", so we put parameters directly in the payload
payload = {"messages": messages}

# Set output modalities at top level
if modality == "text2img" or modality == "img2img":
payload["modalities"] = ["image"]
elif modality == "img2text" or modality == "text2text":
payload["modalities"] = ["text"]

# Add generation parameters directly to payload
if height is not None:
payload["height"] = height
if width is not None:
payload["width"] = width
if steps is not None:
payload["num_inference_steps"] = steps
if seed is not None:
payload["seed"] = seed
if negative_prompt:
payload["negative_prompt"] = negative_prompt

# Send request
try:
print(f"Sending request to {server_url} with modality {modality}...")
response = requests.post(
f"{server_url}/v1/chat/completions",
headers={"Content-Type": "application/json"},
json=payload,
timeout=300,
)
response.raise_for_status()
data = response.json()

# Extract content - check ALL choices since server may return multiple
# (e.g., text in choices[0], image in choices[1])
choices = data.get("choices", [])

# First pass: look for image output in any choice
for choice in choices:
choice_content = choice.get("message", {}).get("content")

# Handle Image Output
if isinstance(choice_content, list) and len(choice_content) > 0:
first_item = choice_content[0]
if isinstance(first_item, dict) and "image_url" in first_item:
img_url_str = first_item["image_url"].get("url", "")
if img_url_str.startswith("data:image"):
_, b64_data = img_url_str.split(",", 1)
return base64.b64decode(b64_data)

# Second pass: look for text output if no image found
for choice in choices:
choice_content = choice.get("message", {}).get("content")
if isinstance(choice_content, str) and choice_content:
return choice_content

print(f"Unexpected response format: {choices}")
return None

except Exception as e:
print(f"Error: {e}")
return None


def main():
parser = argparse.ArgumentParser(description="Bagel multimodal chat client")
parser.add_argument("--prompt", "-p", default="<|im_start|>A cute cat<|im_end|>", help="Text prompt")
parser.add_argument("--output", "-o", default="bagel_output.png", help="Output file (for image results)")
parser.add_argument("--server", "-s", default="http://localhost:8091", help="Server URL")

# Modality Control
parser.add_argument("--image-url", "-i", type=str, help="Input image URL or local path")
parser.add_argument(
"--modality",
"-m",
default="text2img",
choices=["text2img", "img2img", "img2text", "text2text"],
help="Task modality",
)

# Generation Params
parser.add_argument("--height", type=int, default=512, help="Image height")
parser.add_argument("--width", type=int, default=512, help="Image width")
parser.add_argument("--steps", type=int, default=25, help="Inference steps")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--negative", help="Negative prompt")

args = parser.parse_args()

print(f"Mode: {args.modality}")
if args.image_url:
print(f"Input Image: {args.image_url}")

result = generate_image(
prompt=args.prompt,
server_url=args.server,
image_url=args.image_url,
height=args.height,
width=args.width,
steps=args.steps,
seed=args.seed,
negative_prompt=args.negative,
modality=args.modality,
)

if result:
if isinstance(result, bytes):
# It's an image
output_path = Path(args.output)
output_path.write_bytes(result)
print(f"Image saved to: {output_path}")
print(f"Size: {len(result) / 1024:.1f} KB")
elif isinstance(result, str):
# It's text
print("Response:")
print(result)
else:
print("Failed to generate response")
exit(1)


if __name__ == "__main__":
main()
12 changes: 12 additions & 0 deletions examples/online_serving/bagel/run_server.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash
# Bagel online serving startup script

MODEL="${MODEL:-ByteDance-Seed/BAGEL-7B-MoT}"
PORT="${PORT:-8091}"

echo "Starting Bagel server..."
echo "Model: $MODEL"
echo "Port: $PORT"

vllm serve "$MODEL" --omni \
--port "$PORT"
Loading