Skip to content

Commit afe78bd

Browse files
committed
lazy cache init
1 parent f962c86 commit afe78bd

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

src/transformers/generation/continuous_batching.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -188,21 +188,11 @@ def __init__(
188188
self.num_blocks = num_blocks
189189
self.cache_shape = (self.num_key_value_heads, num_blocks, self.block_size, self.head_dim)
190190

191-
self.dtype = dtype
191+
self._dtype = dtype
192192
self.device = device
193193

194194
self.key_cache: List[torch.Tensor] = []
195195
self.value_cache: List[torch.Tensor] = []
196-
for idx in range(config.num_hidden_layers):
197-
layer_device = layer_device_map[idx] if layer_device_map is not None else device
198-
new_layer_key_cache = torch.zeros(self.cache_shape, dtype=self.dtype, device=layer_device)
199-
new_layer_value_cache = torch.zeros(self.cache_shape, dtype=self.dtype, device=layer_device)
200-
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
201-
# preventing compiled graph breaks when updating the cache.
202-
torch._dynamo.mark_static_address(new_layer_key_cache)
203-
torch._dynamo.mark_static_address(new_layer_value_cache)
204-
self.key_cache.append(new_layer_key_cache)
205-
self.value_cache.append(new_layer_value_cache)
206196

207197
# Block management data structures
208198
self._free_blocks = deque(range(num_blocks))
@@ -280,6 +270,25 @@ def _get_physical_indices(self, state: RequestState, logical_indices: List[int])
280270

281271
return physical_indices
282272

273+
@torch.compiler.disable
274+
def initialise_cache_layer(self, layer_idx, key_states):
275+
if len(self.key_cache) > layer_idx:
276+
return
277+
self.num_key_value_heads = key_states.shape[1]
278+
device = key_states.device
279+
cache_shape = (
280+
self.num_key_value_heads,
281+
self.num_blocks,
282+
self.block_size,
283+
self.head_dim,
284+
)
285+
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device)
286+
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device)
287+
torch._dynamo.mark_static_address(new_layer_key_cache)
288+
torch._dynamo.mark_static_address(new_layer_value_cache)
289+
self.key_cache.append(new_layer_key_cache)
290+
self.value_cache.append(new_layer_value_cache)
291+
283292
@traced
284293
def update(
285294
self,
@@ -291,6 +300,7 @@ def update(
291300
**kwargs,
292301
) -> Tuple[torch.Tensor, torch.Tensor]:
293302
# Reshape cache for easier indexing
303+
self.initialise_cache_layer(layer_idx, key_states)
294304
total_slots = self.num_blocks * self.block_size
295305
k_cache_flat = self.key_cache[layer_idx].view(self.num_key_value_heads, total_slots, self.head_dim)
296306
v_cache_flat = self.value_cache[layer_idx].view(self.num_key_value_heads, total_slots, self.head_dim)

0 commit comments

Comments
 (0)