Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,13 @@ def maybe_pull_model_tokenizer_from_remote(self) -> None:
self.model_weights = self.model_path
self.model_path = client.get_local_dir()

@property
def is_mrope_enabled(self) -> bool:
return (
"rope_scaling" in self.hf_text_config
and "mrope_section" in self.hf_text_config.rope_scaling
)


# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
_STR_DTYPE_TO_TORCH_DTYPE = {
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ def pop_transferred(self) -> List[Req]:
output_top_logprobs_val,
output_top_logprobs_idx,
output_hidden_states,
mrope_position_delta,
) = self.metadata_buffers.get_buf(idx)

decode_req.req.output_ids.append(output_id[0].item())
Expand All @@ -619,6 +620,10 @@ def pop_transferred(self) -> List[Req]:
].tolist()
)

if mrope_position_delta is not None:
decode_req.req.decode_mrope_position_delta = mrope_position_delta[
0
].item()
if hasattr(decode_req.kv_receiver, "clear"):
decode_req.kv_receiver.clear()

Expand Down
21 changes: 21 additions & 0 deletions python/sglang/srt/disaggregation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(
dtype: torch.dtype,
max_top_logprobs_num: int = 128,
custom_mem_pool: torch.cuda.MemPool = None,
is_mrope_enabled: bool = False,
):
self.custom_mem_pool = custom_mem_pool
device = "cpu"
Expand All @@ -100,6 +101,7 @@ def __init__(
device = "npu"
elif self.custom_mem_pool:
device = "cuda"
self.is_mrope_enabled = is_mrope_enabled
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.custom_mem_pool
Expand All @@ -126,6 +128,10 @@ def __init__(
self.output_hidden_states = torch.zeros(
(size, hidden_size), dtype=dtype, device=device
)
if is_mrope_enabled:
self.mrope_position_delta = torch.zeros(
(size, 16), dtype=torch.int32, device=device
)

def get_buf_infos(self):
ptrs = [
Expand All @@ -152,6 +158,10 @@ def get_buf_infos(self):
self.output_top_logprobs_idx[0].nbytes,
self.output_hidden_states[0].nbytes,
]
if self.is_mrope_enabled:
ptrs.append(self.mrope_position_delta.data_ptr())
data_lens.append(self.mrope_position_delta.nbytes)
item_lens.append(self.mrope_position_delta[0].nbytes)
return ptrs, data_lens, item_lens

def get_buf(self, idx: int):
Expand All @@ -162,6 +172,7 @@ def get_buf(self, idx: int):
self.output_top_logprobs_val[idx],
self.output_top_logprobs_idx[idx],
self.output_hidden_states[idx],
self.mrope_position_delta[idx] if self.is_mrope_enabled else None,
)

def set_buf(self, req: Req):
Expand Down Expand Up @@ -194,6 +205,16 @@ def set_buf(self, req: Req):
self.output_hidden_states[req.metadata_buffer_index].copy_(
req.hidden_states_tensor
)
if self.is_mrope_enabled:
if (
req.multimodal_inputs is not None
and req.multimodal_inputs.mrope_position_delta is not None
):
self.mrope_position_delta[req.metadata_buffer_index][
0
] = req.multimodal_inputs.mrope_position_delta
else:
self.mrope_position_delta[req.metadata_buffer_index][0] = 0


#########################
Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,9 @@ def __init__(
# For multimodal inputs
self.multimodal_inputs: Optional[MultimodalInputs] = None

# For qwen-vl series, P-D disaggregation, decode mode.
self.decode_mrope_position_delta: Optional[int] = None

# Prefix info
# The indices to kv cache for the shared prefix.
self.prefix_indices: torch.Tensor = []
Expand Down Expand Up @@ -1728,6 +1731,9 @@ def get_model_worker_batch(
extend_prefix_lens=extend_prefix_lens,
extend_logprob_start_lens=extend_logprob_start_lens,
multimodal_inputs=self.multimodal_inputs,
decode_mrope_position_delta=[
req.decode_mrope_position_delta for req in self.reqs
],
encoder_cached=self.encoder_cached,
encoder_lens=self.encoder_lens,
encoder_lens_cpu=self.encoder_lens_cpu,
Expand Down Expand Up @@ -1866,6 +1872,9 @@ class ModelWorkerBatch:
# For multimodal
multimodal_inputs: Optional[List[MultimodalInputs]]

# For qwen-vl series, P-D disaggregation, decode mode.
decode_mrope_position_delta: Optional[List[int]]

# For encoder-decoder
encoder_cached: Optional[List[bool]]
encoder_lens: Optional[torch.Tensor]
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,7 @@ def init_disaggregation(self):
hidden_size=self.model_config.hf_text_config.hidden_size,
dtype=self.model_config.dtype,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
is_mrope_enabled=self.model_config.is_mrope_enabled,
)

# The decode requests polling kv cache
Expand Down Expand Up @@ -763,6 +764,7 @@ def init_disaggregation(self):
hidden_size=self.model_config.hf_text_config.hidden_size,
dtype=self.model_config.dtype,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
is_mrope_enabled=self.model_config.is_mrope_enabled,
)

self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
Expand Down
29 changes: 22 additions & 7 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ class ForwardBatch:
# For multimodal
mm_inputs: Optional[List[MultimodalInputs]] = None

# For qwen-vl series, P-D disaggregation, decode mode.
decode_mrope_position_delta: Optional[List[int]] = None

# Encoder-decoder
encoder_cached: Optional[List[bool]] = None
encoder_lens: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -304,6 +307,7 @@ def init_new(
seq_lens=batch.seq_lens,
out_cache_loc=batch.out_cache_loc,
mm_inputs=batch.multimodal_inputs,
decode_mrope_position_delta=batch.decode_mrope_position_delta,
encoder_cached=batch.encoder_cached,
encoder_lens=batch.encoder_lens,
encoder_lens_cpu=batch.encoder_lens_cpu,
Expand Down Expand Up @@ -488,20 +492,31 @@ def _compute_mrope_positions(
batch_size = self.seq_lens.shape[0]
mrope_positions_list = [[]] * batch_size
for batch_idx in range(batch_size):
mm_input = batch.multimodal_inputs[batch_idx]
mm_input = (
batch.multimodal_inputs[batch_idx]
if batch.multimodal_inputs is not None
else None
)
if self.forward_mode.is_decode():
mrope_position_deltas = (
[0]
if mm_input is None
else flatten_nested_list(mm_input.mrope_position_delta.tolist())
)
# priorly computed mrope position delta in the batch scheduler is set only
# in PD disaggregation, decode mode. So we use it directly if available.
# NOTE: mrope_position_delta can be None on decode warmup, and
# the 'or 0' here prevents None to be accessed by the function here.
if batch.decode_mrope_position_delta:
mrope_position_deltas = [batch.decode_mrope_position_delta[batch_idx]]
else:
mrope_position_deltas = (
[0]
if mm_input is None
else flatten_nested_list(mm_input.mrope_position_delta.tolist())
)
next_input_positions = []
for mrope_position_delta in mrope_position_deltas:
# batched deltas needs to be processed separately
# Convert list of lists to tensor with shape [3, seq_len]
next_input_positions += [
MRotaryEmbedding.get_next_input_positions(
mrope_position_delta,
mrope_position_delta or 0,
int(self.seq_lens[batch_idx]) - 1,
int(self.seq_lens[batch_idx]),
)
Expand Down