Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mlx/backend/metal/kernels/fp4.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ struct fp4_e2m1 {
}

operator float() {
return FP4_LUT[bits];
half converted = as_type<half>(ushort((bits & 7) << 9));
converted *= 16384.0;
converted = bits & 8 ? -converted : converted;
return converted;
}

uint8_t bits;
Expand Down
24 changes: 9 additions & 15 deletions mlx/backend/metal/kernels/fp8.h
Original file line number Diff line number Diff line change
@@ -1,29 +1,22 @@
#pragma once

inline float fp32_from_bits(uint32_t bits) {
return *(reinterpret_cast<thread float*>(&bits));
}
inline float fp32_to_bits(float x) {
return *(reinterpret_cast<thread uint32_t*>(&x));
}

struct fp8_e4m3 {
template <typename T>
fp8_e4m3(T f) {
// From PyTorch
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148
uint32_t fp8_max = 543 << 21;
uint32_t denorm_mask = 141 << 23;
uint32_t f_bits = fp32_to_bits(static_cast<float>(f));
uint32_t f_bits = as_type<uint32_t>(static_cast<float>(f));
uint32_t sign = f_bits & 0x80000000;
f_bits ^= sign;
if (f_bits >= fp8_max) {
// Default behavior saturates to min/max
bits = 0x7E;
} else {
if (f_bits < (121 << 23)) {
f_bits =
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
f_bits = as_type<uint32_t>(
as_type<float>(f_bits) + as_type<float>(denorm_mask));
bits = static_cast<uint8_t>(f_bits - denorm_mask);
} else {
// resulting mantissa is odd
Expand Down Expand Up @@ -53,7 +46,7 @@ struct fp8_e4m3 {
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
return fp32_from_bits(result);
return as_type<float>(result);
}

uint8_t bits;
Expand All @@ -77,11 +70,12 @@ struct fp8_e8m0 {
bits = static_cast<uint8_t>(n + 127);
}

operator bfloat16_t() {
uint16_t out = (bits == 0 ? 0x40 : (static_cast<uint16_t>(bits) << 7));
return as_type<bfloat16_t>(out);
}
operator float() {
if (bits == 0xFF) {
return metal::numeric_limits<float>::quiet_NaN();
}
return metal::ldexp(1.0f, static_cast<int>(bits) - 127);
return static_cast<float>(this->operator bfloat16_t());
}

uint8_t bits;
Expand Down
Loading