Skip to content

Commit d3c3f1d

Browse files
committed
[Codegen] Use DMA for LHS/RHS only in scaled matmul
* For now, remove the blanket guard that disabled DMA for all scaled matmuls. * Use DMA (UseGlobalLoadDMAAttr) for LHS/RHS operands. * Fix lowering of DMA copy.
1 parent 5c5a70d commit d3c3f1d

6 files changed

Lines changed: 173 additions & 29 deletions

File tree

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -784,15 +784,7 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
784784
lhsScaleType,
785785
rhsScaleType};
786786

787-
// TODO(#22119): We don't use global load DMA for scaled matmuls, because
788-
// compilation doesn't support it. Once this is fixed, we should use global
789-
// load DMA here when possible.
790787
Location loc = operands[0].getLoc();
791-
if (scaled && useDirectLoad) {
792-
mlir::emitWarning(loc) << "direct load (global load DMA) is not yet "
793-
"supported for scaled matmuls, ignoring";
794-
useDirectLoad = false;
795-
}
796788

797789
// Accumulator needs shared memory if:
798790
// - Padding requires C promotion, OR
@@ -910,18 +902,28 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
910902
if (scaled) {
911903
promotionList.append({2, 3});
912904
auto defaultConfigAttr = IREE::GPU::DerivedThreadConfigAttr::get(context);
913-
// TODO(#23329): Do not swizzle shapes that have no bank conflicts.
914-
FailureOr<Attribute> lhsSwizzleAttr =
915-
getXorShuffleAttr(context, defaultConfigAttr, target, kind,
916-
schedule->kTileSizes, kMMAOperandLhs);
917-
FailureOr<Attribute> rhsSwizzleAttr =
918-
getXorShuffleAttr(context, defaultConfigAttr, target, kind,
919-
schedule->kTileSizes, kMMAOperandRhs);
920-
if (failed(lhsSwizzleAttr) || failed(rhsSwizzleAttr)) {
921-
promotionArray = {};
922-
} else {
923-
promotionArray = {*lhsSwizzleAttr, *rhsSwizzleAttr, defaultConfigAttr,
905+
if (useDirectLoad) {
906+
// Use DMA for LHS/RHS (operands 0,1) and thread-based copy for scale
907+
// operands (2,3). Scale operands use a different mapping level than DMA
908+
// copies, so mixing DMA for all operands would prevent loop fusion in
909+
// GPUFuseAndHoistParallelLoops (see #22119).
910+
Attribute useGlobalDma = IREE::GPU::UseGlobalLoadDMAAttr::get(context);
911+
promotionArray = {useGlobalDma, useGlobalDma, defaultConfigAttr,
924912
defaultConfigAttr};
913+
} else {
914+
// TODO(#23329): Do not swizzle shapes that have no bank conflicts.
915+
FailureOr<Attribute> lhsSwizzleAttr =
916+
getXorShuffleAttr(context, defaultConfigAttr, target, kind,
917+
schedule->kTileSizes, kMMAOperandLhs);
918+
FailureOr<Attribute> rhsSwizzleAttr =
919+
getXorShuffleAttr(context, defaultConfigAttr, target, kind,
920+
schedule->kTileSizes, kMMAOperandRhs);
921+
if (failed(lhsSwizzleAttr) || failed(rhsSwizzleAttr)) {
922+
promotionArray = {};
923+
} else {
924+
promotionArray = {*lhsSwizzleAttr, *rhsSwizzleAttr, defaultConfigAttr,
925+
defaultConfigAttr};
926+
}
925927
}
926928
}
927929
if ((!mustBeAligned || couldNeedPadding) && cPromoteIfPadding) {

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ iree_lit_test_suite(
3939
"pipeline_igemm_tile_and_fuse.mlir",
4040
"pipeline_igemm_tile_and_fuse_gfx950.mlir",
4141
"pipeline_lower_to_llvmgpu.mlir",
42+
"pipeline_scaled_matmul_dma.mlir",
4243
"pipeline_scaled_truncation_gfx950.mlir",
4344
"pipeline_tile_and_fuse.mlir",
4445
"pipeline_tile_and_fuse_gfx950.mlir",

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ iree_lit_test_suite(
3434
"pipeline_igemm_tile_and_fuse.mlir"
3535
"pipeline_igemm_tile_and_fuse_gfx950.mlir"
3636
"pipeline_lower_to_llvmgpu.mlir"
37+
"pipeline_scaled_matmul_dma.mlir"
3738
"pipeline_scaled_truncation_gfx950.mlir"
3839
"pipeline_tile_and_fuse.mlir"
3940
"pipeline_tile_and_fuse_gfx950.mlir"

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
// RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" \
1010
// RUN: --remarks-filter=".*" %s 2>&1 | FileCheck %s --check-prefix=CHECK-REMARKS
1111

12+
// RUN: iree-opt --mlir-print-local-scope --split-input-file --iree-gpu-test-target=gfx950 \
13+
// RUN: --iree-codegen-llvmgpu-use-tile-and-fuse-matmul=true --iree-codegen-llvmgpu-test-tile-and-fuse-vectorize=true \
14+
// RUN: --iree-codegen-llvmgpu-use-igemm=false --iree-llvmgpu-use-direct-load=true --iree-llvmgpu-prefetch-num-stages=2 \
15+
// RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s \
16+
// RUN: | FileCheck %s --check-prefix=CHECK-DIRECT-LOAD
17+
1218
// RUN: iree-opt --mlir-print-local-scope --split-input-file --iree-gpu-test-target=gfx950 \
1319
// RUN: --iree-codegen-llvmgpu-use-tile-and-fuse-matmul=true --iree-codegen-llvmgpu-test-tile-and-fuse-vectorize=true \
1420
// RUN: --iree-codegen-llvmgpu-use-igemm=false --iree-llvmgpu-use-direct-load=true --iree-llvmgpu-prefetch-num-stages=2 \
@@ -53,17 +59,25 @@ func.func @scaled_matmul(
5359
// CHECK-SAME: subgroup = [4, 8, 0, 0]
5460
// CHECK-SAME: workgroup = [256, 256, 0, 0]
5561

62+
// With --iree-llvmgpu-use-direct-load, LHS/RHS get use_global_load_dma while
63+
// scales keep derived_thread_config.
64+
// CHECK-DIRECT-LOAD-LABEL: func.func @scaled_matmul
65+
// CHECK-DIRECT-LOAD: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
66+
// CHECK-DIRECT-LOAD-SAME: promotion_types = [#iree_gpu.use_global_load_dma, #iree_gpu.use_global_load_dma, #iree_gpu.derived_thread_config, #iree_gpu.derived_thread_config]
67+
5668
// CHECK-REMARKS: [Analysis] SharedMemoryUsage
5769
// CHECK-REMARKS-SAME: Category:deduceMMASchedule
5870
// CHECK-REMARKS-SAME: Remark=34816
5971

72+
// TODO(#22119): With direct-load, no cache swizzle on LHS/RHS so shared
73+
// memory increases. This needs to be addressed.
6074
// CHECK-REMARKS-DIRECT-LOAD-2: [Analysis] SharedMemoryUsage
6175
// CHECK-REMARKS-DIRECT-LOAD-2-SAME: Category:deduceMMASchedule
62-
// CHECK-REMARKS-DIRECT-LOAD-2-SAME: Remark=34816
76+
// CHECK-REMARKS-DIRECT-LOAD-2-SAME: Remark=69632
6377

6478
// CHECK-REMARKS-DIRECT-LOAD-3: [Analysis] SharedMemoryUsage
6579
// CHECK-REMARKS-DIRECT-LOAD-3-SAME: Category:deduceMMASchedule
66-
// CHECK-REMARKS-DIRECT-LOAD-3-SAME: Remark=34816
80+
// CHECK-REMARKS-DIRECT-LOAD-3-SAME: Remark=104448
6781

6882
// -----
6983

@@ -105,11 +119,11 @@ func.func @scaled_matmul_with_batch(
105119

106120
// CHECK-REMARKS-DIRECT-LOAD-2: [Analysis] SharedMemoryUsage
107121
// CHECK-REMARKS-DIRECT-LOAD-2-SAME: Category:deduceMMASchedule
108-
// CHECK-REMARKS-DIRECT-LOAD-2-SAME: Remark=34816
122+
// CHECK-REMARKS-DIRECT-LOAD-2-SAME: Remark=69632
109123

110124
// CHECK-REMARKS-DIRECT-LOAD-3: [Analysis] SharedMemoryUsage
111125
// CHECK-REMARKS-DIRECT-LOAD-3-SAME: Category:deduceMMASchedule
112-
// CHECK-REMARKS-DIRECT-LOAD-3-SAME: Remark=34816
126+
// CHECK-REMARKS-DIRECT-LOAD-3-SAME: Remark=104448
113127

114128
// -----
115129

@@ -179,11 +193,11 @@ func.func @scaled_matmul_with_dynamic_batch(
179193

180194
// CHECK-REMARKS-DIRECT-LOAD-2: [Analysis] SharedMemoryUsage
181195
// CHECK-REMARKS-DIRECT-LOAD-2-SAME: Category:deduceMMASchedule
182-
// CHECK-REMARKS-DIRECT-LOAD-2-SAME: Remark=26112
196+
// CHECK-REMARKS-DIRECT-LOAD-2-SAME: Remark=52224
183197

184198
// CHECK-REMARKS-DIRECT-LOAD-3: [Analysis] SharedMemoryUsage
185199
// CHECK-REMARKS-DIRECT-LOAD-3-SAME: Category:deduceMMASchedule
186-
// CHECK-REMARKS-DIRECT-LOAD-3-SAME: Remark=26112
200+
// CHECK-REMARKS-DIRECT-LOAD-3-SAME: Remark=78336
187201

188202
// -----
189203

@@ -225,11 +239,11 @@ func.func @small_scaled_matmul(
225239

226240
// CHECK-REMARKS-DIRECT-LOAD-2: [Analysis] SharedMemoryUsage
227241
// CHECK-REMARKS-DIRECT-LOAD-2-SAME: Category:deduceMMASchedule
228-
// CHECK-REMARKS-DIRECT-LOAD-2-SAME: Remark=2176
242+
// CHECK-REMARKS-DIRECT-LOAD-2-SAME: Remark=4352
229243

230244
// CHECK-REMARKS-DIRECT-LOAD-3: [Analysis] SharedMemoryUsage
231245
// CHECK-REMARKS-DIRECT-LOAD-3-SAME: Category:deduceMMASchedule
232-
// CHECK-REMARKS-DIRECT-LOAD-3-SAME: Remark=2176
246+
// CHECK-REMARKS-DIRECT-LOAD-3-SAME: Remark=6528
233247

234248
// -----
235249

@@ -346,11 +360,11 @@ func.func @scaled_matmul_accumulate(
346360

347361
// CHECK-REMARKS-DIRECT-LOAD-2: [Analysis] SharedMemoryUsage
348362
// CHECK-REMARKS-DIRECT-LOAD-2-SAME: Category:deduceMMASchedule
349-
// CHECK-REMARKS-DIRECT-LOAD-2-SAME: Remark=157184
363+
// CHECK-REMARKS-DIRECT-LOAD-2-SAME: Remark=109056
350364

351365
// CHECK-REMARKS-DIRECT-LOAD-3: [Analysis] SharedMemoryUsage
352366
// CHECK-REMARKS-DIRECT-LOAD-3-SAME: Category:deduceMMASchedule
353-
// CHECK-REMARKS-DIRECT-LOAD-3-SAME: Remark=157184
367+
// CHECK-REMARKS-DIRECT-LOAD-3-SAME: Remark=130816
354368

355369
// -----
356370

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx950 \
2+
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmgpu-lower-executable-target{for-rocdl=true})))))" %s | FileCheck %s
3+
4+
// Test: Scaled matmul (f4E2M1FN * f4E2M1FN with f8E8M0FNU scales) compiles
5+
// through the full pipeline with DMA config. This validates that the pipeline
6+
// handles sub-byte types correctly, including the narrow type emulation for
7+
// gather_to_lds ops.
8+
9+
#pipeline_layout = #hal.pipeline.layout<bindings = [
10+
#hal.pipeline.binding<storage_buffer, ReadOnly>,
11+
#hal.pipeline.binding<storage_buffer, ReadOnly>,
12+
#hal.pipeline.binding<storage_buffer, ReadOnly>,
13+
#hal.pipeline.binding<storage_buffer, ReadOnly>,
14+
#hal.pipeline.binding<storage_buffer>
15+
]>
16+
#translation_info = #iree_codegen.translation_info<pipeline =
17+
LLVMGPUTileAndFuse
18+
workgroup_size = [512, 1, 1]
19+
subgroup_size = 64,
20+
{
21+
gpu_pipeline_options = #iree_gpu.pipeline_options<
22+
prefetch_num_stages = 2,
23+
no_reduce_shared_memory_bank_conflicts = true>
24+
}
25+
>
26+
#config = #iree_gpu.lowering_config<{
27+
mma_kind = #iree_gpu.scaled_mma_layout<
28+
intrinsic = MFMA_SCALE_F32_16x16x128_B32,
29+
lhs_elem_type = f4E2M1FN, rhs_elem_type = f4E2M1FN,
30+
acc_elem_type = f32>,
31+
promote_operands = [0, 1, 2, 3],
32+
promotion_types = [
33+
#iree_gpu.use_global_load_dma,
34+
#iree_gpu.use_global_load_dma,
35+
#iree_gpu.derived_thread_config,
36+
#iree_gpu.derived_thread_config],
37+
reduction = [0, 0, 1, 1],
38+
subgroup = [4, 8, 0, 0],
39+
workgroup = [256, 256, 0, 0]
40+
}>
41+
#lhs_map = affine_map<(M, N, Ko, Kb) -> (M, Ko, Kb)>
42+
#rhs_map = affine_map<(M, N, Ko, Kb) -> (N, Ko, Kb)>
43+
#scale_m = affine_map<(M, N, Ko, Kb) -> (M, Ko)>
44+
#scale_n = affine_map<(M, N, Ko, Kb) -> (N, Ko)>
45+
#out_map = affine_map<(M, N, Ko, Kb) -> (M, N)>
46+
hal.executable public @main {
47+
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
48+
hal.executable.export public @scaled_matmul_dma ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) {
49+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice()
50+
hal.return %x, %y, %z : index, index, index
51+
}
52+
builtin.module {
53+
func.func @scaled_matmul_dma()
54+
attributes {translation_info = #translation_info} {
55+
%cst = arith.constant 0.000000e+00 : f32
56+
%c0 = arith.constant 0 : index
57+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1024x512x32xf4E2M1FN>>
58+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1024x512x32xf4E2M1FN>>
59+
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1024x512xf8E8M0FNU>>
60+
%3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1024x512xf8E8M0FNU>>
61+
%4 = hal.interface.binding.subspan layout(#pipeline_layout) binding(4) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<1024x1024xf32>>
62+
%A = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [1024, 512, 32], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1024x512x32xf4E2M1FN>> -> tensor<1024x512x32xf4E2M1FN>
63+
%B = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [1024, 512, 32], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1024x512x32xf4E2M1FN>> -> tensor<1024x512x32xf4E2M1FN>
64+
%A_scales = iree_tensor_ext.dispatch.tensor.load %2, offsets = [0, 0], sizes = [1024, 512], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1024x512xf8E8M0FNU>> -> tensor<1024x512xf8E8M0FNU>
65+
%B_scales = iree_tensor_ext.dispatch.tensor.load %3, offsets = [0, 0], sizes = [1024, 512], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1024x512xf8E8M0FNU>> -> tensor<1024x512xf8E8M0FNU>
66+
%empty = tensor.empty() : tensor<1024x1024xf32>
67+
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
68+
%result = linalg.generic {
69+
indexing_maps = [#lhs_map, #rhs_map, #scale_m, #scale_n, #out_map],
70+
iterator_types = ["parallel", "parallel", "reduction", "reduction"]
71+
} ins(%A, %B, %A_scales, %B_scales : tensor<1024x512x32xf4E2M1FN>, tensor<1024x512x32xf4E2M1FN>, tensor<1024x512xf8E8M0FNU>, tensor<1024x512xf8E8M0FNU>) outs(%fill : tensor<1024x1024xf32>) attrs = {lowering_config = #config} {
72+
^bb0(%a: f4E2M1FN, %b: f4E2M1FN, %a_scale: f8E8M0FNU, %b_scale: f8E8M0FNU, %out: f32):
73+
%s1 = arith.scaling_extf %a, %a_scale : f4E2M1FN, f8E8M0FNU to f32
74+
%s2 = arith.scaling_extf %b, %b_scale : f4E2M1FN, f8E8M0FNU to f32
75+
%m = arith.mulf %s1, %s2 : f32
76+
%r = arith.addf %out, %m : f32
77+
linalg.yield %r : f32
78+
} -> tensor<1024x1024xf32>
79+
iree_tensor_ext.dispatch.tensor.store %result, %4, offsets = [0, 0], sizes = [1024, 1024], strides = [1, 1] : tensor<1024x1024xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<1024x1024xf32>>
80+
return
81+
}
82+
}
83+
}
84+
}
85+
86+
// Verify pipeline completes and produces scaled MFMA compute ops.
87+
// LHS/RHS are promoted to workgroup shared memory and scales use thread-based
88+
// copies. The compute uses 16x16x128 scaled MFMA instructions.
89+
90+
// CHECK-LABEL: func.func @scaled_matmul_dma
91+
// CHECK-DAG: memref.alloc() : memref<{{.*}}xf8E8M0FNU, #gpu.address_space<workgroup>>
92+
// CHECK-DAG: memref.alloc() : memref<{{.*}}xf4E2M1FN, #gpu.address_space<workgroup>>
93+
// CHECK: scf.forall
94+
// CHECK: scf.for
95+
// CHECK: amdgpu.scaled_mfma 16x16x128

tests/e2e/matmul/CMakeLists.txt

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2624,6 +2624,37 @@ iree_generated_e2e_runner_test(
26242624
"requires-gpu-cdna4"
26252625
)
26262626

2627+
iree_generated_e2e_runner_test(
2628+
NAME
2629+
e2e_matmul_cdna4_mxfp4_dma
2630+
TEST_TYPE
2631+
matmul
2632+
GENERATOR
2633+
"generate_e2e_matmul_tests.py"
2634+
GENERATOR_ARGS
2635+
"--lhs_rhs_type=f4E2M1FN"
2636+
"--acc_type=f32"
2637+
"--mx_scale_type=f8E8M0FNU"
2638+
"--mx_block_size=32"
2639+
"--shapes=easy_large_static"
2640+
"--transpose_rhs"
2641+
TEST_RUNNER
2642+
iree_tools_testing_e2e_iree-e2e-matmul-test
2643+
TARGET_BACKENDS
2644+
"rocm"
2645+
DRIVERS
2646+
"hip"
2647+
COMPILER_FLAGS
2648+
${IREE_HIP_TEST_COMPILER_FLAGS}
2649+
"--iree-llvmgpu-use-direct-load"
2650+
LABELS
2651+
"noasan"
2652+
"nomsan"
2653+
"notsan"
2654+
"noubsan"
2655+
"requires-gpu-cdna4"
2656+
)
2657+
26272658
endif()
26282659

26292660

0 commit comments

Comments
 (0)