Skip to content

POC: CUDA tensor parallel#1018

Closed
ikawrakow wants to merge 15 commits intomainfrom
ik/poc_tp
Closed

POC: CUDA tensor parallel#1018
ikawrakow wants to merge 15 commits intomainfrom
ik/poc_tp

Conversation

@ikawrakow
Copy link
Copy Markdown
Owner

@ikawrakow ikawrakow commented Nov 27, 2025

This is a very rough around the edges POC for tensor parallelism (TP). It is a completely different implementation than TP via split mode "row" in mainline llama.cpp. I prefer to call it "graph parallel" rather than "tensor parallel", but I guess TP is the better known term.

I have (almost?) completely removed the split mode "row" related code inherited from mainline (but broken in ik_llama.cpp for a while). TP is realized when building the computation graph, instead of black magic woodoo in the CUDA code. Model tensors are split across rows or columns (depending on tensor) to allow whole portions of the computation graph to be computed in parallel without synchronization between GPUs (so, that's why "graph parallel"). Instead of having to synchronize after every matrix multiplication, as it is the case in mainline's split mode "row", there are basically two synchronization points per model layer -- one after self attention, and one after the feed-forward network -- where results from the GPUs involved get added together (instead of being concatenated as it is in split mode "row"). Most importantly, KV cache for each layer is also split between GPUs, and self attention is computed in parallel, including V * softmax(K*Q), which is the most computationally expensive part of the transformer architecture for long contexts.

The POC is very rough because

  • Only dense LlaMA models are implemented. I wanted to have a version working for one architecture before moving on to other arches and MoE models
  • Full and partial GPU offload works, but not tensor overrides (but there is also no point in using tensor overrides for dense models)
  • Some options are not functional yet (-mqkv, -gr)

Nevertheless, I wanted to put it out there for visibility (and to give mainline developers more time to fully independently discover the approach /s).

To use it, the command line option is -sm graph or --split-mode graph. How much gets offloaded to what GPU can still be controlled via -ts or --tensor-split (but if not provided, split is determined by available VRAM).

I have developed and tested on a 2x3090 system donated by @magikRUKKOLA, so many thanks again!

For the 8B LlaMA model, TP is still slower than split mode "layer", except for long context PP. But for the 70B LlaMA, this TP implementation beats split mode "layer" for PP and TG. Here are some sweep-bench graphs and tables for these two models quantized with Q4_0 (so I can run the same model in ik_llama.cpp and llama.cpp). With the 70B model I can only go to a context of 16k tokens with the 2x3090 system and full offload.

LlaMA-8B

tp_8B_pp tp_8B_tg
ik_llama.cpp, split mode "layer"
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
2048 256 0 0.321 6376.90 1.480 173.00
2048 256 2048 0.351 5833.68 1.548 165.41
2048 256 4096 0.378 5411.12 1.636 156.52
2048 256 6144 0.405 5053.40 1.726 148.34
2048 256 8192 0.436 4694.29 1.813 141.19
2048 256 10240 0.462 4434.87 1.864 137.30
2048 256 12288 0.492 4162.61 1.941 131.86
2048 256 14336 0.520 3935.15 2.016 126.97
2048 256 16384 0.549 3732.12 2.108 121.47
2048 256 18432 0.575 3559.41 2.185 117.15
2048 256 20480 0.604 3389.22 2.228 114.90
2048 256 22528 0.634 3232.06 2.310 110.84
2048 256 24576 0.662 3094.48 2.392 107.01
2048 256 26624 0.693 2954.93 2.482 103.12
2048 256 28672 0.728 2812.98 2.557 100.10
2048 256 30720 0.757 2703.94 2.608 98.17
ik_llama.cpp, split mode "graph"
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
2048 256 0 0.419 4891.67 2.156 118.72
2048 256 2048 0.423 4837.56 2.183 117.28
2048 256 4096 0.438 4679.16 2.234 114.61
2048 256 6144 0.452 4529.99 2.311 110.80
2048 256 8192 0.466 4393.16 2.319 110.41
2048 256 10240 0.481 4261.16 2.365 108.26
2048 256 12288 0.496 4130.27 2.418 105.89
2048 256 14336 0.511 4007.35 2.446 104.65
2048 256 16384 0.524 3904.96 2.503 102.28
2048 256 18432 0.540 3794.69 2.521 101.55
2048 256 20480 0.554 3696.18 2.555 100.20
2048 256 22528 0.569 3601.70 2.611 98.03
2048 256 24576 0.584 3507.90 2.730 93.77
2048 256 26624 0.598 3426.16 2.826 90.60
2048 256 28672 0.613 3341.58 2.841 90.10
2048 256 30720 0.628 3261.98 2.749 93.13
llama.cpp, split mode "layer"
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
2048 512 0 0.352 5823.61 3.101 165.11
2048 512 2048 0.392 5229.45 3.264 156.88
2048 512 4096 0.440 4659.23 3.484 146.97
2048 512 6144 0.486 4213.65 3.746 136.70
2048 512 8192 0.534 3836.10 3.890 131.62
2048 512 10240 0.580 3529.11 4.097 124.96
2048 512 12288 0.658 3113.51 4.273 119.83
2048 512 14336 0.677 3024.10 4.445 115.18
2048 512 16384 0.725 2823.67 4.679 109.43
2048 512 18432 0.774 2647.29 4.817 106.28
2048 512 20480 0.823 2488.96 4.985 102.72
2048 512 22528 0.873 2347.26 5.199 98.49
2048 512 24576 0.924 2215.87 5.373 95.29
2048 512 26624 0.974 2103.34 5.600 91.43
2048 512 28672 1.024 2000.81 5.741 89.19
2048 512 30720 1.075 1905.11 5.906 86.69
llama.cpp, split mode "row"
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
2048 512 0 0.830 2467.54 7.133 71.78
2048 512 2048 0.867 2360.85 7.276 70.37
2048 512 4096 0.915 2237.79 7.523 68.06
2048 512 6144 0.966 2119.49 7.801 65.63
2048 512 8192 1.013 2021.22 7.947 64.43
2048 512 10240 1.062 1929.13 8.123 63.03
2048 512 12288 1.110 1844.48 8.305 61.65
2048 512 14336 1.157 1770.47 8.486 60.34
2048 512 16384 1.177 1739.33 8.722 58.70
2048 512 18432 1.259 1626.91 8.857 57.80
2048 512 20480 1.268 1614.77 9.028 56.71
2048 512 22528 1.362 1503.41 9.246 55.37
2048 512 24576 1.413 1449.69 9.384 54.56
2048 512 26624 1.410 1452.73 9.604 53.31
2048 512 28672 1.506 1359.76 9.752 52.50
2048 512 30720 1.555 1316.77 9.905 51.69

LlaMA 70B

tp_70B_pp tp_70B_tg
ik_llama.cpp, split mode "layer"
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
1024 256 0 1.363 751.06 11.322 22.61
1024 256 1024 1.399 732.08 11.397 22.46
1024 256 2048 1.433 714.56 11.506 22.25
1024 256 3072 1.468 697.43 11.679 21.92
1024 256 4096 1.501 682.06 11.740 21.81
1024 256 5120 1.542 664.07 11.948 21.43
1024 256 6144 1.576 649.55 11.964 21.40
1024 256 7168 1.611 635.77 12.013 21.31
1024 256 8192 1.646 622.03 12.144 21.08
1024 256 9216 1.681 609.28 12.167 21.04
1024 256 10240 1.719 595.74 12.286 20.84
1024 256 11264 1.753 584.27 12.419 20.61
1024 256 12288 1.783 574.23 12.470 20.53
1024 256 13312 1.817 563.46 12.611 20.30
1024 256 14336 1.852 552.97 12.663 20.22
1024 256 15360 1.886 543.00 12.741 20.09
ik_llama.cpp, split mode "graph"
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
1024 256 0 1.296 790.38 9.114 28.09
1024 256 1024 1.304 785.01 9.163 27.94
1024 256 2048 1.322 774.56 9.225 27.75
1024 256 3072 1.340 764.03 9.271 27.61
1024 256 4096 1.359 753.55 9.350 27.38
1024 256 5120 1.378 743.08 9.519 26.89
1024 256 6144 1.396 733.38 9.555 26.79
1024 256 7168 1.413 724.53 9.545 26.82
1024 256 8192 1.431 715.46 9.580 26.72
1024 256 9216 1.449 706.64 9.606 26.65
1024 256 10240 1.469 696.94 9.687 26.43
1024 256 11264 1.487 688.76 9.799 26.12
1024 256 12288 1.505 680.36 9.803 26.11
1024 256 13312 1.523 672.37 9.830 26.04
1024 256 14336 1.541 664.69 9.886 25.90
1024 256 15360 1.558 657.19 9.913 25.83
llama.cpp, split mode "layer"
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
1024 256 0 1.453 704.55 11.581 22.10
1024 256 1024 1.483 690.49 11.690 21.90
1024 256 2048 1.518 674.54 11.778 21.73
1024 256 3072 1.557 657.47 11.895 21.52
1024 256 4096 1.601 639.76 12.025 21.29
1024 256 5120 1.639 624.76 12.271 20.86
1024 256 6144 1.688 606.75 12.331 20.76
1024 256 7168 1.722 594.62 12.376 20.69
1024 256 8192 1.765 580.25 12.475 20.52
1024 256 9216 1.822 562.07 12.572 20.36
1024 256 10240 1.845 554.95 12.663 20.22
1024 256 11264 1.882 544.03 12.838 19.94
1024 256 12288 1.924 532.16 12.888 19.86
1024 256 13312 1.963 521.56 12.980 19.72
1024 256 14336 2.007 510.23 13.087 19.56
1024 256 15360 2.058 497.52 13.169 19.44
llama.cpp, split mode "row"
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
1024 256 0 2.115 484.16 13.277 19.28
1024 256 1024 2.143 477.90 13.402 19.10
1024 256 2048 2.172 471.42 13.483 18.99
1024 256 3072 2.216 462.11 13.605 18.82
1024 256 4096 2.256 453.88 13.721 18.66
1024 256 5120 2.286 447.92 13.974 18.32
1024 256 6144 2.318 441.70 14.008 18.28
1024 256 7168 2.355 434.77 14.083 18.18
1024 256 8192 2.393 427.91 14.136 18.11
1024 256 9216 2.438 420.07 14.248 17.97
1024 256 10240 2.466 415.18 14.324 17.87
1024 256 11264 2.507 408.52 14.546 17.60
1024 256 12288 2.533 404.34 14.583 17.55
1024 256 13312 2.586 395.99 14.679 17.44
1024 256 14336 2.603 393.34 14.750 17.36
1024 256 15360 2.659 385.04 14.854 17.23

Some details

Feed-forward-network (FFN)

Ignoring for simplicity biases and such, one needs to compute D * ((U*X) . unary(G*X)), where * stands for matrix multiplication, the dot indicates element wise multiplication, X are the input activations, G is the gate tensor (ffn_gate), U is the up tensor (ffn_up), and D is the down tensor (ffn_down). U and G are M x N, X is M x K, and D is N x M (with M being the embedding size, N the FFN size, and K the batch size, i.e., number of tokens). The outcome of U*X and G*X are N x K matrices. Let us for simplicity consider splitting these operations between 2 GPUs.

The split mode "row" approach does the following

  • U, G are split into two M x N/2 tensors, and D is split into N x M/2 (i.e., split along rows)
  • U*X and G*X result in N/2 x K matrices
  • These two are concatenated along dimension zero into an N x K matrix. This requires synchronization between the GPUs
  • The unary op his performed on the main GPU (so, no parallelism)
  • The element wise multiplication is performed on the main GPU (so, no parallelism)
  • The matrix multiplication with the two D halves is done in parallel
  • The final result is obtained by concatenating the two results, so a 3rd GPU synchronization

The approach implemented here does the following

  • U, G are split into two M x N/2 tensors (i.e., split along rows), while D is split into N/2 x M tensors (i.e., split along columns)
  • This allows each GPU to compute the entire D * ((U*X) . unary(G*X)) operation with its half
  • The final result is obtained by adding the two results, which requires a single synchronization between the GPUs

Self attention

Ignoring biases and intermediate normalizations for simplicity, here one needs to compute W_O * (V * softmax(K * rope(W_Q*X))). Here K and V stand for the K- and V-cache, X are the activations as above, and W_0 is the attention output tensor. Before computing these operations, one needs to compute rope(W_K * X) and W_V * X, and store (copy) the result into the K- and V-cache. The PR does the following:

  • Splits W_K, W_Q and W_V along rows
  • Splits W_0 along columns
  • As the RoPE operations must be done over complete attention heads, and because we wish to be able to complete all operations independently with the corresponding portion of the model tensors, splitting granularity is defined by the attention head size, so there is less flexibility compared to split mode "row".
  • With this, each GPU can compute rope(W_K * X) and W_V * X, and can store (copy) the result into its own KV cache
  • Hence, each GPU can use its own portion of the KV cache to compute V * softmax(K * rope(W_Q * X)) (using flash attention), and then perform the matrix multiplication with its portion of W_O.
  • The final self-attention result is obtained by adding the individual GPU results, which requires a single synchronization
  • I think the above graphs clearly demonstrate the benefit of this approach as performance decline with increasing context length is significantly slower than with split mode "layer" (or split mode "row").

Compute graph handling

The "split backend", which in the split mode "row" implementation is used to coordinate GPU synchronization and to split the work between the GPUs, is only used to coordinate model loading and data copy of the appropriate model tensors portions to the corresponding GPUs. It is not used during computation at all. Instead, when the compute graph is built, the self attention and FFN portions are done per GPU, which allows the ggml back-end to construct graph splits that can be computed on the GPUs in parallel. The ggml backend required a bit of help to do the right thing. I have added a parameter in the ggml_add operation. When set, the back-end starts a new graph split. Without this little hack, the backend would construct graph splits, which would prevent GPU1 to start working before GPU0 has done its work. Doing it in this way has the distinct advantage of allowing easier TP implementation for other back-ends (e.g., Vulkan), as the only thing that is required is to implement the model tensor data copying to the appropriate GPUs. Still, in retrospectI'm not 100% that this was the right decision because

  • We lose the ability to use CUDA graphs.
  • We rely on the back-end for synchronization and intermediate result copies, instead of perhaps using events (but that again is CUDA specific)

Next steps

  • Implement for one MoE architecture. GLM-4.5/4.6, being so popular, is a likely candidate
  • Extend to additional dense and MoE models
  • Look into self attention TP for DeepSeek models, which use MLA where the above self attention considerations do not apply
  • Look into partial offload for MoE models. Ideally one should be parallelizing the offload and computations with MoE tensors stored in RAM (instead of copying to a single GPU as it is done now)

Additional notes

Data transfer between GPUs is entirely non-negligible during PP. Here is some napkin math for LlaMA-70B

  • Embedding size is 8192. Lets assume batch size is 1024, and we are running on 2 GPUs
  • As explained above, we need to copy the result of GPU-1 self-attention to GPU-0 (to perform the addition of these two results), then copy the result to GPU-1 (so it can start computing FFN)
  • Then we need to copy the FFN result from GPU-1 to GPU-0 (to perform the addition), and then copy the result back to GPU-1 (to start computing the next layer)
  • So, there are 4 copies per layer. Each copy is 8192 x 1024 x 4 bytes =32 MiB, so 128 MiB per layer
  • There are 80 layers, so 10 GiB of data
  • My PCI-E is 30 GiB/s theoretical, but in practice I observe more like 20-22 GiB/s (we are not copying one giant block of data at once, so there are latencies etc. involved). So, copying around intermediate results takes about 0.5 seconds per batch
  • For split mode "layer", where only one GPU is active at a time (and there is very little data to be copied), we get ~750 t/s, so computation of the batch on 1 GPU is about 1.36 seconds
  • If we were able to run both GPUs fully concurrently, the computation would take 1.36/2 = 0.68 seconds
  • I.e., without any additional synchronization latencies involved, it would take 0.68 + 0.5 =1.18 seconds to compute the batch, and data transfer consumed 42% of the time!
  • In that ideal case we would get 1024/1.18 =868 tokens/second
  • The above data tables show 790 t/s, so quite close to theoretical maximum. But there is a bit of cheating involved as I have converted 3 of the 4 copies per layer to fp16, so the actually copied data is 6.25 GiB, which takes about 0.3 seconds. So, the maximum possible t/s would be 1024/(0.68 + 0.3) = 1045 t/s. Hence, the observed 790 t/s is about 75% of theoretically achievable, which means that in addition to the data transfer bottleneck, there is also a significant synchronization overhead.

Iwan Kawrakow added 14 commits November 27, 2025 14:58
But it also looks like the backend scheduler is not going to help:
* It copies mask and input positions to GPU 0
* => RoPE ops must run on GPU 0
* => To proceed attn evaluation, GPU 1 must wait for GPU 0 to finish its
     entire attn calculation
* Same with FFN. The rms_norm gets scheduled on GPU 0. Hence, GPU 1 must
  wait for GPU 0 to finish its entore FFN calculation before it can
  start (as it needs to copy the result of rms_norm from GPU 0)
* => Seems useless without writing a bespoke TP scheduling
the graph is still not being computed in parallel.
Why? Because the scheduler creates graph splits where the
result of the computation on one GPU becomes an input for the
other split. Hence, to trigger the computation on the second GPU
one needs to wait for the computation on the first GPU to finish,
even thiough the two can be done in parallel up to the sunchronization
point. So, all that is left to do is to trick the scheduler to create
to splits that can be done in parallel, and then have a graph split
where the results get combined.
This change tricks it into doing the right thing^TM.
Still quite a bit slower than split mode layer for the 8B LlaMA model.
But for the 70B LlaMA it now beats split mode layer for TG:
28 t/s vs 24.4 t/s. PP is 627 t/s vs 744 t/s.
In comparison, split mode "row" in mainline gets
484 t/s PP and 19.3 t/s TG.
Granularity for Wq, Wo is not just head size, but
head size * gqa_ratio.
Else the Wk, Wv tensors end up not being a multiple of the
head size when we divide the split determined by Wo with
the gqa_ratio.
but no tensor overrides yet, just ngl < num_layers.
Now PP is faster than split mode layer for L3-70B.
@ikawrakow
Copy link
Copy Markdown
Owner Author

Here nvtop screenshots showing ~75% GPU utilization during LlaMA-70B inference

Screenshot 2025-11-27 at 7 36 34 PM

@Ph0rk0z
Copy link
Copy Markdown

Ph0rk0z commented Nov 27, 2025

CUDA_VISIBLE_DEVICES=0,1 ./bin/llama-sweep-bench \
-m /Dusk-Miqu-70B-i1-GGUF/Dusk-Miqu-70B.i1-Q4_K_M.gguf \
-t 48 \
-c 32768 \
-fa on \
-ctk q8_0 \
-ctv q8_0 \
--no-mmap \
-ngl 99 

2x regular GPU

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
512 128 0 0.815 628.11 6.401 20.00
512 128 512 0.820 624.45 6.483 19.75
512 128 1024 0.830 617.18 6.514 19.65
512 128 1536 0.839 610.07 6.559 19.52
512 128 2048 0.849 603.18 6.614 19.35
512 128 2560 0.858 596.92 6.745 18.98
512 128 3072 0.868 590.01 6.797 18.83
512 128 3584 0.878 583.11 6.849 18.69
512 128 4096 0.887 577.33 6.895 18.56
512 128 4608 0.897 570.69 6.963 18.38
512 128 5120 0.907 564.44 7.096 18.04
512 128 5632 0.917 558.51 7.127 17.96
512 128 6144 0.937 546.18 7.181 17.83
512 128 6656 0.936 546.75 7.235 17.69
512 128 7168 0.946 541.50 7.271 17.60
512 128 7680 0.955 536.07 7.367 17.37
512 128 8192 0.965 530.34 7.422 17.24

4x Regular GPU

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
512 128 0 0.832 615.13 6.414 19.96
512 128 512 0.826 619.67 6.497 19.70
512 128 1024 0.843 607.41 6.529 19.61
512 128 1536 0.883 579.87 6.575 19.47
512 128 2048 0.855 598.86 6.629 19.31
512 128 2560 0.865 591.84 6.757 18.94
512 128 3072 0.874 585.57 6.812 18.79
512 128 3584 0.885 578.61 6.863 18.65
512 128 4096 0.895 572.21 6.909 18.53
512 128 4608 0.904 566.41 6.965 18.38
512 128 5120 0.914 560.26 7.103 18.02
512 128 5632 0.924 554.27 7.145 17.92
512 128 6144 0.934 548.17 7.196 17.79
512 128 6656 0.943 542.81 7.233 17.70
512 128 7168 0.955 535.96 7.291 17.56
512 128 7680 0.963 531.80 7.383 17.34
512 128 8192 0.973 526.31 7.438 17.21

2xGPU - Mainline

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
512 128 0 0.810 632.48 6.285 20.37
512 128 512 0.822 622.73 6.423 19.93
512 128 1024 0.834 613.84 6.462 19.81
512 128 1536 0.845 606.01 6.550 19.54
512 128 2048 0.857 597.54 6.634 19.29
512 128 2560 0.866 591.20 6.720 19.05
512 128 3072 0.879 582.47 6.800 18.82
512 128 3584 0.889 576.12 6.831 18.74
512 128 4096 0.900 569.03 6.920 18.50
512 128 4608 0.909 563.03 7.008 18.26
512 128 5120 0.920 556.35 7.086 18.06
512 128 5632 0.932 549.15 7.149 17.90
512 128 6144 0.942 543.70 7.189 17.80
512 128 6656 0.954 536.75 7.273 17.60
512 128 7168 0.964 531.07 7.363 17.38
512 128 7680 0.974 525.48 7.447 17.19
512 128 8192 0.986 519.19 7.513 17.04

4x GPU - - Mainline

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
512 128 0 0.807 634.39 6.294 20.34
512 128 512 0.818 626.05 6.434 19.89
512 128 1024 0.829 617.94 6.475 19.77
512 128 1536 0.839 610.05 6.566 19.50
512 128 2048 0.850 602.43 6.652 19.24
512 128 2560 0.861 594.43 6.741 18.99
512 128 3072 0.872 587.44 6.822 18.76
512 128 3584 0.882 580.48 6.855 18.67
512 128 4096 0.894 572.86 6.947 18.42
512 128 4608 0.905 565.63 7.036 18.19
512 128 5120 0.917 558.31 7.117 17.98
512 128 5632 0.928 551.47 7.185 17.82
512 128 6144 0.939 545.47 7.227 17.71
512 128 6656 0.951 538.62 7.313 17.50
512 128 7168 0.962 532.45 7.402 17.29
512 128 7680 0.973 526.42 7.485 17.10
512 128 8192 0.985 519.86 7.555 16.94

2x GPU - TP

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
512 128 0 1.043 490.80 5.356 23.90
512 128 512 1.039 492.94 5.409 23.66
512 128 1024 1.044 490.29 5.419 23.62
512 128 1536 1.052 486.87 5.439 23.53
512 128 2048 1.056 484.74 5.463 23.43
512 128 2560 1.061 482.37 5.491 23.31
512 128 3072 1.067 480.02 5.527 23.16
512 128 3584 1.072 477.46 5.545 23.08
512 128 4096 1.078 475.04 5.586 22.91
512 128 4608 1.083 472.60 5.621 22.77
512 128 5120 1.090 469.92 5.715 22.40
512 128 5632 1.095 467.72 5.749 22.27
512 128 6144 1.100 465.28 5.785 22.12
512 128 6656 1.106 462.92 5.815 22.01
512 128 7168 1.112 460.63 5.818 22.00
512 128 7680 1.117 458.46 5.826 21.97
512 128 8192 1.122 456.14 5.867 21.82

4x GPU - TP

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
512 128 0 2.032 251.98 7.015 18.25
512 128 512 2.017 253.83 7.079 18.08
512 128 1024 2.021 253.37 7.083 18.07
512 128 1536 2.023 253.11 7.081 18.08
512 128 2048 2.026 252.76 7.111 18.00
512 128 2560 2.029 252.39 7.105 18.02
512 128 3072 2.032 252.02 7.125 17.97
512 128 3584 2.035 251.58 7.140 17.93
512 128 4096 2.037 251.30 7.164 17.87
512 128 4608 2.041 250.89 7.165 17.86
512 128 5120 2.047 250.14 7.201 17.78
512 128 5632 2.046 250.19 7.251 17.65
512 128 6144 2.050 249.76 7.241 17.68
512 128 6656 2.053 249.43 7.259 17.63
512 128 7168 2.057 248.93 7.276 17.59
512 128 7680 2.059 248.65 7.294 17.55
512 128 8192 2.063 248.22 7.312 17.51

With 4x gpu, I don't see a lot of b/w. Maybe 2-3 GiB/s in nvtop. GPU usage is from 27-37% with one GPU at 46%. Most of the transfers really happen during PP anyway. I have turbo disabled (in ss, he doesn't) and PCIE3 through 2 PLX with an X16 link per PLX. I have seen higher transfer on nccl wan so I don't think it's transfer bottleneck just yet. Must be something else?

@ikawrakow
Copy link
Copy Markdown
Owner Author

What are the GPU's? 3090? If so, why would your split mode "layer" PP performance be 25% lower than mine?
Have you tried running full GPU offload with 1-4 CPU threads instead of 48?
If you are seeing 2-3 GiB/s during PP with TP there is something seriously wrong with your system.
If the 2-3 GiB/s is during TG, that's irrelevant. Data transfer between GPU's is not a bottleneck during TG.

@magikRUKKOLA
Copy link
Copy Markdown

magikRUKKOLA commented Nov 27, 2025

@ikawrakow

What a cool release and the read!! Thanks a lot!

Here nvtop screenshots showing ~75% GPU utilization during LlaMA-70B inference

Please note that the first GPU temperature is 79C and its getting thermal-throttled so the core clock is only 1920 MHz. The second GPU which is placed at the bottom is 57C and it allows it to utilize almost all of the applied offset of +100 MHz so 1950 -> 2040 MHz.

So the following means that air-cooling is obsolete and its necessary to use the liquid cooling.

Also I got a question regarding the BW. Is it only about 1.6 Gbit/sec ? Will it make any difference if you would install the NvLink bridge (3-slot bridge in your case)?

[EDIT]:

which means that in addition to the data transfer bottleneck, there is also a significant synchronization overhead.

Aha. So the NvLink should help then. But will it help for the config with 4 GPUs where only two NvLinks are installed?

@Ph0rk0z
Copy link
Copy Markdown

Ph0rk0z commented Nov 27, 2025

why would your split mode "layer" PP performance be 25% lower than mine?

You run with full clocks. I have it capped at 1695. I can do at least 2 cards without.

Have you tried running full GPU offload with 1-4 CPU threads instead of 48?

t48 is left over from CPU running. It should be using a single thread. That's how mainline has been for ages.

Also I got a question regarding the BW. Is it only about 1.6 Gbit/sec ? Will it make any difference if you would install the NvLink bridge (3-slot bridge in your case)?

1.6 isn't maximal bandwidth. That's just what I see on this model. Wan sees 6-8 in NCCL. I have also done p2p tests and all that jazz.

btw, 4 is 2080ti 22g. not P2P

Bidirectional P2P=Enabled Bandwidth Matrix (GB/s)
   D\D     0      1      2      3      4 
     0 839.61  25.38  19.72  19.70   8.54 
     1  25.37 839.36  19.72  19.70   8.57 
     2  19.72  19.72 840.28  25.36   8.58 
     3  19.70  19.70  25.37 839.15   8.58 
     4   8.54   8.54   8.52   8.56 529.73 
Unidirectional P2P=Enabled Bandwidth (P2P Writes) Matrix (GB/s)
   D\D     0      1      2      3      4 
     0 833.78  13.18  10.29  10.27   5.82 
     1  13.18 835.11  10.29  10.26   5.81 
     2  10.29  10.29 833.78  13.11   5.83 
     3  10.29  10.29  13.18 836.46   5.83 
     4   5.86   5.85   5.86   5.85 528.74 

Watch your nvtop and see what you get. It only appears briefly.

edit: ok, so I did some tests and it's none of those things. All I can think of is the crappy single threaded performance of xeon vs threadripper or nvidia driver/cuda disfavoring ampere. It's 580.82.09 and cuda release 12.6, V12.6.20

@Nexesenex
Copy link
Copy Markdown
Contributor

Nexesenex commented Nov 27, 2025

I tested the branch yesterday, after "playing games with the scheduler", and today again once you released the PR.

My GPU/PCIE setup is as such:

  • 0: RTX 3090 on PCIE 5.0 (4.0 effective) 16x, undervolt at 64% TDP.
  • 1: RTX 3090 on PCIE 4.0 4x, undervolt at 80% TDP.
  • 2: RTX A4000 on PCIE 4.0 4x, undervolt at 70% TDP.

Command : llama-server -m LLaMa-70B-Q5_K_M.gguf -ngl 150 -sm graph -b 256 -mg 1 -ts 31,31,19 -c 65536 -ctk q8_0 -ctv q8_0 --host 127.0.0.1 --port 8080

First observation, the GPU occupation is now sufficient in TG to trigger the P2 State of my GPUs, instead of playing with crappy locked overclocks. And that's great for my comfort of use and the durability of my GPUs.

Second, the PP is still lower for me than layer split, by a ratio of approximately 2.5 at 200 ctx and 2 at 10,000 ctx. Maybe my PCIE bandwith on GPU 1 and 2 is not enough to enjoy the benefit of TP, PP wise.

On the other hand, the TG is 5 % to 10% higher, mainly due to a better thermal management of my cards now that I don't need to lock the frequency. This, with no overclock and a higher undervolt to keep the GPUs cool. If I maintain the fixed overclock the TG drops below split layer due to thermal constraints.

Fourth, there's a bug when resetting the context. Let's say I chat using llama-server, then start another chat in the same client or another, and it provokes this:

======== Prompt cache: cache size: 13565, n_keep: 0, n_discarded_prompt: 0, cache_ram_n_min: 0, f_keep: 0.00, cache_ram_similarity: 0.50
updating prompt cache
 - saving prompt with length 13565, total state size = 2252.107 MiB
Q:\GitHub\ik_llama.cpp.fks\ggml\src\ggml-cuda.cu:994: GGML_ASSERT(size == ggml_nbytes(tensor)) failed

I can also have this error :

Q:\GitHub\ik_llama.cpp.fks\src\llama-load-tensors.cpp:242: GGML_ASSERT(nr % granularity == 0) failed

When loading an Q5_K_S 70B Llama 3.x model with split graph, while it poses no problem with split layer.

I patched myself up by replacing that:

    GGML_ASSERT(nr % granularity == 0);
    GGML_ASSERT(!splits.empty());
    if (granularity < 0) return std::vector<int>(splits.size(), nr);

With that:

    // GGML_ASSERT(nr % granularity == 0);
    GGML_ASSERT(!splits.empty());
    if (granularity < 0){
	    return std::vector<int>(splits.size(), nr);
    }
    else {
        granularity == 0;
    }

And then, in backend.cpp, line 217, commented an asset.

// GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");

And it works.

@magikRUKKOLA
Copy link
Copy Markdown

magikRUKKOLA commented Nov 27, 2025

@Nexesenex

And that's great for my comfort of use and the durability of my GPUs.

You do know that you can make a daemon that monitors the GPU utilization and overclocks it or undervolts it depending on load? Plus you can manage the fan speed.

If I maintain the fixed overclock the TG drops below split layer due to thermal constraints.

That would be a very bad idea for sure. The better way is to apply the OC offsets for the GPU clock and the VRAM clock. This way the GPU itself would do the overclocking depending on the current temperatures.

@Nexesenex
Copy link
Copy Markdown
Contributor

@magikRUKKOLA

I'm not that savvy, especially under Windows 11. Nvidia is very pesky about states and frequencies, Cuda wise. So with my undervolt, I use MSI afterburner until I find the right kick to trigger P2. Now, with TP, it triggers much more easily. I'm a true amateur, Rukkola! ^^

@magikRUKKOLA
Copy link
Copy Markdown

magikRUKKOLA commented Nov 28, 2025

@ikawrakow

-sm graph performance notes

The graphs below apparently show that two GPU config works faster than three. (not surprisingly)
It also shows that the -sm graph prefill performance efficiency improves with longer context (if compared to -sm layer).
The decode speed with three GPU config drops only by about 10% (for 32k ctx) that is, from 25.81 t/s to 23.17 t/s. Also I made an extended test for 72k ctx. The total drop for such a case is about 18% (from 25.67 t/s to 21.24 t/s).

3 x 3090, Llama-3.3-70B-Instruct-Q4_0.gguf

prefill

decode

Details
#!/usr/bin/env bash

export MALLOC_CONF="background_thread:true,percpu_arena:phycpu,metadata_thp:auto,dirty_decay_ms:10000,muzzy_decay_ms:60000"
export LD_PRELOAD=/usr/local/lib/libjemalloc.so

ulimit -n 9999
ulimit -l unlimited

export CUDA_VISIBLE_DEVICES="0,1,2"

/opt/ik_llama.cpp/ik_llama.cpp/build/bin/llama-sweep-bench \
    --warmup-batch \
    -m /opt/unsloth/Llama-3.3-70B-Instruct-GGUF/Q4_0/Llama-3.3-70B-Instruct-Q4_0.gguf \
    -sm graph \
    -ub 1024 \
    --threads $(grep ^cpu\\scores /proc/cpuinfo | uniq | awk '{print $4}' | xargs -I{} echo "{}-0" | bc) \
    -c 32768 \
    -mla 3 \
    -fa on \
    -ctk q8_0 \
    -ctv q8_0 \
    --no-mmap \
    -ngl 99

File: /opt/unsloth/Llama-3.3-70B-Instruct-GGUF/Q4_0/ik_llama.cpp-bench-sm-graph-f16-2gpu.log


main: n_kv_max = 16384, n_batch = 2048, n_ubatch = 1024, flash_attn = 1, n_gpu_layers = 99, n_threads = 64, n_threads_batch = 64

|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  1024 |    256 |      0 |    1.280 |   800.29 |    9.058 |    28.26 |
|  1024 |    256 |   1024 |    1.300 |   787.96 |    9.102 |    28.13 |
|  1024 |    256 |   2048 |    1.316 |   777.85 |    9.165 |    27.93 |
|  1024 |    256 |   3072 |    1.334 |   767.45 |    9.220 |    27.77 |
|  1024 |    256 |   4096 |    1.352 |   757.34 |    9.306 |    27.51 |
|  1024 |    256 |   5120 |    1.370 |   747.55 |    9.464 |    27.05 |
|  1024 |    256 |   6144 |    1.387 |   738.23 |    9.476 |    27.02 |
|  1024 |    256 |   7168 |    1.407 |   727.93 |    9.501 |    26.95 |
|  1024 |    256 |   8192 |    1.425 |   718.83 |    9.528 |    26.87 |
|  1024 |    256 |   9216 |    1.442 |   710.13 |    9.560 |    26.78 |
|  1024 |    256 |  10240 |    1.460 |   701.59 |    9.608 |    26.64 |
|  1024 |    256 |  11264 |    1.478 |   693.06 |    9.747 |    26.26 |
|  1024 |    256 |  12288 |    1.495 |   685.11 |    9.761 |    26.23 |
|  1024 |    256 |  13312 |    1.512 |   677.18 |    9.759 |    26.23 |
|  1024 |    256 |  14336 |    1.529 |   669.83 |    9.805 |    26.11 |
|  1024 |    256 |  15360 |    1.550 |   660.63 |    9.839 |    26.02 |

File: /opt/unsloth/Llama-3.3-70B-Instruct-GGUF/Q4_0/ik_llama.cpp-bench-sm-layer-f16-2gpu.log


main: n_kv_max = 16384, n_batch = 2048, n_ubatch = 1024, flash_attn = 1, n_gpu_layers = 99, n_threads = 64, n_threads_batch = 64

|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  1024 |    256 |      0 |    1.342 |   763.16 |   11.492 |    22.28 |
|  1024 |    256 |   1024 |    1.376 |   743.93 |   11.570 |    22.13 |
|  1024 |    256 |   2048 |    1.407 |   728.02 |   11.676 |    21.93 |
|  1024 |    256 |   3072 |    1.442 |   710.23 |   11.842 |    21.62 |
|  1024 |    256 |   4096 |    1.477 |   693.24 |   11.909 |    21.50 |
|  1024 |    256 |   5120 |    1.511 |   677.55 |   12.109 |    21.14 |
|  1024 |    256 |   6144 |    1.544 |   663.25 |   12.125 |    21.11 |
|  1024 |    256 |   7168 |    1.577 |   649.15 |   12.183 |    21.01 |
|  1024 |    256 |   8192 |    1.608 |   636.80 |   12.313 |    20.79 |
|  1024 |    256 |   9216 |    1.643 |   623.39 |   12.336 |    20.75 |
|  1024 |    256 |  10240 |    1.674 |   611.68 |   12.451 |    20.56 |
|  1024 |    256 |  11264 |    1.707 |   600.04 |   12.586 |    20.34 |
|  1024 |    256 |  12288 |    1.741 |   588.32 |   12.637 |    20.26 |
|  1024 |    256 |  13312 |    1.772 |   577.89 |   12.781 |    20.03 |
|  1024 |    256 |  14336 |    1.802 |   568.40 |   12.831 |    19.95 |
|  1024 |    256 |  15360 |    1.834 |   558.41 |   12.911 |    19.83 |

File: /opt/unsloth/Llama-3.3-70B-Instruct-GGUF/Q4_0/ik_llama.cpp-bench-sm-graph-f16-3gpu-72kctx.log


main: n_kv_max = 73728, n_batch = 2048, n_ubatch = 1024, flash_attn = 1, n_gpu_layers = 99, n_threads = 64, n_threads_batch = 64

|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  1024 |    256 |      0 |    1.685 |   607.58 |    9.972 |    25.67 |
|  1024 |    256 |   1024 |    1.697 |   603.57 |   10.026 |    25.53 |
|  1024 |    256 |   2048 |    1.706 |   600.21 |   10.068 |    25.43 |
|  1024 |    256 |   3072 |    1.716 |   596.73 |   10.125 |    25.29 |
|  1024 |    256 |   4096 |    1.727 |   592.86 |   10.158 |    25.20 |
|  1024 |    256 |   5120 |    1.738 |   589.11 |   10.192 |    25.12 |
|  1024 |    256 |   6144 |    1.747 |   586.02 |   10.235 |    25.01 |
|  1024 |    256 |   7168 |    1.757 |   582.78 |   10.299 |    24.86 |
|  1024 |    256 |   8192 |    1.768 |   579.11 |   10.354 |    24.72 |
|  1024 |    256 |   9216 |    1.778 |   575.99 |   10.402 |    24.61 |
|  1024 |    256 |  10240 |    1.788 |   572.59 |   10.439 |    24.52 |
|  1024 |    256 |  11264 |    1.801 |   568.51 |   10.570 |    24.22 |
|  1024 |    256 |  12288 |    1.810 |   565.60 |   10.581 |    24.19 |
|  1024 |    256 |  13312 |    1.820 |   562.59 |   10.573 |    24.21 |
|  1024 |    256 |  14336 |    1.831 |   559.23 |   10.587 |    24.18 |
|  1024 |    256 |  15360 |    1.841 |   556.17 |   10.610 |    24.13 |
|  1024 |    256 |  16384 |    1.851 |   553.22 |   10.610 |    24.13 |
|  1024 |    256 |  17408 |    1.861 |   550.30 |   10.631 |    24.08 |
|  1024 |    256 |  18432 |    1.871 |   547.33 |   10.660 |    24.02 |
|  1024 |    256 |  19456 |    1.881 |   544.43 |   10.662 |    24.01 |
|  1024 |    256 |  20480 |    1.891 |   541.47 |   10.694 |    23.94 |
|  1024 |    256 |  21504 |    1.902 |   538.49 |   10.805 |    23.69 |
|  1024 |    256 |  22528 |    1.910 |   536.00 |   10.824 |    23.65 |
|  1024 |    256 |  23552 |    1.924 |   532.29 |   10.846 |    23.60 |
|  1024 |    256 |  24576 |    1.934 |   529.40 |   10.874 |    23.54 |
|  1024 |    256 |  25600 |    1.944 |   526.87 |   10.857 |    23.58 |
|  1024 |    256 |  26624 |    1.954 |   524.02 |   10.879 |    23.53 |
|  1024 |    256 |  27648 |    1.962 |   521.82 |   10.916 |    23.45 |
|  1024 |    256 |  28672 |    1.975 |   518.48 |   10.923 |    23.44 |
|  1024 |    256 |  29696 |    1.990 |   514.46 |   10.967 |    23.34 |
|  1024 |    256 |  30720 |    1.999 |   512.14 |   10.971 |    23.33 |
|  1024 |    256 |  31744 |    2.013 |   508.73 |   11.109 |    23.04 |
|  1024 |    256 |  32768 |    2.023 |   506.29 |   11.132 |    23.00 |
|  1024 |    256 |  33792 |    2.037 |   502.79 |   11.117 |    23.03 |
|  1024 |    256 |  34816 |    2.049 |   499.79 |   11.151 |    22.96 |
|  1024 |    256 |  35840 |    2.061 |   496.87 |   11.125 |    23.01 |
|  1024 |    256 |  36864 |    2.073 |   494.08 |   11.180 |    22.90 |
|  1024 |    256 |  37888 |    2.085 |   491.10 |   11.245 |    22.77 |
|  1024 |    256 |  38912 |    2.097 |   488.31 |   11.260 |    22.73 |
|  1024 |    256 |  39936 |    2.119 |   483.25 |   11.273 |    22.71 |
|  1024 |    256 |  40960 |    2.139 |   478.78 |   11.278 |    22.70 |
|  1024 |    256 |  41984 |    2.131 |   480.54 |   11.388 |    22.48 |
|  1024 |    256 |  43008 |    2.149 |   476.44 |   11.424 |    22.41 |
|  1024 |    256 |  44032 |    2.155 |   475.10 |   11.434 |    22.39 |
|  1024 |    256 |  45056 |    2.163 |   473.31 |   11.421 |    22.42 |
|  1024 |    256 |  46080 |    2.167 |   472.51 |   11.443 |    22.37 |
|  1024 |    256 |  47104 |    2.191 |   467.29 |   11.461 |    22.34 |
|  1024 |    256 |  48128 |    2.197 |   466.19 |   11.482 |    22.30 |
|  1024 |    256 |  49152 |    2.236 |   457.87 |   11.492 |    22.28 |
|  1024 |    256 |  50176 |    2.217 |   461.90 |   11.510 |    22.24 |
|  1024 |    256 |  51200 |    2.230 |   459.24 |   11.528 |    22.21 |
|  1024 |    256 |  52224 |    2.252 |   454.71 |   11.560 |    22.15 |
|  1024 |    256 |  53248 |    2.264 |   452.28 |   11.645 |    21.98 |
|  1024 |    256 |  54272 |    2.279 |   449.22 |   11.672 |    21.93 |
|  1024 |    256 |  55296 |    2.282 |   448.64 |   11.694 |    21.89 |
|  1024 |    256 |  56320 |    2.299 |   445.34 |   11.693 |    21.89 |
|  1024 |    256 |  57344 |    2.321 |   441.12 |   11.704 |    21.87 |
|  1024 |    256 |  58368 |    2.320 |   441.47 |   11.746 |    21.79 |
|  1024 |    256 |  59392 |    2.333 |   438.84 |   11.735 |    21.81 |
|  1024 |    256 |  60416 |    2.350 |   435.69 |   11.737 |    21.81 |
|  1024 |    256 |  61440 |    2.363 |   433.31 |   11.774 |    21.74 |
|  1024 |    256 |  62464 |    2.390 |   428.39 |   11.860 |    21.59 |
|  1024 |    256 |  63488 |    2.397 |   427.17 |   11.873 |    21.56 |
|  1024 |    256 |  64512 |    2.413 |   424.39 |   11.926 |    21.47 |
|  1024 |    256 |  65536 |    2.437 |   420.19 |   11.953 |    21.42 |
|  1024 |    256 |  66560 |    2.430 |   421.32 |   11.967 |    21.39 |
|  1024 |    256 |  67584 |    2.450 |   417.92 |   11.966 |    21.39 |
|  1024 |    256 |  68608 |    2.467 |   415.05 |   11.969 |    21.39 |
|  1024 |    256 |  69632 |    2.474 |   413.90 |   12.006 |    21.32 |
|  1024 |    256 |  70656 |    2.492 |   410.88 |   12.012 |    21.31 |
|  1024 |    256 |  71680 |    2.500 |   409.68 |   12.023 |    21.29 |
|  1024 |    256 |  72704 |    2.518 |   406.68 |   12.055 |    21.24 |

File: /opt/unsloth/Llama-3.3-70B-Instruct-GGUF/Q4_0/ik_llama.cpp-bench-sm-graph-f16.log

|  1024 |    256 |      0 |    1.680 |   609.58 |    9.918 |    25.81 |
|  1024 |    256 |   1024 |    1.690 |   605.97 |    9.965 |    25.69 |
|  1024 |    256 |   2048 |    1.704 |   600.81 |   10.000 |    25.60 |
|  1024 |    256 |   3072 |    1.711 |   598.45 |   10.050 |    25.47 |
|  1024 |    256 |   4096 |    1.721 |   595.09 |   10.200 |    25.10 |
|  1024 |    256 |   5120 |    1.731 |   591.66 |   10.128 |    25.28 |
|  1024 |    256 |   6144 |    1.741 |   588.08 |   10.170 |    25.17 |
|  1024 |    256 |   7168 |    1.753 |   584.29 |   10.198 |    25.10 |
|  1024 |    256 |   8192 |    1.761 |   581.35 |   10.222 |    25.04 |
|  1024 |    256 |   9216 |    1.772 |   578.04 |   10.278 |    24.91 |
|  1024 |    256 |  10240 |    1.782 |   574.69 |   10.380 |    24.66 |
|  1024 |    256 |  11264 |    1.791 |   571.63 |   10.469 |    24.45 |
|  1024 |    256 |  12288 |    1.804 |   567.66 |   10.488 |    24.41 |
|  1024 |    256 |  13312 |    1.813 |   564.68 |   10.512 |    24.35 |
|  1024 |    256 |  14336 |    1.824 |   561.55 |   10.500 |    24.38 |
|  1024 |    256 |  15360 |    1.834 |   558.31 |   10.528 |    24.32 |
|  1024 |    256 |  16384 |    1.843 |   555.56 |   10.528 |    24.32 |
|  1024 |    256 |  17408 |    1.853 |   552.60 |   10.555 |    24.25 |
|  1024 |    256 |  18432 |    1.863 |   549.66 |   10.578 |    24.20 |
|  1024 |    256 |  19456 |    1.872 |   546.94 |   10.598 |    24.16 |
|  1024 |    256 |  20480 |    1.887 |   542.63 |   10.645 |    24.05 |
|  1024 |    256 |  21504 |    1.896 |   540.05 |   10.765 |    23.78 |
|  1024 |    256 |  22528 |    1.906 |   537.23 |   10.775 |    23.76 |
|  1024 |    256 |  23552 |    1.916 |   534.36 |   10.787 |    23.73 |
|  1024 |    256 |  24576 |    1.925 |   531.97 |   10.783 |    23.74 |
|  1024 |    256 |  25600 |    1.935 |   529.15 |   10.792 |    23.72 |
|  1024 |    256 |  26624 |    1.945 |   526.55 |   10.799 |    23.71 |
|  1024 |    256 |  27648 |    1.955 |   523.84 |   10.828 |    23.64 |
|  1024 |    256 |  28672 |    1.965 |   521.10 |   10.839 |    23.62 |
|  1024 |    256 |  29696 |    1.978 |   517.57 |   10.880 |    23.53 |
|  1024 |    256 |  30720 |    1.989 |   514.86 |   10.920 |    23.44 |
|  1024 |    256 |  31744 |    1.998 |   512.39 |   11.048 |    23.17 |

File: /opt/unsloth/Llama-3.3-70B-Instruct-GGUF/Q4_0/ik_llama.cpp-bench-sm-graph.log


main: n_kv_max = 32768, n_batch = 2048, n_ubatch = 1024, flash_attn = 1, n_gpu_layers = 99, n_threads = 64, n_threads_batch = 64

|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  1024 |    256 |      0 |    1.685 |   607.79 |   10.311 |    24.83 |
|  1024 |    256 |   1024 |    1.698 |   603.13 |   10.443 |    24.51 |
|  1024 |    256 |   2048 |    1.708 |   599.42 |   10.541 |    24.29 |
|  1024 |    256 |   3072 |    1.717 |   596.51 |   10.575 |    24.21 |
|  1024 |    256 |   4096 |    1.728 |   592.62 |   10.617 |    24.11 |
|  1024 |    256 |   5120 |    1.750 |   585.19 |   10.699 |    23.93 |
|  1024 |    256 |   6144 |    1.748 |   585.73 |   10.753 |    23.81 |
|  1024 |    256 |   7168 |    1.759 |   582.15 |   10.836 |    23.62 |
|  1024 |    256 |   8192 |    1.767 |   579.46 |   10.918 |    23.45 |
|  1024 |    256 |   9216 |    1.782 |   574.68 |   11.000 |    23.27 |
|  1024 |    256 |  10240 |    1.792 |   571.57 |   11.095 |    23.07 |
|  1024 |    256 |  11264 |    1.801 |   568.54 |   11.251 |    22.75 |
|  1024 |    256 |  12288 |    1.812 |   565.09 |   11.319 |    22.62 |
|  1024 |    256 |  13312 |    1.834 |   558.38 |   11.378 |    22.50 |
|  1024 |    256 |  14336 |    1.833 |   558.64 |   11.435 |    22.39 |
|  1024 |    256 |  15360 |    1.844 |   555.18 |   11.484 |    22.29 |
|  1024 |    256 |  16384 |    1.855 |   552.15 |   11.532 |    22.20 |
|  1024 |    256 |  17408 |    1.864 |   549.40 |   11.597 |    22.07 |
|  1024 |    256 |  18432 |    1.876 |   545.95 |   11.656 |    21.96 |
|  1024 |    256 |  19456 |    1.886 |   542.91 |   11.750 |    21.79 |
|  1024 |    256 |  20480 |    1.896 |   539.99 |   11.784 |    21.72 |
|  1024 |    256 |  21504 |    1.937 |   528.73 |   11.956 |    21.41 |
|  1024 |    256 |  22528 |    1.921 |   532.92 |   12.015 |    21.31 |
|  1024 |    256 |  23552 |    1.936 |   528.85 |   12.075 |    21.20 |
|  1024 |    256 |  24576 |    1.950 |   525.25 |   12.108 |    21.14 |
|  1024 |    256 |  25600 |    1.958 |   522.87 |   12.182 |    21.01 |
|  1024 |    256 |  26624 |    1.973 |   518.89 |   12.232 |    20.93 |
|  1024 |    256 |  27648 |    1.983 |   516.40 |   12.290 |    20.83 |
|  1024 |    256 |  28672 |    2.001 |   511.74 |   12.370 |    20.70 |
|  1024 |    256 |  29696 |    2.010 |   509.38 |   12.423 |    20.61 |
|  1024 |    256 |  30720 |    2.025 |   505.62 |   12.478 |    20.52 |
|  1024 |    256 |  31744 |    2.043 |   501.18 |   12.647 |    20.24 |

File: /opt/unsloth/Llama-3.3-70B-Instruct-GGUF/Q4_0/ik_llama.cpp-bench-sm-layer-f16.log


main: n_kv_max = 32768, n_batch = 2048, n_ubatch = 1024, flash_attn = 1, n_gpu_layers = 99, n_threads = 64, n_threads_batch = 64

|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  1024 |    256 |      0 |    1.358 |   754.12 |   11.506 |    22.25 |
|  1024 |    256 |   1024 |    1.393 |   735.31 |   11.579 |    22.11 |
|  1024 |    256 |   2048 |    1.423 |   719.36 |   11.688 |    21.90 |
|  1024 |    256 |   3072 |    1.460 |   701.43 |   11.861 |    21.58 |
|  1024 |    256 |   4096 |    1.496 |   684.63 |   11.923 |    21.47 |
|  1024 |    256 |   5120 |    1.532 |   668.26 |   12.125 |    21.11 |
|  1024 |    256 |   6144 |    1.564 |   654.57 |   12.144 |    21.08 |
|  1024 |    256 |   7168 |    1.597 |   641.22 |   12.197 |    20.99 |
|  1024 |    256 |   8192 |    1.635 |   626.37 |   12.331 |    20.76 |
|  1024 |    256 |   9216 |    1.671 |   612.96 |   12.360 |    20.71 |
|  1024 |    256 |  10240 |    1.704 |   600.86 |   12.470 |    20.53 |
|  1024 |    256 |  11264 |    1.739 |   588.68 |   12.605 |    20.31 |
|  1024 |    256 |  12288 |    1.772 |   577.99 |   12.661 |    20.22 |
|  1024 |    256 |  13312 |    1.807 |   566.78 |   12.805 |    19.99 |
|  1024 |    256 |  14336 |    1.838 |   557.11 |   12.852 |    19.92 |
|  1024 |    256 |  15360 |    1.871 |   547.41 |   12.937 |    19.79 |
|  1024 |    256 |  16384 |    1.903 |   538.00 |   13.079 |    19.57 |
|  1024 |    256 |  17408 |    1.940 |   527.94 |   13.131 |    19.50 |
|  1024 |    256 |  18432 |    1.974 |   518.64 |   13.267 |    19.30 |
|  1024 |    256 |  19456 |    2.007 |   510.22 |   13.308 |    19.24 |
|  1024 |    256 |  20480 |    2.045 |   500.71 |   13.377 |    19.14 |
|  1024 |    256 |  21504 |    2.080 |   492.31 |   13.558 |    18.88 |
|  1024 |    256 |  22528 |    2.116 |   483.98 |   13.586 |    18.84 |
|  1024 |    256 |  23552 |    2.150 |   476.17 |   13.736 |    18.64 |
|  1024 |    256 |  24576 |    2.185 |   468.65 |   13.791 |    18.56 |
|  1024 |    256 |  25600 |    2.220 |   461.34 |   13.860 |    18.47 |
|  1024 |    256 |  26624 |    2.254 |   454.21 |   14.015 |    18.27 |
|  1024 |    256 |  27648 |    2.289 |   447.42 |   14.042 |    18.23 |
|  1024 |    256 |  28672 |    2.327 |   440.12 |   14.182 |    18.05 |
|  1024 |    256 |  29696 |    2.361 |   433.67 |   14.245 |    17.97 |
|  1024 |    256 |  30720 |    2.400 |   426.68 |   14.297 |    17.91 |
|  1024 |    256 |  31744 |    2.444 |   418.95 |   14.489 |    17.67 |

File: /opt/unsloth/Llama-3.3-70B-Instruct-GGUF/Q4_0/ik_llama.cpp-bench-sm-layer.log


main: n_kv_max = 32768, n_batch = 2048, n_ubatch = 1024, flash_attn = 1, n_gpu_layers = 99, n_threads = 64, n_threads_batch = 64

|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  1024 |    256 |      0 |    1.377 |   743.38 |   11.629 |    22.01 |
|  1024 |    256 |   1024 |    1.414 |   724.38 |   11.842 |    21.62 |
|  1024 |    256 |   2048 |    1.449 |   706.63 |   12.065 |    21.22 |
|  1024 |    256 |   3072 |    1.482 |   691.03 |   12.389 |    20.66 |
|  1024 |    256 |   4096 |    1.522 |   672.92 |   12.591 |    20.33 |
|  1024 |    256 |   5120 |    1.556 |   658.15 |   12.946 |    19.77 |
|  1024 |    256 |   6144 |    1.591 |   643.70 |   13.126 |    19.50 |
|  1024 |    256 |   7168 |    1.627 |   629.49 |   13.325 |    19.21 |
|  1024 |    256 |   8192 |    1.668 |   614.09 |   13.620 |    18.80 |
|  1024 |    256 |   9216 |    1.704 |   600.90 |   13.794 |    18.56 |
|  1024 |    256 |  10240 |    1.737 |   589.46 |   14.044 |    18.23 |
|  1024 |    256 |  11264 |    1.773 |   577.53 |   14.334 |    17.86 |
|  1024 |    256 |  12288 |    1.805 |   567.21 |   14.527 |    17.62 |
|  1024 |    256 |  13312 |    1.835 |   558.02 |   14.829 |    17.26 |
|  1024 |    256 |  14336 |    1.871 |   547.39 |   15.018 |    17.05 |
|  1024 |    256 |  15360 |    1.904 |   537.91 |   15.231 |    16.81 |
|  1024 |    256 |  16384 |    1.938 |   528.34 |   15.525 |    16.49 |
|  1024 |    256 |  17408 |    1.976 |   518.31 |   15.719 |    16.29 |
|  1024 |    256 |  18432 |    2.015 |   508.31 |   16.019 |    15.98 |
|  1024 |    256 |  19456 |    2.056 |   498.03 |   16.202 |    15.80 |
|  1024 |    256 |  20480 |    2.096 |   488.58 |   16.415 |    15.60 |
|  1024 |    256 |  21504 |    2.127 |   481.32 |   16.730 |    15.30 |
|  1024 |    256 |  22528 |    2.171 |   471.75 |   16.901 |    15.15 |
|  1024 |    256 |  23552 |    2.204 |   464.59 |   17.208 |    14.88 |
|  1024 |    256 |  24576 |    2.244 |   456.39 |   17.404 |    14.71 |
|  1024 |    256 |  25600 |    2.278 |   449.43 |   17.609 |    14.54 |
|  1024 |    256 |  26624 |    2.311 |   443.11 |   17.913 |    14.29 |
|  1024 |    256 |  27648 |    2.348 |   436.06 |   18.078 |    14.16 |
|  1024 |    256 |  28672 |    2.386 |   429.15 |   18.362 |    13.94 |
|  1024 |    256 |  29696 |    2.420 |   423.17 |   18.576 |    13.78 |
|  1024 |    256 |  30720 |    2.472 |   414.20 |   18.776 |    13.63 |
|  1024 |    256 |  31744 |    2.535 |   403.99 |   19.110 |    13.40 |

@magikRUKKOLA
Copy link
Copy Markdown

@ikawrakow

ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);

Apparently the P2P access is never getting enabled huh?

@ikawrakow
Copy link
Copy Markdown
Owner Author

Thank you all for testing!

@Nexesenex @Ph0rk0z Is it possible that you both don't have P2P copy enabled? If I disable P2P copy on my system (cmake -DGGML_CUDA_NO_PEER_COPY=ON), I get ~25% drop in PP and ~10% drop in TG performance:

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
1024 256 0 1.681 609.33 10.212 25.07
1024 256 1024 1.607 637.02 10.457 24.48
1024 256 2048 1.622 631.39 10.293 24.87
1024 256 3072 1.712 598.11 10.362 24.71
1024 256 4096 1.658 617.55 10.509 24.36
1024 256 5120 1.675 611.18 10.574 24.21
1024 256 6144 1.691 605.49 10.609 24.13
1024 256 7168 1.708 599.48 10.603 24.14
1024 256 8192 1.824 561.44 10.621 24.10
1024 256 9216 1.742 587.99 10.664 24.01
1024 256 10240 1.763 580.75 10.738 23.84
1024 256 11264 1.780 575.37 10.885 23.52
1024 256 12288 1.795 570.63 10.882 23.53
1024 256 13312 1.812 565.07 10.883 23.52
1024 256 14336 1.827 560.51 10.926 23.43
1024 256 15360 1.845 555.10 10.964 23.35

From @magikRUKKOLA's results it looks like I need to think about limiting TP to 2 GPUs, so using GPU0,1 for the first N1 layers, GPU2,3 for the second N2 layers, etc. Odd number of GPU's is of course awkward as things don't divide nicely into equally sized portions. This is perhaps less of an issue for the FFN part, but for sure is an issue for self-attention. For instance, Llama-70B has 8 KV heads, so we can only split as 3,3,2 with 3 GPUs.

@ikawrakow
Copy link
Copy Markdown
Owner Author

@magikRUKKOLA

Apparently the P2P access is never getting enabled huh?

This code is an irrelevant remnant from the split mode "row" implementation (and I guess I should remove it to avoid confusion). The copy between GPU's is now always done using this:

CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, dst_ctx->device, src->data, src_ctx->device, ggml_nbytes(src), cudaStreamPerThread));

@ikawrakow
Copy link
Copy Markdown
Owner Author

Also I got a question regarding the BW. Is it only about 1.6 Gbit/sec ? Will it make any difference if you would install the NvLink bridge (3-slot bridge in your case)?

I have zero experience with NvLink, so don't know the answer. The BW for the P2P copy is determined by the PCI-E BW of the two GPU's involved. It should be 30 GB/s in my case, but I observe 20-22 GB/s in practice.

@ikawrakow
Copy link
Copy Markdown
Owner Author

OK, I just made another change that allows to copy all partial results as fp16. With this change I now see PP-1024 = 879 t/s at zero context for LlaMA3-70B! Here the full sweep-bench:

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
1024 256 0 1.165 879.32 9.101 28.13
1024 256 1024 1.172 873.59 9.156 27.96
1024 256 2048 1.189 861.37 9.263 27.64
1024 256 3072 1.208 847.53 9.260 27.65
1024 256 4096 1.226 835.35 9.321 27.47
1024 256 5120 1.244 823.40 9.481 27.00
1024 256 6144 1.262 811.63 9.527 26.87
1024 256 7168 1.282 798.47 9.520 26.89
1024 256 8192 1.301 787.25 9.555 26.79
1024 256 9216 1.320 775.95 9.582 26.72
1024 256 10240 1.337 765.86 9.667 26.48
1024 256 11264 1.356 755.25 9.787 26.16
1024 256 12288 1.376 743.94 9.788 26.16
1024 256 13312 1.394 734.56 9.889 25.89
1024 256 14336 1.413 724.56 9.985 25.64
1024 256 15360 1.431 715.61 9.917 25.82

Some napkin math: before the last change we had 790 t/s, so 1024/790 = 1.296 seconds per batch of 1024. We now have 879 t/s, so 1024/879 =1.165 s/batch, so we saved 0.131 t/s per batch. Before the change we were copying 7.5 GiB of data per batch, now we copy 5 GiB/batch, so 2.5 GiB less per batch. This works out to 2.5 GiB/0.131 seconds = 19 GiB/second effectively for the P2P copy.

@magikRUKKOLA
Copy link
Copy Markdown

magikRUKKOLA commented Nov 28, 2025

@ikawrakow

OK, I just made another change that allows to copy all partial results as fp16.

Its from about +15% (beginning of the ctx) down to +5% (the end of 72k ctx) -- for the 3 GPU setup.

Details ```

main: n_kv_max = 73728, n_batch = 2048, n_ubatch = 1024, flash_attn = 1, n_gpu_layers = 99, n_threads = 64, n_threads_batch = 64

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
1024 256 0 1.435 713.36 9.974 25.67
1024 256 1024 1.448 707.22 10.035 25.51
1024 256 2048 1.457 702.61 10.065 25.43
1024 256 3072 1.467 697.96 10.154 25.21
1024 256 4096 1.475 694.12 10.159 25.20
1024 256 5120 1.485 689.35 10.189 25.13
1024 256 6144 1.497 684.22 10.255 24.96
1024 256 7168 1.509 678.50 10.280 24.90
1024 256 8192 1.520 673.63 10.320 24.81
1024 256 9216 1.530 669.13 10.380 24.66
1024 256 10240 1.540 664.77 10.408 24.60
1024 256 11264 1.551 660.42 10.551 24.26
1024 256 12288 1.560 656.32 10.554 24.26
1024 256 13312 1.570 652.05 10.557 24.25
1024 256 14336 1.581 647.65 10.569 24.22
1024 256 15360 1.590 644.02 10.601 24.15
1024 256 16384 1.601 639.76 10.594 24.16
1024 256 17408 1.611 635.69 10.618 24.11
1024 256 18432 1.622 631.41 10.653 24.03
1024 256 19456 1.633 627.21 10.673 23.99
1024 256 20480 1.642 623.70 10.707 23.91
1024 256 21504 1.655 618.57 10.835 23.63
1024 256 22528 1.667 614.26 10.834 23.63
1024 256 23552 1.678 610.26 10.876 23.54
1024 256 24576 1.688 606.61 10.876 23.54
1024 256 25600 1.699 602.61 10.867 23.56
1024 256 26624 1.716 596.68 10.892 23.50
1024 256 27648 1.727 592.98 10.947 23.39
1024 256 28672 1.743 587.58 10.935 23.41
1024 256 29696 1.757 582.74 10.972 23.33
1024 256 30720 1.785 573.52 10.957 23.36
1024 256 31744 1.792 571.47 11.126 23.01
1024 256 32768 1.817 563.41 11.139 22.98
1024 256 33792 1.835 558.16 11.134 22.99
1024 256 34816 1.858 551.17 11.155 22.95
1024 256 35840 1.885 543.18 11.156 22.95
1024 256 36864 1.890 541.93 11.180 22.90
1024 256 37888 1.885 543.19 11.199 22.86
1024 256 38912 1.917 534.30 11.251 22.75
1024 256 39936 1.922 532.76 11.275 22.70
1024 256 40960 1.948 525.75 11.269 22.72
1024 256 41984 1.976 518.30 11.379 22.50
1024 256 43008 2.001 511.77 11.406 22.44
1024 256 44032 1.993 513.92 11.404 22.45
1024 256 45056 2.019 507.07 11.431 22.40
1024 256 46080 2.037 502.75 11.465 22.33
1024 256 47104 2.064 496.01 11.504 22.25
1024 256 48128 2.085 491.20 11.512 22.24
1024 256 49152 2.089 490.19 11.549 22.17
1024 256 50176 2.108 485.72 11.536 22.19
1024 256 51200 2.143 477.91 11.572 22.12
1024 256 52224 2.136 479.35 11.612 22.05
1024 256 53248 2.148 476.64 11.690 21.90
1024 256 54272 2.206 464.23 11.731 21.82
1024 256 55296 2.148 476.61 11.740 21.81
1024 256 56320 2.175 470.75 11.715 21.85
1024 256 57344 2.180 469.78 11.709 21.86
1024 256 58368 2.207 464.06 11.790 21.71
1024 256 59392 2.211 463.15 11.783 21.73
1024 256 60416 2.231 458.93 11.773 21.75
1024 256 61440 2.234 458.46 11.810 21.68
1024 256 62464 2.234 458.46 11.877 21.55
1024 256 63488 2.281 448.99 11.945 21.43
1024 256 64512 2.278 449.43 11.946 21.43
1024 256 65536 2.293 446.64 11.952 21.42
1024 256 66560 2.318 441.84 11.950 21.42
1024 256 67584 2.325 440.50 11.971 21.39
1024 256 68608 2.330 439.46 11.981 21.37
1024 256 69632 2.388 428.83 12.011 21.31
1024 256 70656 2.371 431.94 12.031 21.28
1024 256 71680 2.365 432.90 12.023 21.29
1024 256 72704 2.400 426.63 12.034 21.27
</details>

@ikawrakow
Copy link
Copy Markdown
Owner Author

@magikRUKKOLA

Thanks for the data point. Yes, the amount of data exchanged between the GPUs does not change with context length. With increasing context length computation becomes more expensive, so the fraction of the time spent copying data between GPUs decreases, so the impact of this optimization decreases.

But looking at all these datapoints, I'm starting to think that I should add another CUDA parameter that controls the precision of the intermediate GPU results being exchanged. It might as well be that Q8_0 is enough at least for some models, so that would allow to reduce the amount of data being exchanged by another factor of almost 2.

@magikRUKKOLA
Copy link
Copy Markdown

@ikawrakow

But looking at all these datapoints, I'm starting to think that I should add another CUDA parameter that controls the precision of the intermediate GPU results being exchanged. It might as well be that Q8_0 is enough at least for some models, so that would allow to reduce the amount of data being exchanged by another factor of almost 2.

I was wondering if its possible to estimate the quantization error prior to the F16 to Q8_0 conversion. That is, to use the CUDA code to first check the scale per block, etc. The mock code like:

#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>

// Simple Q8 block: 1 scale + 32 int8s
struct Q8Block {
    half scale;
    int8_t vals[32];
};

// Single decision function combining all checks
__device__ bool should_use_q8(const half* block, float* out_scale) {
    // Find max and mean
    float max_abs = 0.0f;
    float sum_abs = 0.0f;
    
    #pragma unroll
    for (int i = 0; i < 32; i++) {
        float v = fabsf(__half2float(block[i]));
        max_abs = fmaxf(max_abs, v);
        sum_abs += v;
    }
    
    *out_scale = max_abs / 127.0f;
    float mean_abs = sum_abs / 32.0f;
    
    // Check 1: Scale too small or too large
    if (*out_scale < 1e-5f || *out_scale > 1.0f) return false;
    
    // Check 2: Granularity (step size vs mean)
    float granularity = *out_scale / (mean_abs + 1e-8f);
    if (granularity > 0.02f) return false; // Max 2% steps
    
    // Check 3: Actual max error
    float max_err = 0.0f;
    #pragma unroll
    for (int i = 0; i < 32; i++) {
        float orig = __half2float(block[i]);
        int8_t q = __float2int_rn(orig / *out_scale);
        q = max(-127, min(127, q));
        float recon = (float)q * *out_scale;
        float err = fabsf(orig - recon) / (fabsf(orig) + 1e-8f);
        max_err = fmaxf(max_err, err);
    }
    
    return max_err < 0.01f; // Max 1% error
}

// Main kernel: activations in, mixed Q8/F16 out
__global__ void quantize_activations(
    const half* __restrict__ activations,  // Input activations (FP16)
    Q8Block* __restrict__ q8_out,          // Q8 output
    half* __restrict__ f16_out,            // FP16 fallback output
    uint8_t* __restrict__ flags,           // 1=Q8, 0=F16
    int n_elements
) {
    int block_idx = blockIdx.x * blockDim.x + threadIdx.x;
    int offset = block_idx * 32;
    
    if (offset >= n_elements) return;
    
    // Load 32 elements
    half vals[32];
    #pragma unroll
    for (int i = 0; i < 32; i++) {
        vals[i] = (offset + i < n_elements) ? 
                  activations[offset + i] : __float2half(0.0f);
    }
    
    // Decide
    float scale;
    if (should_use_q8(vals, &scale)) {
        // Store as Q8
        Q8Block block;
        block.scale = __float2half(scale);
        #pragma unroll
        for (int i = 0; i < 32; i++) {
            float v = __half2float(vals[i]);
            block.vals[i] = max(-127, min(127, __float2int_rn(v / scale)));
        }
        q8_out[block_idx] = block;
        flags[block_idx] = 1;
    } else {
        // Store as F16
        for (int i = 0; i < 32 && offset + i < n_elements; i++) {
            f16_out[offset + i] = vals[i];
        }
        flags[block_idx] = 0;
    }
}

// Launch wrapper
void quantize_activation_tensor(
    const half* d_activations,
    Q8Block* d_q8_out,
    half* d_f16_out,
    uint8_t* d_flags,
    int n_elements,
    cudaStream_t stream = 0
) {
    int n_blocks = (n_elements + 31) / 32;
    int threads = 256;
    int blocks = (n_blocks + threads - 1) / threads;
    
    quantize_activations<<<blocks, threads, 0, stream>>>(
        d_activations, d_q8_out, d_f16_out, d_flags, n_elements
    );
}

@ikawrakow
Copy link
Copy Markdown
Owner Author

This is of course possible, but relatively expensive. It then also becomes a nightmare to handle: one GPU decides that Q8_0 is OK, another decides that it is not, so now we get a mix of different results on the GPU that needs to combine them. I can see adding a kernel that performs the addition of two Q8_0 tensors, but handling mixes becomes way too complicated, so we will need to cast back to f32 first, so that will again add a performance penalty.

Overall I think it is just easier to make it a command line option. We convert activations to Q8_0 anyway for quantized model weights, so my guess is that it will be almost always fine to just use Q8_0, but have the option to turn it off if one notices issues.

@Ph0rk0z
Copy link
Copy Markdown

Ph0rk0z commented Nov 28, 2025

Is it possible that you both don't have P2P copy enabled?

I use ccmake and that is not enabled. But I do have GGML_CUDA_PEER_MAX_BATCH_SIZE 8192

post that commit

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
512 128 0 0.919 557.04 5.392 23.74
512 128 512 0.915 559.57 5.496 23.29
512 128 1024 0.920 556.41 5.453 23.47
copycopy

and here is what is happening with 4gpu transfers/usage wise.
5gum

It did help PP now that I look at the bench:

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
512 128 0 1.681 304.62 7.057 18.14
512 128 512 1.663 307.79 7.116 17.99

@ikawrakow
Copy link
Copy Markdown
Owner Author

@Ph0rk0z

OK, the P2P copy is not disabled, but is it also supported? The easiest way to find out is to just rebuild with -DGGML_CUDA_NO_PEER_COPY=ON and see if the performance is any different.

Btw, in your initial benchmarks, where you also have llama.cpp benchmark results, it would have been useful to also have llama.cpp with -sm row, so we have a comparison between split mode "row" and split mode "graph" on the same system that is not my development/test system.

@Ph0rk0z
Copy link
Copy Markdown

Ph0rk0z commented Nov 28, 2025

I've got patched driver and tested with NCCL. It is definitely supported. Transfer speeds without P2P are terrible, so is latency.

Mainline 2 cards.

smrowtransfer
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
512 128 0 1.587 322.60 7.209 17.76
512 128 512 1.634 313.29 7.335 17.45
512 128 1024 1.635 313.14 7.372 17.36
512 128 1536 1.657 308.92 7.469 17.14
512 128 2048 1.622 315.67 7.560 16.93
512 128 2560 1.663 307.94 7.646 16.74
512 128 3072 1.709 299.61 7.722 16.58
512 128 3584 1.695 302.07 7.754 16.51
512 128 4096 1.650 310.26 7.840 16.33
512 128 4608 1.731 295.79 7.932 16.14
512 128 5120 1.743 293.71 8.012 15.98
512 128 5632 1.715 298.49 8.079 15.84
512 128 6144 1.725 296.89 8.118 15.77
512 128 6656 1.762 290.57 8.210 15.59
512 128 7168 1.768 289.64 8.293 15.43
512 128 7680 1.784 286.99 8.371 15.29
512 128 8192 1.757 291.46 8.430 15.18

4 cards
smrow4gpu

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
512 128 0 3.841 133.29 8.071 15.86
512 128 512 3.850 132.97 8.198 15.61
512 128 1024 3.860 132.64 8.250 15.52
512 128 1536 3.873 132.18 8.334 15.36
512 128 2048 3.881 131.92 8.430 15.18
512 128 2560 3.900 131.29 8.528 15.01
512 128 3072 3.907 131.03 8.606 14.87
512 128 3584 3.918 130.68 8.633 14.83
512 128 4096 3.931 130.24 8.731 14.66
512 128 4608 3.942 129.89 8.811 14.53
512 128 5120 3.953 129.53 8.902 14.38
512 128 5632 3.961 129.26 8.965 14.28
512 128 6144 3.972 128.91 9.014 14.20
512 128 6656 3.985 128.50 9.107 14.06
512 128 7168 3.996 128.13 9.203 13.91
512 128 7680 4.005 127.86 9.288 13.78
512 128 8192 4.019 127.40 9.355 13.68

There's overhead with that many GPU and only a single thread is used at 100%. I remember getting 17.99t/s on these probably same llama models when I had only 2x3090, some P40s and a broadwell xeon. CPU single core perf and probably the driver play into it.

I got my OG server in april of '23 and the upgrade board I had to fix around january of '24. Wish I had bought epyc on H12SSL because they appreciated in price and would have probably had fewer issues. Was just lazy and wanted a "full" solution. Didn't know small caveats like PLX, QPI, etc would have such effects.

Can also try to do some runs for exllama2/3 with similar sized models but there's no sweep bench so might be apples to oranges.

@magikRUKKOLA
Copy link
Copy Markdown

magikRUKKOLA commented Nov 28, 2025

@ikawrakow

It should be 30 GB/s in my case, but I observe 20-22 GB/s in practice.

Just a small note.

Well, in case you're sure that by using cudaMemcpyPeerAsync it automatically utilizes the P2P transfers (/optimizations in terms of topology etc.), that's cool (perhaps I need to read more docs etc.).

But keep in mind that if you're using the bidirectional transfers with the currently installed P2P-enabled 580.105.08 then your unidirectional speed via P2P PCIe will be: ....

/usr/share/doc/nvidia-cuda-toolkit/examples/Samples/5_Domain_Specific/p2pBandwidthLatencyTest/p2pBandwidthLatencyTest

quote:

Unidirectional P2P=Enabled Bandwidth (P2P Writes) Matrix (GB/s)
   D\D     0      1      2
     0 835.56  25.92  25.85
     1  26.28 902.14  25.92
     2  25.83  25.87 901.10

*please note that the P620 Lenovo usually about 5% slower in P2P transfers for some unknown reason than the regular motherboard with AMD Threadripper. But that is negligible for our particular case.

So the theoretical max for P2P transfers in one direction is about 26 GB/s.

You could get about 2 times more like:

Bidirectional P2P=Enabled Bandwidth Matrix (GB/s)
   D\D     0      1      2
     0 907.11  51.30  50.67
     1  52.01 906.78  51.21
     2  51.72  50.25 850.30

but its only foir the bidirectional trasfers which have no place (right huh?) in the sequential algo you described?

That said, [there is a change that] NvLink would perform somewhat better since the specs of NvLink are saying that the speed should be about 2 times better than P2P transfers.

Well, it would also allow to take off the load from the CPU servicing the PCIe lanes. In case of AMD Threadripper PRO its likely pointless -- the CPU supports up to 128 PCIe lanes (the 3995wx) so its will be about 4 GPUs with PCIe 4.0 with full x16, but only 3 GPU PCIe 4.0 x8. The rest of the lanes went somewhere else. So ideally one would strive for 4 x GPU workstation. In case of RTX 3090 (which are really cheap nowadays ... its about 700 EUR a pop) one could get a workstation equal in terms of VRAM to RTX 6000 Pro Blackwell for a third of a price! Should we consider building the workstations with 8 GPUs instead of 4? I am suggesting we could use the NvLink for the GPUs that are connected to the CPU via PCIe 4.0x8 (or whatever) etc. So that the each NvLink would connect the the GPU with x16 to the GPU w[h]ich is sitting at x8. This way we will be able to connect 8 x 3090 RTX per workstation. Its 192 GB per workstation. The next step would be of course to find the right network cards to connect two or three of them (P2P, via ring-architecture) together and get the target space of 400 or 600 GB of pure VRAM.

The systems such as these would require a dedicated AC and the liquid cooling of course (it would be beneficial for the overclocking as well).

What do you think? Is it worth it? Do you have any better ideas of what to do in the pursuit of the optimal performance?

@Ph0rk0z
Copy link
Copy Markdown

Ph0rk0z commented Nov 29, 2025

FWIW, nvlink + p2p driver doesn't work together.

@ikawrakow
Copy link
Copy Markdown
Owner Author

@magikRUKKOLA

I thought about bidirectional copies, but it was difficult (difficult as in I didn't manage to do it) to convince the backend to do the right graph splits. I didn't know that for 2 GPU's bidirectional is twice as fast as one direction, so didn't try harder. But if it is true, then for 2 GPUs doing the following will be faster

  • GPU-0 copies its partial result to GPU-1 and vice versa
  • Both GPU's compute the sum of the partial results and continue without further synchronization

For 4 GPUs we currently have 6 copies per synchronization point (GPU-1,2,3, copy their data to GPU-0, GPU-0 computes the sum, copies the result back to GPU-1,2,3). If we change to each GPU copying its results to every other GPU, we will make 3x4 = 12 copies. If bidirectional is twice as fast, then it will take the same time as it currently does.

For 8 GPUs we currently make 14 copies. Going to each GPU copying its result to every other GPU will require 7x8 = 56 copies. Hence, even if bidirectional is twice as fast, it will be still 2 times slower.

Concerning building multi-GPU boxes: not an expert in hardware, but from what I have seen so far with my TP implementation, having fewer bigger (more VRAM) and faster GPUs seems better than having more GPUs with the same amount of VRAM. I now have TP for GLM-4.5/4.6 ready, and there TG is slower than split mode layer even for 2 GPUs. I'll make another PR in a bit and give more details there.

@ikawrakow
Copy link
Copy Markdown
Owner Author

Closed in favor of #1022, which includes all changes in this branch.

@ikawrakow ikawrakow closed this Dec 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants