Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/aie/Dialect/AIEX/IR/AIEX.td
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ def AIE_NpuDmaMemcpyNdOp: AIEX_Op<"npu.dma_memcpy_nd", [
}];

let hasVerifier = 1;
let hasCanonicalizer = 1;
}

def AIE_NpuDmaWaitOp: AIEX_Op<"npu.dma_wait", []> {
Expand Down
88 changes: 88 additions & 0 deletions lib/Dialect/AIEX/IR/AIEXDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,94 @@ bool AIEX::NpuDmaMemcpyNdOp::isLinearTransferWithoutTransformation() {
inputStrides[1] == 0 && inputStrides[2] == 0);
}

// Canonicalization pattern: rewrite a contiguous row-major access pattern to
// the canonical linear form [s3, 1, 1, N][st3, 0, 0, 1].
//
// Using outermost-first index notation (matching the IR syntax), a 4D access
// [s3, s2, s1, s0][st3, st2, st1, st0] is a contiguous linear scan when:
// st0 == 1
// s1 == 1 || st1 == s0 (stride irrelevant when size is 1)
// s2 == 1 || st2 == s0 * s1
// yielding a total of N = s0 * s1 * s2 contiguous elements. The repeat
// dimension s3 / stride st3 is unchanged by the fold.
//
// This fold is always semantically valid and never introduces new hardware
// limit violations: in the resulting linear form, isLinearTransferWithout-
// Transformation() returns true, so verifyStridesWraps() skips the 10-bit
// d0 wrap-size check. The hardware uses a wider transfer-length register in
// linear mode, so arbitrarily large N is supported.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are still limits but they are very large.

namespace {
struct LinearizeContiguousTransfer
: public mlir::OpRewritePattern<AIEX::NpuDmaMemcpyNdOp> {
using OpRewritePattern::OpRewritePattern;

mlir::LogicalResult
matchAndRewrite(AIEX::NpuDmaMemcpyNdOp op,
mlir::PatternRewriter &rewriter) const override {
// Only constant sizes/strides can be analysed statically.
if (!llvm::all_of(op.getMixedSizes(), [](mlir::OpFoldResult s) {
return mlir::getConstantIntValue(s).has_value();
}))
return mlir::failure();
if (!llvm::all_of(op.getMixedStrides(), [](mlir::OpFoldResult s) {
return mlir::getConstantIntValue(s).has_value();
}))
return mlir::failure();

// Skip ops that are already in canonical linear form.
if (op.isLinearTransferWithoutTransformation())
return mlir::failure();

// getMixedSizes/Strides return outermost-first; reverse to innermost-first
// so index 0 = d0 (innermost) and index 3 = repeat (outermost).
llvm::SmallVector<int64_t, 4> sizes = llvm::map_to_vector(
llvm::reverse(op.getMixedSizes()), [](mlir::OpFoldResult s) {
return mlir::getConstantIntValue(s).value();
});
llvm::SmallVector<int64_t, 4> strides = llvm::map_to_vector(
llvm::reverse(op.getMixedStrides()), [](mlir::OpFoldResult s) {
return mlir::getConstantIntValue(s).value();
});

// Require a contiguous row-major scan. A stride is only constrained when
// its corresponding size is > 1 (a never-applied stride is irrelevant).
if (strides[0] != 1)
return mlir::failure();
if (sizes[1] > 1 && strides[1] != sizes[0])
return mlir::failure();
if (sizes[2] > 1 && strides[2] != sizes[0] * sizes[1])
return mlir::failure();

// Fold d0/d1/d2 into one linear count; keep the repeat dimension intact.
// Build directly in outermost-first order for the replacement op.
int64_t N = sizes[0] * sizes[1] * sizes[2];
llvm::SmallVector<int64_t, 4> newSizesOuter = {sizes[3], 1, 1, N};
llvm::SmallVector<int64_t, 4> newStridesOuter = {strides[3], 0, 0, 1};

// Preserve all other attributes (offsets, packet, metadata, etc.) exactly.
rewriter.replaceOpWithNewOp<AIEX::NpuDmaMemcpyNdOp>(
op, op.getMemref(),
/*offsets=*/op.getOffsets(),
/*sizes=*/mlir::ValueRange{},
/*strides=*/mlir::ValueRange{},
mlir::DenseI64ArrayAttr::get(op.getContext(), op.getStaticOffsets()),
mlir::DenseI64ArrayAttr::get(op.getContext(), newSizesOuter),
mlir::DenseI64ArrayAttr::get(op.getContext(), newStridesOuter),
op.getPacketAttr(), op.getMetadata(), op.getIdAttr(),
op.getIssueTokenAttr(), op.getD0ZeroBeforeAttr(),
op.getD1ZeroBeforeAttr(), op.getD2ZeroBeforeAttr(),
op.getD0ZeroAfterAttr(), op.getD1ZeroAfterAttr(),
op.getD2ZeroAfterAttr(), op.getBurstLengthAttr());
return mlir::success();
}
};
} // namespace

