Skip to content

Commit 22352d4

Browse files
Improve streaming, log_level, memory report, weight loading, and benchmark script (#7632)
Co-authored-by: Kan Wu <[email protected]>
1 parent c5131f7 commit 22352d4

24 files changed

+626
-160
lines changed

docs/backend/server_arguments.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
116116
| `--log-level` | The logging level of all loggers. | info |
117117
| `--log-level-http` | The logging level of HTTP server. If not set, reuse --log-level by default. | None |
118118
| `--log-requests` | Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level. | False |
119-
| `--log-requests-level` | 0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output. | 0 |
119+
| `--log-requests-level` | 0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output. | 0 |
120120
| `--show-time-cost` | Show time cost of custom marks. | False |
121121
| `--enable-metrics` | Enable log prometheus metrics. | False |
122122
| `--bucket-time-to-first-token` | The buckets of time to first token, specified as a list of floats. | None |

python/sglang/bench_one_batch_server.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class BenchArgs:
3838
output_len: Tuple[int] = (16,)
3939
temperature: float = 0.0
4040
return_logprob: bool = False
41+
client_stream_interval: int = 1
4142
input_len_step_percentage: float = 0.0
4243
result_filename: str = "result.jsonl"
4344
base_url: str = ""
@@ -60,6 +61,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
6061
)
6162
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
6263
parser.add_argument("--return-logprob", action="store_true")
64+
parser.add_argument(
65+
"--client-stream-interval",
66+
type=int,
67+
default=BenchArgs.client_stream_interval,
68+
)
6369
parser.add_argument(
6470
"--input-len-step-percentage",
6571
type=float,
@@ -120,6 +126,7 @@ def run_one_case(
120126
output_len: int,
121127
temperature: float,
122128
return_logprob: bool,
129+
stream_interval: int,
123130
input_len_step_percentage: float,
124131
run_name: str,
125132
result_filename: str,
@@ -168,6 +175,7 @@ def run_one_case(
168175
"max_new_tokens": output_len,
169176
"ignore_eos": True,
170177
"json_schema": json_schema,
178+
"stream_interval": stream_interval,
171179
},
172180
"return_logprob": return_logprob,
173181
"stream": True,
@@ -245,8 +253,9 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
245253
else:
246254
proc, base_url = launch_server_process(server_args)
247255

248-
tokenizer_id = server_args.tokenizer_path or server_args.model_path
249-
tokenizer = get_tokenizer(tokenizer_id)
256+
server_info = requests.get(base_url + "/get_server_info")
257+
tokenizer_path = server_info.json()["tokenizer_path"]
258+
tokenizer = get_tokenizer(tokenizer_path)
250259

251260
# warmup
252261
if not bench_args.skip_warmup:
@@ -258,6 +267,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
258267
output_len=16,
259268
temperature=bench_args.temperature,
260269
return_logprob=bench_args.return_logprob,
270+
stream_interval=bench_args.client_stream_interval,
261271
input_len_step_percentage=bench_args.input_len_step_percentage,
262272
run_name="",
263273
result_filename="",
@@ -280,6 +290,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
280290
ol,
281291
temperature=bench_args.temperature,
282292
return_logprob=bench_args.return_logprob,
293+
stream_interval=bench_args.client_stream_interval,
283294
input_len_step_percentage=bench_args.input_len_step_percentage,
284295
run_name=bench_args.run_name,
285296
result_filename=bench_args.result_filename,
@@ -301,6 +312,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
301312
ol,
302313
temperature=bench_args.temperature,
303314
return_logprob=bench_args.return_logprob,
315+
stream_interval=bench_args.client_stream_interval,
304316
input_len_step_percentage=bench_args.input_len_step_percentage,
305317
run_name=bench_args.run_name,
306318
result_filename=bench_args.result_filename,

python/sglang/bench_serving.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1678,7 +1678,6 @@ def run_benchmark(args_: argparse.Namespace):
16781678
if args.base_url
16791679
else f"http://{args.host}:{args.port}/generate"
16801680
)
1681-
args.apply_chat_template = True
16821681
elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
16831682
api_url = (
16841683
f"{args.base_url}/v1/completions"

python/sglang/srt/configs/internvl.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,12 +147,11 @@ def _rope_scaling_validation(self):
147147
)
148148
if (
149149
rope_scaling_factor is None
150-
or not isinstance(rope_scaling_factor, float)
151-
or not isinstance(rope_scaling_factor, int)
150+
or not isinstance(rope_scaling_factor, (float, int))
152151
or rope_scaling_factor < 1.0
153152
):
154153
raise ValueError(
155-
f"`rope_scaling`'s factor field must be a float|int >= 1, got {rope_scaling_factor}"
154+
f"`rope_scaling`'s factor field must be a float|int >= 1, got {rope_scaling_factor=}, {type(rope_scaling_factor)=}"
156155
)
157156
if isinstance(rope_scaling_factor, int):
158157
rope_scaling_factor = float(rope_scaling_factor)

python/sglang/srt/entrypoints/http_server.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,6 @@ def set_global_state(global_state: _GlobalState):
126126

127127
@asynccontextmanager
128128
async def lifespan(fast_api_app: FastAPI):
129-
server_args: ServerArgs = fast_api_app.server_args
130-
131129
# Initialize OpenAI serving handlers
132130
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
133131
_global_state.tokenizer_manager, _global_state.template_manager
@@ -145,9 +143,12 @@ async def lifespan(fast_api_app: FastAPI):
145143
_global_state.tokenizer_manager
146144
)
147145

146+
server_args: ServerArgs = fast_api_app.server_args
148147
if server_args.warmups is not None:
149148
await execute_warmups(
150-
server_args.warmups.split(","), _global_state.tokenizer_manager
149+
server_args.disaggregation_mode,
150+
server_args.warmups.split(","),
151+
_global_state.tokenizer_manager,
151152
)
152153
logger.info("Warmup ended")
153154

@@ -280,13 +281,17 @@ async def get_model_info():
280281
"model_path": _global_state.tokenizer_manager.model_path,
281282
"tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path,
282283
"is_generation": _global_state.tokenizer_manager.is_generation,
284+
"preferred_sampling_params": _global_state.tokenizer_manager.server_args.preferred_sampling_params,
283285
}
284286
return result
285287

286288

287289
@app.get("/get_server_info")
288290
async def get_server_info():
289-
internal_states = await _global_state.tokenizer_manager.get_internal_state()
291+
# Returns interna states per DP.
292+
internal_states: List[Dict[Any, Any]] = (
293+
await _global_state.tokenizer_manager.get_internal_state()
294+
)
290295
return {
291296
**dataclasses.asdict(_global_state.tokenizer_manager.server_args),
292297
**_global_state.scheduler_info,
@@ -300,6 +305,8 @@ async def get_load():
300305
return await _global_state.tokenizer_manager.get_load()
301306

302307

308+
# example usage:
309+
# curl -s -X POST http://localhost:30000/set_internal_state -H "Content-Type: application/json" -d '{"server_args": {"max_micro_batch_size": 8}}'
303310
@app.api_route("/set_internal_state", methods=["POST", "PUT"])
304311
async def set_internal_state(obj: SetInternalStateReq, request: Request):
305312
res = await _global_state.tokenizer_manager.set_internal_state(obj)
@@ -886,14 +893,23 @@ def launch_server(
886893
add_prometheus_middleware(app)
887894
enable_func_timer()
888895

896+
image_token_text = None
897+
if (
898+
tokenizer_manager.image_token_id is not None
899+
and not server_args.skip_tokenizer_init
900+
):
901+
image_token_text = tokenizer_manager.tokenizer.decode(
902+
[tokenizer_manager.image_token_id]
903+
)
904+
889905
# Send a warmup request - we will create the thread launch it
890906
# in the lifespan after all other warmups have fired.
891907
warmup_thread = threading.Thread(
892908
target=_wait_and_warmup,
893909
args=(
894910
server_args,
895911
pipe_finish_writer,
896-
_global_state.tokenizer_manager.image_token_id,
912+
image_token_text,
897913
launch_callback,
898914
),
899915
)
@@ -1022,9 +1038,10 @@ def _wait_and_warmup(
10221038
return
10231039

10241040
# Debug print
1025-
# logger.info(f"{res.json()=}")
1041+
# logger.info(f"warmup request returns: {res.json()=}")
10261042

10271043
logger.info("The server is fired up and ready to roll!")
1044+
10281045
if pipe_finish_writer is not None:
10291046
pipe_finish_writer.send("ready")
10301047

python/sglang/srt/layers/elementwise.py

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
_is_hip = is_hip()
1010

11+
1112
fused_softcap_autotune = triton.autotune(
1213
configs=[
1314
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
@@ -189,21 +190,16 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
189190
assert x.shape == residual.shape and x.dtype == residual.dtype
190191
output, mid = torch.empty_like(x), torch.empty_like(x)
191192
bs, hidden_dim = x.shape
192-
193-
min_num_warps = 16 if _is_hip else 32
194-
195193
if autotune:
196194
fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
197195
output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
198196
)
199197
else:
198+
max_warps = 16 if _is_hip else 32
200199
config = {
201200
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
202201
"num_warps": max(
203-
min(
204-
triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps
205-
),
206-
4,
202+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
207203
),
208204
}
209205

@@ -260,13 +256,11 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
260256
else:
261257
output = torch.empty_like(x)
262258
bs, hidden_dim = x.shape
263-
264-
min_num_warps = 16 if _is_hip else 32
265-
259+
max_warps = 16 if _is_hip else 32
266260
config = {
267261
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
268262
"num_warps": max(
269-
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
263+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
270264
),
271265
}
272266

@@ -331,6 +325,75 @@ def forward_native(
331325
return self.rmsnorm2.forward_native(residual), residual
332326

333327

328+
@triton.jit
329+
def experts_combine_kernel(
330+
out_hidden_states,
331+
moe_hidden_states,
332+
mlp_hidden_states,
333+
combine_k: tl.constexpr,
334+
hidden_dim: tl.constexpr,
335+
BLOCK_SIZE: tl.constexpr,
336+
):
337+
pid = tl.program_id(0)
338+
start_index_mlp = pid * hidden_dim
339+
start_index_rmoe = pid * hidden_dim * combine_k
340+
offsets = tl.arange(0, BLOCK_SIZE)
341+
mask = offsets < hidden_dim
342+
combine_k_offsets = tl.arange(0, combine_k)
343+
344+
moe_x = tl.load(
345+
moe_hidden_states
346+
+ start_index_rmoe
347+
+ combine_k_offsets[:, None] * hidden_dim
348+
+ offsets[None, :],
349+
mask=mask[None, :],
350+
other=0.0,
351+
)
352+
moe_x = tl.sum(moe_x, axis=0)
353+
mlp_x = tl.load(mlp_hidden_states + start_index_mlp + offsets, mask=mask, other=0.0)
354+
combined_x = (moe_x + mlp_x) / 1.4142135623730951
355+
356+
tl.store(out_hidden_states + start_index_mlp + offsets, combined_x, mask=mask)
357+
358+
359+
def experts_combine_triton(moe_hidden_states, mlp_hidden_states, output_buffer=None):
360+
assert moe_hidden_states.is_contiguous()
361+
assert mlp_hidden_states.is_contiguous()
362+
363+
if len(moe_hidden_states.shape) == 2:
364+
combine_k = 1 # pre-combined
365+
else:
366+
combine_k = moe_hidden_states.shape[1]
367+
368+
if output_buffer is None:
369+
out_hidden_states = torch.empty_like(mlp_hidden_states)
370+
else:
371+
flat_output_buffer = output_buffer.view(mlp_hidden_states.dtype).reshape(-1)
372+
assert flat_output_buffer.numel() >= mlp_hidden_states.numel()
373+
out_hidden_states = flat_output_buffer[: mlp_hidden_states.numel()].reshape(
374+
mlp_hidden_states.shape
375+
)
376+
377+
bs, hidden_dim = mlp_hidden_states.shape
378+
379+
config = {
380+
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
381+
"num_warps": max(
382+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 1024)), 8), 4
383+
),
384+
}
385+
386+
experts_combine_kernel[(bs,)](
387+
out_hidden_states,
388+
moe_hidden_states,
389+
mlp_hidden_states,
390+
combine_k,
391+
hidden_dim,
392+
**config,
393+
)
394+
return out_hidden_states
395+
396+
334397
# gelu on first half of vector
335398
@triton.jit
336399
def gelu_and_mul_kernel(
@@ -400,10 +463,11 @@ def gelu_and_mul_triton(
400463
out_scales = scales
401464
static_scale = True
402465

466+
max_warps = 16 if _is_hip else 32
403467
config = {
404468
# 8 ele per thread (not tuned)
405469
"num_warps": max(
406-
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4
470+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4
407471
),
408472
}
409473

0 commit comments

Comments
 (0)