@@ -3640,36 +3640,38 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMa<[100, 101]>]> {
36403640}
36413641
36423642//===----------------------------------------------------------------------===//
3643- // NVVM dot.accumulate.4way Op
3643+ // NVVM dot.accumulate Ops
36443644//===----------------------------------------------------------------------===//
36453645
3646- def DotAccumulate4WayS8 : I32EnumAttrCase<"S8 ", 1 , "s8 ">;
3647- def DotAccumulate4WayU8 : I32EnumAttrCase<"U8 ", 0 , "u8 ">;
3646+ def DotAccumulateUnsigned : I32EnumAttrCase<"UNSIGNED ", 0 , "unsigned ">;
3647+ def DotAccumulateSigned : I32EnumAttrCase<"SIGNED ", 1 , "signed ">;
36483648
3649- def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType ",
3650- "NVVM DotAccumulate4WayType ",
3651- [DotAccumulate4WayS8, DotAccumulate4WayU8 ]> {
3649+ def DotAccumulateType : I32EnumAttr<"DotAccumulateType ",
3650+ "NVVM DotAccumulateType ",
3651+ [DotAccumulateSigned, DotAccumulateUnsigned ]> {
36523652 let cppNamespace = "::mlir::NVVM";
36533653 let genSpecializedAttr = 0;
36543654}
36553655
3656- def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType , "dot_accumulate_4way_type "> {
3656+ def DotAccumulateTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulateType , "dot_accumulate_type "> {
36573657 let assemblyFormat = "`<` $value `>`";
36583658}
36593659
36603660def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3661- let summary = "Four-way byte dot product-accumulate instruction. ";
3661+ let summary = "Four-way byte dot product-accumulate instruction";
36623662 let description = [{
36633663 Performs a four-way byte dot-product which is accumulated in a 32-bit
36643664 result.
36653665 Operand `a` and `b` are vectors of 4 bytes between which the dot product is
36663666 computed.
3667+
36673668 The `a_type` and `b_type` attributes specify the type of the elements in `a`
36683669 and `b` respectively.
3669- If `a_type` or `b_type` is `s8 `, then the elements in the corresponding
3670+ If `a_type` or `b_type` is `signed `, then the elements in the corresponding
36703671 vector are sign-extended to 32-bit before the dot product is computed.
3671- If `a_type` or `b_type` is `u8`, then the elements in the corresponding
3672- vector are zero-extended to 32-bit instead.
3672+ If `a_type` or `b_type` is `unsigned`, then the elements in the
3673+ corresponding vector are zero-extended to 32-bit instead.
3674+
36733675 Operand `c` is a 32-bit integer to which the result is accumulated. It is
36743676 treated as holding a signed integer if any of `a_type` or `b_type` is `s8`.
36753677
@@ -3678,9 +3680,9 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
36783680
36793681 let arguments = (ins
36803682 VectorOfLengthAndType<[4], [I8]>:$a,
3681- DotAccumulate4WayTypeAttr :$a_type,
3683+ DotAccumulateTypeAttr :$a_type,
36823684 VectorOfLengthAndType<[4], [I8]>:$b,
3683- DotAccumulate4WayTypeAttr :$b_type,
3685+ DotAccumulateTypeAttr :$b_type,
36843686 I32:$c
36853687 );
36863688
@@ -3689,17 +3691,15 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
36893691 let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
36903692
36913693 let extraClassDeclaration = [{
3692- static llvm::Intrinsic::ID
3693- getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
3694- NVVM::DotAccumulate4WayType b_type);
3695- llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
3694+ static mlir::NVVM::IDArgPair
3695+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3696+ llvm::IRBuilderBase &builder);
36963697 }];
36973698
36983699 string llvmBuilder = [{
3699- llvm::Intrinsic::ID id = NVVM::DotAccumulate4WayOp::getIntrinsicID($a_type, $b_type);
3700- llvm::Value* argA = op.getPackedArg($a, builder);
3701- llvm::Value* argB = op.getPackedArg($b, builder);
3702- $res = createIntrinsicCall(builder, id, {argA, argB, $c});
3700+ auto [id, args] = NVVM::DotAccumulate4WayOp::getIntrinsicIDAndArgs(
3701+ *op, moduleTranslation, builder);
3702+ $res = createIntrinsicCall(builder, id, args);
37033703 }];
37043704}
37053705
0 commit comments