Skip to content

[QST] Hopper mixed precision gemm always worse than FP8 #1549

@divchenko

Description

@divchenko

I'm doing A 4 bit x B fp16 matmul w/ large A and small B. I expect it to beat fp8 matmul (it should be memory-bound).
In reality, it seems to be always worse.

Example:
Kernel code is here: https://gist.github.com/divchenko/9b02f40ae109e8dc8549afbde059d32e
it's called from python:

import torch
import cuscratch

g = 64
m = 3584 
n = 16
k = 8192

scale_k = (k + g - 1) // g

s = torch.ones((m, scale_k), dtype=torch.half, device="cuda")
a = torch.ones((m, (k + 1) // 2), dtype=torch.int8, device="cuda")
b = torch.ones((n, k), dtype=torch.half, device="cuda")
d = torch.zeros((n, m), dtype=torch.half, device="cuda")

cuscratch.matmul_mixed(a, b.t(), d.t(), s, k, g)

The best perf I can get is using streamk scheduler (k is large indeed). But it's still very low on memory b/w (~20%).
Persistent tile scheduler is way worse for both TMA and TMACooperative kernel schedulers.
Fp8 implementation can reach ~60% of memory b/w and hence is faster although it reads ~2x more bytes.

Am I missing anything? Thank you!

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions