3737 is_hip_cdna3 ,
3838 is_hip_cdna4 ,
3939 is_hip_gfx12 ,
40+ get_lds_size ,
4041 is_xpu ,
4142 get_arch ,
4243 torch_float8_dtypes ,
@@ -216,7 +217,7 @@ def __str__(self):
216217 return f"#{ GPU_DIALECT } .blocked<{{sizePerThread={ self .sz_per_thread } , threadsPerWarp={ self .threads_per_warp } , warpsPerCTA={ self .warps_per_cta } , order={ self .order } , CTAsPerCGA={ self .ctas_per_cga } , CTASplitNum={ self .cta_split_num } , CTAOrder={ self .cta_order } }}>"
217218
218219
219- class SharedLayout :
220+ class SwizzledSharedLayout :
220221
221222 def __init__ (self , vec , per_phase , max_phase , order , ctas_per_cga , cta_split_num , cta_order ):
222223 self .vec = vec
@@ -231,6 +232,19 @@ def __str__(self):
231232 return f"#{ GPU_DIALECT } .swizzled_shared<{{vec={ self .vec } , perPhase={ self .per_phase } , maxPhase={ self .max_phase } , order={ self .order } , CTAsPerCGA={ self .ctas_per_cga } , CTASplitNum={ self .cta_split_num } , CTAOrder={ self .cta_order } }}>"
232233
233234
235+ class PaddedSharedLayout :
236+
237+ def __init__ (self , interval_padding_pairs , order , ctas_per_cga , cta_split_num , cta_order ):
238+ self .interval_padding_pairs = "[" + ", " .join (f"{ v [0 ]} :{ v [1 ]:+d} " for v in interval_padding_pairs ) + "]"
239+ self .order = order
240+ self .ctas_per_cga = ctas_per_cga
241+ self .cta_split_num = cta_split_num
242+ self .cta_order = cta_order
243+
244+ def __str__ (self ):
245+ return f"#{ GPU_DIALECT } .padded_shared<{ self .interval_padding_pairs } {{order={ self .order } , CTAsPerCGA={ self .ctas_per_cga } , CTASplitNum={ self .cta_split_num } , CTAOrder={ self .cta_order } }}>"
246+
247+
234248class NVMMASharedLayout :
235249
236250 def __init__ (self , swizzle , transpose , element_bit_width , ctas_per_cga , cta_split_num , cta_order ):
@@ -293,7 +307,7 @@ def warps_per_cta(layout, shape):
293307
294308
295309def is_layout_applicable (layout ) -> bool :
296- if isinstance (layout , (BlockedLayout , SharedLayout , LinearLayout )):
310+ if isinstance (layout , (BlockedLayout , SwizzledSharedLayout , PaddedSharedLayout , LinearLayout )):
297311 return True
298312 elif isinstance (layout , SliceLayout ):
299313 return is_layout_applicable (layout .parent )
@@ -6145,10 +6159,12 @@ def kernel(Out):
61456159
61466160intermediate_layouts = [
61476161 None ,
6148- SharedLayout (1 , 1 , 1 , [0 , 1 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
6149- SharedLayout (1 , 1 , 1 , [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
6150- SharedLayout (4 , 2 , 4 , [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
6151- SharedLayout (2 , 2 , 4 , [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
6162+ SwizzledSharedLayout (1 , 1 , 1 , [0 , 1 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
6163+ SwizzledSharedLayout (1 , 1 , 1 , [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
6164+ SwizzledSharedLayout (4 , 2 , 4 , [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
6165+ SwizzledSharedLayout (2 , 2 , 4 , [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
6166+ PaddedSharedLayout ([[32 , 8 ]], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
6167+ PaddedSharedLayout ([[64 , 4 ], [128 , 8 ]], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ])
61526168]
61536169
61546170
@@ -6182,7 +6198,7 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t
61826198 scratch_shape = compute_scratch_buffer_shape (src_layout , dst_layout , (M , N ))
61836199 except AssertionError :
61846200 pytest .skip ("Can't compute scratch buffer size" )
6185- lds_size = 65536
6201+ lds_size = get_lds_size ()
61866202 # consider int32 dtype in scratch buffer size,
61876203 # because it is the largest dtype used in convert_layout in this test
61886204 int32_size = 4
@@ -6258,10 +6274,10 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t
62586274]
62596275
62606276shared_layouts_3d = [
6261- SharedLayout (1 , 1 , 1 , [2 , 1 , 0 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
6262- SharedLayout (4 , 2 , 4 , [1 , 2 , 0 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
6263- SharedLayout (8 , 2 , 4 , [0 , 2 , 1 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
6264- SharedLayout (4 , 2 , 1 , [2 , 0 , 1 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
6277+ SwizzledSharedLayout (1 , 1 , 1 , [2 , 1 , 0 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
6278+ SwizzledSharedLayout (4 , 2 , 4 , [1 , 2 , 0 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
6279+ SwizzledSharedLayout (8 , 2 , 4 , [0 , 2 , 1 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
6280+ SwizzledSharedLayout (4 , 2 , 1 , [2 , 0 , 1 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
62656281]
62666282
62676283
@@ -6349,9 +6365,9 @@ def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path:
63496365]
63506366
63516367shared_layouts = [
6352- SharedLayout (4 , 2 , 4 , [0 , 1 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
6353- SharedLayout (8 , 1 , 8 , [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
6354- SharedLayout (16 , 1 , 16 , [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
6368+ SwizzledSharedLayout (4 , 2 , 4 , [0 , 1 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
6369+ SwizzledSharedLayout (8 , 1 , 8 , [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
6370+ SwizzledSharedLayout (16 , 1 , 16 , [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
63556371]
63566372
63576373
@@ -6502,7 +6518,7 @@ def test_local_load_store_dot(M, N, dtype, dist_layout, shared_layout, device, t
65026518]
65036519
65046520shared_layouts = [
6505- SharedLayout (8 , 1 , 1 , [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
6521+ SwizzledSharedLayout (8 , 1 , 1 , [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
65066522 NVMMASharedLayout (64 , False , 16 , [1 , 1 ], [1 , 1 ], [0 , 1 ]),
65076523 NVMMASharedLayout (128 , False , 16 , [1 , 1 ], [1 , 1 ], [0 , 1 ]),
65086524]
0 commit comments