hexagon: Flash Attention optimizations (dma, mpyacc, multi-row) and MatMul updates#20118
Merged
max-krasnyansky merged 19 commits intoggml-org:masterfrom Mar 5, 2026
Merged
Conversation
… by expanding vector handling and optimizing accumulation # Conflicts: # ggml/src/ggml-hexagon/htp/flash-attn-ops.c
…uce_sum_f32x4 for improved performance and reduced complexity
…ing in flash attention # Conflicts: # ggml/src/ggml-hexagon/htp/flash-attn-ops.c
…ng unused scale parameter and improving vector accumulation # Conflicts: # ggml/src/ggml-hexagon/htp/flash-attn-ops.c
…y and return HVX_Vector for better integration # Conflicts: # ggml/src/ggml-hexagon/htp/flash-attn-ops.c
… counts as parameters for improved clarity and flexibility
lhez
approved these changes
Mar 5, 2026
chraac
reviewed
Mar 5, 2026
| } | ||
|
|
||
| static inline HVX_Vector hvx_vec_splat_f16(float v) { | ||
| static inline HVX_Vector hvx_vec_splat_f16(_Float16 v) { |
Contributor
There was a problem hiding this comment.
nit: better to keep the same type (__fp16) as union below.
Member
Author
There was a problem hiding this comment.
__fp16 can't be used as a function argument, would have to be a pointer.
_Float16 can. That's pretty much the only reason I used it.
Contributor
There was a problem hiding this comment.
From gcc doc, there's a minor different between __fp16 and _Float16 in some arch, but both okay in this case
bartowski1182
pushed a commit
to bartowski1182/llama.cpp
that referenced
this pull request
Mar 10, 2026
…atMul updates (ggml-org#20118) * ggml-hexagon: enhance hvx_dot_f16_f16_aa_rx4 for improved performance by expanding vector handling and optimizing accumulation # Conflicts: # ggml/src/ggml-hexagon/htp/flash-attn-ops.c * ggml-hexagon: optimize hvx_dot_f16_f16_aa_rx4 and enhance hvx_vec_reduce_sum_f32x4 for improved performance and reduced complexity * ggml-hexagon: add hvx_dot_f16_f16_aa_rx32 for enhanced vector processing in flash attention # Conflicts: # ggml/src/ggml-hexagon/htp/flash-attn-ops.c * optimize hvx_dot_f16_f16_aa_rx4 and hvx_dot_f16_f16_aa_rx32 by removing unused scale parameter and improving vector accumulation # Conflicts: # ggml/src/ggml-hexagon/htp/flash-attn-ops.c * ggml-hexagon: refactor hvx_dot_f16_f16_aa_rx4 for improved readability and return HVX_Vector for better integration # Conflicts: # ggml/src/ggml-hexagon/htp/flash-attn-ops.c * ggml-hexagon: initialize sums variable in hvx_dot_f16_f16_aa_rx32 for clarity * ggml-hexagon: fix compiling error * fix hvx_dot_f16_f16_aa_rx4 to handle leftover elements correctly using masking * refactor hvx_dot_f16_f16_aa_rx4 to accept vector and leftover element counts as parameters for improved clarity and flexibility * wip * fa: instrumentation and dma reordering * hex-fa: use block-size 64 to improve DMA pipelining * hex-fa: optimize vec-dot for v79 and above * hex-fa: use block size 64 * hex-fa: avoid scalar fp32->fp16 conversions * hex-fa: simplify dot_f16 functions using optimized vec_mpyacc * hex-fa: rewrite mad_f32_f16 using hvx_vec_mpyacc * hex-mm: use mpyacc in matmul dot functions --------- Co-authored-by: chraac <[email protected]>
Ethan-a2
pushed a commit
to Ethan-a2/llama.cpp
that referenced
this pull request
Mar 20, 2026
…atMul updates (ggml-org#20118) * ggml-hexagon: enhance hvx_dot_f16_f16_aa_rx4 for improved performance by expanding vector handling and optimizing accumulation # Conflicts: # ggml/src/ggml-hexagon/htp/flash-attn-ops.c * ggml-hexagon: optimize hvx_dot_f16_f16_aa_rx4 and enhance hvx_vec_reduce_sum_f32x4 for improved performance and reduced complexity * ggml-hexagon: add hvx_dot_f16_f16_aa_rx32 for enhanced vector processing in flash attention # Conflicts: # ggml/src/ggml-hexagon/htp/flash-attn-ops.c * optimize hvx_dot_f16_f16_aa_rx4 and hvx_dot_f16_f16_aa_rx32 by removing unused scale parameter and improving vector accumulation # Conflicts: # ggml/src/ggml-hexagon/htp/flash-attn-ops.c * ggml-hexagon: refactor hvx_dot_f16_f16_aa_rx4 for improved readability and return HVX_Vector for better integration # Conflicts: # ggml/src/ggml-hexagon/htp/flash-attn-ops.c * ggml-hexagon: initialize sums variable in hvx_dot_f16_f16_aa_rx32 for clarity * ggml-hexagon: fix compiling error * fix hvx_dot_f16_f16_aa_rx4 to handle leftover elements correctly using masking * refactor hvx_dot_f16_f16_aa_rx4 to accept vector and leftover element counts as parameters for improved clarity and flexibility * wip * fa: instrumentation and dma reordering * hex-fa: use block-size 64 to improve DMA pipelining * hex-fa: optimize vec-dot for v79 and above * hex-fa: use block size 64 * hex-fa: avoid scalar fp32->fp16 conversions * hex-fa: simplify dot_f16 functions using optimized vec_mpyacc * hex-fa: rewrite mad_f32_f16 using hvx_vec_mpyacc * hex-mm: use mpyacc in matmul dot functions --------- Co-authored-by: chraac <[email protected]>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Further updates on top of #19780 by @chraac
Some quick perf numbers on S25+ (Gen4) with Llama3.2-3B-Q4_0 and FA on Hexagon