Skip to content

Commit 1b94178

Browse files
LucasWilkinsonshreyankg
authored andcommitted
[Attention] Remove slow setattr in MLA (vllm-project#14769)
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent f204bde commit 1b94178

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,13 @@ def forward_cuda(
161161
) -> Tuple[torch.Tensor, torch.Tensor]:
162162
from vllm import _custom_ops as ops
163163

164-
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
165-
dtype=query.dtype)
164+
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
165+
# is expensive, so avoid calling it if possible
166+
if self.cos_sin_cache.device != query.device or \
167+
self.cos_sin_cache.dtype != query.dtype:
168+
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
169+
dtype=query.dtype)
170+
166171
# ops.rotary_embedding()/batched_rotary_embedding()
167172
# are in-place operations that update the query and key tensors.
168173
if offsets is not None:

0 commit comments

Comments
 (0)