Skip to content

Commit 0168f9e

Browse files
committed
rebase + add meta functions for machete kernels
1 parent af6302f commit 0168f9e

File tree

3 files changed

+30
-5
lines changed

3 files changed

+30
-5
lines changed

csrc/ops.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,6 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
8383
torch::Tensor& b_scales, torch::Tensor& workspace,
8484
int64_t size_m, int64_t size_n, int64_t size_k);
8585

86-
torch::Tensor marlin_gemm_meta(torch::Tensor& a, torch::Tensor& b_q_weight,
87-
torch::Tensor& b_scales,
88-
torch::Tensor& workspace, int64_t size_m,
89-
int64_t size_n, int64_t size_k);
90-
9186
namespace machete {
9287

9388
std::vector<std::string> supported_schedules(

tests/kernels/test_machete_gemm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010
import torch
1111

12+
from tests.kernels.utils import opcheck
1213
from vllm import _custom_ops as ops
1314
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1415
pack_rows, quantize_weights)
@@ -76,6 +77,8 @@ def machete_quantize_and_pack(w: torch.Tensor,
7677
w_q = w_q.t().contiguous().t() # convert to col major
7778
w_q_machete = ops.machete_prepack_B(w_q, wtype)
7879

80+
opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype))
81+
7982
return w_ref, w_q_machete, w_s, w_zp
8083

8184

@@ -146,6 +149,10 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
146149
schedule=schedule,
147150
)
148151

152+
opcheck(torch.ops._C.machete_gemm,
153+
(a, w_q_machete, wtype, w_s, maybe_convert_zeropoints(
154+
w_zp, w_s), group_size, None, None, None, schedule))
155+
149156
# Relax atol as our reduction dim becomes larger (more rounding error)
150157
# Relax atol when we have zeropoints since the way machete applies
151158
# zeropoints (after scales) causes noise around 0

vllm/_custom_ops.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,29 @@ def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
365365
size_k: int) -> torch.Tensor:
366366
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
367367

368+
@torch.library.register_fake("_C::machete_gemm")
369+
def machete_gemm_fake(
370+
a: torch.Tensor,
371+
b_q: torch.
372+
Tensor, # Should be the tensor returned by machete_prepack_B
373+
b_type: ScalarType,
374+
b_scales: Optional[torch.Tensor] = None,
375+
b_zeros: Optional[torch.Tensor] = None,
376+
b_group_size: Optional[int] = None,
377+
c: Optional[torch.Tensor] = None,
378+
alpha: Optional[float] = None,
379+
beta: Optional[float] = None,
380+
schedule: Optional[str] = None,
381+
) -> torch.Tensor:
382+
m = a.size(0)
383+
n = b_q.size(1)
384+
return torch.empty((m, n), device=a.device, dtype=a.dtype)
385+
386+
@torch.library.register_fake("_C::machete_prepack_B")
387+
def machete_prepack_B_fake(b_q_weight: torch.Tensor,
388+
b_type: ScalarType) -> torch.Tensor:
389+
return torch.empty_like(b_q_weight)
390+
368391
except Exception:
369392
pass
370393

0 commit comments

Comments
 (0)