Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions benchmark/gsm8k/bench_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sglang.api import set_default_backend
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
dump_bench_raw_result,
select_sglang_backend,
)
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
Expand Down Expand Up @@ -115,6 +116,12 @@ def few_shot_gsm8k(s, question):

# Dump results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
dump_bench_raw_result(
path=args.raw_result_file,
states=states,
preds=preds,
labels=labels,
)

with open(args.result_file, "a") as fout:
value = {
Expand Down
8 changes: 8 additions & 0 deletions benchmark/mmlu/bench_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
dump_bench_raw_result,
select_sglang_backend,
)

Expand Down Expand Up @@ -142,6 +143,13 @@ def few_shot_mmlu(s, examples, question):
assert pt == len(cors)
weighted_acc = np.mean(cors)

dump_bench_raw_result(
path=args.raw_result_file,
states=states,
preds=preds,
labels=labels,
)

# Print results
print("Total latency: {:.3f}".format(latency))
print("Average accuracy: {:.3f}".format(weighted_acc))
Expand Down
172 changes: 172 additions & 0 deletions python/sglang/srt/debug_utils/text_comparator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import argparse
import json
from pathlib import Path

import polars as pl

_DESCRIPTION = """Compare and find differences to benchmark outputs.

Supported inputs:
* The samples jsonl from `lm_eval --log_samples --output_path FOLDER_NAME`
* The output from `gsm8k/bench_sglang.py --raw-result-file FILE_NAME` (or mmlu)
"""


def main(args):
df_input = _transform_df_input(_compute_df_raw(args))
assert all(
c in df_input.columns
for c in ["category", "trial_index", "prompt_id", "prompt", "output", "correct"]
)

df_meta = _compute_df_meta(df_input)

df_correctness_per_trial = df_input.group_by(
"category", "trial_index", maintain_order=True
).agg(pl.col("correct").mean())
df_correctness_delta = (
df_meta.group_by("correctness_delta").len().sort("correctness_delta")
)
df_good_to_bad = df_meta.filter(pl.col("correctness_delta") < 0)
df_bad_to_good = df_meta.filter(pl.col("correctness_delta") > 0)

print(f"Dump output to {args.output_path}")
Path(args.output_path).write_text(
json.dumps(
dict(
df_meta=df_meta.to_dicts(),
df_good_to_bad=df_good_to_bad.to_dicts(),
df_bad_to_good=df_bad_to_good.to_dicts(),
)
)
)

if not args.disable_print_details:
with pl.Config(
fmt_str_lengths=10000,
tbl_cols=-1,
tbl_rows=-1,
tbl_width_chars=-1,
tbl_formatting="UTF8_FULL",
):
print("====== Correctness per trial ======")
print(df_correctness_per_trial)

print(
"====== Correctness Delta (-1.0 means all-right becomes all-wrong) ======"
)
print(df_correctness_delta)

for name, df in [
("Good->Bad", df_good_to_bad),
("Bad->Good", df_bad_to_good),
]:
print(f"====== Concrete Examples: {name} ======")
print(df)


def _compute_df_raw(args):
return pl.concat(
[
_read_df_raw(p, category=category, trial_index=i)
for category, paths in [
("baseline", args.baseline_path),
("target", args.target_path),
]
for i, p in enumerate(paths)
]
)


def _read_df_raw(path: str, category: str, trial_index: int):
return pl.read_ndjson(path).with_columns(
category=pl.lit(category), trial_index=trial_index
)


def _transform_df_input(df: pl.DataFrame):
if "doc_id" in df.columns:
print("Transform mode: lm_eval")

filter_names = df["filter"].unique(maintain_order=True).to_list()
if len(filter_names) > 1:
filter_name = filter_names[0]
print(f"Choose {filter_name=} among {filter_names}")
df = df.filter(pl.col("filter") == filter_name)

df = df.select(
pl.col("category"),
pl.col("trial_index"),
prompt_id=pl.col("doc_id"),
prompt=pl.col("arguments").struct.field("gen_args_0").struct.field("arg_0"),
output=pl.col("resps").list.get(0).list.get(0),
correct=pl.col("exact_match").cast(bool),
)

return df
elif "prompt_id" in df.columns:
print("Transform mode: SGLang bench")
return df
else:
raise Exception(f"Unknown data: {df.columns}")


def _compute_df_meta(df_input: pl.DataFrame):
df_input = df_input.sort("prompt_id", "category", "trial_index")
df_meta = pl.DataFrame(
[
_handle_one_prompt(df_one_prompt)
for df_one_prompt in df_input.partition_by("prompt_id", maintain_order=True)
]
)
df_meta = df_meta.with_columns(
correctness_delta=pl.col("correctness_target") - pl.col("correctness_baseline"),
)
df_meta = df_meta.sort("correctness_delta", "output_same_prefix_len")
return df_meta


def _handle_one_prompt(df_one_prompt: pl.DataFrame):
assert len(set(df_one_prompt["prompt"])) == 1

df_baseline = df_one_prompt.filter(pl.col("category") == "baseline")
df_target = df_one_prompt.filter(pl.col("category") == "target")

outputs_baseline = df_baseline["output"].to_list()
outputs_target = df_target["output"].to_list()

output_same_prefix_len = max(
_compute_str_prefix_len(output_baseline, output_target)
for output_baseline in outputs_baseline
for output_target in outputs_target
)

return dict(
prompt_id=df_one_prompt[0, "prompt_id"],
correctness_baseline=df_baseline["correct"].mean(),
correctness_target=df_target["correct"].mean(),
output_same_prefix_len=output_same_prefix_len,
prompt=df_one_prompt[0, "prompt"],
outputs_baseline=outputs_baseline,
outputs_target=outputs_target,
)


def _compute_str_prefix_len(a: str, b: str) -> int:
min_len = min(len(a), len(b))
for i in range(min_len):
if a[i] != b[i]:
return i
return min_len


if __name__ == "__main__":
parser = argparse.ArgumentParser(description=_DESCRIPTION)
parser.add_argument("--baseline-path", type=str, nargs="+")
parser.add_argument("--target-path", type=str, nargs="+")
parser.add_argument(
"--output-path", type=str, default="/tmp/text_comparator_output.json"
)
parser.add_argument("--disable-print-details", action="store_true")
args = parser.parse_args()
main(args)
35 changes: 35 additions & 0 deletions python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from types import SimpleNamespace
from typing import Awaitable, Callable, List, Optional, Tuple

Expand All @@ -27,6 +28,7 @@
from sglang.global_config import global_config
from sglang.lang.backend.openai import OpenAI
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.lang.interpreter import ProgramState
from sglang.srt.utils import (
get_bool_env_var,
get_device,
Expand Down Expand Up @@ -348,6 +350,7 @@ def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
help="Device type (auto/cuda/rocm/cpu). Auto will detect available platforms",
)
parser.add_argument("--result-file", type=str, default="result.jsonl")
parser.add_argument("--raw-result-file", type=str)
args = parser.parse_args()

return args
Expand Down Expand Up @@ -1309,3 +1312,35 @@ def _callTestMethod(self, method):
lambda: super(CustomTestCase, self)._callTestMethod(method),
max_retry=max_retry,
)


def dump_bench_raw_result(
path: str,
states,
preds,
labels,
):
if not path:
return

rows = []
for i in range(len(states)):
state = states[i]
output = state["answer"]
prompt = _ensure_remove_suffix(state.text(), output)
rows.append(
dict(
prompt_id=i,
prompt=prompt,
output=output,
correct=bool(preds[i] == labels[i]),
)
)

print(f"BenchRawResultDumper save results to {path}")
Path(path).write_text("\n".join(json.dumps(row) for row in rows))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add a trailing newline character to the end of the file to comply with the standard format for line-delimited JSON files.

    Path(path).write_text("\n".join(json.dumps(row) for row in rows) + "\n")



def _ensure_remove_suffix(text: str, suffix: str):
assert text.endswith(suffix)
return text.removesuffix(suffix)
Comment on lines +1345 to +1346
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider raising a ValueError instead of using assert for data validation. Assertions can be disabled in production, leading to unexpected behavior.

def _ensure_remove_suffix(text: str, suffix: str):
    if not text.endswith(suffix):
        raise ValueError(f"Text does not end with the expected suffix: '{suffix}'")
    return text.removesuffix(suffix)

Loading