Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/aie/Dialect/AIEX/IR/AIEX.td
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ def AIE_NpuDmaMemcpyNdOp: AIEX_Op<"npu.dma_memcpy_nd", [
}];

let hasVerifier = 1;
let hasCanonicalizer = 1;
}

def AIE_NpuDmaWaitOp: AIEX_Op<"npu.dma_wait", []> {
Expand Down
88 changes: 88 additions & 0 deletions lib/Dialect/AIEX/IR/AIEXDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,94 @@ bool AIEX::NpuDmaMemcpyNdOp::isLinearTransferWithoutTransformation() {
return isLinearTransfer(inputSizes, inputStrides);
}

// Canonicalization pattern: rewrite a contiguous row-major access pattern to
// the canonical linear form [s3, 1, 1, N][st3, 0, 0, 1].
//
// Using outermost-first index notation (matching the IR syntax), a 4D access
// [s3, s2, s1, s0][st3, st2, st1, st0] is a contiguous linear scan when:
// st0 == 1
// s1 == 1 || st1 == s0 (stride irrelevant when size is 1)
// s2 == 1 || st2 == s0 * s1
// yielding a total of N = s0 * s1 * s2 contiguous elements. The repeat
// dimension s3 / stride st3 is unchanged by the fold.
//
// This fold is always semantically valid and never introduces new hardware
// limit violations: in the resulting linear form, isLinearTransferWithout-
// Transformation() returns true, so verifyStridesWraps() skips the 10-bit
// d0 wrap-size check. The hardware uses a wider transfer-length register in
// linear mode, so arbitrarily large N is supported.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are still limits but they are very large.

namespace {
struct LinearizeContiguousTransfer
: public mlir::OpRewritePattern<AIEX::NpuDmaMemcpyNdOp> {
using OpRewritePattern::OpRewritePattern;

mlir::LogicalResult
matchAndRewrite(AIEX::NpuDmaMemcpyNdOp op,
mlir::PatternRewriter &rewriter) const override {
// Only constant sizes/strides can be analysed statically.
if (!llvm::all_of(op.getMixedSizes(), [](mlir::OpFoldResult s) {
return mlir::getConstantIntValue(s).has_value();
}))
return mlir::failure();
if (!llvm::all_of(op.getMixedStrides(), [](mlir::OpFoldResult s) {
return mlir::getConstantIntValue(s).has_value();
}))
return mlir::failure();

// Skip ops that are already in canonical linear form.
if (op.isLinearTransferWithoutTransformation())
return mlir::failure();

// getMixedSizes/Strides return outermost-first; reverse to innermost-first
// so index 0 = d0 (innermost) and index 3 = repeat (outermost).
llvm::SmallVector<int64_t, 4> sizes = llvm::map_to_vector(
llvm::reverse(op.getMixedSizes()), [](mlir::OpFoldResult s) {
return mlir::getConstantIntValue(s).value();
});
llvm::SmallVector<int64_t, 4> strides = llvm::map_to_vector(
llvm::reverse(op.getMixedStrides()), [](mlir::OpFoldResult s) {
return mlir::getConstantIntValue(s).value();
});

// Require a contiguous row-major scan. A stride is only constrained when
// its corresponding size is > 1 (a never-applied stride is irrelevant).
if (strides[0] != 1)
return mlir::failure();
if (sizes[1] > 1 && strides[1] != sizes[0])
return mlir::failure();
if (sizes[2] > 1 && strides[2] != sizes[0] * sizes[1])
return mlir::failure();

// Fold d0/d1/d2 into one linear count; keep the repeat dimension intact.
// Build directly in outermost-first order for the replacement op.
int64_t N = sizes[0] * sizes[1] * sizes[2];
llvm::SmallVector<int64_t, 4> newSizesOuter = {sizes[3], 1, 1, N};
llvm::SmallVector<int64_t, 4> newStridesOuter = {strides[3], 0, 0, 1};

// Preserve all other attributes (offsets, packet, metadata, etc.) exactly.
rewriter.replaceOpWithNewOp<AIEX::NpuDmaMemcpyNdOp>(
op, op.getMemref(),
/*offsets=*/op.getOffsets(),
/*sizes=*/mlir::ValueRange{},
/*strides=*/mlir::ValueRange{},
mlir::DenseI64ArrayAttr::get(op.getContext(), op.getStaticOffsets()),
mlir::DenseI64ArrayAttr::get(op.getContext(), newSizesOuter),
mlir::DenseI64ArrayAttr::get(op.getContext(), newStridesOuter),
op.getPacketAttr(), op.getMetadata(), op.getIdAttr(),
op.getIssueTokenAttr(), op.getD0ZeroBeforeAttr(),
op.getD1ZeroBeforeAttr(), op.getD2ZeroBeforeAttr(),
op.getD0ZeroAfterAttr(), op.getD1ZeroAfterAttr(),
op.getD2ZeroAfterAttr(), op.getBurstLengthAttr());
return mlir::success();
}
};
} // namespace

void AIEX::NpuDmaMemcpyNdOp::getCanonicalizationPatterns(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) {
patterns.add<LinearizeContiguousTransfer>(context);
}

// Helper method to check if a requested burst length is supported by the target
// model. Returns an error message if the burst length is not supported or an
// empty option otherwise.
Expand Down
250 changes: 250 additions & 0 deletions test/dialect/AIEX/canonicalize_linear.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
//===- canonicalize_linear.mlir --------------------------------*- MLIR -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Copyright (C) 2026, Advanced Micro Devices, Inc.
//
//===----------------------------------------------------------------------===//
//
// Tests for NpuDmaMemcpyNdOp canonicalization: contiguous row-major access
// patterns are folded into the canonical linear form [s3,1,1,N][st3,0,0,1].
//
// This is the fix for github.com/Xilinx/mlir-aie/issues/2825.
//
// All tests use static literal sizes/strides so that:
// (a) canonicalization sees constant values and can fire, and
// (b) the pre-canonicalization op is in-bounds for the verifier.
//
//===----------------------------------------------------------------------===//

// RUN: aie-opt --canonicalize --split-input-file %s | FileCheck %s

// -----

// Basic 2D fold: sizes=[1,1,2,512] strides=[0,0,512,1] ->
// sizes=[1,1,1,1024] strides=[0,0,0,1]
//
// Motivating case from issue #2825: in production K can exceed 1023 (the d0
// wrap limit). After folding, N is encoded in the wider linear-mode transfer
// length register, so no limit applies.

// CHECK-LABEL: aie.device(npu1)
// CHECK: aie.runtime_sequence @fold_2d
// CHECK: aiex.npu.dma_memcpy_nd
// CHECK-SAME: [0, 0, 0, 0][1, 1, 1, 1024][0, 0, 0, 1]
module {
aie.device(npu1) {
aie.runtime_sequence @fold_2d(%arg0 : memref<2x512xi32>) {
aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 1, 2, 512][0, 0, 512, 1])
{ metadata = @of_fromMem, id = 0 : i64 } : memref<2x512xi32>
}
%tile = aie.tile(0, 0)
aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0)
}
}

// -----

// 3D fold: sizes=[1,3,4,5] strides=[0,20,5,1] ->
// sizes=[1,1,1,60] strides=[0,0,0,1]

// CHECK-LABEL: aie.device(npu1)
// CHECK: aie.runtime_sequence @fold_3d
// CHECK: aiex.npu.dma_memcpy_nd
// CHECK-SAME: [0, 0, 0, 0][1, 1, 1, 60][0, 0, 0, 1]
module {
aie.device(npu1) {
aie.runtime_sequence @fold_3d(%arg0 : memref<3x4x5xi32>) {
aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 3, 4, 5][0, 20, 5, 1])
{ metadata = @of_fromMem, id = 0 : i64 } : memref<3x4x5xi32>
}
%tile = aie.tile(0, 0)
aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0)
}
}

// -----

// Already in canonical linear form: the pattern must not fire (idempotent).

// CHECK-LABEL: aie.device(npu1)
// CHECK: aie.runtime_sequence @already_linear
// CHECK: aiex.npu.dma_memcpy_nd
// CHECK-SAME: [0, 0, 0, 0][1, 1, 1, 4096][0, 0, 0, 1]
module {
aie.device(npu1) {
aie.runtime_sequence @already_linear(%arg0 : memref<4096xi32>) {
aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 1, 1, 4096][0, 0, 0, 1])
{ metadata = @of_fromMem, id = 0 : i64 } : memref<4096xi32>
}
%tile = aie.tile(0, 0)
aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0)
}
}

// -----

// Non-contiguous: stride1 (3) != size0 (4) — must NOT be folded.

// CHECK-LABEL: aie.device(npu1)
// CHECK: aie.runtime_sequence @no_fold_strided
// CHECK: aiex.npu.dma_memcpy_nd
// CHECK-SAME: [0, 0, 0, 0][1, 1, 2, 4][0, 0, 3, 1]
module {
aie.device(npu1) {
aie.runtime_sequence @no_fold_strided(%arg0 : memref<32xi32>) {
// stride1=3 != size0=4: genuinely strided rows, cannot fold.
aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 1, 2, 4][0, 0, 3, 1])
{ metadata = @of_fromMem, id = 0 : i64 } : memref<32xi32>
}
%tile = aie.tile(0, 0)
aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0)
}
}

// -----

// Repeat dimension (s3 > 1) is preserved through the fold.
// sizes=[2,1,2,4] strides=[4096,0,4,1] -> sizes=[2,1,1,8] strides=[4096,0,0,1]

// CHECK-LABEL: aie.device(npu1)
// CHECK: aie.runtime_sequence @fold_with_repeat
// CHECK: aiex.npu.dma_memcpy_nd
// CHECK-SAME: [0, 0, 0, 0][2, 1, 1, 8][4096, 0, 0, 1]
module {
aie.device(npu1) {
aie.runtime_sequence @fold_with_repeat(%arg0 : memref<8192xi32>) {
aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][2, 1, 2, 4][4096, 0, 4, 1])
{ metadata = @of_fromMem, id = 0 : i64 } : memref<8192xi32>
}
%tile = aie.tile(0, 0)
aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0)
}
}

// -----

// bf16 element type — motivating case from issue #2825.
// sizes=[1,1,2,512] strides=[0,0,512,1] -> sizes=[1,1,1,1024] strides=[0,0,0,1]
// In production K can be 1024+ (exceeding the d0 limit); the fold moves the
// total count into the wider linear-mode transfer-length register.

// CHECK-LABEL: aie.device(npu1)
// CHECK: aie.runtime_sequence @fold_bf16
// CHECK: aiex.npu.dma_memcpy_nd
// CHECK-SAME: [0, 0, 0, 0][1, 1, 1, 1024][0, 0, 0, 1]
module {
aie.device(npu1) {
aie.runtime_sequence @fold_bf16(%arg0 : memref<2x512xbf16>) {
aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 1, 2, 512][0, 0, 512, 1])
{ metadata = @of_fromMem, id = 0 : i64 } : memref<2x512xbf16>
}
%tile = aie.tile(0, 0)
aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0)
}
}

// -----

// Non-unit inner stride: stride0=2 means elements are not unit-stride.
// Must NOT be folded.

// CHECK-LABEL: aie.device(npu1)
// CHECK: aie.runtime_sequence @no_fold_inner_stride
// CHECK: aiex.npu.dma_memcpy_nd
// CHECK-SAME: [0, 0, 0, 0][1, 1, 2, 4][0, 0, 4, 2]
module {
aie.device(npu1) {
aie.runtime_sequence @no_fold_inner_stride(%arg0 : memref<32xi32>) {
// stride0=2: skips every other element, not a linear scan.
aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 1, 2, 4][0, 0, 4, 2])
{ metadata = @of_fromMem, id = 0 : i64 } : memref<32xi32>
}
%tile = aie.tile(0, 0)
aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0)
}
}

// -----

// Wrong stride2: size2 > 1 but stride2 != size0 * size1 — must NOT be folded.
// (stride1 is correct, only stride2 is wrong.)

// CHECK-LABEL: aie.device(npu1)
// CHECK: aie.runtime_sequence @no_fold_stride2
// CHECK: aiex.npu.dma_memcpy_nd
// CHECK-SAME: [0, 0, 0, 0][1, 2, 3, 4][0, 7, 4, 1]
module {
aie.device(npu1) {
aie.runtime_sequence @no_fold_stride2(%arg0 : memref<64xi32>) {
// stride2=7 != size0*size1=4*3=12: non-contiguous outer loop, cannot fold.
aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 2, 3, 4][0, 7, 4, 1])
{ metadata = @of_fromMem, id = 0 : i64 } : memref<64xi32>
}
%tile = aie.tile(0, 0)
aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0)
}
}

// -----

// Nonzero static offset is preserved unchanged through the fold.

// CHECK-LABEL: aie.device(npu1)
// CHECK: aie.runtime_sequence @fold_with_offset
// CHECK: aiex.npu.dma_memcpy_nd
// CHECK-SAME: [0, 0, 0, 4][1, 1, 1, 1024][0, 0, 0, 1]
module {
aie.device(npu1) {
aie.runtime_sequence @fold_with_offset(%arg0 : memref<2048xi32>) {
// Offset of 4 elements; sizes/strides fold as normal.
aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 4][1, 1, 2, 512][0, 0, 512, 1])
{ metadata = @of_fromMem, id = 0 : i64 } : memref<2048xi32>
}
%tile = aie.tile(0, 0)
aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0)
}
}

// -----

// packet attribute is preserved after the fold.

// CHECK-LABEL: aie.device(npu1)
// CHECK: aie.runtime_sequence @fold_packet
// CHECK: aiex.npu.dma_memcpy_nd
// CHECK-SAME: [0, 0, 0, 0][1, 1, 1, 1024][0, 0, 0, 1]
// CHECK-SAME: packet = <pkt_type = 0, pkt_id = 5>
module {
aie.device(npu1) {
aie.runtime_sequence @fold_packet(%arg0 : memref<2x512xi32>) {
aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 1, 2, 512][0, 0, 512, 1],
packet = <pkt_id = 5, pkt_type = 0>)
{ metadata = @of_fromMem, id = 0 : i64 } : memref<2x512xi32>
}
%tile = aie.tile(0, 0)
aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0)
}
}

// -----

// issue_token attribute is preserved after the fold.

// CHECK-LABEL: aie.device(npu1)
// CHECK: aie.runtime_sequence @fold_issue_token
// CHECK: aiex.npu.dma_memcpy_nd
// CHECK-SAME: [0, 0, 0, 0][1, 1, 1, 1024][0, 0, 0, 1]
// CHECK-SAME: issue_token = true
module {
aie.device(npu1) {
aie.runtime_sequence @fold_issue_token(%arg0 : memref<2x512xi32>) {
aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 1, 2, 512][0, 0, 512, 1])
{ metadata = @of_fromMem, id = 0 : i64, issue_token = true } : memref<2x512xi32>
}
%tile = aie.tile(0, 0)
aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0)
}
}
Loading