Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 72 additions & 36 deletions ggml/src/ggml-hexagon/htp/flash-attn-ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -54,53 +54,96 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
}

static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
const void * restrict y,
const void * restrict x0,
const void * restrict x1,
unsigned int n,
float s) {
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16

uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements

HVX_Vector rsum0 = Q6_V_vsplat_R(0);
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y,
const uint8_t * restrict x,
const size_t stride_x,
const size_t nvec,
const size_t nloe) {
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x; // fp16
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) (x + stride_x); // fp16
const HVX_Vector * restrict vx2 = (const HVX_Vector * restrict) (x + stride_x * 2); // fp16
const HVX_Vector * restrict vx3 = (const HVX_Vector * restrict) (x + stride_x * 3); // fp16
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16

HVX_Vector rsum0 = Q6_V_vsplat_R(0);
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
HVX_Vector rsum2 = Q6_V_vsplat_R(0);
HVX_Vector rsum3 = Q6_V_vsplat_R(0);

uint32_t i = 0;

#pragma unroll(4)
for (i = 0; i < nvec; i++) {
HVX_Vector y_hf = vy[i];
HVX_Vector x0_hf = vx0[i];
HVX_Vector x1_hf = vx1[i];
HVX_Vector x2_hf = vx2[i];
HVX_Vector x3_hf = vx3[i];

HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);

rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
HVX_VectorPair xy2_qf = Q6_Wqf32_vmpy_VhfVhf(x2_hf, y_hf);
HVX_VectorPair xy3_qf = Q6_Wqf32_vmpy_VhfVhf(x3_hf, y_hf);

rsum0 = Q6_Vsf_equals_Vqf32(
Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
rsum1 = Q6_Vsf_equals_Vqf32(
Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
rsum2 = Q6_Vsf_equals_Vqf32(
Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy2_qf), Q6_V_hi_W(xy2_qf)), rsum2));
rsum3 = Q6_Vsf_equals_Vqf32(
Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy3_qf), Q6_V_hi_W(xy3_qf)), rsum3));
}

if (nloe) {
// Load x (fp16) and zero-out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
HVX_Vector x2_hf = Q6_V_vand_QV(bmask, vx2[i]);
HVX_Vector x3_hf = Q6_V_vand_QV(bmask, vx3[i]);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);

HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
HVX_VectorPair xy2_qf = Q6_Wqf32_vmpy_VhfVhf(x2_hf, y_hf);
HVX_VectorPair xy3_qf = Q6_Wqf32_vmpy_VhfVhf(x3_hf, y_hf);

rsum0 = Q6_Vsf_equals_Vqf32(
Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
rsum1 = Q6_Vsf_equals_Vqf32(
Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
rsum2 = Q6_Vsf_equals_Vqf32(
Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy2_qf), Q6_V_hi_W(xy2_qf)), rsum2));
rsum3 = Q6_Vsf_equals_Vqf32(
Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy3_qf), Q6_V_hi_W(xy3_qf)), rsum3));
}

HVX_Vector_x4 rsum0123 = {
.v = { rsum0, rsum1, rsum2, rsum3 }
};
return hvx_vec_reduce_sum_f32x4(rsum0123);
}

rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y,
const uint8_t * restrict x,
const size_t stride_x,
const size_t n,
float s) {

const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
const size_t nloe = n % VLEN_FP16; // leftover elements

HVX_Vector sums; // initialize at j = 0
const size_t stride_x_4 = stride_x * 4;
for (uint32_t j = 0; j < VLEN_FP32; j += 4) {
HVX_Vector sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe);
HVX_VectorPred pred = Q6_Q_vsetq_R(j * SIZEOF_FP32);
sums = Q6_V_vmux_QVV(pred, sums, sums_x4);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Key Improvement: Instead of using hvx_vec_store_u, we now use vmux to directly insert the 4 floats into the result vector.

x += stride_x_4;
}

HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
sums = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), sums);
return Q6_Vsf_equals_Vqf32(sums);
}

// MAD: y (F32) += x (F16) * s (F32)
Expand Down Expand Up @@ -179,9 +222,9 @@ static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y,
HVX_Vector xs = xs_p_lo;
i = 2 * i; // index for ptr_y

if (nloe >= 32) {
if (nloe >= VLEN_FP32) {
ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
nloe -= 32; ++i;
nloe -= VLEN_FP32; ++i;
xs = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p));
}

Expand Down Expand Up @@ -402,14 +445,7 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
// 1. Compute scores
float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
const uint32_t cur_ic = ic + j;
const uint8_t * k_ptr = k_base + cur_ic * factx->size_k_row_padded;
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + factx->size_k_row_padded, DK, factx->scale);
}

HVX_Vector scores = *(HVX_Vector *) scores_arr;
HVX_Vector scores = hvx_dot_f16_f16_aa_rx32(q_ptr_vtcm, k_base + ic * factx->size_k_row_padded, factx->size_k_row_padded, DK, factx->scale);

// 2. Softcap
if (factx->logit_softcap != 0.0f) {
Expand Down
30 changes: 30 additions & 0 deletions ggml/src/ggml-hexagon/htp/hvx-reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ static inline HVX_Vector hvx_vec_reduce_sum_qf32(HVX_Vector in) {

#if __HVX_ARCH__ > 75

static inline HVX_Vector hvx_vec_reduce_sum_f32x4(HVX_Vector_x4 in) {
HVX_VectorPair sum_p01 = Q6_W_vshuff_VVR(in.v[1], in.v[0], 4);
HVX_VectorPair sum_p23 = Q6_W_vshuff_VVR(in.v[3], in.v[2], 4);
HVX_Vector sum_sf01 = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sum_p01), Q6_V_hi_W(sum_p01));
HVX_Vector sum_sf23 = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sum_p23), Q6_V_hi_W(sum_p23));

HVX_VectorPair sum_p0123 = Q6_W_vshuff_VVR(sum_sf23, sum_sf01, 8);
HVX_Vector sum_sf = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sum_p0123), Q6_V_hi_W(sum_p0123));

sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 2));
sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 4));
sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 8));
return sum_sf;
}

static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) {
HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4);
HVX_Vector sum_sf = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump));
Expand All @@ -72,6 +87,21 @@ static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n)

#else

static inline HVX_Vector hvx_vec_reduce_sum_f32x4(HVX_Vector_x4 in) {
HVX_VectorPair sum_p01 = Q6_W_vshuff_VVR(in.v[1], in.v[0], 4);
HVX_VectorPair sum_p23 = Q6_W_vshuff_VVR(in.v[3], in.v[2], 4);
HVX_Vector sum_qf01 = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sum_p01), Q6_V_hi_W(sum_p01));
HVX_Vector sum_qf23 = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sum_p23), Q6_V_hi_W(sum_p23));

HVX_VectorPair sum_p0123 = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(sum_qf23), Q6_Vsf_equals_Vqf32(sum_qf01), 8);
HVX_Vector sum_qf = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sum_p0123), Q6_V_hi_W(sum_p0123));

sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 2));
sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 4));
sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 8));
return Q6_Vsf_equals_Vqf32(sum_qf);
}

static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) {
HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4);
HVX_Vector sum_qf = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump));
Expand Down
Loading