|
1 | 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
2 | 2 | # See https://llvm.org/LICENSE.txt for license information. |
3 | 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 4 | +from typing import Optional, Sequence |
4 | 5 |
|
5 | 6 | from ._memref_ops_gen import * |
| 7 | +from ..ir import Value, ShapedType, MemRefType, StridedLayoutAttr |
| 8 | + |
| 9 | + |
| 10 | +def _infer_memref_subview_result_type( |
| 11 | + source_memref_type, static_offsets, static_sizes, static_strides |
| 12 | +): |
| 13 | + source_strides, source_offset = source_memref_type.strides_and_offset |
| 14 | + target_offset = source_offset |
| 15 | + for static_offset, target_stride in zip(static_offsets, source_strides): |
| 16 | + target_offset += static_offset * target_stride |
| 17 | + |
| 18 | + target_strides = [] |
| 19 | + for source_stride, static_stride in zip(source_strides, static_strides): |
| 20 | + target_strides.append(source_stride * static_stride) |
| 21 | + |
| 22 | + layout = StridedLayoutAttr.get(target_offset, target_strides) |
| 23 | + return MemRefType.get( |
| 24 | + static_sizes, |
| 25 | + source_memref_type.element_type, |
| 26 | + layout, |
| 27 | + source_memref_type.memory_space, |
| 28 | + ) |
| 29 | + |
| 30 | + |
| 31 | +_generated_subview = subview |
| 32 | + |
| 33 | + |
| 34 | +def subview( |
| 35 | + source: Value, |
| 36 | + offsets: Optional[Sequence[Value]] = None, |
| 37 | + strides: Optional[Sequence[Value]] = None, |
| 38 | + static_offsets: Optional[Sequence[int]] = None, |
| 39 | + static_sizes: Optional[Sequence[int]] = None, |
| 40 | + static_strides: Optional[Sequence[int]] = None, |
| 41 | + *, |
| 42 | + loc=None, |
| 43 | + ip=None, |
| 44 | +): |
| 45 | + if offsets is None: |
| 46 | + offsets = [] |
| 47 | + if static_offsets is None: |
| 48 | + static_offsets = [] |
| 49 | + if strides is None: |
| 50 | + strides = [] |
| 51 | + if static_strides is None: |
| 52 | + static_strides = [] |
| 53 | + assert static_sizes, f"this convenience method only handles static sizes" |
| 54 | + sizes = [] |
| 55 | + S = ShapedType.get_dynamic_size() |
| 56 | + if offsets and static_offsets: |
| 57 | + assert all(s == S for s in static_offsets) |
| 58 | + if strides and static_strides: |
| 59 | + assert all(s == S for s in static_strides) |
| 60 | + result_type = _infer_memref_subview_result_type( |
| 61 | + source.type, static_offsets, static_sizes, static_strides |
| 62 | + ) |
| 63 | + return _generated_subview( |
| 64 | + result_type, |
| 65 | + source, |
| 66 | + offsets, |
| 67 | + sizes, |
| 68 | + strides, |
| 69 | + static_offsets, |
| 70 | + static_sizes, |
| 71 | + static_strides, |
| 72 | + loc=loc, |
| 73 | + ip=ip, |
| 74 | + ) |
0 commit comments