From 10a22755be39148c270e441545be8ef28d30d2b4 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 22 Mar 2025 02:36:45 +0000 Subject: [PATCH 1/3] Fix non-contiguous input passed to Marlin kernel Signed-off-by: Qubitium --- .../layers/quantization/kernels/mixed_precision/marlin.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py index e21801cf6a78..c74cfbe63176 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py @@ -115,6 +115,10 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + # marlin requires contiguous memory layout + # kv/prefill caching may cause x to be non-contiguous + x = x.contiguous() # no-op if already contiguous + c = self.config w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer) From 1f731aecaf31a9bab0787af99a263ef2ac628d28 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 22 Mar 2025 03:29:03 +0000 Subject: [PATCH 2/3] format Signed-off-by: Qubitium --- .../layers/quantization/kernels/mixed_precision/marlin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py index c74cfbe63176..c80cbe2ecaa1 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py @@ -117,7 +117,7 @@ def apply_weights(self, bias: Optional[torch.Tensor] = None) -> torch.Tensor: # marlin requires contiguous memory layout # kv/prefill caching may cause x to be non-contiguous - x = x.contiguous() # no-op if already contiguous + x = x.contiguous() # no-op if already contiguous c = self.config w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer) From d3ea7fe63caf112b15e0e0d7bebc1b3ae98afba7 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 22 Mar 2025 13:34:41 +0000 Subject: [PATCH 3/3] update comment Signed-off-by: Qubitium --- .../layers/quantization/kernels/mixed_precision/marlin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py index c80cbe2ecaa1..b030e1484a6a 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py @@ -116,7 +116,7 @@ def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: # marlin requires contiguous memory layout - # kv/prefill caching may cause x to be non-contiguous + # prefix caching may cause x to be non-contiguous x = x.contiguous() # no-op if already contiguous c = self.config