-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Implement Scale Preshuffling and opSel on GFX1250 #8576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
--------- Co-authored-by: Kyle Wang <[email protected]>
…CK_K < 128 (triton-lang#231) - Supported using get_wmma_scale_layout to deduce scale layout - Supported layout inference of MemDescReshapeOp - Supported the case when BLOCK_K < 128 for FA --------- Co-authored-by: Kyle Wang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed before, why don't we do this in the frontend, as we do in Nvidia?
cc @ThomasRaoux @antiagainst
|
The preshuffle scheme is done in the kernel with |
|
Hi @lezcano, sorry for misleading. I've updated the PR message and title to provide more context. Please take a look. |
|
Ah, sorry, my mistake. I just went over by the name and changes in the compiler and thought that it was still being implemented in the middle-end. Will review properly tomorrow! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! We can now pretty much do with Padded layouts virtually anything that we could do with non-padded layouts (other than use them in our swizzling algorithm, but that's fine).
Could I get a bit more context on what does it mean the tilesPerWarp on the layouts that it's used, what is the context for it and why wasn't it neded before but it is needed now? Are the layouts on gfx1250 more flexible than they were in previous generations?
| } else if (auto padded = dyn_cast<PaddedSharedEncodingAttr>(srcEnc)) { | ||
| LinearLayout ll = padded.getLinearComponent(); | ||
| LinearLayout dst = reshapeLayout(ctx, ll, dstShape); | ||
| SmallVector<std::pair<unsigned, unsigned>> intervalPads; | ||
| auto intervals = padded.getIntervals(); | ||
| auto paddings = padded.getPaddings(); | ||
| for (auto [interval, padding] : llvm::zip(intervals, paddings)) { | ||
| intervalPads.emplace_back(interval, padding); | ||
| } | ||
| dstEnc = PaddedSharedEncodingAttr::get(ctx, intervalPads, dst); | ||
| return success(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ThomasRaoux @antiagainst I had said before that reshaping a padded layout is not correct in general, but thinking a bit more about it, luckily for us, I think that it is!
The idea is the following:
As per usual, when thinking about shared memory layouts, given that they often do not cover all the shared memory with data (and this is very much the case here with the padding), and given that we do not want to put the same data in two different parts of the shared memory, it is best to think about them through their inverse. In this case, consider P : dim0 x dim1 -> Offset, a padded layout that maps a given matrix into shared memory. P can be decompsed as P = P' o L where L : dim0 x dim1 -> Offset is a linear map, and then P' : Offset -> Offset is non-linear. More explicitly, L is given by (the inverse of) getLinearComponent, and P' is given by
triton/lib/Conversion/TritonGPUToLLVM/Utility.cpp
Lines 469 to 486 in 2b29c3d
| Value emitPadding(Location loc, RewriterBase &rewriter, | |
| triton::gpu::PaddedSharedEncodingAttr layout, | |
| unsigned bitwidth, Value smemOffset, bool offsetInBytes) { | |
| TritonLLVMOpBuilder b(loc, rewriter); | |
| assert((bitwidth >= 8) && "Invalid bitwidth for padded shared layout"); | |
| Value padOffset = b.i32_val(0); | |
| unsigned offScale = offsetInBytes ? bitwidth / 8 : 1; | |
| for (auto [interval, padding] : | |
| llvm::zip_equal(layout.getIntervals(), layout.getPaddings())) { | |
| unsigned intervalScaled = offScale * interval; | |
| unsigned paddingScaled = offScale * padding; | |
| Value iVal = b.i32_val(llvm::Log2_32(intervalScaled)); | |
| Value pVal = b.i32_val(llvm::Log2_32(paddingScaled)); | |
| padOffset = b.add(padOffset, b.shl(b.ashr(smemOffset, iVal), pVal)); | |
| } | |
| return padOffset; | |
| } |
which is indeed a function that takes an offset and returns an offset.
Given this representation, we see that P depends on interval and padding, while L does not. Even more, it should be clear that we can just apply any operation that acts on the logical tensor (what we call views, that is, transposes and reshapes, and with enough care, even subviews) just by applying it to the linear part of the layout and keeping the paddings and intervals the same!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah this is because the padding doesn't have to happen at the end of a row? Does that mean the padding is unaffected by reshapes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because padding just depends on offset of a given element in shared memory, not on the shape of the logical tensor. You can see this is indeed the case because the emitPadding function above does not take the shape of the tensor as an input, it just takes the memory offset.
And yes, padding is unaffected by reshapes. You just reshape the internal linear layout as done here and you are golden.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, yes with this representation it should be easy. One downside I realize now is that allocation size depends on layout which means reinterpret_cast will have extra restrictions on padding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1. My mental model regarding padded shared layout is that we have the n-D logical indexing part + 1-D physical offset padding part, which is nice conceptual boundary to reason about lots of transformations.
Sure! It's conceptually similar to the tilesPerWarp in AMDMfmaEncodingAttr. Normally, A0 and A1, B0 and B1 are independent, we can tile the block with warps just like what we are doing right now:
But if we want to use So we need to make sure wmma instructions for A0xB0, A0xA1, A1xB0, A1xB1 are issued back-to-back in the same warp, with correct That's why we need to tile the block like this:
|
|
Would we ever not want to tile the block as you have latter described? Meaning, could we just make that tiling the default and forget about the other one? |
My understanding is that this puts more restraints on the block size. For other cases, we still enjoy the flexibility. |
|
I don't think that is the case? The layout I am proposing, for a given shape, figures out how many blocks (as per your drawing) each warp needs to own along a given dimension, and makes sure that for a given warp these are contiguous, rather than broadcasting the initial tile. |
Hi @lezcano, you are right. I was thinking following current implementation, but we can definitely calculate the number of repetitions ahead of time then broadcast in the latter way. I didn't find any problems in it. @antiagainst @zhanglx13 wdyt? |
|
We have the "strided" repetition behavior as the default for warp everywhere I believe. This would be a flip specifically for WMMA/MFMA layout to have a different default, which can be a catch. We want this contiguous behavior for scaled wmma (and mfma16) but not others. So flipping everything for it seems more than needed--it can have implication over subtle aspect like downstream llvm instruction scheduling and register alloc so need to be a bit careful. My preference would be have this control via tilesPerWarp for consistent default layout warp behavior and opt in where needed. |
Following #7603, this PR implemented scale preshuffling on gfx1250 for efficient memory access and better wmma codegen with
opSel.As an example, in a mxfp GEMM kernel with
BLOCK_M x BLOCK_N x BLOCK_K, scaleA's shape isBLOCK_M x (BLOCK_K // 32). We preshuffle it to be(BLOCK_M // 128) x (BLOCK_K x 4)outside the kernel for better vectorization, and 'unshuffle' it inside the kernel to get canonical input towmma_scaledop. Same to scaleB.Besides, 16x16x128 scaled wmma instruction reads scales only from the first 16 lanes in a wave, which is a waste of reading capacity. Therefore we use
opSelto control wmma instruction to read scales from the first or last 16 lanes in a wave. So that we can read scales with all the lanes in a wave.To correctly issue wmma instructions with
opSel, we need to group 2 consecutive wmma instruction tiles in a wave. This is done by introducingtilesPerWarptoAMDWmmaEncodingAttr, to avoid composing linear layout in gluon kernel all the time.This PR also includes the support for inferring padded shared layout for MemDescReshapeOp, because in case of async/tensor load, we need to do the 'unshuffling' on memory subview.