diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index 1fd5691f41eec..881b6dad2b84d 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -408,6 +408,12 @@ MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type); /// Returns the memory space of the given MemRef type. MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type); +/// Returns the strides of the MemRef if the layout map is in strided form. +/// Both strides and offset are out params. strides must point to pre-allocated +/// memory of length equal to the rank of the memref. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirMemRefTypeGetStridesAndOffset( + MlirType type, int64_t *strides, int64_t *offset); + /// Returns the memory spcae of the given Unranked MemRef type. MLIR_CAPI_EXPORTED MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 56e895d305379..820992de65906 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -12,6 +12,8 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" +#include "mlir-c/Support.h" + #include namespace py = pybind11; @@ -618,6 +620,18 @@ class PyMemRefType : public PyConcreteType { return mlirMemRefTypeGetLayout(self); }, "The layout of the MemRef type.") + .def( + "get_strides_and_offset", + [](PyMemRefType &self) -> std::pair, int64_t> { + std::vector strides(mlirShapedTypeGetRank(self)); + int64_t offset; + if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset( + self, strides.data(), &offset))) + throw std::runtime_error( + "Failed to extract strides and offset from memref."); + return {strides, offset}; + }, + "The strides and offset of the MemRef type.") .def_property_readonly( "affine_map", [](PyMemRefType &self) -> PyAffineMap { diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 6e645188dac86..18c9414c5d0f3 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -9,12 +9,16 @@ #include "mlir-c/BuiltinTypes.h" #include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" +#include "mlir-c/Support.h" #include "mlir/CAPI/AffineMap.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" +#include "mlir/Support/LogicalResult.h" + +#include using namespace mlir; @@ -426,6 +430,18 @@ MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) { return wrap(llvm::cast(unwrap(type)).getMemorySpace()); } +MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type, + int64_t *strides, + int64_t *offset) { + MemRefType memrefType = llvm::cast(unwrap(type)); + SmallVector strides_; + if (failed(getStridesAndOffset(memrefType, strides_, *offset))) + return mlirLogicalResultFailure(); + + (void)std::copy(strides_.begin(), strides_.end(), strides); + return mlirLogicalResultSuccess(); +} + MlirTypeID mlirUnrankedMemRefTypeGetTypeID() { return wrap(UnrankedMemRefType::getTypeID()); } diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 1685124fbccdc..3af3b5ce73bc6 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -2,16 +2,30 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Provide a convenient name for sub-packages to resolve the main C-extension -# with a relative import. -from .._mlir_libs import _mlir as _cext from typing import ( + List as _List, + Optional as _Optional, Sequence as _Sequence, + Tuple as _Tuple, Type as _Type, TypeVar as _TypeVar, Union as _Union, ) +from .._mlir_libs import _mlir as _cext +from ..ir import ( + ArrayAttr, + Attribute, + BoolAttr, + DenseI64ArrayAttr, + IntegerAttr, + IntegerType, + OpView, + Operation, + ShapedType, + Value, +) + __all__ = [ "equally_sized_accessor", "get_default_loc_context", @@ -138,3 +152,157 @@ def get_op_result_or_op_results( ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value ResultValueT = _Union[ResultValueTypeTuple] VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]] + +StaticIntLike = _Union[int, IntegerAttr] +ValueLike = _Union[Operation, OpView, Value] +MixedInt = _Union[StaticIntLike, ValueLike] + +IntOrAttrList = _Sequence[_Union[IntegerAttr, int]] +OptionalIntList = _Optional[_Union[ArrayAttr, IntOrAttrList]] + +BoolOrAttrList = _Sequence[_Union[BoolAttr, bool]] +OptionalBoolList = _Optional[_Union[ArrayAttr, BoolOrAttrList]] + +MixedValues = _Union[_Sequence[_Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike] + +DynamicIndexList = _Sequence[_Union[MixedInt, _Sequence[MixedInt]]] + + +def _dispatch_dynamic_index_list( + indices: _Union[DynamicIndexList, ArrayAttr], +) -> _Tuple[_List[ValueLike], _Union[_List[int], ArrayAttr], _List[bool]]: + """Dispatches a list of indices to the appropriate form. + + This is similar to the custom `DynamicIndexList` directive upstream: + provided indices may be in the form of dynamic SSA values or static values, + and they may be scalable (i.e., as a singleton list) or not. This function + dispatches each index into its respective form. It also extracts the SSA + values and static indices from various similar structures, respectively. + """ + dynamic_indices = [] + static_indices = [ShapedType.get_dynamic_size()] * len(indices) + scalable_indices = [False] * len(indices) + + # ArrayAttr: Extract index values. + if isinstance(indices, ArrayAttr): + indices = [idx for idx in indices] + + def process_nonscalable_index(i, index): + """Processes any form of non-scalable index. + + Returns False if the given index was scalable and thus remains + unprocessed; True otherwise. + """ + if isinstance(index, int): + static_indices[i] = index + elif isinstance(index, IntegerAttr): + static_indices[i] = index.value # pytype: disable=attribute-error + elif isinstance(index, (Operation, Value, OpView)): + dynamic_indices.append(index) + else: + return False + return True + + # Process each index at a time. + for i, index in enumerate(indices): + if not process_nonscalable_index(i, index): + # If it wasn't processed, it must be a scalable index, which is + # provided as a _Sequence of one value, so extract and process that. + scalable_indices[i] = True + assert len(index) == 1 + ret = process_nonscalable_index(i, index[0]) + assert ret + + return dynamic_indices, static_indices, scalable_indices + + +# Dispatches `MixedValues` that all represents integers in various forms into +# the following three categories: +# - `dynamic_values`: a list of `Value`s, potentially from op results; +# - `packed_values`: a value handle, potentially from an op result, associated +# to one or more payload operations of integer type; +# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python +# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`. +# The input is in the form for `packed_values`, only that result is set and the +# other two are empty. Otherwise, the input can be a mix of the other two forms, +# and for each dynamic value, a special value is added to the `static_values`. +def _dispatch_mixed_values( + values: MixedValues, +) -> _Tuple[_List[Value], _Union[Operation, Value, OpView], DenseI64ArrayAttr]: + dynamic_values = [] + packed_values = None + static_values = None + if isinstance(values, ArrayAttr): + static_values = values + elif isinstance(values, (Operation, Value, OpView)): + packed_values = values + else: + static_values = [] + for size in values or []: + if isinstance(size, int): + static_values.append(size) + else: + static_values.append(ShapedType.get_dynamic_size()) + dynamic_values.append(size) + static_values = DenseI64ArrayAttr.get(static_values) + + return (dynamic_values, packed_values, static_values) + + +def _get_value_or_attribute_value( + value_or_attr: _Union[any, Attribute, ArrayAttr] +) -> any: + if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"): + return value_or_attr.value + if isinstance(value_or_attr, ArrayAttr): + return _get_value_list(value_or_attr) + return value_or_attr + + +def _get_value_list( + sequence_or_array_attr: _Union[_Sequence[any], ArrayAttr] +) -> _Sequence[any]: + return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr] + + +def _get_int_array_attr( + values: _Optional[_Union[ArrayAttr, IntOrAttrList]] +) -> ArrayAttr: + if values is None: + return None + + # Turn into a Python list of Python ints. + values = _get_value_list(values) + + # Make an ArrayAttr of IntegerAttrs out of it. + return ArrayAttr.get( + [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values] + ) + + +def _get_int_array_array_attr( + values: _Optional[_Union[ArrayAttr, _Sequence[_Union[ArrayAttr, IntOrAttrList]]]] +) -> ArrayAttr: + """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs. + + The input has to be a collection of a collection of integers, where any + Python _Sequence and ArrayAttr are admissible collections and Python ints and + any IntegerAttr are admissible integers. Both levels of collections are + turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s. + If the input is None, an empty ArrayAttr is returned. + """ + if values is None: + return None + + # Make sure the outer level is a list. + values = _get_value_list(values) + + # The inner level is now either invalid or a mixed sequence of ArrayAttrs and + # Sequences. Make sure the nested values are all lists. + values = [_get_value_list(nested) for nested in values] + + # Turn each nested list into an ArrayAttr. + values = [_get_int_array_attr(nested) for nested in values] + + # Turn the outer list into an ArrayAttr. + return ArrayAttr.get(values) diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py index 3afb6a70cb9e0..a3d783415855e 100644 --- a/mlir/python/mlir/dialects/memref.py +++ b/mlir/python/mlir/dialects/memref.py @@ -1,5 +1,135 @@ # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import operator +from itertools import accumulate +from typing import Optional from ._memref_ops_gen import * +from ._ods_common import _dispatch_mixed_values, MixedValues +from .arith import ConstantOp, _is_integer_like_type +from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType + + +def _is_constant_int_like(i): + return ( + isinstance(i, Value) + and isinstance(i.owner.opview, ConstantOp) + and _is_integer_like_type(i.type) + ) + + +def _is_static_int_like(i): + return ( + isinstance(i, int) and not ShapedType.is_dynamic_size(i) + ) or _is_constant_int_like(i) + + +def _infer_memref_subview_result_type( + source_memref_type, offsets, static_sizes, static_strides +): + source_strides, source_offset = source_memref_type.get_strides_and_offset() + # "canonicalize" from tuple|list -> list + offsets, static_sizes, static_strides, source_strides = map( + list, (offsets, static_sizes, static_strides, source_strides) + ) + + if not all( + all(_is_static_int_like(i) for i in s) + for s in [ + static_sizes, + static_strides, + source_strides, + ] + ): + raise ValueError( + "Only inferring from python or mlir integer constant is supported." + ) + + for s in [offsets, static_sizes, static_strides]: + for idx, i in enumerate(s): + if _is_constant_int_like(i): + s[idx] = i.owner.opview.literal_value + + if any(not _is_static_int_like(i) for i in offsets + [source_offset]): + target_offset = ShapedType.get_dynamic_size() + else: + target_offset = source_offset + for offset, target_stride in zip(offsets, source_strides): + target_offset += offset * target_stride + + target_strides = [] + for source_stride, static_stride in zip(source_strides, static_strides): + target_strides.append(source_stride * static_stride) + + # If default striding then no need to complicate things for downstream ops (e.g., expand_shape). + default_strides = list(accumulate(static_sizes[1:][::-1], operator.mul))[::-1] + [1] + if target_strides == default_strides and target_offset == 0: + layout = None + else: + layout = StridedLayoutAttr.get(target_offset, target_strides) + return ( + offsets, + static_sizes, + static_strides, + MemRefType.get( + static_sizes, + source_memref_type.element_type, + layout, + source_memref_type.memory_space, + ), + ) + + +_generated_subview = subview + + +def subview( + source: Value, + offsets: MixedValues, + sizes: MixedValues, + strides: MixedValues, + *, + result_type: Optional[MemRefType] = None, + loc=None, + ip=None, +): + if offsets is None: + offsets = [] + if sizes is None: + sizes = [] + if strides is None: + strides = [] + source_strides, source_offset = source.type.get_strides_and_offset() + if result_type is None and all( + all(_is_static_int_like(i) for i in s) for s in [sizes, strides, source_strides] + ): + # If any are arith.constant results then this will canonicalize to python int + # (which can then be used to fully specify the subview). + ( + offsets, + sizes, + strides, + result_type, + ) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides) + elif result_type is None: + raise ValueError( + "mixed static/dynamic offset/sizes/strides requires explicit result type." + ) + + offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets) + sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes) + strides, _packed_strides, static_strides = _dispatch_mixed_values(strides) + + return _generated_subview( + result_type, + source, + offsets, + sizes, + strides, + static_offsets, + static_sizes, + static_strides, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py index 284c93823acbd..d7b41c0bd2207 100644 --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -9,163 +9,24 @@ try: from ...ir import * from ...dialects import transform - from .._ods_common import _cext as _ods_cext + from .._ods_common import ( + DynamicIndexList, + IntOrAttrList, + MixedValues, + OptionalBoolList, + OptionalIntList, + _cext as _ods_cext, + _dispatch_dynamic_index_list, + _dispatch_mixed_values, + _get_int_array_array_attr, + _get_int_array_attr, + _get_value_list, + _get_value_or_attribute_value, + ) except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from typing import List, Optional, Sequence, Tuple, Union, overload - -StaticIntLike = Union[int, IntegerAttr] -ValueLike = Union[Operation, OpView, Value] -MixedInt = Union[StaticIntLike, ValueLike] - -IntOrAttrList = Sequence[Union[IntegerAttr, int]] -OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] - -BoolOrAttrList = Sequence[Union[BoolAttr, bool]] -OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]] - -MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike] - -DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]] - - -def _dispatch_dynamic_index_list( - indices: Union[DynamicIndexList, ArrayAttr], -) -> Tuple[List[ValueLike], Union[List[int], ArrayAttr], List[bool]]: - """Dispatches a list of indices to the appropriate form. - - This is similar to the custom `DynamicIndexList` directive upstream: - provided indices may be in the form of dynamic SSA values or static values, - and they may be scalable (i.e., as a singleton list) or not. This function - dispatches each index into its respective form. It also extracts the SSA - values and static indices from various similar structures, respectively. - """ - dynamic_indices = [] - static_indices = [ShapedType.get_dynamic_size()] * len(indices) - scalable_indices = [False] * len(indices) - - # ArrayAttr: Extract index values. - if isinstance(indices, ArrayAttr): - indices = [idx for idx in indices] - - def process_nonscalable_index(i, index): - """Processes any form of non-scalable index. - - Returns False if the given index was scalable and thus remains - unprocessed; True otherwise. - """ - if isinstance(index, int): - static_indices[i] = index - elif isinstance(index, IntegerAttr): - static_indices[i] = index.value # pytype: disable=attribute-error - elif isinstance(index, (Operation, Value, OpView)): - dynamic_indices.append(index) - else: - return False - return True - - # Process each index at a time. - for i, index in enumerate(indices): - if not process_nonscalable_index(i, index): - # If it wasn't processed, it must be a scalable index, which is - # provided as a Sequence of one value, so extract and process that. - scalable_indices[i] = True - assert len(index) == 1 - ret = process_nonscalable_index(i, index[0]) - assert ret - - return dynamic_indices, static_indices, scalable_indices - - -# Dispatches `MixedValues` that all represents integers in various forms into -# the following three categories: -# - `dynamic_values`: a list of `Value`s, potentially from op results; -# - `packed_values`: a value handle, potentially from an op result, associated -# to one or more payload operations of integer type; -# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python -# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`. -# The input is in the form for `packed_values`, only that result is set and the -# other two are empty. Otherwise, the input can be a mix of the other two forms, -# and for each dynamic value, a special value is added to the `static_values`. -def _dispatch_mixed_values( - values: MixedValues, -) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]: - dynamic_values = [] - packed_values = None - static_values = None - if isinstance(values, ArrayAttr): - static_values = values - elif isinstance(values, (Operation, Value, OpView)): - packed_values = values - else: - static_values = [] - for size in values or []: - if isinstance(size, int): - static_values.append(size) - else: - static_values.append(ShapedType.get_dynamic_size()) - dynamic_values.append(size) - static_values = DenseI64ArrayAttr.get(static_values) - - return (dynamic_values, packed_values, static_values) - - -def _get_value_or_attribute_value( - value_or_attr: Union[any, Attribute, ArrayAttr] -) -> any: - if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"): - return value_or_attr.value - if isinstance(value_or_attr, ArrayAttr): - return _get_value_list(value_or_attr) - return value_or_attr - - -def _get_value_list( - sequence_or_array_attr: Union[Sequence[any], ArrayAttr] -) -> Sequence[any]: - return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr] - - -def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr: - if values is None: - return None - - # Turn into a Python list of Python ints. - values = _get_value_list(values) - - # Make an ArrayAttr of IntegerAttrs out of it. - return ArrayAttr.get( - [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values] - ) - - -def _get_int_array_array_attr( - values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]] -) -> ArrayAttr: - """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs. - - The input has to be a collection of collection of integers, where any - Python Sequence and ArrayAttr are admissible collections and Python ints and - any IntegerAttr are admissible integers. Both levels of collections are - turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s. - If the input is None, an empty ArrayAttr is returned. - """ - if values is None: - return None - - # Make sure the outer level is a list. - values = _get_value_list(values) - - # The inner level is now either invalid or a mixed sequence of ArrayAttrs and - # Sequences. Make sure the nested values are all lists. - values = [_get_value_list(nested) for nested in values] - - # Turn each nested list into an ArrayAttr. - values = [_get_int_array_attr(nested) for nested in values] - - # Turn the outer list into an ArrayAttr. - return ArrayAttr.get(values) +from typing import List, Optional, Sequence, Union, overload @_ods_cext.register_operation(_Dialect, replace=True) diff --git a/mlir/test/python/dialects/memref.py b/mlir/test/python/dialects/memref.py index 0c8a7ee282fe1..162c22aedbdc8 100644 --- a/mlir/test/python/dialects/memref.py +++ b/mlir/test/python/dialects/memref.py @@ -1,9 +1,10 @@ # RUN: %PYTHON %s | FileCheck %s -from mlir.ir import * -import mlir.dialects.func as func +import mlir.dialects.arith as arith import mlir.dialects.memref as memref import mlir.extras.types as T +from mlir.dialects.memref import _infer_memref_subview_result_type +from mlir.ir import * def run(f): @@ -88,3 +89,164 @@ def testMemRefAttr(): memref.global_("objFifo_in0", T.memref(16, T.i32())) # CHECK: memref.global @objFifo_in0 : memref<16xi32> print(module) + + +# CHECK-LABEL: TEST: testSubViewOpInferReturnTypeSemantics +@run +def testSubViewOpInferReturnTypeSemantics(): + with Context() as ctx, Location.unknown(ctx): + module = Module.create() + with InsertionPoint(module.body): + x = memref.alloc(T.memref(10, 10, T.i32()), [], []) + # CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<10x10xi32> + print(x.owner) + + y = memref.subview(x, [1, 1], [3, 3], [1, 1]) + assert y.owner.verify() + # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>> + print(y.owner) + + z = memref.subview( + x, + [arith.constant(T.index(), 1), 1], + [3, 3], + [1, 1], + ) + # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>> + print(z.owner) + + z = memref.subview( + x, + [arith.constant(T.index(), 3), arith.constant(T.index(), 4)], + [3, 3], + [1, 1], + ) + # CHECK: %{{.*}} = memref.subview %[[ALLOC]][3, 4] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 34>> + print(z.owner) + + s = arith.addi(arith.constant(T.index(), 3), arith.constant(T.index(), 4)) + z = memref.subview( + x, + [s, 0], + [3, 3], + [1, 1], + ) + # CHECK: {{.*}} = memref.subview %[[ALLOC]][%0, 0] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: ?>> + print(z) + + try: + _infer_memref_subview_result_type( + x.type, + [arith.constant(T.index(), 3), arith.constant(T.index(), 4)], + [ShapedType.get_dynamic_size(), 3], + [1, 1], + ) + except ValueError as e: + # CHECK: Only inferring from python or mlir integer constant is supported + print(e) + + try: + memref.subview( + x, + [arith.constant(T.index(), 3), arith.constant(T.index(), 4)], + [ShapedType.get_dynamic_size(), 3], + [1, 1], + ) + except ValueError as e: + # CHECK: mixed static/dynamic offset/sizes/strides requires explicit result type + print(e) + + layout = StridedLayoutAttr.get(ShapedType.get_dynamic_size(), [10, 1]) + x = memref.alloc( + T.memref( + 10, + 10, + T.i32(), + layout=layout, + ), + [], + [arith.constant(T.index(), 42)], + ) + # CHECK: %[[DYNAMICALLOC:.*]] = memref.alloc()[%c42] : memref<10x10xi32, strided<[10, 1], offset: ?>> + print(x.owner) + y = memref.subview( + x, + [1, 1], + [3, 3], + [1, 1], + result_type=T.memref(3, 3, T.i32(), layout=layout), + ) + # CHECK: %{{.*}} = memref.subview %[[DYNAMICALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32, strided<[10, 1], offset: ?>> to memref<3x3xi32, strided<[10, 1], offset: ?>> + print(y.owner) + + +# CHECK-LABEL: TEST: testSubViewOpInferReturnTypeExtensiveSlicing +@run +def testSubViewOpInferReturnTypeExtensiveSlicing(): + def check_strides_offset(memref, np_view): + layout = memref.type.layout + dtype_size_in_bytes = np_view.dtype.itemsize + golden_strides = (np.array(np_view.strides) // dtype_size_in_bytes).tolist() + golden_offset = ( + np_view.ctypes.data - np_view.base.ctypes.data + ) // dtype_size_in_bytes + + assert (layout.strides, layout.offset) == (golden_strides, golden_offset) + + with Context() as ctx, Location.unknown(ctx): + module = Module.create() + with InsertionPoint(module.body): + shape = (10, 22, 333, 4444) + golden_mem = np.zeros(shape, dtype=np.int32) + mem1 = memref.alloc(T.memref(*shape, T.i32()), [], []) + + # fmt: off + check_strides_offset(memref.subview(mem1, (1, 0, 0, 0), (1, 22, 333, 4444), (1, 1, 1, 1)), golden_mem[1:2, ...]) + check_strides_offset(memref.subview(mem1, (0, 1, 0, 0), (10, 1, 333, 4444), (1, 1, 1, 1)), golden_mem[:, 1:2]) + check_strides_offset(memref.subview(mem1, (0, 0, 1, 0), (10, 22, 1, 4444), (1, 1, 1, 1)), golden_mem[:, :, 1:2]) + check_strides_offset(memref.subview(mem1, (0, 0, 0, 1), (10, 22, 333, 1), (1, 1, 1, 1)), golden_mem[:, :, :, 1:2]) + check_strides_offset(memref.subview(mem1, (0, 1, 0, 1), (10, 1, 333, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, :, 1:2]) + check_strides_offset(memref.subview(mem1, (1, 0, 0, 1), (1, 22, 333, 1), (1, 1, 1, 1)), golden_mem[1:2, :, :, 1:2]) + check_strides_offset(memref.subview(mem1, (1, 1, 0, 0), (1, 1, 333, 4444), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, :]) + check_strides_offset(memref.subview(mem1, (0, 0, 1, 1), (10, 22, 1, 1), (1, 1, 1, 1)), golden_mem[:, :, 1:2, 1:2]) + check_strides_offset(memref.subview(mem1, (0, 1, 1, 0), (10, 1, 1, 4444), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, :]) + check_strides_offset(memref.subview(mem1, (1, 0, 1, 0), (1, 22, 1, 4444), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, :]) + check_strides_offset(memref.subview(mem1, (1, 1, 0, 1), (1, 1, 333, 1), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, 1:2]) + check_strides_offset(memref.subview(mem1, (1, 0, 1, 1), (1, 22, 1, 1), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, 1:2]) + check_strides_offset(memref.subview(mem1, (0, 1, 1, 1), (10, 1, 1, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, 1:2]) + check_strides_offset(memref.subview(mem1, (1, 1, 1, 0), (1, 1, 1, 4444), (1, 1, 1, 1)), golden_mem[1:2, 1:2, 1:2, :]) + # fmt: on + + # default strides and offset means no stridedlayout attribute means affinemap layout + assert memref.subview( + mem1, (0, 0, 0, 0), (10, 22, 333, 4444), (1, 1, 1, 1) + ).type.layout == AffineMapAttr.get( + AffineMap.get( + 4, + 0, + [ + AffineDimExpr.get(0), + AffineDimExpr.get(1), + AffineDimExpr.get(2), + AffineDimExpr.get(3), + ], + ) + ) + + shape = (7, 22, 333, 4444) + golden_mem = np.zeros(shape, dtype=np.int32) + mem2 = memref.alloc(T.memref(*shape, T.i32()), [], []) + # fmt: off + check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 333, 4444), (1, 2, 1, 1)), golden_mem[:, 0:22:2]) + check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 4444), (1, 2, 30, 1)), golden_mem[:, 0:22:2, 0:330:30]) + check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 11), (1, 2, 30, 400)), golden_mem[:, 0:22:2, 0:330:30, 0:4400:400]) + check_strides_offset(memref.subview(mem2, (0, 0, 100, 1000), (7, 22, 20, 20), (1, 1, 5, 50)), golden_mem[:, :, 100:200:5, 1000:2000:50]) + # fmt: on + + shape = (8, 8) + golden_mem = np.zeros(shape, dtype=np.int32) + # fmt: off + mem3 = memref.alloc(T.memref(*shape, T.i32()), [], []) + check_strides_offset(memref.subview(mem3, (0, 0), (4, 4), (1, 1)), golden_mem[0:4, 0:4]) + check_strides_offset(memref.subview(mem3, (4, 4), (4, 4), (1, 1)), golden_mem[4:8, 4:8]) + # fmt: on