-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Open
Labels
Description
I'm doing A 4 bit x B fp16 matmul w/ large A and small B. I expect it to beat fp8 matmul (it should be memory-bound).
In reality, it seems to be always worse.
Example:
Kernel code is here: https://gist.github.com/divchenko/9b02f40ae109e8dc8549afbde059d32e
it's called from python:
import torch
import cuscratch
g = 64
m = 3584
n = 16
k = 8192
scale_k = (k + g - 1) // g
s = torch.ones((m, scale_k), dtype=torch.half, device="cuda")
a = torch.ones((m, (k + 1) // 2), dtype=torch.int8, device="cuda")
b = torch.ones((n, k), dtype=torch.half, device="cuda")
d = torch.zeros((n, m), dtype=torch.half, device="cuda")
cuscratch.matmul_mixed(a, b.t(), d.t(), s, k, g)
The best perf I can get is using streamk scheduler (k is large indeed). But it's still very low on memory b/w (~20%).
Persistent tile scheduler is way worse for both TMA and TMACooperative kernel schedulers.
Fp8 implementation can reach ~60% of memory b/w and hence is faster although it reads ~2x more bytes.
Am I missing anything? Thank you!