Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def capture_one_batch_size(self, num_seqs: int, forward: Callable):
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens.sum(),
seq_lens_sum=seq_lens.sum().item(),
return_logprob=False,
positions=positions,
spec_algorithm=self.model_runner.spec_algorithm,
Expand Down Expand Up @@ -209,7 +209,7 @@ def replay(self, forward_batch: ForwardBatch):
forward_batch.positions = self.positions[:num_tokens]

# Special handle for seq_len_cpu used when flashinfer mla is used
if (forward_batch.seq_lens_cpu is not None) and (bs != raw_bs):
if forward_batch.seq_lens_cpu is not None and bs != raw_bs:
self.seq_lens_cpu.fill_(1)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens.sum(),
seq_lens_sum=seq_lens.sum().item(),
return_logprob=False,
positions=positions,
spec_algorithm=self.model_runner.spec_algorithm,
Expand Down
98 changes: 50 additions & 48 deletions python/sglang/srt/speculative/eagle_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import logging
import os
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional
from typing import List, Optional

import torch
import torch.nn.functional as F
Expand All @@ -12,6 +14,7 @@
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor
from sglang.srt.managers.schedule_batch import (
Req,
ScheduleBatch,
Expand All @@ -20,7 +23,6 @@
)
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2

Expand All @@ -34,15 +36,15 @@
elif is_hip():
from sgl_kernel import verify_tree_greedy

if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch

import logging

logger = logging.getLogger(__name__)


# Simulate acceptance length for benchmarking purposes
SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial")

TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly


@dataclass
Expand Down Expand Up @@ -84,9 +86,9 @@ def prepare_extend_after_decode(
self,
batch: ScheduleBatch,
speculative_num_steps: int,
context_length: int,
pad_input: bool = False,
):
assert len(self.verified_id) == len(batch.out_cache_loc)
accept_length_cpu = batch.spec_info.accept_length_cpu
batch.extend_lens = [x + 1 for x in accept_length_cpu]
batch.extend_num_tokens = sum(batch.extend_lens)
Expand All @@ -112,49 +114,49 @@ def prepare_extend_after_decode(
batch.input_ids = self.verified_id
self.verified_id = new_verified_id

if pad_input:
batch_size = sum(not req.finished() for req in batch.reqs)
# Total constant input length after padding
static_len = speculative_num_steps + 1
# Total size after padding
padded_input_size = batch_size * static_len

padded_len = padded_input_size - batch.input_ids.shape[0]
if padded_len > 0:
new_input_ids = torch.nn.functional.pad(
batch.input_ids, (0, padded_len), value=0
)
position_padding = torch.arange(
padded_len, device=self.positions.device
)
new_positions = torch.cat([self.positions, position_padding])

# need dummy hidden states for the padded positions
hidden_states_dim = self.hidden_states.shape[-1]
new_hidden_states = torch.cat(
[
self.hidden_states,
torch.zeros(
(padded_len, hidden_states_dim),
dtype=self.hidden_states.dtype,
device=self.hidden_states.device,
),
],
dim=0,
)
if not pad_input:
return

# allocate KV cache location for the padded tokens
padded_cache_loc = torch.zeros(
padded_len,
dtype=batch.out_cache_loc.dtype,
device=batch.out_cache_loc.device,
)
new_out_cache_loc = torch.cat([batch.out_cache_loc, padded_cache_loc])
batch_size = sum(not req.finished() for req in batch.reqs)
# Total constant input length after padding
static_len = speculative_num_steps + 1
# Total size after padding
padded_input_size = batch_size * static_len

padded_len = padded_input_size - batch.input_ids.shape[0]
if padded_len > 0:
new_input_ids = torch.nn.functional.pad(
batch.input_ids, (0, padded_len), value=0
)
position_padding = torch.arange(padded_len, device=self.positions.device)
new_positions = torch.cat([self.positions, position_padding])

# need dummy hidden states for the padded positions
hidden_states_dim = self.hidden_states.shape[-1]
new_hidden_states = torch.cat(
[
self.hidden_states,
torch.zeros(
(padded_len, hidden_states_dim),
dtype=self.hidden_states.dtype,
device=self.hidden_states.device,
),
],
dim=0,
)

# allocate KV cache location for the padded tokens
padded_cache_loc = torch.zeros(
padded_len,
dtype=batch.out_cache_loc.dtype,
device=batch.out_cache_loc.device,
)
new_out_cache_loc = torch.cat([batch.out_cache_loc, padded_cache_loc])

batch.input_ids = new_input_ids
self.hidden_states = new_hidden_states
self.positions = new_positions
batch.out_cache_loc = new_out_cache_loc
batch.input_ids = new_input_ids
self.hidden_states = new_hidden_states
self.positions = new_positions
batch.out_cache_loc = new_out_cache_loc

def generate_attn_arg_prefill(
self,
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/speculative/eagle_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,7 @@ def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
batch.spec_info.prepare_extend_after_decode(
batch,
self.speculative_num_steps,
self.server_args.context_length,
pad_input=self.cuda_graph_runner_for_draft_extend is not None,
)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
Expand Down
4 changes: 4 additions & 0 deletions test/srt/test_eagle_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_logprob_check,
)
Expand Down Expand Up @@ -578,6 +579,7 @@ def setUpClass(cls):
)


@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
class TestEAGLEDraftExtend(CustomTestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -669,6 +671,7 @@ def setUpClass(cls):
cls.accept_len_threshold = 1.50


@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -697,6 +700,7 @@ def setUpClass(cls):
cls.accept_len_threshold = 1.50


@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
class TestEAGLEDraftExtendFlashinferMLA(TestEAGLEDraftExtend):
@classmethod
def setUpClass(cls):
Expand Down
Loading