Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from collections import deque
from dataclasses import dataclass
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -49,6 +49,7 @@
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter

logger = logging.getLogger(__name__)

Expand All @@ -57,6 +58,67 @@
from sglang.srt.managers.scheduler import Scheduler


class DecodeReqToTokenPool:
"""
The difference of DecodeReqToTokenPool and ReqToTokenPool is that
DecodeReqToTokenPool subscribes memory for pre-allocated requests.

In ReqToTokenPool, if `--max-running-requests` is 8,
#pre-allocated + #transfer + #running <= 8, but there are in fact more memory can carry pre-allocated requests.

In DecodeReqToTokenPool, if `--max-running-requests` is 8,
#running <= 8, #pre-allocated + #transfer <= pre_alloc_size, so we can use the free memory to pre-allocate requests to unblock prefill.
"""

def __init__(
self,
size: int,
max_context_len: int,
device: str,
enable_memory_saver: bool,
pre_alloc_size: int,
):
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)

self.size = size
self.max_context_len = max_context_len
self.device = device
self.pre_alloc_size = pre_alloc_size
with memory_saver_adapter.region():
self.req_to_token = torch.zeros(
(size + pre_alloc_size, max_context_len),
dtype=torch.int32,
device=device,
)

self.free_slots = list(range(size + pre_alloc_size))

def write(self, indices, values):
self.req_to_token[indices] = values

def available_size(self):
return len(self.free_slots)

def alloc(self, need_size: int) -> List[int]:
if need_size > len(self.free_slots):
return None

select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
return select_index
Comment on lines +104 to +110
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The alloc method currently uses list slicing (self.free_slots[:need_size] and self.free_slots = self.free_slots[need_size:]) to manage self.free_slots. If self.free_slots can become very large (e.g., thousands of entries, which is possible given max_num_reqs can be large) and alloc is called frequently, these O(N) list operations (where N is the length of free_slots) could potentially become a performance bottleneck.

Have you considered using collections.deque for self.free_slots? A deque would allow for O(1) appends (for free) and O(1) popleft operations. Allocating need_size items would then be O(need_size), which could be more efficient than O(N) list slicing if N is large and need_size is relatively small. This would require changing the type of self.free_slots in __init__ and clear as well.


def free(self, free_index: Union[int, List[int]]):
if isinstance(free_index, (int,)):
self.free_slots.append(free_index)
else:
self.free_slots.extend(free_index)

def clear(self):
self.free_slots = list(range(self.size + self.pre_alloc_size))


@dataclass
class DecodeRequest:
req: Req
Expand Down
26 changes: 20 additions & 6 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,12 +912,26 @@ def init_memory_pool(
)

if self.req_to_token_pool is None:
self.req_to_token_pool = ReqToTokenPool(
size=max_num_reqs,
max_context_len=self.model_config.context_len + 4,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
if self.server_args.disaggregation_mode == "decode":
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool

# subscribe memory for pre-allocated requests
# if max_num_reqs <= 32, we pre-allocate 2x requests
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The numbers 2 (multiplier) and 32 (threshold for max_num_reqs) used to calculate pre_alloc_size are magic numbers. While the comment explains the heuristic (2x for small max_num_reqs), it would improve maintainability and configurability if these were defined as named constants (e.g., at the module level or as part of ServerArgs if they need to be configurable).

Could these be refactored into constants? For example:

# At module level or in a config class
_PRE_ALLOC_SIZE_MULTIPLIER = 2
_MAX_NUM_REQS_THRESHOLD_FOR_PRE_ALLOC = 32

# In the function
pre_alloc_size = (max_num_reqs * _PRE_ALLOC_SIZE_MULTIPLIER 
                  if max_num_reqs <= _MAX_NUM_REQS_THRESHOLD_FOR_PRE_ALLOC 
                  else 0)
Suggested change
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0 # TODO: Consider making 2 and 32 named constants or configurable

self.req_to_token_pool = DecodeReqToTokenPool(
size=max_num_reqs,
max_context_len=self.model_config.context_len + 4,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
pre_alloc_size=pre_alloc_size,
)
else:
self.req_to_token_pool = ReqToTokenPool(
size=max_num_reqs,
max_context_len=self.model_config.context_len + 4,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
else:
# Draft worker shares req_to_token_pool with the target worker.
assert self.is_draft_worker
Expand Down
Loading