Skip to content

--numa mirror: mirror model weights to every Numa node in the system#16000

Draft
dbsanfte wants to merge 25 commits intoggml-org:masterfrom
dbsanfte:numa-mirror
Draft

--numa mirror: mirror model weights to every Numa node in the system#16000
dbsanfte wants to merge 25 commits intoggml-org:masterfrom
dbsanfte:numa-mirror

Conversation

@dbsanfte
Copy link
Copy Markdown

@dbsanfte dbsanfte commented Sep 15, 2025

This PR adds a new --numa mirror option which mirrors model weights to each Numa node on the system, and uses a thread-local var in the OMP threadpool to select the correct mirror copy local to the thread at runtime, to eliminate cross-socket traffic.

Build instructions:

apt-get update
apt-get install -y libnuma-dev libgomp1
cmake -B build -DCMAKE_BUILD_TYPE=Release -DCMAKE_C_FLAGS="-march=native" -DCMAKE_CXX_FLAGS="-march=native"  -DGGML_OPENMP=ON
cmake --build build --parallel

To test:

# No mirroring:
./build/bin/llama-bench -m ~/models/Qwen3-30B-A3B-UD-Q4_K_XL.gguf

# Numa mirroring of model weights to every node:
./build/bin/llama-bench -m ~/models/Qwen3-30B-A3B-UD-Q4_K_XL.gguf --numa mirror

Test system is a two-socket Xeon 6238R Cascade Lake, with 768GB of DDR4-2933 (6 channels per socket).

Without --numa mirror:

developer@81ec6c6e6af6:/workspaces/llama-cpp-dbsanfte-dev$ ./build/bin/llama-bench -m ./.devcontainer/Qwen3-32B-Q6_K.gguf     
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| qwen3 32B Q6_K                 |  25.03 GiB |    32.76 B | CPU        |      56 |           pp512 |         20.99 ± 0.01 |
| qwen3 32B Q6_K                 |  25.03 GiB |    32.76 B | CPU        |      56 |           tg128 |          1.91 ± 0.00 |

build: c665d3c9 (6468)

With --numa mirror:

developer@81ec6c6e6af6:/workspaces/llama-cpp-dbsanfte-dev$ ./build/bin/llama-bench -m .
/.devcontainer/Qwen3-32B-Q6_K.gguf --numa mirror
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| qwen3 32B Q6_K                 |  25.03 GiB |    32.76 B | CPU        |      56 |           pp512 |         21.36 ± 0.11 |
| qwen3 32B Q6_K                 |  25.03 GiB |    32.76 B | CPU        |      56 |           tg128 |          2.70 ± 0.00 |

build: c665d3c9 (6468)

Intel PCM tool during mirror inference showing both sockets using local mem:

image

There's still a bit of cross-socket traffic (5%) because only model weights are mirrored, not tensors created at inference time. I'll play with that, maybe mirroring those aggressively will help too, or maybe not. Right now anything created at inference time just gets set to live on Node 0.

- Achieved 5% inference speed improvement (14.6 -> 15.3 t/s)
- Clean explicit NUMA setup during model loading
- Ultra-minimal hot path with thread-local NUMA node access
- Working NUMA mirrors for all model weights
- Performance: text generation improved, prompt processing needs optimization

Performance Results (Qwen3-30B-A3B):
- Text Generation: 14.6 -> 15.3 t/s (+5% improvement)
- Prompt Processing: 176 -> 152 t/s (14% regression - needs investigation)

Technical Implementation:
- tensor_data(): O(1) NUMA-aware access via thread-local ggml_current_numa_node
- tensor_set_data_with_numa_mirrors(): Explicit NUMA setup for model weights
- NUMA coordinator: Thread binding and memory locality
- Clean separation: model loading (explicit setup) vs inference (fast access)
@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs Vulkan Issues specific to the Vulkan backend examples python python script changes devops improvements to build systems and github actions ggml changes relating to the ggml tensor library for machine learning SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language Apple Metal https://en.wikipedia.org/wiki/Metal_(API) Ascend NPU issues specific to Ascend NPUs OpenCL Issues specific to the OpenCL backend IBM zDNN issues specific to IBM zDNN Accelerator labels Sep 15, 2025
@dbsanfte dbsanfte marked this pull request as draft September 15, 2025 06:14
@dbsanfte
Copy link
Copy Markdown
Author

dbsanfte commented Sep 15, 2025

Physical core detection was very broken in arg.cpp / common.cpp, it assumed every 2nd core consecutively was a hyperthread. This isn't true on Xeons at least - Physical cores are my first 56, then the next 56 are the hyperthreads. This led to wildly inconsistent results at inference time. Now I use proper CPU topology detection to choose only physical cores. But if you still want to use hyperthreads I added a new option: --cpu-use-hyperthreading.

@dbsanfte
Copy link
Copy Markdown
Author

dbsanfte commented Dec 2, 2025

Pretty close. There's still some overhead - I have too many OpenMP regions and need to consolidate them. But the best I got from llama-cpp even with the mirror fix was ~85GB/s. I regularly see 165+ with Llaminar.

@Ph0rk0z
Copy link
Copy Markdown

Ph0rk0z commented Dec 2, 2025

It still uses GGUF? i guess I'll try it when it's out.

@dbsanfte
Copy link
Copy Markdown
Author

dbsanfte commented Dec 2, 2025

Yup it uses gguf.

I might add other formats too, it's not really that hard. It all gets repacked for AVX512-VNNI anyway.

@usrlocalben
Copy link
Copy Markdown

@dbsanfte will your system be able to model expert-parallelism? it doesn't seem to fit well in llama since NUMA nodes would need to be a first-class concept (addressable at the model graph level) and that doesn't seem to be the case as I understand it.

@dbsanfte
Copy link
Copy Markdown
Author

dbsanfte commented Dec 2, 2025

Yeah. I started with Qwen 2.5 to begin with but I'll do Qwen3 and Qwen3-MoE next. I designed it around a reusable pipeline system so it's extremely flexible.

@DocShotgun
Copy link
Copy Markdown
Contributor

@usrlocalben

I've been looking for further NUMA optimizations and found this thread. My setup is a dual 5th gen Xeon server with 2 nodes per socket (768gb ddr5-5600 total) and a single RTX Pro 6000, which is somewhat similar to your setup. I had always been using --no-mmap, and NUMA interleave+distribute across as many nodes as I need to bind to based on the size of the model. The problem here is that for smaller models such as DeepSeekV3 or GLM 4.7 I'm effectively just not using the second socket worth of compute and bandwidth.

I was curious about your suggestion to use drop_caches followed by loading the model with --mmap and --numa distribute so I tried it out - loading deepseek 4.92bpw across both sockets, and wow, it really improves TG speed noticeably from ~11 T/s on single socket to around ~15 T/s... but it really hurts my prompt processing speed compared to the --no-mmap case! Previously, llama.cpp could essentially max out my RTX Pro 6000's compute and PCIe 5.0 x16 bandwidth during large batch PP with host->device offload, but with these new settings, that's not the case. The GPU utilization is far lower and the PP is far slower, not even 100 T/s PP even on 4096 batches (previously could get up to 400).

I'm wondering if this is something you noticed as well. Perhaps the mmap-based NUMA "migration" results in a less efficient memory configuration for host->device offload during PP?

@jukofyork
Copy link
Copy Markdown
Collaborator

jukofyork commented Jan 25, 2026

use drop_caches followed by loading the model with --mmap and --numa distribute so I tried it out - loading deepseek 4.92bpw across both sockets, and wow, it really improves TG speed noticeably from ~11 T/s on single socket to around ~15 T/s...

I also found this on my dual xeons.

One thing that might help further is to use all your threads (ie: hyperthreads included). This gained me back some of the lost PP for llama.cpp (and oddly makes the model load faster from SSD), but didn't help ik_llama.cpp which seems best to just use non-hyperthread threads only...

but it really hurts my prompt processing speed compared to the --no-mmap case! Previously, llama.cpp could essentially max out my RTX Pro 6000's compute and PCIe 5.0 x16 bandwidth during large batch PP with host->device offload, but with these new settings, that's not the case. The GPU utilization is far lower and the PP is far slower, not even 100 T/s PP even on 4096 batches (previously could get up to 400).

Yeah, I can't figure this out either:

  • I can get 12.7GB/s through my PCI-E 3.0 16x (max theoretical 16GB/s) and max out my RTX 6000 Ada IF I use the new direct-io loading method (which then kills my TG due to loading everything onto a single node [numactl --interleave=all doesn't help either]).
  • If I do the drop caches, --mmap and --numa distribute trick I can only ever get 4.7GB/s through my PCI-E 3.0 16x (but very occasionally I get 6.5GB/s!), but this then kills my offload PP...

I suspected that it could be an alignment problem due to the weird 6.5GB/s I can get occasionally, but alas I tied changing the stock 32 byte alignment inside of ggml for tensors in a gguf file to 4k, 64k, 1mb and 2mb, and sadly it made no difference at all (it's actually quite involved to do this and requires several places hacking to read in a 32byte aligned gguf, write out a XXXbyte aligned gguf and then read it in and actually use it!).

So now, all I can think is that the QPI link between nodes is causing problems (for me it's 80GB/s which seems like it shouldn't be a problem for getting 12.7GB/s to the GPU, but who knows).

I get exactly the same problem on ik_llama.cpp and the same weird 90% of the time 4.7GB/s, but then randomly 10% of the time get 6.5GB/s...

@DocShotgun
Copy link
Copy Markdown
Contributor

Yeah, I can't figure this out either:

  • I can get 12.7GB/s through my PCI-E 3.0 16x (max theoretical 16GB/s) and max out my RTX 6000 Ada IF I use the new direct-io loading method (which then kills my TG due to loading everything onto a single node [numactl --interleave=all doesn't help either]).
  • If I do the drop caches, --mmap and --numa distribute trick I can only ever get 4.7GB/s through my PCI-E 3.0 16x (but very occasionally I get 6.5GB/s!), but this then kills my offload PP...

I suspected that it could be an alignment problem due to the weird 6.5GB/s I can get occasionally, but alas I tied changing the stock 32 byte alignment inside of ggml for tensors in a gguf file to 4k, 64k, 1mb and 2mb, and sadly it made no difference at all (it's actually quite involved to do this and requires several places hacking to read in a 32byte aligned gguf, write out a XXXbyte aligned gguf and then read it in and actually use it!).

So now, all I can think is that the QPI link between nodes is causing problems (for me it's 80GB/s which seems like it shouldn't be a problem for getting 12.7GB/s to the GPU, but who knows).

My guess is that the mmap+distribute “migrate” trick puts the weights all over the place in a way that makes it harder to do host->device offload efficiently, but that’s just me throwing stuff out there lol.

The best overall setup I’ve gotten so far is to interleave+distribute+no mmap on 2 nodes on the same socket (if the model fits in 384gb ram). During PP the GPU can hit 100% util and I can get like 50+ GB/s across PCIe, and the TG speed is fair. It just feels bad not using the other socket lol. With the drop_caches+mmap+distribute trick, I could only get like 8 GB/s across PCIe and GPU like 5-15% util during PP (but excellent TG!).

Interleave+distribute+no mmap across all nodes is… acceptable for models that are too big for a single socket, but for smaller models it’s still a performance loss in TG for me.

I asked local Minimax last night to investigate whether multi-NUMA expert parallelism would be a viable solution but I think it got lost in the weeds lol.

@Ph0rk0z
Copy link
Copy Markdown

Ph0rk0z commented Jan 26, 2026

Mirroring is just a kludge. It would be better if local threads only accessed local memory and worked in parallel. I have even turned off numa in the system and it behaves identical to interleave/distribute but the QPI still gets saturated. When calculations are moved to the GPU they might be crossing QPI and then PCIE as well.

My PCIE3 all to all speeds are similar, around 6.5-7gb/s. I'm not sure that is a bottleneck. So what am I trying to say... that there's no way to fix this with small changes or settings.

@dbsanfte
Copy link
Copy Markdown
Author

It's not just weights either. It's the KVCache, residual buffer, workspaces for kernels, etc etc. All of it needs to be bound to the local NUMA. KVCache and weights need to be sharded across sockets and you need to do TP with a cross socket allreduce. I don't want to say that's impossible to do in llama cpp but it was way faster just to write all the infrastructure myself in my own engine.

Just working on cuda and rocm now btw. I had to design a graph orchestrator and tune the cuda and hip kernels. I have cross vendor TP working (cuda and hip in the same build, doing p2p TP allreduce via pcie). Getting closer.

@jukofyork
Copy link
Copy Markdown
Collaborator

Do either of you who are bottlenecked by the PCI-E bus have enough RAM to fix ~2x the expert weights in?

I have narrowed it down to misalignment, but can't seem to get a mmaped version of GGUF to align properly, but I have found a way to make a cache that basically stores a second copy of the experts in properly aligned memory.

@usrlocalben
Copy link
Copy Markdown

Do either of you who are bottlenecked by the PCI-E bus have enough RAM to fix ~2x the expert weights in?

I have narrowed it down to misalignment, but can't seem to get a mmaped version of GGUF to align properly, but I have found a way to make a cache that basically stores a second copy of the experts in properly aligned memory.

This would be normal numa distribute + extra copy? I can fit it, for e.g. K2 Q4_X. That is, unless the extra copy is all on one node, that won't fit.

@Ph0rk0z
Copy link
Copy Markdown

Ph0rk0z commented Feb 7, 2026

Depends on the model. I have 192g per node but only one node has the GPUs.

@jukofyork
Copy link
Copy Markdown
Collaborator

jukofyork commented Feb 8, 2026

So only tested on ik_llama.cpp, but this is working for me and should work for llama.cpp the same I think:

static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {

Needs changing to this:

GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;

    ggml_cuda_set_device(ctx->device);
    if (strstr(tensor->name, "_exps") != nullptr) {
        static std::unordered_map<const void*, void*> pinned_cache;
        static std::mutex cache_mutex;
        void* pinned_copy = nullptr;
        {
            std::lock_guard<std::mutex> lock(cache_mutex);
            auto it = pinned_cache.find(data);
            if (it != pinned_cache.end()) {
                pinned_copy = it->second;
                //fprintf(stderr, "[CACHE HIT] %s: %.2f MB\n", tensor->name, size/1e6);
            } else {
                CUDA_CHECK(cudaHostAlloc(&pinned_copy, size, cudaHostAllocDefault));
                memcpy(pinned_copy, data, size);
                pinned_cache[data] = pinned_copy;
                //fprintf(stderr, "[CACHE MISS] %s: %.2f MB\n", tensor->name, size/1e6);
            }
        }
        CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, pinned_copy, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
    } else {
        CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
    }
    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
}

To use with ik_llama.cpp you need to change here:

https://github.com/ikawrakow/ik_llama.cpp/blob/e22b2d124635d7f9403b8ee4644e472a29a9b332/ggml/src/ggml-cuda.cu#L610

but also make it skip this:

https://github.com/ikawrakow/ik_llama.cpp/blob/e22b2d124635d7f9403b8ee4644e472a29a9b332/src/llama.cpp#L4928

like this:

    //if (params.only_active_experts) {
    //    LLAMA_LOG_INFO("XXXXXXXXXXXXXXXXXXXXX Setting only active experts offload\n");
    //    ggml_backend_sched_set_only_active_experts(ctx->sched, true);
    //}

(or else it's impossible to cache whole tensors!)


You should run as you do do get maximum NUMA throughput, eg:

# Turn off NUMA balancing
echo 0 | sudo tee /proc/sys/kernel/numa_balancing > /dev/null

# Drop caches (if first time)
echo 3 | sudo tee /proc/sys/vm/drop_caches > /dev/null

./llama-server --numa distribute --threads "$(nproc)"

or whatever you've found works best for you...

The first batch will take ages and you'll see your memory use slowly grow to around double normal.

After this you should get full PCI-E throughput until you unload the model and have to do the caching again.

It's a bit of a crappy method, but I can't seem to get it to have properly aligned memory any other way (eg: creating aligned tensors in GGUF and memory mapping to 64k or 2MB boundary doesn't work...).

@DocShotgun
Copy link
Copy Markdown
Contributor

I could probably do that with Deepseek (since that was already fitting on a single node for me), but not for Kimi. If I'm understanding correctly, this would allow us to get the speed benefits of mmap migration without killing the disaggregated prefill speed, at the cost of double the RAM use?

@jukofyork
Copy link
Copy Markdown
Collaborator

jukofyork commented Feb 8, 2026

I could probably do that with Deepseek (since that was already fitting on a single node for me), but not for Kimi. If I'm understanding correctly, this would allow us to get the speed benefits of mmap migration without killing the disaggregated prefill speed, at the cost of double the RAM use?

Yeah, I only get 4.5-4.7GB/s through the PCI-E bus during offloaded PP (and randomly 7.5GB/s about 10-15% of the time for who knows what reason!).

By doing this I get 12.7GB/s through the PCI-E bus during offloaded PP (same as the new "Direct IO" method, but sadly I get halved TG and PP for non-offloaded using this as everything ends up on a single NODE and badly laid out for NUMA threads).

This essential let's you get the best NUMA TG, best (non-offloaded) NUMA PP and the best offloaded PP, at the cost of double RAM use.

@DocShotgun
Copy link
Copy Markdown
Contributor

Maybe it's worth me trying the next time I run a model that I have room for double memory cost for lol, been using Kimi K2.5 Q4_X recently which is too large for that.

When using both sockets, I noticed was getting better TG with --no-mmap + --no-direct-io compared to using direct IO, unsure of the reason why, maybe some shenanigans with how stuff lands across nodes, which is strange because I'm interleaving across nodes anyways.

@usrlocalben
Copy link
Copy Markdown

@jukofyork

Here's Kimi K2.5 on ik_llama with your patch. tl;dr - it's effective.
2x 9115 NPS2, 24x ddr5 4800, 1x rtx6000pro (CPU1, not CPU0)

This should really be done with a sweep tool. Usually I'm only interested in decode-phase and the CPU-side of MoE (i.e. constant time) so I don't have a go-to sweep script.

But, I can ensure the system is warm, runs are the same, and have unique prefixes so they don't hit KV-cache.

this is 9115, which only has 2xCCD. I'm not sure what would be expected wrt. PCIe transfers.
Also, this is NPS2. I did a couple runs on NPS1 and your patch showed no change in perf.

Improving the overhead of PCIe transfers for offload would benefit smaller batch more, so I chose 4k for smallish.

Bonus: sglang/kt-kernel INT4, which has been my running system recently since it supports K2.5 vision and has comparable perf.

## bs=16384
20Ktok
bs=16384 Q4_X ik-main 2.4ms/tok  418t/s
bs=16384 Q4_X ik-juk  1.7ms/tok  590t/s
bs=16384 INT4 sg/kt   2.7ms/tok  370t/s

100Ktok
bs=16384 Q4_X ik-main 3.2ms/tok  308t/s
bs=16384 Q4_X ik-juk  2.8ms/tok  357t/s
bs=16384 INT4 sg/kt   2.4ms/tok  413t/s

## bs=4096
20K
bs=4096 Q4_X ik-main 4.9ms/tok  205t/s
bs=4096 Q4_X ik-juk  3.1ms/tok  320t/s
bs=4096 INT4 sg/kt   6.4ms/tok  156t/s

100K
bs=4096 Q4_X ik-main 5.9ms/tok  169t/s
bs=4096 Q4_X ik-juk  4.2ms/tok  236t/s
bs=4096 INT4 sg/kt   6.6ms/tok  151t/s

@jukofyork
Copy link
Copy Markdown
Collaborator

Maybe it's worth me trying the next time I run a model that I have room for double memory cost for lol, been using Kimi K2.5 Q4_X recently which is too large for that.

When using both sockets, I noticed was getting better TG with --no-mmap + --no-direct-io compared to using direct IO, unsure of the reason why, maybe some shenanigans with how stuff lands across nodes, which is strange because I'm interleaving across nodes anyways.

Yeah, don't worry about it! It opens up the hope that there may be a way to align the single copy of the tensor though.

It also proves that the QPI link between nodes isn't a bottleneck as I feared might be the case when "Direct IO" got the full PCI-E speed, but with everything on a single node anyway... I've actually turned on Sub-NUMA Clustering for my Xeon Gold 6248 CPUs (get around 5-7% better TG with this), so the second copy is well and truly split up over nodes.

@Ph0rk0z
Copy link
Copy Markdown

Ph0rk0z commented Feb 8, 2026

Will it double the entire weights of the model or just the portion that is in sysram? For instance with qwen I have only some 60gb in sysram. That I can comfortably double within the same node.

and just to clarify, we need only your code changes or to merge the PR & do the code changes?

@jukofyork
Copy link
Copy Markdown
Collaborator

@jukofyork

Here's Kimi K2.5 on ik_llama with your patch. tl;dr - it's effective. 2x 9115 NPS2, 24x ddr5 4800, 1x rtx6000pro (CPU1, not CPU0)

This should really be done with a sweep tool. Usually I'm only interested in decode-phase and the CPU-side of MoE (i.e. constant time) so I don't have a go-to sweep script.

But, I can ensure the system is warm, runs are the same, and have unique prefixes so they don't hit KV-cache.

this is 9115, which only has 2xCCD. I'm not sure what would be expected wrt. PCIe transfers. Also, this is NPS2. I did a couple runs on NPS1 and your patch showed no change in perf.

Improving the overhead of PCIe transfers for offload would benefit smaller batch more, so I chose 4k for smallish.

Bonus: sglang/kt-kernel INT4, which has been my running system recently since it supports K2.5 vision and has comparable perf.

## bs=16384
20Ktok
bs=16384 Q4_X ik-main 2.4ms/tok  418t/s
bs=16384 Q4_X ik-juk  1.7ms/tok  590t/s
bs=16384 INT4 sg/kt   2.7ms/tok  370t/s

100Ktok
bs=16384 Q4_X ik-main 3.2ms/tok  308t/s
bs=16384 Q4_X ik-juk  2.8ms/tok  357t/s
bs=16384 INT4 sg/kt   2.4ms/tok  413t/s

## bs=4096
20K
bs=4096 Q4_X ik-main 4.9ms/tok  205t/s
bs=4096 Q4_X ik-juk  3.1ms/tok  320t/s
bs=4096 INT4 sg/kt   6.4ms/tok  156t/s

100K
bs=4096 Q4_X ik-main 5.9ms/tok  169t/s
bs=4096 Q4_X ik-juk  4.2ms/tok  236t/s
bs=4096 INT4 sg/kt   6.6ms/tok  151t/s

Yeah, I find it hard to get stable results on NUMA, but from my quick tests yesterday running the full Kimi K2.5 (ie: hacked QAT matching Q4_0 for all experts + Q8_0 for everything else) I saw 250 tokens/s for PP at 70-75k context with batch size 24k. This was around 125 tokens/s when the model decided to use 4.7GB/s through the PCI-E bus and 165k when it decided to use 7.5GBs (I would really love to know what caused this!).

My machine for reference:

  • 2x Xeon Gold 6248 with Hyper-threading disabled and Sub-NUMA Clustering on.
  • 1.5TB of 2666MHz DDR4 fully populated (hex channel; ~120GB/s memory bandwidth per node).
  • 2x RTX 6000 Ada (with a 10% memory overclock).
#!/bin/bash

host_address=192.168.1.1
port_number=8080

# Turn off NUMA balancing
echo 0 | sudo tee /proc/sys/kernel/numa_balancing > /dev/null

# Ask for permission to drop caches
read -p "Do you want to drop caches? (y/n) " -n 1 -r
echo    # Move to a new line
if [[ $REPLY =~ ^[Yy]$ ]]
then
    echo "Dropping caches..."
    echo 3 | sudo tee /proc/sys/vm/drop_caches > /dev/null
fi

# Run the main command
export CUDA_VISIBLE_DEVICES=0,1
~/ik_llama.cpp/build/bin/llama-server \
        --host "$host_address" \
        --port "$port_number" \
        --alias "Kimi-K2.5" \
        --model ~/models/gguf/Kimi-K2.5-Q4_0_lloyd_max.gguf \
        --log-disable \
        --jinja \
        --chat-template-file ~/models/Kimi-K2.5.jinja \
        --mla-use 3 \
        --flash-attn 1 \
        --n-gpu-layers 99 \
        --tensor-split 24,37 \
        --numa distribute \
        --threads "$(nproc)" \
        --override-tensor exps=CPU \
        --ctx_size 262144 \
        --cache-type-k f16 \
        --batch-size 24576 \
        --ubatch-size 24576 \
        --attention-max-batch 2048 \
        --cuda-params offload-batch-size=64 \
        --parallel 1 \
        --no-cont-batching \
        --cache-ram 65536 \
        --temp 1.0 \
        --min-p 0.01

NOTE: To use the 24k batch size I'm having to use ikawrakow/ik_llama.cpp#1191 (ie: git checkout 109686af6f7ca4441adb556569dafcf9fa235478) as ikawrakow/ik_llama.cpp#1192 has some problem with NaNs from huge batches that I haven't had time to investigate yet.


Also in case anybody else wants to pick up on this, I found that ik_llama.cpp uses ggml_backend_cuda_set_tensor_async rather than ggml_backend_cuda_buffer_set_tensor is you don't disable the ggml_backend_sched_set_only_active_experts(ctx->sched, true) call (not sure what llama.cpp uses as default, and it could be either). You obviously can't cache the pointer addresses with ggml_backend_sched_set_only_active_experts(ctx->sched, true) though and not sure if llama.cpp does a similar optimisation or not.

@jukofyork
Copy link
Copy Markdown
Collaborator

Will it double the entire weights of the model or just the portion that is in sysram? For instance with qwen I have only some 60gb in sysram. That I can comfortably double within the same node.

It will just double anything that gets offloaded from RAM to VRAM that gets matched by the strstr(tensor->name, "_exps") if statement - you can print out debug info as it gets called and possibly tune this to fit your own setup if needed.

and just to clarify, we need only your code changes or to merge the PR & do the code changes?

No, it's completely self-contained hack using a static:

        static std::unordered_map<const void*, void*> pinned_cache;
        static std::mutex cache_mutex;

this is generally a horrible thing to do, but the code is really just to prove it works for now.

I haven't actually testes it on llama.cpp and the replacement function might need a bit of massaging to get it to compile (eg: I don't see the GGML_CALL at the front now in the current version of llama.cpp). It also might end up being ggml_backend_cuda_set_tensor_async that needs a similar hack in llama.cpp.

@jukofyork
Copy link
Copy Markdown
Collaborator

I probably won't have time to do much more on this as my back is killing me from PC use and I need a few days break... Feel free to try to pick up on this! My next steps would be:

  • Try to read up and/or debug how cudaHostAlloc is aligning the allocations in RAM.
  • Investigate more how GGUF_DEFAULT_ALIGNMENT works (is it really aligning the tensor starts, etc).
  • This code:
        size_t alignment = 64 * 1024; // 64KB

        void *temp = mmap(NULL, file->size() + alignment, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
        if (temp == MAP_FAILED) {
            throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
        }

        uintptr_t aligned = ((uintptr_t)temp + alignment - 1) & ~(alignment - 1);
        munmap(temp, file->size() + alignment);

        addr = mmap((void*)aligned, file->size(), PROT_READ, flags | MAP_FIXED_NOREPLACE, fd, 0);
        if (addr == MAP_FAILED) {
            throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
        }

        fprintf(stderr, "***** addr: %p *****\n", addr);

can be used to modify:

addr = mmap(NULL, file->size(), PROT_READ, flags, fd, 0);
to align the start of the GGUF file when it gets mmapped to 64k boundary:

eg:

***** addr: 0x7f5e8bb60000 *****

(confirmed working as 0x10000 = 64k)

There is a lot that can likely be tried like this.

@jukofyork
Copy link
Copy Markdown
Collaborator

jukofyork commented Feb 8, 2026

Also whilst doing this I found what might be a better way to pin the NUMA threads to the cores:

static void set_numa_thread_affinity(int thread_n) {
    cpu_set_t cpus;
    CPU_ZERO(&cpus);
    CPU_SET(thread_n, &cpus);
    pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpus);
}

If you run with -threads "$(nproc)" this ends up with exactly 1 thread on each core, and usually thread#0 consistently does the work when running offloaded PP (my "main" GPU is attached to CPU0) and:

image

it looks like this might avoid a lot of the "CMS" mesh stuff.

If your CPU has hyper-threading turned on then you might still be able to use this via -threads "$(($(nproc)/2))" if your hyper-threads are all high numbers (like mine), but I think AMD CPUs might alternate them... It's not hard to hack this to do whatever you want if you ask Claude or similar LLM though!

@jukofyork
Copy link
Copy Markdown
Collaborator

Also the thread ordering of this:

NODE0: 0, 1, 2, 3, 4, ..., 19
NODE1: 20, 21, 22, ..., 39

might work better for page faults when say ggml_mul_mat_id() is working on a single tensor, rather than the current --numa distribute:

NODE0: 0, 2, 4, ..., 36, 38
NODE1: 1, 3, 5, ..., 37, 39

The latter is likely to fault the 4k pages of the huge 3GB tensors over alternate nodes, etc... Worth investigating more IMO.

@Ph0rk0z
Copy link
Copy Markdown

Ph0rk0z commented Feb 8, 2026

Interesting to mess with mmap because generally I disable it to get faster speeds. Also put exp layers on GPU in addition to sysram and use only physical cores. Disabling HT caused worse TG performance.

@jukofyork
Copy link
Copy Markdown
Collaborator

Interesting to mess with mmap because generally I disable it to get faster speeds. Also put exp layers on GPU in addition to sysram and use only physical cores. Disabling HT caused worse TG performance.

Yeah, I think every setup need lots of tweaking - I have some other older machines with the previous generation of Xeons in (dual e5-2699v4 with quad-channel 2400Mhz DDR4) and IIRC, these most liked --numactl interleave=all!

I've tweaked mine to the nth degree now and using ik_llama.cpp I get:

  • Around 10.5-11 tokens/s TG with no context and around 8.5-9.5 tokens/s TG with a huge (100k+) context.
  • Anywhere from 25-65 tokens/s non-offloaded PP depending on the batch size.

It's actually quite usable in opencode now; even for very large context.

@Ph0rk0z
Copy link
Copy Markdown

Ph0rk0z commented Feb 8, 2026

Do you see difference in your pcm-memory benchmarks? I will try this stuff when I download stepfun. I think it's small enough around 110-120gb.

@rankaiyx
Copy link
Copy Markdown

rankaiyx commented Feb 9, 2026

One thing I've learned is that it's best to turn off memory thermal management in the BIOS.
It helps avoid some weird issues.

@Ph0rk0z
Copy link
Copy Markdown

Ph0rk0z commented Feb 9, 2026

One thing I've learned is that it's best to turn off memory thermal management in the BIOS.

In my server that doesn't exist.

So I tried both patches for affinity and tensor loading with stepfun, didn't seem to make any difference on speed one way or another. Is this tied to a specific model or having to do full exp=cpu? In my setups I put some layers on the GPUs to help prompt and textgen. Stepfun gives me about 400PP and 40t/s at 4KL.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Apple Metal https://en.wikipedia.org/wiki/Metal_(API) Ascend NPU issues specific to Ascend NPUs devops improvements to build systems and github actions examples ggml changes relating to the ggml tensor library for machine learning IBM zDNN issues specific to IBM zDNN Accelerator Nvidia GPU Issues specific to Nvidia GPUs OpenCL Issues specific to the OpenCL backend python python script changes SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language testing Everything test related Vulkan Issues specific to the Vulkan backend

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants