diff --git a/vllm/model_executor/layers/quantization/kernels/marlin.py b/vllm/model_executor/layers/quantization/kernels/marlin.py index 5b4bba76ee0c..6969583d6d47 100644 --- a/vllm/model_executor/layers/quantization/kernels/marlin.py +++ b/vllm/model_executor/layers/quantization/kernels/marlin.py @@ -38,10 +38,11 @@ def can_implement(cls, "Marlin, supported group sizes are: "\ f"{MARLIN_SUPPORTED_GROUP_SIZES}" - return check_marlin_supports_shape(c.partition_weight_shape[0], - c.partition_weight_shape[1], - c.full_weight_shape[1], - c.group_size) + return check_marlin_supports_shape( + c.partition_weight_shape[1], # out_features + c.partition_weight_shape[0], # in_features + c.full_weight_shape[0], # in_features + c.group_size) # note assumes that # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}