Skip to content

Conversation

@plognjen
Copy link
Contributor

@plognjen plognjen commented Jul 22, 2025

This PR implements test kernel for efficient scale packing for CDNA4 arch as well as opSel for scaled MFMA instructions.

Scaled MFMA instructions expect scale operands as 32-bit values,
even though each individual scale is only 8 bits.
To reduce register usage, we pack 4 scales into a single 32-bit value and use the opSel
field to select the appropriate byte during execution.
Packing is done along the K dimension first. if there aren’t enough values in K, we
continue along the non-K dimension.

@plognjen plognjen force-pushed the preshuffling_opsel_support branch 3 times, most recently from 0e133f9 to ed9258a Compare July 24, 2025 15:36
@plognjen plognjen changed the title [WIP!!] Scale preshuffling and opSel implementation [AMD] Scale preshuffling and opSel implementation Jul 24, 2025
Copy link
Collaborator

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool stuff! Thanks for working this out! Just a few minor comments inlined.

@antiagainst antiagainst marked this pull request as ready for review July 25, 2025 00:11
Copy link
Collaborator

@zhanglx13 zhanglx13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good work @plognjen!


assert(scaleAKBase % akPackedVals == 0 && scaleBKBase % bkPackedVals == 0);
int nonAKPackedVals = scaleAKBase / akPackedVals;
int nonBKPackedVals = scaleBKBase / bkPackedVals;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the name should be aNonKPackedVals and bNonKPackedVals

}

if (2 == kBase)
results = b.zext(i32_ty, b.bitcast(vec, i16_ty));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some comments to explain this case?
Also the comment at line 429 and 446 needs to be updated.

if (numVecInKBase == 0) {
numVecInKBase = 1;
nonKRep /= kBase / (kRepInKWidth * kWidth);
assert(nonKRep > 0 && "nonKrep too small");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this assert?
This if can only happen for scales, and scale's kBase is bounded by numRepK * numRepM.



@pytest.mark.parametrize("M, N, K", [(1024, 1024, 1024), [512, 1024, 2048]])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(256, 256, 256), (128, 128, 256)])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add another parameter so that we can test the kernel w/ and w/o preshuffling?


# Create pointers for the first block of A and B scales
offs_asn = (pid_n * (BLOCK_N // 32) + tl.arange(0, (BLOCK_N // 32))) % N
offs_ks = tl.arange(0, BLOCK_K // SCALE_GROUP_SIZE * 32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These hard coded "32" happens to be the same as SCALE_GROUP_SIZE. I think it's better to create a variable for them so we know they are part of the preshuffling algorithm.



@pytest.mark.parametrize("M, N, K", [(1024, 1024, 1024), [512, 1024, 2048]])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(256, 256, 256), (128, 128, 256)])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now you are making preshuffling generic, then we don't have any limitations on BLOCK_K, right? If so, can you add more BLOCK sizes?

tl.store(output_ptrs, accumulator, mask=c_mask)


def shuffle_scales_amd(scales: torch.Tensor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pre-shuffling algorithm is designed for mfma_16x16 instructions with a particular order (top-left, bottom-left, top-right, bottom-right) of the 4 scale value for each thread. Can you add some notes here explaining the pre-shuffling algorithm?
Can we also add support for mfma_32x32?

Then we know for sure the compiler-side change is agnostic to the pre-shuffling algorithm.

@plognjen plognjen force-pushed the preshuffling_opsel_support branch from ed9258a to 13394f1 Compare July 29, 2025 16:42
Comment on lines 582 to 583
@pytest.mark.skipif(is_cuda(), reason="AMD specific scale shuffling")
@pytest.mark.skipif(not is_hip_cdna4(), reason="Requires hardware support for scaled mfma instructions")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason why those should be disabled? This should technically run and would confirm that it is correct independently of the optimization

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't fully support scale dot emulation on other architectures yet, this is why I skipped for now.
I enabled it for nv hardware, let's see if CI passes.

tl.store(c_ptrs, c, mask=c_mask, cache_modifier=".wt")


@pytest.mark.parametrize("M, N, K", [(1024, 1024, 1024), [512, 1024, 2048], [2048, 2048, 2048]])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a point trying all those matrix sizes? It is unlikely to change much as this will call the same exact kernel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, I will leave only (1024, 1024, 1024).



@pytest.mark.parametrize("M, N, K", [(1024, 1024, 1024), [512, 1024, 2048], [2048, 2048, 2048]])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(256, 256, 256), (128, 128, 256), (128, 128, 512), [32, 32, 64]])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here also it would be good to pick sizes that test interesting cases, some of those block size are likely to heavily spill and make things slow so we should keep them only if they test interesting corner cases

@plognjen plognjen force-pushed the preshuffling_opsel_support branch 2 times, most recently from b407a47 to 8a294c7 Compare July 29, 2025 23:03
Copy link
Collaborator

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Just some final nits and questions.

// For scale tensor preshuffling, the minimum block size is 32x32x256.
// When using MFMA16 instructions, each warp should compute two MFMA ops
// along the non-K dimension. To support this, we must set tilesPerWarp to
// {2, 2}. Failing to do so won't break correctness, but it will prevent
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine for now. But is it really a limitation due to we need to set dot operand layout? If we use linear layout throughout everywhere will it be "automically" figured out? Maybe need some generic linear layout algorithms to deduce there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly, we would need to do some analysis to figure out which linear layout is "the best" in terms for, in this case, local load vectorization. I agree it would be good to do this eventually. We can create lower priority ticket for this. This is somewhat ad-hoc, but I also think it's fine for now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. Let's create an issue to track it to improve later.

@ThomasRaoux
Copy link
Collaborator

looks like GB200 is failing, feel free to disable the pass with a TODO and I can take a look at it later. Would be great if you could keep the other targets on

@plognjen plognjen force-pushed the preshuffling_opsel_support branch from 8a294c7 to 1efeccc Compare July 30, 2025 13:52
@plognjen
Copy link
Contributor Author

@antiagainst @zhanglx13 @ThomasRaoux Thanks for detailed review, appreciate it!

return false;
}

const std::array<int, 7> transposeOrder{0, 5, 3, 1, 4, 2, 6};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this order bounded with the following order?

Packing order: mfma_op_0, mfma_op_2, mfma_op_1, mfma_op_3

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is "unshuffling" order which is related to the way we shuffle. We should generalize this by actually analyzing the LL, but I guess it's fine as a workaround for now.

@antiagainst antiagainst merged commit a1f42ef into triton-lang:main Jul 30, 2025
9 checks passed
AlexAUT pushed a commit to ROCm/triton that referenced this pull request Jul 31, 2025
This PR implements test kernel for efficient scale packing for CDNA4
arch as well as opSel for scaled MFMA instructions.

Scaled MFMA instructions expect scale operands as 32-bit values,
even though each individual scale is only 8 bits.
To reduce register usage, we pack 4 scales into a single 32-bit value
and use the opSel field to select the appropriate byte during execution.
Packing is done along the K dimension first. if there aren’t enough
values in K, we continue along the non-K dimension.

---------

Co-authored-by: Ognjen Plavsic <[email protected]>
jataylo pushed a commit to ROCm/triton that referenced this pull request Aug 1, 2025
* [AMD] Avoid async load to pipeline for less than 32bit load (triton-lang#7250)

We can only use AsyncCopy if the final load width can be >= 4 bytes. 
`triton::canBeConvertedToAsyncLoad` checks that the vecSize of the
source is large enough. Additionally we need to ensure the register to
shared layout (blocked+shared) does have enough contiguous elements
since we cannot scatter into LDS.

Before this PR we will abort compilation instead of falling back to
pipelining through registers.

* [AMD] Pipeline small tensors w/ registers only on GFX950 (triton-lang#7171)

Fixes a perf regression on gfx942 but preserves functionality for
gfx950 (and above).

* Reland "[AMD] Optimize reduction with v_permlane intrinsics in GFX950" (triton-lang#7321)

triton-lang#7291 fixed the
LLVM issue that caused correctness problems. Now
we can reland this patch.

* [Pipeliner] Expose core pipeliner utilities to reuse in AMD backend (triton-lang#7222)

This PR exposes (via header) core pipelining utilities/helpers:

```c++
bool hasGpuBarriers(scf::ForOp forOp);
bool isSafeToPipeline(scf::ForOp forOp);
llvm::MapVector<Operation *, std::pair<int, Operation *>>
loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot,
                          triton::ModuleAxisInfoAnalysis &axisInfoAnalysis,
                          int numStages, bool filterSmall = true);
void scheduleDistanceOneDependencies(scf::ForOp forOp,
                                     CoarseSchedule &schedule);
void scheduleRemainingToLastStage(scf::ForOp forOp, CoarseSchedule &schedule,
                                  CoarseSchedule::Cluster afterPrologue);
```

They are directly useable by AMD's pipeliner. 

Note, this is basically NFC for AMD because AMD's pipeliner simply had
copy-paste of the same functions from ~last year.


Small API changes:

1. On NV we do not pipeline small loads (vec width < 32b). On AMD we do.
The choice is made inside `isPipeliningBeneficial` inside
`loadOpsToIndirectionLevel`. To support AMD I have added a flag
`filterSmall`.
2. On AMD the load `use`s (computed as a matter of course in
`loadOpsToIndirectionLevel`) are used (no pun intended) whereas on NV
they are not. To support AMD I keep those `use`s in the
`llvm::MapVector<Operation *, std::pair<int, Operation *>>` return from
`loadOpsToIndirectionLevel`.

These two small changes are the only "non-NFC" changes.

* [AMD] Retire local prefetch schedule hint variant (triton-lang#7395)

This variant was from some prior experiments. We have a better way
to implement later.

* [AMD] Retire TritonAMDGPU_OpIdxAttr and TritonAMDGPU_InstCounter  (triton-lang#7476)

triton-lang#7395 retired the local
prefetch schedule variant. This made `TritonAMDGPU_OpIdxAttr` and
`TritonAMDGPU_InstCounter` unused which are removed by this PR.

* [AMD][NFC] Split createAndSchedule* in stream pipeliner(triton-lang#7514)

Splits `createAndScheduleAsyncCopy` and `createAndScheduleStreamCopy` to
make it reusable if we want to schedule the ops differently in a future
PR.

* [AMD] Refactor StreamPipeliner to use more common functions (triton-lang#7526)

Further refactoring of Streampipeliner.cpp to use more common pipeliner
functionality: `triton::createAllocation`,
`triton::createSingleBufferView`, `triton::replaceWithSharedLoad` and a
bit of general cleanup.

Overall NFC except:
- The order of LocalDealloc is reversed now
- The memdesc of the subview additionally includes the allocSize

Also we had no lit test checking that the LocalLoad consumes the
AsyncToken so I adjusted one to include the check.

* [AMD] NFC: Refactor stream pipeliner to better encapsulate functions (triton-lang#7540)

Mostly moves code around to reduce the dependencies between functions
and further splits up functions doing more than one thing
(`createAndSchedule*`,` preprocessAndBuildSchedule`). This will also
allow us to use more common pipeliner functionality in a future PR, e.g.
`createAsyncCopy`.

* [FA] Set vecSize=nonKDim for V shared layout to avoid bank conflicts

I'll submit a PR upstream later.

* [GEMM] Add combine dot_scaled and addF

* [AMD][NFC] Consolidate initialization in initSchedule for pipeliner (triton-lang#7556)

Moves all initializations of stages to `initSchedule`. Missed this one
in the last PRs.

* [AMD] NFC: Drop version minor for AMD MFMA layout (triton-lang#7285)

AMD's MFMA layout does not need version minor information like NVIDIA.
It always defaults to 0 in the current codebase. The PR drops version
minor and change to a single `version` parameter for MFMA layout.

* [AMD] Add tilesPerWarp parameter to mfma layout (triton-lang#7283)

This PR introduces the tilesPerWarp parameter to the MFMA layout.
Previously, the MFMA layout assumed that each warp within a CTA tile
computed a single MFMA tile.
When the tensor was larger than a single CTA tile, these tiles were
repeated across the tensor.
In this setup, the output tiles computed by each wave were strided by
the number of warps
per CTA in both row and column dimensions.

For instance, with 16 MFMA tiles and warpsPerCTA = [2, 2], the
distribution of
warps across the MFMA tiles looked like:

w0 w1 w0 w1
w2 w3 w2 w3
w0 w1 w0 w1
w2 w3 w2 w3

The new tilesPerWarp parameter allows each warp to compute contiguous
MFMA tiles
in the row and/or column dimensions. Using the same example with
tilesPerWarp = [2, 2], the layout becomes:

w0 w0 w1 w1
w0 w0 w1 w1
w2 w2 w3 w3
w2 w2 w3 w3

While this is a general enhancement, the main motivation for introducing
this parameter
is to improve memory access efficiency for scale tensors in scaled dot
operations.
Specific patterns and use cases will be implemented in follow-up PRs.

---------

Co-authored-by: Ognjen Plavsic <[email protected]>
Co-authored-by: Lei Zhang <[email protected]>

* [AMD] Add support for pingpong GEMM using async_copy

* [BACKEND] combineRedundantWaitOps should not combine across loops/branches (triton-lang#7593)

`combineRedundantWaitOps` did skip over branches/loops, so if we end up
with something like:

```mlir
ttg.async_wait
scf.for
  ....
  scf.yield
ttg.async_wait
```
we merge the async_waits in the prologue and epilogue because we do not
find a `ttg.commit_group` in between. This PR stops the forward search
if we encounter a branch/loop. I can also walk through all successor
blocks if we think this is worth the effort.

This problem was not triggered before because the `ttg.async_wait` was
scheduled in the same stage as its user(s) so we ended up with no
`ttg.async_wait` in the prologue or there was another prefetch after it
in the prologue.

Since triton-lang#7458 we might place the
`ttg.async_wait` in the previous stage compared to its user(s) so we
might end up with the problematic IR.

* [AMD][NFC] Group scheduling functions in StreamPipeliner (triton-lang#7607)

NFC: Groups all scheduling related function to a namespace to prepare
for additional scheduling variants.

* [AMD] Add pingpong transformation for chained dot schedule (triton-lang#7638)

Adds support to enable pingpong for loops scheduled with the new
`ChainedDotSchedule` introduced by
triton-lang#7601.

The schedule already places the ops in the correct order so we just have
to insert the sync ops to ensure proper pingpong'ing.

* [AMD] Fix pingpong ChainedDot for empty second memory cluster (triton-lang#7694)

triton-lang#7638 introduced a null
pointer access (during review adjustments) if the second memory cluster
is empty or if there are no memory clusters at all.

Added a lit test to catch it and revert to the old logic.

* [AMD] Remove bypass permute optimization for AsyncCopy (triton-lang#7704)

We can only bypass ds_bpermute to apply the swizzling if lanes loading
the same row read a contiguous chunk of memory from HBM, which we cannot
infer when lowering to LLVM.
The current selection does only check if the elements for each lane are
contiguous which is not strict enough.

* [AMD] Add ChainedDotSchedule to StreamPipeliner (triton-lang#7601)

Adds a new scheduling variant which kicks in for loop which have 2
chained dots and `num_stages==4`. It places the two dots in consecutive
stages so we can interleave operations using the result of the first dot
with both dots in the loop, a pseudo example IR:

```
   %1 = tt.dot ...
   %2 = arith.addf %1, %arg1
   %3 = arith.subf %2, %arg2
   %4 = tt.dot %X, %Y, %3
```
Which could result in the following pseudo schedule (ignoring mem ops)
to interleave with both dots:
```
   stage N,   Cluster0: [%1 = tt.dot, %3 = arith.subf]
   stage N+1, Cluster1: [%4 = tt.dot, %2 = arith.addf]
```

As a first step the schedule splits the op chain between dot1 and dot2
when it encounters an operation which has more than 2 users. This aims
to avoid adding too many loop carried dependencies but does not
guarantee a good work balance between the two clusters. In future PRs we
might make this more sophisticated.

* [AMD] Add scale preshuffling and opSel implementation (triton-lang#7603)

This PR implements test kernel for efficient scale packing for CDNA4
arch as well as opSel for scaled MFMA instructions.

Scaled MFMA instructions expect scale operands as 32-bit values,
even though each individual scale is only 8 bits.
To reduce register usage, we pack 4 scales into a single 32-bit value
and use the opSel field to select the appropriate byte during execution.
Packing is done along the K dimension first. if there aren’t enough
values in K, we continue along the non-K dimension.

---------

Co-authored-by: Ognjen Plavsic <[email protected]>

* [AMD] Enable Pingpong by default on gfx950 arch (triton-lang#7697)

List of enabling conditions
- FP/BF16 GEMM with M,N>64 tilesize when num_stages=3 and num_warps=8
- GEMM using `dot_scaled` with M=N=256 tile size when num_stages=2 and
num_warps=8
- FA with num_stages=4
Only with using async_copy.

* [Backend] Bump to llvm/llvm-project@570885128351 (triton-lang#7291)

This picks up a bug fix for AMDGPU v_permlane_swap:

llvm/llvm-project#144423
Without this fix, the v_permlane_swap is wrongly sunk.

Along the way we need to fix API changes:

Add header file for the class IRBuilder
Add missing default parameter in convertFuncOpToLLVMFuncOp

---------

Co-authored-by: Maksim Levental <[email protected]>
Co-authored-by: Yi Qian <[email protected]>
Co-authored-by: Lei Zhang <[email protected]>
Co-authored-by: Lixun Zhang <[email protected]>
Co-authored-by: Jungwook Park <[email protected]>
Co-authored-by: Pengzhan Zhao <[email protected]>
Co-authored-by: plognjen <[email protected]>
Co-authored-by: Ognjen Plavsic <[email protected]>
Co-authored-by: Zeng Wu <[email protected]>
stadlmax pushed a commit to stadlmax/triton that referenced this pull request Aug 4, 2025
This PR implements test kernel for efficient scale packing for CDNA4
arch as well as opSel for scaled MFMA instructions.

Scaled MFMA instructions expect scale operands as 32-bit values, 
even though each individual scale is only 8 bits. 
To reduce register usage, we pack 4 scales into a single 32-bit value
and use the opSel field to select the appropriate byte during execution. 
Packing is done along the K dimension first. if there aren’t enough
values in K, we continue along the non-K dimension.

---------

Co-authored-by: Ognjen Plavsic <[email protected]>
stadlmax pushed a commit to stadlmax/triton that referenced this pull request Aug 4, 2025
This PR implements test kernel for efficient scale packing for CDNA4
arch as well as opSel for scaled MFMA instructions.

Scaled MFMA instructions expect scale operands as 32-bit values, 
even though each individual scale is only 8 bits. 
To reduce register usage, we pack 4 scales into a single 32-bit value
and use the opSel field to select the appropriate byte during execution. 
Packing is done along the K dimension first. if there aren’t enough
values in K, we continue along the non-K dimension.

---------

Co-authored-by: Ognjen Plavsic <[email protected]>
antiagainst pushed a commit that referenced this pull request Nov 6, 2025
Following #7603, this PR
implemented scale preshuffling on gfx1250 for efficient memory access
and better wmma codegen with `opSel`.

As an example, in a mxfp GEMM kernel with `BLOCK_M x BLOCK_N x BLOCK_K`,
scaleA's shape is `BLOCK_M x (BLOCK_K // 32)`. We preshuffle it to be
`(BLOCK_M // 128) x (BLOCK_K x 4)` outside the kernel for better
vectorization, and 'unshuffle' it inside the kernel to get canonical
input to `wmma_scaled` op. Same to scaleB.

Besides, 16x16x128 scaled wmma instruction reads scales only from the
first 16 lanes in a wave, which is a waste of reading capacity.
Therefore we use `opSel` to control wmma instruction to read scales from
the first or last 16 lanes in a wave. So that we can read scales with
all the lanes in a wave.

To correctly issue wmma instructions with `opSel`, we need to group 2
consecutive wmma instruction tiles in a wave. This is done by
introducing `tilesPerWarp` to `AMDWmmaEncodingAttr`, to avoid composing
linear layout in gluon kernel all the time.

This PR also includes the support for inferring padded shared layout for
MemDescReshapeOp, because in case of async/tensor load, we need to do
the 'unshuffling' on memory subview.
jwu10003 pushed a commit to jwu10003/triton that referenced this pull request Nov 7, 2025
…8576)

Following triton-lang#7603, this PR
implemented scale preshuffling on gfx1250 for efficient memory access
and better wmma codegen with `opSel`.

As an example, in a mxfp GEMM kernel with `BLOCK_M x BLOCK_N x BLOCK_K`,
scaleA's shape is `BLOCK_M x (BLOCK_K // 32)`. We preshuffle it to be
`(BLOCK_M // 128) x (BLOCK_K x 4)` outside the kernel for better
vectorization, and 'unshuffle' it inside the kernel to get canonical
input to `wmma_scaled` op. Same to scaleB.

Besides, 16x16x128 scaled wmma instruction reads scales only from the
first 16 lanes in a wave, which is a waste of reading capacity.
Therefore we use `opSel` to control wmma instruction to read scales from
the first or last 16 lanes in a wave. So that we can read scales with
all the lanes in a wave.

To correctly issue wmma instructions with `opSel`, we need to group 2
consecutive wmma instruction tiles in a wave. This is done by
introducing `tilesPerWarp` to `AMDWmmaEncodingAttr`, to avoid composing
linear layout in gluon kernel all the time.

This PR also includes the support for inferring padded shared layout for
MemDescReshapeOp, because in case of async/tensor load, we need to do
the 'unshuffling' on memory subview.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants