Skip to content
Merged

fix #2701

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions mlx/backend/cuda/quantized/cuda_fp4.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
#pragma once

struct __nv_fp8_e8m0 {
__device__ __nv_fp8_e8m0(uint8_t x) : __x(x) {}

__device__ operator float() {
if (__x == 0xFF) {
return std::numeric_limits<float>::quiet_NaN();
}
return std::ldexp(1.0f, static_cast<int>(__x) - 127);
}

uint8_t __x{0};
};

struct __nv_fp4_e2m1 {
__device__ __nv_fp4_e2m1(float x) {
if (std::isnan(x)) {
Expand Down