diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp
index 3acba5d054..b2f5c91ffc 100644
--- a/common/chat-parser.cpp
+++ b/common/chat-parser.cpp
@@ -82,28 +82,38 @@ bool common_chat_msg_parser::try_consume_literal(const std::string & literal) {
}
bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) {
- auto start_pos = input_.find(start_think, pos_);
- if (start_pos == std::string::npos) {
- return false;
- }
+ auto handle_reasoning = [&](const std::string & reasoning, bool closed) {
+ auto stripped_reasoning = string_strip(reasoning);
+ if (stripped_reasoning.empty()) {
+ return;
+ }
+ if (syntax_.reasoning_in_content) {
+ add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "" : start_think);
+ add_content(stripped_reasoning);
+ if (closed) {
+ add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "" : end_think);
+ }
+ } else {
+ add_reasoning_content(stripped_reasoning);
+ }
+ };
- auto end_pos = input_.find(end_think, start_pos + start_think.size());
- if (end_pos == std::string::npos) {
- if (is_partial_) {
- // Partial reasoning content
- auto reasoning = input_.substr(start_pos + start_think.size());
- add_reasoning_content(string_strip(reasoning));
- pos_ = input_.size();
+ if (syntax_.reasoning_format != COMMON_REASONING_FORMAT_NONE) {
+ if (syntax_.thinking_forced_open || try_consume_literal(start_think)) {
+ if (auto res = try_find_literal(end_think)) {
+ handle_reasoning(res->prelude, /* closed */ true);
+ consume_spaces();
+ return true;
+ }
+ auto rest = consume_rest();
+ if (!rest.empty()) {
+ handle_reasoning(rest, /* closed */ !is_partial());
+ }
+ // Allow unclosed thinking tags for now (following original llama.cpp)
return true;
}
- return false;
}
-
- // Extract reasoning content
- auto reasoning = input_.substr(start_pos + start_think.size(), end_pos - start_pos - start_think.size());
- add_reasoning_content(string_strip(reasoning));
- pos_ = end_pos + end_think.size();
- return true;
+ return false;
}
std::optional common_chat_msg_parser::try_find_literal_legacy(const std::string & literal) {
diff --git a/common/chat.cpp b/common/chat.cpp
index 15cfbbf03d..f62c280119 100644
--- a/common/chat.cpp
+++ b/common/chat.cpp
@@ -278,6 +278,9 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
throw; // Re-throw for partial mode
}
}
+
+ // Add any remaining content (critical for responses without tool calls)
+ builder.add_content(builder.consume_rest());
}
// Parse DeepSeek R1 tools array format following original llama.cpp parse_prefixed_json_tool_call_array pattern
diff --git a/common/chat.h b/common/chat.h
index e23f84f383..5899ef1a1e 100644
--- a/common/chat.h
+++ b/common/chat.h
@@ -135,8 +135,18 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_KIMI_K2, // Our custom format (keep last for backward compatibility)
};
+enum common_reasoning_format {
+ COMMON_REASONING_FORMAT_NONE,
+ COMMON_REASONING_FORMAT_DEEPSEEK,
+ COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY,
+};
+
struct common_chat_syntax {
common_chat_format format = COMMON_CHAT_FORMAT_KIMI_K2;
+ common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
+ // Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
+ bool reasoning_in_content = false;
+ bool thinking_forced_open = false;
bool enable_thinking = false;
bool enable_tool_calls = true;
};
diff --git a/common/common.cpp b/common/common.cpp
index 1801da0395..da702368de 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -1080,6 +1080,24 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
return true;
}
+ if (arg == "--cpu-moe" || arg == "-cmoe") {
+ params.tensor_buft_overrides.push_back({strdup("\\.ffn_(up|down|gate)_exps"), ggml_backend_cpu_buffer_type()});
+ return true;
+ }
+ if (arg == "--n-cpu-moe" || arg == "-ncmoe") {
+ CHECK_ARG
+ int32_t n_layers = std::stoi(argv[i]);
+ if (n_layers < 0) {
+ fprintf(stderr, "error: Invalid value for --n-cpu-moe: %d (must be >= 0)\n", n_layers);
+ invalid_param = true;
+ return true;
+ }
+ for (int32_t l = 0; l < n_layers; ++l) {
+ std::string pattern = "blk\\." + std::to_string(l) + "\\.(ffn_(up|down|gate)_exps)";
+ params.tensor_buft_overrides.push_back({strdup(pattern.c_str()), ggml_backend_cpu_buffer_type()});
+ }
+ return true;
+ }
if (arg == "--no-mmap") {
params.use_mmap = false;
return true;
@@ -1794,6 +1812,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", " --no-mmap", "do not memory-map model (slower load but may reduce pageouts if not using mlock)" });
}
options.push_back({ "*", " --run-time-repack", "repack tensors if interleaved variant is available"});
+ options.push_back({ "*", " --cpu-moe", "keep all MoE weights in CPU memory"});
+ options.push_back({ "*", " --n-cpu-moe N", "keep MoE weights of the first N layers in CPU memory"});
options.push_back({ "*", " --numa TYPE", "attempt optimizations that help on some NUMA systems\n"
" - distribute: spread execution evenly over all nodes\n"
" - isolate: only spawn threads on CPUs on the node that execution started on\n"
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp
index 3de7bc20ad..2e2c62bf4e 100644
--- a/examples/quantize/quantize.cpp
+++ b/examples/quantize/quantize.cpp
@@ -28,6 +28,7 @@ static const std::vector QUANT_OPTIONS = {
{ "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 4.33G, +0.0683 ppl @ LLaMA-v1-7B", },
{ "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 4.70G, +0.0349 ppl @ LLaMA-v1-7B", },
{ "Q6_0", LLAMA_FTYPE_MOSTLY_Q6_0, " 6.5 bpw quantization", },
+ { "MXFP4", LLAMA_FTYPE_MOSTLY_MXFP4, " 4.25 bpw 4-bit float quantization",},
{ "IQ2_XXS", LLAMA_FTYPE_MOSTLY_IQ2_XXS, " 2.06 bpw quantization", },
{ "IQ2_XXS_R4",LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4,"IQ2_XXS repacked", },
{ "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS, " 2.31 bpw quantization", },
diff --git a/examples/server/function_calls.hpp b/examples/server/function_calls.hpp
index 068c5f24ce..92d25a0d40 100644
--- a/examples/server/function_calls.hpp
+++ b/examples/server/function_calls.hpp
@@ -89,6 +89,8 @@ static ik_chat_msg parse_chat_message_incremental(const std::string& content, bo
try {
common_chat_syntax syntax;
syntax.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1;
+ syntax.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
+ syntax.reasoning_in_content = true; // Fix for thinking tag termination issue
syntax.enable_tool_calls = true;
common_chat_msg_parser parser(content, is_partial, syntax);
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 5b90c9a50e..2c2613928b 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -403,6 +403,7 @@ extern "C" {
GGML_TYPE_Q4_0_4_4 = 31,
GGML_TYPE_Q4_0_4_8 = 32,
GGML_TYPE_Q4_0_8_8 = 33,
+ GGML_TYPE_MXFP4 = 39, // so we are compatible with mainline
//
// So we are able to consume MS BitNet I2_S quants
//
@@ -507,9 +508,10 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
- GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors
- GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors
- GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors
+ GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors, using 26 to be compatible with mainline
+ GGML_FTYPE_MOSTLY_Q4_0_4_4 = 26, // except 1d tensors
+ GGML_FTYPE_MOSTLY_Q4_0_4_8 = 27, // except 1d tensors
+ GGML_FTYPE_MOSTLY_Q4_0_8_8 = 28, // except 1d tensors
//
GGML_FTYPE_MOSTLY_Q6_0 = 127, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ1_BN = 128, // except 1d tensors
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index 1dc1ff6ecb..59f7ae711a 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -158,6 +158,9 @@ typedef sycl::half2 ggml_half2;
#define QI1_BN (QK_IQ1BN / (4*QR1_BN))
#define QR1_BN 8
+#define QI_MXFP4 (QK_MXFP4 / (4 * QR_MXFP4))
+#define QR_MXFP4 2
+
#endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP
#define QK4_0 32
@@ -174,6 +177,15 @@ typedef struct {
} block_q4_1;
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
+// This is unfortunate (block is 17 bytes, so not even a 2-byte alignment)
+// But to be able to use MXFP4-quantized models from mainline, we do the same.
+#define QK_MXFP4 32
+typedef struct {
+ uint8_t e; // E8M0
+ uint8_t qs[QK_MXFP4/2];
+} block_mxfp4;
+static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + QK_MXFP4/2, "wrong mxfp4 block size/padding");
+
#define QK5_0 32
typedef struct {
ggml_half d; // delta
@@ -2211,5 +2223,11 @@ GGML_TABLE_BEGIN(int8_t, iq6nl_values, 128)
48, 52, 56, 60, 64, 69, 73, 78, 83, 88, 93, 99, 104, 110, 116, 122,
GGML_TABLE_END()
+// e2m1 values (doubled)
+// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
+ 0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
+GGML_TABLE_END()
+
#endif // GGML_COMMON_IMPL
#endif // GGML_COMMON_IMPL
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index 7fee71d848..9372c05c16 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -1647,7 +1647,7 @@ static void ggml_cuda_op_mul_mat(
}
const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
- if (!(split && used_devices > 1) && quantization_done && ne11 == 1 && ne12 > 1 && ne13 == 1) {
+ if (!(split && used_devices > 1) && quantization_done && ne11 == 1 && ne12 > 1 && ne13 == 1 && ne02 == ne12 && ne02 == dst->ne[2]) {
//printf("invoking fast path for %s x %s\n", src0->name, src1->name);
int id = ctx.device;
char * src0_dd_i = dev[id].src0_dd;
@@ -3498,6 +3498,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ2_KL:
case GGML_TYPE_IQ3_KS:
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 15485f6065..c856a44b6e 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -550,6 +550,13 @@ struct ggml_cuda_type_traits {
static constexpr int qi = QI4_NL;
};
+template<>
+struct ggml_cuda_type_traits {
+ static constexpr int qk = QK4_NL;
+ static constexpr int qr = QR4_NL;
+ static constexpr int qi = QI4_NL;
+};
+
template<>
struct ggml_cuda_type_traits {
static constexpr int qk = QK_K;
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index 8c03ae1bc6..689613f59a 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -736,6 +736,27 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
}
}
+template
+static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ constexpr uint32_t uval[2] = { 0x00200000, 0x00400000 };
+ const int64_t i = blockIdx.x;
+ const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK4_NL);
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+ const uint8_t * q4 = x[ib].qs + 4*il;
+ union { float f; uint32_t u; } helper;
+ helper.u = x[ib].e >= 2 ? uint32_t(x[ib].e - 1) << 23u : uval[x[ib].e];
+ const float d = helper.f;
+ for (int j = 0; j < 4; ++j) {
+ y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf];
+ y[j+16] = d * kvalues_mxfp4[q4[j] >> 4];
+ }
+}
+
template
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
@@ -1611,6 +1632,13 @@ static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t
dequantize_block_iq4_nl<<>>(vx, y);
}
+template
+static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
+ const int nb = (k + QK_K - 1) / QK_K;
+ dequantize_block_mxfp4<<>>(vx, y);
+}
+
template
static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
@@ -1943,6 +1971,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq2_bn_cuda;
case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_cuda;
+ case GGML_TYPE_MXFP4:
+ return dequantize_row_mxfp4_cuda;
case GGML_TYPE_IQ4_XS:
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ4_KS:
@@ -2044,6 +2074,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq2_bn_cuda;
case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_cuda;
+ case GGML_TYPE_MXFP4:
+ return dequantize_row_mxfp4_cuda;
case GGML_TYPE_IQ4_XS:
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ4_KS:
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
index a0e7da129c..bebc7c8739 100644
--- a/ggml/src/ggml-cuda/mmq.cu
+++ b/ggml/src/ggml-cuda/mmq.cu
@@ -14,7 +14,7 @@ void ggml_cuda_op_mul_mat_q(
const int64_t src1_padded_row_size, cudaStream_t stream) {
const int64_t ne00 = src0->ne[0];
- const int64_t nb01 = src0->nb[1];
+ const int64_t nb01 = ggml_row_size(src0->type, ne00);
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
@@ -94,6 +94,9 @@ void ggml_cuda_op_mul_mat_q(
case GGML_TYPE_IQ4_NL:
mul_mat_q_case(ctx, args, stream);
break;
+ case GGML_TYPE_MXFP4:
+ mul_mat_q_case(ctx, args, stream);
+ break;
case GGML_TYPE_IQ2_KL:
mul_mat_q_case(ctx, args, stream);
break;
@@ -210,6 +213,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
case GGML_TYPE_IQ1_S_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_IQ2_KL:
case GGML_TYPE_IQ3_KS:
case GGML_TYPE_IQ4_KSS:
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
index 20277041f8..9adc94a684 100644
--- a/ggml/src/ggml-cuda/mmq.cuh
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -84,6 +84,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
return MMQ_Q8_1_DS_LAYOUT_DS4;
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_K_R4:
@@ -204,6 +205,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
case GGML_TYPE_IQ1_S_R4: return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_XS : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_NL : return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_MXFP4 : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ2_KL : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ3_KS : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_KSS : return MMQ_DP4A_TXS_Q8_0;
@@ -263,6 +265,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_IQ1_S_R4: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_XS : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_NL : return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_MXFP4 : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ2_KL : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ3_KS : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_KSS : return MMQ_MMA_TILE_X_K_Q8_0;
@@ -2078,6 +2081,67 @@ template static __device__ __forceinlin
}
}
+template static __device__ __forceinline__ void load_tiles_mxfp4(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kbx = threadIdx.x / QI4_NL;
+ const int kqsx = threadIdx.x % QI4_NL;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + threadIdx.y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_mxfp4 * bxi = (const block_mxfp4 *)(x + i*stride) + kbx0 + kbx;
+
+ const int aux_q4 = get_int_b1(bxi->qs, kqsx);
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
+ const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+ union { float f; uint32_t u; } helper;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) {
+ int i = i0 + threadIdx.y * QI4_NL + threadIdx.x / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_mxfp4 * bxi = (const block_mxfp4 *)(x + i*stride) + kbx0 + kbxd;
+ helper.u = bxi->e ? uint32_t(bxi->e) << 23u : 0x00400000;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = 0.5f * helper.f;
+#else
+ x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = 0.5f * helper.f;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
template static __device__ __forceinline__ void load_tiles_iq2_xxs(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
@@ -3624,6 +3688,13 @@ struct mmq_type_traits {
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
};
+template
+struct mmq_type_traits {
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
+};
+
template
struct mmq_type_traits {
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs;
@@ -4164,6 +4235,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
+extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KL);
extern DECL_MMQ_CASE(GGML_TYPE_IQ3_KS);
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index 012b3e5e73..10d16aeb82 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -31,6 +31,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
case GGML_TYPE_IQ1_S : return vec_dot_iq1_s_q8_1;
case GGML_TYPE_IQ1_M : return vec_dot_iq1_m_q8_1;
case GGML_TYPE_IQ4_NL : return vec_dot_iq4_nl_q8_1;
+ case GGML_TYPE_MXFP4 : return vec_dot_mxfp4_q8_1;
case GGML_TYPE_IQ4_XS : return vec_dot_iq4_xs_q8_1;
case GGML_TYPE_IQ3_S : return vec_dot_iq3_s_q8_1;
default : return nullptr;
@@ -56,6 +57,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
case GGML_TYPE_IQ3_XXS : return VDR_IQ3_XXS_Q8_1_MMVQ;
case GGML_TYPE_IQ3_S : return VDR_IQ3_S_Q8_1_MMVQ;
case GGML_TYPE_IQ4_NL : return VDR_IQ4_NL_Q8_1_MMVQ;
+ case GGML_TYPE_MXFP4 : return VDR_MXFP4_Q8_1_MMVQ;
case GGML_TYPE_IQ4_XS : return VDR_IQ4_XS_Q8_1_MMVQ;
default : return 1;
}
@@ -417,6 +419,14 @@ static void mul_mat_vec_iq4_nl_q8_1_cuda(
mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
+static void mul_mat_vec_mxfp4_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const char * ids_data,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
+}
+
static void mul_mat_vec_iq4_xs_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
@@ -509,6 +519,9 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm
case GGML_TYPE_IQ4_NL:
mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
+ case GGML_TYPE_MXFP4:
+ mul_mat_vec_mxfp4_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
+ break;
case GGML_TYPE_IQ4_XS:
mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
@@ -686,6 +699,7 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) {
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_KL:
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu
index eb02fab002..c88946c219 100644
--- a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu
@@ -1,5 +1,4 @@
-// This file has been autogenerated by generate_cu_files.py, do not edit manually.
-
#include "../mmq.cuh"
DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
+DECL_MMQ_CASE(GGML_TYPE_MXFP4);
diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh
index cae5e04fed..97a792bd68 100644
--- a/ggml/src/ggml-cuda/vecdotq.cuh
+++ b/ggml/src/ggml-cuda/vecdotq.cuh
@@ -17,6 +17,15 @@ static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32
return x32;
}
+static __device__ __forceinline__ int get_int_b1(const void * x, const int & i32) {
+ const uint8_t * x8 = (const uint8_t *)x;
+
+ int x32 = x8[4*i32 + 0] | (x8[4*i32 + 1] << 8);
+ x32 |= (x8[4*i32 + 2] | (x8[4*i32 + 3] << 8)) << 16;
+
+ return x32;
+}
+
static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32) {
return ((const int *) x)[i32]; // assume at least 4 byte alignment
}
@@ -1167,6 +1176,32 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
return d * sumi;
}
+#define VDR_MXFP4_Q8_1_MMVQ 2
+#define VDR_MXFP4_Q8_1_MMQ 4
+
+static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx;
+
+ const int * q8 = (const int *) bq8_1->qs + iqs;
+
+ int2 sumi = {0, 0};
+#pragma unroll
+ for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
+ const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
+
+ sumi.x = ggml_cuda_dp4a(v.x, q8[l + 0], sumi.x);
+ sumi.y = ggml_cuda_dp4a(v.y, q8[l + 4], sumi.y);
+ }
+
+ union { float f; uint32_t u; } helper;
+ helper.u = bq4->e ? uint32_t(bq4->e) << 23u : 0x00400000;
+
+ return 0.5f * helper.f * __low2float(bq8_1->ds) * (sumi.x + sumi.y);
+}
+
#define VDR_IQ4_XS_Q8_1_MMVQ 4
#define VDR_IQ4_XS_Q8_1_MMQ 4
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index e4e3686088..62f07e1e5b 100644
--- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h
@@ -29,6 +29,24 @@
#endif
+// Does not handle NaN
+static inline float ggml_e8m0_to_fp32(uint8_t x) {
+ union { float f; uint32_t u; } helper;
+ helper.u = x ? (uint32_t)x << 23u : 0x00400000;
+ return helper.f;
+}
+
+// As above, but returns ggml_e8m0_to_fp32(x)/2
+static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
+ static uint32_t val[2] = { 0x00200000, 0x00400000 };
+ union { float f; uint32_t u; } helper;
+ helper.u = x >= 2 ? (uint32_t)(x - 1) << 23u : val[x];
+ return helper.f;
+}
+
+#define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)
+#define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)
+
/**
* Converts brain16 to float32.
*
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index a86c66b6dc..104ad664f5 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -105,6 +105,7 @@
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_KS,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS,
@@ -153,6 +154,7 @@
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_KS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KS_F32,
@@ -195,6 +197,7 @@
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_KS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32,
@@ -234,6 +237,7 @@
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32,
@@ -273,6 +277,7 @@
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16,
@@ -312,6 +317,7 @@
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_KS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32,
@@ -767,6 +773,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN, get_rows_iq1_bn, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN, get_rows_iq2_bn, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_KS, get_rows_iq3_ks, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS, get_rows_iq4_ks, true);
@@ -815,6 +822,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32, mul_mv_iq1_bn_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32, mul_mv_iq2_bn_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_KS_F32, mul_mv_iq3_ks_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KS_F32, mul_mv_iq4_ks_f32, ctx->support_simdgroup_reduction);
@@ -857,6 +865,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32, mul_mv_id_iq1_bn_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32, mul_mv_id_iq2_bn_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, mul_mv_id_mxfp4_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_KS_F32, mul_mv_id_iq3_ks_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32, mul_mv_id_iq4_ks_f32, ctx->support_simdgroup_reduction);
@@ -896,6 +905,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32, mul_mm_iq1_bn_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32, mul_mm_iq2_bn_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F32, mul_mm_iq3_ks_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32, mul_mm_iq4_ks_f32, ctx->support_simdgroup_mm);
@@ -935,6 +945,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F16, mul_mm_iq1_bn_f16, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16, mul_mm_iq2_bn_f16, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16, mul_mm_iq4_nl_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F16, mul_mm_mxfp4_f16, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16, mul_mm_iq4_xs_f16, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F16, mul_mm_iq3_ks_f16, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16, mul_mm_iq4_ks_f16, ctx->support_simdgroup_mm);
@@ -974,6 +985,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32, mul_mm_id_iq1_bn_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32, mul_mm_id_iq2_bn_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F32, mul_mm_id_mxfp4_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_KS_F32, mul_mm_id_iq3_ks_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32, mul_mm_id_iq4_ks_f32, ctx->support_simdgroup_mm);
@@ -2192,6 +2204,7 @@ static void ggml_metal_encode_node(
case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32 ].pipeline; break;
case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32 ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
case GGML_TYPE_IQ3_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F32 ].pipeline; break;
case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32 ].pipeline; break;
@@ -2236,6 +2249,7 @@ static void ggml_metal_encode_node(
case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F16 ].pipeline; break;
case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16 ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16 ].pipeline; break;
+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F16 ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16 ].pipeline; break;
case GGML_TYPE_IQ3_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F16 ].pipeline; break;
case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16 ].pipeline; break;
@@ -2450,6 +2464,12 @@ static void ggml_metal_encode_node(
nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
} break;
+ case GGML_TYPE_MXFP4:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline;
+ } break;
case GGML_TYPE_IQ4_XS:
{
nth0 = 4;
@@ -2595,7 +2615,7 @@ static void ggml_metal_encode_node(
}
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K ||
src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS||
- src0t == GGML_TYPE_IQ4_KSS || src0t == GGML_TYPE_IQ5_KS) {
+ src0t == GGML_TYPE_IQ4_KSS || src0t == GGML_TYPE_IQ5_KS || src0t == GGML_TYPE_MXFP4) {
const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float)
: src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ5_KS ? 64*sizeof(float) : 32*sizeof(float);
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
@@ -2690,6 +2710,7 @@ static void ggml_metal_encode_node(
case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32 ].pipeline; break;
case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32 ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F32 ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
case GGML_TYPE_IQ3_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_KS_F32 ].pipeline; break;
case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32 ].pipeline; break;
@@ -2888,6 +2909,12 @@ static void ggml_metal_encode_node(
nth1 = 2;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
} break;
+ case GGML_TYPE_MXFP4:
+ {
+ nth0 = 32;
+ nth1 = 2;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline;
+ } break;
case GGML_TYPE_IQ4_XS:
{
nth0 = 32;
@@ -3044,7 +3071,7 @@ static void ggml_metal_encode_node(
}
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K ||
src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS||
- src0t == GGML_TYPE_IQ4_KSS || src0t == GGML_TYPE_IQ5_KS) {
+ src0t == GGML_TYPE_IQ4_KSS || src0t == GGML_TYPE_IQ5_KS || src0t == GGML_TYPE_MXFP4) {
const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float)
: src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ5_KS ? 64*sizeof(float) : 32*sizeof(float);
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
@@ -3095,6 +3122,7 @@ static void ggml_metal_encode_node(
case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN ].pipeline; break;
case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4 ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
case GGML_TYPE_IQ3_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_KS ].pipeline; break;
case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS ].pipeline; break;
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index 53de59dd8a..f700e6f7ad 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -3975,6 +3975,10 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
};
+constexpr constant static float kvalues_mxfp4_f[16] = {
+ 0.f, 1.f, 2.f, 3.f, 4.f, 6.f, 8.f, 12.f, 0.f, -1.f, -2.f, -3.f, -4.f, -6.f, -8.f, -12.f
+};
+
constexpr constant static float kvalues_iq4k_f[32] = {
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f,
-123.f, -100.f, -79.f, -61.f, -45.f, -31.f, -18.f, -6.f, 5.f, 17.f, 29.f, 42.f, 57.f, 73.f, 93.f, 117.f,
@@ -6082,6 +6086,104 @@ void kernel_mul_mv_iq4_nl_f32_impl(
}
}
+void kernel_mul_mv_mxfp4_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values_i8,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
+ threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
+ const int nb = ne00/QK4_NL;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+ const int first_row = (r0 * 2 + sgitg) * 2;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ device const block_mxfp4 * x = (device const block_mxfp4 *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ const int ix = tiisg/2; // 0...15
+ const int it = tiisg%2; // 0 or 1
+
+ shared_values[tiisg] = kvalues_mxfp4_f[tiisg%16];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ float4 yl[4];
+ float sumf[2]={0.f}, all_sum;
+
+ device const float * yb = y + ix * QK4_NL + it * 8;
+
+ uint32_t aux32[2];
+ thread const uint8_t * q8 = (thread const uint8_t *)aux32;
+
+ float4 qf1, qf2;
+
+ constexpr uint32_t val[2] = { 0x00200000, 0x00400000 };
+ union { float f; uint32_t u; } helper;
+
+ for (int ib = ix; ib < nb; ib += 16) {
+
+ device const float4 * y4 = (device const float4 *)yb;
+ yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
+
+ for (int row = 0; row < 2 && first_row + row < ne01; ++row) {
+
+ device const block_mxfp4 & xb = x[row*nb + ib];
+ device const uint8_t * q4 = (device const uint8_t *)(xb.qs + 8*it);
+
+ float4 acc1 = {0.f}, acc2 = {0.f};
+
+ aux32[0] = q4[0] | (q4[1] << 8) | (q4[2] << 16) | (q4[3] << 24);
+ aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
+ aux32[0] &= 0x0f0f0f0f;
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
+ acc1 += yl[0] * qf1;
+ acc2 += yl[1] * qf2;
+
+ aux32[0] = q4[4] | (q4[5] << 8) | (q4[6] << 16) | (q4[7] << 24);
+ aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
+ aux32[0] &= 0x0f0f0f0f;
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
+ acc1 += yl[2] * qf1;
+ acc2 += yl[3] * qf2;
+
+ acc1 += acc2;
+
+ helper.u = xb.e >= 2 ? (uint32_t)(xb.e - 1) << 23u : val[xb.e];
+ sumf[row] += helper.f * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
+
+ }
+
+ yb += 16 * QK4_NL;
+ }
+
+ for (int row = 0; row < 2 && first_row + row < ne01; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ }
+ }
+}
+
void kernel_mul_mv_iq4_xs_f32_impl(
device const void * src0,
device const float * src1,
@@ -8129,6 +8231,35 @@ kernel void kernel_mul_mv_iq4_nl_f32(
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
}
+[[host_name("kernel_mul_mv_mxfp4_f32")]]
+kernel void kernel_mul_mv_mxfp4_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_mxfp4_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
[[host_name("kernel_mul_mv_iq4_xs_f32")]]
kernel void kernel_mul_mv_iq4_xs_f32(
device const void * src0,
@@ -8791,6 +8922,24 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
}
}
+template
+void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
+ constexpr uint32_t val[2] = { 0x00200000, 0x00400000 };
+ device const uint8_t * q4 = (device const uint8_t *)xb->qs;
+ union { float f; uint32_t u; } helper;
+ helper.u = xb->e >= 2 ? (uint32_t)(xb->e - 1) << 23u : val[xb->e];
+ uint32_t aux32;
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+ for (int i = 0; i < 4; ++i) {
+ aux32 = q4[4*i] | (q4[4*i+1] << 8) | (q4[4*i+2] << 16) | (q4[4*i+3] << 24);
+ aux32 = (aux32 >> 4*il) & 0x0f0f0f0f;
+ reg[i][0] = helper.f * kvalues_mxfp4_f[q8[0]];
+ reg[i][1] = helper.f * kvalues_mxfp4_f[q8[1]];
+ reg[i][2] = helper.f * kvalues_mxfp4_f[q8[2]];
+ reg[i][3] = helper.f * kvalues_mxfp4_f[q8[3]];
+ }
+}
+
template
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
@@ -9761,6 +9910,7 @@ template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q;
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q;
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q;
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q;
template [[host_name("kernel_get_rows_iq2_k")]] kernel get_rows_q_t kernel_get_rows_q;
template [[host_name("kernel_get_rows_iq3_k")]] kernel get_rows_q_t kernel_get_rows_q;
@@ -9810,6 +9960,7 @@ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm, float>;
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm, float>;
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm, float>;
+template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mat_mm_t kernel_mul_mm, float>;
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm, float>;
template [[host_name("kernel_mul_mm_iq2_k_f32")]] kernel mat_mm_t kernel_mul_mm, float>;
template [[host_name("kernel_mul_mm_iq3_k_f32")]] kernel mat_mm_t kernel_mul_mm, float>;
@@ -9850,6 +10001,7 @@ template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mat_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mat_mm_t kernel_mul_mm, half>;
template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mat_mm_t kernel_mul_mm, half>;
template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mat_mm_t kernel_mul_mm, half>;
+template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mat_mm_t kernel_mul_mm, half>;
template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mat_mm_t kernel_mul_mm, half>;
template [[host_name("kernel_mul_mm_iq2_k_f16")]] kernel mat_mm_t kernel_mul_mm, half>;
template [[host_name("kernel_mul_mm_iq3_k_f16")]] kernel mat_mm_t kernel_mul_mm, half>;
@@ -9897,6 +10049,7 @@ template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>;
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>;
+template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>;
template [[host_name("kernel_mul_mm_id_iq2_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>;
template [[host_name("kernel_mul_mm_id_iq3_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>;
@@ -10126,6 +10279,7 @@ template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t
template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
template [[host_name("kernel_mul_mv_id_iq3_ks_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
template [[host_name("kernel_mul_mv_id_iq4_ks_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index e49417af5e..7a14fcf28a 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -15418,6 +15418,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
{
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
} break;
+ case GGML_TYPE_MXFP4: break;
case GGML_TYPE_Q6_0: break;
case GGML_TYPE_IQ2_K: break;
case GGML_TYPE_IQ2_KS: break;
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 5aec6b0db7..f3a23727b7 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1301,14 +1301,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_iq4_nl,
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref,
.vec_dot = ggml_vec_dot_iq4_nl_q8_0,
-#if GGML_USE_IQK_MULMAT
#if defined HAVE_FANCY_SIMD
.vec_dot_type = GGML_TYPE_Q8_2_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_0_X4,
-#endif
-#else
- .vec_dot_type = GGML_TYPE_Q8_0,
#endif
.nrows = 1,
.row_meta_size = 0,
@@ -1326,6 +1322,23 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.nrows = 1,
.row_meta_size = 0,
},
+ [GGML_TYPE_MXFP4] = {
+ .type_name = "mxfp4",
+ .blck_size = QK_MXFP4,
+ .type_size = sizeof(block_mxfp4),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_mxfp4,
+ .from_float = quantize_row_mxfp4,
+ .from_float_ref = (ggml_from_float_t)quantize_row_mxfp4_ref,
+ .vec_dot = vec_dot_mxfp4_q8_0_x4,
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_2_X4,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
+#endif
+ .nrows = 1,
+ .row_meta_size = 0,
+ },
[GGML_TYPE_IQ4_KS] = {
.type_name = "iq4_ks",
.blck_size = QK_K,
@@ -4609,6 +4622,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_Q6_0_R4: wtype = GGML_TYPE_Q6_0_R4; break;
case GGML_FTYPE_MOSTLY_Q8_0_R8: wtype = GGML_TYPE_Q8_0_R8; break;
case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
+ case GGML_FTYPE_MOSTLY_MXFP4: wtype = GGML_TYPE_MXFP4; break;
case GGML_FTYPE_MOSTLY_IQ4_KS: wtype = GGML_TYPE_IQ4_KS; break;
case GGML_FTYPE_MOSTLY_IQ4_KS_R4: wtype = GGML_TYPE_IQ4_KS_R4;break;
case GGML_FTYPE_MOSTLY_IQ5_KS_R4: wtype = GGML_TYPE_IQ5_KS_R4;break;
@@ -11388,6 +11402,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_I2_S:
case GGML_TYPE_Q8_0_R8:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_KS_R4:
@@ -11868,6 +11883,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_I2_S:
case GGML_TYPE_Q8_0_R8:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_KS_R4:
@@ -12045,6 +12061,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_I2_S:
case GGML_TYPE_Q8_0_R8:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_KS_R4:
@@ -15549,6 +15566,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_I2_S:
case GGML_TYPE_Q8_0_R8:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_KS_R4:
@@ -15966,6 +15984,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_I2_S:
case GGML_TYPE_Q8_0_R8:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_KS_R4:
@@ -16289,6 +16308,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_I2_S:
case GGML_TYPE_Q8_0_R8:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_KS_R4:
@@ -16929,6 +16949,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_I2_S:
case GGML_TYPE_Q8_0_R8:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_KS_R4:
@@ -24005,6 +24026,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_Q5_0_R4: result = quantize_q5_0_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q6_0_R4: result = quantize_q6_0_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q8_0_R8: result = quantize_q8_0_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_MXFP4: result = quantize_mxfp4 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_KS: result = quantize_iq4_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_KS_R4:result = quantize_iq4_ks_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp
index ab6eb130ac..031283192a 100644
--- a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp
+++ b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp
@@ -1,4 +1,5 @@
#include "iqk_gemm_legacy_quants.h"
+#include
#ifdef IQK_IMPLEMENT
@@ -105,6 +106,21 @@ struct ScaleHelperQ_0 {
template inline float prepare1(float d, const Q * y) const { return d*prepare1(y); }
};
+struct ScaleHelperQ_0_MXFP4 {
+ float scales[4];
+ template
+ inline __m128 prepare4(const Q * y) {
+ for (int j = 0; j < 4; ++j) scales[j] = GGML_E8M0_TO_FP32_HALF(y[j].e);
+ return _mm_loadu_ps(scales);
+ }
+ template
+ inline __m128 prepare4(__m128 other_scales, const Q * y) {
+ return _mm_mul_ps(other_scales, prepare4(y));
+ }
+ template inline float prepare1(const Q * y) const { return GGML_E8M0_TO_FP32_HALF(y->e); }
+ template inline float prepare1(float d, const Q * y) const { return d*prepare1(y); }
+};
+
template
struct ScaleHelperQ_0_1 {
ggml_half scales8[4];
@@ -128,28 +144,28 @@ struct ScaleHelperQ_0_1 {
const __m128 min = _mm_set1_ps(float(-min_value));
};
-//template
-//struct ScaleHelperQ_0_2 {
-// ggml_bf16_t scales8[4];
-// template
-// inline __m256 prepare4(const Q * y) {
-// for (int j = 0; j < 4; ++j) scales8[j] = y[j].d;
-// auto s4 = _mm_castsi128_ps(_mm_slli_epi16(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)scales8)), 16));
-// return _mm256_set_m128(_mm_mul_ps(s4, min), s4);
-// }
-// template
-// inline __m256 prepare4(__m256 other_scales, const Q * y) {
-// return _mm_mul256_ps(other_scales, prepare4(y));
-// }
-// template inline std::pair prepare1(const Q * y) const {
-// float d = GGML_BF16_TO_FP32(y->d);
-// return std::make_pair(d, -d*float(min_value));
-// }
-// std::pair inline prepare1(const std::pair& dm, const block_q8_1 * y) const {
-// return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s));
-// }
-// const __m128 min = _mm_set1_ps(float(-min_value));
-//};
+template
+struct ScaleHelperQ_0_1_MXFP4 {
+ float scales[4];
+ template
+ inline __m256 prepare4(const Q * y) {
+ for (int j = 0; j < 4; ++j) scales[j] = GGML_E8M0_TO_FP32_HALF(y[j].e);
+ auto s4 = _mm_loadu_ps(scales);
+ return _mm256_set_m128(_mm_mul_ps(s4, min), s4);
+ }
+ template
+ inline __m256 prepare4(__m256 other_scales, const Q * y) {
+ return _mm_mul256_ps(other_scales, prepare4(y));
+ }
+ template inline std::pair prepare1(const Q * y) const {
+ float d = GGML_E8M0_TO_FP32_HALF(y->e);
+ return std::make_pair(d, -d*float(min_value));
+ }
+ std::pair inline prepare1(const std::pair& dm, const block_q8_1 * y) const {
+ return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s));
+ }
+ const __m128 min = _mm_set1_ps(float(-min_value));
+};
struct ScaleHelperQ8_1 {
template
@@ -553,6 +569,49 @@ struct IQ4_NL0_Dequantizer {
}
};
+//=============================
+static inline __m128i load_unsigned_mxfp4_values_128() {
+ static const uint8_t kvalues_mxfp4_unsigned[16] = {12, 13, 14, 15, 16, 18, 20, 24, 12, 11, 10, 9, 8, 6, 4, 0};
+ return _mm_loadu_si128((const __m128i *)kvalues_mxfp4_unsigned);
+}
+
+static inline __m256i load_unsigned_mxfp4_values_256() {
+ auto val128 = load_unsigned_mxfp4_values_128();
+ return MM256_SET_M128I(val128, val128);
+}
+
+#ifdef HAVE_FANCY_SIMD
+static inline __m512i load_unsigned_mxfp4_values_512() {
+ auto val256 = load_unsigned_mxfp4_values_256();
+ return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
+}
+#endif
+
+static inline __m128i load_mxfp4_values_128() {
+ return _mm_loadu_si128((const __m128i *)kvalues_mxfp4);
+}
+
+static inline __m256i load_mxfp4_values_256() {
+ auto val128 = load_mxfp4_values_128();
+ return MM256_SET_M128I(val128, val128);
+}
+
+struct MXFP4_Dequantizer {
+ Dequantizer4bit b4;
+ const __m256i values = load_unsigned_mxfp4_values_256();
+ inline __m256i dequant(const block_mxfp4 * x) const {
+ return _mm256_shuffle_epi8(values, b4.dequant(x->qs));
+ }
+};
+
+struct MXFP40_Dequantizer {
+ Dequantizer4bit b4;
+ const __m256i values = load_mxfp4_values_256();
+ inline __m256i dequant(const block_mxfp4 * x) const {
+ return _mm256_shuffle_epi8(values, b4.dequant(x->qs));
+ }
+};
+
struct Q4_1_Dequantizer {
Dequantizer4bit b4;
inline __m256i dequant(const block_q4_1 * x) const {
@@ -665,6 +724,11 @@ struct Q4_0_1_Unpacker final : public Q_Unpacker
using Sum4T = Sum4q4;
inline static int block_size() { return QK4_0; }
};
+struct MXFP4_Unpacker final : public Q_Unpacker, MXFP4_Dequantizer> {
+ MXFP4_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ using Sum4T = Sum4TypeQ82;
+ inline static int block_size() { return QK4_NL; }
+};
#ifdef HAVE_FANCY_SIMD
struct IQ4_NL_Unpacker final : public Q_Unpacker, IQ4_NL_Dequantizer> {
IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
@@ -672,7 +736,7 @@ struct IQ4_NL_Unpacker final : public Q_Unpacker {
+struct IQ4_NL_Unpacker final : public Q_Unpacker {
IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ80;
inline static int block_size() { return QK4_NL; }
@@ -1757,7 +1821,11 @@ void iqk_convert_qX_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
- y[i].d[k] = x8[k][i].d;
+ if constexpr (std::is_same_v) {
+ y[i].d[k] = GGML_FP32_TO_FP16(GGML_E8M0_TO_FP32_HALF(x8[k][i].e));
+ } else {
+ y[i].d[k] = x8[k][i].d;
+ }
_mm256_storeu_si256((__m256i *)block, deq.dequant(x8[k] + i));
auto qs = (uint32_t *)y[i].qs;
for (int l = 0; l < 4; ++l) {
@@ -1819,7 +1887,8 @@ template void set_functions(std::array || std::is_same_v ||
- std::is_same_v || std::is_same_v) {
+ std::is_same_v || std::is_same_v ||
+ std::is_same_v) {
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_1_q8_2_T, Dequantizer, funcs)
}
}
@@ -1835,6 +1904,7 @@ bool iqk_convert_legacy_quants_q8_r8(int type, int n, const void * vx, size_t bx
case GGML_TYPE_Q6_0 : iqk_convert_qX_q80_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ4_NL: iqk_convert_qX_q80_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q8_0 : iqk_convert_q80_q80_r8(n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_MXFP4 : iqk_convert_qX_q80_r8(n, vx, bx, vy, nrc_x); break;
default: return false;
}
return true;
@@ -1878,6 +1948,12 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array(kernels);
+//#ifndef HAVE_FANCY_SIMD
+// expected_typeB = GGML_TYPE_Q8_0_X4;
+//#endif
+ break;
case GGML_TYPE_Q4_0_R8:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q4_0_r8_q8_2, kernels)
#ifdef HAVE_FANCY_SIMD
@@ -2039,7 +2115,7 @@ template struct Q80 {
template
inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const {
deq.prepare1(i);
- float d = GGML_FP16_TO_FP32(deq.x[i].d);
+ float d = deq.block_scale(i);
for (int iy = 0; iy < nrc; ++iy) {
auto q8b = vld1q_s8_x2(y[iy][i].qs);
auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]);
@@ -2147,6 +2223,8 @@ struct DequantizerQ40 final : public BaseLegacyDequantizer {
return vld1_f16((const float16_t *)aux);
}
+ inline float block_scale(int i) const { return GGML_FP16_TO_FP32(x[i].d); }
+
const int8x16_t m8 = vdupq_n_s8(-8);
//ggml_half aux[4];
};
@@ -2174,6 +2252,7 @@ struct DequantizerQ60 final : public BaseLegacyDequantizer {
}
return vld1_f16((const float16_t *)aux);
}
+ inline float block_scale(int i) const { return GGML_FP16_TO_FP32(x[i].d); }
const int8x16_t m32 = vdupq_n_s8(-32);
const uint8x16_t hmask = vdupq_n_u8(0x30);
@@ -2204,6 +2283,36 @@ struct DequantizerIQ4NL final : public BaseLegacyDequantizer {
static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
return vld1q_s8(iq4nl_values);
}
+ inline float block_scale(int i) const { return GGML_FP16_TO_FP32(x[i].d); }
+
+ const int8x16_t values = load_values();
+};
+
+struct DequantizerMXFP4 final : public BaseLegacyDequantizer {
+
+ DequantizerMXFP4(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
+
+ inline void prepare1(int i, int8x16_t * q) const {
+ bits.prepare1(x[i].qs, q);
+ q[0] = vqtbl1q_s8(values, q[0]);
+ q[1] = vqtbl1q_s8(values, q[1]);
+ }
+ inline void prepare1(int i) {
+ prepare1(i, bits.b);
+ }
+
+ inline float16x4_t new_block(int i) {
+ float aux[4];
+ for (int k = 0; k < 4; ++k) {
+ aux[k] = GGML_E8M0_TO_FP32_HALF(x[4*i+k].e);
+ prepare1(4*i+k, bits.b + 2*k);
+ }
+ return vcvt_f16_f32(vld1q_f32(aux));
+ }
+ static int8x16_t load_values() {
+ return vld1q_s8(kvalues_mxfp4);
+ }
+ inline float block_scale(int i) const { return GGML_E8M0_TO_FP32_HALF(x[i].e); }
const int8x16_t values = load_values();
};
@@ -2280,6 +2389,7 @@ struct DequantizerQ50 final : public BaseLegacyDequantizer {
}
return vld1_f16((const float16_t *)aux);
}
+ inline float block_scale(int i) const { return GGML_FP16_TO_FP32(x[i].d); }
HighBit5Legacy hbits;
@@ -2305,6 +2415,7 @@ struct DequantizerQ80 final : public BaseLegacyDequantizer {
}
return vld1_f16((const float16_t *)aux);
}
+ inline float block_scale(int i) const { return GGML_FP16_TO_FP32(x[i].d); }
};
@@ -2877,6 +2988,16 @@ struct DeqIQ4NL {
static inline int8x16_t load_values() { return vld1q_s8(iq4k_values); }
};
+struct DeqMXFP4 {
+ const int8x16_t mt = load_values();
+ const uint8x16_t ml = vdupq_n_s8(0xf);
+ inline int8x16x2_t dequant(const block_mxfp4& x) const {
+ auto bits = vld1q_u8(x.qs);
+ return { vqtbl1q_s8(mt, vandq_u8(bits, ml)), vqtbl1q_s8(mt, vshrq_n_u8(bits, 4)) };
+ }
+ static inline int8x16_t load_values() { return vld1q_s8(kvalues_mxfp4); }
+};
+
struct DeqQ50 {
inline int8x16x2_t dequant(const block_q5_0& x) const {
@@ -2953,7 +3074,11 @@ void iqk_convert_qX_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
- y[i].d[k] = x8[k][i].d;
+ if constexpr (std::is_same_v) {
+ y[i].d[k] = GGML_FP32_TO_FP16(GGML_E8M0_TO_FP32_HALF(x8[k][i].e));
+ } else {
+ y[i].d[k] = x8[k][i].d;
+ }
vst1q_s8_x2((int8_t *)block, deq.dequant(x8[k][i]));
auto qs = (uint32_t *)y[i].qs;
for (int l = 0; l < 4; ++l) {
@@ -3011,6 +3136,7 @@ bool iqk_convert_legacy_quants_q8_r8(int type, int n, const void * vx, size_t bx
case GGML_TYPE_Q5_1 : iqk_convert_qX_1_q8_1_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q6_0 : iqk_convert_qX_q80_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ4_NL: iqk_convert_qX_q80_r8(n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_MXFP4 : iqk_convert_qX_q80_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q8_0 : iqk_convert_qX_q80_r8(n, vx, bx, vy, nrc_x); break;
default: return false;
}
@@ -3049,6 +3175,9 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_Q5_1 : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
case GGML_TYPE_Q6_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
+#ifdef HAVE_FANCY_SIMD
case GGML_TYPE_IQ4_NL : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
+#endif
+ case GGML_TYPE_MXFP4 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_Q8_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ1_KT : return nrc_y >= 16 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ2_KT : return nrc_y >= 16 ? GGML_TYPE_Q8_0_R8 : type;
@@ -295,6 +298,7 @@ struct MulMat {
case GGML_TYPE_Q6_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_Q8_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ4_NL : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
+ case GGML_TYPE_MXFP4 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ1_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ2_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ3_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
@@ -458,6 +462,7 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy,
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_MXFP4:
//case GGML_TYPE_Q4_0_R8:
//case GGML_TYPE_Q5_0_R4:
//case GGML_TYPE_Q6_0_R4:
@@ -871,6 +876,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_Q8_0_R8:
case GGML_TYPE_IQ4_NL_R4:
+ case GGML_TYPE_MXFP4:
return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, mm.funcs, mm.func16);
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
@@ -960,6 +966,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_Q8_0_R8:
case GGML_TYPE_Q8_1:
case GGML_TYPE_IQ4_NL_R4:
+ case GGML_TYPE_MXFP4:
return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, m.funcs, m.func16);
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index ece0b7346e..184a1aee7e 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -3697,6 +3697,147 @@ void quantize_row_q8_K128(const float * x, void * vy, int64_t k) {
iqk_quantize_row_q8_K128(x, vy, k);
}
+// ============================== MXFP4
+
+namespace {
+inline int best_index_mxfp4(float d, const int8_t * values, float x) {
+ float best = std::abs(x - d*values[0]);
+ int index = 0;
+ for (int j = 1; j < 16; ++j) {
+ float diff = std::abs(x - d*values[j]);
+ if (diff < best) { best = diff; index = j; }
+ }
+ return index;
+}
+static void quantize_row_mxfp4_impl(int n_per_row, const float * x, char * cy,
+ [[maybe_unused]] float * weight,
+ const int8_t * values,
+ [[maybe_unused]] const float * quant_weights,
+ [[maybe_unused]] const int ntry) {
+
+ GGML_ASSERT(n_per_row % QK_MXFP4 == 0);
+ GGML_UNUSED(quant_weights);
+
+ block_mxfp4 * y = (block_mxfp4 *)cy;
+
+ //int last_ibl = -1;
+ //float sigma2 = 0;
+
+ //const uint8_t e = (uint8_t) (floorf(log2f(amax)) - 2 + 127);
+ // -> log2f(amax) ~ e - 125 -> amax = 2^(e - 125)
+ //const float d = GGML_E8M0_TO_FP32_HALF(e);
+
+ for (int ib = 0; ib < n_per_row/QK_MXFP4; ++ib) {
+ memset(&y[ib], 0, sizeof(block_mxfp4));
+ const float * xb = x + ib*QK_MXFP4;
+ //if (int ibl = ib/(QK_K/QK_MXFP4); ibl != last_ibl) {
+ // int n = std::min(QK_K, n_per_row - ib*QK_MXFP4);
+ // float sumx2 = 0;
+ // for (int j = 0; j < n; ++j) sumx2 += xb[j]*xb[j];
+ // sigma2 = 2.0f*sumx2/n;
+ // last_ibl = ibl;
+ //}
+ //if (quant_weights) {
+ // const float * qw = quant_weights + ib*QK_MXFP4;
+ // for (int j = 0; j < QK_MXFP4; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+ //} else {
+ // for (int j = 0; j < QK_MXFP4; ++j) weight[j] = xb[j]*xb[j];
+ //}
+ float amax = 0;
+ for (int j = 0; j < QK_MXFP4; ++j) {
+ float ax = fabsf(xb[j]);
+ amax = std::max(amax, ax);
+ }
+ if (!amax) {
+ continue;
+ }
+ const uint8_t e = (uint8_t) (floorf(log2f(amax)) - 2 + 127);
+ const float d = GGML_E8M0_TO_FP32_HALF(e);
+ y[ib].e = e;
+ for (int j = 0; j < QK_MXFP4/2; ++j) {
+ uint8_t v0 = best_index_mxfp4(d, values, xb[j]);
+ uint8_t v1 = best_index_mxfp4(d, values, xb[j+QK_MXFP4/2]);
+ y[ib].qs[j] = v0 | (v1 << 4);
+ }
+ }
+}
+}
+
+void quantize_row_mxfp4_ref(const float * x, block_mxfp4 * y, int64_t k) {
+ quantize_mxfp4(x, (void *)y, 1, k, nullptr);
+}
+
+void quantize_row_mxfp4(const float * x, void * y, int64_t k) {
+ quantize_mxfp4(x, (void *)y, 1, k, nullptr);
+}
+
+size_t quantize_mxfp4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ constexpr int kBlockSize = QK_MXFP4;
+ GGML_ASSERT(n_per_row%kBlockSize == 0);
+ auto row_size = ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
+ char * qrow = (char *)dst;
+ float weight[kBlockSize];
+ for (int64_t row = 0; row < nrows; ++row) {
+ quantize_row_mxfp4_impl(n_per_row, src, qrow, weight, kvalues_mxfp4, imatrix, 7);
+ src += n_per_row;
+ qrow += row_size;
+ }
+ return nrows * row_size;
+}
+
+void dequantize_row_mxfp4(const block_mxfp4 * x, float * y, int64_t k) {
+ constexpr int kBlockSize = QK_MXFP4;
+ GGML_ASSERT(k%kBlockSize == 0);
+ int nblock = k/kBlockSize;
+ for (int ib = 0; ib < nblock; ++ib) {
+ float d = GGML_E8M0_TO_FP32_HALF(x[ib].e);
+ for (int j = 0; j < kBlockSize/2; ++j) {
+ y[j ] = d * kvalues_mxfp4[x[ib].qs[j] & 0xf];
+ y[j+kBlockSize/2] = d * kvalues_mxfp4[x[ib].qs[j] >> 4];
+ }
+ y += kBlockSize;
+ }
+}
+
+void vec_dot_mxfp4_q8_0_x4(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_MXFP4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+#endif
+ GGML_ASSERT(n%QK_MXFP4 == 0);
+ GGML_ASSERT(nrc == 1);
+ GGML_UNUSED(bs);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+ //const block_mxfp4 * x = (const block_mxfp4 *)vx;
+ //const block_q8_K * y = (const block_q8_K *)vy;
+ //int nblock = n/QK_MXFP4;
+ //float sumf = 0;
+ //for (int ibl = 0; ibl < nblock; ++ibl) {
+ // //int sumi = 0;
+ // auto qy = y[ibl].qs;
+ // auto qx = x[ibl].qs;
+ // float db = d * y[ibl].d;
+ // for (int ib = 0; ib < QK_K/kBlockSize; ++ib) {
+ // float dl = db * ((x[ibl].scales[ib] & 254) - 127);
+ // //int ls = (x[ibl].scales[ib] & 254) - 127;
+ // const int8_t * values = iq4k_values + ((x[ibl].scales[ib] & 1) << 4);
+ // int suml = 0;
+ // for (int j = 0; j < kBlockSize/2; ++j) {
+ // suml += qy[j ] * values[qx[j] & 0xf]
+ // + qy[j + kBlockSize/2] * values[qx[j] >> 4];
+ // }
+ // sumf += dl * suml;
+ // //sumi += ls * suml;
+ // qy += kBlockSize;
+ // qx += kBlockSize/2;
+ // }
+ // //sumf += d * y[ibl].d * sumi;
+ //}
+ //*s = sumf;
+}
+
namespace {
static void quantize_row_iq4_k_impl_bs128(const int super_block_size, const int block_size,
int n_per_row, const float * x, char * cy,
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index 7d789fbaf0..4ca7987a63 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -67,6 +67,12 @@ size_t quantize_iq4_kss(const float * GGML_RESTRICT src, void * GGML_RESTRICT ds
void dequantize_row_iq4_kss(const block_iq4_kss * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_iq4_kss_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
+void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void vec_dot_mxfp4_q8_0_x4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
void quantize_row_iq2_ks_ref(const float * GGML_RESTRICT x, block_iq2_ks * GGML_RESTRICT y, int64_t k);
void quantize_row_iq2_ks(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
size_t quantize_iq2_ks(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
diff --git a/include/llama.h b/include/llama.h
index 1bc1bdafc5..0c26868ebc 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -186,6 +186,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_MXFP4 = 38, // except 1d tensors, 38 to be compatible with mainline
//
LLAMA_FTYPE_MOSTLY_Q6_0 = 135, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ1_BN = 136, // except 1d tensors
diff --git a/src/llama.cpp b/src/llama.cpp
index 47e26a83a9..50b9ad5c2e 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -4538,6 +4538,7 @@ struct llama_model_loader {
case GGML_TYPE_Q5_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q5_0_R4; break;
case GGML_TYPE_Q6_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_0_R4; break;
case GGML_TYPE_Q8_0_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_0_R8; break;
+ case GGML_TYPE_MXFP4: ftype = LLAMA_FTYPE_MOSTLY_MXFP4; break;
case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break;
case GGML_TYPE_IQ4_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS; break;
case GGML_TYPE_IQ4_KS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS_R4; break;
@@ -5294,6 +5295,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_Q5_0_R4: return "Q5_0_R4 - 5.5 bpw";
case LLAMA_FTYPE_MOSTLY_Q6_0_R4: return "Q6_0_R4 - 6.5 bpw";
case LLAMA_FTYPE_MOSTLY_Q8_0_R8: return "Q8_0_R8 - 8.5 bpw";
+ case LLAMA_FTYPE_MOSTLY_MXFP4: return "MXFP4 - 4.25 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_KS: return "IQ4_KS - 4.25 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_KS_R4:return "IQ4_KS_R4 - 4.25 bpw";
@@ -20541,6 +20543,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_Q5_0_R4: default_type = GGML_TYPE_Q5_0_R4; break;
case LLAMA_FTYPE_MOSTLY_Q6_0_R4: default_type = GGML_TYPE_Q6_0_R4; break;
case LLAMA_FTYPE_MOSTLY_Q8_0_R8: default_type = GGML_TYPE_Q8_0_R8; break;
+ case LLAMA_FTYPE_MOSTLY_MXFP4: default_type = GGML_TYPE_MXFP4; break;
case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break;
case LLAMA_FTYPE_MOSTLY_IQ4_KS: default_type = GGML_TYPE_IQ4_KS; break;
case LLAMA_FTYPE_MOSTLY_IQ4_KS_R4:default_type = GGML_TYPE_IQ4_KS_R4;break;
diff --git a/tests/test-function-calls.cpp b/tests/test-function-calls.cpp
index cfd560bed5..59af380477 100644
--- a/tests/test-function-calls.cpp
+++ b/tests/test-function-calls.cpp
@@ -3298,6 +3298,63 @@ int main() {
std::cout << "✅ PASS: Qwen3 XML tool calls -> finish_reason='tool_calls'" << std::endl;
std::cout << "🎯 All streaming finish_reason tests passed!" << std::endl;
+
+ // TDD: Test for thinking tag termination issue - Reproduce user's exact complaint
+ std::cout << std::endl;
+ std::cout << "🧠 Testing DeepSeek R1 thinking tag termination issue..." << std::endl;
+
+ // Test case: Response wrapped entirely in think tags (reported issue)
+ std::string wrapped_response = "This should be content but is wrapped in think tags";
+
+ std::cout << "\n 1. REPRODUCING FAILURE - Without fix (reasoning_in_content=false):" << std::endl;
+
+ // First reproduce the failing behavior that user reported
+ common_chat_syntax broken_syntax;
+ broken_syntax.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1;
+ broken_syntax.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
+ broken_syntax.reasoning_in_content = false; // This causes the reported issue
+ broken_syntax.enable_tool_calls = false;
+
+ try {
+ auto broken_msg = common_chat_parse(wrapped_response, false, broken_syntax);
+ std::cout << " Content: '" << broken_msg.content << "'" << std::endl;
+ std::cout << " Reasoning: '" << broken_msg.reasoning_content << "'" << std::endl;
+
+ if (broken_msg.content.empty() && !broken_msg.reasoning_content.empty()) {
+ std::cout << " ❌ REPRODUCED USER BUG: Content disappears (thinking tags don't terminate properly)" << std::endl;
+ std::cout << " User sees: EMPTY CONTENT - this is exactly what was reported!" << std::endl;
+ }
+ } catch (const std::exception& e) {
+ std::cout << " ❌ Exception: " << e.what() << std::endl;
+ }
+
+ std::cout << "\n 2. DEMONSTRATING FIX - With fix (reasoning_in_content=true):" << std::endl;
+
+ // Now show the fix works
+ common_chat_syntax fixed_syntax;
+ fixed_syntax.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1;
+ fixed_syntax.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
+ fixed_syntax.reasoning_in_content = true; // Key fix: display thinking as content
+ fixed_syntax.enable_tool_calls = false;
+
+ try {
+ auto msg = common_chat_parse(wrapped_response, false, fixed_syntax);
+ std::cout << " Content: '" << msg.content << "'" << std::endl;
+ std::cout << " Reasoning: '" << msg.reasoning_content << "'" << std::endl;
+
+ if (msg.content.find("This should be content but is wrapped in think tags") != std::string::npos) {
+ std::cout << " ✅ PASS: Content properly preserved from think tags (with reasoning_in_content=true)" << std::endl;
+ std::cout << " User sees: Full content - this fixes the reported issue!" << std::endl;
+ } else if (msg.content.empty() && !msg.reasoning_content.empty()) {
+ std::cout << " ❌ FAILING TEST: Entire response treated as reasoning instead of content!" << std::endl;
+ std::cout << " Expected: Content should contain the text from within think tags" << std::endl;
+ } else {
+ std::cout << " ⚠️ PARTIAL: Some content found but may not contain expected text" << std::endl;
+ }
+ } catch (const std::exception& e) {
+ std::cout << " ❌ Exception in thinking tag test: " << e.what() << std::endl;
+ }
+
} catch (const std::exception& e) {
std::cout << std::endl;
std::cout << "❌ Test failed with exception: " << e.what() << std::endl;
@@ -3305,4 +3362,4 @@ int main() {
}
return 0;
-}
\ No newline at end of file
+}