From 257808d934228209f0bff24be88918ed9118fc2d Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Tue, 17 Jun 2025 01:31:33 -0700 Subject: [PATCH] fix layer done counter sync and hiradix pre-compute --- benchmark/hicache/bench_multiturn.py | 2 +- .../sglang/srt/managers/cache_controller.py | 39 +++++++++++++------ python/sglang/srt/managers/schedule_batch.py | 13 +++---- python/sglang/srt/managers/scheduler.py | 23 +++++++++-- python/sglang/srt/managers/tp_worker.py | 9 +++++ .../srt/managers/tp_worker_overlap_thread.py | 11 ++++++ python/sglang/srt/mem_cache/hiradix_cache.py | 3 ++ 7 files changed, 76 insertions(+), 24 deletions(-) diff --git a/benchmark/hicache/bench_multiturn.py b/benchmark/hicache/bench_multiturn.py index a2a88b634abc..5b8d706a399c 100644 --- a/benchmark/hicache/bench_multiturn.py +++ b/benchmark/hicache/bench_multiturn.py @@ -239,7 +239,7 @@ def __init__(self, args): tokenizer=self.tokenizer, dataset_path=args.dataset_path, ) - self.candidate_inputs = [i[0] for i in self.candidate_inputs] + self.candidate_inputs = [i.prompt for i in self.candidate_inputs] init_requests = [ (i, gen_payload(self.candidate_inputs[i], args.output_length)) diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 12d9249981ca..d44c5d5e2d96 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -29,22 +29,37 @@ class LayerDoneCounter: def __init__(self, num_layers): - self.counter = num_layers - self.condition = threading.Condition() + self.num_layers = num_layers + # extra producer and consumer counters for overlap mode + self.num_counters = 3 + self.counters = [num_layers] * self.num_counters + self.conditions = [threading.Condition() for _ in range(self.num_counters)] + self.producer_index = 0 + self.consumer_index = 0 + + def next_producer(self): + return (self.producer_index + 1) % self.num_counters + + def update_producer(self): + self.producer_index = self.next_producer() + return self.producer_index + + def set_consumer(self, index): + self.consumer_index = index def increment(self): - with self.condition: - self.counter += 1 - self.condition.notify_all() + with self.conditions[self.producer_index]: + self.counters[self.producer_index] += 1 + self.conditions[self.producer_index].notify_all() def wait_until(self, threshold): - with self.condition: - while self.counter <= threshold: - self.condition.wait() + with self.conditions[self.consumer_index]: + while self.counters[self.consumer_index] <= threshold: + self.conditions[self.consumer_index].wait() def reset(self): - with self.condition: - self.counter = 0 + with self.conditions[self.producer_index]: + self.counters[self.producer_index] = 0 class CacheOperation: @@ -295,7 +310,6 @@ def load_thread_func_direct(self): while not self.stop_event.is_set(): try: operation = self.load_queue.get(block=True, timeout=1) - # time.sleep(18e-6 * len(operation.host_indices)) operation.data = self.mem_pool_host.get_flat_data( operation.host_indices ) @@ -319,6 +333,7 @@ def load_thread_func_layer_by_layer(self): if not self.load_cache_event.is_set(): continue self.load_cache_event.clear() + self.layer_done_counter.update_producer() batch_operation = None while self.load_queue.qsize() > 0: @@ -330,6 +345,7 @@ def load_thread_func_layer_by_layer(self): if batch_operation is None: continue + # start layer-wise KV cache transfer from CPU to GPU self.layer_done_counter.reset() for i in range(self.mem_pool_host.layer_num): if self.page_size == 1: @@ -465,6 +481,7 @@ def _pin_op(op_, put=True): except Exception as e: logger.error(e) + # todo (zhiqiang): double buffering to be deprecated def write_thread_func_buffer(self): aux_thread = threading.Thread(target=self.write_aux_func, daemon=True) aux_thread.start() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 912018ca9e35..ad297db3575a 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -655,14 +655,6 @@ def init_next_round_input( self.prefix_indices, self.last_node = tree_cache.match_prefix( rid=self.rid, key=self.adjust_max_prefix_ids() ) - elif enable_hierarchical_cache: - # in case last_node is evicted during scheduling, we need to update the prefix_indices - while self.last_node.evicted: - self.prefix_indices = self.prefix_indices[ - : -len(self.last_node.host_value) - ] - self.last_node = self.last_node.parent - self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) def adjust_max_prefix_ids(self): @@ -903,6 +895,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Whether to return hidden states return_hidden_states: bool = False + # hicache pointer for synchronizing data loading from CPU to GPU + hicache_consumer_index: int = 0 + @classmethod def init_new( cls, @@ -1705,6 +1700,7 @@ def get_model_worker_batch( input_embeds=self.input_embeds, spec_algorithm=self.spec_algorithm, spec_info=self.spec_info, + hicache_consumer_index=self.hicache_consumer_index, capture_hidden_mode=( CaptureHiddenMode.FULL if self.return_hidden_states @@ -1801,6 +1797,7 @@ class ModelWorkerBatch: spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None # If set, the output of the batch contains the hidden states of the run. capture_hidden_mode: CaptureHiddenMode = None + hicache_consumer_index: int = 0 # Overlap event launch_done: Optional[threading.Event] = None diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 37f39096cc50..cd8136bd1b21 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -565,6 +565,10 @@ def init_memory_pool_and_cache(self): hicache_size=server_args.hicache_size, hicache_write_policy=server_args.hicache_write_policy, ) + self.tp_worker.register_hicache_layer_transfer_counter( + self.tree_cache.cache_controller.layer_done_counter + ) + else: self.tree_cache = RadixCache( req_to_token_pool=self.req_to_token_pool, @@ -1460,8 +1464,13 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.running_batch.batch_is_full = True break + # bypass prefix_computed if enable_hierarchical_cache req.init_next_round_input( - None if prefix_computed else self.tree_cache, + ( + None + if (prefix_computed and not self.enable_hierarchical_cache) + else self.tree_cache + ), self.enable_hierarchical_cache, ) @@ -1494,9 +1503,6 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: x for x in self.waiting_queue if x not in set(can_run_list) ] - if self.enable_hierarchical_cache: - self.tree_cache.ready_to_load_cache() - if adder.new_chunked_req is not None: assert self.chunked_req is None self.chunked_req = adder.new_chunked_req @@ -1520,6 +1526,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.server_args.enable_custom_logit_processor, chunked_req=self.chunked_req, ) + if self.enable_hierarchical_cache: + # todo (zhiqiang): disable cuda graph execution if hicache loading triggered + new_batch.hicache_consumer_index = self.tree_cache.ready_to_load_cache() + new_batch.prepare_for_extend() # Mixed-style chunked prefill @@ -1595,6 +1605,11 @@ def run_batch( if self.is_generation: if self.spec_algorithm.is_none(): model_worker_batch = batch.get_model_worker_batch() + + # update the consumer index of hicache to the running batch + self.tp_worker.set_hicache_consumer( + model_worker_batch.hicache_consumer_index + ) if self.pp_group.is_last_rank: logits_output, next_token_ids, can_run_cuda_graph = ( self.tp_worker.forward_batch_generation(model_worker_batch) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 786a34a1edbd..88bbde1b6bbd 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -147,6 +147,15 @@ def __init__( # A reference make this class has the same member as TpModelWorkerClient self.worker = self + self.hicache_layer_transfer_counter = None + + def register_hicache_layer_transfer_counter(self, counter): + self.hicache_layer_transfer_counter = counter + + def set_hicache_consumer(self, consumer_index): + if self.hicache_layer_transfer_counter is not None: + self.hicache_layer_transfer_counter.set_consumer(consumer_index) + def get_worker_info(self): return ( self.max_total_num_tokens, diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 783d864eae4c..45f220db62ab 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -88,6 +88,15 @@ def __init__( if self.device == "cpu": self.scheduler_stream.synchronize = lambda: None # No-op for CPU + self.hicache_layer_transfer_counter = None + + def register_hicache_layer_transfer_counter(self, counter): + self.hicache_layer_transfer_counter = counter + + def set_hicache_consumer(self, consumer_index): + if self.hicache_layer_transfer_counter is not None: + self.hicache_layer_transfer_counter.set_consumer(consumer_index) + def get_worker_info(self): return self.worker.get_worker_info() @@ -146,6 +155,8 @@ def forward_thread_func_(self): input_ids = model_worker_batch.input_ids resolve_future_token_ids(input_ids, self.future_token_ids_map) + # update the consumer index of hicache to the running batch + self.set_hicache_consumer(model_worker_batch.hicache_consumer_index) # Run forward logits_output, next_token_ids, can_run_cuda_graph = ( self.worker.forward_batch_generation( diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 4bec901aa600..60f959057ec4 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -305,7 +305,9 @@ def init_load_back( return last_node, prefix_indices def ready_to_load_cache(self): + producer_index = self.cache_controller.layer_done_counter.next_producer() self.load_cache_event.set() + return producer_index def match_prefix(self, key: List[int], include_evicted=False, **kwargs): empty_value = torch.empty((0,), dtype=torch.int64, device=self.device) @@ -370,6 +372,7 @@ def _split_node(self, key, child: TreeNode, split_len: int): new_node.lock_ref = child.lock_ref new_node.key = child.key[:split_len] new_node.loading = child.loading + new_node.hit_count = child.hit_count # split value and host value if exists if child.evicted: