Skip to content

hexagon: Flash Attention optimizations (dma, mpyacc, multi-row) and MatMul updates#20118

Merged
max-krasnyansky merged 19 commits intoggml-org:masterfrom
qualcomm:hexagon-fa-updates-dma-mpyacc
Mar 5, 2026
Merged

hexagon: Flash Attention optimizations (dma, mpyacc, multi-row) and MatMul updates#20118
max-krasnyansky merged 19 commits intoggml-org:masterfrom
qualcomm:hexagon-fa-updates-dma-mpyacc

Conversation

@max-krasnyansky
Copy link
Member

Further updates on top of #19780 by @chraac

  • Improved DMA pipelining in FA
  • Reduced FA block size from 128 to 64 to improve DMA prefetch (128 is too big for most models)
  • Improved usage or vmpyacc intrinsics in dot products

Some quick perf numbers on S25+ (Gen4) with Llama3.2-3B-Q4_0 and FA on Hexagon

Before:
common_perf_print: prompt eval time =    2711.51 ms /   205 tokens (   13.23 ms per token,    75.60 tokens per second)
common_perf_print:        eval time =    3934.46 ms /    63 runs   (   62.45 ms per token,    16.01 tokens per second)

After:
common_perf_print: prompt eval time =    2538.64 ms /   205 tokens (   12.38 ms per token,    80.75 tokens per second)
common_perf_print:        eval time =    3586.00 ms /    63 runs   (   56.92 ms per token,    17.57 tokens per second)

Original notes from #19780

Dot Product Function Improvements

Replaced the previous hvx_dot_f16_f16_aa_rx2 function with new, more parallelized hvx_dot_f16_f16_aa_rx4 and
hvx_dot_f16_f16_aa_rx32 functions in flash-attn-ops.c, allowing computation of 4 and 32 dot products at a time,
respectively. This increases throughput and simplifies the code by leveraging vectorization.
Updated the main attention kernel (flash_attn_ext_f16_thread) to use the new hvx_dot_f16_f16_aa_rx32 function,
replacing the looped calls to the old function and removing the need for temporary arrays.

Vector Reduction Utilities

Added hvx_vec_reduce_sum_f32x4 utility in hvx-reduce.h for both HVX architectures, enabling efficient reduction
of four HVX vector results into a single vector. This supports the new parallel dot product functions.

chraac and others added 19 commits March 4, 2026 11:37
… by expanding vector handling and optimizing accumulation

# Conflicts:
#	ggml/src/ggml-hexagon/htp/flash-attn-ops.c
…uce_sum_f32x4 for improved performance and reduced complexity
…ing in flash attention

# Conflicts:
#	ggml/src/ggml-hexagon/htp/flash-attn-ops.c
…ng unused scale parameter and improving vector accumulation

# Conflicts:
#	ggml/src/ggml-hexagon/htp/flash-attn-ops.c
…y and return HVX_Vector for better integration

# Conflicts:
#	ggml/src/ggml-hexagon/htp/flash-attn-ops.c
… counts as parameters for improved clarity and flexibility
@max-krasnyansky max-krasnyansky requested a review from lhez as a code owner March 4, 2026 23:21
@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label Mar 5, 2026
@max-krasnyansky max-krasnyansky merged commit 7a99dc8 into ggml-org:master Mar 5, 2026
78 checks passed
}

static inline HVX_Vector hvx_vec_splat_f16(float v) {
static inline HVX_Vector hvx_vec_splat_f16(_Float16 v) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: better to keep the same type (__fp16) as union below.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__fp16 can't be used as a function argument, would have to be a pointer.
_Float16 can. That's pretty much the only reason I used it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From gcc doc, there's a minor different between __fp16 and _Float16 in some arch, but both okay in this case

bartowski1182 pushed a commit to bartowski1182/llama.cpp that referenced this pull request Mar 10, 2026
…atMul updates (ggml-org#20118)

* ggml-hexagon: enhance hvx_dot_f16_f16_aa_rx4 for improved performance by expanding vector handling and optimizing accumulation

# Conflicts:
#	ggml/src/ggml-hexagon/htp/flash-attn-ops.c

* ggml-hexagon: optimize hvx_dot_f16_f16_aa_rx4 and enhance hvx_vec_reduce_sum_f32x4 for improved performance and reduced complexity

* ggml-hexagon: add hvx_dot_f16_f16_aa_rx32 for enhanced vector processing in flash attention

# Conflicts:
#	ggml/src/ggml-hexagon/htp/flash-attn-ops.c

* optimize hvx_dot_f16_f16_aa_rx4 and hvx_dot_f16_f16_aa_rx32 by removing unused scale parameter and improving vector accumulation

# Conflicts:
#	ggml/src/ggml-hexagon/htp/flash-attn-ops.c

* ggml-hexagon: refactor hvx_dot_f16_f16_aa_rx4 for improved readability and return HVX_Vector for better integration

# Conflicts:
#	ggml/src/ggml-hexagon/htp/flash-attn-ops.c

* ggml-hexagon: initialize sums variable in hvx_dot_f16_f16_aa_rx32 for clarity

* ggml-hexagon: fix compiling error

* fix hvx_dot_f16_f16_aa_rx4 to handle leftover elements correctly using masking

* refactor hvx_dot_f16_f16_aa_rx4 to accept vector and leftover element counts as parameters for improved clarity and flexibility

* wip

* fa: instrumentation and dma reordering

* hex-fa: use block-size 64 to improve DMA pipelining

* hex-fa: optimize vec-dot for v79 and above

* hex-fa: use block size 64

* hex-fa: avoid scalar fp32->fp16 conversions

* hex-fa: simplify dot_f16 functions using optimized vec_mpyacc

* hex-fa: rewrite mad_f32_f16 using hvx_vec_mpyacc

* hex-mm: use mpyacc in matmul dot functions

---------

Co-authored-by: chraac <[email protected]>
Ethan-a2 pushed a commit to Ethan-a2/llama.cpp that referenced this pull request Mar 20, 2026
…atMul updates (ggml-org#20118)

* ggml-hexagon: enhance hvx_dot_f16_f16_aa_rx4 for improved performance by expanding vector handling and optimizing accumulation

# Conflicts:
#	ggml/src/ggml-hexagon/htp/flash-attn-ops.c

* ggml-hexagon: optimize hvx_dot_f16_f16_aa_rx4 and enhance hvx_vec_reduce_sum_f32x4 for improved performance and reduced complexity

* ggml-hexagon: add hvx_dot_f16_f16_aa_rx32 for enhanced vector processing in flash attention

# Conflicts:
#	ggml/src/ggml-hexagon/htp/flash-attn-ops.c

* optimize hvx_dot_f16_f16_aa_rx4 and hvx_dot_f16_f16_aa_rx32 by removing unused scale parameter and improving vector accumulation

# Conflicts:
#	ggml/src/ggml-hexagon/htp/flash-attn-ops.c

* ggml-hexagon: refactor hvx_dot_f16_f16_aa_rx4 for improved readability and return HVX_Vector for better integration

# Conflicts:
#	ggml/src/ggml-hexagon/htp/flash-attn-ops.c

* ggml-hexagon: initialize sums variable in hvx_dot_f16_f16_aa_rx32 for clarity

* ggml-hexagon: fix compiling error

* fix hvx_dot_f16_f16_aa_rx4 to handle leftover elements correctly using masking

* refactor hvx_dot_f16_f16_aa_rx4 to accept vector and leftover element counts as parameters for improved clarity and flexibility

* wip

* fa: instrumentation and dma reordering

* hex-fa: use block-size 64 to improve DMA pipelining

* hex-fa: optimize vec-dot for v79 and above

* hex-fa: use block size 64

* hex-fa: avoid scalar fp32->fp16 conversions

* hex-fa: simplify dot_f16 functions using optimized vec_mpyacc

* hex-fa: rewrite mad_f32_f16 using hvx_vec_mpyacc

* hex-mm: use mpyacc in matmul dot functions

---------

Co-authored-by: chraac <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants