Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion ggml/src/iqk/iqk_mul_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1510,8 +1510,12 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,

const float scale = 1.0f / sqrtf((float) head_dim);

#ifdef __AVX512F__
__m512 v_prime[head_dim/16], out_val[head_dim/16];
#else
float v_new_buf[head_dim];
float v_prime[head_dim], out_val[head_dim];
#endif

for (int h_idx = h_start; h_idx < h_end; ++h_idx) {
const int batch_idx = h_idx / n_heads;
Expand Down Expand Up @@ -1539,7 +1543,15 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
const float beta_raw = beta_data[g_head_offset + t];

float kq_sum = 0.0f;
#ifdef __AVX2__
#if defined __AVX512F__
auto vqksum = _mm512_setzero_ps();
for (int i = 0; i < head_dim; i += 16) {
auto vq = _mm512_loadu_ps(q_t + i);
auto vk = _mm512_loadu_ps(k_t + i);
vqksum = _mm512_fmadd_ps(vk, vq, vqksum);
}
kq_sum = _mm512_reduce_add_ps(vqksum);
#elif defined __AVX2__
auto vqksum = _mm256_setzero_ps();
for (int i = 0; i < head_dim; i += 8) {
auto vq = _mm256_loadu_ps(q_t + i);
Expand All @@ -1560,6 +1572,42 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,

float * out_t = out_data + out_head_offset + t * out_token_stride;

#ifdef __AVX512F__
for (int j = 0; j < head_dim/16; ++j) {
v_prime[j] = out_val[j] = _mm512_setzero_ps();
}
for (int col = 0; col < head_dim; ++col) {
auto k_col = _mm512_set1_ps(k_t[col]);
auto q_col = _mm512_set1_ps(q_t[col]);
for (int j = 0; j < head_dim/16; ++j) {
auto s = _mm512_loadu_ps(state + col * head_dim + 16*j);
v_prime[j] = _mm512_fmadd_ps(s, k_col, v_prime[j]);
out_val[j] = _mm512_fmadd_ps(s, q_col, out_val[j]);
}
}
auto c1 = _mm512_set1_ps(beta_val);
auto c2 = _mm512_set1_ps(beta_val*decay);
auto c3 = _mm512_set1_ps(decay*scale);
auto c4 = _mm512_set1_ps(attn_score);
for (int j = 0; j < head_dim/16; ++j) {
auto v = _mm512_loadu_ps(v_t + 16*j);
v_prime[j] = _mm512_sub_ps(_mm512_mul_ps(v, c1), _mm512_mul_ps(v_prime[j], c2));
auto oval = _mm512_fmadd_ps(v_prime[j], c4, _mm512_mul_ps(out_val[j], c3));
_mm512_storeu_ps(out_t + 16*j, oval);
}
auto vmin = _mm512_set1_ps(-1e6f);
auto vmax = _mm512_set1_ps( 1e6f);
auto vd = _mm512_set1_ps(decay);
for (int col = 0; col < head_dim; ++col) {
auto vk = _mm512_set1_ps(k_t[col]);
for (int j = 0; j < head_dim/16; ++j) {
auto vs = _mm512_loadu_ps(state + col * head_dim + 16*j);
vs = _mm512_fmadd_ps(v_prime[j], vk, _mm512_mul_ps(vs, vd));
vs = _mm512_max_ps(vmin, _mm512_min_ps(vmax, vs));
_mm512_storeu_ps(state + col * head_dim + 16*j, vs);
}
}
#else
std::memset(v_prime, 0, head_dim*sizeof(float));
std::memset(out_val, 0, head_dim*sizeof(float));
for (int col = 0; col < head_dim; ++col) {
Expand Down Expand Up @@ -1603,6 +1651,7 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
state[row + col * head_dim] = fminf(fmaxf(s, -1e6f), 1e6f);
}
}
#endif
#endif
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/llama-delta-net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,12 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_fused_delta_net(ggml_co
const int64_t output_size = S_v * H_v * n_tokens * n_seqs;
const int64_t state_size = S_v * S_v * H_v * n_seqs;

ggml_tensor * output_tokens = ggml_view_4d(ctx0, fused_result,
auto output_tokens = ggml_view_4d(ctx0, fused_result,
S_v, H_v, n_tokens, n_seqs,
ggml_row_size(fused_result->type, S_v),
ggml_row_size(fused_result->type, S_v * H_v),
ggml_row_size(fused_result->type, S_v * H_v * n_tokens), 0);
output_tokens = ggml_cont_4d(ctx0, output_tokens, S_v, H_v, n_tokens, n_seqs);
//output_tokens = ggml_cont_4d(ctx0, output_tokens, S_v, H_v, n_tokens, n_seqs);

ggml_tensor * new_state_flat = ggml_view_1d(ctx0, fused_result, state_size,
output_size * ggml_element_size(fused_result));
Expand Down