Skip to content

Commit 24285b9

Browse files
committed
Add padded shared layout to test_convert2d
1 parent cf09a13 commit 24285b9

File tree

5 files changed

+41
-18
lines changed

5 files changed

+41
-18
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
369369
let genVerifyDecl = 1;
370370
}
371371

372-
def PaddeddSharedEncodingAttr
372+
def PaddedSharedEncodingAttr
373373
: TritonGPU_Attr<"PaddedSharedEncoding", "padded_shared_encoding",
374374
[SharedEncodingTrait, LayoutEncodingTrait]> {
375375
let mnemonic = "padded_shared";

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ Value emitPadding(Location loc, RewriterBase &rewriter,
424424
unsigned bitwidth, Value smemOffset, bool offsetInBytes) {
425425
TritonLLVMOpBuilder b(loc, rewriter);
426426

427+
assert((bitwidth >= 8) && "Invalid bitwidth for padded shared layout");
427428
Value padOffset = b.i32_val(0);
428429
unsigned offScale = offsetInBytes ? bitwidth / 8 : 1;
429430
for (auto [interval, padding] :
@@ -712,7 +713,8 @@ bool emitTransferBetweenRegistersAndShared(
712713
smemOffset = b.xor_(smemOffset, offset);
713714
if (paddedLayout) {
714715
// Apply the offset needed for padding.
715-
Value padOffset = emitPadding(loc, rewriter, paddedLayout, /*bitwidth=*/0,
716+
auto bitwidth = elemLlvmTy.getIntOrFloatBitWidth();
717+
Value padOffset = emitPadding(loc, rewriter, paddedLayout, bitwidth,
716718
smemOffset, /*offsetInBytes=*/false);
717719
smemOffset = b.add(smemOffset, padOffset);
718720
}

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,6 +1780,8 @@ int64_t PaddedSharedEncodingAttr::getPaddedSize(ArrayRef<int64_t> shape) const {
17801780
llvm::zip_equal(getIntervals(), getPaddings())) {
17811781
paddingSize += (unpaddedSize >> llvm::Log2_32(interval))
17821782
<< llvm::Log2_32(padding);
1783+
// There is no need for padding after the last element
1784+
paddingSize -= padding;
17831785
}
17841786
return unpaddedSize + paddingSize;
17851787
}

python/test/unit/language/test_core.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
is_hip_cdna3,
3838
is_hip_cdna4,
3939
is_hip_gfx12,
40+
get_lds_size,
4041
is_xpu,
4142
get_arch,
4243
torch_float8_dtypes,
@@ -216,7 +217,7 @@ def __str__(self):
216217
return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
217218

218219

219-
class SharedLayout:
220+
class SwizzledSharedLayout:
220221

221222
def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num, cta_order):
222223
self.vec = vec
@@ -231,6 +232,19 @@ def __str__(self):
231232
return f"#{GPU_DIALECT}.swizzled_shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
232233

233234

235+
class PaddedSharedLayout:
236+
237+
def __init__(self, interval_padding_pairs, order, ctas_per_cga, cta_split_num, cta_order):
238+
self.interval_padding_pairs = "[" + ", ".join(f"{v[0]}:{v[1]:+d}" for v in interval_padding_pairs) + "]"
239+
self.order = order
240+
self.ctas_per_cga = ctas_per_cga
241+
self.cta_split_num = cta_split_num
242+
self.cta_order = cta_order
243+
244+
def __str__(self):
245+
return f"#{GPU_DIALECT}.padded_shared<{self.interval_padding_pairs} {{order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
246+
247+
234248
class NVMMASharedLayout:
235249

236250
def __init__(self, swizzle, transpose, element_bit_width, ctas_per_cga, cta_split_num, cta_order):
@@ -293,7 +307,7 @@ def warps_per_cta(layout, shape):
293307

294308

295309
def is_layout_applicable(layout) -> bool:
296-
if isinstance(layout, (BlockedLayout, SharedLayout, LinearLayout)):
310+
if isinstance(layout, (BlockedLayout, SwizzledSharedLayout, PaddedSharedLayout, LinearLayout)):
297311
return True
298312
elif isinstance(layout, SliceLayout):
299313
return is_layout_applicable(layout.parent)
@@ -6145,10 +6159,12 @@ def kernel(Out):
61456159

61466160
intermediate_layouts = [
61476161
None,
6148-
SharedLayout(1, 1, 1, [0, 1], [1, 1], [1, 1], [0, 1]),
6149-
SharedLayout(1, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]),
6150-
SharedLayout(4, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]),
6151-
SharedLayout(2, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]),
6162+
SwizzledSharedLayout(1, 1, 1, [0, 1], [1, 1], [1, 1], [0, 1]),
6163+
SwizzledSharedLayout(1, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]),
6164+
SwizzledSharedLayout(4, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]),
6165+
SwizzledSharedLayout(2, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]),
6166+
PaddedSharedLayout([[32, 8]], [1, 0], [1, 1], [1, 1], [0, 1]),
6167+
PaddedSharedLayout([[64, 4], [128, 8]], [1, 0], [1, 1], [1, 1], [0, 1])
61526168
]
61536169

61546170

@@ -6182,7 +6198,7 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t
61826198
scratch_shape = compute_scratch_buffer_shape(src_layout, dst_layout, (M, N))
61836199
except AssertionError:
61846200
pytest.skip("Can't compute scratch buffer size")
6185-
lds_size = 65536
6201+
lds_size = get_lds_size()
61866202
# consider int32 dtype in scratch buffer size,
61876203
# because it is the largest dtype used in convert_layout in this test
61886204
int32_size = 4
@@ -6258,10 +6274,10 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t
62586274
]
62596275

62606276
shared_layouts_3d = [
6261-
SharedLayout(1, 1, 1, [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
6262-
SharedLayout(4, 2, 4, [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
6263-
SharedLayout(8, 2, 4, [0, 2, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
6264-
SharedLayout(4, 2, 1, [2, 0, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
6277+
SwizzledSharedLayout(1, 1, 1, [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
6278+
SwizzledSharedLayout(4, 2, 4, [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
6279+
SwizzledSharedLayout(8, 2, 4, [0, 2, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
6280+
SwizzledSharedLayout(4, 2, 1, [2, 0, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
62656281
]
62666282

62676283

@@ -6349,9 +6365,9 @@ def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path:
63496365
]
63506366

63516367
shared_layouts = [
6352-
SharedLayout(4, 2, 4, [0, 1], [1, 1], [1, 1], [0, 1]),
6353-
SharedLayout(8, 1, 8, [1, 0], [1, 1], [1, 1], [0, 1]),
6354-
SharedLayout(16, 1, 16, [1, 0], [1, 1], [1, 1], [0, 1]),
6368+
SwizzledSharedLayout(4, 2, 4, [0, 1], [1, 1], [1, 1], [0, 1]),
6369+
SwizzledSharedLayout(8, 1, 8, [1, 0], [1, 1], [1, 1], [0, 1]),
6370+
SwizzledSharedLayout(16, 1, 16, [1, 0], [1, 1], [1, 1], [0, 1]),
63556371
]
63566372

63576373

@@ -6502,7 +6518,7 @@ def test_local_load_store_dot(M, N, dtype, dist_layout, shared_layout, device, t
65026518
]
65036519

65046520
shared_layouts = [
6505-
SharedLayout(8, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]),
6521+
SwizzledSharedLayout(8, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]),
65066522
NVMMASharedLayout(64, False, 16, [1, 1], [1, 1], [0, 1]),
65076523
NVMMASharedLayout(128, False, 16, [1, 1], [1, 1], [0, 1]),
65086524
]

python/triton/_internal_testing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,17 @@ def is_hip_cdna4():
7676

7777
def is_hip_gfx12():
7878
target = get_current_target()
79-
print(target.arch)
8079
return target is not None and target.backend == 'hip' and 'gfx12' in target.arch
8180

8281

8382
def is_hip_cdna():
8483
return is_hip_cdna2() or is_hip_cdna3() or is_hip_cdna4()
8584

8685

86+
def get_lds_size():
87+
return 163840 if is_hip_cdna4() else 65536
88+
89+
8790
def is_xpu():
8891
target = get_current_target()
8992
return False if target is None else target.backend == "xpu"

0 commit comments

Comments
 (0)