Skip to content

feat: CUDA port of TurboQuant3 KV cache — 3.47x compression, 98.5% of F16 decode speed on RTX 5090#3

Open
signalnine wants to merge 4 commits intoTheTom:feature/turboquant-kv-cachefrom
signalnine:feature/turboquant-kv-cache
Open

feat: CUDA port of TurboQuant3 KV cache — 3.47x compression, 98.5% of F16 decode speed on RTX 5090#3
signalnine wants to merge 4 commits intoTheTom:feature/turboquant-kv-cachefrom
signalnine:feature/turboquant-kv-cache

Conversation

@signalnine
Copy link
Copy Markdown

Summary

This PR ports TurboQuant3 (turbo3) KV cache compression to CUDA, targeting SM 12.0 (RTX 5090 / Blackwell) with near-parity decode performance vs F16.

What's included

CUDA kernel port (set-rows.cu, turbo-quant.cuh, turbo-wht.cu/cuh):

  • k_set_rows_turbo3: quantises incoming F32 KV tokens into turbo3 blocks on GPU
  • Fully parallel design: one block per 128-element group, 128 threads/block
  • WHT via shared-memory butterfly (7 stages), L2 norm via warp reduction
  • Bit packing with __shfl_sync (qs) and __ballot_sync (signs) — no atomics
  • Reconstruction norm corrected for quantisation error before writing

Flash attention integration (fattn-common.cuh, fattn-vec.cuh, fattn.cu):

  • vec_dot_fattn_vec_KQ_turbo3_0: optimised KQ dot product — elem0/elem1 always share the same turbo3 block, so qs/signs loaded once per pair instead of twice
  • dequantize_V_turbo3_0: ne==4 fast path — single qs byte + single signs byte covers all 4 elements; unrolled float2/half2 output
  • Routes decode (Q→ne[1] ≤ 2) through VEC flash attention kernel on Ada/Blackwell (CC ≥ 890)

Quality gate / auto-enable (llama.cpp, ggml-cuda.cu):

  • Flash attention auto-enabled when turbo cache types are detected
  • ggml_context overflow fix for large KV cache allocations

Benchmark results (Qwen3.5 35B, RTX 5090, tg128)

KV cache Decode (t/s) vs F16
F16 95.4 1.00×
q8_0 95.7 1.00×
turbo3 94.0 0.985×

Memory: 3.47× compression vs F16 (3-bit vs 16-bit KV cache)

Key design choices

  • Group size = 128 (one WHT per head-dim for typical 128-dim heads), matching the Python reference
  • Norm correction: stores grp_norm / recon_norm (not raw grp_norm) in the half-precision norm field so dequant is a single multiply
  • __launch_bounds__(128) on the quantisation kernel prevents spilling with the large shared memory footprint

Testing

Validated with llama-bench tg32/tg128 on Qwen3.5 35B. NIAH (needle-in-a-haystack) quality tested at multiple context lengths — results in the companion Python repo.

🤖 Generated with Claude Code

signalnine and others added 3 commits March 26, 2026 15:46
Ports the Metal turbo3 implementation to NVIDIA CUDA. End-to-end working
on RTX 5090 with Qwen3.5 35B A3B Q4_K_M: 3.47x KV compression at 32K
context, ~4x max context extension (256K → 1M tokens on 32GB VRAM).

New files:
- ggml-cuda/turbo-quant.cuh   — block_turbo3_0 layout, WHT sign arrays,
                                 3-bit centroid LUT, dequant helpers,
                                 quantize kernel (set_rows path)
- ggml-cuda/turbo-wht.cu/.cuh — GGML_OP_TURBO_WHT CUDA kernel; 128-thread
                                 blocks, in-place butterfly in shared memory,
                                 forward + inverse WHT via compile-time template
- ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo3_0.cu
                               — VEC flash-attention instance for D=64/128/256

Modified files:
- dequantize.cuh    — dequantize_turbo3_0 (produces float2 pairs)
- convert.cu        — all 5 to-fp16/fp32 dispatchers
- fattn-common.cuh  — vec_dot_fattn_vec_KQ_turbo3_0, dequantize_V_turbo3_0,
                       dispatcher extensions
