Skip to content

Conversation

@knwng
Copy link
Contributor

@knwng knwng commented Oct 28, 2025

Following #7603, this PR implemented scale preshuffling on gfx1250 for efficient memory access and better wmma codegen with opSel.

As an example, in a mxfp GEMM kernel with BLOCK_M x BLOCK_N x BLOCK_K, scaleA's shape is BLOCK_M x (BLOCK_K // 32). We preshuffle it to be (BLOCK_M // 128) x (BLOCK_K x 4) outside the kernel for better vectorization, and 'unshuffle' it inside the kernel to get canonical input to wmma_scaled op. Same to scaleB.

Besides, 16x16x128 scaled wmma instruction reads scales only from the first 16 lanes in a wave, which is a waste of reading capacity. Therefore we use opSel to control wmma instruction to read scales from the first or last 16 lanes in a wave. So that we can read scales with all the lanes in a wave.

To correctly issue wmma instructions with opSel, we need to group 2 consecutive wmma instruction tiles in a wave. This is done by introducing tilesPerWarp to AMDWmmaEncodingAttr, to avoid composing linear layout in gluon kernel all the time.

This PR also includes the support for inferring padded shared layout for MemDescReshapeOp, because in case of async/tensor load, we need to do the 'unshuffling' on memory subview.

