Skip to content

Commit 8cbbb3c

Browse files
ByronHsujianan-gu
authored andcommitted
Add decode req pool (sgl-project#6980)
1 parent d2f4f63 commit 8cbbb3c

2 files changed

Lines changed: 83 additions & 7 deletions

File tree

python/sglang/srt/disaggregation/decode.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from collections import deque
2626
from dataclasses import dataclass
2727
from http import HTTPStatus
28-
from typing import TYPE_CHECKING, List, Optional, Tuple
28+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
2929

3030
import numpy as np
3131
import torch
@@ -49,6 +49,7 @@
4949
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
5050
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
5151
from sglang.srt.model_executor.forward_batch_info import ForwardMode
52+
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
5253

5354
logger = logging.getLogger(__name__)
5455

@@ -57,6 +58,67 @@
5758
from sglang.srt.managers.scheduler import Scheduler
5859

5960

61+
class DecodeReqToTokenPool:
62+
"""
63+
The difference of DecodeReqToTokenPool and ReqToTokenPool is that
64+
DecodeReqToTokenPool subscribes memory for pre-allocated requests.
65+
66+
In ReqToTokenPool, if `--max-running-requests` is 8,
67+
#pre-allocated + #transfer + #running <= 8, but there are in fact more memory can carry pre-allocated requests.
68+
69+
In DecodeReqToTokenPool, if `--max-running-requests` is 8,
70+
#running <= 8, #pre-allocated + #transfer <= pre_alloc_size, so we can use the free memory to pre-allocate requests to unblock prefill.
71+
"""
72+
73+
def __init__(
74+
self,
75+
size: int,
76+
max_context_len: int,
77+
device: str,
78+
enable_memory_saver: bool,
79+
pre_alloc_size: int,
80+
):
81+
memory_saver_adapter = TorchMemorySaverAdapter.create(
82+
enable=enable_memory_saver
83+
)
84+
85+
self.size = size
86+
self.max_context_len = max_context_len
87+
self.device = device
88+
self.pre_alloc_size = pre_alloc_size
89+
with memory_saver_adapter.region():
90+
self.req_to_token = torch.zeros(
91+
(size + pre_alloc_size, max_context_len),
92+
dtype=torch.int32,
93+
device=device,
94+
)
95+
96+
self.free_slots = list(range(size + pre_alloc_size))
97+
98+
def write(self, indices, values):
99+
self.req_to_token[indices] = values
100+
101+
def available_size(self):
102+
return len(self.free_slots)
103+
104+
def alloc(self, need_size: int) -> List[int]:
105+
if need_size > len(self.free_slots):
106+
return None
107+
108+
select_index = self.free_slots[:need_size]
109+
self.free_slots = self.free_slots[need_size:]
110+
return select_index
111+
112+
def free(self, free_index: Union[int, List[int]]):
113+
if isinstance(free_index, (int,)):
114+
self.free_slots.append(free_index)
115+
else:
116+
self.free_slots.extend(free_index)
117+
118+
def clear(self):
119+
self.free_slots = list(range(self.size + self.pre_alloc_size))
120+
121+
60122
@dataclass
61123
class DecodeRequest:
62124
req: Req

python/sglang/srt/model_executor/model_runner.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -916,12 +916,26 @@ def init_memory_pool(
916916
)
917917

918918
if self.req_to_token_pool is None:
919-
self.req_to_token_pool = ReqToTokenPool(
920-
size=max_num_reqs,
921-
max_context_len=self.model_config.context_len + 4,
922-
device=self.device,
923-
enable_memory_saver=self.server_args.enable_memory_saver,
924-
)
919+
if self.server_args.disaggregation_mode == "decode":
920+
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
921+
922+
# subscribe memory for pre-allocated requests
923+
# if max_num_reqs <= 32, we pre-allocate 2x requests
924+
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
925+
self.req_to_token_pool = DecodeReqToTokenPool(
926+
size=max_num_reqs,
927+
max_context_len=self.model_config.context_len + 4,
928+
device=self.device,
929+
enable_memory_saver=self.server_args.enable_memory_saver,
930+
pre_alloc_size=pre_alloc_size,
931+
)
932+
else:
933+
self.req_to_token_pool = ReqToTokenPool(
934+
size=max_num_reqs,
935+
max_context_len=self.model_config.context_len + 4,
936+
device=self.device,
937+
enable_memory_saver=self.server_args.enable_memory_saver,
938+
)
925939
else:
926940
# Draft worker shares req_to_token_pool with the target worker.
927941
assert self.is_draft_worker

0 commit comments

Comments
 (0)