-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[AMD] Scale preshuffling and opSel implementation #7603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4fae5e5
1bad406
14a8271
9216e45
71e763b
4066c83
40bce9c
1efeccc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -439,8 +439,15 @@ struct DotOpMFMAConversionHelper { | |
| results = b.zext(i32_ty, b.bitcast(vec, i8_ty)); | ||
| } | ||
| } | ||
|
|
||
| if (2 == kBase) | ||
| // This case can occur during scale tensor packing when there aren't | ||
| // enough elements to fill all 4 opSel slots. For example, with an A | ||
| // tensor of size 16x256 and using 16x16x128 block sizes, we end up with | ||
| // only 2 elements to pack, resulting in a kBase of 2. | ||
| results = b.zext(i32_ty, b.bitcast(vec, i16_ty)); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add some comments to explain this case? |
||
| if (4 == kBase) | ||
| // This is for int8 on pre- CDNA3 GPUs | ||
| // This is for int8 on pre- CDNA3 GPUs and scale tensors on CDNA4 GPUs | ||
| results = b.bitcast(vec, i32_ty); | ||
| if (8 == kBase) | ||
| results = b.bitcast(vec, i64_ty); | ||
|
|
@@ -465,6 +472,11 @@ struct DotOpMFMAConversionHelper { | |
| auto elems = unpackLLElements(loc, value, rewriter); | ||
| // number of kBase-element vectors | ||
| int numVecInKBase = kRepInKWidth * kWidth / kBase; | ||
| if (numVecInKBase == 0) { | ||
| numVecInKBase = 1; | ||
| nonKRep /= kBase / (kRepInKWidth * kWidth); | ||
| assert(nonKRep > 0 && "nonKrep too small"); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we still need this assert? |
||
| } | ||
| ValueTable dotOpVals; | ||
|
|
||
| SmallVector<int64_t> strides = | ||
|
|
@@ -544,17 +556,19 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper { | |
|
|
||
| Value generateScaledMFMAOp(StringRef intrinsicName, Value valA, Value valB, | ||
| Value valC, Value valScaleA, Value valScaleB, | ||
| Type elemTypeA, Type elemTypeB) const { | ||
| Type elemTypeA, Type elemTypeB, int opSelA, | ||
| int opSelB) const { | ||
| auto b = TritonLLVMOpBuilder(loc, rewriter); | ||
| auto resType = valC.getType(); | ||
| Value zeroFlag = b.i32_val(0); | ||
| Value valOpSelA = b.i32_val(opSelA); | ||
| Value valOpSelB = b.i32_val(opSelB); | ||
| OperationState loweredOp(loc, intrinsicName); | ||
| int32_t cbsz = getMfmaF8F6F4MatrixFormat(elemTypeA); | ||
| int32_t blgp = getMfmaF8F6F4MatrixFormat(elemTypeB); | ||
| assert((cbsz != -1) && (blgp != -1)); | ||
| loweredOp.addTypes(resType); | ||
| loweredOp.addOperands({valA, valB, valC, b.i32_val(cbsz), b.i32_val(blgp), | ||
| zeroFlag, valScaleA, zeroFlag, valScaleB}); | ||
| valOpSelA, valScaleA, valOpSelB, valScaleB}); | ||
| return rewriter.create(loweredOp)->getResult(0); | ||
| } | ||
|
|
||
|
|
@@ -636,8 +650,6 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper { | |
| // better way to get it when adapting other data types. Similar to | ||
| // scaleKBase | ||
| constexpr int scaleKWidth = 1; | ||
| constexpr int scaleKBase = 1; | ||
|
|
||
| Value loadedA = adaptor.getA(); | ||
| Value loadedB = adaptor.getB(); | ||
| Value loadedAScale = adaptor.getAScale(); | ||
|
|
@@ -650,6 +662,27 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper { | |
| auto numRepB = repA[0]; | ||
| assert(repA[0] == repB[0]); | ||
|
|
||
| // Scaled MFMA instructions expect scale operands as 32-bit values, | ||
| // even though each individual scale is only 8 bits. To reduce register | ||
| // usage, we pack 4 scales into a single 32-bit value and use the opSel | ||
| // field to select the appropriate byte during execution. Packing is done | ||
| // along the K dimension first; if there aren’t enough values in K, we | ||
| // continue along the non-K dimension. | ||
| // TODO: Support opSel selection for constant scales stored in SGPRs. | ||
| const int scaleAKBase = | ||
| isAScaleConstant ? 1 : std::min(4, static_cast<int>(numRepK * numRepM)); | ||
| const int scaleBKBase = | ||
| isBScaleConstant ? 1 : std::min(4, static_cast<int>(numRepK * numRepN)); | ||
|
|
||
| int akPackedVals = | ||
| isAScaleConstant ? 1 : std::min(4, static_cast<int>(numRepK)); | ||
| int bkPackedVals = | ||
| isBScaleConstant ? 1 : std::min(4, static_cast<int>(numRepK)); | ||
|
|
||
| assert(scaleAKBase % akPackedVals == 0 && scaleBKBase % bkPackedVals == 0); | ||
| int aNonKPackedVals = scaleAKBase / akPackedVals; | ||
| int bNonKPackedVals = scaleBKBase / bkPackedVals; | ||
|
|
||
| auto operandA = getValuesFromDotOperandLayoutStruct( | ||
| loadedA, numRepB, numRepM, numRepK, aKWidth, aKBase, | ||
| aTensorTy.getElementType(), allowXF32, /*preserveBF16=*/false); | ||
|
|
@@ -664,13 +697,13 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper { | |
| if (existBothScales) { | ||
| auto aScaleTensorTy = cast<RankedTensorType>(aScale.getType()); | ||
| operandAScale = getValuesFromDotOperandLayoutStruct( | ||
| loadedAScale, numRepB, numRepM, numRepK, scaleKWidth, scaleKBase, | ||
| loadedAScale, numRepB, numRepM, numRepK, scaleKWidth, scaleAKBase, | ||
| aScaleTensorTy.getElementType(), allowXF32, /*preserveBF16=*/false, | ||
| isAScaleConstant); | ||
|
|
||
| auto bScaleTensorTy = cast<RankedTensorType>(bScale.getType()); | ||
| operandBScale = getValuesFromDotOperandLayoutStruct( | ||
| loadedBScale, numRepB, numRepN, numRepK, scaleKWidth, scaleKBase, | ||
| loadedBScale, numRepB, numRepN, numRepK, scaleKWidth, scaleBKBase, | ||
| bScaleTensorTy.getElementType(), allowXF32, /*preserveBF16=*/false, | ||
| isBScaleConstant); | ||
| } | ||
|
|
@@ -731,18 +764,29 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper { | |
| for (innerK = 0; innerK < innerKBound; innerK++) { | ||
| int k = is2Step ? outerK : innerK; | ||
| if (existBothScales) { | ||
| int akScale = k / akPackedVals; | ||
| int bkScale = k / bkPackedVals; | ||
| int opSelA = 0, opSelB = 0; | ||
|
|
||
| int mScale = m / aNonKPackedVals; | ||
| int nScale = n / bNonKPackedVals; | ||
| opSelA = (m * numRepK + k) % (aNonKPackedVals * akPackedVals); | ||
| opSelB = (n * numRepK + k) % (bNonKPackedVals * bkPackedVals); | ||
|
|
||
| if (mfmaLayout.getIsTransposed()) { | ||
| acc = generateScaledMFMAOp( | ||
| intrinsicName, operandB[{b, n, k}], operandA[{b, m, k}], | ||
| acc, operandBScale[{b, n, k}], operandAScale[{b, m, k}], | ||
| acc, operandBScale[{b, nScale, bkScale}], | ||
| operandAScale[{b, mScale, akScale}], | ||
| maybeMfmaIntrinsic->bElementType, | ||
| maybeMfmaIntrinsic->aElementType); | ||
| maybeMfmaIntrinsic->aElementType, opSelB, opSelA); | ||
| } else { | ||
| acc = generateScaledMFMAOp( | ||
| intrinsicName, operandA[{b, m, k}], operandB[{b, n, k}], | ||
| acc, operandAScale[{b, m, k}], operandBScale[{b, n, k}], | ||
| acc, operandAScale[{b, mScale, akScale}], | ||
| operandBScale[{b, nScale, bkScale}], | ||
| maybeMfmaIntrinsic->aElementType, | ||
| maybeMfmaIntrinsic->bElementType); | ||
| maybeMfmaIntrinsic->bElementType, opSelA, opSelB); | ||
| } | ||
| } else { | ||
| if (mfmaLayout.getIsTransposed()) { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.