- fattn-vec.cuh     — turbo3 treated as unquantized (f16-style nthreads_KQ)
- fattn.cu          — route turbo3 exclusively to VEC kernel; add dispatch macro
- set-rows.cu       — k_set_rows_turbo3 kernel: per-128-elem group quantization
                       with WHT rotation and norm correction
- ggml-cuda.cu      — supports_op + compute dispatch for TURBO_WHT + SET_ROWS
- llama-kv-cache.cpp — +2 tensor overhead for rotation matrices

Benchmark (RTX 5090, Qwen3.5 35B A3B Q4_K_M, FA on):
  KV memory @32k: 702 MiB (f16) → 202 MiB (turbo3)  = 3.47x compression
  Max context:    ~256K (f16)   → ~1M  (turbo3)       = ~4x extension
  Decode @short:  233 t/s (q8_0) → 190 t/s (turbo3)  = 0.82x
  Prefill @32k:   6335 t/s (q8_0) → 1215 t/s (turbo3) = 0.19x

Note: prefill speed degrades significantly vs Metal (Metal: 0.99x q8_0 at all
contexts; CUDA: 0.19x at 32K). Root cause: turbo3 currently uses the VEC
flash-attention kernel; q8_0 uses the more efficient TILE/MMA kernels at long
context. TILE/MMA support for turbo3 is the next milestone.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Remove the 5-line early-return that forced turbo3 onto the VEC flash
attention kernel.  The VEC kernel is still used for decode (Q->ne[1]==1)
via the existing dispatch logic, but prefill now goes through the Turing
MMA kernel (RTX 5090 is SM 12.0 >> 7.5).

launch_fattn already pre-dequantizes K/V to FP16 when need_f16_K/V are
set (which TILE/MMA always pass as true).  Our ggml_get_to_fp16_cuda and
ggml_get_to_fp16_nc_cuda dispatchers for TURBO3_0 — added in the original
CUDA port commit — provide that conversion automatically.  Stride
recalculation (nb11 = nb11*bs*sizeof(half)/ts) also works out correctly
for turbo3 (bs=32, ts=14):  nb11*32*2/14 = ne[0]*sizeof(half). ✓

Before (VEC only):                    After (MMA for prefill):
  2K prefill:  5032 t/s (0.73× q8_0)   6734 t/s (0.98× q8_0)
  8K prefill:  3110 t/s (0.46× q8_0)   6613 t/s (0.98× q8_0)
 32K prefill:  1215 t/s (0.19× q8_0)   6168 t/s (0.97× q8_0)

Matches Metal M5 Max result (0.99× q8_0 flat across all context sizes).

Decode unchanged (VEC, ~0.64-0.82× q8_0 depending on context depth).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…e (71→94 t/s, 98.5% of F16)

k_set_rows_turbo3 was the decode bottleneck: 1 thread/group serial kernel
gave 3.1% GPU utilisation (36.5 µs × 80 calls/token = ~21% of decode budget).

Replace with a fully parallel kernel — 1 block per 128-element group,
128 threads per block (one thread per element):
  • Shared-memory WHT butterfly (7 stages, no atomics)
  • Warp-reduce L2 norm + inter-warp accumulate via smem
  • qs packed with __shfl_sync (4-wide gather), signs with __ballot_sync
  • Reconstruction norm same pattern; one write per sub-block (warp lane 0)

Also tighten flash-attention dequant paths (fattn-common.cuh):
  • vec_dot_fattn_vec_KQ_turbo3_0: elem0/elem1 always share the same
    turbo3 block — load qs/signs once instead of twice per pair
  • dequantize_V_turbo3_0: ne==4 fast path — load one qs byte and one
    signs byte for all 4 elements; unrolled float2 / half2 output pairs

Benchmark (Qwen3.5 35B, RTX 5090, tg128):
  Before: 71.86 t/s (0.75× q8_0)
  After:  94.04 t/s (0.985× q8_0, within measurement noise of parity)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@signalnine signalnine changed the base branch from master to feature/turboquant-kv-cache March 27, 2026 00:34
seanrasch pushed a commit to seanrasch/llama-cpp-turboquant that referenced this pull request Mar 27, 2026
Codex post-commit review found:
1. TURBO_D was QK_TURBO3 (now 32) — broke turbo4 C array sizes
2. SET_ROWS kernel turbo3-specific but instantiated for turbo4
3. Tail block drop for non-128 head dims

