|
25 | 25 | from collections import deque |
26 | 26 | from dataclasses import dataclass |
27 | 27 | from http import HTTPStatus |
28 | | -from typing import TYPE_CHECKING, List, Optional, Tuple |
| 28 | +from typing import TYPE_CHECKING, List, Optional, Tuple, Union |
29 | 29 |
|
30 | 30 | import numpy as np |
31 | 31 | import torch |
|
49 | 49 | from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache |
50 | 50 | from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator |
51 | 51 | from sglang.srt.model_executor.forward_batch_info import ForwardMode |
| 52 | +from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter |
52 | 53 |
|
53 | 54 | logger = logging.getLogger(__name__) |
54 | 55 |
|
|
57 | 58 | from sglang.srt.managers.scheduler import Scheduler |
58 | 59 |
|
59 | 60 |
|
| 61 | +class DecodeReqToTokenPool: |
| 62 | + """ |
| 63 | + The difference of DecodeReqToTokenPool and ReqToTokenPool is that |
| 64 | + DecodeReqToTokenPool subscribes memory for pre-allocated requests. |
| 65 | +
|
| 66 | + In ReqToTokenPool, if `--max-running-requests` is 8, |
| 67 | + #pre-allocated + #transfer + #running <= 8, but there are in fact more memory can carry pre-allocated requests. |
| 68 | +
|
| 69 | + In DecodeReqToTokenPool, if `--max-running-requests` is 8, |
| 70 | + #running <= 8, #pre-allocated + #transfer <= pre_alloc_size, so we can use the free memory to pre-allocate requests to unblock prefill. |
| 71 | + """ |
| 72 | + |
| 73 | + def __init__( |
| 74 | + self, |
| 75 | + size: int, |
| 76 | + max_context_len: int, |
| 77 | + device: str, |
| 78 | + enable_memory_saver: bool, |
| 79 | + pre_alloc_size: int, |
| 80 | + ): |
| 81 | + memory_saver_adapter = TorchMemorySaverAdapter.create( |
| 82 | + enable=enable_memory_saver |
| 83 | + ) |
| 84 | + |
| 85 | + self.size = size |
| 86 | + self.max_context_len = max_context_len |
| 87 | + self.device = device |
| 88 | + self.pre_alloc_size = pre_alloc_size |
| 89 | + with memory_saver_adapter.region(): |
| 90 | + self.req_to_token = torch.zeros( |
| 91 | + (size + pre_alloc_size, max_context_len), |
| 92 | + dtype=torch.int32, |
| 93 | + device=device, |
| 94 | + ) |
| 95 | + |
| 96 | + self.free_slots = list(range(size + pre_alloc_size)) |
| 97 | + |
| 98 | + def write(self, indices, values): |
| 99 | + self.req_to_token[indices] = values |
| 100 | + |
| 101 | + def available_size(self): |
| 102 | + return len(self.free_slots) |
| 103 | + |
| 104 | + def alloc(self, need_size: int) -> List[int]: |
| 105 | + if need_size > len(self.free_slots): |
| 106 | + return None |
| 107 | + |
| 108 | + select_index = self.free_slots[:need_size] |
| 109 | + self.free_slots = self.free_slots[need_size:] |
| 110 | + return select_index |
| 111 | + |
| 112 | + def free(self, free_index: Union[int, List[int]]): |
| 113 | + if isinstance(free_index, (int,)): |
| 114 | + self.free_slots.append(free_index) |
| 115 | + else: |
| 116 | + self.free_slots.extend(free_index) |
| 117 | + |
| 118 | + def clear(self): |
| 119 | + self.free_slots = list(range(self.size + self.pre_alloc_size)) |
| 120 | + |
| 121 | + |
60 | 122 | @dataclass |
61 | 123 | class DecodeRequest: |
62 | 124 | req: Req |
|
0 commit comments