Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 154 additions & 22 deletions python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,124 @@ async def async_request_openai_completions(
return output


async def async_request_openai_chat_completions(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(
"chat/completions"
), "OpenAI Chat Completions API URL must end with 'chat/completions'."

if request_func_input.image_data:
messages = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": request_func_input.image_data},
},
{"type": "text", "text": request_func_input.prompt},
],
},
]
else:
messages = [{"role": "user", "content": request_func_input.prompt}]

async with _create_bench_client_session() as session:
payload = {
"model": request_func_input.model,
"messages": messages,
"temperature": 0.0,
"max_tokens": request_func_input.output_len,
"stream": not args.disable_stream,
**request_func_input.extra_request_body,
}
headers = get_auth_headers()

output = RequestFuncOutput.init_new(request_func_input)

generated_text = ""
output_len = request_func_input.output_len
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(
url=api_url, json=payload, headers=headers
) as response:
if response.status == 200:
if args.disable_stream:
# Non-streaming response
response_json = await response.json()
output.generated_text = response_json["choices"][0]["message"][
"content"
]
output.success = True
output.latency = time.perf_counter() - st
output.ttft = (
output.latency
) # For non-streaming, TTFT = total latency
output.output_len = response_json.get("usage", {}).get(
"completion_tokens", output_len
)
else:
# Streaming response
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue

chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
latency = time.perf_counter() - st
if chunk == "[DONE]":
pass
else:
data = json.loads(chunk)

# Check if this chunk contains content
delta = data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")

if content:
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft

# Decoding phase
else:
output.itl.append(
timestamp - most_recent_timestamp
)

most_recent_timestamp = timestamp
generated_text += content

# Check for usage info in final chunk
output_len = (data.get("usage") or {}).get(
"completion_tokens", output_len
)

output.generated_text = generated_text
output.success = True
output.latency = latency
output.output_len = output_len
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))

if pbar:
pbar.update(1)
return output


async def async_request_truss(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
Expand Down Expand Up @@ -544,6 +662,7 @@ def get_dataset(args, tokenizer):
num_requests=args.num_prompts,
tokenizer=tokenizer,
fixed_output_len=args.random_output_len,
apply_chat_template=args.apply_chat_template,
random_sample=True,
)
else:
Expand All @@ -555,8 +674,11 @@ def get_dataset(args, tokenizer):
"sglang": async_request_sglang_generate,
"sglang-native": async_request_sglang_generate,
"sglang-oai": async_request_openai_completions,
"sglang-oai-chat": async_request_openai_chat_completions,
"vllm": async_request_openai_completions,
"vllm-chat": async_request_openai_chat_completions,
"lmdeploy": async_request_openai_completions,
"lmdeploy-chat": async_request_openai_chat_completions,
"trt": async_request_trt_llm,
"gserver": async_request_gserver,
"truss": async_request_truss,
Expand Down Expand Up @@ -661,6 +783,7 @@ def sample_mmmu_requests(
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
apply_chat_template: bool = True,
random_sample: bool = True,
) -> List[DatasetRow]:
"""
Expand Down Expand Up @@ -739,28 +862,30 @@ def sample_mmmu_requests(

# Construct the prompt
prompt = f"Question: {question}\n\nAnswer: "

try:
prompt = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": image_data},
},
{"type": "text", "text": prompt},
],
}
],
add_generation_prompt=True,
tokenize=False,
)
except Exception as e:
# Note (Xinyuan): This is a workaround for an issue where some tokenizers do not support content as a list. (e.g. InternVL)
print(f"Error applying chat template: {e}, fallback to <image> tag")
prompt = f"<image>{prompt}"
if apply_chat_template:
try:
prompt = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": image_data},
},
{"type": "text", "text": prompt},
],
}
],
add_generation_prompt=True,
tokenize=False,
)
except Exception as e:
# Note (Xinyuan): This is a workaround for an issue where some tokenizers do not support content as a list. (e.g. InternVL)
print(
f"Error applying chat template: {e}, fallback to <image> tag"
)
prompt = f"<image>{prompt}"

# Calculate token lengths for text only (without image data)
prompt_token_ids = tokenizer.encode(prompt)
Expand Down Expand Up @@ -1538,12 +1663,19 @@ def run_benchmark(args_: argparse.Namespace):
if args.base_url
else f"http://{args.host}:{args.port}/generate"
)
args.apply_chat_template = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this debug code

elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
api_url = (
f"{args.base_url}/v1/completions"
if args.base_url
else f"http://{args.host}:{args.port}/v1/completions"
)
elif args.backend in ["sglang-oai-chat", "vllm-chat", "lmdeploy-chat"]:
api_url = (
f"{args.base_url}/v1/chat/completions"
if args.base_url
else f"http://{args.host}:{args.port}/v1/chat/completions"
)
elif args.backend == "trt":
api_url = (
f"{args.base_url}/v2/models/ensemble/generate_stream"
Expand Down
Loading