Fixed TheTom#3 (TURBO_D). TheTom#1 and TheTom#2 don't affect turbo3+dk128 path.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TheTom added a commit that referenced this pull request Mar 27, 2026
Complete experiment log:
  #1  4-mag LUT:           15.1 at 8K (BEST, +38%)
  #2  Batched extract:     13.7 (+25%)
  #3  Inline FA block:     13.5 (I-cache pressure)
  #4  Deferred norm:       12.9 (loses ILP)
  ggml-org#5  2-pair half2:        12.0 (ternary overhead)
  ggml-org#6  Select chain:        11.9 (branches kill)
  ggml-org#7  Bit-arithmetic:      11.6 (ALU too heavy)
  ggml-org#8  FMA branchless:      11.4 (ALU still too heavy)
  ggml-org#9  Named-reg ternary:   10.3 (branches worst)
  ggml-org#10 Main (8-LUT):        10.95 (baseline)
  ggml-org#11 Non-vec FA:          10.2 (wrong kernel)
  Ceiling:                 24.5 (no dequant)

Apple8 hardware truth:
  1 divergent constant read < 7 ALU ops (even with fma)
  Branches cost MORE than divergent constant reads
  Array indexing ALWAYS spills on Metal
  4 constant addresses is the sweet spot

The 4-mag LUT is the dequant-level ceiling on Apple Silicon.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: tturney@psyguard.ai
@signalnine signalnine marked this pull request as draft March 27, 2026 05:50
@signalnine
Copy link
Copy Markdown
Author

Looks like I've got some quality issues with larger context windows, converting to draft while I work those out.

The outer loop in vec_dot_fattn_vec_KQ_turbo3_0 stepped k_KQ_0 by
`nthreads` (8), but the Q register file is loaded in blocks of
`nthreads*cpy_ne` (32) elements per thread — the same pattern used by
the f16/bf16 VEC kernels. This caused thread t>0 to pair K element
(16s + 2t) against Q element (64*(s/4) + 8t + 2*(s%4)), a complete
index mismatch. Every generated token had garbage attention scores.

Fix: match the f16 kernel pattern — step by nthreads*cpy_ne, add an
inner k_KQ_1 loop over cpy_ne pairs, and index Q_v as
Q_v[k_KQ_0/nthreads + k_KQ_1].

Also clean up stale "PPL 23.5 vs 6.19" TODO comments in llama-graph.cpp
that documented the symptom of this bug.

Tested on RTX 5090, Qwen3.5-35B-A3B-Q4_K_M:
- PPL (wikitext-2): 6.2023 → 6.2996 (+1.57%, within 5% target)
- NIAH: 11/11 at 4K–256K (matches q8_0; was 0/11 before fix)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@signalnine
Copy link
Copy Markdown
Author

Bug fix: VEC flash-attention Q/K stride mismatch (commit 4c91451)

Root cause

vec_dot_fattn_vec_KQ_turbo3_0 in fattn-common.cuh had the wrong outer loop stride. It stepped k_KQ_0 by nthreads (8), but Q registers are loaded in blocks of nthreads*cpy_ne (32) elements per thread — the same pattern the f16/bf16 VEC kernels use.

This caused thread t > 0 to pair K element 16s + 2t against Q element 64*(s/4) + 8t + 2*(s%4). For example, thread 1 paired K[2] with Q[8] instead of Q[2]. The f16 kernel avoids this by stepping its outer loop by nthreads*cpy_ne and processing cpy_ne K elements per thread per iteration.

This kernel is used for all generation steps (n_tokens ≤ 2), so every generated token had garbage attention scores. Prefill (MMA kernel path) was unaffected.

Fix (3 lines changed)

// Before — wrong stride, Q/K indices misaligned for thread t > 0:
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads) {
    const int k_KQ = k_KQ_0 + (threadIdx.x % nthreads);
    ...
    const float2 qv = ((const float2 *) Q_v)[k_KQ_0/nthreads];

// After — matches f16 kernel pattern:
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
    for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
        const int k_KQ = k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne + k_KQ_1;
        ...
        const float2 qv = ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1];

Test results (RTX 5090, Qwen3.5-35B-A3B-Q4_K_M)

Perplexity (wikitext-2, 512 ctx):

PPL
Baseline (no compression) 6.2023
turbo3 after fix 6.2996 (+1.57%)
turbo3 before fix ~23.5 (garbage)

NIAH single-needle (11 depth points, q8_0 vs turbo3):

4K 8K 16K 32K 64K 128K 256K
q8_0 11/11 11/11 11/11 10/11 10/11 9/11 10/11
turbo3 11/11 11/11 11/11 10/11 9/11 10/11 11/11

Aggregate score is identical (turbo3 = q8_0). The few single-cell differences are in opposite directions and consistent with model-level retrieval variance, not KV compression degradation.

Speed (llama-bench):

q8_0 turbo3 ratio
Prefill 4K 6947 t/s 6853 t/s 98.6%
Prefill 32K 6380 t/s 6301 t/s 98.8%
Prefill 128K 4731 t/s 4711 t/s 99.6%
Generation 218 t/s 207 t/s 95.0%

Prefill is at parity. Generation is ~5% slower due to the centroid lookup overhead — this model is compute-bound during decoding (Q4_K_M MoE weights dominate bandwidth), so the 3.47x smaller KV cache doesn't help here. On a weight-fp16 or large-batch workload the bandwidth savings would show.

@signalnine signalnine marked this pull request as ready for review March 27, 2026 17:40
@signalnine signalnine closed this Mar 27, 2026
@signalnine signalnine reopened this Mar 27, 2026
@dan-and
Copy link
Copy Markdown

dan-and commented Mar 27, 2026

Thats interesting. I will give it a try. I gave the fork from Madreag/turbo3-cuda a run tonight and it had also issues at large context sizes.

q8

llama-benchy --base-url http://127.0.01:18080 --model llamacpp-model --depth 0 4096 8192 16384 204800  --tg 32 128 --latency-mode generation


| model          |             test |             t/s |     peak t/s |             ttfr (ms) |          est_ppt (ms) |         e2e_ttft (ms) |
|:---------------|-----------------:|----------------:|-------------:|----------------------:|----------------------:|----------------------:|
| llamacpp-model |           pp2048 | 1282.76 ± 39.17 |              |       1593.28 ± 30.46 |       1439.92 ± 30.46 |       1593.33 ± 30.46 |
| llamacpp-model |             tg32 |    63.62 ± 0.46 | 65.71 ± 0.48 |                       |                       |                       |
| llamacpp-model |           pp2048 | 1153.16 ± 50.08 |              |       1752.58 ± 46.67 |       1599.21 ± 46.67 |       1752.62 ± 46.67 |
| llamacpp-model |            tg128 |    62.00 ± 0.28 | 62.67 ± 0.47 |                       |                       |                       |
| llamacpp-model |   pp2048 @ d4096 | 1212.91 ± 52.05 |              |      4744.25 ± 242.83 |      4590.88 ± 242.83 |      4744.29 ± 242.83 |
| llamacpp-model |     tg32 @ d4096 |    60.08 ± 0.20 | 62.05 ± 0.20 |                       |                       |                       |
| llamacpp-model |   pp2048 @ d4096 | 1057.25 ± 18.41 |              |      5444.91 ± 207.05 |      5291.54 ± 207.05 |      5444.95 ± 207.05 |
| llamacpp-model |    tg128 @ d4096 |    58.47 ± 0.52 | 59.67 ± 0.47 |                       |                       |                       |
| llamacpp-model |   pp2048 @ d8192 |  988.52 ± 36.40 |              |      9567.42 ± 275.26 |      9414.05 ± 275.26 |      9567.46 ± 275.26 |
| llamacpp-model |     tg32 @ d8192 |    57.53 ± 0.88 | 59.45 ± 0.90 |                       |                       |                       |
| llamacpp-model |   pp2048 @ d8192 |  878.16 ± 20.97 |              |     10821.39 ± 312.45 |     10668.03 ± 312.45 |     10821.43 ± 312.45 |
| llamacpp-model |    tg128 @ d8192 |    55.67 ± 0.21 | 57.00 ± 0.00 |                       |                       |                       |
| llamacpp-model |  pp2048 @ d16384 |  797.96 ± 49.51 |              |    21042.87 ± 1204.39 |    20889.50 ± 1204.39 |    21042.91 ± 1204.39 |
| llamacpp-model |    tg32 @ d16384 |    53.09 ± 1.42 | 54.86 ± 1.47 |                       |                       |                       |
| llamacpp-model |  pp2048 @ d16384 |  661.82 ± 18.54 |              |     25294.80 ± 747.67 |     25141.44 ± 747.67 |     25294.84 ± 747.67 |
| llamacpp-model |   tg128 @ d16384 |    49.39 ± 0.13 | 51.00 ± 0.00 |                       |                       |                       |
| llamacpp-model | pp2048 @ d204800 |  281.94 ± 66.24 |              | 698075.03 ± 140499.74 | 697921.66 ± 140499.74 | 698075.09 ± 140499.73 |
| llamacpp-model |   tg32 @ d204800 |    29.33 ± 0.67 | 30.00 ± 0.82 |                       |                       |                       |
| llamacpp-model | pp2048 @ d204800 |   230.29 ± 2.67 |              |   814309.55 ± 9528.99 |   814156.18 ± 9528.99 |   814309.59 ± 9528.99 |
| llamacpp-model |  tg128 @ d204800 |    28.13 ± 0.06 | 29.00 ± 0.00 |                       |                       |                       |

llama-benchy (0.3.2.dev1+g17b42667a)
date: 2026-03-27 18:08:57 | latency mode: generation

CUDA_VISIBLE_DEVICES=0,1,2,3 build/bin/llama-server --webui-mcp-proxy --alias llamacpp-model -m ../models/Qwen3.5-35B-A3B-UD-Q8_K_XL.gguf --temp 0.6 --top-p 0.95 --top-k 20 --min-p 0.00 --kv-unified -ctk q8_0 -ctv q8_0 --swa-full --presence-penalty 1.5 --repeat-penalty 1.0 --ctx-size 260000 -fa on --no-mmap --jinja --threads -1 --reasoning on --metrics --host 0.0.0.0 --port 18080 --alias llamacpp-model


llama_memory_breakdown_print: | memory breakdown [MiB] | total   free     self   model   context   compute    unaccounted |
llama_memory_breakdown_print: |   - CUDA0 (RTX 3080)   | 20054 = 3910 + (15352 = 12367 +     615 +    2370) +         792 |
llama_memory_breakdown_print: |   - CUDA1 (RTX 3080)   | 20054 = 6064 + (13196 = 11228 +     868 +    1100) +         794 |
llama_memory_breakdown_print: |   - CUDA2 (RTX 3080)   | 20054 = 6314 + (12947 = 11240 +     606 +    1100) +         793 |
llama_memory_breakdown_print: |   - CUDA3 (RTX 3080)   | 20054 = 6260 + (13001 = 10616 +     859 +    1525) +         793 |
llama_memory_breakdown_print: |   - Host               |                  3010 =   970 +       0 +    2040                |


Madreag/turbo3-cuda turboquant

CUDA_VISIBLE_DEVICES=0,1,2,3 build/bin/llama-server --webui-mcp-proxy --alias llamacpp-model -m ../models/Qwen3.5-35B-A3B-UD-Q8_K_XL.gguf --temp 0.6 --top-p 0.95 --top-k 20 --min-p 0.00 --kv-unified -ctk turbo3 -ctv turbo3 --swa-full --presence-penalty 1.5 --repeat-penalty 1.0 --ctx-size 260000 -fa on --no-mmap --jinja --threads -1 --reasoning on --metrics --host 0.0.0.0 --port 18080 --alias llamacpp-model

| model          |             test |             t/s |     peak t/s |             ttfr (ms) |          est_ppt (ms) |         e2e_ttft (ms) |
|:---------------|-----------------:|----------------:|-------------:|----------------------:|----------------------:|----------------------:|
| llamacpp-model |           pp2048 | 1287.54 ± 66.23 |              |       1528.95 ± 58.07 |       1434.45 ± 58.07 |       1528.99 ± 58.07 |
| llamacpp-model |             tg32 |    53.96 ± 1.86 | 55.72 ± 1.91 |                       |                       |                       |
| llamacpp-model |           pp2048 | 1106.94 ± 53.64 |              |       1744.38 ± 48.83 |       1649.88 ± 48.83 |       1744.42 ± 48.82 |
| llamacpp-model |            tg128 |    48.93 ± 0.09 | 50.00 ± 0.00 |                       |                       |                       |
| llamacpp-model |   pp2048 @ d4096 | 1189.36 ± 46.89 |              |      4849.93 ± 249.32 |      4755.43 ± 249.32 |      4849.97 ± 249.32 |
| llamacpp-model |     tg32 @ d4096 |    43.29 ± 1.98 | 44.70 ± 2.04 |                       |                       |                       |
| llamacpp-model |   pp2048 @ d4096 | 1040.50 ± 22.35 |              |       5517.25 ± 61.20 |       5422.75 ± 61.20 |       5517.29 ± 61.20 |
| llamacpp-model |    tg128 @ d4096 |    36.40 ± 0.05 | 38.00 ± 0.00 |                       |                       |                       |
| llamacpp-model |   pp2048 @ d8192 |  981.78 ± 33.00 |              |      9510.98 ± 489.26 |      9416.48 ± 489.26 |      9511.02 ± 489.26 |
| llamacpp-model |     tg32 @ d8192 |    33.75 ± 1.54 | 34.85 ± 1.59 |                       |                       |                       |
| llamacpp-model |   pp2048 @ d8192 |  873.38 ± 19.11 |              |     10607.80 ± 207.11 |     10513.30 ± 207.11 |     10607.85 ± 207.11 |
| llamacpp-model |    tg128 @ d8192 |    28.35 ± 0.55 | 30.33 ± 0.47 |                       |                       |                       |
| llamacpp-model |  pp2048 @ d16384 |  799.01 ± 50.18 |              |    20877.31 ± 1192.24 |    20782.81 ± 1192.24 |    20877.35 ± 1192.24 |
| llamacpp-model |    tg32 @ d16384 |    25.43 ± 1.50 | 26.00 ± 1.63 |                       |                       |                       |
| llamacpp-model |  pp2048 @ d16384 |  663.46 ± 16.78 |              |     25337.12 ± 764.88 |     25242.63 ± 764.88 |     25337.17 ± 764.89 |
| llamacpp-model |   tg128 @ d16384 |    19.75 ± 0.32 | 21.33 ± 0.47 |                       |                       |                       |
| llamacpp-model | pp2048 @ d204800 |  281.80 ± 66.36 |              | 698821.15 ± 142220.41 | 698726.65 ± 142220.41 | 698821.21 ± 142220.39 |
| llamacpp-model |   tg32 @ d204800 |     5.36 ± 0.27 |  6.00 ± 0.00 |                       |                       |                       |
| llamacpp-model | pp2048 @ d204800 |   235.28 ± 2.53 |              |   796568.05 ± 8491.18 |   796473.55 ± 8491.18 |   796568.10 ± 8491.18 |
| llamacpp-model |  tg128 @ d204800 |     5.38 ± 0.23 |  6.00 ± 0.00 |                       |                       |                       |

llama-benchy (0.3.2.dev1+g17b42667a)
date: 2026-03-27 19:55:00 | latency mode: generation

llama_memory_breakdown_print: | memory breakdown [MiB] | total   free     self   model   context   compute    unaccounted |
llama_memory_breakdown_print: |   - CUDA0 (RTX 3080)   | 20054 = 4218 + (15042 = 12367 +     297 +    2378) +         794 |
llama_memory_breakdown_print: |   - CUDA1 (RTX 3080)   | 20054 = 6540 + (12720 = 11228 +     392 +    1100) +         794 |
llama_memory_breakdown_print: |   - CUDA2 (RTX 3080)   | 20054 = 6630 + (12629 = 11240 +     289 +    1100) +         795 |
llama_memory_breakdown_print: |   - CUDA3 (RTX 3080)   | 20054 = 6736 + (12525 = 10616 +     383 +    1525) +         793 |
llama_memory_breakdown_print: |   - Host               |                  3010 =   970 +       0 +    2040                |

(will update with your fork)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants