Skip to content

Commit bb218a6

Browse files
authored
[Misc][VIP] hotfix for gptq-marlin non-contiguous error (#9)
Signed-off-by: DefTruth <[email protected]>
1 parent 433ffcb commit bb218a6

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

vllm/_custom_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,13 @@ def gptq_marlin_gemm(a: torch.Tensor,
738738
use_atomic_add: bool = False,
739739
use_fp32_reduce: bool = False,
740740
is_zp_float: bool = False) -> torch.Tensor:
741+
# FIXME(DefTruth): Remove this patch once gptq_marlin_gemm
742+
# supports non-contiguous input. Currently, marlin requires
743+
# contiguous memory layout, but prefix cache may cause `a`
744+
# to be non-contiguous. We should lower the non-contiguous
745+
# fix into the this function, since `gptq_marlin_gemm` has
746+
# been used in multiple code paths, both AWQ and GPTQ.
747+
a = a.contiguous() # no-op if already contiguous
741748
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
742749
g_idx, perm, workspace, b_q_type.id,
743750
size_m, size_n, size_k, is_k_full,

vllm/attention/backends/mla/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1161,7 +1161,7 @@ def _compute_prefill_context(
11611161
k_pe = workspace[:toks]\
11621162
[..., self.kv_lora_rank:].unsqueeze(1)
11631163

1164-
kv_nope = self.kv_b_proj(kv_c_normed.contiguous())[0].view( \
1164+
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
11651165
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
11661166
k_nope, v = kv_nope\
11671167
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,6 @@ def apply_weights(self,
115115
layer: torch.nn.Module,
116116
x: torch.Tensor,
117117
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
118-
# marlin requires contiguous memory layout
119-
# prefix caching may cause x to be non-contiguous
120-
x = x.contiguous() # no-op if already contiguous
121-
122118
c = self.config
123119
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
124120

0 commit comments

Comments
 (0)