Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 41 additions & 23 deletions python/sglang/srt/speculative/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,16 @@ class EagleDraftInput:
kv_indptr: torch.Tensor = None
kv_indices: torch.Tensor = None

# Inputs for draft extend
# shape: (b,)
seq_lens_for_draft_extend: torch.Tensor = None
req_pool_indices_for_draft_extend: torch.Tensor = None

def prepare_for_extend(self, batch: ScheduleBatch):

if batch.forward_mode.is_idle():
return

# Prefill only generate 1 token.
assert len(self.verified_id) == len(batch.seq_lens)

Expand All @@ -94,7 +101,7 @@ def create_idle_input(
capture_hidden_mode: CaptureHiddenMode,
):
return cls(
verified_id=None,
verified_id=torch.empty((0,), device=device, dtype=torch.int32),
hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
Expand All @@ -108,7 +115,10 @@ def prepare_extend_after_decode(
batch: ScheduleBatch,
speculative_num_steps: int,
):
batch.forward_mode = ForwardMode.DRAFT_EXTEND

if batch.forward_mode.is_idle():
return

batch.input_ids = self.verified_id
batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
batch.extend_num_tokens = sum(batch.extend_lens)
Expand Down Expand Up @@ -315,7 +325,7 @@ def generate_attn_arg_prefill(
def verify(
self,
batch: ScheduleBatch,
logits_output: torch.Tensor,
logits_output: LogitsProcessorOutput,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
page_size: int,
vocab_mask: Optional[torch.Tensor] = None, # For grammar
Expand Down Expand Up @@ -593,13 +603,14 @@ def verify(
batch.out_cache_loc = tgt_cache_loc
batch.seq_lens.add_(accept_length + 1)

draft_input = EagleDraftInput()
draft_input.hidden_states = batch.spec_info.hidden_states[accept_index]
draft_input.verified_id = verified_id
draft_input.accept_length = accept_length
draft_input.accept_length_cpu = accept_length.tolist()
draft_input.seq_lens_for_draft_extend = batch.seq_lens
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
draft_input = EagleDraftInput(
hidden_states=batch.spec_info.hidden_states[accept_index],
verified_id=verified_id,
accept_length=accept_length,
accept_length_cpu=accept_length.tolist(),
seq_lens_for_draft_extend=batch.seq_lens,
req_pool_indices_for_draft_extend=batch.req_pool_indices,
)

return EagleVerifyOutput(
draft_input=draft_input,
Expand All @@ -622,7 +633,6 @@ def verify(
batch.seq_lens.add_(accept_length + 1)

accept_length_cpu = accept_length.tolist()
draft_input = EagleDraftInput()
if len(unfinished_accept_index) > 0:
unfinished_accept_index = torch.cat(unfinished_accept_index)
unfinished_index_device = torch.tensor(
Expand Down Expand Up @@ -653,18 +663,26 @@ def verify(
next_power_of_2(self.draft_token_num),
)

draft_input.hidden_states = batch.spec_info.hidden_states[
unfinished_accept_index
]
draft_input.verified_id = predict[unfinished_accept_index]
draft_input.accept_length_cpu = draft_input_accept_length_cpu
draft_input.accept_length = accept_length[unfinished_index_device]
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
unfinished_index_device
]
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
unfinished_index_device
]
draft_input = EagleDraftInput(
hidden_states=batch.spec_info.hidden_states[
unfinished_accept_index
],
verified_id=predict[unfinished_accept_index],
accept_length_cpu=draft_input_accept_length_cpu,
accept_length=accept_length[unfinished_index_device],
seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device],
req_pool_indices_for_draft_extend=batch.req_pool_indices[
unfinished_index_device
],
)
else:
draft_input = EagleDraftInput.create_idle_input(
device=batch.device,
hidden_size=batch.model_config.hidden_size,
dtype=batch.model_config.dtype,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)

return EagleVerifyOutput(
draft_input=draft_input,
Expand Down
77 changes: 42 additions & 35 deletions python/sglang/srt/speculative/eagle_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def draft_model_runner(self):

def forward_batch_speculative_generation(
self, batch: ScheduleBatch
) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]:
"""Run speculative decoding forward.

NOTE: Many states of batch is modified as you go through. It is not guaranteed that
Expand Down Expand Up @@ -325,11 +325,16 @@ def forward_batch_speculative_generation(
self.verify(batch, spec_info)
)

if self.check_forward_draft_extend_after_decode(batch):
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend_after_decode(
batch,
)
with self.draft_tp_context(self.draft_model_runner.tp_group):
# NOTE: We should use `check_forward_draft_extend_after_decode`
# when DP attention is enabled, but it is slow. Skip it for now.
if (
self.server_args.enable_dp_attention
or batch.spec_info.verified_id.shape[0] > 0
):
# decode is not finished
self.forward_draft_extend_after_decode(batch)

return (
logits_output,
verify_output.verified_id,
Expand All @@ -339,10 +344,7 @@ def forward_batch_speculative_generation(
)

def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
local_need_forward = (
batch.spec_info.verified_id is not None
and batch.spec_info.verified_id.shape[0] > 0
)
local_need_forward = batch.spec_info.verified_id.shape[0] > 0
if not self.server_args.enable_dp_attention:
return local_need_forward

Expand All @@ -361,7 +363,7 @@ def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):

def forward_target_extend(
self, batch: ScheduleBatch
) -> Tuple[LogitsProcessorOutput, List[int], int]:
) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, Optional[torch.Tensor]]:
"""Run the target extend.

Args:
Expand Down Expand Up @@ -782,8 +784,8 @@ def forward_draft_extend(
self,
batch: ScheduleBatch,
hidden_states: torch.Tensor,
next_token_ids: List[int],
seq_lens_cpu: torch.Tensor,
next_token_ids: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
):
"""Run draft model extend. This API modifies the states of the batch.

Expand Down Expand Up @@ -819,29 +821,34 @@ def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
req_pool_indices_backup = batch.req_pool_indices
accept_length_backup = batch.spec_info.accept_length
return_logprob_backup = batch.return_logprob

input_is_idle = batch.forward_mode.is_idle()
if not input_is_idle:
# Prepare metadata
if batch.spec_info.verified_id is not None:
batch.spec_info.prepare_extend_after_decode(
batch,
self.speculative_num_steps,
)
else:
batch = batch.copy()
batch.prepare_for_idle()
hidden_size = (
self.model_config.hidden_size * 3
if self.speculative_algorithm.is_eagle3()
else self.model_config.hidden_size
)
batch.spec_info = EagleDraftInput.create_idle_input(
device=self.device,
hidden_size=hidden_size,
dtype=self.model_config.dtype,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)

if not input_is_idle and batch.spec_info.verified_id.numel() == 0:
batch = batch.copy()
batch.prepare_for_idle()
hidden_size = (
self.model_config.hidden_size * 3
if self.speculative_algorithm.is_eagle3()
else self.model_config.hidden_size
)
batch.spec_info = EagleDraftInput.create_idle_input(
device=self.device,
hidden_size=hidden_size,
dtype=self.model_config.dtype,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
batch.spec_info.prepare_extend_after_decode(
batch,
self.speculative_num_steps,
)
batch.forward_mode = (
ForwardMode.DRAFT_EXTEND
if not batch.forward_mode.is_idle()
else ForwardMode.IDLE
)

batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_num_draft_tokens = self.speculative_num_steps + 1
Expand Down