Skip to content

Commit 73ec4ca

Browse files
committed
KV cache: Add Q5_0 scale adjustment optimization
Implement the same scale adjustment optimization for Q5_0 KV cache that was already applied to Q4_0 (PR ikawrakow#1547) and Q6_0. This optimization computes an optimal scale factor that minimizes quantization error by: 1. Computing weighted sums sumqx and sumq2 during quantization: - w0 = v0*v0, w1 = v1*v1 (weights based on actual values) - q0 = xi0 - 16, q1 = xi1 - 16 (quantized values offset) - sumqx += w0*q0*v0 + w1*q1*v1 - sumq2 += w0*q0*q0 + w1*q1*q1 2. Setting the final scale as y->d = sumqx/sumq2 when sumq2 > 0 This produces a computationally cheap but noticeable improvement in perplexity for KV cache quantization, similar to the results seen for Q4_0 in PR ikawrakow#1547. Based on work by Iwan Kawrakow (ikawrakow) - lead LLM quantization developer.
1 parent 0ddd2e9 commit 73ec4ca

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

ggml/src/ggml-cuda/cpy-utils.cuh

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,17 +109,29 @@ static __device__ void quantize_f32_q5_0_block(const float * __restrict__ x, blo
109109
y->d = d;
110110

111111
uint32_t qh = 0;
112+
float sumqx = 0, sumq2 = 0;
112113
for (int j = 0; j < QK5_0/2; ++j) {
113-
const float x0 = x[0 + j]*id;
114-
const float x1 = x[QK5_0/2 + j]*id;
114+
const float v0 = x[0 + j];
115+
const float v1 = x[QK5_0/2 + j];
116+
const float x0 = v0*id;
117+
const float x1 = v1*id;
115118

116119
const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
117120
const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
121+
float q0 = xi0 - 16;
122+
float q1 = xi1 - 16;
123+
float w0 = v0*v0;
124+
float w1 = v1*v1;
125+
sumqx += w0*q0*v0 + w1*q1*v1;
126+
sumq2 += w0*q0*q0 + w1*q1*q1;
118127

119128
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
120129
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
121130
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
122131
}
132+
if (sumq2 > 0) {
133+
y->d = sumqx/sumq2;
134+
}
123135
memcpy(y->qh, &qh, sizeof(qh));
124136
}
125137

0 commit comments

Comments
 (0)