Skip to content

Commit 6370062

Browse files
authored
feat(cache): StaticCache uses index_copy_ to avoid useless copy (#31857)
* feat(cache): StaticCache uses index_copy_ to avoid useless copy Using index_copy_ allows for explicit in-place change of the tensor. Some backends (XLA) will otherwise copy the tensor, making the code slower and using more memory. Proposed implementation will end up using less memory and on XLA will result in less compilation, but the change is also quite generic, making no change whatsoever on CUDA or CPU backend. * feat(cache): SlidingWindowCache uses index_copy_ to avoid useless copy Applying the same change done in StaticCache. * fix(cache): fallback of index_copy_ when not implemented * fix(cache): in index_copy_ ensure tensors are on same device * [run slow] llama * fix(cache): add move of cache_position to same device in SlidingWindowCache * Revert "[run slow] llama" This reverts commit 02608dd.
1 parent a009fbd commit 6370062

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

src/transformers/cache_utils.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -862,8 +862,18 @@ def update(
862862
k_out.copy_(key_states)
863863
v_out.copy_(value_states)
864864
else:
865-
k_out[:, :, cache_position] = key_states
866-
v_out[:, :, cache_position] = value_states
865+
# Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
866+
# `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
867+
# operation, that avoids copies and uses less memory.
868+
try:
869+
# If using several devices (e.g.: multiple GPUs), we need to ensure everything is on the same one
870+
cache_position.to(device=k_out.device)
871+
k_out.index_copy_(2, cache_position, key_states)
872+
v_out.index_copy_(2, cache_position, value_states)
873+
except NotImplementedError:
874+
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
875+
k_out[:, :, cache_position] = key_states
876+
v_out[:, :, cache_position] = value_states
867877

868878
return k_out, v_out
869879

@@ -958,8 +968,14 @@ def update(
958968
k_out = k_out[:, :, indices]
959969
v_out = v_out[:, :, indices]
960970

961-
k_out[:, :, cache_position] = key_states
962-
v_out[:, :, cache_position] = value_states
971+
try:
972+
cache_position.to(device=k_out.device)
973+
k_out.index_copy_(2, cache_position, key_states)
974+
v_out.index_copy_(2, cache_position, value_states)
975+
except NotImplementedError:
976+
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
977+
k_out[:, :, cache_position] = key_states
978+
v_out[:, :, cache_position] = value_states
963979

964980
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
965981
self.key_cache[layer_idx].zero_()

0 commit comments

Comments
 (0)