Skip to content

Commit ce8add2

Browse files
committed
Fix lit test
1 parent 13394f1 commit ce8add2

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,12 @@ template <typename Op> Op getNextOp(Value op) {
755755
}
756756

757757
bool scalePreshuffled(Value scale) {
758+
if (!scale) {
759+
return false;
760+
}
761+
758762
auto shape = cast<RankedTensorType>(scale.getType()).getShape();
763+
759764
int rank = shape.size();
760765
int blockNonK = shape[rank - 2];
761766
// 1 scale always scales 32 elements along K dim
@@ -785,7 +790,7 @@ bool scalePreshuffled(Value scale) {
785790
}
786791

787792
SmallVector<unsigned, 2> getTilesPerWarp(Value aScale, Value bScale) {
788-
if (scalePreshuffled(aScale) && scalePreshuffled(bScale)) {
793+
if (scalePreshuffled(aScale) || scalePreshuffled(bScale)) {
789794
return {2, 2};
790795
}
791796
return {1, 1};

0 commit comments

Comments
 (0)