Skip to content
59 changes: 23 additions & 36 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,7 @@ def __init__(self, model_runner: ModelRunner):
# Attention backend
self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs
if global_server_args_dict["attention_backend"] == "flashmla":
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
else:
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
self.seq_len_fill_value = (
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
)
Expand Down Expand Up @@ -259,23 +256,8 @@ def __init__(self, model_runner: ModelRunner):
}

# Speculative_inference
if (
model_runner.spec_algorithm.is_eagle3()
and not model_runner.is_draft_worker
):
self.hidden_states = torch.zeros(
(
self.max_num_token,
3 * self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
if model_runner.spec_algorithm.is_eagle3():
self.model_runner.model.set_eagle3_layers_to_capture()
elif model_runner.spec_algorithm.is_eagle():
self.hidden_states = torch.zeros(
(self.max_num_token, self.model_runner.model_config.hidden_size),
dtype=self.model_runner.dtype,
)

if self.is_encoder_decoder:
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
Expand All @@ -284,8 +266,8 @@ def __init__(self, model_runner: ModelRunner):
)
else:
self.encoder_lens = None
if self.enable_dp_attention or self.enable_sp_layernorm:
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer

if self.enable_dp_attention:
self.gathered_buffer = torch.zeros(
(
self.max_bs * self.dp_size * self.num_tokens_per_bs,
Expand All @@ -303,13 +285,7 @@ def __init__(self, model_runner: ModelRunner):
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture CUDA graph failed: {e}\n"
"Possible solutions:\n"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
"3. disable torch compile by not using --enable-torch-compile\n"
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
)

@contextmanager
Expand Down Expand Up @@ -439,6 +415,7 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
self.capture_hidden_mode = (
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
)

if self.model_runner.server_args.lora_paths is not None:
# Currently, if the lora_path in `lora_paths` is None, the lora backend will use a
# different logic to handle lora, so we need to set `lora_paths` to a list of non-None
Expand Down Expand Up @@ -467,9 +444,9 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
capture_hidden_mode=self.capture_hidden_mode,
lora_paths=lora_paths,
num_token_non_padded=self.num_token_non_padded,
global_forward_mode=self.capture_forward_mode,
lora_paths=lora_paths,
)
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)

Expand Down Expand Up @@ -497,7 +474,9 @@ def run_once():
self.pp_size > 1
and "pp_proxy_tensors" in inspect.signature(forward).parameters
):
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
kwargs["pp_proxy_tensors"] = PPProxyTensors(
{k: v.clone() for k, v in pp_proxy_tensors.tensors.items()}
)

logits_output_or_pp_proxy_tensors = forward(
input_ids,
Expand Down Expand Up @@ -590,9 +569,6 @@ def replay_prepare(
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)

if hasattr(forward_batch.spec_info, "hidden_states"):
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states

# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
bs,
Expand Down Expand Up @@ -650,7 +626,7 @@ def get_spec_info(self, num_tokens: int):
else:
spec_info = EagleVerifyInput(
draft_token=None,
custom_mask=torch.zeros(
custom_mask=torch.ones(
(num_tokens * self.model_runner.model_config.context_len),
dtype=torch.bool,
device="cuda",
Expand All @@ -660,9 +636,20 @@ def get_spec_info(self, num_tokens: int):
retrive_next_token=None,
retrive_next_sibling=None,
retrive_cum_len=None,
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
spec_steps=self.model_runner.server_args.speculative_num_steps,
topk=self.model_runner.server_args.speculative_eagle_topk,
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
capture_hidden_mode=CaptureHiddenMode.FULL,
)

return spec_info


CUDA_GRAPH_CAPTURE_FAILED_MSG = (
"Possible solutions:\n"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
"3. disable torch compile by not using --enable-torch-compile\n"
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
)
Loading