Skip to content

Mistral Small 4: Absorbed MLA + INT4 quantized latent cache#1037

Open
ProducerGuy wants to merge 1 commit intoml-explore:mainfrom
ProducerGuy:mistral-small-4-moe-support
Open

Mistral Small 4: Absorbed MLA + INT4 quantized latent cache#1037
ProducerGuy wants to merge 1 commit intoml-explore:mainfrom
ProducerGuy:mistral-small-4-moe-support

Conversation

@ProducerGuy
Copy link
Copy Markdown

@ProducerGuy ProducerGuy commented Mar 21, 2026

Summary

Updates Mistral Small 4 (119B MoE) support with absorbed Multi-head Latent
Attention and INT4 quantized latent cache. 104× KV cache reduction while
maintaining decode speed, enabling 256K context on 128GB machines.

Building on the absorbed MLA approach from @graelo's #1075 (with MoEGate fix).

Benchmarks (M5 Max, 128GB, 4-bit quantized model)

Performance

Metric Original PR #1037 This PR
Generation tok/s 108 ~110
Prompt tok/s 7 ~190
Peak memory 67.0 GB 67.0 GB
KV cache at 8K 4.5 GiB 0.04 GiB
KV cache at 128K ~72 GiB ~0.69 GiB
KV cache at 256K ~144 GiB (impossible) ~1.37 GiB
Cache compression 104×

Quality (Perplexity)

Measured on allenai/tulu-3-sft-mixture, 100 samples, sequence length 512, seed 42:

Version Perplexity Cache at 128K tok/s
Original PR #1037 (non-absorbed) 4.473 ± 0.065 ~72 GiB 108
@graelo's absorbed MLA (Phase 1) 4.606 ± 0.064 ~2.75 GiB 102
This PR (absorbed + INT4 cache) 4.606 ± 0.064 ~0.69 GiB ~110

Absorbed MLA adds +3% perplexity (4.473 → 4.606) — a known tradeoff that
matches the research prediction (arXiv 2603.04428). This cost comes from
the absorption technique, not from our work.

Our INT4 cache compression adds zero additional perplexity (4.606 → 4.606).
We turned 2.75 GiB of fp16 latent cache into 0.69 GiB of INT4 cache for free.

Cache sizes above 8K are calculated, not empirically verified at those lengths.

Design Decisions & Research Basis

Why absorbed MLA

The original PR cached full decompressed K/V (8,192 values/token/layer).
At 256K context that's 144 GiB for KV cache alone — impossible even on
128GB machines. Absorbed MLA (DeepSeek-V2, arXiv 2405.04434; DeepSeek-V3,
arXiv 2412.19437) decomposes kv_b_proj into absorbed weight matrices at
load time, caching only the compressed 256-dim latent + 64-dim RoPE instead
of full K/V. 25.6× cache reduction.

Why INT4 quantized latent cache

arXiv 2603.04428 demonstrated Q4 quantization on DeepSeek's MLA cache adds
only +3% perplexity. FlashMLA (DeepSeek production) uses FP8 latent cache
with fp16 RoPE — we follow the same split pattern at INT4. TurboQuant
(ICLR 2026, Google) showed 3-bit KV cache with zero accuracy loss. MLA's
256-dim latent is effectively one large "head" — larger dimensions tolerate
quantization better than standard per-head K/V. Our perplexity results
confirm this: zero quality loss from INT4 compression.

Why RoPE stays in fp16

FlashMLA's production pattern. RoPE is position-sensitive and degrades under
quantization. The 64-dim RoPE portion is only 20% of total cache — quantizing
it saves little but risks positional accuracy.

Why a custom C++ Metal kernel

Per-operation profiling showed the decode bottleneck was dispatch overhead,
not compute. The absorbed MLA decode path required 5+ separate kernel
dispatches (dequant, nope scoring, rope scoring, softmax, value accumulation)
for what should be one fused operation. MLX's sdpa_vector.h proved the
fused attention pattern (online softmax + simd_sum) works on Apple Silicon.
arXiv 2507.15465 showed MLA's arithmetic intensity is 100× higher than MHA,
making it especially suited for kernel fusion.

Why direct cache update (copy_shared_buffer)

Profiling revealed MLX's SliceUpdate copies the entire cache array to
append one token — write amplification of S:1 at each decode step. Using
copy_shared_buffer aliasing (the same pattern MLX uses in its own SDPA
and RoPE kernels) eliminates this copy entirely. The kernel writes the new
token directly into the cache buffer at the correct offset.

Why split nope/rope scoring

Mistral Small 4's MLA uses separate nope (256-dim latent) and rope (64-dim
positional) scoring paths that combine before softmax. The original PR
concatenated them into one standard SDPA call, which works but prevents
INT4-specific optimizations on the nope path. The fused kernel scores them
separately using the same online softmax.

What Didn't Work (And What We Learned)

Life isn't about everything we did right. It's about what we learned,
what we know, what we learned we didn't know, and what we still don't know.

Token-level sparse attention (43 tok/s — abandoned)

Tried selecting top-K tokens from cache to reduce attention computation,
based on SALS (arXiv 2510.24273, NeurIPS 2025) and DeepSeek V3.2's Indexer.
Python-level overhead from mx.take_along_axis, mx.sort, and per-step
mx.eval for concrete indices exceeded the compute savings. The idea is
sound — DeepSeek uses it in production — but requires a C++ Metal
implementation, not Python ops. Block sparsity (fixed-size blocks instead
of per-token selection) remains a promising direction we didn't pursue.

Head tiling / shared latent reads (5% regression — abandoned)

MLA's latent is shared across all 32 heads, so each head reads the same
data independently — 32× redundant device memory reads. Tried H_TILE=4
(4 heads per threadgroup, 8 threadgroups instead of 32) to share reads
via threadgroup memory. Regressed 5% because at short-to-medium context
the kernel is latency-bound, not bandwidth-bound (measured via custom
diagnostic tools). Reducing threadgroups from 32 to 8 hurt parallelism
more than the bandwidth savings helped. Head tiling likely wins at very
long context (S>4096) where bandwidth utilization rises from 0.82% to
22%+, but we didn't reach that regime.

Fused quantize-on-store kernel (net zero — kept but not strategic)

Built a custom Metal kernel to replace mx.quantize for cache writes.
Microbenchmarks showed 8% faster in isolation, but end-to-end tok/s
didn't change. Diagnostic tools revealed the real bottleneck was
SliceUpdate's full-cache copy, not the quantize dispatch. Fixing
SliceUpdate (via copy_shared_buffer) was the actual win.

embed_q fusion into SDPA kernel (correct but imprecise)

Attempted to fuse the embed_q matmul (W_UK absorption) into the SDPA kernel
to eliminate a dispatch boundary. Worked — model produced correct output.
But MLX's quantized_matmul uses a specific internal computation path
(qdot with pre-divided x values per load_vector in quantized.h)
that differs from standard dequant+matmul at the floating-point level.
Replicating the formula achieved 1.54% relative error — good enough for
correct generation but slightly changes model behavior. Kept as a research
branch, not shipped.

TurboQuant-style scoring kernel (15-20 tok/s — abandoned)

Walsh-Hadamard rotation before quantization, based on ICLR 2026 paper (PR
#1067). Built on top of an early custom Metal kernel that had a fundamental
grid parameter bug (passing total threads instead of threadgroups to
mx.fast.metal_kernel). The TurboQuant idea was never tested on working
infrastructure. The rotation concept may still have value for improving
INT4 cache quality.

External C++ nanobind extension (crash — abandoned)

Attempted to build MLA kernels as an external Python extension using
nanobind. MLX's type registry prevents sharing mlx.core.array types
across nanobind domains — TypeError on every call. Led us to fork MLX
directly, which is what ultimately worked.

What We Still Don't Know

  • Whether head tiling or split-KV (Flash-Decoding style) would win at
    very long context where bandwidth becomes the bottleneck
  • How much MLX's recent upstream improvements (Neural Accelerator support,
    split-K quantized matmul, faster two-pass SDPA) would change the baseline
  • Whether the embed_q precision gap can be fully closed by replicating
    qmv_quad's exact thread topology (4-thread quadgroups vs our
    32-thread simdgroups)
  • The full breakdown of decode step time between attention and MLP/experts
    — we profiled attention thoroughly but never measured the complete pipeline

Architecture Notes

Mistral Small 4's MLA differs from DeepSeek V3:

Parameter Mistral Small 4 DeepSeek V3
kv_lora_rank 256 512
qk_nope_head_dim 64 128
qk_rope_head_dim 64 64
num_attention_heads 32 128
n_group 1 8
first_k_dense_replace 0 3
rope_interleave true false

Hardware

Tested on Apple M5 Max (128GB, 40-core GPU, LPDDR5X 614 GB/s).
Should work on any Apple Silicon Mac with sufficient memory (~67GB for
the 4-bit quantized model). Not tested on M1/M2/M3/M4 — performance
characteristics may differ, especially for the fused Metal kernel.

Files Changed

File Change
mlx_lm/models/mistral4.py Absorbed MLA + fused decode path + MoEGate fix
mlx_lm/models/mistral3.py Model routing
mlx_lm/models/cache.py QuantizedLatentKVCache with direct cache update

Companion PR

The fused Metal kernel lives in ml-explore/mlx: ml-explore/mlx#3373

Without the companion kernel, the model falls back to the standard
quantized_matmul + SDPA path (~99 tok/s with INT4 cache). The kernel
adds the fused decode optimization that recovers full speed.

Usage

# With companion MLX kernel (full speed, ~110 tok/s)
pip install git+https://github.com/ProducerGuy/mlx.git@phase3c-v3-kernel-opt
pip install git+https://github.com/ProducerGuy/mlx-lm.git@phase3c-v2-direct-write

python -m mlx_lm.generate \
  --model mistralai/Mistral-Small-4-Base \
  --prompt "What is 2+2?"

# Without the companion kernel, the model still works but falls back to
# the quantized_matmul + standard SDPA path (~99 tok/s with INT4 cache).
# No code changes needed — the fallback is automatic.

Requires Apple Silicon Mac with ~67GB memory for the 4-bit quantized model.

Acknowledgments

  • @graelo for the absorbed MLA approach in Add Mistral Small 4 (119B MoE) support with absorbed MLA #1075 — the foundation for
    everything that followed. Your work pushed us to go deeper.
  • @geosh5676 for extending this work with FP8 E4M3 decoding support
  • DeepSeek-V2/V3 papers for the absorbed MLA architecture
  • arXiv 2603.04428 for empirical validation of INT4 MLA cache
  • FlashMLA for the latent + RoPE cache split pattern
  • SALS (NeurIPS 2025) for sparse attention in latent space (not shipped
    but informed our understanding)
  • MLX team for sdpa_vector.h which our fused kernel is based on

References

  • DeepSeek-V2 (arXiv 2405.04434) — original MLA paper
  • DeepSeek-V3 (arXiv 2412.19437) — absorbed MLA in production
  • arXiv 2603.04428 — Q4 MLA cache validation
  • TurboQuant (ICLR 2026, Google) — 3-bit KV cache
  • FlashMLA (github.com/deepseek-ai/FlashMLA) — production MLA kernel
  • SALS (arXiv 2510.24273, NeurIPS 2025) — sparse attention in latent space
  • arXiv 2507.15465 — MLA arithmetic intensity analysis
  • MLRA (arXiv 2603.02188, ICLR 2026) — MLA tensor parallelism

Adds MoE + MLA model support for Mistral Small 4
(mistralai/Mistral-Small-4-119B-2603), enabling
mlx-community/Mistral-Small-4-119B-2603-4bit to load and run.

