Skip to content

Commit 3978865

Browse files
ikawrakowabc-nix
authored andcommitted
Fused delta net 2 (ikawrakow#1320)
* Revive fused delta-net * Add command line argument for fused delta net * Simplify/improve CUDA delta-net * Add -fdn to llama-bench * More CUDA fused delta net optimizations * CPU optimizations * Much faster fused delta-net on the CPU It seems it is faster than the chunked implementation! * Change meaning of fdn from bool flag to threshold value * Use eps = 1e-6 * Give some nodes a name * Don't re-apply L2 norm - it has already been done * This seems quite a bit better * More tweaks * Restore per context buffer size log Not everybody uses models split in 2000 parts, and those who do, actually want to see the biffer sizes.
1 parent 4753e32 commit 3978865

File tree

3 files changed

+41
-78
lines changed

3 files changed

+41
-78
lines changed

ggml/src/ggml-cuda/delta-net.cu

Lines changed: 33 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -84,93 +84,75 @@ __global__ void delta_net_recurrent_f32(
8484
float * sQ = smem; // HEAD_DIM
8585
float * sK = sQ + HEAD_DIM; // HEAD_DIM
8686
float * sV = sK + HEAD_DIM; // HEAD_DIM
87-
float * sKBeta = sV + HEAD_DIM; // HEAD_DIM (plain k for state update)
88-
float * sVBeta = sKBeta + HEAD_DIM; // HEAD_DIM (v * sigmoid(beta))
89-
float * sOut = sVBeta + HEAD_DIM; // HEAD_DIM
90-
float * sKCumdecay = sOut + HEAD_DIM; // HEAD_DIM (k * sigmoid(beta) * exp(g))
91-
float * sVNew = sKCumdecay + HEAD_DIM; // HEAD_DIM (v_beta - v_prime)
87+
float * sVNew = sV + HEAD_DIM; // HEAD_DIM
9288

9389
const float scale = rsqrtf((float)HEAD_DIM);
9490

9591
__shared__ float sum_helper[block_size/WARP_SIZE];
9692

9793
// Copy initial state to output buffer (will be updated in place)
98-
for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) {
94+
for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += block_size) {
9995
state_dst[i] = state_src[i];
10096
}
101-
__syncthreads();
97+
98+
constexpr int HEAD_DIM_S = HEAD_DIM + 1;
99+
__shared__ float all_sum[2*HEAD_DIM_S*NUM_WARPS];
100+
auto all_sum1 = all_sum;
101+
auto all_sum2 = all_sum1 + HEAD_DIM_S*NUM_WARPS;
102102

103103
// Process each token sequentially
104104
for (int64_t t = 0; t < n_tokens; t++) {
105105

106-
float q_sq = 0.0f;
107-
float k_sq = 0.0f;
108-
for (int i = tid; i < HEAD_DIM; i += blockDim.x) {
109-
sQ[i] = q_ptr[t * qkv_stride_token + i];
106+
float sum_kq = 0.0f;
107+
for (int i = tid; i < HEAD_DIM; i += block_size) {
108+
sQ[i] = q_ptr[t * qkv_stride_token + i] * scale;
110109
sK[i] = k_ptr[t * qkv_stride_token + i];
111110
sV[i] = v_ptr[t * qkv_stride_token + i];
112-
q_sq += sQ[i] * sQ[i];
113-
k_sq += sK[i] * sK[i];
111+
sum_kq += sK[i] * sQ[i];
114112
}
115113

116-
q_sq = reduce_sum<block_size>(q_sq, sum_helper);
117-
k_sq = reduce_sum<block_size>(k_sq, sum_helper);
118-
119-
float q_norm = rsqrtf(q_sq + eps);
120-
float k_norm = rsqrtf(k_sq + eps);
114+
float attn_score = reduce_sum<block_size>(sum_kq, sum_helper);
121115

122116
float beta_val = sigmoid_f(beta_ptr[t]);
123117
float decay = expf(fminf(g_ptr[t], 50.0f));
124118

125-
float sum = 0;
126-
for (int i = tid; i < HEAD_DIM; i += blockDim.x) {
127-
sQ[i] = sQ[i] * q_norm * scale;
128-
sK[i] = sK[i] * k_norm;
129-
sKBeta[i] = sK[i];
130-
sVBeta[i] = sV[i] * beta_val;
131-
sKCumdecay[i] = sK[i] * beta_val * decay;
132-
sum += sK[i] * sQ[i];
133-
}
134-
float attn_score = reduce_sum<block_size>(sum, sum_helper);
135-
//__syncthreads();
136-
137-
for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) {
119+
for (int row_out = lane_id; row_out < HEAD_DIM; row_out += WARP_SIZE) {
138120
float sum1 = 0.0f;
139121
float sum2 = 0.0f;
140122
#pragma unroll
141-
for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) {
123+
for (int col = warp_id; col < HEAD_DIM; col += NUM_WARPS) {
142124
float sval = state_dst[row_out + col * HEAD_DIM];
143-
sum1 += sval * sKCumdecay[col];
125+
sum1 += sval * sK[col];
144126
sum2 += sval * sQ[col];
145127
}
146-
sum1 = warp_reduce_sum(sum1);
147-
sum2 = warp_reduce_sum(sum2);
148-
if (lane_id == 0) {
149-
sVNew[row_out] = sVBeta[row_out] - sum1;
150-
float v_attn = sVNew[row_out] * attn_score;
151-
//sOut[row_out] = sum2 * decay + v_attn;
152-
out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn;
128+
all_sum1[warp_id*HEAD_DIM_S + row_out] = sum1;
129+
all_sum2[warp_id*HEAD_DIM_S + row_out] = sum2;
130+
}
131+
__syncthreads();
132+
133+
for (int row_out = tid; row_out < HEAD_DIM; row_out += block_size) {
134+
float sum1 = all_sum1[row_out];
135+
float sum2 = all_sum2[row_out];
136+
#pragma unroll
137+
for (int i = 1; i < NUM_WARPS; ++i) {
138+
sum1 += all_sum1[row_out + i*HEAD_DIM_S];
139+
sum2 += all_sum2[row_out + i*HEAD_DIM_S];
153140
}
141+
sVNew[row_out] = sV[row_out] * beta_val - sum1 * beta_val * decay;
142+
float v_attn = sVNew[row_out] * attn_score;
143+
out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn;
154144
}
155145
__syncthreads();
156146

157147
for (int out_dim = warp_id; out_dim < HEAD_DIM; out_dim += NUM_WARPS) {
158148
#pragma unroll
159149
for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) {
160150
float state_val = state_dst[row + out_dim * HEAD_DIM];
161-
float safe_decay = decay;
162-
if (isnan(safe_decay) || isinf(safe_decay)) {
163-
safe_decay = 1.0f;
164-
}
165-
float new_state_val = safe_decay * state_val + sVNew[row] * sKBeta[out_dim];
151+
float new_state_val = decay * state_val + sVNew[row] * sK[out_dim];
166152
new_state_val = fminf(fmaxf(new_state_val, -1e6f), 1e6f);
167153
state_dst[row + out_dim * HEAD_DIM] = new_state_val;
168154
}
169155
}
170-
if (t < n_tokens - 1) {
171-
__syncthreads();
172-
}
173-
174156
}
175157
}
176158

@@ -408,9 +390,7 @@ static void delta_net_f32_cuda(
408390
const int num_blocks = n_seqs * n_heads;
409391
constexpr int threads_per_block = 512; //256;
410392

411-
// Shared memory: 9 * head_dim (for Q, K, V, KBeta, VBeta, Out, KCumdecay, VPrime, VNew)
412-
// Plus 6 floats for Norm[2], g_val, beta_val, decay, attn_score
413-
const size_t smem_size = (9 * head_dim + 6) * sizeof(float);
393+
const size_t smem_size = 4 * head_dim * sizeof(float);
414394

415395
// Use templated kernel for common head dimensions, generic for others
416396
if (head_dim == 64) {
@@ -421,6 +401,7 @@ static void delta_net_f32_cuda(
421401
delta_net_recurrent_f32<128, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
422402
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
423403
} else {
404+
GGML_ASSERT("Unsupported delta net head size");
424405
delta_net_recurrent_generic_f32<<<num_blocks, threads_per_block, smem_size, stream>>>(
425406
q, k, v, g, beta, state_in, dst, head_dim, n_tokens, n_heads, n_seqs, output_offset, eps);
426407
}

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,7 +1397,6 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
13971397
static_assert(head_dim % 8 == 0);
13981398
#endif
13991399

1400-
const float eps = 1e-6f;
14011400
const float scale = 1.0f / sqrtf((float) head_dim);
14021401

14031402
float v_new_buf[head_dim];
@@ -1428,42 +1427,25 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
14281427
const float g_val = g_data[g_head_offset + t];
14291428
const float beta_raw = beta_data[g_head_offset + t];
14301429

1431-
float q_norm_sq = 0.0f;
1432-
float k_norm_sq = 0.0f;
14331430
float kq_sum = 0.0f;
14341431
#ifdef __AVX2__
1435-
auto vqsum = _mm256_setzero_ps();
1436-
auto vksum = _mm256_setzero_ps();
14371432
auto vqksum = _mm256_setzero_ps();
14381433
for (int i = 0; i < head_dim; i += 8) {
14391434
auto vq = _mm256_loadu_ps(q_t + i);
14401435
auto vk = _mm256_loadu_ps(k_t + i);
1441-
vqsum = _mm256_fmadd_ps(vq, vq, vqsum);
1442-
vksum = _mm256_fmadd_ps(vk, vk, vksum);
14431436
vqksum = _mm256_fmadd_ps(vk, vq, vqksum);
14441437
}
1445-
q_norm_sq = hsum_float_8(vqsum);
1446-
k_norm_sq = hsum_float_8(vksum);
1447-
kq_sum = hsum_float_8(vqksum);
1438+
kq_sum = hsum_float_8(vqksum);
14481439
#else
14491440
for (int i = 0; i < head_dim; ++i) {
1450-
q_norm_sq += q_t[i] * q_t[i];
1451-
k_norm_sq += k_t[i] * k_t[i];
1452-
kq_sum += k_t[i] * q_t[i];
1441+
kq_sum += k_t[i] * q_t[i];
14531442
}
14541443
#endif
1455-
const float q_norm_inv = 1.0f / sqrtf(q_norm_sq + eps);
1456-
const float k_norm_inv = 1.0f / sqrtf(k_norm_sq + eps);
14571444

14581445
const float beta_val = 1.0f / (1.0f + expf(-beta_raw));
14591446
const float decay = expf(fminf(g_val, 50.0f));
14601447

1461-
float attn_score = kq_sum * k_norm_inv * q_norm_inv * scale;
1462-
1463-
//float attn_score = 0.0f;
1464-
//for (int i = 0; i < head_dim; ++i) {
1465-
// attn_score += (k_t[i] * k_norm_inv) * (q_t[i] * q_norm_inv * scale);
1466-
//}
1448+
float attn_score = kq_sum * scale;
14671449

14681450
float * out_t = out_data + out_head_offset + t * out_token_stride;
14691451

@@ -1479,17 +1461,17 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
14791461
}
14801462
}
14811463
for (int row = 0; row < head_dim; ++row) {
1482-
const float v_new = v_t[row] * beta_val - v_prime[row] * beta_val * decay * k_norm_inv;
1464+
const float v_new = v_t[row] * beta_val - v_prime[row] * beta_val * decay;
14831465
v_new_buf[row] = v_new;
1484-
out_t[row] = out_val[row] * decay * q_norm_inv * scale + v_new * attn_score;
1466+
out_t[row] = out_val[row] * decay * scale + v_new * attn_score;
14851467
}
14861468

14871469
#ifdef __AVX2__
14881470
auto vd = _mm256_set1_ps(decay);
14891471
auto vmin = _mm256_set1_ps(-1e6f);
14901472
auto vmax = _mm256_set1_ps( 1e6f);
14911473
for (int col = 0; col < head_dim; ++col) {
1492-
auto vk = _mm256_set1_ps(k_t[col] * k_norm_inv);
1474+
auto vk = _mm256_set1_ps(k_t[col]);
14931475
for (int row = 0; row < head_dim; row += 8) {
14941476
auto vs = _mm256_loadu_ps(state + col * head_dim + row);
14951477
auto vn = _mm256_loadu_ps(v_new_buf + row);
@@ -1503,7 +1485,7 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
15031485
}
15041486
#else
15051487
for (int col = 0; col < head_dim; ++col) {
1506-
const float k_col = k_t[col] * k_norm_inv;
1488+
const float k_col = k_t[col];
15071489
for (int row = 0; row < head_dim; ++row) {
15081490
float s = state[row + col * head_dim];
15091491
s = decay * s + v_new_buf[row] * k_col;

src/llama.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2222,7 +2222,7 @@ static bool llm_load_tensors(
22222222

22232223
// print memory requirements
22242224
for (ggml_backend_buffer_t buf : model.bufs) {
2225-
LLAMA_LOG_DEBUG("%s: %10s buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0);
2225+
LLAMA_LOG_INFO("%s: %10s buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0);
22262226
}
22272227

22282228
// populate tensors_by_name

0 commit comments

Comments
 (0)