Skip to content

ggml : add NVFP4 quantization type support#19769

Merged
CISC merged 52 commits intoggml-org:masterfrom
richarddd:feat/nvfp4
Mar 11, 2026
Merged

ggml : add NVFP4 quantization type support#19769
CISC merged 52 commits intoggml-org:masterfrom
richarddd:feat/nvfp4

Conversation

@richarddd
Copy link
Copy Markdown
Contributor

@richarddd richarddd commented Feb 20, 2026

I'm not super experienced with the ggml/gguf internals so feedback is very welcome. Note on AI usage: Claude Opus 4.6 was used for navigating the codebase, debugging, and writing parts of the code. All changes have been reviewed and tested manually. Open to reworking anything that doesn't meet the project's standards.

This adds support for NVIDIA's NVFP4 quantization format (FP4 E2M1 weights, UE4M3 per-block scale, 16 elements per block). This is the format produced by NVIDIA ModelOpt's NVFP4 algo. The main difference is the scale encoding (UE4M3 vs E8M0).

What's in here:

  • New GGML_TYPE_NVFP4 type, block struct, UE4M3 conversion helpers, reference quantize/dequantize
  • convert_hf_to_gguf.py detects NVFP4 ModelOpt models and repacks into the GGUF block format
  • CPU backend: scalar dot product + ARM NEON
  • gguf-py: type constant, quant/dequant, endian conversion
  • Tests added to test-backend-ops and test-quantize-fns

Tested with models from https://huggingface.co/NVFP4 Apple M5 MacBook (CPU, NEON) Ran llama-bench and a basic server smoke test. Would appreciate help with that if someone has a good baseline to compare against.

Here is a Qwen3-4B model to test with.

@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs Vulkan Issues specific to the Vulkan backend python python script changes ggml changes relating to the ggml tensor library for machine learning Apple Metal https://en.wikipedia.org/wiki/Metal_(API) labels Feb 20, 2026
@JohannesGaessler
Copy link
Copy Markdown
Contributor

As is clearly laid out in the llama.cpp contributing guidelines:

When adding support for a new model or feature, focus on CPU support only in the initial PR unless you have a good reason not to. Add support for other backends like CUDA in follow-up PRs

@pwilkin
Copy link
Copy Markdown
Member

pwilkin commented Feb 20, 2026

I would really love NVFP4 support and I appreciate the work done here, but as @JohannesGaessler has already mentioned, the ratio of verified information to maintainer-needed work is way too high with this PR.

Please:

  • shelf all the backend implementations for now, they should be added in separate PRs so people specialized in specific backends can look at them
  • provide a GGUF of a converted model, preferrably one that can be ran comfortably by most mtaintainers (i.e. rather 8B or 12B than 400B).
  • make a KLD analysis for a full FP16 version as documented here
  • make perplexity and KLD checks for your quantized model as well as a comparable "standard" quant (Q4_1 would probably be a good choice here)
  • run benchmark tests for a known benchmark (you can use a tool such as Inspect AI, a good quick general benchmark to run is for example ARC Challenge

@jeffbolznv
Copy link
Copy Markdown
Contributor

It would be great if nvfp4 could be stored in larger blocks that are at least a multiple of 4B (16B would be better).

@JohannesGaessler
Copy link
Copy Markdown
Contributor

I agree that memory alignment is relevant, as long as the tensor dimensions are multiples of e.g. 256 it should be feasible to permute the data upon load though (except for maybe CPU+GPU hybrid inference where the overhead could be relevant).

@ggerganov
Copy link
Copy Markdown
Member

  • make a KLD analysis for a full FP16 version as documented here
  • make perplexity and KLD checks for your quantized model as well as a comparable "standard" quant (Q4_1 would probably be a good choice here)
  • run benchmark tests for a known benchmark (you can use a tool such as Inspect AI, a good quick general benchmark to run is for example ARC Challenge

Btw, @pwilkin these are not really necessary for NVFP4 - adding support for this data type would not depend on the outcome of these. They are good for sanity checks, but other than that do not matter much. The main use case of NVFP4 is to load models that are already trained in that format - not to quantize models with it.

Regarding the alignment - I guess we can make blocks of 256 which would result in alignment of 16 bytes. Though we risk not being able to load tensors with dimension that is not multiple of 256. There was the same dilemma for MXFP4 and gpt-oss unfortunately has shapes that are only divisible by 64 but not 256.

@am17an
Copy link
Copy Markdown
Contributor

am17an commented Feb 23, 2026

NVFP4 also has a separate per tensor float scale which this PR doesn't take into account, unless I'm wrong. Also this whole PR is pretty much AI generated from what I can see. I had plans to add nvfp4 support after mxfp4 but another developer had promised to do it but since has not delivered so I will also create a PR for nvfp4 support in the meantime.

@pwilkin
Copy link
Copy Markdown
Member

pwilkin commented Feb 23, 2026

@ggerganov I know but I meant it exactly as a sanity check.

@pwilkin
Copy link
Copy Markdown
Member

pwilkin commented Feb 23, 2026

NVFP4 also has a separate per tensor float scale which this PR doesn't take into account, unless I'm wrong. Also this whole PR is pretty much AI generated from what I can see. I had plans to add nvfp4 support after mxfp4 but another developer had promised to do it but since has not delivered so I will also create a PR for nvfp4 support in the meantime.

Yeah I'm pretty frustrated as I was also thinking about working on it and was hoping this PR goes somewhere but seems it's going nowhere so far :/

@richarddd richarddd marked this pull request as draft February 23, 2026 12:27
@richarddd
Copy link
Copy Markdown
Contributor Author

NVFP4 also has a separate per tensor float scale which this PR doesn't take into account, unless I'm wrong. Also this whole PR is pretty much AI generated from what I can see. I had plans to add nvfp4 support after mxfp4 but another developer had promised to do it but since has not delivered so I will also create a PR for nvfp4 support in the meantime.

It's taken into account. And regarding AI, as mentioned in the PR, I leaned on AI and following principles patterns applied in the MXFP4 PR. I'll remove the half-baked backend implementation and stick with NEON + generic CPU implementation for now. Again, this is a WIP which proves the concept and implements a lot of the boilerplate. I'll also increased blocksize to 64.

@am17an
Copy link
Copy Markdown
Contributor

am17an commented Feb 23, 2026

It's taken into account.

It is not. Please see the f32 scale as presented here https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/

As a reminder: you are supposed to know the content of the PR even if the PR is written with AI help. See the contributing guidelines.

@richarddd
Copy link
Copy Markdown
Contributor Author

I would really love NVFP4 support and I appreciate the work done here, but as @JohannesGaessler has already mentioned, the ratio of verified information to maintainer-needed work is way too high with this PR.

Please:

  • shelf all the backend implementations for now, they should be added in separate PRs so people specialized in specific backends can look at them
  • provide a GGUF of a converted model, preferrably one that can be ran comfortably by most mtaintainers (i.e. rather 8B or 12B than 400B).
  • make a KLD analysis for a full FP16 version as documented here
  • make perplexity and KLD checks for your quantized model as well as a comparable "standard" quant (Q4_1 would probably be a good choice here)
  • run benchmark tests for a known benchmark (you can use a tool such as Inspect AI, a good quick general benchmark to run is for example ARC Challenge

Addressed these comments.

Here are results for Qwen3-4B

NVFP4 (5.0 BPW) Q4_1 (5.15 BPW)
PPL 15.25 (+8.0%) 15.81 (+12.0%)
Mean KLD 0.110 0.112
tg128 t/s 15.2 14.7
ARC Challenge (Inspect AI) 80%

@richarddd

This comment was marked as outdated.

@am17an
Copy link
Copy Markdown
Contributor

am17an commented Feb 23, 2026

Okay, not sure if that works but if it does then it's great since it simplifies the implementation quite a bit. The current state of your PR is not ok though, I see random changes in the CUDA and Vulkan code. Can you fix it?

@richarddd
Copy link
Copy Markdown
Contributor Author

I see random changes in the CUDA and Vulkan code. Can you fix it?

Thanks, I noticed that as well. The problem was a one-time thing from the shelf commit targeting an older master. PR should be clean now

@richarddd richarddd marked this pull request as ready for review February 23, 2026 17:06
@CISC
Copy link
Copy Markdown
Member

CISC commented Mar 11, 2026

I guess we could hold on merging this until we prototype this and make sure there aren't any surprises?

Ouch. :)

@ggerganov
Copy link
Copy Markdown
Member

No worries, we don't have other alternatives either way, so if the repack does not work out we'll have to live with the 4 byte alignment.

@CISC
Copy link
Copy Markdown
Member

CISC commented Mar 11, 2026

No worries, we don't have other alternatives either way, so if the repack does not work out we'll have to live with the 4 byte alignment.

Well, come to think of it, can we not have two NVFP4 quants? One with 16-byte alignment and this one to fall back on if that won't fit?

@ggerganov
Copy link
Copy Markdown
Member

Sounds like too much redundancy and extra complexity for not much benefit.

@CISC
Copy link
Copy Markdown
Member

CISC commented Mar 11, 2026

Sounds like too much redundancy and extra complexity for not much benefit.

True, let's hope repacking pans out.

@am17an
Copy link
Copy Markdown
Contributor

am17an commented Mar 12, 2026

4 byte alignment is already quite good. Each CUDA thread reading 4 bytes in a warp leads to a 128 byte transaction which is ideal.

ProgenyAlpha pushed a commit to ProgenyAlpha/llama.cpp that referenced this pull request Mar 12, 2026
* WIP: add NVFP4 quantization support

* tests

* improve NVFP4 dot product implementation performance and fix bad super call

* typo

* Use nvfp4 kvalues

* vulkan : fix NVFP4 shader compilation by including kvalues_mxfp4 lookup table

* vulcal and perf fixes

* wip

* Fix metal

* fix vulcan

* Rename threshold & fix wrong scale

* Fix MOE

* Shelf backend implementations (CUDA, Metal, Vulkan, arch-specific SIMD)

Remove NVFP4 support from GPU backends and architecture-specific
optimized dot products. These should be added in separate PRs so
backend specialists can review them independently.

Reverted files:
- ggml-cuda: common.cuh, convert.cu, mmq.cu/cuh, mmvq.cu, vecdotq.cuh,
  quantize.cu/cuh, mma.cuh, ggml-cuda.cu, fattn-tile.cuh
- ggml-metal: ggml-metal.metal, ggml-metal-device.cpp, ggml-metal-impl.h,
  ggml-metal-ops.cpp
- ggml-vulkan: ggml-vulkan.cpp, all vulkan-shaders/*
- ggml-cpu arch: arm/quants.c, x86/quants.c, powerpc/quants.c, s390/quants.c

Core NVFP4 support (type definition, CPU fallback dot product,
quantization, dequantization, conversion) is retained.

* Fix arch-fallback.h: add NVFP4 generic fallback for all platforms

After shelving backend-specific SIMD implementations, the generic
CPU dot product needs to be aliased on ARM, x86, PowerPC, and s390
platforms that previously relied on arch-specific versions.

* quantize: add NVFP4 as a quantization type option

* Fix ggml_fp32_to_ue4m3: handle subnormal values

Previously, values with ue4m3_exp <= 0 were clamped to 0, causing
all small scales to underflow. This made NVFP4 quantization via
llama-quantize produce garbage (PPL = 5.8M) since typical transformer
weights have amax/6.0 in the range 0.001-0.01, which falls in the
UE4M3 subnormal range.

Now subnormals are properly encoded as man * 2^-9 (exp=0, man=1..7),
matching the decode path in ggml_ue4m3_to_fp32.

Result: NVFP4 requantization now produces PPL = 15.25 (vs F16 = 14.33),
comparable to Q4_1 (PPL = 15.81) at slightly lower BPW (4.70 vs 5.15).

* Restore ARM NEON NVFP4 dot product implementation

Restores the optimized ggml_vec_dot_nvfp4_q8_0 for ARM NEON using
vqtbl1q_s8 lookup and ggml_vdotq_s32 dot products.

tg128 performance: 4.37 t/s (generic) -> 13.66 t/s (NEON) = 3.1x speedup

* Optimize ARM NEON NVFP4 dot product: LUT + vpaddq + vfmaq

- Add ue4m3_scale_lut[128] to ggml-common.h replacing branch-heavy
  ggml_ue4m3_to_fp32() in the hot loop
- Use vpaddq_s32 for pairwise int32 reduction instead of vaddvq_s32
- Accumulate with vfmaq_f32 into float32x4_t vector accumulators

tg128: 8.1 -> 31.0 t/s (3.8x speedup, 77% of Q4_1 speed)

* ARM NEON NVFP4: rearrange q8 to match nibble layout

Alternative approach: rearrange q8 data to match the NVFP4 lo/hi
nibble layout instead of rearranging the looked-up NVFP4 values.
Eliminates vcombine_s8(vget_low, vget_low) shuffles.

Performance is equivalent (~18.5 t/s) - the bottleneck is the 2x
block overhead from QK=16 vs QK=32, not the shuffle instructions.

* CPU only backend 64 super-block layout

* cleanup

* Remove unused LUT

* int

* exclude NVFP4 from unsupported ops in metal build

* remove quantization for now

* store scales as native UE4M3, preserve original model bits when possible

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* correct comment

* format

* reduce duplication and cleanup

* Address comments

* move detection to prepare_tensors

* Use math instead of const

* Move

* fix comment

* Shelf quantize tests

* Rebase and move check

* cleanup

* lint

* Update gguf-py/gguf/scripts/gguf_convert_endian.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Use fallback quant config

* Simplify

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* organize

* Refactor

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* add quantize_nvfp4 (required for test_quants.py)

* add quantize_nvfp4 (required for test_quants.py)

* add quantize_nvfp4 (required for test_quants.py)

* fix return type

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
tekintian added a commit to tekintian/llama.cpp that referenced this pull request Mar 12, 2026
* 'master' of github.com:ggml-org/llama.cpp: (33 commits)
  convert : better mtp check and fix return [no ci] (ggml-org#20419)
  vulkan: fix SSM_CONV PP scaling with large ubatch sizes (ggml-org#20379)
  New conversations now auto-select the first loaded model (ggml-org#20403)
  ggml-virtgpu: Fix some build commands (ggml-org#20341)
  metal : avoid divisions in bin kernel (ggml-org#20426)
  ci: Setup self-hosted CI for Intel Linux Vulkan backend (ggml-org#20154)
  vulkan: fix l2_norm epsilon handling (ggml-org#20350)
  vulkan: fix OOB check in flash_attn_mask_opt (ggml-org#20296)
  vulkan: Fix ErrorOutOfHostMemory on Intel GPU when loading large models with --no-mmap (ggml-org#20059)
  opencl: use larger workgroup size for get_rows (ggml-org#20316)
  opencl: add cumsum op (ggml-org#18981)
  hip: compile debug builds with -O2 on hip to avoid a compiler bug (ggml-org#20392)
  common/parser: add GigaChatV3/3.1 models support (ggml-org#19931)
  model : add support for Phi4ForCausalLMV (ggml-org#20168)
  graph : add optional scale parameter to build_lora_mm [no ci] (ggml-org#20427)
  common : fix --n-cpu-moe, --cpu-moe for models with fused gate + up (ggml-org#20416)
  ggml-webgpu: Add supports for `GGML_OP_REPEAT` (ggml-org#20230)
  llama : enable chunked fused GDN path (ggml-org#20340)
  llama : whitespace cleanup (ggml-org#20422)
  ggml : add NVFP4 quantization type support (ggml-org#19769)
  ...
@JohannesGaessler
Copy link
Copy Markdown
Contributor

4 byte alignment is already quite good. Each CUDA thread reading 4 bytes in a warp leads to a 128 byte transaction which is ideal.

For synchronous data copies I agree, for asynchronous copies chunks of 16 bytes work better in my excperience.

@michaelw9999
Copy link
Copy Markdown
Contributor

I've got the current version working with CUDA converting to pack SoA (without 4/6 or any fancy stuff) but it's not as fast as it should be (about 13,000 tk/s on Qwen4-B). Should I post it anywhere or do we have a thread to discuss follow up NVFP4 tasks? Having issues converting models and have fixes for the py script. Hope I can contribute something. Thanks

@richarddd
Copy link
Copy Markdown
Contributor Author

richarddd commented Mar 12, 2026

I've got the current version working with CUDA converting to pack SoA (without 4/6 or any fancy stuff) but it's not as fast as it should be (about 13,000 tk/s on Qwen4-B). Should I post it anywhere or do we have a thread to discuss follow up NVFP4 tasks? Having issues converting models and have fixes for the py script. Hope I can contribute something. Thanks

@michaelw9999 I think individual PRs. Small isolated onces. If improvements are incremental, they should rather be separate PR's IMO. For example, one with basic CUDA support, one for 4/6 and maybe some fancy stuff etc

@JohannesGaessler
Copy link
Copy Markdown
Contributor

The CUDA code should have the following pieces for basic support: NVFP4 dequantization + cuBLAS, MMVQ support, MMQ support via dp4a, MMQ support via tensor cores. For new contributors please only as individual and self-contained PRs, for more experienced contributors I think it's fine to do multiple things at once. Fancy stuff should come after that with evidence that it is an improvement.

@xkmire
Copy link
Copy Markdown

xkmire commented Mar 12, 2026

Thanks very much for the NVFP4 work!!

I found two very interesting NVFP4 models on huggingface:

  1. txn545/Qwen3.5-122B-A10B-NVFP4 quantized using the NVIDIA Model Optimizer.
  2. AxionML/Qwen3.5-122B-A10B-NVFP4 quantized using NVIDIA TensorRT Model Optimizer.

I tried to convert them to gguf, but both failed.

  1. ValueError: Can not map tensor 'model.language_model.layers.0.mlp.shared_expert.down_proj.weight'
  2. ValueError: Can not map tensor 'model.language_model.layers.0.linear_attn.in_proj_a.weight'

I was just wondering, if this are the kind of models that is intended to work with the NVFP4 support I have seen going into llama.cpp the last days.

If yes, I tink I might have a go at trying to figure out why they fail. Not sure I will be able to find out how to fix, but eager to get my new expensive GPU to run at its best...

@michaelw9999
Copy link
Copy Markdown
Contributor

michaelw9999 commented Mar 12, 2026 via email

@vbooka1
Copy link
Copy Markdown

vbooka1 commented Mar 13, 2026

Hello, I am getting error "Quant method is not yet supported: 'modelopt'" when trying to convert NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 ( https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4/ ) to .gguf

error log: #20411 (comment)

@CISC
Copy link
Copy Markdown
Member

CISC commented Mar 13, 2026

Hello, I am getting error "Quant method is not yet supported: 'modelopt'" when trying to convert NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 ( https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4/ ) to .gguf

error log: #20411 (comment)

Seems they have per-tensor quant_algo, which we don't check for, so repacking never kicks in.

@ORippler
Copy link
Copy Markdown
Collaborator

The CUDA code should have the following pieces for basic support: NVFP4 dequantization + cuBLAS, MMVQ support, MMQ support via dp4a, MMQ support via tensor cores.

Most likely stating the obvious: For MMVQ and MMQ dp4a path, it makes sense to do computations in BF16/FP16, as throughput is equal for FP and ALU in CUDA cores and we can save the I2F conversion via fp4 intrinsics (on the hardware that supports those of course).

Just wanted to point this out as the CPU path in this PR does ALU followed by I2F.

@ORippler
Copy link
Copy Markdown
Collaborator

4 byte alignment is already quite good. Each CUDA thread reading 4 bytes in a warp leads to a 128 byte transaction which is ideal.

For synchronous data copies I agree, for asynchronous copies chunks of 16 bytes work better in my excperience.

4 byte is the minimum we need to be able to issue LDGSTS via cg::memcpy_async and reduce register pressure by bypassing registers for the store op, and wider should always be better (as it doesn't depend on the MMU to pack LDGs issued across threads into the same read call and has higher IPC)

@JohannesGaessler
Copy link
Copy Markdown
Contributor

Regarding MMVQ: currently the activations are unconditionally converted to q8_1, if we intend to use floating-point math we will need to extend this. More generally, if we add a path using floating-point math it may make sense to use it for small matrices to remove the overhead from quantizing the activations. This table doesn't seem to list the throughput of __dp4a but I would assume for a matrix vector multiplication it won't make much of a difference either way though. We should maybe also try to define what we want to put in mmvq.cu vs. mmvf.cu. The way that would make sense to me is to use MMVQ for block-wise src0 data types and to use MMVF for scalar data types. (MMVF is strictly not needed since cuBLAS could be used, we basically only have it for performance reasons, particularly for MoE models.)

@ORippler
Copy link
Copy Markdown
Collaborator

ORippler commented Mar 18, 2026

Regarding MMVQ: currently the activations are unconditionally converted to q8_1, if we intend to use floating-point math we will need to extend this. More generally, if we add a path using floating-point math it may make sense to use it for small matrices to remove the overhead from quantizing the activations. This table doesn't seem to list the throughput of __dp4a but I would assume for a matrix vector multiplication it won't make much of a difference either way though. We should maybe also try to define what we want to put in mmvq.cu vs. mmvf.cu. The way that would make sense to me is to use MMVQ for block-wise src0 data types and to use MMVF for scalar data types. (MMVF is strictly not needed since cuBLAS could be used, we basically only have it for performance reasons, particularly for MoE models.)

I took the time to setup a script to bench dp4a vs f16 paths for mxfp4 inputs (as the gist benches mxfp4/mxfp4 instead of mxfp4/q8_0 the raw gain of f16 vs. dp4a will be less, though we can offset it by foregoing activation-quantization as you mentioned, which would be the much bigger perf gain):

https://gist.github.com/ORippler/1ac0757dc9bc462e4bf5c19a71b67c67

Cross-posting/quoting entries from there:


Some numbers on BW system (SM120)

(.venv) ➜  scratchpad ./run_nvfp4_compare.sh --elements=16777216 --iters=50 --warmup=5 --repeats=5 
GPU: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition (SM 12.0)
elements=16777216 iters=50 warmup=5 repeats=5 blocks=752 threads=256 requested_scale_a=0.125 requested_scale_b=0.0625 e8m0_scale_a=0.125 e8m0_scale_b=0.0625

Path                                              avg ms          GMac/s        checksum
FP4(E2M1) -> INT8 + DP4A + post-scale              0.150        5591.570       31150.195
FP4(E2M1) -> f16 + f16-domain multiply             0.108        7739.711       31150.195

DP4A path speed vs f16 path: 0.722x
L1 relative output delta (f16 ref): 0.000
(.venv) ➜  scratchpad ./run_nvfp4_compare.sh --elements=167772160 --iters=50 --warmup=5 --repeats=5
GPU: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition (SM 12.0)
elements=167772160 iters=50 warmup=5 repeats=5 blocks=752 threads=256 requested_scale_a=0.125 requested_scale_b=0.0625 e8m0_scale_a=0.125 e8m0_scale_b=0.0625

Path                                              avg ms          GMac/s        checksum
FP4(E2M1) -> INT8 + DP4A + post-scale              3.301        2540.992       18451.855
FP4(E2M1) -> f16 + f16-domain multiply             3.217        2607.946       18451.855

DP4A path speed vs f16 path: 0.974x
L1 relative output delta (f16 ref): 0.000

Why is F16 path preferred even if throughput of DP4a is higher theoretically? Two-reasons:

  1. We use ALU and FMA pipes instead of almost exclusively the ALU pipe (safe for the multiplication with chunk scale in DP4a path).
  2. We have specific instructions available in our ISA to convert from fp4x2 to half2, yet none exist for conversion from packed fp4 to int32, meaning we have to spend more ALU instructions on this (basically the most costly part is repacking the loaded data in the registers)

Scaling of F16 vs. DP4a depends on the workload size. In the first setting, we stay within cache, whereas in the second setting we are at 100% memory bandwidth (and thus waiting for data most of time).

coyotte508 pushed a commit to huggingface/huggingface.js that referenced this pull request Mar 18, 2026
Added in ggml-org/llama.cpp#19769

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Low Risk**
> Low risk: adds a new GGUF quantization enum entry plus
description/size metadata, with minimal impact beyond recognizing an
additional quant label and computing its bpw.
> 
> **Overview**
> Adds support for the new `NVFP4` GGUF quantization type (from
llama.cpp PR #19769).
> 
> Updates quantization metadata to recognize `NVFP4` in filename
parsing/ordering (`packages/tasks/src/gguf.ts`) and documents its
description plus bits-per-weight calculation
(`packages/gguf/src/quant-descriptions.ts`).
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
0fc05ce. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->
@michaelw9999
Copy link
Copy Markdown
Contributor

michaelw9999 commented Mar 18, 2026

Super helpful, thank you! Your numbers match almost perfectly the real numbers I am seeing with NVFP4. I've been working out the best path for updating the dp4a-only model. On F16 some models overflow (another issue) but forcing F32, which fixes that, slows max prefill by 0.72x; avg 9000tk/s to 6500tk/s.. == 0.72.

ggml_cuda_init: found 1 CUDA devices (Total VRAM: 32606 MiB):
  Device 0: NVIDIA GeForce RTX 5090, compute capability 12.0, VMM: yes, VRAM: 32606 MiB
| model                          |       size |     params | backend    | ngl |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | --------------: | -------------------: |
| qwen3 4B BF16                  |   2.63 GiB |     4.02 B | CUDA       |  99 |           pp512 |     8914.99 ± 950.13 |
| qwen3 4B BF16                  |   2.63 GiB |     4.02 B | CUDA       |  99 |           tg128 |        280.78 ± 0.60 |

build: 6e5081b5a (8381)
ggml_cuda_init: found 1 CUDA devices (Total VRAM: 32606 MiB):
  Device 0: NVIDIA GeForce RTX 5090, compute capability 12.0, VMM: yes, VRAM: 32606 MiB
| model                          |       size |     params | backend    | ngl |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | --------------: | -------------------: |
| qwen3 4B BF16                  |   2.63 GiB |     4.02 B | CUDA       |  99 |           pp512 |     6646.92 ± 567.05 |
| qwen3 4B BF16                  |   2.63 GiB |     4.02 B | CUDA       |  99 |           tg128 |        280.52 ± 0.88 |

I've been trying to find a way to use F32 without such a huge loss but I guess that is how it will be on SM120.
Side thought: need to fix convert_hf_... to stop naming models BF16 on the NVFP4 conversions.

Ethan-a2 pushed a commit to Ethan-a2/llama.cpp that referenced this pull request Mar 20, 2026
* WIP: add NVFP4 quantization support

* tests

* improve NVFP4 dot product implementation performance and fix bad super call

* typo

* Use nvfp4 kvalues

* vulkan : fix NVFP4 shader compilation by including kvalues_mxfp4 lookup table

* vulcal and perf fixes

* wip

* Fix metal

* fix vulcan

* Rename threshold & fix wrong scale

* Fix MOE

* Shelf backend implementations (CUDA, Metal, Vulkan, arch-specific SIMD)

Remove NVFP4 support from GPU backends and architecture-specific
optimized dot products. These should be added in separate PRs so
backend specialists can review them independently.

Reverted files:
- ggml-cuda: common.cuh, convert.cu, mmq.cu/cuh, mmvq.cu, vecdotq.cuh,
  quantize.cu/cuh, mma.cuh, ggml-cuda.cu, fattn-tile.cuh
- ggml-metal: ggml-metal.metal, ggml-metal-device.cpp, ggml-metal-impl.h,
  ggml-metal-ops.cpp
- ggml-vulkan: ggml-vulkan.cpp, all vulkan-shaders/*
- ggml-cpu arch: arm/quants.c, x86/quants.c, powerpc/quants.c, s390/quants.c

Core NVFP4 support (type definition, CPU fallback dot product,
quantization, dequantization, conversion) is retained.

* Fix arch-fallback.h: add NVFP4 generic fallback for all platforms

After shelving backend-specific SIMD implementations, the generic
CPU dot product needs to be aliased on ARM, x86, PowerPC, and s390
platforms that previously relied on arch-specific versions.

* quantize: add NVFP4 as a quantization type option

* Fix ggml_fp32_to_ue4m3: handle subnormal values

Previously, values with ue4m3_exp <= 0 were clamped to 0, causing
all small scales to underflow. This made NVFP4 quantization via
llama-quantize produce garbage (PPL = 5.8M) since typical transformer
weights have amax/6.0 in the range 0.001-0.01, which falls in the
UE4M3 subnormal range.

Now subnormals are properly encoded as man * 2^-9 (exp=0, man=1..7),
matching the decode path in ggml_ue4m3_to_fp32.

Result: NVFP4 requantization now produces PPL = 15.25 (vs F16 = 14.33),
comparable to Q4_1 (PPL = 15.81) at slightly lower BPW (4.70 vs 5.15).

* Restore ARM NEON NVFP4 dot product implementation

Restores the optimized ggml_vec_dot_nvfp4_q8_0 for ARM NEON using
vqtbl1q_s8 lookup and ggml_vdotq_s32 dot products.

tg128 performance: 4.37 t/s (generic) -> 13.66 t/s (NEON) = 3.1x speedup

* Optimize ARM NEON NVFP4 dot product: LUT + vpaddq + vfmaq

- Add ue4m3_scale_lut[128] to ggml-common.h replacing branch-heavy
  ggml_ue4m3_to_fp32() in the hot loop
- Use vpaddq_s32 for pairwise int32 reduction instead of vaddvq_s32
- Accumulate with vfmaq_f32 into float32x4_t vector accumulators

tg128: 8.1 -> 31.0 t/s (3.8x speedup, 77% of Q4_1 speed)

* ARM NEON NVFP4: rearrange q8 to match nibble layout

Alternative approach: rearrange q8 data to match the NVFP4 lo/hi
nibble layout instead of rearranging the looked-up NVFP4 values.
Eliminates vcombine_s8(vget_low, vget_low) shuffles.

Performance is equivalent (~18.5 t/s) - the bottleneck is the 2x
block overhead from QK=16 vs QK=32, not the shuffle instructions.

* CPU only backend 64 super-block layout

* cleanup

* Remove unused LUT

* int

* exclude NVFP4 from unsupported ops in metal build

* remove quantization for now

* store scales as native UE4M3, preserve original model bits when possible

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* correct comment

* format

* reduce duplication and cleanup

* Address comments

* move detection to prepare_tensors

* Use math instead of const

* Move

* fix comment

* Shelf quantize tests

* Rebase and move check

* cleanup

* lint

* Update gguf-py/gguf/scripts/gguf_convert_endian.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Use fallback quant config

* Simplify

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* organize

* Refactor

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* add quantize_nvfp4 (required for test_quants.py)

* add quantize_nvfp4 (required for test_quants.py)

* add quantize_nvfp4 (required for test_quants.py)

* fix return type

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
@JohannesGaessler
Copy link
Copy Markdown
Contributor

Tests added to test-backend-ops and test-quantize-fns

There aren't actually any NVFP4 test cases in test-backend-ops. Was this overlooked?

@CISC
Copy link
Copy Markdown
Member

CISC commented Mar 24, 2026

Tests added to test-backend-ops and test-quantize-fns

There aren't actually any NVFP4 test cases in test-backend-ops. Was this overlooked?

I think they were removed again, TBD at first backend support.

@michaelw9999
Copy link
Copy Markdown
Contributor

Tests added to test-backend-ops and test-quantize-fns

There aren't actually any NVFP4 test cases in test-backend-ops. Was this overlooked?

@JohannesGaessler @CISC I can add them in a new PR if you like. Have various different tests. I did not include any in the CUDA PR.

@JohannesGaessler
Copy link
Copy Markdown
Contributor

No tests beyond test-backend-ops should be needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Apple Metal https://en.wikipedia.org/wiki/Metal_(API) examples ggml changes relating to the ggml tensor library for machine learning model Model specific Nvidia GPU Issues specific to Nvidia GPUs python python script changes testing Everything test related Vulkan Issues specific to the Vulkan backend

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature Request: The script convert_hf_to_gguf.py supports conversion of DeepSeek-R1-0528-FP4.