New file: mlx_lm/models/mistral4.py
- MoE feedforward with SwitchGLU routing (128 experts, top-4)
- Shared expert support
- MLA attention with compressed KV via explicit kv_b_proj
  (distinct from DeepSeek V3's MultiLinear approach)
- Standard attention fallback for dense layers
- All dimensions read from config, nothing hardcoded

Modified: mlx_lm/models/mistral3.py
- Structural routing: n_routed_experts presence routes to mistral4
- Forward-compatible with future MoE Mistral variants
- Dense Ministral 3B/8B/14B models unaffected

Tested on MacBook Pro M5 Max (128GB):
- 104 tok/s generation
- 67 GB peak memory
- Correct factual output confirmed

Before: ValueError: Received 1260 parameters not in model
After: Model loads and generates correctly

Chat template (apply_chat_template) works out of the box
once the model loads — no additional changes needed.
@graelo
Copy link
Copy Markdown

graelo commented Mar 30, 2026

Thanks for the work on this @ProducerGuy — getting Mistral Small 4 running on MLX is great.

One area worth discussing: the current MLA implementation decompresses KV via kv_b_proj before caching, so the KV cache stores the full decompressed tensors — 32 × (64 + 64 + 128) = 8,192 floats/token/layer. Since there's no nonlinearity between kv_a_layernorm and kv_b_proj, the decompression matrix can be split into its K/V sub-matrices and absorbed into the query and output paths, bringing the cache down to just the compressed latent + RoPE component: 256 + 64 = 320 floats/token/layer (~25× smaller).

You noted that Mistral 4 uses "a single linear projection rather than per-head Kronecker-style decomposition" — that's a correct observation about weight format, but it doesn't prevent absorption. The math is the same as DeepSeek V2, and in fact deepseek_v3.py in this repo already does exactly this decomposition in its sanitize method.

For reference, vLLM serves this model with dedicated MLA backends (FLASH_ATTN_MLA, TRITON_MLA) that use absorbed attention, and SGLang implements forward_absorb as its MLA inference path.

I've opened #1075 with an absorbed MLA variant following the existing deepseek_v3.py pattern. Happy to coordinate if useful.

@geosh5676
Copy link
Copy Markdown

geosh5676 commented Apr 4, 2026

Fix for FP8 weight loading — sanitize needs expert tensor remapping

Hey @ProducerGuy — thanks for the model architecture work on this, the MLA implementation is solid.

I ran into issues getting the conversion to work on Apple Silicon (M3 Ultra, mlx 0.31.1) and traced it to two things in the sanitize method:

Problem 1: The HF weights use fused expert tensors that don't match SwitchGLU's expected names.

  • HF has experts.gate_up_proj [128, 4096, 4096] — gate and up concatenated
  • HF has experts.down_proj [128, 4096, 2048]
  • SwitchGLU expects switch_mlp.gate_proj.weight, switch_mlp.up_proj.weight, switch_mlp.down_proj.weight separately

Problem 2: FP8 scale tensors on the expert weights use a different naming pattern than the shared experts.

  • Expert scales: experts.down_proj_scale_inv [128, 1, 1] (suffix on tensor name, no .weight in path)
  • Shared expert scales: shared_experts.down_proj.weight_scale_inv (scalar, with .weight in path)
  • The original sanitize handled the shared expert pattern but not the fused expert pattern

Fix: Updated sanitize to (1) dequantize expert FP8 tensors using their _scale_inv, (2) split the fused gate_up_proj into separate gate and up, and (3) rename experts.Xswitch_mlp.X.weight.

Working on my end — 222GB bf16 conversion, loads and runs on M3 Ultra 512GB. Happy to open a PR against your branch or paste the diff here, whatever's easier if/when helpful.

Thank you again for your work!

Fix developed with Claude Code

UPDATE: the weight * scale_inv dequantization doesn't properly decode FP8 E4M3 sign bits — MLX reads them as uint8. Still working on the correct FP8→bf16 conversion. The expert remapping and naming fix seem to be correct though.

@geosh5676
Copy link
Copy Markdown

UPDATE 2: FP8 E4M3 decoding fix — working now ✅

The issue from my previous update: MLX loads FP8 E4M3 tensors as uint8 (raw bytes), so tensor.astype(mx.bfloat16) just converts the integer values 0-249 to floats — losing all sign information and producing all-positive weights. The model generated garbage ("Tomorrow" repeated forever).

Fix: Built a 256-entry lookup table that properly decodes each uint8 byte as E4M3 (1 sign bit, 4 exponent bits, 3 mantissa bits):

_e4m3_lut = np.zeros(256, dtype=np.float32)
for i in range(256):
    sign = (i >> 7) & 1
    exp = (i >> 3) & 0xF
    mant = i & 0x7
    if exp == 0 and mant == 0:
        _e4m3_lut[i] = 0.0
    elif exp == 0:  # subnormal
        _e4m3_lut[i] = ((-1.0) ** sign) * (2.0 ** (-6)) * (mant / 8.0)
    elif exp == 15 and mant == 7:
        _e4m3_lut[i] = 0.0  # NaN → 0
    else:  # normal
        _e4m3_lut[i] = ((-1.0) ** sign) * (2.0 ** (exp - 7)) * (1.0 + mant / 8.0)

Then index into it: decoded = e4m3_lut[tensor.reshape(-1).astype(mx.uint32)]. For the large expert tensors (128 × 4096 × 4096 = 2.1B elements, exceeds int32), process per-expert to avoid overflow.

Verified working:

  • Weight stats now correct: mean ≈ 0, ~50% negative values, reasonable range
  • Model generates coherent text (chat template produces fluent English)
  • 222GB bf16 conversion loads and runs on M3 Ultra 512GB

Full sanitize now handles: (1) E4M3 LUT decoding, (2) scale_inv multiplication, (3) fused gate_up_proj splitting, (4) expert weight renaming. Happy to share the complete diff or open a PR against your branch.

Fix developed with Claude Code

@ProducerGuy ProducerGuy changed the title Add Mistral Small 4 (119B MoE) support via mistral4.py Mistral Small 4: Absorbed MLA + INT4 quantized latent cache Apr 4, 2026
@graelo
Copy link
Copy Markdown

graelo commented Apr 4, 2026

Beautiful work! 🙇

@graelo
Copy link
Copy Markdown

graelo commented Apr 4, 2026

Hi @ProducerGuy! Thanks for the great work and the credit ❤️. Everything looks great, very transparent. I'll probably test it soon locally too. The INT4 latent cache and fused kernel look great (above my current skill level): a 104x compression while maintaining perplexity is a very cool result.

Looking forward to seeing your PR ml-explore/mlx#3373 get merged. In the meantime, what do you plan on doing? Keeping support for Mistral Small 4 without the fused kernel or you'd rather wait for your MLX PR to get merged? Just curious. I'll close my PR anyway.

One small note for @geosh5676 on the FP8 decoding: mx.from_fp8() has been available since MLX 0.30 (check the contrib named "Fp8 conversion by @awni in ml-explore/mlx#2686". It handles E4M3→bf16 natively: there's no need for the manual LUT. It's a one-liner:

weight = mx.from_fp8(weight, dtype=mx.bfloat16)

It's what I use in my #1075: _dequant_fp8, and it works on 0.31.1. (Re-)Finding that this conversion is necessary was a pretty good shot anyway, congrats @geosh5676.

@geosh5676
Copy link
Copy Markdown

Oh no 😅 thank you @graelo — that's exactly the kind of thing I needed someone to point out! I clearly didn't know mx.from_fp8() existed; the LUT was my "okay I'll just give it a try" solution. Feels a bit like reinventing the wheel and then proudly presenting it to the wheel factory 🤣. I appreciate your kind guidance; should I delete my comment or leave it up? I'm new to trying to contribute. I'll have to figure out how to be more careful next time.

Meanwhile, thank you again and I appreciate the very helpful pointer and the kind words! 🙏

@graelo
Copy link
Copy Markdown

graelo commented Apr 4, 2026

No worries, according to the mlx contributors, you reinvented a pretty useful wheel anyway ;)
Same as you, I'm just trying to contribute here and there as I can; none of us is always spot on.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants