|
24 | 24 | #include "mlir/IR/BuiltinAttributes.h" |
25 | 25 | #include "mlir/IR/BuiltinTypes.h" |
26 | 26 | #include "mlir/IR/TypeUtilities.h" |
| 27 | +#include "mlir/IR/ValueRange.h" |
| 28 | +#include "llvm/ADT/STLExtras.h" |
27 | 29 | #include "llvm/ADT/StringSwitch.h" |
28 | 30 |
|
29 | 31 | #include <cassert> |
30 | 32 |
|
31 | 33 | namespace mlir { |
| 34 | +//===----------------------------------------------------------------------===// |
| 35 | +// Patterns and helpers used by both the KHR and the NV lowering paths. |
| 36 | +//===----------------------------------------------------------------------===// |
| 37 | + |
32 | 38 | /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op |
33 | 39 | /// when the elementwise op directly supports with cooperative matrix type. |
34 | 40 | /// Returns false if cannot. |
@@ -77,6 +83,119 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder, |
77 | 83 | return false; |
78 | 84 | } |
79 | 85 |
|
| 86 | +bool allOperandsHaveSameCoopMatrixType(ValueRange operands) { |
| 87 | + assert(!operands.empty()); |
| 88 | + if (!llvm::all_equal( |
| 89 | + llvm::map_range(operands, [](Value v) { return v.getType(); }))) |
| 90 | + return false; |
| 91 | + |
| 92 | + return isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>( |
| 93 | + operands.front().getType()); |
| 94 | +} |
| 95 | + |
| 96 | +namespace { |
| 97 | +/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative |
| 98 | +/// matrix ops. |
| 99 | +struct WmmaConstantOpToSPIRVLowering final |
| 100 | + : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> { |
| 101 | + using OpConversionPattern::OpConversionPattern; |
| 102 | + |
| 103 | + LogicalResult |
| 104 | + matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor, |
| 105 | + ConversionPatternRewriter &rewriter) const override { |
| 106 | + assert(adaptor.getOperands().size() == 1); |
| 107 | + Value cst = adaptor.getOperands().front(); |
| 108 | + auto coopType = getTypeConverter()->convertType(op.getType()); |
| 109 | + if (!coopType) |
| 110 | + return rewriter.notifyMatchFailure(op, "type conversion failed"); |
| 111 | + |
| 112 | + rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst); |
| 113 | + return success(); |
| 114 | + } |
| 115 | +}; |
| 116 | + |
| 117 | +/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for |
| 118 | +/// the default case. |
| 119 | +struct WmmaElementwiseOpToSPIRVDefaultLowering final |
| 120 | + : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> { |
| 121 | + using OpConversionPattern::OpConversionPattern; |
| 122 | + |
| 123 | + LogicalResult |
| 124 | + matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor, |
| 125 | + ConversionPatternRewriter &rewriter) const override { |
| 126 | + // All operands should be of cooperative matrix types. |
| 127 | + if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) { |
| 128 | + return rewriter.notifyMatchFailure(op, |
| 129 | + "not all operands are coop matrices"); |
| 130 | + } |
| 131 | + |
| 132 | + auto coopType = getTypeConverter()->convertType(op.getType()); |
| 133 | + if (!coopType) |
| 134 | + return rewriter.notifyMatchFailure(op, "type conversion failed"); |
| 135 | + |
| 136 | + return success( |
| 137 | + createElementwiseOp(rewriter, op, coopType, adaptor.getOperands())); |
| 138 | + } |
| 139 | +}; |
| 140 | + |
| 141 | +/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for |
| 142 | +/// matrix times scalar case. |
| 143 | +struct WmmaElementwiseOpToSPIRVScalarMulLowering final |
| 144 | + : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> { |
| 145 | + using OpConversionPattern::OpConversionPattern; |
| 146 | + |
| 147 | + LogicalResult |
| 148 | + matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor, |
| 149 | + ConversionPatternRewriter &rewriter) const override { |
| 150 | + if (adaptor.getOperands().size() != 2) |
| 151 | + return failure(); |
| 152 | + |
| 153 | + // All operands should be of cooperative matrix types. |
| 154 | + if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) { |
| 155 | + return rewriter.notifyMatchFailure(op, |
| 156 | + "not all operands are coop matrices"); |
| 157 | + } |
| 158 | + |
| 159 | + if (op.getOpType() != gpu::MMAElementwiseOp::MULF) |
| 160 | + return failure(); |
| 161 | + |
| 162 | + // Use the original operands to check whether one of the operands is a splat |
| 163 | + // scalar value. |
| 164 | + Value lhs = op.getOperands().front(); |
| 165 | + Value rhs = op.getOperands().back(); |
| 166 | + Value splat = nullptr; |
| 167 | + Value matrix = nullptr; |
| 168 | + if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) { |
| 169 | + splat = adaptor.getOperands().front(); |
| 170 | + matrix = adaptor.getOperands().back(); |
| 171 | + } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) { |
| 172 | + matrix = adaptor.getOperands().front(); |
| 173 | + splat = adaptor.getOperands().back(); |
| 174 | + } |
| 175 | + if (!splat || !matrix) |
| 176 | + return rewriter.notifyMatchFailure(op, "no splat operand"); |
| 177 | + |
| 178 | + // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops. |
| 179 | + Value scalar; |
| 180 | + auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>(); |
| 181 | + if (!cc) { |
| 182 | + return rewriter.notifyMatchFailure(op, |
| 183 | + "splat is not a composite construct"); |
| 184 | + } |
| 185 | + |
| 186 | + assert(cc.getConstituents().size() == 1); |
| 187 | + scalar = cc.getConstituents().front(); |
| 188 | + |
| 189 | + auto coopType = getTypeConverter()->convertType(op.getType()); |
| 190 | + if (!coopType) |
| 191 | + return rewriter.notifyMatchFailure(op, "type conversion failed"); |
| 192 | + rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>( |
| 193 | + op, coopType, ValueRange{matrix, scalar}); |
| 194 | + return success(); |
| 195 | + } |
| 196 | +}; |
| 197 | +} // namespace |
| 198 | + |
80 | 199 | //===----------------------------------------------------------------------===// |
81 | 200 | // SPV_KHR_cooperative_matrix |
82 | 201 | //===----------------------------------------------------------------------===// |
@@ -262,100 +381,6 @@ struct WmmaMmaOpToSPIRVLowering final |
262 | 381 | } |
263 | 382 | }; |
264 | 383 |
|
265 | | -/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V NV cooperative matrix |
266 | | -/// ops. |
267 | | -struct WmmaConstantOpToSPIRVLowering final |
268 | | - : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> { |
269 | | - using OpConversionPattern::OpConversionPattern; |
270 | | - |
271 | | - LogicalResult |
272 | | - matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantMatrixOp, |
273 | | - OpAdaptor adaptor, |
274 | | - ConversionPatternRewriter &rewriter) const override { |
275 | | - Value cst = adaptor.getOperands()[0]; |
276 | | - auto coopType = convertMMAToSPIRVCoopMatrixNVType( |
277 | | - cast<gpu::MMAMatrixType>(subgroupMmaConstantMatrixOp.getType())); |
278 | | - rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( |
279 | | - subgroupMmaConstantMatrixOp, coopType, cst); |
280 | | - return success(); |
281 | | - } |
282 | | -}; |
283 | | - |
284 | | -/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for |
285 | | -/// the default case. |
286 | | -struct WmmaElementwiseOpToSPIRVDefaultLowering final |
287 | | - : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> { |
288 | | - using OpConversionPattern::OpConversionPattern; |
289 | | - |
290 | | - LogicalResult |
291 | | - matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp, |
292 | | - OpAdaptor adaptor, |
293 | | - ConversionPatternRewriter &rewriter) const override { |
294 | | - // All operands should be of cooperative matrix types. |
295 | | - for (Value operand : adaptor.getOperands()) { |
296 | | - if (!isa<spirv::CooperativeMatrixNVType>(operand.getType())) |
297 | | - return failure(); |
298 | | - } |
299 | | - auto coopType = convertMMAToSPIRVCoopMatrixNVType( |
300 | | - cast<gpu::MMAMatrixType>(elementwiseOp.getType())); |
301 | | - return success(createElementwiseOp(rewriter, elementwiseOp, coopType, |
302 | | - adaptor.getOperands())); |
303 | | - } |
304 | | -}; |
305 | | - |
306 | | -/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for |
307 | | -/// matrix times scalar case. |
308 | | -struct WmmaElementwiseOpToSPIRVScalarMulLowering final |
309 | | - : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> { |
310 | | - using OpConversionPattern::OpConversionPattern; |
311 | | - |
312 | | - LogicalResult |
313 | | - matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp, |
314 | | - OpAdaptor adaptor, |
315 | | - ConversionPatternRewriter &rewriter) const override { |
316 | | - if (adaptor.getOperands().size() != 2) |
317 | | - return failure(); |
318 | | - // All operands should be of cooperative matrix types. |
319 | | - for (Value operand : adaptor.getOperands()) { |
320 | | - if (!isa<spirv::CooperativeMatrixNVType>(operand.getType())) |
321 | | - return failure(); |
322 | | - } |
323 | | - |
324 | | - if (elementwiseOp.getOpType() != gpu::MMAElementwiseOp::MULF) |
325 | | - return failure(); |
326 | | - |
327 | | - // Use the original operands to check whether one of the operands is a splat |
328 | | - // scalar value. |
329 | | - Value lhs = elementwiseOp.getOperands().front(); |
330 | | - Value rhs = elementwiseOp.getOperands().back(); |
331 | | - Value splat = nullptr; |
332 | | - Value matrix = nullptr; |
333 | | - if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) { |
334 | | - splat = adaptor.getOperands().front(); |
335 | | - matrix = adaptor.getOperands().back(); |
336 | | - } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) { |
337 | | - matrix = adaptor.getOperands().front(); |
338 | | - splat = adaptor.getOperands().back(); |
339 | | - } |
340 | | - if (!splat || !matrix) |
341 | | - return failure(); |
342 | | - |
343 | | - // Constant MMA matrix ops are converted to spirv.CompositeConstruct ops. |
344 | | - Value scalar = nullptr; |
345 | | - auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>(); |
346 | | - if (!cc) |
347 | | - return failure(); |
348 | | - assert(cc.getConstituents().size() == 1); |
349 | | - scalar = cc.getConstituents().front(); |
350 | | - |
351 | | - auto coopType = convertMMAToSPIRVCoopMatrixNVType( |
352 | | - cast<gpu::MMAMatrixType>(elementwiseOp.getType())); |
353 | | - rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>( |
354 | | - elementwiseOp, coopType, ValueRange{matrix, scalar}); |
355 | | - return success(); |
356 | | - } |
357 | | -}; |
358 | | - |
359 | 384 | } // namespace |
360 | 385 | } // namespace nv |
361 | 386 | } // namespace mlir |
@@ -389,19 +414,21 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns( |
389 | 414 | using namespace mlir; |
390 | 415 | MLIRContext *context = patterns.getContext(); |
391 | 416 | patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering, |
392 | | - khr::WmmaStoreOpToSPIRVLowering>(converter, context); |
| 417 | + khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering, |
| 418 | + WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context); |
| 419 | + // Give the following patterns higher benefit to prevail over the default one. |
| 420 | + patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context, |
| 421 | + /*benefit=*/2); |
393 | 422 | } |
394 | 423 |
|
395 | 424 | void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns( |
396 | 425 | SPIRVTypeConverter &converter, RewritePatternSet &patterns) { |
397 | 426 | using namespace mlir; |
398 | 427 | MLIRContext *context = patterns.getContext(); |
399 | | - patterns |
400 | | - .add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering, |
401 | | - nv::WmmaStoreOpToSPIRVLowering, nv::WmmaConstantOpToSPIRVLowering, |
402 | | - nv::WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context); |
| 428 | + patterns.add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering, |
| 429 | + nv::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering, |
| 430 | + WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context); |
403 | 431 | // Give the following patterns higher benefit to prevail over the default one. |
404 | | - patterns.add<nv::WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, |
405 | | - context, |
406 | | - /*benefit=*/2); |
| 432 | + patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context, |
| 433 | + /*benefit=*/2); |
407 | 434 | } |
0 commit comments