Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions benchmark/gsm8k/bench_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sglang.api import set_default_backend
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
dump_bench_raw_result,
select_sglang_backend,
)
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
Expand Down Expand Up @@ -115,6 +116,12 @@ def few_shot_gsm8k(s, question):

# Dump results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
dump_bench_raw_result(
path=args.raw_result_file,
states=states,
preds=preds,
labels=labels,
)

with open(args.result_file, "a") as fout:
value = {
Expand Down
8 changes: 8 additions & 0 deletions benchmark/mmlu/bench_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
dump_bench_raw_result,
select_sglang_backend,
)

Expand Down Expand Up @@ -142,6 +143,13 @@ def few_shot_mmlu(s, examples, question):
assert pt == len(cors)
weighted_acc = np.mean(cors)

dump_bench_raw_result(
path=args.raw_result_file,
states=states,
preds=preds,
labels=labels,
)

# Print results
print("Total latency: {:.3f}".format(latency))
print("Average accuracy: {:.3f}".format(weighted_acc))
Expand Down
36 changes: 36 additions & 0 deletions python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import argparse
import copy
import json
import logging
import os
import random
Expand All @@ -13,6 +14,7 @@
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from types import SimpleNamespace
from typing import Callable, List, Optional, Tuple

Expand All @@ -25,6 +27,7 @@
from sglang.global_config import global_config
from sglang.lang.backend.openai import OpenAI
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.lang.interpreter import ProgramState
from sglang.srt.utils import (
get_bool_env_var,
get_device,
Expand Down Expand Up @@ -337,6 +340,7 @@ def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
help="Device type (auto/cuda/rocm/cpu). Auto will detect available platforms",
)
parser.add_argument("--result-file", type=str, default="result.jsonl")
parser.add_argument("--raw-result-file", type=str)
args = parser.parse_args()

return args
Expand Down Expand Up @@ -1249,3 +1253,35 @@ def _callTestMethod(self, method):
lambda: super(CustomTestCase, self)._callTestMethod(method),
max_retry=max_retry,
)


def dump_bench_raw_result(
path: str,
states,
preds,
labels,
):
if path == "":
return
Comment on lines +1259 to +1324
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The path argument should be Optional[str] to reflect that the --raw-result-file argument is optional. Also, use if path is None: instead of if path == "" to correctly handle the case where no path is provided. Finally, consider adding type hints for the other arguments and return type to improve readability.

def dump_bench_raw_result(
    path: Optional[str],
    states: List[ProgramState],
    preds: List,
    labels: List,
) -> None:
    if path is None:
        return


rows = []
for i in range(len(states)):
state = states[i]
output = state["answer"]
prompt = _ensure_remove_suffix(state.text(), output)
rows.append(
dict(
prompt_id=i,
prompt=prompt,
output=output,
correct=bool(preds[i] == labels[i]),
)
)

print(f"BenchRawResultDumper save results to {path}")
Path(path).write_text("\n".join(json.dumps(row) for row in rows))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add a trailing newline character to the end of the file to comply with the standard format for line-delimited JSON files.

    Path(path).write_text("\n".join(json.dumps(row) for row in rows) + "\n")



def _ensure_remove_suffix(text: str, suffix: str):
assert text.endswith(suffix)
return text.removesuffix(suffix)
Comment on lines +1345 to +1346
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider raising a ValueError instead of using assert for data validation. Assertions can be disabled in production, leading to unexpected behavior.

def _ensure_remove_suffix(text: str, suffix: str):
    if not text.endswith(suffix):
        raise ValueError(f"Text does not end with the expected suffix: '{suffix}'")
    return text.removesuffix(suffix)

Loading