Skip to content

Conversation

@nurmukhametov
Copy link

Technical Details

Replace per-transpose loops with a single unified loop that processes all transposes simultaneously, computing indices once and reusing them across all operations.

Update packed_transpose_multiple_heroes.hlo test to verify the single-loop structure with multiple iter_args.

Motivation

This fixes the second performance regression (first #434) caused by the new implementation of fused transpose emitter for fp narrower than 32 (PackedTranspose).

Test Result

It reduces the execution time for fused_convert_transpose_3.hlo of Llama 3 8B FP8 by ~30%, bringing it almost back to v0.6.0 performance (~4% gap). Together with #434, the performance of the 4 top fused_convert_transpose kernels is improved by ~17%, resulting in an end-to-end model performance improvement of ~1% (tokens per second per gpu).

Replace per-transpose loops with a single unified loop that processes all
transposes simultaneously, computing indices once and reusing them across
all operations.

Update packed_transpose_multiple_heroes.hlo test to verify the single-loop
structure with multiple iter_args.
@charleshofer
Copy link
Collaborator

Could you check to see if this is applicable to the rocm-jaxlib-v0.8.0 and fix there as well, if needed?

@i-chaochen i-chaochen requested a review from pemeliya November 21, 2025 23:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants