Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,11 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
DefaultValuedOptionalAttr<
AffineMapArrayAttr,
"BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
>:$indexing_maps,
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);
Expand All @@ -884,9 +888,10 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
}]>,
OpBuilder<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
"Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
$_state.addOperands(operands);
$_state.addAttribute("cast", cast);
$_state.addAttributes(attributes);
$_state.addTypes(resultTensorTypes);
(void)$_state.addRegion(),
Expand Down
25 changes: 16 additions & 9 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3951,11 +3951,18 @@ void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
RegionBuilderHelper helper(b, block);
SmallVector<Value> yields;

TypeFn castVal = TypeFn::cast_signed;
auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
return attr.getName() == "cast";
});
if (castIter != attrs.end()) {
if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
castVal = attr.getValue();
}

auto toType = block.getArgument(2).getType();
Value castValA =
helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
Value castValB =
helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
Value addVal =
helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
Expand Down Expand Up @@ -4004,11 +4011,6 @@ ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
}

void BatchMatmulOp::print(OpAsmPrinter &p) {
SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
elidedAttrs);

SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
BatchMatmulOp::getDefaultIndexingMaps(getContext()),
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
Expand All @@ -4018,6 +4020,11 @@ void BatchMatmulOp::print(OpAsmPrinter &p) {
[&](Attribute attr) { p.printAttribute(attr); });
p << "]";
}

SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
elidedAttrs);
}

/// Verify the user defined indexing maps.
Expand Down
41 changes: 25 additions & 16 deletions mlir/python/mlir/dialects/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def __init__(
generic = region_op(GenericOp_, terminator=YieldOp)


def matmul(
def create_op(
op_type,
*ins: Union[Operation, OpView, Value],
outs: Sequence[Union[Operation, OpView, Value]],
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
Expand All @@ -161,7 +162,7 @@ def matmul(
init = _get_op_result_or_value(outs[0])
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []

op = MatmulOp(
op = op_type(
result_tensors=result_types,
inputs=ins,
outputs=[init],
Expand All @@ -172,24 +173,32 @@ def matmul(
return op


def matmul(
*ins: Union[Operation, OpView, Value],
outs: Sequence[Union[Operation, OpView, Value]],
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
cast: Optional[Union[TypeFn, Attribute]] = None,
):
return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast)


def batch_matmul(
*ins: Union[Operation, OpView, Value],
outs: Sequence[Union[Operation, OpView, Value]],
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
cast: Optional[Union[TypeFn, Attribute]] = None,
):
return create_op(
BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
)


def contract(
*ins: Union[Operation, OpView, Value],
outs: Sequence[Union[Operation, OpView, Value]],
indexing_maps: Sequence[AffineMapAttr],
cast: Optional[Union[TypeFn, Attribute]] = None,
):
ins = [_get_op_result_or_value(input) for input in ins]
if len(outs) > 1:
raise ValueError(f"{outs=} must have length 1.")
init = _get_op_result_or_value(outs[0])
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []

op = ContractOp(
result_tensors=result_types,
inputs=ins,
outputs=[init],
indexing_maps=indexing_maps,
cast=cast,
return create_op(
ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
)
fill_builtin_region(op.operation)
return op
10 changes: 5 additions & 5 deletions mlir/test/Dialect/Linalg/named-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1497,7 +1497,7 @@ func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %a
// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<2x5x7xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
// CHECK: return
// CHECK: }
func.func @batch_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
Expand All @@ -1520,7 +1520,7 @@ func.func @batch_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %ar
// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<2x5x7xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
// CHECK: return
// CHECK: }
func.func @batch_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
Expand All @@ -1543,7 +1543,7 @@ func.func @batch_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<
// CHECK-SAME: %[[VAL_0:.*]]: memref<2x3x5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
// CHECK: return
// CHECK: }
func.func @batch_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<2x3x7xf32>) {
Expand All @@ -1566,7 +1566,7 @@ func.func @batch_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1:
// CHECK-SAME: %[[VAL_0:.*]]: memref<2x3x5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
// CHECK: return
// CHECK: }

Expand Down Expand Up @@ -1622,7 +1622,7 @@ func.func @batch_matmul_explicit_transpose_B(%arg0: memref<2x3x5xf32>, %arg1: me
// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<2x7x5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
// CHECK: return
// CHECK: }
func.func @batch_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) {
Expand Down
100 changes: 100 additions & 0 deletions mlir/test/python/dialects/linalg/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,103 @@ def matmul_as_contract_op(
)

print(module)


# CHECK-LABEL: TEST: testBatchMatmulOp
@run
def testBatchMatmulOp():
with Context(), Location.unknown():
module = Module.create()
f32 = F32Type.get()
with InsertionPoint(module.body):
a_shape = (2, 4, 8)
b_shape = (2, 8, 12)
b_transposed_shape = (2, 12, 8)
c_shape = (2, 4, 12)

dimBatch = ir.AffineDimExpr.get(0)
dimM = ir.AffineDimExpr.get(1)
dimN = ir.AffineDimExpr.get(2)
dimK = ir.AffineDimExpr.get(3)

# CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
# CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
# CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>

a_map = ir.AffineMap.get(4, 0, [dimBatch, dimM, dimK])
b_transposed_map = ir.AffineMap.get(4, 0, [dimBatch, dimN, dimK])
c_map = ir.AffineMap.get(4, 0, [dimBatch, dimM, dimN])

# CHECK: func.func @batch_matmul_op(
@func.FuncOp.from_py_func(
# CHECK-SAME: %[[A:.*]]: tensor<2x4x8xf32>,
RankedTensorType.get(a_shape, f32),
# CHECK-SAME: %[[Amem:.*]]: memref<2x4x8xf32>,
MemRefType.get(a_shape, f32),
# CHECK-SAME: %[[B:.*]]: tensor<2x8x12xf32>,
RankedTensorType.get(b_shape, f32),
# CHECK-SAME: %[[Bmem:.*]]: memref<2x8x12xf32>,
MemRefType.get(b_shape, f32),
# CHECK-SAME: %[[BTrans:.*]]: tensor<2x12x8xf32>,
RankedTensorType.get(b_transposed_shape, f32),
# CHECK-SAME: %[[BTransmem:.*]]: memref<2x12x8xf32>,
MemRefType.get(b_transposed_shape, f32),
# CHECK-SAME: %[[C:.*]]: tensor<2x4x12xf32>,
RankedTensorType.get(c_shape, f32),
# CHECK-SAME: %[[Cmem:.*]]: memref<2x4x12xf32>)
MemRefType.get(c_shape, f32),
)
def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
# CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x8x12xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
res = linalg.BatchMatmulOp(
result_tensors=(C.type,),
inputs=(A, B),
outputs=(C,),
)
linalg.fill_builtin_region(res.operation)
# CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x8x12xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
res = linalg.batch_matmul(A, B, outs=(C,))

# CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<2x4x8xf32>, tensor<2x12x8xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
res = linalg.BatchMatmulOp(
result_tensors=(C.type,),
inputs=(A, Btransposed),
outputs=(C,),
indexing_maps=[a_map, b_transposed_map, c_map],
)
linalg.fill_builtin_region(res.operation)
# CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<2x4x8xf32>, tensor<2x12x8xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
res = linalg.batch_matmul(
A,
Btransposed,
outs=(C,),
indexing_maps=[a_map, b_transposed_map, c_map],
)

# CHECK: linalg.batch_matmul ins(%[[Amem]], %[[Bmem]] : memref<2x4x8xf32>, memref<2x8x12xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
res = linalg.BatchMatmulOp(
result_tensors=[],
inputs=(Amem, Bmem),
outputs=(Cmem,),
)
linalg.fill_builtin_region(res.operation)
# CHECK: linalg.batch_matmul ins(%[[Amem]], %[[Bmem]] : memref<2x4x8xf32>, memref<2x8x12xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
linalg.batch_matmul(Amem, Bmem, outs=(Cmem,))

# CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<2x4x8xf32>, memref<2x12x8xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
res = linalg.BatchMatmulOp(
result_tensors=[],
inputs=(Amem, Btransposedmem),
outputs=(Cmem,),
indexing_maps=[a_map, b_transposed_map, c_map],
)
linalg.fill_builtin_region(res.operation)
# CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<2x4x8xf32>, memref<2x12x8xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
linalg.batch_matmul(
Amem,
Btransposedmem,
outs=(Cmem,),
indexing_maps=[a_map, b_transposed_map, c_map],
)

print(module)