Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 33 additions & 52 deletions ggml/src/ggml-cuda/delta-net.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,93 +84,75 @@ __global__ void delta_net_recurrent_f32(
float * sQ = smem; // HEAD_DIM
float * sK = sQ + HEAD_DIM; // HEAD_DIM
float * sV = sK + HEAD_DIM; // HEAD_DIM
float * sKBeta = sV + HEAD_DIM; // HEAD_DIM (plain k for state update)
float * sVBeta = sKBeta + HEAD_DIM; // HEAD_DIM (v * sigmoid(beta))
float * sOut = sVBeta + HEAD_DIM; // HEAD_DIM
float * sKCumdecay = sOut + HEAD_DIM; // HEAD_DIM (k * sigmoid(beta) * exp(g))
float * sVNew = sKCumdecay + HEAD_DIM; // HEAD_DIM (v_beta - v_prime)
float * sVNew = sV + HEAD_DIM; // HEAD_DIM

const float scale = rsqrtf((float)HEAD_DIM);

__shared__ float sum_helper[block_size/WARP_SIZE];

// Copy initial state to output buffer (will be updated in place)
for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) {
for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += block_size) {
state_dst[i] = state_src[i];
}
__syncthreads();

constexpr int HEAD_DIM_S = HEAD_DIM + 1;
__shared__ float all_sum[2*HEAD_DIM_S*NUM_WARPS];
auto all_sum1 = all_sum;
auto all_sum2 = all_sum1 + HEAD_DIM_S*NUM_WARPS;

// Process each token sequentially
for (int64_t t = 0; t < n_tokens; t++) {

float q_sq = 0.0f;
float k_sq = 0.0f;
for (int i = tid; i < HEAD_DIM; i += blockDim.x) {
sQ[i] = q_ptr[t * qkv_stride_token + i];
float sum_kq = 0.0f;
for (int i = tid; i < HEAD_DIM; i += block_size) {
sQ[i] = q_ptr[t * qkv_stride_token + i] * scale;
sK[i] = k_ptr[t * qkv_stride_token + i];
sV[i] = v_ptr[t * qkv_stride_token + i];
q_sq += sQ[i] * sQ[i];
k_sq += sK[i] * sK[i];
sum_kq += sK[i] * sQ[i];
}

q_sq = reduce_sum<block_size>(q_sq, sum_helper);
k_sq = reduce_sum<block_size>(k_sq, sum_helper);

float q_norm = rsqrtf(q_sq + eps);
float k_norm = rsqrtf(k_sq + eps);
float attn_score = reduce_sum<block_size>(sum_kq, sum_helper);

float beta_val = sigmoid_f(beta_ptr[t]);
float decay = expf(fminf(g_ptr[t], 50.0f));

float sum = 0;
for (int i = tid; i < HEAD_DIM; i += blockDim.x) {
sQ[i] = sQ[i] * q_norm * scale;
sK[i] = sK[i] * k_norm;
sKBeta[i] = sK[i];
sVBeta[i] = sV[i] * beta_val;
sKCumdecay[i] = sK[i] * beta_val * decay;
sum += sK[i] * sQ[i];
}
float attn_score = reduce_sum<block_size>(sum, sum_helper);
//__syncthreads();

for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) {
for (int row_out = lane_id; row_out < HEAD_DIM; row_out += WARP_SIZE) {
float sum1 = 0.0f;
float sum2 = 0.0f;
#pragma unroll
for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) {
for (int col = warp_id; col < HEAD_DIM; col += NUM_WARPS) {
float sval = state_dst[row_out + col * HEAD_DIM];
sum1 += sval * sKCumdecay[col];
sum1 += sval * sK[col];
sum2 += sval * sQ[col];
}
sum1 = warp_reduce_sum(sum1);
sum2 = warp_reduce_sum(sum2);
if (lane_id == 0) {
sVNew[row_out] = sVBeta[row_out] - sum1;
float v_attn = sVNew[row_out] * attn_score;
//sOut[row_out] = sum2 * decay + v_attn;
out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn;
all_sum1[warp_id*HEAD_DIM_S + row_out] = sum1;
all_sum2[warp_id*HEAD_DIM_S + row_out] = sum2;
}
__syncthreads();

for (int row_out = tid; row_out < HEAD_DIM; row_out += block_size) {
float sum1 = all_sum1[row_out];
float sum2 = all_sum2[row_out];
#pragma unroll
for (int i = 1; i < NUM_WARPS; ++i) {
sum1 += all_sum1[row_out + i*HEAD_DIM_S];
sum2 += all_sum2[row_out + i*HEAD_DIM_S];
}
sVNew[row_out] = sV[row_out] * beta_val - sum1 * beta_val * decay;
float v_attn = sVNew[row_out] * attn_score;
out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn;
}
__syncthreads();

for (int out_dim = warp_id; out_dim < HEAD_DIM; out_dim += NUM_WARPS) {
#pragma unroll
for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) {
float state_val = state_dst[row + out_dim * HEAD_DIM];
float safe_decay = decay;
if (isnan(safe_decay) || isinf(safe_decay)) {
safe_decay = 1.0f;
}
float new_state_val = safe_decay * state_val + sVNew[row] * sKBeta[out_dim];
float new_state_val = decay * state_val + sVNew[row] * sK[out_dim];
new_state_val = fminf(fmaxf(new_state_val, -1e6f), 1e6f);
state_dst[row + out_dim * HEAD_DIM] = new_state_val;
}
}
if (t < n_tokens - 1) {
__syncthreads();
}

}
}

Expand Down Expand Up @@ -408,9 +390,7 @@ static void delta_net_f32_cuda(
const int num_blocks = n_seqs * n_heads;
constexpr int threads_per_block = 512; //256;

// Shared memory: 9 * head_dim (for Q, K, V, KBeta, VBeta, Out, KCumdecay, VPrime, VNew)
// Plus 6 floats for Norm[2], g_val, beta_val, decay, attn_score
const size_t smem_size = (9 * head_dim + 6) * sizeof(float);
const size_t smem_size = 4 * head_dim * sizeof(float);

// Use templated kernel for common head dimensions, generic for others
if (head_dim == 64) {
Expand All @@ -421,6 +401,7 @@ static void delta_net_f32_cuda(
delta_net_recurrent_f32<128, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
} else {
GGML_ASSERT("Unsupported delta net head size");
delta_net_recurrent_generic_f32<<<num_blocks, threads_per_block, smem_size, stream>>>(
q, k, v, g, beta, state_in, dst, head_dim, n_tokens, n_heads, n_seqs, output_offset, eps);
}
Expand Down
32 changes: 7 additions & 25 deletions ggml/src/iqk/iqk_mul_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1397,7 +1397,6 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
static_assert(head_dim % 8 == 0);
#endif

const float eps = 1e-6f;
const float scale = 1.0f / sqrtf((float) head_dim);

float v_new_buf[head_dim];
Expand Down Expand Up @@ -1428,42 +1427,25 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
const float g_val = g_data[g_head_offset + t];
const float beta_raw = beta_data[g_head_offset + t];

float q_norm_sq = 0.0f;
float k_norm_sq = 0.0f;
float kq_sum = 0.0f;
#ifdef __AVX2__
auto vqsum = _mm256_setzero_ps();
auto vksum = _mm256_setzero_ps();
auto vqksum = _mm256_setzero_ps();
for (int i = 0; i < head_dim; i += 8) {
auto vq = _mm256_loadu_ps(q_t + i);
auto vk = _mm256_loadu_ps(k_t + i);
vqsum = _mm256_fmadd_ps(vq, vq, vqsum);
vksum = _mm256_fmadd_ps(vk, vk, vksum);
vqksum = _mm256_fmadd_ps(vk, vq, vqksum);
}
q_norm_sq = hsum_float_8(vqsum);
k_norm_sq = hsum_float_8(vksum);
kq_sum = hsum_float_8(vqksum);
kq_sum = hsum_float_8(vqksum);
#else
for (int i = 0; i < head_dim; ++i) {
q_norm_sq += q_t[i] * q_t[i];
k_norm_sq += k_t[i] * k_t[i];
kq_sum += k_t[i] * q_t[i];
kq_sum += k_t[i] * q_t[i];
}
#endif
const float q_norm_inv = 1.0f / sqrtf(q_norm_sq + eps);
const float k_norm_inv = 1.0f / sqrtf(k_norm_sq + eps);

const float beta_val = 1.0f / (1.0f + expf(-beta_raw));
const float decay = expf(fminf(g_val, 50.0f));

float attn_score = kq_sum * k_norm_inv * q_norm_inv * scale;

//float attn_score = 0.0f;
//for (int i = 0; i < head_dim; ++i) {
// attn_score += (k_t[i] * k_norm_inv) * (q_t[i] * q_norm_inv * scale);
//}
float attn_score = kq_sum * scale;

float * out_t = out_data + out_head_offset + t * out_token_stride;

Expand All @@ -1479,17 +1461,17 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
}
}
for (int row = 0; row < head_dim; ++row) {
const float v_new = v_t[row] * beta_val - v_prime[row] * beta_val * decay * k_norm_inv;
const float v_new = v_t[row] * beta_val - v_prime[row] * beta_val * decay;
v_new_buf[row] = v_new;
out_t[row] = out_val[row] * decay * q_norm_inv * scale + v_new * attn_score;
out_t[row] = out_val[row] * decay * scale + v_new * attn_score;
}

#ifdef __AVX2__
auto vd = _mm256_set1_ps(decay);
auto vmin = _mm256_set1_ps(-1e6f);
auto vmax = _mm256_set1_ps( 1e6f);
for (int col = 0; col < head_dim; ++col) {
auto vk = _mm256_set1_ps(k_t[col] * k_norm_inv);
auto vk = _mm256_set1_ps(k_t[col]);
for (int row = 0; row < head_dim; row += 8) {
auto vs = _mm256_loadu_ps(state + col * head_dim + row);
auto vn = _mm256_loadu_ps(v_new_buf + row);
Expand All @@ -1503,7 +1485,7 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
}
#else
for (int col = 0; col < head_dim; ++col) {
const float k_col = k_t[col] * k_norm_inv;
const float k_col = k_t[col];
for (int row = 0; row < head_dim; ++row) {
float s = state[row + col * head_dim];
s = decay * s + v_new_buf[row] * k_col;
Expand Down
2 changes: 1 addition & 1 deletion src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2209,7 +2209,7 @@ static bool llm_load_tensors(

// print memory requirements
for (ggml_backend_buffer_t buf : model.bufs) {
LLAMA_LOG_DEBUG("%s: %10s buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0);
LLAMA_LOG_INFO("%s: %10s buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0);
}

// populate tensors_by_name
Expand Down