@@ -84,93 +84,75 @@ __global__ void delta_net_recurrent_f32(
8484 float * sQ = smem; // HEAD_DIM
8585 float * sK = sQ + HEAD_DIM; // HEAD_DIM
8686 float * sV = sK + HEAD_DIM; // HEAD_DIM
87- float * sKBeta = sV + HEAD_DIM; // HEAD_DIM (plain k for state update)
88- float * sVBeta = sKBeta + HEAD_DIM; // HEAD_DIM (v * sigmoid(beta))
89- float * sOut = sVBeta + HEAD_DIM; // HEAD_DIM
90- float * sKCumdecay = sOut + HEAD_DIM; // HEAD_DIM (k * sigmoid(beta) * exp(g))
91- float * sVNew = sKCumdecay + HEAD_DIM; // HEAD_DIM (v_beta - v_prime)
87+ float * sVNew = sV + HEAD_DIM; // HEAD_DIM
9288
9389 const float scale = rsqrtf ((float )HEAD_DIM);
9490
9591 __shared__ float sum_helper[block_size/WARP_SIZE];
9692
9793 // Copy initial state to output buffer (will be updated in place)
98- for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim . x ) {
94+ for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += block_size ) {
9995 state_dst[i] = state_src[i];
10096 }
101- __syncthreads ();
97+
98+ constexpr int HEAD_DIM_S = HEAD_DIM + 1 ;
99+ __shared__ float all_sum[2 *HEAD_DIM_S*NUM_WARPS];
100+ auto all_sum1 = all_sum;
101+ auto all_sum2 = all_sum1 + HEAD_DIM_S*NUM_WARPS;
102102
103103 // Process each token sequentially
104104 for (int64_t t = 0 ; t < n_tokens; t++) {
105105
106- float q_sq = 0 .0f ;
107- float k_sq = 0 .0f ;
108- for (int i = tid; i < HEAD_DIM; i += blockDim .x ) {
109- sQ [i] = q_ptr[t * qkv_stride_token + i];
106+ float sum_kq = 0 .0f ;
107+ for (int i = tid; i < HEAD_DIM; i += block_size) {
108+ sQ [i] = q_ptr[t * qkv_stride_token + i] * scale;
110109 sK [i] = k_ptr[t * qkv_stride_token + i];
111110 sV [i] = v_ptr[t * qkv_stride_token + i];
112- q_sq += sQ [i] * sQ [i];
113- k_sq += sK [i] * sK [i];
111+ sum_kq += sK [i] * sQ [i];
114112 }
115113
116- q_sq = reduce_sum<block_size>(q_sq, sum_helper);
117- k_sq = reduce_sum<block_size>(k_sq, sum_helper);
118-
119- float q_norm = rsqrtf (q_sq + eps);
120- float k_norm = rsqrtf (k_sq + eps);
114+ float attn_score = reduce_sum<block_size>(sum_kq, sum_helper);
121115
122116 float beta_val = sigmoid_f (beta_ptr[t]);
123117 float decay = expf (fminf (g_ptr[t], 50 .0f ));
124118
125- float sum = 0 ;
126- for (int i = tid; i < HEAD_DIM; i += blockDim .x ) {
127- sQ [i] = sQ [i] * q_norm * scale;
128- sK [i] = sK [i] * k_norm;
129- sKBeta [i] = sK [i];
130- sVBeta [i] = sV [i] * beta_val;
131- sKCumdecay [i] = sK [i] * beta_val * decay;
132- sum += sK [i] * sQ [i];
133- }
134- float attn_score = reduce_sum<block_size>(sum, sum_helper);
135- // __syncthreads();
136-
137- for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) {
119+ for (int row_out = lane_id; row_out < HEAD_DIM; row_out += WARP_SIZE) {
138120 float sum1 = 0 .0f ;
139121 float sum2 = 0 .0f ;
140122 #pragma unroll
141- for (int col = lane_id ; col < HEAD_DIM; col += WARP_SIZE ) {
123+ for (int col = warp_id ; col < HEAD_DIM; col += NUM_WARPS ) {
142124 float sval = state_dst[row_out + col * HEAD_DIM];
143- sum1 += sval * sKCumdecay [col];
125+ sum1 += sval * sK [col];
144126 sum2 += sval * sQ [col];
145127 }
146- sum1 = warp_reduce_sum (sum1);
147- sum2 = warp_reduce_sum (sum2);
148- if (lane_id == 0 ) {
149- sVNew [row_out] = sVBeta [row_out] - sum1;
150- float v_attn = sVNew [row_out] * attn_score;
151- // sOut[row_out] = sum2 * decay + v_attn;
152- out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn;
128+ all_sum1[warp_id*HEAD_DIM_S + row_out] = sum1;
129+ all_sum2[warp_id*HEAD_DIM_S + row_out] = sum2;
130+ }
131+ __syncthreads ();
132+
133+ for (int row_out = tid; row_out < HEAD_DIM; row_out += block_size) {
134+ float sum1 = all_sum1[row_out];
135+ float sum2 = all_sum2[row_out];
136+ #pragma unroll
137+ for (int i = 1 ; i < NUM_WARPS; ++i) {
138+ sum1 += all_sum1[row_out + i*HEAD_DIM_S];
139+ sum2 += all_sum2[row_out + i*HEAD_DIM_S];
153140 }
141+ sVNew [row_out] = sV [row_out] * beta_val - sum1 * beta_val * decay;
142+ float v_attn = sVNew [row_out] * attn_score;
143+ out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn;
154144 }
155145 __syncthreads ();
156146
157147 for (int out_dim = warp_id; out_dim < HEAD_DIM; out_dim += NUM_WARPS) {
158148 #pragma unroll
159149 for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) {
160150 float state_val = state_dst[row + out_dim * HEAD_DIM];
161- float safe_decay = decay;
162- if (isnan (safe_decay) || isinf (safe_decay)) {
163- safe_decay = 1 .0f ;
164- }
165- float new_state_val = safe_decay * state_val + sVNew [row] * sKBeta [out_dim];
151+ float new_state_val = decay * state_val + sVNew [row] * sK [out_dim];
166152 new_state_val = fminf (fmaxf (new_state_val, -1e6f), 1e6f);
167153 state_dst[row + out_dim * HEAD_DIM] = new_state_val;
168154 }
169155 }
170- if (t < n_tokens - 1 ) {
171- __syncthreads ();
172- }
173-
174156 }
175157}
176158
@@ -408,9 +390,7 @@ static void delta_net_f32_cuda(
408390 const int num_blocks = n_seqs * n_heads;
409391 constexpr int threads_per_block = 512 ; // 256;
410392
411- // Shared memory: 9 * head_dim (for Q, K, V, KBeta, VBeta, Out, KCumdecay, VPrime, VNew)
412- // Plus 6 floats for Norm[2], g_val, beta_val, decay, attn_score
413- const size_t smem_size = (9 * head_dim + 6 ) * sizeof (float );
393+ const size_t smem_size = 4 * head_dim * sizeof (float );
414394
415395 // Use templated kernel for common head dimensions, generic for others
416396 if (head_dim == 64 ) {
@@ -421,6 +401,7 @@ static void delta_net_f32_cuda(
421401 delta_net_recurrent_f32<128 , threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>> (
422402 q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
423403 } else {
404+ GGML_ASSERT (" Unsupported delta net head size" );
424405 delta_net_recurrent_generic_f32<<<num_blocks, threads_per_block, smem_size, stream>>> (
425406 q, k, v, g, beta, state_in, dst, head_dim, n_tokens, n_heads, n_seqs, output_offset, eps);
426407 }
0 commit comments