void AIEX::NpuDmaMemcpyNdOp::getCanonicalizationPatterns(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) {
patterns.add<LinearizeContiguousTransfer>(context);
}

// Helper method to check if a requested burst length is supported by the target
// model. Returns an error message if the burst length is not supported or an
// empty option otherwise.
Expand Down
250 changes: 250 additions & 0 deletions test/dialect/AIEX/canonicalize_linear.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
//===- canonicalize_linear.mlir --------------------------------*- MLIR -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Copyright (C) 2026, Advanced Micro Devices, Inc.
//
//===----------------------------------------------------------------------===//
//
// Tests for NpuDmaMemcpyNdOp canonicalization: contiguous row-major access
// patterns are folded into the canonical linear form [s3,1,1,N][st3,0,0,1].
//
// This is the fix for github.com/Xilinx/mlir-aie/issues/2825.
//
// All tests use static literal sizes/strides so that:
// (a) canonicalization sees constant values and can fire, and
// (b) the pre-canonicalization op is in-bounds for the verifier.
//
//===----------------------------------------------------------------------===//

// RUN: aie-opt --canonicalize --split-input-file %s | FileCheck %s

// -----

// Basic 2D fold: sizes=[1,1,2,512] strides=[0,0,512,1] ->
// sizes=[1,1,1,1024] strides=[0,0,0,1]
//
// Motivating case from issue #2825: in production K can exceed 1023 (the d0
// wrap limit). After folding, N is encoded in the wider linear-mode transfer
// length register, so no limit applies.

// CHECK-LABEL: aie.device(npu1)
// CHECK: aie.runtime_sequence @fold_2d
// CHECK: aiex.npu.dma_memcpy_nd
// CHECK-SAME: [0, 0, 0, 0][1, 1, 1, 1024][0, 0, 0, 1]
module {
aie.device(npu1) {
aie.runtime_sequence @fold_2d(%arg0 : memref<2x512xi32>) {
aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 1, 2, 512][0, 0, 512, 1])
{ metadata = @of_fromMem, id = 0 : i64 } : memref<2x512xi32>
}
%tile = aie.tile(0, 0)
aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0)
}
}

// -----

// 3D fold: sizes=[1,3,4,5] strides=[0,20,5,1] ->
// sizes=[1,1,1,60] strides=[0,0,0,1]

// CHECK-LABEL: aie.device(npu1)
// CHECK: aie.runtime_sequence @fold_3d
// CHECK: aiex.npu.dma_memcpy_nd
// CHECK-SAME: [0, 0, 0, 0][1, 1, 1, 60][0, 0, 0, 1]
module {
aie.device(npu1) {
aie.runtime_sequence @fold_3d(%arg0 : memref<3x4x5xi32>) {
aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 3, 4, 5][0, 20, 5, 1])
{ metadata = @of_fromMem, id = 0 : i64 } : memref<3x4x5xi32>
}
%tile = aie.tile(0, 0)
aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0)
}
}

// -----

// Already in canonical linear form: the pattern must not fire (idempotent).

// CHECK-LABEL: aie.device(npu1)
// CHECK: aie.runtime_sequence @already_linear
// CHECK: aiex.npu.dma_memcpy_nd
// CHECK-SAME: [0, 0, 0, 0][1, 1, 1, 4096][0, 0, 0, 1]
module {
aie.device(npu1) {
aie.runtime_sequence @already_linear(%arg0 : memref<4096xi32>) {
aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 1, 1, 4096][0, 0, 0, 1])
{ metadata = @of_fromMem, id = 0 : i64 } : memref<4096xi32>
}
%tile = aie.tile(0, 0)
aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0)
}
}

// -----

// Non-contiguous: stride1 (3) != size0 (4) — must NOT be folded.

// CHECK-LABEL: aie.device(npu1)
// CHECK: aie.runtime_sequence @no_fold_strided
// CHECK: aiex.npu.dma_memcpy_nd
// CHECK-SAME: [0, 0, 0, 0][1, 1, 2, 4][0, 0, 3, 1]
module {
aie.device(npu1) {
aie.runtime_sequence @no_fold_strided(%arg0 : memref<32xi32>) {
// stride1=3 != size0=4: genuinely strided rows, cannot fold.
aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 1, 2, 4][0, 0, 3, 1])
{ metadata = @of_fromMem, id = 0 : i64 } : memref<32xi32>
}
%tile = aie.tile(0, 0)
aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0)
}
}

// -----

// Repeat dimension (s3 > 1) is preserved through the fold.
// sizes=[2,1,2,4] strides=[4096,0,4,1] -> sizes=[2,1,1,8] strides=[4096,0,0,1]

// CHECK-LABEL: aie.device(npu1)
// CHECK: aie.runtime_sequence @fold_with_repeat
// CHECK: aiex.npu.dma_memcpy_nd
// CHECK-SAME: [0, 0, 0, 0][2, 1, 1, 8][4096, 0, 0, 1]
module {
aie.device(npu1) {
aie.runtime_sequence @fold_with_repeat(%arg0 : memref<8192xi32>) {
aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][2, 1, 2, 4][4096, 0, 4, 1])
{ metadata = @of_fromMem, id = 0 : i64 } : memref<8192xi32>
}
%tile = aie.tile(0, 0)
aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0)
}
}

// -----

// bf16 element type — motivating case from issue #2825.
// sizes=[1,1,2,512] strides=[0,0,512,1] -> sizes=[1,1,1,1024] strides=[0,0,0,1]
// In production K can be 1024+ (exceeding the d0 limit); the fold moves the
// total count into the wider linear-mode transfer-length register.

// CHECK-LABEL: aie.device(npu1)
// CHECK: aie.runtime_sequence @fold_bf16
// CHECK: aiex.npu.dma_memcpy_nd
// CHECK-SAME: [0, 0, 0, 0][1, 1, 1, 1024][0, 0, 0, 1]
module {
aie.device(npu1) {
aie.runtime_sequence @fold_bf16(%arg0 : memref<2x512xbf16>) {
aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 1, 2, 512][0, 0, 512, 1])
{ metadata = @of_fromMem, id = 0 : i64 } : memref<2x512xbf16>
}
%tile = aie.tile(0, 0)
aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0)
}
}

// -----

// Non-unit inner stride: stride0=2 means elements are not unit-stride.
// Must NOT be folded.

// CHECK-LABEL: aie.device(npu1)
// CHECK: aie.runtime_sequence @no_fold_inner_stride
// CHECK: aiex.npu.dma_memcpy_nd
// CHECK-SAME: [0, 0, 0, 0][1, 1, 2, 4][0, 0, 4, 2]
module {
aie.device(npu1) {
aie.runtime_sequence @no_fold_inner_stride(%arg0 : memref<32xi32>) {
// stride0=2: skips every other element, not a linear scan.
aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 1, 2, 4][0, 0, 4, 2])
{ metadata = @of_fromMem, id = 0 : i64 } : memref<32xi32>
}
%tile = aie.tile(0, 0)
aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0)
}
}

// -----

// Wrong stride2: size2 > 1 but stride2 != size0 * size1 — must NOT be folded.
// (stride1 is correct, only stride2 is wrong.)

// CHECK-LABEL: aie.device(npu1)
// CHECK: aie.runtime_sequence @no_fold_stride2
// CHECK: aiex.npu.dma_memcpy_nd
// CHECK-SAME: [0, 0, 0, 0][1, 2, 3, 4][0, 7, 4, 1]
module {
aie.device(npu1) {
aie.runtime_sequence @no_fold_stride2(%arg0 : memref<64xi32>) {
// stride2=7 != size0*size1=4*3=12: non-contiguous outer loop, cannot fold.
aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 2, 3, 4][0, 7, 4, 1])
{ metadata = @of_fromMem, id = 0 : i64 } : memref<64xi32>
}
%tile = aie.tile(0, 0)
aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0)
}
}

// -----

// Nonzero static offset is preserved unchanged through the fold.

// CHECK-LABEL: aie.device(npu1)
// CHECK: aie.runtime_sequence @fold_with_offset
// CHECK: aiex.npu.dma_memcpy_nd
// CHECK-SAME: [0, 0, 0, 4][1, 1, 1, 1024][0, 0, 0, 1]
module {
aie.device(npu1) {
aie.runtime_sequence @fold_with_offset(%arg0 : memref<2048xi32>) {
// Offset of 4 elements; sizes/strides fold as normal.
aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 4][1, 1, 2, 512][0, 0, 512, 1])
{ metadata = @of_fromMem, id = 0 : i64 } : memref<2048xi32>
}
%tile = aie.tile(0, 0)
aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0)
}
}

// -----

// packet attribute is preserved after the fold.

// CHECK-LABEL: aie.device(npu1)
// CHECK: aie.runtime_sequence @fold_packet
// CHECK: aiex.npu.dma_memcpy_nd
// CHECK-SAME: [0, 0, 0, 0][1, 1, 1, 1024][0, 0, 0, 1]
// CHECK-SAME: packet = <pkt_type = 0, pkt_id = 5>
module {
aie.device(npu1) {
aie.runtime_sequence @fold_packet(%arg0 : memref<2x512xi32>) {
aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 1, 2, 512][0, 0, 512, 1],
packet = <pkt_id = 5, pkt_type = 0>)
{ metadata = @of_fromMem, id = 0 : i64 } : memref<2x512xi32>
}
%tile = aie.tile(0, 0)
aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0)
}
}

// -----

// issue_token attribute is preserved after the fold.

// CHECK-LABEL: aie.device(npu1)
// CHECK: aie.runtime_sequence @fold_issue_token
// CHECK: aiex.npu.dma_memcpy_nd
// CHECK-SAME: [0, 0, 0, 0][1, 1, 1, 1024][0, 0, 0, 1]
// CHECK-SAME: issue_token = true
module {
aie.device(npu1) {
aie.runtime_sequence @fold_issue_token(%arg0 : memref<2x512xi32>) {
aiex.npu.dma_memcpy_nd (%arg0[0, 0, 0, 0][1, 1, 2, 512][0, 0, 512, 1])
{ metadata = @of_fromMem, id = 0 : i64, issue_token = true } : memref<2x512xi32>
}
%tile = aie.tile(0, 0)
aie.shim_dma_allocation @of_fromMem (%tile, MM2S, 0)
}
}
Loading