Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/hicache/bench_multiturn.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def __init__(self, args):
tokenizer=self.tokenizer,
dataset_path=args.dataset_path,
)
self.candidate_inputs = [i[0] for i in self.candidate_inputs]
self.candidate_inputs = [i.prompt for i in self.candidate_inputs]

init_requests = [
(i, gen_payload(self.candidate_inputs[i], args.output_length))
Expand Down
39 changes: 28 additions & 11 deletions python/sglang/srt/managers/cache_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,37 @@

class LayerDoneCounter:
def __init__(self, num_layers):
self.counter = num_layers
self.condition = threading.Condition()
self.num_layers = num_layers
# extra producer and consumer counters for overlap mode
self.num_counters = 3
self.counters = [num_layers] * self.num_counters
self.conditions = [threading.Condition() for _ in range(self.num_counters)]
self.producer_index = 0
self.consumer_index = 0

def next_producer(self):
return (self.producer_index + 1) % self.num_counters

def update_producer(self):
self.producer_index = self.next_producer()
return self.producer_index

def set_consumer(self, index):
self.consumer_index = index

def increment(self):
with self.condition:
self.counter += 1
self.condition.notify_all()
with self.conditions[self.producer_index]:
self.counters[self.producer_index] += 1
self.conditions[self.producer_index].notify_all()

def wait_until(self, threshold):
with self.condition:
while self.counter <= threshold:
self.condition.wait()
with self.conditions[self.consumer_index]:
while self.counters[self.consumer_index] <= threshold:
self.conditions[self.consumer_index].wait()

def reset(self):
with self.condition:
self.counter = 0
with self.conditions[self.producer_index]:
self.counters[self.producer_index] = 0


class CacheOperation:
Expand Down Expand Up @@ -295,7 +310,6 @@ def load_thread_func_direct(self):
while not self.stop_event.is_set():
try:
operation = self.load_queue.get(block=True, timeout=1)
# time.sleep(18e-6 * len(operation.host_indices))
operation.data = self.mem_pool_host.get_flat_data(
operation.host_indices
)
Expand All @@ -319,6 +333,7 @@ def load_thread_func_layer_by_layer(self):
if not self.load_cache_event.is_set():
continue
self.load_cache_event.clear()
self.layer_done_counter.update_producer()

batch_operation = None
while self.load_queue.qsize() > 0:
Expand All @@ -330,6 +345,7 @@ def load_thread_func_layer_by_layer(self):
if batch_operation is None:
continue

# start layer-wise KV cache transfer from CPU to GPU
self.layer_done_counter.reset()
for i in range(self.mem_pool_host.layer_num):
if self.page_size == 1:
Expand Down Expand Up @@ -465,6 +481,7 @@ def _pin_op(op_, put=True):
except Exception as e:
logger.error(e)

# todo (zhiqiang): double buffering to be deprecated
def write_thread_func_buffer(self):
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
aux_thread.start()
Expand Down
13 changes: 5 additions & 8 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,14 +655,6 @@ def init_next_round_input(
self.prefix_indices, self.last_node = tree_cache.match_prefix(
rid=self.rid, key=self.adjust_max_prefix_ids()
)
elif enable_hierarchical_cache:
# in case last_node is evicted during scheduling, we need to update the prefix_indices
while self.last_node.evicted:
self.prefix_indices = self.prefix_indices[
: -len(self.last_node.host_value)
]
self.last_node = self.last_node.parent

Comment on lines -662 to -669
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi I'm trying to understand better how the hierarchical cache transfers are happening. What were the invariants regarding last_node / ref counts / eviction before, what are they now, and why can this now be deleted?

self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)

def adjust_max_prefix_ids(self):
Expand Down Expand Up @@ -903,6 +895,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Whether to return hidden states
return_hidden_states: bool = False

# hicache pointer for synchronizing data loading from CPU to GPU
hicache_consumer_index: int = 0

@classmethod
def init_new(
cls,
Expand Down Expand Up @@ -1705,6 +1700,7 @@ def get_model_worker_batch(
input_embeds=self.input_embeds,
spec_algorithm=self.spec_algorithm,
spec_info=self.spec_info,
hicache_consumer_index=self.hicache_consumer_index,
capture_hidden_mode=(
CaptureHiddenMode.FULL
if self.return_hidden_states
Expand Down Expand Up @@ -1801,6 +1797,7 @@ class ModelWorkerBatch:
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None
hicache_consumer_index: int = 0

# Overlap event
launch_done: Optional[threading.Event] = None
Expand Down
23 changes: 19 additions & 4 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,10 @@ def init_memory_pool_and_cache(self):
hicache_size=server_args.hicache_size,
hicache_write_policy=server_args.hicache_write_policy,
)
self.tp_worker.register_hicache_layer_transfer_counter(
self.tree_cache.cache_controller.layer_done_counter
)

else:
self.tree_cache = RadixCache(
req_to_token_pool=self.req_to_token_pool,
Expand Down Expand Up @@ -1460,8 +1464,13 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
self.running_batch.batch_is_full = True
break

# bypass prefix_computed if enable_hierarchical_cache
req.init_next_round_input(
None if prefix_computed else self.tree_cache,
(
None
if (prefix_computed and not self.enable_hierarchical_cache)
else self.tree_cache
),
self.enable_hierarchical_cache,
)

Expand Down Expand Up @@ -1494,9 +1503,6 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
x for x in self.waiting_queue if x not in set(can_run_list)
]

if self.enable_hierarchical_cache:
self.tree_cache.ready_to_load_cache()

if adder.new_chunked_req is not None:
assert self.chunked_req is None
self.chunked_req = adder.new_chunked_req
Expand All @@ -1520,6 +1526,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
self.server_args.enable_custom_logit_processor,
chunked_req=self.chunked_req,
)
if self.enable_hierarchical_cache:
# todo (zhiqiang): disable cuda graph execution if hicache loading triggered
new_batch.hicache_consumer_index = self.tree_cache.ready_to_load_cache()

new_batch.prepare_for_extend()

# Mixed-style chunked prefill
Expand Down Expand Up @@ -1595,6 +1605,11 @@ def run_batch(
if self.is_generation:
if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch()

# update the consumer index of hicache to the running batch
self.tp_worker.set_hicache_consumer(
model_worker_batch.hicache_consumer_index
)
if self.pp_group.is_last_rank:
logits_output, next_token_ids, can_run_cuda_graph = (
self.tp_worker.forward_batch_generation(model_worker_batch)
Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,15 @@ def __init__(
# A reference make this class has the same member as TpModelWorkerClient
self.worker = self

self.hicache_layer_transfer_counter = None

def register_hicache_layer_transfer_counter(self, counter):
self.hicache_layer_transfer_counter = counter

def set_hicache_consumer(self, consumer_index):
if self.hicache_layer_transfer_counter is not None:
self.hicache_layer_transfer_counter.set_consumer(consumer_index)

def get_worker_info(self):
return (
self.max_total_num_tokens,
Expand Down
11 changes: 11 additions & 0 deletions python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,15 @@ def __init__(
if self.device == "cpu":
self.scheduler_stream.synchronize = lambda: None # No-op for CPU

self.hicache_layer_transfer_counter = None

def register_hicache_layer_transfer_counter(self, counter):
self.hicache_layer_transfer_counter = counter

def set_hicache_consumer(self, consumer_index):
if self.hicache_layer_transfer_counter is not None:
self.hicache_layer_transfer_counter.set_consumer(consumer_index)

def get_worker_info(self):
return self.worker.get_worker_info()

Expand Down Expand Up @@ -146,6 +155,8 @@ def forward_thread_func_(self):
input_ids = model_worker_batch.input_ids
resolve_future_token_ids(input_ids, self.future_token_ids_map)

# update the consumer index of hicache to the running batch
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
# Run forward
logits_output, next_token_ids, can_run_cuda_graph = (
self.worker.forward_batch_generation(
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/mem_cache/hiradix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,9 @@ def init_load_back(
return last_node, prefix_indices

def ready_to_load_cache(self):
producer_index = self.cache_controller.layer_done_counter.next_producer()
self.load_cache_event.set()
return producer_index

def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
Expand Down Expand Up @@ -370,6 +372,7 @@ def _split_node(self, key, child: TreeNode, split_len: int):
new_node.lock_ref = child.lock_ref
new_node.key = child.key[:split_len]
new_node.loading = child.loading
new_node.hit_count = child.hit_count

# split value and host value if exists
if child.evicted:
Expand Down