diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 84c96d91df0b..2bd0e963e716 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -516,6 +516,13 @@ def maybe_pull_model_tokenizer_from_remote(self) -> None: self.model_weights = self.model_path self.model_path = client.get_local_dir() + @property + def is_mrope_enabled(self) -> bool: + return ( + "rope_scaling" in self.hf_text_config + and "mrope_section" in self.hf_text_config.rope_scaling + ) + # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py _STR_DTYPE_TO_TORCH_DTYPE = { diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index ddc405c4819f..9ed2f984ab24 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -596,6 +596,7 @@ def pop_transferred(self) -> List[Req]: output_top_logprobs_val, output_top_logprobs_idx, output_hidden_states, + mrope_position_delta, ) = self.metadata_buffers.get_buf(idx) decode_req.req.output_ids.append(output_id[0].item()) @@ -619,6 +620,10 @@ def pop_transferred(self) -> List[Req]: ].tolist() ) + if mrope_position_delta is not None: + decode_req.req.decode_mrope_position_delta = mrope_position_delta[ + 0 + ].item() if hasattr(decode_req.kv_receiver, "clear"): decode_req.kv_receiver.clear() diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index 720c9d5a59e9..b4fddcc07c1d 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -92,6 +92,7 @@ def __init__( dtype: torch.dtype, max_top_logprobs_num: int = 128, custom_mem_pool: torch.cuda.MemPool = None, + is_mrope_enabled: bool = False, ): self.custom_mem_pool = custom_mem_pool device = "cpu" @@ -100,6 +101,7 @@ def __init__( device = "npu" elif self.custom_mem_pool: device = "cuda" + self.is_mrope_enabled = is_mrope_enabled with ( torch.cuda.use_mem_pool(self.custom_mem_pool) if self.custom_mem_pool @@ -126,6 +128,10 @@ def __init__( self.output_hidden_states = torch.zeros( (size, hidden_size), dtype=dtype, device=device ) + if is_mrope_enabled: + self.mrope_position_delta = torch.zeros( + (size, 16), dtype=torch.int32, device=device + ) def get_buf_infos(self): ptrs = [ @@ -152,6 +158,10 @@ def get_buf_infos(self): self.output_top_logprobs_idx[0].nbytes, self.output_hidden_states[0].nbytes, ] + if self.is_mrope_enabled: + ptrs.append(self.mrope_position_delta.data_ptr()) + data_lens.append(self.mrope_position_delta.nbytes) + item_lens.append(self.mrope_position_delta[0].nbytes) return ptrs, data_lens, item_lens def get_buf(self, idx: int): @@ -162,6 +172,7 @@ def get_buf(self, idx: int): self.output_top_logprobs_val[idx], self.output_top_logprobs_idx[idx], self.output_hidden_states[idx], + self.mrope_position_delta[idx] if self.is_mrope_enabled else None, ) def set_buf(self, req: Req): @@ -194,6 +205,16 @@ def set_buf(self, req: Req): self.output_hidden_states[req.metadata_buffer_index].copy_( req.hidden_states_tensor ) + if self.is_mrope_enabled: + if ( + req.multimodal_inputs is not None + and req.multimodal_inputs.mrope_position_delta is not None + ): + self.mrope_position_delta[req.metadata_buffer_index][ + 0 + ] = req.multimodal_inputs.mrope_position_delta + else: + self.mrope_position_delta[req.metadata_buffer_index][0] = 0 ######################### diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 536198cd27b4..6a45b7b75d33 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -496,6 +496,9 @@ def __init__( # For multimodal inputs self.multimodal_inputs: Optional[MultimodalInputs] = None + # For qwen-vl series, P-D disaggregation, decode mode. + self.decode_mrope_position_delta: Optional[int] = None + # Prefix info # The indices to kv cache for the shared prefix. self.prefix_indices: torch.Tensor = [] @@ -1728,6 +1731,9 @@ def get_model_worker_batch( extend_prefix_lens=extend_prefix_lens, extend_logprob_start_lens=extend_logprob_start_lens, multimodal_inputs=self.multimodal_inputs, + decode_mrope_position_delta=[ + req.decode_mrope_position_delta for req in self.reqs + ], encoder_cached=self.encoder_cached, encoder_lens=self.encoder_lens, encoder_lens_cpu=self.encoder_lens_cpu, @@ -1866,6 +1872,9 @@ class ModelWorkerBatch: # For multimodal multimodal_inputs: Optional[List[MultimodalInputs]] + # For qwen-vl series, P-D disaggregation, decode mode. + decode_mrope_position_delta: Optional[List[int]] + # For encoder-decoder encoder_cached: Optional[List[bool]] encoder_lens: Optional[torch.Tensor] diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e6dd80d717ad..1699c2786c12 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -714,6 +714,7 @@ def init_disaggregation(self): hidden_size=self.model_config.hf_text_config.hidden_size, dtype=self.model_config.dtype, custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(), + is_mrope_enabled=self.model_config.is_mrope_enabled, ) # The decode requests polling kv cache @@ -763,6 +764,7 @@ def init_disaggregation(self): hidden_size=self.model_config.hf_text_config.hidden_size, dtype=self.model_config.dtype, custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(), + is_mrope_enabled=self.model_config.is_mrope_enabled, ) self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue( diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 6f3ea547477f..29e024024f3d 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -232,6 +232,9 @@ class ForwardBatch: # For multimodal mm_inputs: Optional[List[MultimodalInputs]] = None + # For qwen-vl series, P-D disaggregation, decode mode. + decode_mrope_position_delta: Optional[List[int]] = None + # Encoder-decoder encoder_cached: Optional[List[bool]] = None encoder_lens: Optional[torch.Tensor] = None @@ -304,6 +307,7 @@ def init_new( seq_lens=batch.seq_lens, out_cache_loc=batch.out_cache_loc, mm_inputs=batch.multimodal_inputs, + decode_mrope_position_delta=batch.decode_mrope_position_delta, encoder_cached=batch.encoder_cached, encoder_lens=batch.encoder_lens, encoder_lens_cpu=batch.encoder_lens_cpu, @@ -488,20 +492,31 @@ def _compute_mrope_positions( batch_size = self.seq_lens.shape[0] mrope_positions_list = [[]] * batch_size for batch_idx in range(batch_size): - mm_input = batch.multimodal_inputs[batch_idx] + mm_input = ( + batch.multimodal_inputs[batch_idx] + if batch.multimodal_inputs is not None + else None + ) if self.forward_mode.is_decode(): - mrope_position_deltas = ( - [0] - if mm_input is None - else flatten_nested_list(mm_input.mrope_position_delta.tolist()) - ) + # priorly computed mrope position delta in the batch scheduler is set only + # in PD disaggregation, decode mode. So we use it directly if available. + # NOTE: mrope_position_delta can be None on decode warmup, and + # the 'or 0' here prevents None to be accessed by the function here. + if batch.decode_mrope_position_delta: + mrope_position_deltas = [batch.decode_mrope_position_delta[batch_idx]] + else: + mrope_position_deltas = ( + [0] + if mm_input is None + else flatten_nested_list(mm_input.mrope_position_delta.tolist()) + ) next_input_positions = [] for mrope_position_delta in mrope_position_deltas: # batched deltas needs to be processed separately # Convert list of lists to tensor with shape [3, seq_len] next_input_positions += [ MRotaryEmbedding.get_next_input_positions( - mrope_position_delta, + mrope_position_delta or 0, int(self.seq_lens[batch_idx]) - 1, int(self.seq_lens[batch_idx]), )