diff --git a/include/aie/Dialect/AIEX/IR/AIEX.td b/include/aie/Dialect/AIEX/IR/AIEX.td index e09d67dd85c..9a978f9fcb2 100644 --- a/include/aie/Dialect/AIEX/IR/AIEX.td +++ b/include/aie/Dialect/AIEX/IR/AIEX.td @@ -709,6 +709,7 @@ def AIE_NpuDmaMemcpyNdOp: AIEX_Op<"npu.dma_memcpy_nd", [ }]; let hasVerifier = 1; + let hasCanonicalizer = 1; } def AIE_NpuDmaWaitOp: AIEX_Op<"npu.dma_wait", []> { diff --git a/lib/Dialect/AIEX/IR/AIEXDialect.cpp b/lib/Dialect/AIEX/IR/AIEXDialect.cpp index ded09b22ea5..a3dc1bd0f9c 100644 --- a/lib/Dialect/AIEX/IR/AIEXDialect.cpp +++ b/lib/Dialect/AIEX/IR/AIEXDialect.cpp @@ -368,6 +368,94 @@ bool AIEX::NpuDmaMemcpyNdOp::isLinearTransferWithoutTransformation() { return isLinearTransfer(inputSizes, inputStrides); } +// Canonicalization pattern: rewrite a contiguous row-major access pattern to +// the canonical linear form [s3, 1, 1, N][st3, 0, 0, 1]. +// +// Using outermost-first index notation (matching the IR syntax), a 4D access +// [s3, s2, s1, s0][st3, st2, st1, st0] is a contiguous linear scan when: +// st0 == 1 +// s1 == 1 || st1 == s0 (stride irrelevant when size is 1) +// s2 == 1 || st2 == s0 * s1 +// yielding a total of N = s0 * s1 * s2 contiguous elements. The repeat +// dimension s3 / stride st3 is unchanged by the fold. +// +// This fold is always semantically valid and never introduces new hardware +// limit violations: in the resulting linear form, isLinearTransferWithout- +// Transformation() returns true, so verifyStridesWraps() skips the 10-bit +// d0 wrap-size check. The hardware uses a wider transfer-length register in +// linear mode, so arbitrarily large N is supported. +namespace { +struct LinearizeContiguousTransfer + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(AIEX::NpuDmaMemcpyNdOp op, + mlir::PatternRewriter &rewriter) const override { + // Only constant sizes/strides can be analysed statically. + if (!llvm::all_of(op.getMixedSizes(), [](mlir::OpFoldResult s) { + return mlir::getConstantIntValue(s).has_value(); + })) + return mlir::failure(); + if (!llvm::all_of(op.getMixedStrides(), [](mlir::OpFoldResult s) { + return mlir::getConstantIntValue(s).has_value(); + })) + return mlir::failure(); + + // Skip ops that are already in canonical linear form. + if (op.isLinearTransferWithoutTransformation()) + return mlir::failure(); + + // getMixedSizes/Strides return outermost-first; reverse to innermost-first + // so index 0 = d0 (innermost) and index 3 = repeat (outermost). + llvm::SmallVector sizes = llvm::map_to_vector( + llvm::reverse(op.getMixedSizes()), [](mlir::OpFoldResult s) { + return mlir::getConstantIntValue(s).value(); + }); + llvm::SmallVector strides = llvm::map_to_vector( + llvm::reverse(op.getMixedStrides()), [](mlir::OpFoldResult s) { + return mlir::getConstantIntValue(s).value(); + }); + + // Require a contiguous row-major scan. A stride is only constrained when + // its corresponding size is > 1 (a never-applied stride is irrelevant). + if (strides[0] != 1) + return mlir::failure(); + if (sizes[1] > 1 && strides[1] != sizes[0]) + return mlir::failure(); + if (sizes[2] > 1 && strides[2] != sizes[0] * sizes[1]) + return mlir::failure(); + + // Fold d0/d1/d2 into one linear count; keep the repeat dimension intact. + // Build directly in outermost-first order for the replacement op. + int64_t N = sizes[0] * sizes[1] * sizes[2]; + llvm::SmallVector newSizesOuter = {sizes[3], 1, 1, N}; + llvm::SmallVector newStridesOuter = {strides[3], 0, 0, 1}; + + // Preserve all other attributes (offsets, packet, metadata, etc.) exactly. + rewriter.replaceOpWithNewOp( + op, op.getMemref(), + /*offsets=*/op.getOffsets(), + /*sizes=*/mlir::ValueRange{}, + /*strides=*/mlir::ValueRange{}, + mlir::DenseI64ArrayAttr::get(op.getContext(), op.getStaticOffsets()), + mlir::DenseI64ArrayAttr::get(op.getContext(), newSizesOuter), + mlir::DenseI64ArrayAttr::get(op.getContext(), newStridesOuter), + op.getPacketAttr(), op.getMetadata(), op.getIdAttr(), + op.getIssueTokenAttr(), op.getD0ZeroBeforeAttr(), + op.getD1ZeroBeforeAttr(), op.getD2ZeroBeforeAttr(), + op.getD0ZeroAfterAttr(), op.getD1ZeroAfterAttr(), + op.getD2ZeroAfterAttr(), op.getBurstLengthAttr()); + return mlir::success(); + } +}; +} // namespace + +void AIEX::NpuDmaMemcpyNdOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + patterns.add(context); +} + // Helper method to check if a requested burst length is supported by the target // model. Returns an error message if the burst length is not supported or an // empty option otherwise. diff --git a/test/dialect/AIEX/canonicalize_linear.mlir b/test/dialect/AIEX/canonicalize_linear.mlir new file mode 100644 index 00000000000..69c70bf52ef --- /dev/null +++ b/test/dialect/AIEX/canonicalize_linear.mlir @@ -0,0 +1,250 @@ +//===- canonicalize_linear.mlir --------------------------------*- MLIR -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Copyright (C) 2026, Advanced Micro Devices, Inc. +// +//===----------------------------------------------------------------------===// +// +// Tests for NpuDmaMemcpyNdOp canonicalization: contiguous row-major access +// patterns are folded into the canonical linear form [s3,1,1,N][st3,0,0,1]. +// +// This is the fix for github.com/Xilinx/mlir-aie/issues/2825. +// +// All tests use static literal sizes/strides so that: +// (a) canonicalization sees constant values and can fire, and +// (b) the pre-canonicalization op is in-bounds for the verifier. +// +//===----------------------------------------------------------------------===// + +// RUN: aie-opt --canonicalize --split-input-file %s | FileCheck %s + +// ----- + +// Basic 2D fold: sizes=[1,1,2,512] strides=[0,0,512,1] -> +// sizes=[1,1,1,1024] strides=[0,0,0,1] +// +// Motivating case from issue #2825: in production K can exceed 1023 (the d0 +// wrap limit). After folding, N is encoded in the wider linear-mode transfer +// length register, so no limit applies. + +// CHECK-LABEL: aie.device(npu1) +// CHECK: aie.runtime_sequence @fold_2d +// CHECK: aiex.npu.dma_memcpy_nd +// CHECK-SAME: [0, 0, 0, 0][1, 1, 1, 1024][0, 0, 0, 1] +module { + aie.device(npu1) { + aie.runtime_sequence @fold_2d(%arg0 : memref<2x512xi32>) { + aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 1, 2, 512][0, 0, 512, 1]) + { metadata = @of_fromMem, id = 0 : i64 } : memref<2x512xi32> + } + %tile = aie.tile(0, 0) + aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0) + } +} + +// ----- + +// 3D fold: sizes=[1,3,4,5] strides=[0,20,5,1] -> +// sizes=[1,1,1,60] strides=[0,0,0,1] + +// CHECK-LABEL: aie.device(npu1) +// CHECK: aie.runtime_sequence @fold_3d +// CHECK: aiex.npu.dma_memcpy_nd +// CHECK-SAME: [0, 0, 0, 0][1, 1, 1, 60][0, 0, 0, 1] +module { + aie.device(npu1) { + aie.runtime_sequence @fold_3d(%arg0 : memref<3x4x5xi32>) { + aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 3, 4, 5][0, 20, 5, 1]) + { metadata = @of_fromMem, id = 0 : i64 } : memref<3x4x5xi32> + } + %tile = aie.tile(0, 0) + aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0) + } +} + +// ----- + +// Already in canonical linear form: the pattern must not fire (idempotent). + +// CHECK-LABEL: aie.device(npu1) +// CHECK: aie.runtime_sequence @already_linear +// CHECK: aiex.npu.dma_memcpy_nd +// CHECK-SAME: [0, 0, 0, 0][1, 1, 1, 4096][0, 0, 0, 1] +module { + aie.device(npu1) { + aie.runtime_sequence @already_linear(%arg0 : memref<4096xi32>) { + aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 1, 1, 4096][0, 0, 0, 1]) + { metadata = @of_fromMem, id = 0 : i64 } : memref<4096xi32> + } + %tile = aie.tile(0, 0) + aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0) + } +} + +// ----- + +// Non-contiguous: stride1 (3) != size0 (4) — must NOT be folded. + +// CHECK-LABEL: aie.device(npu1) +// CHECK: aie.runtime_sequence @no_fold_strided +// CHECK: aiex.npu.dma_memcpy_nd +// CHECK-SAME: [0, 0, 0, 0][1, 1, 2, 4][0, 0, 3, 1] +module { + aie.device(npu1) { + aie.runtime_sequence @no_fold_strided(%arg0 : memref<32xi32>) { + // stride1=3 != size0=4: genuinely strided rows, cannot fold. + aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 1, 2, 4][0, 0, 3, 1]) + { metadata = @of_fromMem, id = 0 : i64 } : memref<32xi32> + } + %tile = aie.tile(0, 0) + aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0) + } +} + +// ----- + +// Repeat dimension (s3 > 1) is preserved through the fold. +// sizes=[2,1,2,4] strides=[4096,0,4,1] -> sizes=[2,1,1,8] strides=[4096,0,0,1] + +// CHECK-LABEL: aie.device(npu1) +// CHECK: aie.runtime_sequence @fold_with_repeat +// CHECK: aiex.npu.dma_memcpy_nd +// CHECK-SAME: [0, 0, 0, 0][2, 1, 1, 8][4096, 0, 0, 1] +module { + aie.device(npu1) { + aie.runtime_sequence @fold_with_repeat(%arg0 : memref<8192xi32>) { + aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][2, 1, 2, 4][4096, 0, 4, 1]) + { metadata = @of_fromMem, id = 0 : i64 } : memref<8192xi32> + } + %tile = aie.tile(0, 0) + aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0) + } +} + +// ----- + +// bf16 element type — motivating case from issue #2825. +// sizes=[1,1,2,512] strides=[0,0,512,1] -> sizes=[1,1,1,1024] strides=[0,0,0,1] +// In production K can be 1024+ (exceeding the d0 limit); the fold moves the +// total count into the wider linear-mode transfer-length register. + +// CHECK-LABEL: aie.device(npu1) +// CHECK: aie.runtime_sequence @fold_bf16 +// CHECK: aiex.npu.dma_memcpy_nd +// CHECK-SAME: [0, 0, 0, 0][1, 1, 1, 1024][0, 0, 0, 1] +module { + aie.device(npu1) { + aie.runtime_sequence @fold_bf16(%arg0 : memref<2x512xbf16>) { + aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 1, 2, 512][0, 0, 512, 1]) + { metadata = @of_fromMem, id = 0 : i64 } : memref<2x512xbf16> + } + %tile = aie.tile(0, 0) + aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0) + } +} + +// ----- + +// Non-unit inner stride: stride0=2 means elements are not unit-stride. +// Must NOT be folded. + +// CHECK-LABEL: aie.device(npu1) +// CHECK: aie.runtime_sequence @no_fold_inner_stride +// CHECK: aiex.npu.dma_memcpy_nd +// CHECK-SAME: [0, 0, 0, 0][1, 1, 2, 4][0, 0, 4, 2] +module { + aie.device(npu1) { + aie.runtime_sequence @no_fold_inner_stride(%arg0 : memref<32xi32>) { + // stride0=2: skips every other element, not a linear scan. + aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 1, 2, 4][0, 0, 4, 2]) + { metadata = @of_fromMem, id = 0 : i64 } : memref<32xi32> + } + %tile = aie.tile(0, 0) + aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0) + } +} + +// ----- + +// Wrong stride2: size2 > 1 but stride2 != size0 * size1 — must NOT be folded. +// (stride1 is correct, only stride2 is wrong.) + +// CHECK-LABEL: aie.device(npu1) +// CHECK: aie.runtime_sequence @no_fold_stride2 +// CHECK: aiex.npu.dma_memcpy_nd +// CHECK-SAME: [0, 0, 0, 0][1, 2, 3, 4][0, 7, 4, 1] +module { + aie.device(npu1) { + aie.runtime_sequence @no_fold_stride2(%arg0 : memref<64xi32>) { + // stride2=7 != size0*size1=4*3=12: non-contiguous outer loop, cannot fold. + aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 2, 3, 4][0, 7, 4, 1]) + { metadata = @of_fromMem, id = 0 : i64 } : memref<64xi32> + } + %tile = aie.tile(0, 0) + aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0) + } +} + +// ----- + +// Nonzero static offset is preserved unchanged through the fold. + +// CHECK-LABEL: aie.device(npu1) +// CHECK: aie.runtime_sequence @fold_with_offset +// CHECK: aiex.npu.dma_memcpy_nd +// CHECK-SAME: [0, 0, 0, 4][1, 1, 1, 1024][0, 0, 0, 1] +module { + aie.device(npu1) { + aie.runtime_sequence @fold_with_offset(%arg0 : memref<2048xi32>) { + // Offset of 4 elements; sizes/strides fold as normal. + aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 4][1, 1, 2, 512][0, 0, 512, 1]) + { metadata = @of_fromMem, id = 0 : i64 } : memref<2048xi32> + } + %tile = aie.tile(0, 0) + aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0) + } +} + +// ----- + +// packet attribute is preserved after the fold. + +// CHECK-LABEL: aie.device(npu1) +// CHECK: aie.runtime_sequence @fold_packet +// CHECK: aiex.npu.dma_memcpy_nd +// CHECK-SAME: [0, 0, 0, 0][1, 1, 1, 1024][0, 0, 0, 1] +// CHECK-SAME: packet = +module { + aie.device(npu1) { + aie.runtime_sequence @fold_packet(%arg0 : memref<2x512xi32>) { + aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 1, 2, 512][0, 0, 512, 1], + packet = ) + { metadata = @of_fromMem, id = 0 : i64 } : memref<2x512xi32> + } + %tile = aie.tile(0, 0) + aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0) + } +} + +// ----- + +// issue_token attribute is preserved after the fold. + +// CHECK-LABEL: aie.device(npu1) +// CHECK: aie.runtime_sequence @fold_issue_token +// CHECK: aiex.npu.dma_memcpy_nd +// CHECK-SAME: [0, 0, 0, 0][1, 1, 1, 1024][0, 0, 0, 1] +// CHECK-SAME: issue_token = true +module { + aie.device(npu1) { + aie.runtime_sequence @fold_issue_token(%arg0 : memref<2x512xi32>) { + aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 1, 2, 512][0, 0, 512, 1]) + { metadata = @of_fromMem, id = 0 : i64, issue_token = true } : memref<2x512xi32> + } + %tile = aie.tile(0, 0) + aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0) + } +}