diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 57de402ea774..1df4ff0d38d8 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -2645,3 +2645,28 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ tt.return } } + +// ----- + +// We had a bug where DotOp lowering treated any input where shape[1] == 1 as an +// outer product and rejected it. This was incorrect in 3D tensors, since +// the dimension to look at would have been shape[2]. + +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [32, 1, 1], instrShape = [1, 16, 8]}> +#dot_operand_a = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}> +#dot_operand_b = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: batched_dot_3d + tt.func public @batched_dot_3d( + %arg0: tensor<32x1x32xf16, #dot_operand_a>, + %arg1: tensor<32x32x32xf16, #dot_operand_b> + ) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x1x32xf32, #mma> + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 + %result = tt.dot %arg0, %arg1, %cst, inputPrecision = tf32 : + tensor<32x1x32xf16, #dot_operand_a> * tensor<32x32x32xf16, #dot_operand_b> -> tensor<32x1x32xf32, #mma> + tt.return + } +} diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp index c414bffef73a..d0de6c36f73d 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp @@ -60,15 +60,9 @@ struct DotOpConversion : public ConvertOpToLLVMPattern { Value A = op.getA(); Value D = op.getResult(); - // Here we assume the DotOp's operands always comes from shared memory. - auto AShapePerCTA = getShapePerCTA(A.getType()); - size_t reduceAxis = 1; - unsigned K = AShapePerCTA[reduceAxis]; - bool isOuter = K == 1; - NvidiaMmaEncodingAttr mmaLayout = dyn_cast( cast(D.getType()).getEncoding()); - if (!isOuter && mmaLayout && supportMMA(op, mmaLayout.getVersionMajor())) { + if (mmaLayout && supportMMA(op, mmaLayout.getVersionMajor())) { if (mmaLayout.getVersionMajor() == 2) { bool isHopperF64 = computeCapability == 90 && @@ -106,14 +100,8 @@ struct WarpGroupDotOpConversion Value A = op.getA(); TypedValue D = op.getResult(); - // Here we assume the DotOp's operands always comes from shared memory. - auto AShapePerCTA = getShapePerCTA(A.getType()); - size_t reduceAxis = 1; - unsigned K = AShapePerCTA[reduceAxis]; - bool isOuter = K == 1; - auto mmaLayout = cast(D.getType().getEncoding()); - if (!isOuter && supportMMA(op.getOperand(0), mmaLayout.getVersionMajor())) { + if (supportMMA(op.getOperand(0), mmaLayout.getVersionMajor())) { return convertWGMMA(op, adaptor, getTypeConverter(), rewriter, getThreadId(rewriter, loc)); }