Skip to content

Commit ac23872

Browse files
merrymercyrkooo567dhou-xaihanming-lu
committed
Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
1 parent 0194948 commit ac23872

86 files changed

Lines changed: 4110 additions & 2009 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,20 @@ def get_model_config(model_name: str, tp_size: int):
3030
topk = config.num_experts_per_tok
3131
intermediate_size = config.moe_intermediate_size
3232
shard_intermediate_size = 2 * intermediate_size // tp_size
33+
elif config.architectures[0] in [
34+
"Grok1ForCausalLM",
35+
"Grok1ImgGen",
36+
"Grok1AForCausalLM",
37+
]:
38+
E = config.num_local_experts
39+
topk = config.num_experts_per_tok
40+
intermediate_size = config.moe_intermediate_size
41+
shard_intermediate_size = 2 * intermediate_size // tp_size
3342
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
3443
E = config.n_routed_experts
3544
topk = config.num_experts_per_tok
3645
intermediate_size = config.intermediate_size
37-
shard_intermediate_size = 2 * intermediate_size // args.tp_size
46+
shard_intermediate_size = 2 * intermediate_size // tp_size
3847
else:
3948
# Default: Mixtral
4049
E = config.num_local_experts

benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,15 @@ def get_model_config(model_name: str, tp_size: int):
3535
topk = config.num_experts_per_tok
3636
intermediate_size = config.moe_intermediate_size
3737
shard_intermediate_size = 2 * intermediate_size // tp_size
38+
elif config.architectures[0] in [
39+
"Grok1ForCausalLM",
40+
"Grok1ImgGen",
41+
"Grok1AForCausalLM",
42+
]:
43+
E = config.num_local_experts
44+
topk = config.num_experts_per_tok
45+
intermediate_size = config.moe_intermediate_size
46+
shard_intermediate_size = 2 * intermediate_size // tp_size
3847
else:
3948
# Default: Mixtral
4049
E = config.num_local_experts

benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,15 @@ def main(args: argparse.Namespace):
397397
topk = config.num_experts_per_tok
398398
intermediate_size = config.moe_intermediate_size
399399
shard_intermediate_size = 2 * intermediate_size // args.tp_size
400+
elif config.architectures[0] in [
401+
"Grok1ForCausalLM",
402+
"Grok1ImgGen",
403+
"Grok1AForCausalLM",
404+
]:
405+
E = config.num_local_experts
406+
topk = config.num_experts_per_tok
407+
intermediate_size = config.moe_intermediate_size
408+
shard_intermediate_size = 2 * intermediate_size // args.tp_size
400409
else:
401410
# Default: Mixtral
402411
E = config.num_local_experts

docs/backend/native_api.ipynb

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,7 @@
210210
"response = requests.post(url, json=data)\n",
211211
"print_highlight(response.text)\n",
212212
"assert response.json()[\"success\"] is True\n",
213-
"assert response.json()[\"message\"] == \"Succeeded to update model weights.\"\n",
214-
"assert response.json().keys() == {\"success\", \"message\"}"
213+
"assert response.json()[\"message\"] == \"Succeeded to update model weights.\""
215214
]
216215
},
217216
{
@@ -411,7 +410,7 @@
411410
" },\n",
412411
")\n",
413412
"output = response.json()\n",
414-
"output_tokens = output[\"token_ids\"]\n",
413+
"output_tokens = output[\"output_ids\"]\n",
415414
"\n",
416415
"output_text = tokenizer.decode(output_tokens, skip_special_tokens=False)\n",
417416
"print_highlight(f\"Tokenized Output: {output_tokens}\")\n",

docs/backend/server_arguments.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ Please consult the documentation below to learn more about the parameters you ma
9696
* `schedule_policy`: The scheduling policy to control the processing order of waiting prefill requests in a single engine.
9797
* `schedule_conservativeness`: Can be used to decrease/increase the conservativeness of the server when taking new requests. Highly conservative behavior leads to starvation, but low conservativeness leads to slowed-down performance.
9898
* `cpu_offload_gb`: Reserve this amount of RAM in GB for offloading of model parameters to the CPU.
99-
* `prefill_only_one_req`: When this flag is turned on, the engine prefills only one request at a time.
10099

101100
## Other runtime options
102101

python/pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,10 @@ dev_cpu = ["sglang[all_cpu]", "sglang[test]"]
9696
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
9797

9898
[tool.setuptools.package-data]
99-
"sglang" = ["srt/layers/moe/fused_moe_triton/configs/*.json", "srt/layers/quantization/configs/*.json"]
99+
"sglang" = [
100+
"srt/layers/moe/fused_moe_triton/configs/*.json",
101+
"srt/layers/quantization/configs/*.json",
102+
]
100103

101104
[tool.setuptools.packages.find]
102105
exclude = [

python/sglang/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
- `bench_one_batch.py`: Benchmark the latency of running a single static batch without a server.
99
- `bench_one_batch_server.py`: Benchmark the latency of running a single batch with a server.
1010
- `bench_serving.py`: Benchmark online serving with dynamic requests.
11-
- `check_env.py`: Check the environment variables.
11+
- `check_env.py`: Check the environment variables and dependencies.
1212
- `global_config.py`: The global configs and constants.
1313
- `launch_server.py`: The entry point for launching the local server.
1414
- `llama3_eval.py`: Evaluation of Llama 3 using the Meta Llama dataset.
15+
- `profiler.py`: Profile a running server.
1516
- `utils.py`: Common utilities.
17+
- `version.py`: Version info.

python/sglang/bench_offline_throughput.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class BenchArgs:
5656
profile: bool = False
5757
skip_warmup: bool = False
5858
do_not_exit: bool = False
59+
prompt_suffix: str = ""
5960

6061
@staticmethod
6162
def add_cli_args(parser: argparse.ArgumentParser):
@@ -177,6 +178,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
177178
action="store_true",
178179
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
179180
)
181+
parser.add_argument(
182+
"--prompt-suffix",
183+
type=str,
184+
default="",
185+
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
186+
)
180187

181188
@classmethod
182189
def from_cli_args(cls, args: argparse.Namespace):
@@ -216,6 +223,10 @@ def throughput_test_once(
216223
]
217224

218225
if profile:
226+
assert (
227+
"SGLANG_TORCH_PROFILER_DIR" in os.environ
228+
), "Please set SGLANG_TORCH_PROFILER_DIR."
229+
os.makedirs(os.environ["SGLANG_TORCH_PROFILER_DIR"], exist_ok=True)
219230
backend.start_profile()
220231

221232
st = time.perf_counter()
@@ -229,6 +240,8 @@ def throughput_test_once(
229240
if backend_name == "runtime":
230241
gen_out = json.loads(gen_out)
231242

243+
server_info = backend.get_server_info()
244+
232245
measurement_results["total_latency"] = latency
233246
measurement_results["total_output_tokens"] = sum(
234247
o["meta_info"]["completion_tokens"] for o in gen_out
@@ -246,6 +259,7 @@ def throughput_test_once(
246259
measurement_results["total_input_tokens"]
247260
+ measurement_results["total_output_tokens"]
248261
) / latency
262+
measurement_results["last_gen_throughput"] = server_info["last_gen_throughput"]
249263

250264
return measurement_results
251265

@@ -361,6 +375,11 @@ def throughput_test(
361375
print(
362376
"{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])
363377
)
378+
print(
379+
"{:<40} {:<10.2f}".format(
380+
"Last generation throughput (tok/s):", result["last_gen_throughput"]
381+
)
382+
)
364383
print(
365384
"{:<40} {:<10.2f}".format(
366385
"Request throughput (req/s):", result["request_throughput"]

0 commit comments

Comments
 (0)