@@ -188,21 +188,11 @@ def __init__(
188188 self .num_blocks = num_blocks
189189 self .cache_shape = (self .num_key_value_heads , num_blocks , self .block_size , self .head_dim )
190190
191- self .dtype = dtype
191+ self ._dtype = dtype
192192 self .device = device
193193
194194 self .key_cache : List [torch .Tensor ] = []
195195 self .value_cache : List [torch .Tensor ] = []
196- for idx in range (config .num_hidden_layers ):
197- layer_device = layer_device_map [idx ] if layer_device_map is not None else device
198- new_layer_key_cache = torch .zeros (self .cache_shape , dtype = self .dtype , device = layer_device )
199- new_layer_value_cache = torch .zeros (self .cache_shape , dtype = self .dtype , device = layer_device )
200- # Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
201- # preventing compiled graph breaks when updating the cache.
202- torch ._dynamo .mark_static_address (new_layer_key_cache )
203- torch ._dynamo .mark_static_address (new_layer_value_cache )
204- self .key_cache .append (new_layer_key_cache )
205- self .value_cache .append (new_layer_value_cache )
206196
207197 # Block management data structures
208198 self ._free_blocks = deque (range (num_blocks ))
@@ -280,6 +270,25 @@ def _get_physical_indices(self, state: RequestState, logical_indices: List[int])
280270
281271 return physical_indices
282272
273+ @torch .compiler .disable
274+ def initialise_cache_layer (self , layer_idx , key_states ):
275+ if len (self .key_cache ) > layer_idx :
276+ return
277+ self .num_key_value_heads = key_states .shape [1 ]
278+ device = key_states .device
279+ cache_shape = (
280+ self .num_key_value_heads ,
281+ self .num_blocks ,
282+ self .block_size ,
283+ self .head_dim ,
284+ )
285+ new_layer_key_cache = torch .zeros (cache_shape , dtype = self ._dtype , device = device )
286+ new_layer_value_cache = torch .zeros (cache_shape , dtype = self ._dtype , device = device )
287+ torch ._dynamo .mark_static_address (new_layer_key_cache )
288+ torch ._dynamo .mark_static_address (new_layer_value_cache )
289+ self .key_cache .append (new_layer_key_cache )
290+ self .value_cache .append (new_layer_value_cache )
291+
283292 @traced
284293 def update (
285294 self ,
@@ -291,6 +300,7 @@ def update(
291300 ** kwargs ,
292301 ) -> Tuple [torch .Tensor , torch .Tensor ]:
293302 # Reshape cache for easier indexing
303+ self .initialise_cache_layer (layer_idx , key_states )
294304 total_slots = self .num_blocks * self .block_size
295305 k_cache_flat = self .key_cache [layer_idx ].view (self .num_key_value_heads , total_slots , self .head_dim )
296306 v_cache_flat = self .value_cache [layer_idx ].view (self .num_key_value_heads , total_slots , self .head_dim )
0 commit comments