Skip to content

POC: CUDA tensor parallel (MoE models)#1022

Merged
ikawrakow merged 32 commits intomainfrom
ik/poc_tp_glm4.5
Dec 1, 2025
Merged

POC: CUDA tensor parallel (MoE models)#1022
ikawrakow merged 32 commits intomainfrom
ik/poc_tp_glm4.5

Conversation

@ikawrakow
Copy link
Copy Markdown
Owner

@ikawrakow ikawrakow commented Nov 29, 2025

This is a very rough around the edges POC for tensor parallelism (TP) for MoE models. It is a follow up of PR #1018.

I have the necessary graph building changes only for GLM-4.5/4.6 just to see how it performs. On a 2x3090 system I can only fully offload a low-bpw quantized GLM-4.5-AIR (full offload is needed as tensor overrides are not yet implemented in this new scheme). I'm using @ubergarm's IQ1_KT model, which is 38.7 GB, and allows me to go to about 32k tokens of context (with f16 KV cache).

Here performance results are mixed. For PP-2048 the new split mode "graph" implementation beats split mode "layer" for all context lengths, being as much as 60% faster at a context of 30k tokens. TG on the other hand is significantly slower at zero context, and only becomes faster than "layer" around a context of 20k tokens. Here are the sweep-bench results

Split mode "graph"

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
2048 256 0 0.897 2283.44 4.231 60.51
2048 256 2048 0.928 2207.16 4.368 58.61
2048 256 4096 0.984 2081.52 4.482 57.12
2048 256 6144 1.044 1962.07 4.626 55.33
2048 256 8192 1.102 1858.33 4.808 53.24
2048 256 10240 1.161 1763.85 5.028 50.91
2048 256 12288 1.222 1676.09 5.224 49.00
2048 256 14336 1.279 1601.86 5.368 47.69
2048 256 16384 1.348 1518.93 5.502 46.53
2048 256 18432 1.406 1456.80 5.653 45.28
2048 256 20480 1.464 1399.11 5.805 44.10
2048 256 22528 1.523 1344.63 6.144 41.66
2048 256 24576 1.586 1291.57 6.174 41.47
2048 256 26624 1.651 1240.34 6.318 40.52
2048 256 28672 1.714 1194.88 6.634 38.59

Split mode "layer"

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
2048 256 0 1.018 2011.77 2.789 91.78
2048 256 2048 1.118 1831.98 3.088 82.89
2048 256 4096 1.226 1670.15 3.417 74.92
2048 256 6144 1.350 1517.54 3.776 67.80
2048 256 8192 1.468 1395.53 4.064 62.99
2048 256 10240 1.586 1291.43 4.370 58.59
2048 256 12288 1.707 1200.09 4.700 54.47
2048 256 14336 1.829 1119.91 4.990 51.31
2048 256 16384 1.958 1046.17 5.318 48.14
2048 256 18432 2.083 982.99 5.635 45.43
2048 256 20480 2.217 923.75 5.943 43.08
2048 256 22528 2.351 871.13 6.290 40.70
2048 256 24576 2.478 826.47 6.582 38.89
2048 256 26624 2.604 786.61 6.888 37.17
2048 256 28672 2.732 749.63 7.255 35.29
2048 256 30720 2.852 718.04 7.539 33.96
tp_glmair_pp tp_glmair_tg

@ikawrakow
Copy link
Copy Markdown
Owner Author

If I use Q8_0 KV cache, I can go to about 55k tokens or so. At 55k tokens split mode "graph" is 66% faster for PP and 42% faster for TG.

Here some graphs where I have used a logarithmic scale for the y-axis to better appreciate the relative change in performance.

tp_glmair_pp1 tp_glmair_tg1

@ikawrakow
Copy link
Copy Markdown
Owner Author

Graph reuse works now with split mode "graph". Here an updated sweep-bench graph for LlaMA-70B TG

tp_70B_tg
ik_llama.cpp, split mode graph, graph reuse (-gr)
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
1024 256 0 1.160 882.66 8.796 29.10
1024 256 1024 1.168 876.37 8.839 28.96
1024 256 2048 1.188 862.06 8.900 28.76
1024 256 3072 1.206 849.33 8.952 28.60
1024 256 4096 1.224 836.62 9.008 28.42
1024 256 5120 1.242 824.76 9.176 27.90
1024 256 6144 1.263 810.70 9.261 27.64
1024 256 7168 1.281 799.61 9.216 27.78
1024 256 8192 1.299 788.24 9.251 27.67
1024 256 9216 1.317 777.39 9.279 27.59
1024 256 10240 1.335 767.10 9.362 27.34
1024 256 11264 1.356 755.26 9.474 27.02
1024 256 12288 1.375 744.69 9.484 26.99
1024 256 13312 1.394 734.67 9.521 26.89
1024 256 14336 1.412 725.32 9.571 26.75
1024 256 15360 1.429 716.39 9.592 26.69

@Ph0rk0z
Copy link
Copy Markdown

Ph0rk0z commented Nov 29, 2025

I only have big GLM. Will it work with hybrid inference?

@ikawrakow
Copy link
Copy Markdown
Owner Author

Not yet. I'm working on it.

@magikRUKKOLA
Copy link
Copy Markdown

magikRUKKOLA commented Nov 30, 2025

@ikawrakow

Having a garbage output when -sm graph is enabled.

Details ``` XXXXXXXXXXXXXXXXXXXXX Setting only active experts offload ========= Program hit cudaErrorGraphExecUpdateFailure (error 910) due to "the graph update was not performed because it included changes which violated constraints specific to instantiated graph update" on CUDA API call to cudaGraphExecUpdate. ========= Saved host backtrace up to driver entry point at error ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x3090ae] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Program hit cudaErrorGraphExecUpdateFailure (error 910) due to "the graph update was not performed because it included changes which violated constraints specific to instantiated graph update" on CUDA API call to cudaGetLastError. ========= Saved host backtrace up to driver entry point at error ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x309437] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Program hit cudaErrorGraphExecUpdateFailure (error 910) due to "the graph update was not performed because it included changes which violated constraints specific to instantiated graph update" on CUDA API call to cudaGraphExecUpdate. ========= Saved host backtrace up to driver entry point at error ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x3090ae] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Program hit cudaErrorGraphExecUpdateFailure (error 910) due to "the graph update was not performed because it included changes which violated constraints specific to instantiated graph update" on CUDA API call to cudaGetLastError. ========= Saved host backtrace up to driver entry point at error ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x309437] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Program hit cudaErrorGraphExecUpdateFailure (error 910) due to "the graph update was not performed because it included changes which violated constraints specific to instantiated graph update" on CUDA API call to cudaGraphExecUpdate. ========= Saved host backtrace up to driver entry point at error ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x3090ae] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Program hit cudaErrorGraphExecUpdateFailure (error 910) due to "the graph update was not performed because it included changes which violated constraints specific to instantiated graph update" on CUDA API call to cudaGetLastError. ========= Saved host backtrace up to driver entry point at error ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x309437] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Program hit cudaErrorGraphExecUpdateFailure (error 910) due to "the graph update was not performed because it included changes which violated constraints specific to instantiated graph update" on CUDA API call to cudaGraphExecUpdate. ========= Saved host backtrace up to driver entry point at error ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x3090ae] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Program hit cudaErrorGraphExecUpdateFailure (error 910) due to "the graph update was not performed because it included changes which violated constraints specific to instantiated graph update" on CUDA API call to cudaGetLastError. ========= Saved host backtrace up to driver entry point at error ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x309437] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Program hit cudaErrorGraphExecUpdateFailure (error 910) due to "the graph update was not performed because it included changes which violated constraints specific to instantiated graph update" on CUDA API call to cudaGraphExecUpdate. ========= Saved host backtrace up to driver entry point at error ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x3090ae] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Program hit cudaErrorGraphExecUpdateFailure (error 910) due to "the graph update was not performed because it included changes which violated constraints specific to instantiated graph update" on CUDA API call to cudaGetLastError. ========= Saved host backtrace up to driver entry point at error ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x309437] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Program hit cudaErrorGraphExecUpdateFailure (error 910) due to "the graph update was not performed because it included changes which violated constraints specific to instantiated graph update" on CUDA API call to cudaGraphExecUpdate. ========= Saved host backtrace up to driver entry point at error ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x3090ae] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Program hit cudaErrorGraphExecUpdateFailure (error 910) due to "the graph update was not performed because it included changes which violated constraints specific to instantiated graph update" on CUDA API call to cudaGetLastError. ========= Saved host backtrace up to driver entry point at error ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x309437] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Program hit cudaErrorGraphExecUpdateFailure (error 910) due to "the graph update was not performed because it included changes which violated constraints specific to instantiated graph update" on CUDA API call to cudaGraphExecUpdate. ========= Saved host backtrace up to driver entry point at error ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x3090ae] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Program hit cudaErrorGraphExecUpdateFailure (error 910) due to "the graph update was not performed because it included changes which violated constraints specific to instantiated graph update" on CUDA API call to cudaGetLastError. ========= Saved host backtrace up to driver entry point at error ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x309437] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Program hit cudaErrorGraphExecUpdateFailure (error 910) due to "the graph update was not performed because it included changes which violated constraints specific to instantiated graph update" on CUDA API call to cudaGraphExecUpdate. ========= Saved host backtrace up to driver entry point at error ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x3090ae] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Program hit cudaErrorGraphExecUpdateFailure (error 910) due to "the graph update was not performed because it included changes which violated constraints specific to instantiated graph update" on CUDA API call to cudaGetLastError. ========= Saved host backtrace up to driver entry point at error ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x309437] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ```

[EDIT]:

with initcheck:

Details ``` ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (8,1,0) in block (45,0,0) ========= Address 0x7ff3a8000780 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (9,1,0) in block (45,0,0) ========= Address 0x7ff3a8000790 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (10,1,0) in block (45,0,0) ========= Address 0x7ff3a80007a0 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (11,1,0) in block (45,0,0) ========= Address 0x7ff3a80007b0 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (12,1,0) in block (45,0,0) ========= Address 0x7ff3a80007c0 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (13,1,0) in block (45,0,0) ========= Address 0x7ff3a80007d0 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (14,1,0) in block (45,0,0) ========= Address 0x7ff3a80007e0 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (15,1,0) in block (45,0,0) ========= Address 0x7ff3a80007f0 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (16,1,0) in block (45,0,0) ========= Address 0x7ff3a8000a00 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (17,1,0) in block (45,0,0) ========= Address 0x7ff3a8000a10 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (18,1,0) in block (45,0,0) ========= Address 0x7ff3a8000a20 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (19,1,0) in block (45,0,0) ========= Address 0x7ff3a8000a30 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (20,1,0) in block (45,0,0) ========= Address 0x7ff3a8000a40 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (21,1,0) in block (45,0,0) ========= Address 0x7ff3a8000a50 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (22,1,0) in block (45,0,0) ========= Address 0x7ff3a8000a60 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (23,1,0) in block (45,0,0) ========= Address 0x7ff3a8000a70 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (24,1,0) in block (45,0,0) ========= Address 0x7ff3a8000a80 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (25,1,0) in block (45,0,0) ========= Address 0x7ff3a8000a90 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (26,1,0) in block (45,0,0) ========= Address 0x7ff3a8000aa0 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (27,1,0) in block (45,0,0) ========= Address 0x7ff3a8000ab0 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (28,1,0) in block (45,0,0) ========= Address 0x7ff3a8000ac0 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (29,1,0) in block (45,0,0) ========= Address 0x7ff3a8000ad0 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (30,1,0) in block (45,0,0) ========= Address 0x7ff3a8000ae0 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (31,1,0) in block (45,0,0) ========= Address 0x7ff3a8000af0 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (0,3,0) in block (9,0,0) ========= Address 0x7ff3a8001200 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (1,3,0) in block (9,0,0) ========= Address 0x7ff3a8001210 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (2,3,0) in block (9,0,0) ========= Address 0x7ff3a8001220 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ========= Uninitialized __global__ memory read of size 16 bytes ========= at void flash_attn_mma_ext_f16<(int)128, (int)2, (int)4, (int)4, (int)64, (int)1, (bool)0>(const char *, const char *, const char *, const char *, const char *, const int2 *, float *, float2 *, float, float, float, float, float, unsigned int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)+0x4c10 ========= by thread (3,3,0) in block (9,0,0) ========= Address 0x7ff3a8001230 ========= Saved host backtrace up to driver entry point at kernel launch time ========= Host Frame: cuGraphLaunch [0x3a1fef] in libcuda.so.1 ========= Host Frame: [0x2062a] in libcudart.so.13 ========= Host Frame: cudaGraphLaunch [0x76211] in libcudart.so.13 ========= Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x307a69] in libggml.so ========= Host Frame: ggml_backend_sched_graph_compute_async [0x1943a1] in libggml.so ========= Host Frame: llama_decode [0x7da6c] in libllama.so ========= Host Frame: llama_init_from_gpt_params(gpt_params&) [0x20c4aa] in llama-server ========= Host Frame: server_context::load_model(gpt_params const&) [0xf3bb4] in llama-server ========= Host Frame: main [0x5cb8b] in llama-server ========= ```

@magikRUKKOLA
Copy link
Copy Markdown

magikRUKKOLA commented Nov 30, 2025

[EDIT]: forgot to remove the cpu-moe flag. The details can be safely ignored.

Details I do not see any garbage output with ```--split-mode graph``` enabled anymore.

I will drop the sweep-benches tests below in a few hours or so.

[EDIT]:

As of now its the following:

main: n_kv_max = 98304, n_batch = 8192, n_ubatch = 2048, flash_attn = 1, n_gpu_layers = 99, n_threads = 64, n_threads_batch = 64

|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  2048 |    512 |      0 |    3.270 |   626.28 |   22.979 |    22.28 |
|  2048 |    512 |   2048 |    3.295 |   621.48 |   23.186 |    22.08 |
|  2048 |    512 |   4096 |    3.318 |   617.22 |   23.373 |    21.91 |
|  2048 |    512 |   6144 |    3.344 |   612.35 |   23.626 |    21.67 |
|  2048 |    512 |   8192 |    3.378 |   606.30 |   23.725 |    21.58 |
|  2048 |    512 |  10240 |    3.424 |   598.18 |   24.025 |    21.31 |
|  2048 |    512 |  12288 |    3.470 |   590.21 |   24.282 |    21.09 |
|  2048 |    512 |  14336 |    3.518 |   582.22 |   24.422 |    20.96 |
|  2048 |    512 |  16384 |    3.569 |   573.86 |   24.496 |    20.90 |
|  2048 |    512 |  18432 |    3.610 |   567.30 |   24.777 |    20.66 |
|  2048 |    512 |  20480 |    3.651 |   560.91 |   24.895 |    20.57 |
|  2048 |    512 |  22528 |    3.691 |   554.85 |   24.956 |    20.52 |
|  2048 |    512 |  24576 |    3.739 |   547.78 |   25.216 |    20.30 |
|  2048 |    512 |  26624 |    3.794 |   539.75 |   25.382 |    20.17 |
|  2048 |    512 |  28672 |    3.839 |   533.47 |   25.607 |    19.99 |
|  2048 |    512 |  30720 |    3.889 |   526.56 |   25.933 |    19.74 |
|  2048 |    512 |  32768 |    3.923 |   522.06 |   26.150 |    19.58 |
|  2048 |    512 |  34816 |    3.985 |   513.86 |   26.631 |    19.23 |
|  2048 |    512 |  36864 |    4.043 |   506.61 |   26.745 |    19.14 |
|  2048 |    512 |  38912 |    4.089 |   500.91 |   26.961 |    18.99 |
|  2048 |    512 |  40960 |    4.132 |   495.63 |   27.280 |    18.77 |
|  2048 |    512 |  43008 |    4.182 |   489.72 |   27.538 |    18.59 |
|  2048 |    512 |  45056 |    4.267 |   479.96 |   27.826 |    18.40 |
|  2048 |    512 |  47104 |    4.330 |   472.98 |   28.035 |    18.26 |
|  2048 |    512 |  49152 |    4.392 |   466.29 |   28.246 |    18.13 |
|  2048 |    512 |  51200 |    4.420 |   463.34 |   28.539 |    17.94 |
|  2048 |    512 |  53248 |    4.489 |   456.24 |   28.764 |    17.80 |
|  2048 |    512 |  55296 |    4.540 |   451.09 |   28.982 |    17.67 |
|  2048 |    512 |  57344 |    4.607 |   444.51 |   29.213 |    17.53 |
|  2048 |    512 |  59392 |    4.654 |   440.02 |   29.445 |    17.39 |
|  2048 |    512 |  61440 |    4.706 |   435.23 |   29.813 |    17.17 |
|  2048 |    512 |  63488 |    4.782 |   428.26 |   30.036 |    17.05 |
|  2048 |    512 |  65536 |    4.888 |   418.94 |   30.312 |    16.89 |
|  2048 |    512 |  67584 |    4.948 |   413.92 |   30.524 |    16.77 |
|  2048 |    512 |  69632 |    5.115 |   400.37 |   30.712 |    16.67 |
|  2048 |    512 |  71680 |    5.240 |   390.80 |   30.944 |    16.55 |
|  2048 |    512 |  73728 |    5.486 |   373.33 |   31.135 |    16.44 |
|  2048 |    512 |  75776 |    5.489 |   373.08 |   31.627 |    16.19 |
|  2048 |    512 |  77824 |    5.478 |   373.89 |   31.763 |    16.12 |
|  2048 |    512 |  79872 |    5.572 |   367.56 |   31.799 |    16.10 |
|  2048 |    512 |  81920 |    5.716 |   358.27 |   32.104 |    15.95 |
|  2048 |    512 |  83968 |    5.620 |   364.40 |   32.330 |    15.84 |
|  2048 |    512 |  86016 |    5.724 |   357.79 |   32.667 |    15.67 |
|  2048 |    512 |  88064 |    5.914 |   346.31 |   32.985 |    15.52 |
|  2048 |    512 |  90112 |    5.882 |   348.21 |   33.114 |    15.46 |
|  2048 |    512 |  92160 |    5.910 |   346.51 |   33.411 |    15.32 |
|  2048 |    512 |  94208 |    5.938 |   344.87 |   33.663 |    15.21 |
|  2048 |    512 |  96256 |    6.001 |   341.29 |   33.864 |    15.12 |

@ikawrakow
Copy link
Copy Markdown
Owner Author

ikawrakow commented Nov 30, 2025

@magikRUKKOLA

You didn't say what kind of a model you are using. There was still a bug with interleaved quants (*_R4, *_R8), this is fixed now. Also, tensor overrides appear to be working as of last commit.

Here is what I have right now on the 2x3090 box.

GLM-4.6, ubergarm's IQ1_KT mix, all MoE tensors left in RAM

Split mode "graph"

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
4096 64 0 10.280 398.43 6.474 9.89
4096 64 4096 10.776 380.12 7.101 9.01
4096 64 8192 11.205 365.54 6.962 9.19
4096 64 12288 11.719 349.51 7.257 8.82
4096 64 16384 12.258 334.16 7.487 8.55
4096 64 20480 12.869 318.28 7.548 8.48
4096 64 24576 13.346 306.90 7.691 8.32
4096 64 28672 13.986 292.87 8.178 7.83

Split mode "layer"

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
4096 64 0 10.140 403.93 5.690 11.25
4096 64 4096 11.173 366.60 6.092 10.51
4096 64 8192 12.353 331.58 6.401 10.00
4096 64 12288 13.506 303.26 6.773 9.45
4096 64 16384 14.655 279.50 7.082 9.04
4096 64 20480 15.862 258.23 7.396 8.65
4096 64 24576 16.983 241.18 7.627 8.39
4096 64 28672 18.418 222.39 7.932 8.07

GLM-4.6, Thireus 5.5 bpw mix, all MoE tensors left in RAM

Split mode "graph"

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
4096 64 0 18.361 223.08 8.018 7.98
4096 64 4096 18.736 218.61 8.004 8.00
4096 64 8192 19.252 212.76 8.175 7.83
4096 64 12288 19.740 207.49 8.381 7.64
4096 64 16384 20.270 202.07 8.472 7.55
4096 64 20480 20.771 197.20 8.652 7.40
4096 64 24576 21.349 191.86 8.938 7.16
4096 64 28672 22.012 186.08 8.946 7.15

Split mode "layer"

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
4096 64 0 17.839 229.60 7.337 8.72
4096 64 4096 18.814 217.71 7.540 8.49
4096 64 8192 19.966 205.15 7.861 8.14
4096 64 12288 21.140 193.76 8.215 7.79
4096 64 16384 22.291 183.75 8.503 7.53
4096 64 20480 23.535 174.04 8.804 7.27
4096 64 24576 24.691 165.89 9.068 7.06
4096 64 28672 26.167 156.53 9.384 6.82

So, on my box PP is (almost) on par with split mode "layer" at zero context, and beats it by a non-negligible margin with increasing context length.

TG is not as good with hybrid. One needs to get to a sufficiently long context to have split mode "graph" being better than split mode "layer"

Iwan Kawrakow added 23 commits November 30, 2025 18:05
But it also looks like the backend scheduler is not going to help:
* It copies mask and input positions to GPU 0
* => RoPE ops must run on GPU 0
* => To proceed attn evaluation, GPU 1 must wait for GPU 0 to finish its
     entire attn calculation
* Same with FFN. The rms_norm gets scheduled on GPU 0. Hence, GPU 1 must
  wait for GPU 0 to finish its entore FFN calculation before it can
  start (as it needs to copy the result of rms_norm from GPU 0)
* => Seems useless without writing a bespoke TP scheduling
the graph is still not being computed in parallel.
Why? Because the scheduler creates graph splits where the
result of the computation on one GPU becomes an input for the
other split. Hence, to trigger the computation on the second GPU
one needs to wait for the computation on the first GPU to finish,
even thiough the two can be done in parallel up to the sunchronization
point. So, all that is left to do is to trick the scheduler to create
to splits that can be done in parallel, and then have a graph split
where the results get combined.
This change tricks it into doing the right thing^TM.
Still quite a bit slower than split mode layer for the 8B LlaMA model.
But for the 70B LlaMA it now beats split mode layer for TG:
28 t/s vs 24.4 t/s. PP is 627 t/s vs 744 t/s.
In comparison, split mode "row" in mainline gets
484 t/s PP and 19.3 t/s TG.
Granularity for Wq, Wo is not just head size, but
head size * gqa_ratio.
Else the Wk, Wv tensors end up not being a multiple of the
head size when we divide the split determined by Wo with
the gqa_ratio.
but no tensor overrides yet, just ngl < num_layers.
Now PP is faster than split mode layer for L3-70B.
PP is already better than split mode layer, but TG for zero context
is kind of low - 60 vs 92 t/s. TG becomes better than split mode layer
at around 20k tokens. PP at 26k tokens is 1.55X of sm layer.
It issues a warning that there is an extra semicolon outside of a function,
but there isn't. If I remove the anonymous namespace and turn the
functions inside into static, the warning disapears, so clearly
a compiler bug.
Runs with wrong results, don't see where the issue could be.
Still does not work for row-interleaved quants
@ikawrakow
Copy link
Copy Markdown
Owner Author

@ubergarm

Thanks for these results. It looks like for a rig with 2 GPUs that does not have an ancient CPU it works quite well. Your results are even better than mine. At 64k context PP is 1.9X, so almost as good as it gets.

@ubergarm
Copy link
Copy Markdown
Contributor

ubergarm commented Dec 4, 2025

Given I used the Q4_0 I ran the same test against mainline, the PP delta is pretty wild, not sure if there is something I could pass there to improve it.

sweep-bench-PR1022-GLM-4 5-Air-vs-mainline
👈 Details
model=/mnt/raid/hf/GLM-4.5-Air-GGUF/Q4_0/GLM-4.5-Air-Q4_0-00001-of-00002.gguf
$ ./build/bin/llama-sweep-bench \
    --model "$model"\
    -c 69632 \
    -ngl 99 \
    -ub 4096 -b 4096 \
    --threads 1

load_tensors: offloaded 48/48 layers to GPU
load_tensors:   CPU_Mapped model buffer size =   333.00 MiB
load_tensors:        CUDA0 model buffer size = 32937.69 MiB
load_tensors:        CUDA1 model buffer size = 28166.12 MiB
....................................................................................................
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 69632
llama_context: n_ctx_seq     = 69632
llama_context: n_batch       = 4096
llama_context: n_ubatch      = 4096
llama_context: causal_attn   = 1
llama_context: flash_attn    = auto
llama_context: kv_unified    = false
llama_context: freq_base     = 1000000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_seq (69632) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
llama_context:  CUDA_Host  output buffer size =     0.58 MiB
llama_kv_cache:      CUDA0 KV buffer size =  6800.00 MiB
llama_kv_cache:      CUDA1 KV buffer size =  5712.00 MiB
llama_kv_cache: size = 12512.00 MiB ( 69632 cells,  46 layers,  1/1 seqs), K (f16): 6256.00 MiB, V (f16): 6256.00 MiB
llama_context: pipeline parallelism enabled (n_copies=1)
llama_context: Flash Attention was auto, set to enabled
llama_context:      CUDA0 compute buffer size =  2144.08 MiB
llama_context:      CUDA1 compute buffer size =  2560.00 MiB
llama_context:  CUDA_Host compute buffer size =  1152.11 MiB
llama_context: graph nodes  = 3146
llama_context: graph splits = 3

main: n_kv_max = 69632, n_batch = 4096, n_ubatch = 4096, flash_attn_type = -1, n_gpu_layers = 99, n_threads = 1, n_threads_batch = 1
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
4096 1024 0 4.396 931.76 19.128 53.53
4096 1024 4096 7.329 558.89 21.328 48.01
4096 1024 8192 10.328 396.58 23.675 43.25
4096 1024 12288 13.304 307.88 26.093 39.24
4096 1024 16384 16.251 252.05 28.511 35.92
4096 1024 20480 19.236 212.94 30.916 33.12
4096 1024 24576 22.202 184.49 33.334 30.72
4096 1024 28672 25.114 163.10 35.724 28.66
4096 1024 32768 28.067 145.93 38.110 26.87
4096 1024 36864 30.910 132.51 40.705 25.16
4096 1024 40960 33.764 121.31 42.821 23.91
4096 1024 45056 36.780 111.37 45.222 22.64
4096 1024 49152 39.663 103.27 47.608 21.51
4096 1024 53248 42.607 96.13 50.001 20.48
4096 1024 57344 45.502 90.02 52.271 19.59
4096 1024 61440 48.322 84.76 54.741 18.71
4096 1024 65536 51.343 79.78 57.096 17.93

@aikitoria
Copy link
Copy Markdown

aikitoria commented Dec 4, 2025

I could try to run the test on my system with 8 GPUs, however I am unable to locate the GLM-4.5-Air-GGUF/Q4_0 test file that everyone is using on hf.

Edit: I was looking at the wrong user, I guess it's the unsloth one?

@magikRUKKOLA
Copy link
Copy Markdown

magikRUKKOLA commented Dec 4, 2025

@ikawrakow

Speaking of tailscale, I would never ever do
as @magikRUKKOLA suggested.

To clarify my point I have to add the following...

I was pointing out that in order to access the machine that does not have a dedicated IP one have to employ a NAT traversal technique. So the easiest example of it would be to use a Tailscale (which is opensource etc.). If one doesn't trust the code for one reason or another, one could always use something like VM (qemu etc.).

Many alternative options do exist. For example, TOR. It's written in C and the code really easy to follow etc. The service provider would need to create an onion service and the client (you, potentially) could just install tor and use the utility torsocks as a prefix to ssh in order to connect via the ssh to the machine that the service provider got setup. The downside is that one would have to use things like eternal terminal, autossh etc. to avoid problems of reconnection on error etc.

I use ssh.

If so, I suggest the following.

flowchart TD
    A[Client: ikawrakow] -->|SSH with public key| B(VPS with Dedicated IP)
    
    subgraph B [VPS Portal]
        C[tmux Session 1]
        D[tmux Session 2]
        E[tmux Session ...]
    end

    C -->|autossh + eternal terminal<br>via TOR or similar NAT traversal| F[Target Machine 1]
    D -->|similar setup| G[Target Machine 2]
    E --> H[Other Machines]
Loading

Someone would need to order any kind of VPS with a dedicated IP address (these, with IPv6 are pretty cheap nowadays). So then, we could add your public ssh key to the authorized_keys file at that server.

The next step would be to use the terminal multiplexer like screen or tmux in which we're setting up a session (via eternal terminal and autossh as described above) plus the backconnect option (as described above, say, the tor). In such a case it would be really convenient for you since the only thing you'd have to do is to connect to such a server and we'd take care of the rest in order to provide a root access to the server in question. Basically such VPS would be like a portal to other servers where you can switch between them by simply switching the tabs (Ctrl + b + 1/2/3 etc.).

@ubergarm
Copy link
Copy Markdown
Contributor

ubergarm commented Dec 4, 2025

@aikitoria

I could try to run the test on my system with 8 GPUs, however I am unable to locate the GLM-4.5-Air-GGUF/Q4_0 test file that everyone is using on hf.

Edit: I was looking at the wrong user, I guess it's the unsloth one?

I didn't upload the one i am using which i quanted, you could probably use this one which is close enough: https://huggingface.co/bartowski/zai-org_GLM-4.5-Air-GGUF/tree/main/zai-org_GLM-4.5-Air-Q4_0

UPDATE
oh actually, mine uses q8_0 for attn/shexp/first dense layer... if you want to compare against mine specifically let me and know and i could upload it, but proabably fine to just use whatever you have handy and get a relative comparison on your rig between modes of operation

@aikitoria
Copy link
Copy Markdown

It's fine I'm already using that one, just need to see the relative comparison I think

@aikitoria
Copy link
Copy Markdown

So at least in its current configuration it does not seem to work correctly on my system, unless I did it wrong (just copied the commands you have there)

CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7,8 ./build/bin/llama-sweep-bench \
    --model "/models/raw/zai-org_GLM-4.5-Air-GGUF/zai-org_GLM-4.5-Air-Q4_0/zai-org_GLM-4.5-Air-Q4_0-00001-of-00002.gguf" \
    --merge-qkv \
    -c 69632 \
    -ngl 99 \
    -ub 4096 -b 4096 \
    --threads 1 \
    --warmup-batch \
    --numa isolate

|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  4096 |   1024 |      0 |    0.894 |  4583.40 |    7.399 |   138.40 |
|  4096 |   1024 |   4096 |    1.054 |  3887.39 |    8.091 |   126.56 |
|  4096 |   1024 |   8192 |    1.214 |  3375.09 |    8.705 |   117.63 |
|  4096 |   1024 |  12288 |    1.379 |  2971.20 |    9.421 |   108.69 |
|  4096 |   1024 |  16384 |    1.549 |  2645.07 |   10.039 |   102.01 |
|  4096 |   1024 |  20480 |    1.734 |  2362.33 |   10.692 |    95.77 |
|  4096 |   1024 |  24576 |    1.912 |  2142.43 |   11.430 |    89.59 |
|  4096 |   1024 |  28672 |    2.123 |  1929.62 |   12.170 |    84.14 |
|  4096 |   1024 |  32768 |    2.320 |  1765.85 |   12.851 |    79.68 |
|  4096 |   1024 |  36864 |    2.520 |  1625.43 |   13.534 |    75.66 |
|  4096 |   1024 |  40960 |    2.725 |  1502.97 |   14.259 |    71.81 |
|  4096 |   1024 |  45056 |    2.933 |  1396.74 |   14.914 |    68.66 |
|  4096 |   1024 |  49152 |    3.144 |  1302.82 |   15.544 |    65.88 |
|  4096 |   1024 |  53248 |    3.358 |  1219.95 |   16.223 |    63.12 |
|  4096 |   1024 |  57344 |    3.726 |  1099.18 |   16.922 |    60.51 |
|  4096 |   1024 |  61440 |    3.938 |  1040.08 |   17.696 |    57.86 |
|  4096 |   1024 |  65536 |    4.243 |   965.31 |   18.306 |    55.94 |

CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7,8 ./build/bin/llama-sweep-bench \
    --model "/models/raw/zai-org_GLM-4.5-Air-GGUF/zai-org_GLM-4.5-Air-Q4_0/zai-org_GLM-4.5-Air-Q4_0-00001-of-00002.gguf" \
    -sm graph \
    -c 69632 \
    -ngl 99 \
    -ub 4096 -b 4096 \
    --threads 1 \
    --warmup-batch \
    --numa isolate

|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  4096 |   1024 |      0 |    2.677 |  1530.07 |   41.220 |    24.84 |
|  4096 |   1024 |   4096 |    2.695 |  1519.93 |   41.494 |    24.68 |
canceled

System info:

2x AMD Epyc 9575F on ASRockRack TURIN2D24G-2L+/500W (configured as 2 NUMA nodes)
24x Micron 32GB 6400MT/s modules
8x RTX 5090 (not power limited)
1x RTX PRO 6000 (not used for this bench)
All GPUs P2P enabled at full PCIe Gen5 speed
Running inside cuda:13.0.2-devel-ubuntu24.04 docker image on Debian 14
NVIDIA driver 590.44.01

The GPUs only had like 15% usage during the test.

@magikRUKKOLA
Copy link
Copy Markdown

magikRUKKOLA commented Dec 4, 2025

@aikitoria

What does the /usr/share/doc/nvidia-cuda-toolkit/examples/Samples/5_Domain_Specific/p2pBandwidthLatencyTest/p2pBandwidthLatencyTest shows BTW? (I am curious)

[EDIT]:

All GPUs P2P enabled at full PCIe Gen5 speed

Waaah! That's pretty neat. Zen 2 single-CPU boards only support up to 4 PCIe Gen4 x16.
So it's seems like I need to pick up a different board then. How much ASRockRack TURIN2D24G-2L+ is going for nowadays? You got them from China?

@ubergarm
Copy link
Copy Markdown
Contributor

ubergarm commented Dec 4, 2025

@aikitoria

So at least in its current configuration it does not seem to work correctly on my system, unless I did it wrong (just copied the commands you have there)

Your test looks correct to me and I see you correctly duplicated my example commands including the first working condition with default split mode and --merge-qkv and the strangely slow -sm graph second test case.

Not sure if using any kind of explicit -ts 1,1,1,1,1,1,1,1 type thing would be needed for the 8x used GPUs but I assume that would be default. You probably don't need --numa isolate given it is full GPU offload.

I had one other tester on Beaver AI discord tell me:

I'll give it a try. Have two RTX 8000s myself.
Well, I get an 'illegal memory address' CUDA error attempting to run it in graph mode 😭
https://discord.com/channels/1238219753324281886/1250677089191985184/1445609994316943370

So yours seems to run but is not performing correctly... hrm.. This feature is still pretty new so thanks for testing with your cool rig! Not sure at the moment what you could try next.

Maybe try it with just 2x GPUs and see if that works?

@magikRUKKOLA
Copy link
Copy Markdown

magikRUKKOLA commented Dec 4, 2025

@ubergarm

So yours seems to run but is not performing correctly... hrm.. This feature is still pretty new so thanks for testing with your cool rig! Not sure at the moment what you could try next.

Have he tried to run with things like compute-sanitizer --tool initcheck or cuda-gdb --args ?

@aikitoria
Copy link
Copy Markdown

p2pBandwidthLatencyTest

It's basically flawless:

Click to expand
./p2pBandwidthLatencyTest
[P2P (Peer-to-Peer) GPU Bandwidth Latency Test]
Device: 0, NVIDIA RTX PRO 6000 Blackwell Workstation Edition, pciBusID: c1, pciDeviceID: 0, pciDomainID:0
Device: 1, NVIDIA GeForce RTX 5090, pciBusID: 1, pciDeviceID: 0, pciDomainID:0
Device: 2, NVIDIA GeForce RTX 5090, pciBusID: 11, pciDeviceID: 0, pciDomainID:0
Device: 3, NVIDIA GeForce RTX 5090, pciBusID: 61, pciDeviceID: 0, pciDomainID:0
Device: 4, NVIDIA GeForce RTX 5090, pciBusID: 71, pciDeviceID: 0, pciDomainID:0
Device: 5, NVIDIA GeForce RTX 5090, pciBusID: 81, pciDeviceID: 0, pciDomainID:0
Device: 6, NVIDIA GeForce RTX 5090, pciBusID: 91, pciDeviceID: 0, pciDomainID:0
Device: 7, NVIDIA GeForce RTX 5090, pciBusID: e1, pciDeviceID: 0, pciDomainID:0
Device: 8, NVIDIA GeForce RTX 5090, pciBusID: f1, pciDeviceID: 0, pciDomainID:0
Device=0 CAN Access Peer Device=1
Device=0 CAN Access Peer Device=2
Device=0 CAN Access Peer Device=3
Device=0 CAN Access Peer Device=4
Device=0 CAN Access Peer Device=5
Device=0 CAN Access Peer Device=6
Device=0 CAN Access Peer Device=7
Device=0 CAN Access Peer Device=8
Device=1 CAN Access Peer Device=0
Device=1 CAN Access Peer Device=2
Device=1 CAN Access Peer Device=3
Device=1 CAN Access Peer Device=4
Device=1 CAN Access Peer Device=5
Device=1 CAN Access Peer Device=6
Device=1 CAN Access Peer Device=7
Device=1 CAN Access Peer Device=8
Device=2 CAN Access Peer Device=0
Device=2 CAN Access Peer Device=1
Device=2 CAN Access Peer Device=3
Device=2 CAN Access Peer Device=4
Device=2 CAN Access Peer Device=5
Device=2 CAN Access Peer Device=6
Device=2 CAN Access Peer Device=7
Device=2 CAN Access Peer Device=8
Device=3 CAN Access Peer Device=0
Device=3 CAN Access Peer Device=1
Device=3 CAN Access Peer Device=2
Device=3 CAN Access Peer Device=4
Device=3 CAN Access Peer Device=5
Device=3 CAN Access Peer Device=6
Device=3 CAN Access Peer Device=7
Device=3 CAN Access Peer Device=8
Device=4 CAN Access Peer Device=0
Device=4 CAN Access Peer Device=1
Device=4 CAN Access Peer Device=2
Device=4 CAN Access Peer Device=3
Device=4 CAN Access Peer Device=5
Device=4 CAN Access Peer Device=6
Device=4 CAN Access Peer Device=7
Device=4 CAN Access Peer Device=8
Device=5 CAN Access Peer Device=0
Device=5 CAN Access Peer Device=1
Device=5 CAN Access Peer Device=2
Device=5 CAN Access Peer Device=3
Device=5 CAN Access Peer Device=4
Device=5 CAN Access Peer Device=6
Device=5 CAN Access Peer Device=7
Device=5 CAN Access Peer Device=8
Device=6 CAN Access Peer Device=0
Device=6 CAN Access Peer Device=1
Device=6 CAN Access Peer Device=2
Device=6 CAN Access Peer Device=3
Device=6 CAN Access Peer Device=4
Device=6 CAN Access Peer Device=5
Device=6 CAN Access Peer Device=7
Device=6 CAN Access Peer Device=8
Device=7 CAN Access Peer Device=0
Device=7 CAN Access Peer Device=1
Device=7 CAN Access Peer Device=2
Device=7 CAN Access Peer Device=3
Device=7 CAN Access Peer Device=4
Device=7 CAN Access Peer Device=5
Device=7 CAN Access Peer Device=6
Device=7 CAN Access Peer Device=8
Device=8 CAN Access Peer Device=0
Device=8 CAN Access Peer Device=1
Device=8 CAN Access Peer Device=2
Device=8 CAN Access Peer Device=3
Device=8 CAN Access Peer Device=4
Device=8 CAN Access Peer Device=5
Device=8 CAN Access Peer Device=6
Device=8 CAN Access Peer Device=7

***NOTE: In case a device doesn't have P2P access to other one, it falls back to normal memcopy procedure.
So you can see lesser Bandwidth (GB/s) and unstable Latency (us) in those cases.

P2P Connectivity Matrix
     D\D     0     1     2     3     4     5     6     7     8
     0       1     1     1     1     1     1     1     1     1
     1       1     1     1     1     1     1     1     1     1
     2       1     1     1     1     1     1     1     1     1
     3       1     1     1     1     1     1     1     1     1
     4       1     1     1     1     1     1     1     1     1
     5       1     1     1     1     1     1     1     1     1
     6       1     1     1     1     1     1     1     1     1
     7       1     1     1     1     1     1     1     1     1
     8       1     1     1     1     1     1     1     1     1
Unidirectional P2P=Disabled Bandwidth Matrix (GB/s)
   D\D     0      1      2      3      4      5      6      7      8
     0 1617.49  43.13  43.19  43.12  43.25  42.78  42.72  42.67  42.78
     1  42.94 1655.19  43.63  43.48  43.71  42.97  42.82  42.82  42.83
     2  42.92  43.50 1665.78  43.37  43.52  42.87  42.90  42.78  42.83
     3  43.02  43.82  43.85 1662.29  43.65  42.87  42.80  42.89  42.95
     4  42.91  43.71  43.80  43.74 1655.25  43.02  43.21  42.81  42.84
     5  42.80  43.23  43.48  43.33  43.52 1662.23  42.70  42.67  42.76
     6  42.75  43.23  43.39  43.23  43.43  42.83 1651.80  42.62  42.72
     7  42.77  43.27  43.45  43.36  43.47  42.90  42.93 1658.76  42.79
     8  42.83  43.46  43.52  43.29  43.63  42.95  42.89  42.76 1658.70
Unidirectional P2P=Enabled Bandwidth (P2P Writes) Matrix (GB/s)
   D\D     0      1      2      3      4      5      6      7      8
     0 1610.82  55.61  55.63  55.65  55.65  56.58  56.58  56.58  56.55
     1  55.63 1637.84  56.55  56.58  56.58  55.63  55.63  55.64  55.64
     2  55.61  56.58 1641.28  56.58  56.57  55.63  55.63  55.63  55.63
     3  55.63  56.58  56.58 1644.74  56.58  55.63  55.63  55.63  55.64
     4  55.64  56.55  56.55  56.58 1641.28  55.61  55.63  55.59  55.64
     5  56.57  55.64  55.63  55.64  55.64 1641.34  56.55  56.58  56.55
     6  56.58  55.60  55.61  55.61  55.61  56.57 1641.28  56.55  56.58
     7  56.57  55.64  55.64  55.62  55.61  56.55  56.55 1641.28  56.58
     8  56.55  55.63  55.63  55.61  55.64  56.58  56.55  56.57 1644.74
Bidirectional P2P=Disabled Bandwidth Matrix (GB/s)
   D\D     0      1      2      3      4      5      6      7      8
     0 1604.16  56.95  56.78  57.30  57.30  56.43  56.55  56.77  56.35
     1  56.87 1641.20  57.28  57.00  57.36  57.03  57.04  56.93  56.81
     2  56.64  57.83 1642.95  57.18  57.23  56.91  56.79  56.95  56.69
     3  56.74  57.45  57.03 1642.95  57.17  56.82  57.06  56.80  56.94
     4  56.61  57.55  57.11  57.10 1644.68  56.81  56.87  56.93  56.64
     5  56.91  57.30  57.17  57.07  57.11 1644.68  56.45  56.81  56.26
     6  56.71  56.85  56.86  57.20  57.29  56.27 1644.66  56.32  56.44
     7  56.57  56.96  56.76  56.82  57.60  56.66  56.62 1646.42  56.67
     8  56.21  56.81  56.87  57.15  56.43  56.20  56.60  56.84 1639.48
Bidirectional P2P=Enabled Bandwidth Matrix (GB/s)
   D\D     0      1      2      3      4      5      6      7      8
     0 1600.87 111.12 111.09 111.10 111.11 111.38 111.38 111.38 111.38
     1 111.09 1637.76 111.34 111.39 111.42 111.17 111.00 111.13 111.10
     2 111.09 111.39 1637.79 111.39 111.45 111.16 111.10 111.10 111.10
     3 111.08 111.39 111.39 1636.07 111.40 111.15 111.16 111.13 111.12
     4 111.21 111.39 111.39 111.40 1636.07 111.19 111.04 111.14 111.13
     5 111.38 111.09 111.13 111.14 111.08 1641.20 111.34 111.38 111.39
     6 111.40 111.12 111.11 111.09 111.03 111.35 1637.79 111.38 111.38
     7 111.34 111.14 111.09 111.11 111.14 111.38 111.39 1639.51 111.38
     8 111.40 111.13 111.15 111.16 111.14 111.39 111.39 111.40 1634.33
P2P=Disabled Latency Matrix (us)
   GPU     0      1      2      3      4      5      6      7      8
     0   2.07  14.34  14.35  14.31  14.31  14.34  14.31  14.32  14.33
     1  14.42   2.07  14.33  14.33  14.33  14.33  14.33  14.31  14.31
     2  14.27  14.25   2.07  14.33  14.32  14.27  14.34  14.25  14.35
     3  14.32  13.58  13.83   2.07  14.31  14.30  14.32  14.33  14.32
     4  14.35  14.07  14.07  14.33   2.07  14.30  14.31  14.15  14.33
     5  14.31  14.33  14.31  14.32  14.32   2.07  14.31  14.31  14.32
     6  14.31  14.46  14.33  14.31  14.32  14.33   2.07  14.32  14.33
     7  14.32  14.32  14.31  14.31  14.31  14.31  14.31   2.07  14.31
     8  14.32  14.33  14.32  14.33  14.33  14.32  14.29  14.33   2.07

   CPU     0      1      2      3      4      5      6      7      8
     0   2.21   5.74   5.92   5.88   5.51   6.29   6.60   6.56   6.17
     1   5.90   1.87   5.25   5.16   4.88   5.59   5.91   5.90   5.53
     2   6.07   5.14   1.96   5.39   5.10   5.80   6.17   6.10   5.80
     3   6.00   5.04   5.37   2.02   5.04   5.79   6.12   6.09   5.75
     4   5.81   4.81   5.14   5.12   1.85   5.55   5.88   5.86   5.48
     5   6.21   5.29   5.59   5.58   5.27   2.09   6.39   6.34   6.03
     6   6.47   5.49   5.83   5.84   5.50   6.24   2.21   6.61   6.24
     7   6.47   5.51   5.83   5.83   5.50   6.26   6.58   2.22   6.20
     8   6.25   5.27   5.57   5.55   5.25   6.01   6.36   6.36   2.08
P2P=Enabled Latency (P2P Writes) Matrix (us)
   GPU     0      1      2      3      4      5      6      7      8
     0   2.07   0.44   0.43   0.43   0.37   0.37   0.37   0.37   0.36
     1   0.37   2.07   0.36   0.43   0.36   0.37   0.44   0.37   0.37
     2   0.43   0.36   2.07   0.36   0.35   0.36   0.37   0.43   0.36
     3   0.35   0.42   0.36   2.07   0.35   0.36   0.43   0.36   0.42
     4   0.35   0.36   0.35   0.35   2.07   0.36   0.35   0.35   0.35
     5   0.44   0.36   0.37   0.36   0.37   2.07   0.38   0.37   0.43
     6   0.37   0.36   0.43   0.43   0.36   0.37   2.07   0.43   0.37
     7   0.36   0.36   0.37   0.37   0.42   0.43   0.43   2.07   0.43
     8   0.38   0.38   0.44   0.45   0.44   0.37   0.38   0.44   2.07

   CPU     0      1      2      3      4      5      6      7      8
     0   2.19   1.63   1.57   1.57   1.57   1.57   1.57   1.58   1.57
     1   1.29   1.82   1.26   1.27   1.27   1.28   1.26   1.26   1.26
     2   1.45   1.40   1.95   1.39   1.38   1.39   1.40   1.38   1.39
     3   1.50   1.39   1.38   1.96   1.38   1.37   1.38   1.39   1.38
     4   1.32   1.27   1.26   1.26   1.85   1.26   1.26   1.26   1.27
     5   1.60   1.54   1.53   1.53   1.54   2.08   1.53   1.53   1.54
     6   1.71   1.66   1.65   1.65   1.65   1.66   2.19   1.65   1.66
     7   1.72   1.66   1.64   1.65   1.64   1.65   1.64   2.20   1.64
     8   1.59   1.54   1.53   1.52   1.52   1.52   1.53   1.53   2.08

NOTE: The CUDA Samples are not meant for performance measurements. Results may vary when GPU Boost is enabled.

@magikRUKKOLA
Copy link
Copy Markdown

magikRUKKOLA commented Dec 4, 2025

@aikitoria

It's basically flawless:

Whoa this is nuts!! You basically achieved the NvLink speed of RTX 3090 with the driver of yours. GJ, LGMT
That is, NvLink actually suck because it can only link two adjacent GPUs so your setup is much better indeed. The only problem is that RTX 5090 is about two times more expensive (used) than RTX 3090 (used) if accounted for the price per VRAM GB. :/
What cooling setup have you used? The water cooling or just a bunch of risers [to spread out the GPUs from each other] etc.?

@aikitoria
Copy link
Copy Markdown

aikitoria commented Dec 4, 2025

Current gen nvlink is (much) faster than this, but for our single user inference tasks it does not matter. The actual data exchanged is tiny, we just want the low latency.

I have the GPUs mounted in a custom open frame I built out of alu extrusions, so they all get good air and it's actually nice and quiet. Using MCIO cables from 10gtek and device adapters from C-Payne in "Bottom" style for short traces. It's kinda off topic for this thread though. I've shared ~all about this setup (and the previous one with pcie gen4, problems I hit and how I fixed them, etc) in the old TheBloke and ExLlama discord servers if you want to build something like it

@aikitoria
Copy link
Copy Markdown

Maybe try it with just 2x GPUs and see if that works?

2 GPUs won't be big enough to load that model, but let's try 4 from the same NUMA node later

@ikawrakow
Copy link
Copy Markdown
Owner Author

@ubergarm

Now that you have posted comparisons with mainline, it looks like something goes wrong with FA for TG. In all of my testing ik_llama.cpp has faster TG than mainline with default split mode "layer", with the performance gap increasing with context length. But in your results mainline is 1.7X faster than ik_llama.cpp at 64k tokens (with split mode "layer"). If I look at @aikitoria TG results with split mode "layer", they get about 40% TG performance at 64k tokens, which is inline with what I see on the 2x3090 box. But your ik_llama.cpp TG with sm "layer" is just 20% of zero context performance. So, I guess, the wrong FA kernel gets picked for your GPUs. Can you add printf statement in ggml/src/ggml-cuda/fattn.cu just before every call to ggml_cuda_flash_attn_ext_vec_f32, ggml_cuda_flash_attn_ext_vec_f16, ggml_cuda_flash_attn_ext_tile_f32, ggml_cuda_flash_attn_ext_tile_f16, ggml_cuda_flash_attn_ext_mma_new, ggml_cuda_flash_attn_ext_wmma_f16, ggml_cuda_flash_attn_ext_mma_f16, then run TG with this model? Thanks.

The PP results are as expected. ik_llama.cpp tends to be 2-3X times faster than mainline for 64k context and split mode "layer", so with split mode "graph" working as intended it becomes close to 5X.

@ikawrakow
Copy link
Copy Markdown
Owner Author

@aikitoria

Thanks for testing on your amazing system!

I think we have established that the current TP implementation is only useful on 2 GPUs, and becomes completely inadequate on 4 or more GPUs. If you want to test again, I think you will get best results with 4 GPUs, but using only 2 of them for TP, and the other 2 for computing the routed experts in part of the layers. Something like this:

CUDA_VISIBLE_DEVICES=1,2,3,4 ./build/bin/llama-sweep-bench \
    --model "/models/raw/zai-org_GLM-4.5-Air-GGUF/zai-org_GLM-4.5-Air-Q4_0/zai-org_GLM-4.5-Air-Q4_0-00001-of-00002.gguf" \
    -c 69632 \
    -ngl 99 \
    -ub 4096 -b 4096 \
    --threads 1 \
    --warmup-batch \
    --tensor-split 100,100,0,0 \
    -ot "blk\.(45|44|43|42|41|40|39|38|37|36|35|34|33)\.ffn_(up|gate|down)_exps\.weight=CUDA3" \
    -ot "blk\.(32|31|30|29|28|27|26|25|24|23|22|21|20)\.ffn_(up|gate|down)_exps\.weight=CUDA2" \
    --numa isolate

You may need to adjust the layers on CUDA2 and CUDA3 to avoid OOM. But basically this will use TP for self-attention, shared experts, and routed experts in the first 20 layers, and use the 3 and 4th GPU to do the routed experts in the remaining layers without TP.

@ubergarm
Copy link
Copy Markdown
Contributor

ubergarm commented Dec 5, 2025

@ikawrakow

In testing it seems like the default -sm layer mode has better TG now after pulling latest tip of main and rebuilding. Also -sm graph mode is giving faster results too...

I see a few PRs came through in the past couple hours and I'll go read them next to catch up. So either something was updated, or perhaps either I made a mistake or something on the rig wasn't quite right after first updating all the drivers yesterday.. huh..

EDIT it is also faster PP on mainline today than yesterday with exact same commit (i checked my logs and the git sha looks right, so maybe the GPUs had something on them I didn't notice or something was wonky)...

sweep-bench-PR1022-GLM-4 5-Air

So, I guess, the wrong FA kernel gets picked for your GPUs. Can you add printf statement

I added a small patch built on tip of the latest main@f4def9b3 and running it is printing out:

DEBUG: /home/w/projects/ik_llama.cpp/ggml/src/ggml-cuda/fattn.cu:140: ggml_cuda_flash_attn_ext_mma_f16

This suggests the only print statement coming from here: https://github.com/ikawrakow/ik_llama.cpp/blob/main/ggml/src/ggml-cuda/fattn.cu#L125

👈 Patch
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
index 83c7cf40..4d522c40 100644
--- a/ggml/src/ggml-cuda/fattn.cu
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -16,6 +16,10 @@
 
 #include <cstdint>
 
+#define DEBUG(msg) \
+    fprintf(stderr, "DEBUG: %s:%d: ", __FILE__, __LINE__); \
+    fprintf(stderr, "%s\n", msg);
+
 #define FATTN_KQ_STRIDE 256
 
 static inline bool mma_better_than_turing(const int cc) {
@@ -55,8 +59,10 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     // On AMD the tile kernels perform poorly, use the vec kernel instead:
     if (cc >= CC_OFFSET_AMD) {
         if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
+            DEBUG("ggml_cuda_flash_attn_ext_vec_f16");
             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
         } else {
+            DEBUG("ggml_cuda_flash_attn_ext_vec_f32");
             ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
         }
         return;
@@ -64,8 +70,10 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
 
     if (!fast_fp16_available(cc)) {
         if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
+            DEBUG("ggml_cuda_flash_attn_ext_vec_f32");
             ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
         } else {
+            DEBUG("ggml_cuda_flash_attn_ext_tile_f32");
             ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
         }
         return;
@@ -74,14 +82,18 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     if (!fp16_mma_available(cc)) {
         if (precision == GGML_PREC_DEFAULT) {
             if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
+                DEBUG("ggml_cuda_flash_attn_ext_vec_f16");
                 ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
             } else {
+                DEBUG("ggml_cuda_flash_attn_ext_tile_f16");
                 ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
             }
         } else {
             if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
+                DEBUG("ggml_cuda_flash_attn_ext_vec_f32");
                 ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
             } else {
+                DEBUG("ggml_cuda_flash_attn_ext_tile_f32");
                 ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
             }
         }
@@ -96,6 +108,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !(Q->ne[1] == 1 && n_swa > 0);
     const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*WARP_SIZE) == 0;
     if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1 && !ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
+        DEBUG("ggml_cuda_flash_attn_ext_vec_f32");
         ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
         return;
     }
@@ -107,6 +120,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     // so no other implementation works.
     //
     if (new_mma_available(cc) && ((K->ne[0] == 576 && V->ne[0] == 512) || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) {
+        DEBUG("ggml_cuda_flash_attn_ext_mma_new");
         ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
         return;
     }
@@ -117,11 +131,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     // We also need it if the new MMA is not available
     //
     if (!new_mma_available(cc) || K->ne[0] != V->ne[0]) {
+        DEBUG("ggml_cuda_flash_attn_ext_wmma_f16");
         ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
         return;
     }
 
     // As mentioned above, the new-new MMA is slower then the new MMA.
+    DEBUG("ggml_cuda_flash_attn_ext_mma_f16");
     ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
     //ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
 }

I didn't go back and re-test on the previous commit, but will leave this here for now and go read up on the new stuff you merged. hah.. Thanks!

EDIT here is same graph as above but with just with today's measurements which look better than yesterday for some reason...

sweep-bench-PR1022-GLM-4 5-Air-vs-mainline

@ikawrakow
Copy link
Copy Markdown
Owner Author

@ubergarm

Thanks! Yes, this is the kernel that is supposed to get used. I was concerned that somehow the vector kernel was getting invoked for TG, and that's why we were seeing such a performance decline with context (more than expected).

So, in your case, split mode "graph" looks like a real winner.

@ikawrakow
Copy link
Copy Markdown
Owner Author

@magikRUKKOLA

To clarify my point I have to add the following...

I was pointing out that in order to access the machine that does not have a dedicated IP one have to employ a NAT traversal technique. So the easiest example of it would be to use a Tailscale (which is opensource etc.). If one doesn't trust the code for one reason or another, one could always use something like VM (qemu etc.).

I guess, in that case I would try to learn how to use tailscale. When I wrote that I would never pipe a script that I downloaded from the internet into sh, it wasn't about me trusting or not trusting the tailscale folks, but about not piping a script into sh before first taking a look. Just like you are concerned that someone might hack into your computer in the 3 minutes for which you enabled ssh access with a password so I can log in using that and add my public key, I'm worried that someone could have hacked the server and replaced the script you are downloading and piping into sh. It is not that this sort of thing never happens. That's why always first look at the script, check the hash, etc.

@magikRUKKOLA
Copy link
Copy Markdown

magikRUKKOLA commented Dec 5, 2025

@ubergarm

In testing it seems like the default -sm layer mode has better TG now after pulling latest tip of main and rebuilding. Also -sm graph mode is giving faster results too...

Shalom! I am having the absolute same thoughts with RTX 3090 now ))

#1029 (comment)

@magikRUKKOLA
Copy link
Copy Markdown

@ikawrakow

[...] That's why always first look at the script, check the hash, etc.

Ahh!! Got it now. Fully agree! The explanation is highly appreciated. Very cool point! Thanks!!

@magikRUKKOLA
Copy link
Copy Markdown

magikRUKKOLA commented Dec 9, 2025

@Ph0rk0z

Heh.. so what's your preferred solution since no static IP and generally I block all remote access to the lan. It probably has to be some kind of service that keeps it up for you. And whenever you get done with it I just go back to being an anonymous pirate.

It turned out that Tailscale does not work for us. So we dropped the attempts to make it work and instead we just used the ngrok.

ngrok setup:
Install:

curl -s https://ngrok-agent.s3.amazonaws.com/ngrok.asc | sudo tee /etc/apt/trusted.gpg.d/ngrok.asc >/dev/null
echo "deb https://ngrok-agent.s3.amazonaws.com buster main" | sudo tee /etc/apt/sources.list.d/ngrok.list
sudo apt update
sudo apt install ngrok

Sign up & get token:
Go to https://ngrok.com (free account)
Get your authtoken from dashboard
Run: ngrok authtoken YOUR_TOKEN
Expose SSH:

ngrok tcp 22

It gives you something like tcp://0.tcp.ngrok.io:12345
Then anyone connects with:

ssh [email protected] -p 12345

No special client software needed. Regular SSH works.

That would require providing the debit card details without the CVV (they do not charge any money; as stated they're using this to combat the fraud cases etc.). So the perfect solution is still to buy the VPS for the crypto. But ngrok works for free. So its privacy vs time invested dilemma again.

@abc-nix abc-nix mentioned this pull request Dec 10, 2025
1 task
@Ph0rk0z
Copy link
Copy Markdown

Ph0rk0z commented Dec 10, 2025

Oh wow that's interesting.. i wonder if they take debit gift cards. I guess I will have to see.

@hksdpc255
Copy link
Copy Markdown
Contributor

How complex it will be to add a new model support for tensor parallel based on this work? Is it worth to try add MiniMax-M2 seris model using vibe coding?

@magikRUKKOLA
Copy link
Copy Markdown

magikRUKKOLA commented Dec 30, 2025

using vibe coding?

Hahaha!!! You're funny lad! :))

@hksdpc255
Copy link
Copy Markdown
Contributor

I’m not familiar with the logic of LLM inference. The only way I can help for this part is vibe coding and validate the result. : )

@ikawrakow
Copy link
Copy Markdown
Owner Author

@hksdpc255 I appreciate that you want to help, but I cannot say that I like the current trend of vibe-coded contributions. In general, I think if someone wants to seriously contribute to a project, they need to become familiar with at least the parts of the project where they want to contribute. Even more so when it comes to building the compute graph, as this is a core part of the inference engine.

@hksdpc255
Copy link
Copy Markdown
Contributor

None of my contributions are vibe-coded unless explicitly stated. Since you mentioned that you prefer not to rely on vibe-coding for unfamiliar parts of the codebase, I will avoid doing so accordingly.

@Ph0rk0z
Copy link
Copy Markdown

Ph0rk0z commented Dec 31, 2025

Big issue with vibecoding is people don't read what the LLM spit out and barely test if it works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants