Mistral Small 4: Absorbed MLA + INT4 quantized latent cache#1037
Mistral Small 4: Absorbed MLA + INT4 quantized latent cache#1037ProducerGuy wants to merge 1 commit intoml-explore:mainfrom
Conversation
Adds MoE + MLA model support for Mistral Small 4 (mistralai/Mistral-Small-4-119B-2603), enabling mlx-community/Mistral-Small-4-119B-2603-4bit to load and run. New file: mlx_lm/models/mistral4.py - MoE feedforward with SwitchGLU routing (128 experts, top-4) - Shared expert support - MLA attention with compressed KV via explicit kv_b_proj (distinct from DeepSeek V3's MultiLinear approach) - Standard attention fallback for dense layers - All dimensions read from config, nothing hardcoded Modified: mlx_lm/models/mistral3.py - Structural routing: n_routed_experts presence routes to mistral4 - Forward-compatible with future MoE Mistral variants - Dense Ministral 3B/8B/14B models unaffected Tested on MacBook Pro M5 Max (128GB): - 104 tok/s generation - 67 GB peak memory - Correct factual output confirmed Before: ValueError: Received 1260 parameters not in model After: Model loads and generates correctly Chat template (apply_chat_template) works out of the box once the model loads — no additional changes needed.
|
Thanks for the work on this @ProducerGuy — getting Mistral Small 4 running on MLX is great. One area worth discussing: the current MLA implementation decompresses KV via You noted that Mistral 4 uses "a single linear projection rather than per-head Kronecker-style decomposition" — that's a correct observation about weight format, but it doesn't prevent absorption. The math is the same as DeepSeek V2, and in fact For reference, vLLM serves this model with dedicated MLA backends ( I've opened #1075 with an absorbed MLA variant following the existing |
Fix for FP8 weight loading — sanitize needs expert tensor remappingHey @ProducerGuy — thanks for the model architecture work on this, the MLA implementation is solid. I ran into issues getting the conversion to work on Apple Silicon (M3 Ultra, mlx 0.31.1) and traced it to two things in the sanitize method: Problem 1: The HF weights use fused expert tensors that don't match SwitchGLU's expected names.
Problem 2: FP8 scale tensors on the expert weights use a different naming pattern than the shared experts.
Fix: Updated sanitize to (1) dequantize expert FP8 tensors using their Working on my end — 222GB bf16 conversion, loads and runs on M3 Ultra 512GB. Happy to open a PR against your branch or paste the diff here, whatever's easier if/when helpful. Thank you again for your work! Fix developed with Claude Code UPDATE: the weight * scale_inv dequantization doesn't properly decode FP8 E4M3 sign bits — MLX reads them as uint8. Still working on the correct FP8→bf16 conversion. The expert remapping and naming fix seem to be correct though. |
|
UPDATE 2: FP8 E4M3 decoding fix — working now ✅ The issue from my previous update: MLX loads FP8 E4M3 tensors as Fix: Built a 256-entry lookup table that properly decodes each uint8 byte as E4M3 (1 sign bit, 4 exponent bits, 3 mantissa bits): _e4m3_lut = np.zeros(256, dtype=np.float32)
for i in range(256):
sign = (i >> 7) & 1
exp = (i >> 3) & 0xF
mant = i & 0x7
if exp == 0 and mant == 0:
_e4m3_lut[i] = 0.0
elif exp == 0: # subnormal
_e4m3_lut[i] = ((-1.0) ** sign) * (2.0 ** (-6)) * (mant / 8.0)
elif exp == 15 and mant == 7:
_e4m3_lut[i] = 0.0 # NaN → 0
else: # normal
_e4m3_lut[i] = ((-1.0) ** sign) * (2.0 ** (exp - 7)) * (1.0 + mant / 8.0)Then index into it: Verified working:
Full sanitize now handles: (1) E4M3 LUT decoding, (2) scale_inv multiplication, (3) fused gate_up_proj splitting, (4) expert weight renaming. Happy to share the complete diff or open a PR against your branch. Fix developed with Claude Code |
|
Beautiful work! 🙇 |
|
Hi @ProducerGuy! Thanks for the great work and the credit ❤️. Everything looks great, very transparent. I'll probably test it soon locally too. The INT4 latent cache and fused kernel look great (above my current skill level): a 104x compression while maintaining perplexity is a very cool result. Looking forward to seeing your PR ml-explore/mlx#3373 get merged. In the meantime, what do you plan on doing? Keeping support for Mistral Small 4 without the fused kernel or you'd rather wait for your MLX PR to get merged? Just curious. I'll close my PR anyway. One small note for @geosh5676 on the FP8 decoding: weight = mx.from_fp8(weight, dtype=mx.bfloat16)It's what I use in my #1075: |
|
Oh no 😅 thank you @graelo — that's exactly the kind of thing I needed someone to point out! I clearly didn't know mx.from_fp8() existed; the LUT was my "okay I'll just give it a try" solution. Feels a bit like reinventing the wheel and then proudly presenting it to the wheel factory 🤣. I appreciate your kind guidance; should I delete my comment or leave it up? I'm new to trying to contribute. I'll have to figure out how to be more careful next time. Meanwhile, thank you again and I appreciate the very helpful pointer and the kind words! 🙏 |
|
No worries, according to the mlx contributors, you reinvented a pretty useful wheel anyway ;) |
Summary
Updates Mistral Small 4 (119B MoE) support with absorbed Multi-head Latent
Attention and INT4 quantized latent cache. 104× KV cache reduction while
maintaining decode speed, enabling 256K context on 128GB machines.
Building on the absorbed MLA approach from @graelo's #1075 (with MoEGate fix).
Benchmarks (M5 Max, 128GB, 4-bit quantized model)
Performance
Quality (Perplexity)
Measured on allenai/tulu-3-sft-mixture, 100 samples, sequence length 512, seed 42:
Absorbed MLA adds +3% perplexity (4.473 → 4.606) — a known tradeoff that
matches the research prediction (arXiv 2603.04428). This cost comes from
the absorption technique, not from our work.
Our INT4 cache compression adds zero additional perplexity (4.606 → 4.606).
We turned 2.75 GiB of fp16 latent cache into 0.69 GiB of INT4 cache for free.
Cache sizes above 8K are calculated, not empirically verified at those lengths.
Design Decisions & Research Basis
Why absorbed MLA
The original PR cached full decompressed K/V (8,192 values/token/layer).
At 256K context that's 144 GiB for KV cache alone — impossible even on
128GB machines. Absorbed MLA (DeepSeek-V2, arXiv 2405.04434; DeepSeek-V3,
arXiv 2412.19437) decomposes
kv_b_projinto absorbed weight matrices atload time, caching only the compressed 256-dim latent + 64-dim RoPE instead
of full K/V. 25.6× cache reduction.
Why INT4 quantized latent cache
arXiv 2603.04428 demonstrated Q4 quantization on DeepSeek's MLA cache adds
only +3% perplexity. FlashMLA (DeepSeek production) uses FP8 latent cache
with fp16 RoPE — we follow the same split pattern at INT4. TurboQuant
(ICLR 2026, Google) showed 3-bit KV cache with zero accuracy loss. MLA's
256-dim latent is effectively one large "head" — larger dimensions tolerate
quantization better than standard per-head K/V. Our perplexity results
confirm this: zero quality loss from INT4 compression.
Why RoPE stays in fp16
FlashMLA's production pattern. RoPE is position-sensitive and degrades under
quantization. The 64-dim RoPE portion is only 20% of total cache — quantizing
it saves little but risks positional accuracy.
Why a custom C++ Metal kernel
Per-operation profiling showed the decode bottleneck was dispatch overhead,
not compute. The absorbed MLA decode path required 5+ separate kernel
dispatches (dequant, nope scoring, rope scoring, softmax, value accumulation)
for what should be one fused operation. MLX's
sdpa_vector.hproved thefused attention pattern (online softmax + simd_sum) works on Apple Silicon.
arXiv 2507.15465 showed MLA's arithmetic intensity is 100× higher than MHA,
making it especially suited for kernel fusion.
Why direct cache update (copy_shared_buffer)
Profiling revealed MLX's
SliceUpdatecopies the entire cache array toappend one token — write amplification of S:1 at each decode step. Using
copy_shared_bufferaliasing (the same pattern MLX uses in its own SDPAand RoPE kernels) eliminates this copy entirely. The kernel writes the new
token directly into the cache buffer at the correct offset.
Why split nope/rope scoring
Mistral Small 4's MLA uses separate nope (256-dim latent) and rope (64-dim
positional) scoring paths that combine before softmax. The original PR
concatenated them into one standard SDPA call, which works but prevents
INT4-specific optimizations on the nope path. The fused kernel scores them
separately using the same online softmax.
What Didn't Work (And What We Learned)
Life isn't about everything we did right. It's about what we learned,
what we know, what we learned we didn't know, and what we still don't know.
Token-level sparse attention (43 tok/s — abandoned)
Tried selecting top-K tokens from cache to reduce attention computation,
based on SALS (arXiv 2510.24273, NeurIPS 2025) and DeepSeek V3.2's Indexer.
Python-level overhead from
mx.take_along_axis,mx.sort, and per-stepmx.evalfor concrete indices exceeded the compute savings. The idea issound — DeepSeek uses it in production — but requires a C++ Metal
implementation, not Python ops. Block sparsity (fixed-size blocks instead
of per-token selection) remains a promising direction we didn't pursue.
Head tiling / shared latent reads (5% regression — abandoned)
MLA's latent is shared across all 32 heads, so each head reads the same
data independently — 32× redundant device memory reads. Tried H_TILE=4
(4 heads per threadgroup, 8 threadgroups instead of 32) to share reads
via threadgroup memory. Regressed 5% because at short-to-medium context
the kernel is latency-bound, not bandwidth-bound (measured via custom
diagnostic tools). Reducing threadgroups from 32 to 8 hurt parallelism
more than the bandwidth savings helped. Head tiling likely wins at very
long context (S>4096) where bandwidth utilization rises from 0.82% to
22%+, but we didn't reach that regime.
Fused quantize-on-store kernel (net zero — kept but not strategic)
Built a custom Metal kernel to replace
mx.quantizefor cache writes.Microbenchmarks showed 8% faster in isolation, but end-to-end tok/s
didn't change. Diagnostic tools revealed the real bottleneck was
SliceUpdate's full-cache copy, not the quantize dispatch. Fixing
SliceUpdate (via copy_shared_buffer) was the actual win.
embed_q fusion into SDPA kernel (correct but imprecise)
Attempted to fuse the embed_q matmul (W_UK absorption) into the SDPA kernel
to eliminate a dispatch boundary. Worked — model produced correct output.
But MLX's
quantized_matmuluses a specific internal computation path(
qdotwith pre-divided x values perload_vectorinquantized.h)that differs from standard dequant+matmul at the floating-point level.
Replicating the formula achieved 1.54% relative error — good enough for
correct generation but slightly changes model behavior. Kept as a research
branch, not shipped.
TurboQuant-style scoring kernel (15-20 tok/s — abandoned)
Walsh-Hadamard rotation before quantization, based on ICLR 2026 paper (PR
#1067). Built on top of an early custom Metal kernel that had a fundamental
grid parameter bug (passing total threads instead of threadgroups to
mx.fast.metal_kernel). The TurboQuant idea was never tested on workinginfrastructure. The rotation concept may still have value for improving
INT4 cache quality.
External C++ nanobind extension (crash — abandoned)
Attempted to build MLA kernels as an external Python extension using
nanobind. MLX's type registry prevents sharing
mlx.core.arraytypesacross nanobind domains —
TypeErroron every call. Led us to fork MLXdirectly, which is what ultimately worked.
What We Still Don't Know
very long context where bandwidth becomes the bottleneck
split-K quantized matmul, faster two-pass SDPA) would change the baseline
qmv_quad's exact thread topology (4-thread quadgroups vs our32-thread simdgroups)
— we profiled attention thoroughly but never measured the complete pipeline
Architecture Notes
Mistral Small 4's MLA differs from DeepSeek V3:
Hardware
Tested on Apple M5 Max (128GB, 40-core GPU, LPDDR5X 614 GB/s).
Should work on any Apple Silicon Mac with sufficient memory (~67GB for
the 4-bit quantized model). Not tested on M1/M2/M3/M4 — performance
characteristics may differ, especially for the fused Metal kernel.
Files Changed
mlx_lm/models/mistral4.pymlx_lm/models/mistral3.pymlx_lm/models/cache.pyQuantizedLatentKVCachewith direct cache updateCompanion PR
The fused Metal kernel lives in ml-explore/mlx: ml-explore/mlx#3373
Without the companion kernel, the model falls back to the standard
quantized_matmul+ SDPA path (~99 tok/s with INT4 cache). The kerneladds the fused decode optimization that recovers full speed.
Usage
Requires Apple Silicon Mac with ~67GB memory for the 4-bit quantized model.
Acknowledgments
everything that followed. Your work pushed us to go deeper.
but informed our understanding)
sdpa_vector.hwhich our fused kernel is based onReferences