Skip to content

Commit 00d1328

Browse files
committed
Zero buffer inside shrink op
Signed-off-by: Andy Lo <[email protected]>
1 parent 5ce2b79 commit 00d1328

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

vllm/lora/ops/triton_ops/lora_shrink_op.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def _lora_shrink(
152152
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
153153
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
154154

155+
output_tensor.zero_()
156+
155157
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
156158
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device)
157159
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank

vllm/lora/punica_wrapper/punica_gpu.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ def add_lora_linear(self,
207207
r = lora_b_stacked[0].size(-1)
208208
# We set the buffer to be float32 by default, refer to:
209209
# https://github.com/triton-lang/triton/issues/1387
210-
buffer = torch.zeros( # type: ignore
210+
# Note: buffer is zeroed inside the shrink op
211+
buffer = torch.empty( # type: ignore
211212
(len(output_slices), x.size(0), r),
212213
dtype=torch.float32,
213214
device=x.device,

0 commit comments

Comments
 (0)