File tree Expand file tree Collapse file tree 1 file changed +6
-1
lines changed
third_party/amd/lib/TritonAMDGPUTransforms Expand file tree Collapse file tree 1 file changed +6
-1
lines changed Original file line number Diff line number Diff line change @@ -755,7 +755,12 @@ template <typename Op> Op getNextOp(Value op) {
755755}
756756
757757bool scalePreshuffled (Value scale) {
758+ if (!scale) {
759+ return false ;
760+ }
761+
758762 auto shape = cast<RankedTensorType>(scale.getType ()).getShape ();
763+
759764 int rank = shape.size ();
760765 int blockNonK = shape[rank - 2 ];
761766 // 1 scale always scales 32 elements along K dim
@@ -785,7 +790,7 @@ bool scalePreshuffled(Value scale) {
785790}
786791
787792SmallVector<unsigned , 2 > getTilesPerWarp (Value aScale, Value bScale) {
788- if (scalePreshuffled (aScale) && scalePreshuffled (bScale)) {
793+ if (scalePreshuffled (aScale) || scalePreshuffled (bScale)) {
789794 return {2 , 2 };
790795 }
791796 return {1 , 1 };
You can’t perform that action at this time.
0 commit comments