knwng and others added 5 commits October 24, 2025 18:41
…CK_K < 128 (triton-lang#231)

- Supported using get_wmma_scale_layout to deduce scale layout
- Supported layout inference of MemDescReshapeOp
- Supported the case when BLOCK_K < 128 for FA

---------

Co-authored-by: Kyle Wang <[email protected]>
Copy link
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed before, why don't we do this in the frontend, as we do in Nvidia?
cc @ThomasRaoux @antiagainst

@antiagainst
Copy link
Collaborator

The preshuffle scheme is done in the kernel with .reshape(..) and .trans(...) etc if you check the test_gluon_gfx1250.py file so it's following NVIDIA flow. Changes to the compiler is meant to 1) support opsel when generating scaled wmma intrinsics for better CodeGen, and 2) make writing wmma layout easier with tilesPerWarp to avoid always resorting to linear layout in Gluon. I guess @knwng you should be explicit and provide more context in pull request messages regarding these changes.

@knwng
Copy link
Contributor Author

knwng commented Oct 28, 2025

Hi @lezcano, sorry for misleading. I've updated the PR message and title to provide more context. Please take a look.

@knwng knwng changed the title Support Scale Preshuffling on GFX1250 Implement Scale Preshuffling and opSel on GFX1250 Oct 28, 2025
@lezcano
Copy link
Contributor

lezcano commented Oct 29, 2025

Ah, sorry, my mistake. I just went over by the name and changes in the compiler and thought that it was still being implemented in the middle-end. Will review properly tomorrow!

@antiagainst antiagainst marked this pull request as ready for review October 29, 2025 03:40
Copy link
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! We can now pretty much do with Padded layouts virtually anything that we could do with non-padded layouts (other than use them in our swizzling algorithm, but that's fine).

Could I get a bit more context on what does it mean the tilesPerWarp on the layouts that it's used, what is the context for it and why wasn't it neded before but it is needed now? Are the layouts on gfx1250 more flexible than they were in previous generations?

Comment on lines +609 to +619
} else if (auto padded = dyn_cast<PaddedSharedEncodingAttr>(srcEnc)) {
LinearLayout ll = padded.getLinearComponent();
LinearLayout dst = reshapeLayout(ctx, ll, dstShape);
SmallVector<std::pair<unsigned, unsigned>> intervalPads;
auto intervals = padded.getIntervals();
auto paddings = padded.getPaddings();
for (auto [interval, padding] : llvm::zip(intervals, paddings)) {
intervalPads.emplace_back(interval, padding);
}
dstEnc = PaddedSharedEncodingAttr::get(ctx, intervalPads, dst);
return success();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ThomasRaoux @antiagainst I had said before that reshaping a padded layout is not correct in general, but thinking a bit more about it, luckily for us, I think that it is!
The idea is the following:
As per usual, when thinking about shared memory layouts, given that they often do not cover all the shared memory with data (and this is very much the case here with the padding), and given that we do not want to put the same data in two different parts of the shared memory, it is best to think about them through their inverse. In this case, consider P : dim0 x dim1 -> Offset, a padded layout that maps a given matrix into shared memory. P can be decompsed as P = P' o L where L : dim0 x dim1 -> Offset is a linear map, and then P' : Offset -> Offset is non-linear. More explicitly, L is given by (the inverse of) getLinearComponent, and P' is given by

Value emitPadding(Location loc, RewriterBase &rewriter,
triton::gpu::PaddedSharedEncodingAttr layout,
unsigned bitwidth, Value smemOffset, bool offsetInBytes) {
TritonLLVMOpBuilder b(loc, rewriter);
assert((bitwidth >= 8) && "Invalid bitwidth for padded shared layout");
Value padOffset = b.i32_val(0);
unsigned offScale = offsetInBytes ? bitwidth / 8 : 1;
for (auto [interval, padding] :
llvm::zip_equal(layout.getIntervals(), layout.getPaddings())) {
unsigned intervalScaled = offScale * interval;
unsigned paddingScaled = offScale * padding;
Value iVal = b.i32_val(llvm::Log2_32(intervalScaled));
Value pVal = b.i32_val(llvm::Log2_32(paddingScaled));
padOffset = b.add(padOffset, b.shl(b.ashr(smemOffset, iVal), pVal));
}
return padOffset;
}

which is indeed a function that takes an offset and returns an offset.

Given this representation, we see that P depends on interval and padding, while L does not. Even more, it should be clear that we can just apply any operation that acts on the logical tensor (what we call views, that is, transposes and reshapes, and with enough care, even subviews) just by applying it to the linear part of the layout and keeping the paddings and intervals the same!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah this is because the padding doesn't have to happen at the end of a row? Does that mean the padding is unaffected by reshapes?

Copy link
Contributor

@lezcano lezcano Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because padding just depends on offset of a given element in shared memory, not on the shape of the logical tensor. You can see this is indeed the case because the emitPadding function above does not take the shape of the tensor as an input, it just takes the memory offset.

And yes, padding is unaffected by reshapes. You just reshape the internal linear layout as done here and you are golden.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, yes with this representation it should be easy. One downside I realize now is that allocation size depends on layout which means reinterpret_cast will have extra restrictions on padding

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. My mental model regarding padded shared layout is that we have the n-D logical indexing part + 1-D physical offset padding part, which is nice conceptual boundary to reason about lots of transformations.

@knwng
Copy link
Contributor Author

knwng commented Oct 29, 2025

Could I get a bit more context on what does it mean the tilesPerWarp on the layouts that it's used, what is the context for it and why wasn't it neded before but it is needed now? Are the layouts on gfx1250 more flexible than they were in previous generations?

Sure! It's conceptually similar to the tilesPerWarp in AMDMfmaEncodingAttr.

Normally, A0 and A1, B0 and B1 are independent, we can tile the block with warps just like what we are doing right now:

B0 B1 B2 B3
A0 w0 w1 w0 w1
A1 w2 w3 w2 w3
A2 w0 w1 w0 w1
A3 w2 w3 w2 w3

But if we want to use opSel, they are no longer indenpendent, in the sense that, we put scaleA0 in [0, 15] lanes of vregSA0, scaleA1 in [16, 31] lanes of vregSA0. Similarly, scaleB0 is in [0, 15] lanes of vregSB0, and scaleB1 is in [16, 31] lanes of vregSB0. So A0 and A1 both need to get scales from vregSA0, but from different lanes. Same as B0 and B1.

So we need to make sure wmma instructions for A0xB0, A0xA1, A1xB0, A1xB1 are issued back-to-back in the same warp, with correct opSel setting, concretely:

wmma A0, B0, vregSA0, vregSB0, opSelA=0, opSelB=0
wmma A0, B1, vregSA0, vregSB0, opSelA=0, opSelB=1
wmma A1, B0, vregSA0, vregSB0, opSelA=1, opSelB=0
wmma A1, B1, vregSA0, vregSB0, opSelA=1, opSelB=1

That's why we need to tile the block like this:

B0 B1 B2 B3
A0 w0 w0 w1 w1
A1 w0 w0 w1 w1
A2 w2 w2 w3 w3
A3 w2 w2 w3 w3

@lezcano
Copy link
Contributor

lezcano commented Oct 29, 2025

Would we ever not want to tile the block as you have latter described? Meaning, could we just make that tiling the default and forget about the other one?

@knwng
Copy link
Contributor Author

knwng commented Oct 29, 2025

Would we ever not want to tile the block as you have latter described? Meaning, could we just make that tiling the default and forget about the other one?

My understanding is that this puts more restraints on the block size. For other cases, we still enjoy the flexibility.

@lezcano
Copy link
Contributor

lezcano commented Oct 29, 2025

I don't think that is the case? The layout I am proposing, for a given shape, figures out how many blocks (as per your drawing) each warp needs to own along a given dimension, and makes sure that for a given warp these are contiguous, rather than broadcasting the initial tile.
In what sense do you say this would put more constraints on the block size? Is it not just a different broadcasting order?

@knwng
Copy link
Contributor Author

knwng commented Oct 31, 2025

I don't think that is the case? The layout I am proposing, for a given shape, figures out how many blocks (as per your drawing) each warp needs to own along a given dimension, and makes sure that for a given warp these are contiguous, rather than broadcasting the initial tile. In what sense do you say this would put more constraints on the block size? Is it not just a different broadcasting order?

Hi @lezcano, you are right. I was thinking following current implementation, but we can definitely calculate the number of repetitions ahead of time then broadcast in the latter way. I didn't find any problems in it. @antiagainst @zhanglx13 wdyt?

@antiagainst
Copy link
Collaborator

We have the "strided" repetition behavior as the default for warp everywhere I believe. This would be a flip specifically for WMMA/MFMA layout to have a different default, which can be a catch. We want this contiguous behavior for scaled wmma (and mfma16) but not others. So flipping everything for it seems more than needed--it can have implication over subtle aspect like downstream llvm instruction scheduling and register alloc so need to be a bit careful. My preference would be have this control via tilesPerWarp for consistent default layout warp behavior and opt in where needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants