Summary
Recently triton added the scaled_dot which consumes A, B in f8, f6, f4 packed in int32 format and u8m0 scales via int8 datatype. https://github.com/triton-lang/triton/pull/4795/files#diff-1d96a0ed473569188c00d6e16c54dd7050e0a66040438ac630c889aef7cbbbe8R1544
Steps
- Implement new mx matmul in triton | add utilities to ensure that this op is only available when new enough triton is used
- Write unit tests verifying the correctness of implementation against the existing
upcast and matmul approach
- Update Mx Tensor's dispatch to (based on config) use the new op instead of upcasting and running in original precision:
|
b = args[1] |
|
assert isinstance(a, MXTensor) and isinstance(b, MXTensor) |
|
a_hp = a.to_dtype(a._orig_dtype) |
|
b_hp = b.to_dtype(b._orig_dtype) |
|
res = aten_op(a_hp, b_hp) |
- Create profile + memory traces
Summary
Recently triton added the
scaled_dotwhich consumes A, B in f8, f6, f4 packed in int32 format and u8m0 scales via int8 datatype. https://github.com/triton-lang/triton/pull/4795/files#diff-1d96a0ed473569188c00d6e16c54dd7050e0a66040438ac630c889aef7cbbbe8R1544Steps
upcast and matmulapproachao/torchao/prototype/mx_formats/mx_ops.py
Lines 64 to 68 in 48bc81c