Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions python/sglang/bench_one_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from sglang.srt.entrypoints.engine import _set_envs_and_config
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.scheduler import Scheduler
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams
Expand Down Expand Up @@ -184,6 +185,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
req.prefix_indices = []
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1
reqs.append(req)

return input_ids, reqs
Expand All @@ -199,6 +201,7 @@ def prepare_extend_inputs_for_correctness_test(
i, : bench_args.cut_len
]
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1
return reqs


Expand All @@ -220,6 +223,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
req.prefix_indices = []
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1
reqs.append(req)

return reqs
Expand All @@ -238,6 +242,7 @@ def extend(reqs, model_runner):
enable_custom_logit_processor=False,
)
batch.prepare_for_extend()
_maybe_prepare_dp_attn_batch(batch, model_runner)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch)
Expand All @@ -249,13 +254,28 @@ def extend(reqs, model_runner):
def decode(input_token_ids, batch, model_runner):
batch.output_ids = input_token_ids
batch.prepare_for_decode()
_maybe_prepare_dp_attn_batch(batch, model_runner)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, forward_batch)
return next_token_ids, logits_output.next_token_logits


def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
if model_runner.server_args.enable_dp_attention:
Scheduler.prepare_dp_attn_batch_raw(
batch,
dp_size=model_runner.server_args.dp_size,
attn_tp_size=1,
tp_cpu_group=model_runner.tp_group.cpu_group,
get_idle_batch=None,
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
spec_algorithm=SpeculativeAlgorithm.NONE,
speculative_num_draft_tokens=None,
)


def correctness_test(
server_args,
port_args,
Expand Down
36 changes: 29 additions & 7 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,14 +1466,36 @@ def process_batch_result(
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
return self.prepare_dp_attn_batch_raw(
local_batch,
dp_size=self.server_args.dp_size,
attn_tp_size=self.attn_tp_size,
tp_cpu_group=self.tp_cpu_group,
get_idle_batch=self.get_idle_batch,
disable_cuda_graph=self.server_args.disable_cuda_graph,
spec_algorithm=self.spec_algorithm,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
)

@staticmethod
def prepare_dp_attn_batch_raw(
local_batch: ScheduleBatch,
dp_size,
attn_tp_size: int,
tp_cpu_group,
get_idle_batch,
disable_cuda_graph: bool,
spec_algorithm,
speculative_num_draft_tokens,
):
# Check if other DP workers have running batches
if local_batch is None:
num_tokens = 0
global_num_tokens_for_logprob = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle():
num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
num_tokens = num_tokens * speculative_num_draft_tokens
global_num_tokens_for_logprob = num_tokens
else:
num_tokens = local_batch.extend_num_tokens
Expand All @@ -1492,7 +1514,7 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
else:
can_cuda_graph = 0

if not self.spec_algorithm.is_none():
if not spec_algorithm.is_none():
# TODO(sang): Support cuda graph when idle batch is there.
if local_batch is None or local_batch.forward_mode.is_idle():
can_cuda_graph = 0
Expand All @@ -1510,28 +1532,28 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
dtype=torch.int64,
)
global_info = torch.empty(
(self.server_args.dp_size, self.attn_tp_size, 4),
(dp_size, attn_tp_size, 4),
dtype=torch.int64,
)
torch.distributed.all_gather_into_tensor(
global_info.flatten(),
local_info,
group=self.tp_cpu_group,
group=tp_cpu_group,
)
global_num_tokens = global_info[:, 0, 0].tolist()
can_cuda_graph = min(global_info[:, 0, 1].tolist())
global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
is_extend_in_batch = global_info[:, 0, 3].tolist()

if local_batch is None and max(global_num_tokens) > 0:
local_batch = self.get_idle_batch()
local_batch = get_idle_batch()

if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob

# Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph:
if not disable_cuda_graph:
local_batch.can_run_dp_cuda_graph = can_cuda_graph

return local_batch, any(is_extend_in_batch)
Expand Down
Loading