diff --git a/aiter/configs/a8w8_bpreshuffle_cktile_tuned_gemm.csv b/aiter/configs/a8w8_bpreshuffle_cktile_tuned_gemm.csv new file mode 100644 index 0000000000..d058eafe65 --- /dev/null +++ b/aiter/configs/a8w8_bpreshuffle_cktile_tuned_gemm.csv @@ -0,0 +1,221 @@ +cu_num,M,N,K,q_dtype_w,kernelId,splitK,us,kernelName,tflops,bw,errRatio +80,1,9216,4096,torch.float8_e4m3fnuz,30,0,13.5714,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_16x64x512_1x4x1_16x16x64_default,5.56,2783.15,0.0 +80,2,9216,4096,torch.float8_e4m3fnuz,2,0,14.0907,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,10.72,2682.18,0.0 +80,4,9216,4096,torch.float8_e4m3fnuz,2,0,13.9599,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,21.63,2710.54,0.0 +80,8,9216,4096,torch.float8_e4m3fnuz,2,0,13.9876,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,43.18,2711.61,0.0 +80,16,9216,4096,torch.float8_e4m3fnuz,30,0,14.8257,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_16x64x512_1x4x1_16x16x64_default,81.48,2570.48,0.0 +80,32,9216,4096,torch.float8_e4m3fnuz,9,0,17.4829,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x64x512_1x4x1_16x16x64_default,138.19,2200.41,0.0 +80,64,9216,4096,torch.float8_e4m3fnuz,22,0,24.8908,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x64x256_1x4x1_16x16x64_default,194.12,1574.5,0.0 +80,128,9216,4096,torch.float8_e4m3fnuz,24,0,37.3315,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x256_1x4x1_16x16x64_default,258.86,1088.42,0.0 +80,256,9216,4096,torch.float8_e4m3fnuz,0,0,66.9303,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_128x128x128_1x4x1_16x16x64_default,288.77,650.17,0.0 +80,1024,9216,4096,torch.float8_e4m3fnuz,54,0,217.4525,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,355.52,279.68,0.0 +80,2048,9216,4096,torch.float8_e4m3fnuz,54,0,421.0536,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,367.22,199.23,0.0 +80,4096,9216,4096,torch.float8_e4m3fnuz,54,0,814.1634,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,379.82,159.7,0.0 +80,4240,9216,4096,torch.float8_e4m3fnuz,54,0,857.5538,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,373.28,155.4,0.0 +80,16384,9216,4096,torch.float8_e4m3fnuz,54,0,3176.6345,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,389.39,128.08,0.0 +80,32768,9216,4096,torch.float8_e4m3fnuz,54,0,6363.4917,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,388.76,121.94,0.0 +80,1,4608,4096,torch.float8_e4m3fnuz,2,0,10.0986,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,3.74,1870.33,0.0 +80,2,4608,4096,torch.float8_e4m3fnuz,2,0,9.7346,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,7.76,1941.63,0.0 +80,4,4608,4096,torch.float8_e4m3fnuz,30,0,9.434,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_16x64x512_1x4x1_16x16x64_default,16.01,2006.32,0.0 +80,8,4608,4096,torch.float8_e4m3fnuz,30,0,10.005,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_16x64x512_1x4x1_16x16x64_default,30.18,1897.14,0.0 +80,16,4608,4096,torch.float8_e4m3fnuz,30,0,9.6456,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_16x64x512_1x4x1_16x16x64_default,62.62,1978.87,0.0 +80,32,4608,4096,torch.float8_e4m3fnuz,9,0,11.8074,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x64x512_1x4x1_16x16x64_default,102.31,1634.6,0.0 +80,64,4608,4096,torch.float8_e4m3fnuz,23,0,16.1455,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x128x256_1x4x1_16x16x64_default,149.63,1221.79,0.0 +80,128,4608,4096,torch.float8_e4m3fnuz,24,0,23.9186,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x256_1x4x1_16x16x64_default,202.01,860.35,0.0 +80,256,4608,4096,torch.float8_e4m3fnuz,24,0,37.1547,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x256_1x4x1_16x16x64_default,260.09,599.72,0.0 +80,1024,4608,4096,torch.float8_e4m3fnuz,26,0,116.0039,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x192x128_1x4x1_16x16x64_default,333.22,280.21,0.0 +80,2048,4608,4096,torch.float8_e4m3fnuz,54,0,215.9621,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,357.98,213.64,0.0 +80,4096,4608,4096,torch.float8_e4m3fnuz,54,0,420.4927,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,367.71,174.56,0.0 +80,16384,4608,4096,torch.float8_e4m3fnuz,54,0,1606.2618,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,385.04,147.53,0.0 +80,32768,4608,4096,torch.float8_e4m3fnuz,54,0,3174.959,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,389.6,143.33,0.0 +80,1,1280,8192,torch.float8_e4m3fnuz,2,0,13.6383,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,1.54,769.63,0.0 +80,32,1280,8192,torch.float8_e4m3fnuz,2,0,13.9553,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,48.09,776.04,0.0 +80,64,1280,8192,torch.float8_e4m3fnuz,2,0,14.4176,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,93.09,775.02,0.0 +80,128,1280,8192,torch.float8_e4m3fnuz,9,0,17.7099,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x64x512_1x4x1_16x16x64_default,151.57,669.8,0.0 +80,192,1280,8192,torch.float8_e4m3fnuz,37,0,25.2086,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_32x64x512_1x4x1_16x16x64_default,159.73,497.85,0.0 +80,256,1280,8192,torch.float8_e4m3fnuz,23,0,25.7499,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x128x256_1x4x1_16x16x64_default,208.49,514.11,0.0 +80,320,1280,8192,torch.float8_e4m3fnuz,23,0,36.6913,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x128x256_1x4x1_16x16x64_default,182.9,379.56,0.0 +80,512,1280,8192,torch.float8_e4m3fnuz,60,0,40.6609,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x256x256_1x4x1_16x16x64_default,264.07,393.27,0.0 +80,1024,1280,8192,torch.float8_e4m3fnuz,24,0,63.7933,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x256_1x4x1_16x16x64_default,336.63,336.96,0.0 +80,2048,1280,8192,torch.float8_e4m3fnuz,20,0,118.4338,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x256x128_1x4x1_16x16x64_default,362.65,274.46,0.0 +80,4096,1280,8192,torch.float8_e4m3fnuz,68,0,229.0696,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x256x256_1x4x1_16x16x64_default,374.99,238.03,0.0 +80,8192,1280,8192,torch.float8_e4m3fnuz,48,0,438.2764,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x256x128_1x4x1_16x16x64_default,391.99,224.89,0.0 +80,16384,1280,8192,torch.float8_e4m3fnuz,48,0,863.6412,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x256x128_1x4x1_16x16x64_default,397.85,216.12,0.0 +80,1,8192,1024,torch.float8_e4m3fnuz,30,0,6.807,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_16x64x512_1x4x1_16x16x64_default,2.46,1234.91,0.0 +80,32,8192,1024,torch.float8_e4m3fnuz,49,0,8.0046,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_32x64x256_1x4x1_16x16x64_default,67.07,1117.57,0.0 +80,64,8192,1024,torch.float8_e4m3fnuz,50,0,10.9114,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x64x256_1x4x1_16x16x64_default,98.41,870.9,0.0 +80,128,8192,1024,torch.float8_e4m3fnuz,24,0,16.2559,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x256_1x4x1_16x16x64_default,132.1,653.11,0.0 +80,192,8192,1024,torch.float8_e4m3fnuz,51,0,20.526,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_32x128x256_1x4x1_16x16x64_default,156.93,571.52,0.0 +80,256,8192,1024,torch.float8_e4m3fnuz,18,0,24.7986,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x128_1x4x1_16x16x64_default,173.19,517.98,0.0 +80,320,8192,1024,torch.float8_e4m3fnuz,18,0,27.2902,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x128_1x4x1_16x16x64_default,196.73,511.51,0.0 +80,512,8192,1024,torch.float8_e4m3fnuz,46,0,41.0662,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x128x128_1x4x1_16x16x64_default,209.17,421.31,0.0 +80,1024,8192,1024,torch.float8_e4m3fnuz,41,0,73.8261,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,232.71,355.08,0.0 +80,2048,8192,1024,torch.float8_e4m3fnuz,41,0,130.8248,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,262.64,336.63,0.0 +80,4096,8192,1024,torch.float8_e4m3fnuz,41,0,244.4845,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,281.08,325.96,0.0 +80,8192,8192,1024,torch.float8_e4m3fnuz,41,0,468.0259,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,293.66,322.62,0.0 +80,16384,8192,1024,torch.float8_e4m3fnuz,41,0,910.2301,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,301.99,322.56,0.0 +80,16,1536,7168,torch.float8_e4m3fnuz,2,0,12.8707,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,27.37,868.16,0.0 +80,32,1536,7168,torch.float8_e4m3fnuz,2,0,13.6657,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,51.56,829.65,0.0 +80,64,1536,7168,torch.float8_e4m3fnuz,37,0,15.7756,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_32x64x512_1x4x1_16x16x64_default,89.33,739.46,0.0 +80,128,1536,7168,torch.float8_e4m3fnuz,49,0,21.3394,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_32x64x256_1x4x1_16x16x64_default,132.08,577.37,0.0 +80,256,1536,7168,torch.float8_e4m3fnuz,57,0,29.7781,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x192x256_1x4x1_16x16x64_default,189.31,457.77,0.0 +80,512,1536,7168,torch.float8_e4m3fnuz,58,0,45.9558,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x192x256_1x4x1_16x16x64_default,245.33,353.66,0.0 +80,1024,1536,7168,torch.float8_e4m3fnuz,58,0,79.0489,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x192x256_1x4x1_16x16x64_default,285.25,271.93,0.0 +80,1536,1536,7168,torch.float8_e4m3fnuz,0,0,104.6251,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_128x128x128_1x4x1_16x16x64_default,323.28,255.57,0.0 +80,2048,1536,7168,torch.float8_e4m3fnuz,24,0,134.4692,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x256_1x4x1_16x16x64_default,335.37,237.84,0.0 +80,4096,1536,7168,torch.float8_e4m3fnuz,48,0,246.2629,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x256x128_1x4x1_16x16x64_default,366.25,215.03,0.0 +80,8192,1536,7168,torch.float8_e4m3fnuz,54,0,467.0737,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,386.21,203.17,0.0 +80,16384,1536,7168,torch.float8_e4m3fnuz,54,0,915.0814,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,394.26,195.37,0.0 +80,20480,1536,7168,torch.float8_e4m3fnuz,54,0,1130.2865,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,398.99,195.28,0.0 +80,16,3072,1536,torch.float8_e4m3fnuz,30,0,6.0562,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_16x64x512_1x4x1_16x16x64_default,24.93,799.42,0.0 +80,32,3072,1536,torch.float8_e4m3fnuz,37,0,7.0704,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_32x64x512_1x4x1_16x16x64_default,42.71,702.13,0.0 +80,64,3072,1536,torch.float8_e4m3fnuz,49,0,9.2954,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_32x64x256_1x4x1_16x16x64_default,64.98,560.5,0.0 +80,128,3072,1536,torch.float8_e4m3fnuz,23,0,12.03,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x128x256_1x4x1_16x16x64_default,100.41,473.95,0.0 +80,256,3072,1536,torch.float8_e4m3fnuz,57,0,16.8707,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x192x256_1x4x1_16x16x64_default,143.2,396.23,0.0 +80,512,3072,1536,torch.float8_e4m3fnuz,23,0,25.0877,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x128x256_1x4x1_16x16x64_default,192.6,344.82,0.0 +80,1024,3072,1536,torch.float8_e4m3fnuz,46,0,42.7756,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x128x128_1x4x1_16x16x64_default,225.92,294.16,0.0 +80,1536,3072,1536,torch.float8_e4m3fnuz,54,0,54.8776,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,264.14,300.94,0.0 +80,2048,3072,1536,torch.float8_e4m3fnuz,54,0,71.0938,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,271.86,287.61,0.0 +80,4096,3072,1536,torch.float8_e4m3fnuz,54,0,127.4567,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,303.28,283.83,0.0 +80,8192,3072,1536,torch.float8_e4m3fnuz,54,0,241.298,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,320.39,280.29,0.0 +80,16384,3072,1536,torch.float8_e4m3fnuz,54,0,470.5864,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,328.57,277.41,0.0 +80,20480,3072,1536,torch.float8_e4m3fnuz,54,0,586.3499,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,329.62,276.29,0.0 +80,16,576,7168,torch.float8_e4m3fnuz,2,0,12.195,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,10.83,349.48,0.0 +80,32,576,7168,torch.float8_e4m3fnuz,2,0,12.3508,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,21.39,355.85,0.0 +80,64,576,7168,torch.float8_e4m3fnuz,2,0,12.6815,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,41.67,367.56,0.0 +80,128,576,7168,torch.float8_e4m3fnuz,30,0,12.9793,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_16x64x512_1x4x1_16x16x64_default,81.43,400.15,0.0 +80,256,576,7168,torch.float8_e4m3fnuz,37,0,15.7506,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_32x64x512_1x4x1_16x16x64_default,134.21,397.36,0.0 +80,512,576,7168,torch.float8_e4m3fnuz,9,0,22.8675,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x64x512_1x4x1_16x16x64_default,184.89,366.84,0.0 +80,1024,576,7168,torch.float8_e4m3fnuz,50,0,36.5975,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x64x256_1x4x1_16x16x64_default,231.05,345.61,0.0 +80,1536,576,7168,torch.float8_e4m3fnuz,58,0,45.6058,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x192x256_1x4x1_16x16x64_default,278.11,370.75,0.0 +80,2048,576,7168,torch.float8_e4m3fnuz,131,0,67.4424,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_96x192x256_1x4x1_16x16x64_default,250.75,313.87,0.0 +80,4096,576,7168,torch.float8_e4m3fnuz,88,0,110.3714,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_160x192x128_1x4x1_16x16x64_default,306.45,346.17,0.0 +80,8192,576,7168,torch.float8_e4m3fnuz,58,0,190.2246,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x192x256_1x4x1_16x16x64_default,355.61,380.0,0.0 +80,16384,576,7168,torch.float8_e4m3fnuz,54,0,365.3289,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,370.33,384.43,0.0 +80,20480,576,7168,torch.float8_e4m3fnuz,54,0,439.7549,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,384.56,396.86,0.0 +80,16,7168,2048,torch.float8_e4m3fnuz,30,0,8.7391,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_16x64x512_1x4x1_16x16x64_default,53.75,1709.81,0.0 +80,32,7168,2048,torch.float8_e4m3fnuz,9,0,10.512,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x64x512_1x4x1_16x16x64_default,89.38,1446.38,0.0 +80,64,7168,2048,torch.float8_e4m3fnuz,50,0,14.8952,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x64x256_1x4x1_16x16x64_default,126.15,1055.95,0.0 +80,128,7168,2048,torch.float8_e4m3fnuz,23,0,20.7357,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x128x256_1x4x1_16x16x64_default,181.24,809.1,0.0 +80,256,7168,2048,torch.float8_e4m3fnuz,18,0,31.9102,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x128_1x4x1_16x16x64_default,235.54,591.48,0.0 +80,512,7168,2048,torch.float8_e4m3fnuz,48,0,54.7699,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x256x128_1x4x1_16x16x64_default,274.46,421.19,0.0 +80,1024,7168,2048,torch.float8_e4m3fnuz,48,0,102.267,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x256x128_1x4x1_16x16x64_default,293.98,307.6,0.0 +80,1536,7168,2048,torch.float8_e4m3fnuz,18,0,142.7196,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x128_1x4x1_16x16x64_default,315.98,279.19,0.0 +80,2048,7168,2048,torch.float8_e4m3fnuz,18,0,192.4732,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x128_1x4x1_16x16x64_default,312.4,250.6,0.0 +80,4096,7168,2048,torch.float8_e4m3fnuz,41,0,358.4362,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,335.51,228.18,0.0 +80,8192,7168,2048,torch.float8_e4m3fnuz,41,0,685.4349,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,350.9,217.23,0.0 +80,16384,7168,2048,torch.float8_e4m3fnuz,41,0,1352.1887,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,355.75,209.38,0.0 +80,20480,7168,2048,torch.float8_e4m3fnuz,41,0,1690.7436,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,355.64,207.14,0.0 +80,16,4608,7168,torch.float8_e4m3fnuz,2,0,13.3851,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,78.97,2487.26,0.0 +80,32,4608,7168,torch.float8_e4m3fnuz,9,0,15.7244,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x64x512_1x4x1_16x16x64_default,134.44,2133.91,0.0 +80,64,4608,7168,torch.float8_e4m3fnuz,23,0,24.1164,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x128x256_1x4x1_16x16x64_default,175.31,1413.09,0.0 +80,128,4608,7168,torch.float8_e4m3fnuz,24,0,36.1367,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x256_1x4x1_16x16x64_default,233.99,972.07,0.0 +80,256,4608,7168,torch.float8_e4m3fnuz,24,0,59.0996,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x256_1x4x1_16x16x64_default,286.15,629.86,0.0 +80,512,4608,7168,torch.float8_e4m3fnuz,0,0,104.9786,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_128x128x128_1x4x1_16x16x64_default,322.19,394.54,0.0 +80,1024,4608,7168,torch.float8_e4m3fnuz,26,0,187.2691,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x192x128_1x4x1_16x16x64_default,361.22,265.97,0.0 +80,1536,4608,7168,torch.float8_e4m3fnuz,46,0,278.8125,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x128x128_1x4x1_16x16x64_default,363.93,208.73,0.0 +80,2048,4608,7168,torch.float8_e4m3fnuz,54,0,358.9201,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,376.94,185.51,0.0 +80,4096,4608,7168,torch.float8_e4m3fnuz,54,0,703.9066,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,384.4,142.26,0.0 +80,8192,4608,7168,torch.float8_e4m3fnuz,54,0,1358.216,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,398.44,123.14,0.0 +80,16384,4608,7168,torch.float8_e4m3fnuz,54,0,2673.4661,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,404.84,112.76,0.0 +80,20480,4608,7168,torch.float8_e4m3fnuz,54,0,3349.1641,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,403.96,110.05,0.0 +80,16,7168,2304,torch.float8_e4m3fnuz,33,0,10.2988,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_16x64x256_1x4x1_16x16x64_default,51.31,1629.44,0.0 +80,32,7168,2304,torch.float8_e4m3fnuz,21,0,11.7542,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x64x256_1x4x1_16x16x64_default,89.92,1450.34,0.0 +80,64,7168,2304,torch.float8_e4m3fnuz,21,0,15.5548,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x64x256_1x4x1_16x16x64_default,135.9,1130.2,0.0 +80,128,7168,2304,torch.float8_e4m3fnuz,16,0,22.3182,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x64x128_1x4x1_16x16x64_default,189.44,835.42,0.0 +80,256,7168,2304,torch.float8_e4m3fnuz,46,0,33.7458,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x128x128_1x4x1_16x16x64_default,250.57,615.63,0.0 +80,512,7168,2304,torch.float8_e4m3fnuz,48,0,59.6275,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x256x128_1x4x1_16x16x64_default,283.62,419.85,0.0 +80,1024,7168,2304,torch.float8_e4m3fnuz,41,0,109.4005,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,309.17,306.71,0.0 +80,1536,7168,2304,torch.float8_e4m3fnuz,18,0,152.2543,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x128_1x4x1_16x16x64_default,333.22,276.34,0.0 +80,2048,7168,2304,torch.float8_e4m3fnuz,18,0,203.1162,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x128_1x4x1_16x16x64_default,333.04,249.09,0.0 +80,4096,7168,2304,torch.float8_e4m3fnuz,41,0,382.4726,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,353.73,221.38,0.0 +80,8192,7168,2304,torch.float8_e4m3fnuz,41,0,741.673,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,364.83,206.06,0.0 +80,16384,7168,2304,torch.float8_e4m3fnuz,41,0,1464.3626,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,369.56,197.45,0.0 +80,20480,7168,2304,torch.float8_e4m3fnuz,41,0,1814.2843,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,372.85,196.94,0.0 +80,16,512,7168,torch.float8_e4m3fnuz,2,0,12.5374,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,9.37,303.18,0.0 +80,32,512,7168,torch.float8_e4m3fnuz,2,0,12.4268,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,18.9,316.43,0.0 +80,64,512,7168,torch.float8_e4m3fnuz,30,0,12.506,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_16x64x512_1x4x1_16x16x64_default,37.56,335.38,0.0 +80,128,512,7168,torch.float8_e4m3fnuz,2,0,12.6386,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,74.34,373.35,0.0 +80,256,512,7168,torch.float8_e4m3fnuz,9,0,15.4559,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x64x512_1x4x1_16x16x64_default,121.57,373.14,0.0 +80,512,512,7168,torch.float8_e4m3fnuz,9,0,22.6732,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x64x512_1x4x1_16x16x64_default,165.75,346.86,0.0 +80,1024,512,7168,torch.float8_e4m3fnuz,24,0,35.1171,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x256_1x4x1_16x16x64_default,214.03,343.38,0.0 +80,1536,512,7168,torch.float8_e4m3fnuz,99,0,48.688,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_96x128x256_1x4x1_16x16x64_default,231.56,333.82,0.0 +80,2048,512,7168,torch.float8_e4m3fnuz,24,0,55.6177,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x256_1x4x1_16x16x64_default,270.28,367.64,0.0 +80,4096,512,7168,torch.float8_e4m3fnuz,18,0,103.2737,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x128_1x4x1_16x16x64_default,291.12,360.44,0.0 +80,8192,512,7168,torch.float8_e4m3fnuz,24,0,183.9884,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x256_1x4x1_16x16x64_default,326.81,384.69,0.0 +80,16384,512,7168,torch.float8_e4m3fnuz,24,0,339.0609,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x256_1x4x1_16x16x64_default,354.68,406.68,0.0 +80,20480,512,7168,torch.float8_e4m3fnuz,68,0,397.1766,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x256x256_1x4x1_16x16x64_default,378.48,431.65,0.0 +80,16,4096,512,torch.float8_e4m3fnuz,33,0,4.0949,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_16x64x256_1x4x1_16x16x64_default,16.39,546.15,0.0 +80,32,4096,512,torch.float8_e4m3fnuz,21,0,4.7315,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x64x256_1x4x1_16x16x64_default,28.37,502.1,0.0 +80,64,4096,512,torch.float8_e4m3fnuz,49,0,5.8289,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_32x64x256_1x4x1_16x16x64_default,46.05,455.35,0.0 +80,128,4096,512,torch.float8_e4m3fnuz,49,0,7.7762,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_32x64x256_1x4x1_16x16x64_default,69.04,412.96,0.0 +80,256,4096,512,torch.float8_e4m3fnuz,23,0,11.612,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x128x256_1x4x1_16x16x64_default,92.47,372.49,0.0 +80,512,4096,512,torch.float8_e4m3fnuz,45,0,17.9873,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_32x128x128_1x4x1_16x16x64_default,119.39,364.35,0.0 +80,1024,4096,512,torch.float8_e4m3fnuz,18,0,28.8943,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x128_1x4x1_16x16x64_default,148.64,381.05,0.0 +80,1536,4096,512,torch.float8_e4m3fnuz,18,0,38.5091,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x128_1x4x1_16x16x64_default,167.3,401.63,0.0 +80,2048,4096,512,torch.float8_e4m3fnuz,46,0,49.318,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x128x128_1x4x1_16x16x64_default,174.17,403.97,0.0 +80,4096,4096,512,torch.float8_e4m3fnuz,41,0,86.521,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,198.56,436.3,0.0 +80,8192,4096,512,torch.float8_e4m3fnuz,41,0,159.4617,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,215.47,460.3,0.0 +80,16384,4096,512,torch.float8_e4m3fnuz,41,0,304.593,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,225.61,475.07,0.0 +80,20480,4096,512,torch.float8_e4m3fnuz,41,0,375.3411,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,228.86,480.51,0.0 +80,16,7168,256,torch.float8_e4m3fnuz,5,0,4.0505,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x256_1x4x1_16x16x64_default,14.5,510.67,0.0 +80,32,7168,256,torch.float8_e4m3fnuz,49,0,4.9536,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_32x64x256_1x4x1_16x16x64_default,23.71,464.7,0.0 +80,64,7168,256,torch.float8_e4m3fnuz,22,0,6.3078,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x64x256_1x4x1_16x16x64_default,37.24,438.96,0.0 +80,128,7168,256,torch.float8_e4m3fnuz,17,0,8.5847,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x128x128_1x4x1_16x16x64_default,54.72,431.32,0.0 +80,256,7168,256,torch.float8_e4m3fnuz,46,0,12.3108,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x128x128_1x4x1_16x16x64_default,76.32,452.49,0.0 +80,512,7168,256,torch.float8_e4m3fnuz,10,0,20.0727,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x256x64_1x4x1_16x16x64_default,93.61,463.62,0.0 +80,1024,7168,256,torch.float8_e4m3fnuz,41,0,34.008,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,110.51,493.33,0.0 +80,1536,7168,256,torch.float8_e4m3fnuz,45,0,47.0605,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_32x128x128_1x4x1_16x16x64_default,119.79,515.26,0.0 +80,2048,7168,256,torch.float8_e4m3fnuz,41,0,59.2228,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,126.91,535.59,0.0 +80,4096,7168,256,torch.float8_e4m3fnuz,41,0,105.2494,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,142.83,585.31,0.0 +80,8192,7168,256,torch.float8_e4m3fnuz,41,0,197.2828,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,152.39,615.22,0.0 +80,16384,7168,256,torch.float8_e4m3fnuz,41,0,381.565,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,157.59,631.37,0.0 +80,20480,7168,256,torch.float8_e4m3fnuz,41,0,472.001,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,159.24,637.03,0.0 +80,1,4096,512,torch.float8_e4m3fnuz,5,0,4.0057,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x256_1x4x1_16x16x64_default,1.05,525.71,0.0 +80,1,2112,7168,torch.float8_e4m3fnuz,2,0,12.3092,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,2.46,1230.8,0.0 +80,1,4608,7168,torch.float8_e4m3fnuz,2,0,12.6122,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,5.24,2620.2,0.0 +80,1,7168,2304,torch.float8_e4m3fnuz,5,0,10.0275,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x256_1x4x1_16x16x64_default,3.29,1648.64,0.0 +80,1,512,7168,torch.float8_e4m3fnuz,30,0,12.268,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_16x64x512_1x4x1_16x16x64_default,0.6,299.82,0.0 +80,1,7168,256,torch.float8_e4m3fnuz,5,0,4.1686,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x256_1x4x1_16x16x64_default,0.88,443.7,0.0 +80,16,2112,7168,torch.float8_e4m3fnuz,2,0,12.2697,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,39.48,1248.69,0.0 +80,32,2112,7168,torch.float8_e4m3fnuz,2,0,12.5735,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,77.06,1233.02,0.0 +80,48,4096,512,torch.float8_e4m3fnuz,33,0,5.6993,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_16x64x256_1x4x1_16x16x64_default,35.32,441.27,0.0 +80,48,2112,7168,torch.float8_e4m3fnuz,9,0,15.4811,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x64x512_1x4x1_16x16x64_default,93.88,1013.21,0.0 +80,48,4608,7168,torch.float8_e4m3fnuz,23,0,23.3699,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x128x256_1x4x1_16x16x64_default,135.68,1447.01,0.0 +80,48,7168,2304,torch.float8_e4m3fnuz,49,0,15.3526,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_32x64x256_1x4x1_16x16x64_default,103.27,1127.74,0.0 +80,48,512,7168,torch.float8_e4m3fnuz,2,0,12.4611,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,28.27,326.07,0.0 +80,48,7168,256,torch.float8_e4m3fnuz,50,0,6.0534,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x64x256_1x4x1_16x16x64_default,29.1,418.84,0.0 +80,64,2112,7168,torch.float8_e4m3fnuz,9,0,15.7138,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x64x512_1x4x1_16x16x64_default,123.32,1009.81,0.0 +80,80,4096,512,torch.float8_e4m3fnuz,30,0,6.5582,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_16x64x512_1x4x1_16x16x64_default,51.16,425.95,0.0 +80,80,2112,7168,torch.float8_e4m3fnuz,37,0,22.7948,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_32x64x512_1x4x1_16x16x64_default,106.26,704.12,0.0 +80,80,4608,7168,torch.float8_e4m3fnuz,57,0,29.8014,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x192x256_1x4x1_16x16x64_default,177.33,1152.32,0.0 +80,80,7168,2304,torch.float8_e4m3fnuz,45,0,18.9166,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_32x128x128_1x4x1_16x16x64_default,139.69,943.42,0.0 +80,80,512,7168,torch.float8_e4m3fnuz,30,0,12.5603,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_16x64x512_1x4x1_16x16x64_default,46.75,344.37,0.0 +80,80,7168,256,torch.float8_e4m3fnuz,6,0,7.3721,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x128x256_1x4x1_16x16x64_default,39.83,407.26,0.0 +80,96,4096,512,torch.float8_e4m3fnuz,21,0,7.2506,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x64x256_1x4x1_16x16x64_default,55.53,404.48,0.0 +80,96,2112,7168,torch.float8_e4m3fnuz,37,0,23.0909,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_32x64x512_1x4x1_16x16x64_default,125.88,702.98,0.0 +80,96,4608,7168,torch.float8_e4m3fnuz,66,0,31.2288,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_32x192x256_1x4x1_16x16x64_default,203.07,1108.05,0.0 +80,96,7168,2304,torch.float8_e4m3fnuz,17,0,19.3679,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x128x128_1x4x1_16x16x64_default,163.72,935.18,0.0 +80,96,512,7168,torch.float8_e4m3fnuz,2,0,12.4871,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,56.43,356.88,0.0 +80,96,7168,256,torch.float8_e4m3fnuz,23,0,8.0488,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x128x256_1x4x1_16x16x64_default,43.77,402.03,0.0 +80,112,4096,512,torch.float8_e4m3fnuz,51,0,7.9797,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_32x128x256_1x4x1_16x16x64_default,58.87,384.98,0.0 +80,112,2112,7168,torch.float8_e4m3fnuz,9,0,22.8662,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x64x512_1x4x1_16x16x64_default,148.3,717.86,0.0 +80,112,4608,7168,torch.float8_e4m3fnuz,24,0,35.9043,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x128x256_1x4x1_16x16x64_default,206.07,971.06,0.0 +80,112,7168,2304,torch.float8_e4m3fnuz,23,0,22.119,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x128x256_1x4x1_16x16x64_default,167.25,830.9,0.0 +80,112,512,7168,torch.float8_e4m3fnuz,2,0,12.7783,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_16x64x512_1x4x1_16x16x64_default,64.33,359.01,0.0 +80,112,7168,256,torch.float8_e4m3fnuz,34,0,7.9613,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_16x128x256_1x4x1_16x16x64_default,51.63,435.77,0.0 +80,128,2112,7168,torch.float8_e4m3fnuz,37,0,22.8679,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_32x64x512_1x4x1_16x16x64_default,169.47,725.78,0.0 +80,256,2112,7168,torch.float8_e4m3fnuz,50,0,36.4203,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x64x256_1x4x1_16x16x64_default,212.82,495.74,0.0 +80,512,2112,7168,torch.float8_e4m3fnuz,98,0,67.0873,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_96x192x256_1x4x1_16x16x64_default,231.07,312.6,0.0 +80,1024,2112,7168,torch.float8_e4m3fnuz,57,0,109.5452,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x192x256_1x4x1_16x16x64_default,283.03,244.69,0.0 +80,1536,2112,7168,torch.float8_e4m3fnuz,54,0,147.1439,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,316.06,221.8,0.0 +80,2048,2112,7168,torch.float8_e4m3fnuz,26,0,184.2233,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_64x192x128_1x4x1_16x16x64_default,336.59,208.82,0.0 +80,4096,2112,7168,torch.float8_e4m3fnuz,54,0,325.7229,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,380.74,189.73,0.0 +80,8192,2112,7168,torch.float8_e4m3fnuz,54,0,632.8361,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,391.94,171.39,0.0 +80,16384,2112,7168,torch.float8_e4m3fnuz,54,0,1251.1568,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,396.49,161.28,0.0 +80,32768,4096,512,torch.float8_e4m3fnuz,41,0,592.8636,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,231.82,484.61,0.0 +80,32768,2112,7168,torch.float8_e4m3fnuz,54,0,2480.8163,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,399.92,156.57,0.0 +80,32768,4608,7168,torch.float8_e4m3fnuz,54,0,5335.9575,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x192x128_1x4x1_16x16x64_default,405.67,106.8,0.0 +80,32768,7168,2304,torch.float8_e4m3fnuz,41,0,2910.7761,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,371.84,193.0,0.0 +80,32768,512,7168,torch.float8_e4m3fnuz,48,0,634.4121,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_64x256x128_1x4x1_16x16x64_default,379.12,428.91,0.0 +80,32768,7168,256,torch.float8_e4m3fnuz,41,0,747.496,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x2_128x128x64_1x4x1_16x16x64_default,160.88,642.12,0.0 diff --git a/aiter/configs/a8w8_bpreshuffle_cktile_untuned_gemm.csv b/aiter/configs/a8w8_bpreshuffle_cktile_untuned_gemm.csv new file mode 100644 index 0000000000..9054e3064c --- /dev/null +++ b/aiter/configs/a8w8_bpreshuffle_cktile_untuned_gemm.csv @@ -0,0 +1,221 @@ +M,N,K,q_dtype_w +1,9216,4096,torch.float8_e4m3fnuz +2,9216,4096,torch.float8_e4m3fnuz +4,9216,4096,torch.float8_e4m3fnuz +8,9216,4096,torch.float8_e4m3fnuz +16,9216,4096,torch.float8_e4m3fnuz +32,9216,4096,torch.float8_e4m3fnuz +64,9216,4096,torch.float8_e4m3fnuz +128,9216,4096,torch.float8_e4m3fnuz +256,9216,4096,torch.float8_e4m3fnuz +1024,9216,4096,torch.float8_e4m3fnuz +2048,9216,4096,torch.float8_e4m3fnuz +4096,9216,4096,torch.float8_e4m3fnuz +4240,9216,4096,torch.float8_e4m3fnuz +16384,9216,4096,torch.float8_e4m3fnuz +32768,9216,4096,torch.float8_e4m3fnuz +1,4608,4096,torch.float8_e4m3fnuz +2,4608,4096,torch.float8_e4m3fnuz +4,4608,4096,torch.float8_e4m3fnuz +8,4608,4096,torch.float8_e4m3fnuz +16,4608,4096,torch.float8_e4m3fnuz +32,4608,4096,torch.float8_e4m3fnuz +64,4608,4096,torch.float8_e4m3fnuz +128,4608,4096,torch.float8_e4m3fnuz +256,4608,4096,torch.float8_e4m3fnuz +1024,4608,4096,torch.float8_e4m3fnuz +2048,4608,4096,torch.float8_e4m3fnuz +4096,4608,4096,torch.float8_e4m3fnuz +16384,4608,4096,torch.float8_e4m3fnuz +32768,4608,4096,torch.float8_e4m3fnuz +1,1280,8192,torch.float8_e4m3fnuz +32,1280,8192,torch.float8_e4m3fnuz +64,1280,8192,torch.float8_e4m3fnuz +128,1280,8192,torch.float8_e4m3fnuz +192,1280,8192,torch.float8_e4m3fnuz +256,1280,8192,torch.float8_e4m3fnuz +320,1280,8192,torch.float8_e4m3fnuz +512,1280,8192,torch.float8_e4m3fnuz +1024,1280,8192,torch.float8_e4m3fnuz +2048,1280,8192,torch.float8_e4m3fnuz +4096,1280,8192,torch.float8_e4m3fnuz +8192,1280,8192,torch.float8_e4m3fnuz +16384,1280,8192,torch.float8_e4m3fnuz +1,8192,1024,torch.float8_e4m3fnuz +32,8192,1024,torch.float8_e4m3fnuz +64,8192,1024,torch.float8_e4m3fnuz +128,8192,1024,torch.float8_e4m3fnuz +192,8192,1024,torch.float8_e4m3fnuz +256,8192,1024,torch.float8_e4m3fnuz +320,8192,1024,torch.float8_e4m3fnuz +512,8192,1024,torch.float8_e4m3fnuz +1024,8192,1024,torch.float8_e4m3fnuz +2048,8192,1024,torch.float8_e4m3fnuz +4096,8192,1024,torch.float8_e4m3fnuz +8192,8192,1024,torch.float8_e4m3fnuz +16384,8192,1024,torch.float8_e4m3fnuz +16,1536,7168,torch.float8_e4m3fnuz +32,1536,7168,torch.float8_e4m3fnuz +64,1536,7168,torch.float8_e4m3fnuz +128,1536,7168,torch.float8_e4m3fnuz +256,1536,7168,torch.float8_e4m3fnuz +512,1536,7168,torch.float8_e4m3fnuz +1024,1536,7168,torch.float8_e4m3fnuz +1536,1536,7168,torch.float8_e4m3fnuz +2048,1536,7168,torch.float8_e4m3fnuz +4096,1536,7168,torch.float8_e4m3fnuz +8192,1536,7168,torch.float8_e4m3fnuz +16384,1536,7168,torch.float8_e4m3fnuz +20480,1536,7168,torch.float8_e4m3fnuz +16,3072,1536,torch.float8_e4m3fnuz +32,3072,1536,torch.float8_e4m3fnuz +64,3072,1536,torch.float8_e4m3fnuz +128,3072,1536,torch.float8_e4m3fnuz +256,3072,1536,torch.float8_e4m3fnuz +512,3072,1536,torch.float8_e4m3fnuz +1024,3072,1536,torch.float8_e4m3fnuz +1536,3072,1536,torch.float8_e4m3fnuz +2048,3072,1536,torch.float8_e4m3fnuz +4096,3072,1536,torch.float8_e4m3fnuz +8192,3072,1536,torch.float8_e4m3fnuz +16384,3072,1536,torch.float8_e4m3fnuz +20480,3072,1536,torch.float8_e4m3fnuz +16,576,7168,torch.float8_e4m3fnuz +32,576,7168,torch.float8_e4m3fnuz +64,576,7168,torch.float8_e4m3fnuz +128,576,7168,torch.float8_e4m3fnuz +256,576,7168,torch.float8_e4m3fnuz +512,576,7168,torch.float8_e4m3fnuz +1024,576,7168,torch.float8_e4m3fnuz +1536,576,7168,torch.float8_e4m3fnuz +2048,576,7168,torch.float8_e4m3fnuz +4096,576,7168,torch.float8_e4m3fnuz +8192,576,7168,torch.float8_e4m3fnuz +16384,576,7168,torch.float8_e4m3fnuz +20480,576,7168,torch.float8_e4m3fnuz +16,7168,2048,torch.float8_e4m3fnuz +32,7168,2048,torch.float8_e4m3fnuz +64,7168,2048,torch.float8_e4m3fnuz +128,7168,2048,torch.float8_e4m3fnuz +256,7168,2048,torch.float8_e4m3fnuz +512,7168,2048,torch.float8_e4m3fnuz +1024,7168,2048,torch.float8_e4m3fnuz +1536,7168,2048,torch.float8_e4m3fnuz +2048,7168,2048,torch.float8_e4m3fnuz +4096,7168,2048,torch.float8_e4m3fnuz +8192,7168,2048,torch.float8_e4m3fnuz +16384,7168,2048,torch.float8_e4m3fnuz +20480,7168,2048,torch.float8_e4m3fnuz +16,4608,7168,torch.float8_e4m3fnuz +32,4608,7168,torch.float8_e4m3fnuz +64,4608,7168,torch.float8_e4m3fnuz +128,4608,7168,torch.float8_e4m3fnuz +256,4608,7168,torch.float8_e4m3fnuz +512,4608,7168,torch.float8_e4m3fnuz +1024,4608,7168,torch.float8_e4m3fnuz +1536,4608,7168,torch.float8_e4m3fnuz +2048,4608,7168,torch.float8_e4m3fnuz +4096,4608,7168,torch.float8_e4m3fnuz +8192,4608,7168,torch.float8_e4m3fnuz +16384,4608,7168,torch.float8_e4m3fnuz +20480,4608,7168,torch.float8_e4m3fnuz +16,7168,2304,torch.float8_e4m3fnuz +32,7168,2304,torch.float8_e4m3fnuz +64,7168,2304,torch.float8_e4m3fnuz +128,7168,2304,torch.float8_e4m3fnuz +256,7168,2304,torch.float8_e4m3fnuz +512,7168,2304,torch.float8_e4m3fnuz +1024,7168,2304,torch.float8_e4m3fnuz +1536,7168,2304,torch.float8_e4m3fnuz +2048,7168,2304,torch.float8_e4m3fnuz +4096,7168,2304,torch.float8_e4m3fnuz +8192,7168,2304,torch.float8_e4m3fnuz +16384,7168,2304,torch.float8_e4m3fnuz +20480,7168,2304,torch.float8_e4m3fnuz +16,512,7168,torch.float8_e4m3fnuz +32,512,7168,torch.float8_e4m3fnuz +64,512,7168,torch.float8_e4m3fnuz +128,512,7168,torch.float8_e4m3fnuz +256,512,7168,torch.float8_e4m3fnuz +512,512,7168,torch.float8_e4m3fnuz +1024,512,7168,torch.float8_e4m3fnuz +1536,512,7168,torch.float8_e4m3fnuz +2048,512,7168,torch.float8_e4m3fnuz +4096,512,7168,torch.float8_e4m3fnuz +8192,512,7168,torch.float8_e4m3fnuz +16384,512,7168,torch.float8_e4m3fnuz +20480,512,7168,torch.float8_e4m3fnuz +16,4096,512,torch.float8_e4m3fnuz +32,4096,512,torch.float8_e4m3fnuz +64,4096,512,torch.float8_e4m3fnuz +128,4096,512,torch.float8_e4m3fnuz +256,4096,512,torch.float8_e4m3fnuz +512,4096,512,torch.float8_e4m3fnuz +1024,4096,512,torch.float8_e4m3fnuz +1536,4096,512,torch.float8_e4m3fnuz +2048,4096,512,torch.float8_e4m3fnuz +4096,4096,512,torch.float8_e4m3fnuz +8192,4096,512,torch.float8_e4m3fnuz +16384,4096,512,torch.float8_e4m3fnuz +20480,4096,512,torch.float8_e4m3fnuz +16,7168,256,torch.float8_e4m3fnuz +32,7168,256,torch.float8_e4m3fnuz +64,7168,256,torch.float8_e4m3fnuz +128,7168,256,torch.float8_e4m3fnuz +256,7168,256,torch.float8_e4m3fnuz +512,7168,256,torch.float8_e4m3fnuz +1024,7168,256,torch.float8_e4m3fnuz +1536,7168,256,torch.float8_e4m3fnuz +2048,7168,256,torch.float8_e4m3fnuz +4096,7168,256,torch.float8_e4m3fnuz +8192,7168,256,torch.float8_e4m3fnuz +16384,7168,256,torch.float8_e4m3fnuz +20480,7168,256,torch.float8_e4m3fnuz +1,4096,512,torch.float8_e4m3fnuz +1,2112,7168,torch.float8_e4m3fnuz +1,4608,7168,torch.float8_e4m3fnuz +1,7168,2304,torch.float8_e4m3fnuz +1,512,7168,torch.float8_e4m3fnuz +1,7168,256,torch.float8_e4m3fnuz +16,2112,7168,torch.float8_e4m3fnuz +32,2112,7168,torch.float8_e4m3fnuz +48,4096,512,torch.float8_e4m3fnuz +48,2112,7168,torch.float8_e4m3fnuz +48,4608,7168,torch.float8_e4m3fnuz +48,7168,2304,torch.float8_e4m3fnuz +48,512,7168,torch.float8_e4m3fnuz +48,7168,256,torch.float8_e4m3fnuz +64,2112,7168,torch.float8_e4m3fnuz +80,4096,512,torch.float8_e4m3fnuz +80,2112,7168,torch.float8_e4m3fnuz +80,4608,7168,torch.float8_e4m3fnuz +80,7168,2304,torch.float8_e4m3fnuz +80,512,7168,torch.float8_e4m3fnuz +80,7168,256,torch.float8_e4m3fnuz +96,4096,512,torch.float8_e4m3fnuz +96,2112,7168,torch.float8_e4m3fnuz +96,4608,7168,torch.float8_e4m3fnuz +96,7168,2304,torch.float8_e4m3fnuz +96,512,7168,torch.float8_e4m3fnuz +96,7168,256,torch.float8_e4m3fnuz +112,4096,512,torch.float8_e4m3fnuz +112,2112,7168,torch.float8_e4m3fnuz +112,4608,7168,torch.float8_e4m3fnuz +112,7168,2304,torch.float8_e4m3fnuz +112,512,7168,torch.float8_e4m3fnuz +112,7168,256,torch.float8_e4m3fnuz +128,2112,7168,torch.float8_e4m3fnuz +256,2112,7168,torch.float8_e4m3fnuz +512,2112,7168,torch.float8_e4m3fnuz +1024,2112,7168,torch.float8_e4m3fnuz +1536,2112,7168,torch.float8_e4m3fnuz +2048,2112,7168,torch.float8_e4m3fnuz +4096,2112,7168,torch.float8_e4m3fnuz +8192,2112,7168,torch.float8_e4m3fnuz +16384,2112,7168,torch.float8_e4m3fnuz +32768,4096,512,torch.float8_e4m3fnuz +32768,2112,7168,torch.float8_e4m3fnuz +32768,4608,7168,torch.float8_e4m3fnuz +32768,7168,2304,torch.float8_e4m3fnuz +32768,512,7168,torch.float8_e4m3fnuz +32768,7168,256,torch.float8_e4m3fnuz diff --git a/aiter/jit/core.py b/aiter/jit/core.py index ca4423820e..504b506513 100644 --- a/aiter/jit/core.py +++ b/aiter/jit/core.py @@ -152,6 +152,12 @@ def get_config_file(env_name, default_file, tuned_file_name): "AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE", f"{AITER_ROOT_DIR}/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv", ) + +AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_CKTILE = os.getenv( + "AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_CKTILE", + f"{AITER_ROOT_DIR}/aiter/configs/a8w8_bpreshuffle_cktile_tuned_gemm.csv", +) + AITER_CONFIG_GEMM_A8W8_BLOCKSCALE = os.getenv( "AITER_CONFIG_GEMM_A8W8_BLOCKSCALE", f"{AITER_ROOT_DIR}/aiter/configs/a8w8_blockscale_tuned_gemm.csv", @@ -192,6 +198,11 @@ def get_config_file(env_name, default_file, tuned_file_name): AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE, "a8w8_bpreshuffle_tuned_gemm", ) +AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_CKTILE_FILE = get_config_file( + "AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_CKTILE", + AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_CKTILE, + "a8w8_bpreshuffle_cktile_tuned_gemm", +) AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_FILE = get_config_file( "AITER_CONFIG_GEMM_A8W8_BLOCKSCALE", AITER_CONFIG_GEMM_A8W8_BLOCKSCALE, diff --git a/aiter/jit/optCompilerConfig.json b/aiter/jit/optCompilerConfig.json index c91f972b16..1d1e2f3e9f 100755 --- a/aiter/jit/optCompilerConfig.json +++ b/aiter/jit/optCompilerConfig.json @@ -288,6 +288,25 @@ "hip_clang_path": "os.environ.get('FLATMM_HIP_CLANG_PATH')", "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_deepgemm/gen_instances.py --working_path {{}}'" }, + "module_gemm_a8w8_bpreshuffle_cktile": { + "srcs": [ + "f'{AITER_CSRC_DIR}/pybind/gemm_a8w8_bpreshuffle_cktile_pybind.cu'", + "f'{AITER_CSRC_DIR}/py_itfs_cu/gemm_common.cu'", + "f'{AITER_CSRC_DIR}/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile.cu'" + ], + "flags_extra_cc": [], + "flags_extra_hip": [], + "extra_ldflags": "None", + "extra_include": [ + "f'{AITER_CSRC_DIR}/cktile_gemm_a8w8_bpreshuffle/include'", + "f'{CK_DIR}/example/ck_tile/18_flatmm'" + ], + "is_python_module": "True", + "is_standalone": "False", + "verbose": "False", + "hip_clang_path": "os.environ.get('FLATMM_HIP_CLANG_PATH')", + "blob_gen_cmd": "f'{AITER_CSRC_DIR}/cktile_gemm_a8w8_bpreshuffle/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_CKTILE_FILE}'" + }, "module_gemm_a8w8_asm": { "srcs": [ "f'{AITER_CSRC_DIR}/pybind/gemm_a8w8_asm_pybind.cu'", @@ -582,6 +601,24 @@ "is_standalone": "False", "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_bpreshuffle/gen_instances.py --working_path {{}} --tune'" }, + "module_gemm_a8w8_bpreshuffle_cktile_tune": { + "srcs": [ + "f'{AITER_CSRC_DIR}/pybind/gemm_a8w8_bpreshuffle_cktile_tune_pybind.cu'", + "f'{AITER_CSRC_DIR}/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile_tune.cu'" + ], + "flags_extra_cc": [], + "flags_extra_hip": [], + "extra_ldflags": "None", + "extra_include": [ + "f'{AITER_CSRC_DIR}/cktile_gemm_a8w8_bpreshuffle/include'", + "f'{CK_DIR}/example/ck_tile/18_flatmm'" + ], + "verbose": "False", + "hip_clang_path": "os.environ.get('FLATMM_HIP_CLANG_PATH')", + "is_python_module": "True", + "is_standalone": "False", + "blob_gen_cmd": "f'{AITER_CSRC_DIR}/cktile_gemm_a8w8_bpreshuffle/gen_instances.py --working_path {{}} --tune'" + }, "module_aiter_operator": { "srcs": [ "f'{AITER_CSRC_DIR}/pybind/aiter_operator_pybind.cu'", diff --git a/aiter/ops/gemm_op_a8w8.py b/aiter/ops/gemm_op_a8w8.py index db81c45f38..6cd2da4758 100644 --- a/aiter/ops/gemm_op_a8w8.py +++ b/aiter/ops/gemm_op_a8w8.py @@ -13,6 +13,7 @@ AITER_CONFIG_GEMM_A8W8_FILE, AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_FILE, AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE, + AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_CKTILE_FILE, AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE_FILE, AITER_LOG_TUNED_CONFIG, ) @@ -75,6 +76,30 @@ def gemm_a8w8_bpreshuffle_ck( ) -> torch.Tensor: ... +def gen_gemm_a8w8_bpreshuffle_cktile_fake_tensors( + XQ: torch.Tensor, + WQ: torch.Tensor, + x_scale: torch.Tensor, + w_scale: torch.Tensor, + Out: torch.Tensor, +) -> torch.Tensor: + return Out + + +@compile_ops( + "module_gemm_a8w8_bpreshuffle_cktile", + fc_name="gemm_a8w8_bpreshuffle_cktile", + gen_fake=gen_gemm_a8w8_bpreshuffle_cktile_fake_tensors, +) +def gemm_a8w8_bpreshuffle_cktile( + XQ: Tensor, + WQ: Tensor, + x_scale: Tensor, + w_scale: Tensor, + out: Tensor, +) -> Tensor: ... + + def gen_gemm_a8w8_asm_fake_tensors( XQ: Tensor, # A:[M, K] i8 WQ: Tensor, # B:[N, K] i8 -> shuffle layout(32,16) @@ -249,21 +274,32 @@ def get_bpreshuffle_GEMM_config( q_dtype_w: torch.dtype, tuned_file=f"{AITER_ROOT_DIR}/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv", ): - if not hasattr(get_bpreshuffle_GEMM_config, "bpreshuffle_gemm_dict"): + # Use dict to cache configs for different files + if not hasattr(get_bpreshuffle_GEMM_config, "file_cache"): + get_bpreshuffle_GEMM_config.file_cache = {} + + # Load file if not cached + if tuned_file not in get_bpreshuffle_GEMM_config.file_cache: asmGemmDictDf = pd.read_csv(tuned_file).drop_duplicates() - get_bpreshuffle_GEMM_config.bpreshuffle_gemm_dict = asmGemmDictDf.set_index( + get_bpreshuffle_GEMM_config.file_cache[tuned_file] = asmGemmDictDf.set_index( ["cu_num", "M", "N", "K", "q_dtype_w"] ).to_dict("index") + cu_num = get_cu_num() - config = get_bpreshuffle_GEMM_config.bpreshuffle_gemm_dict.get( - (cu_num, M, N, K, str(q_dtype_w)), None - ) - if config is not None: - if AITER_LOG_TUNED_CONFIG: - logger.info( - f"shape M:{M}, N:{N}, K:{K} q_dtype_w:{q_dtype_w} is tuned, in {tuned_file}!" - ) - else: + padded_M = M + config = None + for gl in [None, 0, 1]: + padded_M = M if gl is None else get_padded_m(M, N, K, gl) + config = get_bpreshuffle_GEMM_config.file_cache[tuned_file].get( + (cu_num, padded_M, N, K, str(q_dtype_w)), None + ) + if config is not None: + if AITER_LOG_TUNED_CONFIG: + logger.info( + f"shape M:{M}, N:{N}, K:{K} q_dtype_w:{q_dtype_w}, found padded_M: {padded_M}, N:{N}, K:{K} is tuned, in {tuned_file}!" + ) + break + if config is None: logger.info( f"shape is M:{M}, N:{N}, K:{K}, q_dtype_w:{q_dtype_w}, not found tuned config in {tuned_file}, will use default config!" ) @@ -406,9 +442,6 @@ def gemm_a8w8_bpreshuffle( n = WQ.shape[0] k = XQ.shape[-1] - get_bpreshuffle_GEMM_config( - m, n, k, dtypes.fp8, AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE - ) # if ( # ck_config is None # and dtype == dtypes.bf16 @@ -421,7 +454,40 @@ def gemm_a8w8_bpreshuffle( assert WQ.dtype == dtypes.fp8, "gemm_a8w8_bpreshuffle only support fp8 now" assert bias is None, "gemm_a8w8_bpreshuffle does not support bias now" Y = torch.empty(m, n, dtype=dtype, device=XQ.device) - return gemm_a8w8_bpreshuffle_ck(XQ, WQ, x_scale, w_scale, Y) + + # CKTile only supports bf16 dtype + if dtype == dtypes.bf16: + cktile_config = get_bpreshuffle_GEMM_config( + m, n, k, dtypes.fp8, AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_CKTILE_FILE + ) + else: + cktile_config = None + + ck_config = get_bpreshuffle_GEMM_config( + m, n, k, dtypes.fp8, AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE + ) + if cktile_config is not None and ck_config is not None: + cktile_time = cktile_config.get("us", float("inf")) + ck_time = ck_config.get("us", float("inf")) + + if AITER_LOG_TUNED_CONFIG: + logger.info( + f"Both CKTile and CK configs found for M:{m}, N:{n}, K:{k} - " + f"CKTile time: {cktile_time:.6f}us, CK time: {ck_time:.6f}us" + ) + + if cktile_time <= ck_time: + if AITER_LOG_TUNED_CONFIG: + logger.info(f"Using CKTile implementation (faster)") + return gemm_a8w8_bpreshuffle_cktile(XQ, WQ, x_scale, w_scale, Y) + else: + if AITER_LOG_TUNED_CONFIG: + logger.info(f"Using CK implementation (faster)") + return gemm_a8w8_bpreshuffle_ck(XQ, WQ, x_scale, w_scale, Y) + else: + if AITER_LOG_TUNED_CONFIG: + logger.info(f"default Using CK implementation") + return gemm_a8w8_bpreshuffle_ck(XQ, WQ, x_scale, w_scale, Y) def gemm_a8w8_blockscale_fake( @@ -612,3 +678,18 @@ def gemm_a8w8_blockscale_bpreshuffle_tune( kernelId: int = 0, splitK: int = 0, ) -> torch.Tensor: ... + + +@compile_ops( + "module_gemm_a8w8_bpreshuffle_cktile_tune", + fc_name="gemm_a8w8_bpreshuffle_cktile_tune", +) +def gemm_a8w8_bpreshuffle_cktile_tune( + XQ: Tensor, + WQ: Tensor, + x_scale: Tensor, + w_scale: Tensor, + out: Tensor, + kernelId: int, + splitK: int = 0, +) -> Tensor: ... diff --git a/csrc/cktile_gemm_a8w8_bpreshuffle/README.md b/csrc/cktile_gemm_a8w8_bpreshuffle/README.md new file mode 100644 index 0000000000..afe6d25d13 --- /dev/null +++ b/csrc/cktile_gemm_a8w8_bpreshuffle/README.md @@ -0,0 +1,18 @@ +# CKTILE gemm a8w8 bpreshuffle tune + +1. Install aiter: +`python3 setup.py develop` + +2. Tune gemm a8w8: + First add GEMM shapes in `aiter/configs/a8w8_bpreshuffle_cktile_untuned_gemm.csv`, then run the following cmd to start tuning, please wait a few minutes as it will build gemm_a8w8_bpreshuffle_cktile_tune via jit: +`FLATMM_HIP_CLANG_PATH=/data/llvm-project/build/bin/ python3 csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile_tune.py -i aiter/configs/a8w8_bpreshuffle_cktile_untuned_gemm.csv -o aiter/configs/a8w8_bpreshuffle_cktile_tuned_gemm.csv` +If you want to use split K kernels, you can add the `-k` parameter at the end, notice that should change `bias` to `bias/(2^k)`. +You can find the results of the tuning in `aiter/configs/a8w8_bpreshuffle_cktile_tuned_gemm.csv`. + +3. Test the performance, modify the test instance in `op_tests/test_gemm_a8w8.py` and run it, please wait a few minutes as it will build gemm_a8w8_bpreshuffle_cktile kernels in `aiter/configs/a8w8_bpreshuffle_cktile_tuned_gemm.csv` via jit: +`FLATMM_HIP_CLANG_PATH=/data/llvm-project/build/bin/ python3 op_tests/test_gemm_a8w8.py` + + +## More +If you want to re-install gemm_a8w8_bpreshuffle_cktile, you should remove `aiter/jit/module_gemm_a8w8_bpreshuffle_cktile.so` and `aiter/jit/build/module_gemm_a8w8_bpreshuffle_cktile` first. +If you use flag `PREBUILD_KERNELS=1` when you install aiter, it will build gemm a8w8 kernels in `aiter/configs/a8w8_bpreshuffle_cktile_tuned_gemm.csv` by default. If you want to use the new result of gemm_a8w8_bpreshuffle_cktile_tune, please remove `build` and `*.so` first, then re-intall aiter after finishing tune. diff --git a/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile.cu b/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile.cu new file mode 100755 index 0000000000..54727af540 --- /dev/null +++ b/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile.cu @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_bpreshuffle_cktile_common.cuh" +#include "gemm_a8w8_bpreshuffle_cktile_lookup.h" +#include "gemm_a8w8_bpreshuffle_cktile_manifest.h" +#include "gemm_common.h" +#include + +using RowwiseKernel = std::function; + +// Define a custom hash function for std::tuple +struct IntTupleHash +{ + size_t operator()(const std::tuple& t) const + { + auto hash1 = std::hash{}(std::get<0>(t)); + auto hash2 = std::hash{}(std::get<1>(t)); + auto hash3 = std::hash{}(std::get<2>(t)); + return hash1 ^ hash2 ^ hash3; + } +}; + +using RowwiseKernelMap = std::unordered_map, RowwiseKernel, IntTupleHash>; + +template +RowwiseKernel rowwise_heuristic_dispatch(int M, int N, int K) +{ + // Use default kernel for all architectures + return a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_128x128x128_1x4x1_16x16x64_default< + DDataType, EDataType>; +} + +// Helper function to return the next largest power of 2 +static constexpr int nextPow2(unsigned int num) +{ + if(num <= 1) + return 1; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} + +template +RowwiseKernel rowwise_dispatch(int M, int N, int K) +{ + // For a given shape, either find the best kernel via lookup or heuristic. + // For many small M shapes, we bucket them to the next largest kernel. + // This is fine since kernels are padded anyway. + + static const auto lookup = [] { + if constexpr(std::is_same_v) + { + return RowwiseKernelMap{GENERATE_LOOKUP_TABLE(DDataType, F16)}; + } + else if constexpr(std::is_same_v) + { + return RowwiseKernelMap{GENERATE_LOOKUP_TABLE(DDataType, B16)}; + } + else + { + static_assert(false, "rowwise_dispatch used with unsupported dtype!"); + } + }(); + + // First check if this shape(M,N,K) is available in the direct lookup. + auto it = lookup.find({M, N, K}); + // If we found an optimal kernel, use it. + if(it != lookup.end()) + { + return it->second; + } + + int padded_m = M; + + // Fine-grained search + padded_m = getPaddedM(M, N, K, 0); + // Second check if this shape(padded_m,N,K) is available in the direct lookup. + it = lookup.find({padded_m, N, K}); + // If we found an optimal kernel, use it. + if(it != lookup.end()) + { + return it->second; + } + + // Coarse-grained search + padded_m = getPaddedM(M, N, K, 1); + // Third check if this shape(padded_m,N,K) is available in the direct lookup. + it = lookup.find({padded_m, N, K}); + // If we found an optimal kernel, use it. + if(it != lookup.end()) + { + return it->second; + } + + // Otherwise, use heuristics. + return rowwise_heuristic_dispatch(M, N, K); +} + +torch::Tensor gemm_a8w8_bpreshuffle_cktile(torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) +{ + TORCH_CHECK(XQ.dtype() == WQ.dtype(), "Weights and activations should have the same dtype!"); + TORCH_CHECK(x_scale.dtype() == w_scale.dtype(), "Scales should have the same dtype!"); + + int M = XQ.size(0); + int N = WQ.size(0); + int K = XQ.size(1); + + if(x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::Half) + { + rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y); + } + else if(x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::BFloat16) + { + rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y); + } + else + { + TORCH_CHECK(false, "Unsupported scales/output dtype!"); + } + return Y; +} diff --git a/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile_common.py b/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile_common.py new file mode 100644 index 0000000000..670e73894a --- /dev/null +++ b/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile_common.py @@ -0,0 +1,384 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +from dataclasses import dataclass +import os +import sys + +this_dir = os.path.dirname(os.path.abspath(__file__)) +AITER_CORE_DIR = os.path.abspath(f"{this_dir}/../../../") +if os.path.exists(os.path.join(AITER_CORE_DIR, "aiter_meta")): + AITER_CORE_DIR = os.path.join(AITER_CORE_DIR, "aiter/jit/utils") # pip install mode +else: + AITER_CORE_DIR = os.path.abspath( + f"{this_dir}/../../aiter/jit/utils" + ) # develop mode +sys.path.insert(0, AITER_CORE_DIR) + +from chip_info import get_gfx # noqa: E402 + + +@dataclass +class kernelInstance: + sTransposeC: bool + sUseStructuredSparsity: bool + sTileParitionerGroupNum: int + sTileParitionerM01: int + sNumWaveGroups: int + sDoubleSmemBuffer: bool + PadM: bool + PadN: bool + PadK: bool + BlockPerCu: int + MTile: int + NTile: int + KTile: int + MWarp: int + NWarp: int + KWarp: int + MWTile: int + NWTile: int + KWTile: int + sScheduler: str + + @property + def name(self) -> str: + return ("_").join( + [ + "a8w8_bpreshuffle_cktile", + ("x").join( + map( + lambda x: str(x), + [ + self.sTransposeC, + self.sUseStructuredSparsity, + self.sTileParitionerGroupNum, + self.sTileParitionerM01, + self.sNumWaveGroups, + self.sDoubleSmemBuffer, + self.PadM, + self.PadN, + self.PadK, + self.BlockPerCu, + ], + ) + ), + ("x").join(map(lambda x: str(x), [self.MTile, self.NTile, self.KTile])), + ("x").join(map(lambda x: str(x), [self.MWarp, self.NWarp, self.KWarp])), + ("x").join( + map(lambda x: str(x), [self.MWTile, self.NWTile, self.KWTile]) + ), + self.sScheduler.lower(), + ] + ) + + +# fmt: off +# kernels_list_str = ''' +kernels_list_942 = { + 0: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 128, 128, 128, 1, 4, 1, 16, 16, 64, "Default"), + 1: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 128, 128, 256, 1, 4, 1, 16, 16, 64, "Default"), + 2: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 64, 512, 1, 4, 1, 16, 16, 64, "Default"), + 3: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 128, 512, 1, 4, 1, 16, 16, 64, "Default"), + 4: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 256, 512, 1, 4, 1, 16, 16, 64, "Default"), + 5: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), + 6: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 128, 256, 1, 4, 1, 16, 16, 64, "Default"), + 7: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 256, 256, 1, 4, 1, 16, 16, 64, "Default"), + 8: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 512, 256, 1, 4, 1, 16, 16, 64, "Default"), + 9: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 64, 512, 1, 4, 1, 16, 16, 64, "Default"), + 10: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 256, 64, 1, 4, 1, 16, 16, 64, "Default"), + 11: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 128, 64, 128, 1, 4, 1, 16, 16, 64, "Default"), + 12: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 128, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), + 13: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 128, 128, 64, 1, 4, 1, 16, 16, 64, "Default"), + 14: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 128, 256, 128, 1, 4, 1, 16, 16, 64, "Default"), + 15: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 64, 128, 1, 4, 1, 16, 16, 64, "Default"), + 16: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 64, 128, 1, 4, 1, 16, 16, 64, "Default"), + 17: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 128, 128, 1, 4, 1, 16, 16, 64, "Default"), + 18: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 128, 128, 1, 4, 1, 16, 16, 64, "Default"), + 19: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 256, 128, 1, 4, 1, 16, 16, 64, "Default"), + 20: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 256, 128, 1, 4, 1, 16, 16, 64, "Default"), + 21: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), + 22: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), + 23: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 128, 256, 1, 4, 1, 16, 16, 64, "Default"), + 24: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 128, 256, 1, 4, 1, 16, 16, 64, "Default"), + 25: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), + 26: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), + 27: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 128, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), + 28: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 128, 128, 128, 1, 4, 1, 16, 16, 64, "Default"), + 29: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 128, 128, 256, 1, 4, 1, 16, 16, 64, "Default"), + 30: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 16, 64, 512, 1, 4, 1, 16, 16, 64, "Default"), + 31: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 16, 128, 512, 1, 4, 1, 16, 16, 64, "Default"), + 32: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 16, 256, 512, 1, 4, 1, 16, 16, 64, "Default"), + 33: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 16, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), + 34: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 16, 128, 256, 1, 4, 1, 16, 16, 64, "Default"), + 35: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 16, 256, 256, 1, 4, 1, 16, 16, 64, "Default"), + 36: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 16, 512, 256, 1, 4, 1, 16, 16, 64, "Default"), + 37: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 32, 64, 512, 1, 4, 1, 16, 16, 64, "Default"), + 38: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 64, 256, 64, 1, 4, 1, 16, 16, 64, "Default"), + 39: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 128, 64, 128, 1, 4, 1, 16, 16, 64, "Default"), + 40: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 128, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), + 41: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 128, 128, 64, 1, 4, 1, 16, 16, 64, "Default"), + 42: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 128, 256, 128, 1, 4, 1, 16, 16, 64, "Default"), + 43: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 32, 64, 128, 1, 4, 1, 16, 16, 64, "Default"), + 44: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 64, 64, 128, 1, 4, 1, 16, 16, 64, "Default"), + 45: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 32, 128, 128, 1, 4, 1, 16, 16, 64, "Default"), + 46: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 64, 128, 128, 1, 4, 1, 16, 16, 64, "Default"), + 47: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 32, 256, 128, 1, 4, 1, 16, 16, 64, "Default"), + 48: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 64, 256, 128, 1, 4, 1, 16, 16, 64, "Default"), + 49: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 32, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), + 50: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 64, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), + 51: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 32, 128, 256, 1, 4, 1, 16, 16, 64, "Default"), + 52: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 64, 128, 256, 1, 4, 1, 16, 16, 64, "Default"), + 53: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 32, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), + 54: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 64, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), + 55: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 128, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), + 56: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 192, 256, 1, 4, 1, 16, 16, 64, "Default"), + 57: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 192, 256, 1, 4, 1, 16, 16, 64, "Default"), + 58: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 192, 256, 1, 4, 1, 16, 16, 64, "Default"), + 59: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 256, 256, 1, 4, 1, 16, 16, 64, "Default"), + 60: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 256, 256, 1, 4, 1, 16, 16, 64, "Default"), + 61: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 512, 256, 1, 4, 1, 16, 16, 64, "Default"), + 62: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 256, 64, 1, 4, 1, 16, 16, 64, "Default"), + 63: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 256, 512, 1, 4, 1, 16, 16, 64, "Default"), + 64: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 256, 256, 1, 4, 1, 16, 16, 64, "Default"), + 65: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 16, 192, 256, 1, 4, 1, 16, 16, 64, "Default"), + 66: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 32, 192, 256, 1, 4, 1, 16, 16, 64, "Default"), + 67: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 64, 192, 256, 1, 4, 1, 16, 16, 64, "Default"), + 68: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 64, 256, 256, 1, 4, 1, 16, 16, 64, "Default"), + 69: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 32, 256, 256, 1, 4, 1, 16, 16, 64, "Default"), + 70: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 16, 512, 256, 1, 4, 1, 16, 16, 64, "Default"), + 71: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 64, 256, 64, 1, 4, 1, 16, 16, 64, "Default"), + 72: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 16, 256, 512, 1, 4, 1, 16, 16, 64, "Default"), + 73: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 16, 256, 256, 1, 4, 1, 16, 16, 64, "Default"), + 74: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 160, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), + 75: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 48, 128, 256, 1, 4, 1, 16, 16, 64, "Default"), + 76: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), + 77: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 48, 192, 256, 1, 4, 1, 16, 16, 64, "Default"), + 78: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 48, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), + 79: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 64, 128, 1, 4, 1, 16, 16, 64, "Default"), + 80: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 128, 128, 1, 4, 1, 16, 16, 64, "Default"), + 81: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), + 82: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 256, 128, 1, 4, 1, 16, 16, 64, "Default"), + 83: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 48, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), + 84: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 80, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), + 85: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), + 86: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 112, 128, 256, 1, 4, 1, 16, 16, 64, "Default"), + 87: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 112, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), + 88: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 160, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), + 89: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 224, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), + 90: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 256, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), + 91: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 48, 256, 256, 1, 4, 1, 16, 16, 64, "Default"), + 92: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 80, 128, 256, 1, 4, 1, 16, 16, 64, "Default"), + 93: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 224, 64, 128, 1, 4, 1, 16, 16, 64, "Default"), + 94: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 112, 192, 256, 1, 4, 1, 16, 16, 64, "Default"), + 95: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 128, 192, 256, 1, 4, 1, 16, 16, 64, "Default"), + 96: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 192, 128, 128, 1, 4, 1, 16, 16, 64, "Default"), + 97: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 224, 128, 128, 1, 4, 1, 16, 16, 64, "Default"), + 98: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 192, 256, 1, 4, 1, 16, 16, 64, "Default"), + 99: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 128, 256, 1, 4, 1, 16, 16, 64, "Default"), + 100: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 256, 128, 128, 1, 4, 1, 16, 16, 64, "Default"), + 101: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 80, 256, 256, 1, 4, 1, 16, 16, 64, "Default"), + 102: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 192, 64, 128, 1, 4, 1, 16, 16, 64, "Default"), + 103: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 80, 192, 256, 1, 4, 1, 16, 16, 64, "Default"), + 104: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 256, 64, 128, 1, 4, 1, 16, 16, 64, "Default"), + 105: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 112, 256, 256, 1, 4, 1, 16, 16, 64, "Default"), + 106: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 256, 256, 1, 4, 1, 16, 16, 64, "Default"), + 107: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 48, 128, 256, 1, 4, 1, 16, 16, 64, "Default"), + 108: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), + 109: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 48, 192, 256, 1, 4, 1, 16, 16, 64, "Default"), + 110: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 48, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), + 111: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 64, 128, 1, 4, 1, 16, 16, 64, "Default"), + 112: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 128, 128, 1, 4, 1, 16, 16, 64, "Default"), + 113: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 256, 256, 1, 4, 1, 16, 16, 64, "Default"), + 114: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), + 115: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 256, 128, 1, 4, 1, 16, 16, 64, "Default"), + 116: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 48, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), + 117: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 80, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), + 118: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), + 119: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 112, 128, 256, 1, 4, 1, 16, 16, 64, "Default"), + 120: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 112, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), + 121: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 160, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), + 122: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 224, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), + 123: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 256, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), + 124: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 48, 256, 256, 1, 4, 1, 16, 16, 64, "Default"), + 125: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 80, 128, 256, 1, 4, 1, 16, 16, 64, "Default"), + 126: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 224, 64, 128, 1, 4, 1, 16, 16, 64, "Default"), + 127: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 112, 192, 256, 1, 4, 1, 16, 16, 64, "Default"), + 128: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 128, 192, 256, 1, 4, 1, 16, 16, 64, "Default"), + 129: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 192, 128, 128, 1, 4, 1, 16, 16, 64, "Default"), + 130: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 224, 128, 128, 1, 4, 1, 16, 16, 64, "Default"), + 131: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 192, 256, 1, 4, 1, 16, 16, 64, "Default"), + 132: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 128, 256, 1, 4, 1, 16, 16, 64, "Default"), + 133: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 256, 128, 128, 1, 4, 1, 16, 16, 64, "Default"), + 134: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 80, 256, 256, 1, 4, 1, 16, 16, 64, "Default"), + 135: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 192, 64, 128, 1, 4, 1, 16, 16, 64, "Default"), + 136: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 80, 192, 256, 1, 4, 1, 16, 16, 64, "Default"), + 137: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 256, 64, 128, 1, 4, 1, 16, 16, 64, "Default"), + 138: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 112, 256, 256, 1, 4, 1, 16, 16, 64, "Default"), + +} +# ''' + +default_kernels_dict_942 = { + (-1): kernelInstance(0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 128, 128, 128, 1, 4, 1, 16, 16, 64, "Default"), + (-2):kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 64, 512, 1, 4, 1, 16, 16, 64, "Default"), + (-3):kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 64, 512, 1, 4, 1, 16, 16, 64, "Default"), + (-4):kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 256, 64, 1, 4, 1, 16, 16, 64, "Default"), + (-5):kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 128, 128, 64, 1, 4, 1, 16, 16, 64, "Default"), + (-6):kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 128, 64, 128, 1, 4, 1, 16, 16, 64, "Default"), + (-7):kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 256, 128, 1, 4, 1, 16, 16, 64, "Default"), +} + +kernels_list_950 = { + 0: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 128, 128, 128, 1, 4, 1, 16, 16, 128, "Default"), + 1: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 128, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), + 2: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 64, 512, 1, 4, 1, 16, 16, 128, "Default"), + 3: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 128, 512, 1, 4, 1, 16, 16, 128, "Default"), + 4: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 256, 512, 1, 4, 1, 16, 16, 128, "Default"), + 5: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), + 6: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), + 7: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 256, 256, 1, 4, 1, 16, 16, 128, "Default"), + 8: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 512, 256, 1, 4, 1, 16, 16, 128, "Default"), + 9: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 64, 512, 1, 4, 1, 16, 16, 128, "Default"), + 10: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 256, 256, 128, 1, 4, 1, 16, 16, 128, "Default"), + 11: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 128, 64, 128, 1, 4, 1, 16, 16, 128, "Default"), + 12: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 128, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), + 13: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 256, 128, 128, 1, 4, 1, 16, 16, 128, "Default"), + 14: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 128, 256, 128, 1, 4, 1, 16, 16, 128, "Default"), + 15: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 64, 128, 1, 4, 1, 16, 16, 128, "Default"), + 16: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 64, 128, 1, 4, 1, 16, 16, 128, "Default"), + 17: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 128, 128, 1, 4, 1, 16, 16, 128, "Default"), + 18: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 128, 128, 1, 4, 1, 16, 16, 128, "Default"), + 19: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 256, 128, 1, 4, 1, 16, 16, 128, "Default"), + 20: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 256, 128, 1, 4, 1, 16, 16, 128, "Default"), + 21: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), + 22: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), + 23: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), + 24: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), + 25: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), + 26: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), + 27: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 128, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), + 28: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 128, 128, 128, 1, 4, 1, 16, 16, 128, "Default"), + 29: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 128, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), + 30: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 16, 64, 512, 1, 4, 1, 16, 16, 128, "Default"), + 31: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 16, 128, 512, 1, 4, 1, 16, 16, 128, "Default"), + 32: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 16, 256, 512, 1, 4, 1, 16, 16, 128, "Default"), + 33: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 16, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), + 34: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 16, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), + 35: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 16, 256, 256, 1, 4, 1, 16, 16, 128, "Default"), + 36: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 16, 512, 256, 1, 4, 1, 16, 16, 128, "Default"), + 37: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 32, 64, 512, 1, 4, 1, 16, 16, 128, "Default"), + 38: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 192, 64, 128, 1, 4, 1, 16, 16, 128, "Default"), + 39: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 128, 64, 128, 1, 4, 1, 16, 16, 128, "Default"), + 40: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 128, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), + 41: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 80, 256, 256, 1, 4, 1, 16, 16, 128, "Default"), + 42: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 128, 256, 128, 1, 4, 1, 16, 16, 128, "Default"), + 43: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 32, 64, 128, 1, 4, 1, 16, 16, 128, "Default"), + 44: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 64, 64, 128, 1, 4, 1, 16, 16, 128, "Default"), + 45: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 32, 128, 128, 1, 4, 1, 16, 16, 128, "Default"), + 46: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 64, 128, 128, 1, 4, 1, 16, 16, 128, "Default"), + 47: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 32, 256, 128, 1, 4, 1, 16, 16, 128, "Default"), + 48: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 64, 256, 128, 1, 4, 1, 16, 16, 128, "Default"), + 49: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 32, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), + 50: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 64, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), + 51: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 32, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), + 52: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 64, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), + 53: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 32, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), + 54: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 64, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), + 55: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 128, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), + 56: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), + 57: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), + 58: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), + 59: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 256, 256, 1, 4, 1, 16, 16, 128, "Default"), + 60: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 256, 256, 1, 4, 1, 16, 16, 128, "Default"), + 61: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 512, 256, 1, 4, 1, 16, 16, 128, "Default"), + 62: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 80, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), + 63: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 256, 512, 1, 4, 1, 16, 16, 128, "Default"), + 64: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 16, 256, 256, 1, 4, 1, 16, 16, 128, "Default"), + 65: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 16, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), + 66: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 32, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), + 67: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 64, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), + 68: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 64, 256, 256, 1, 4, 1, 16, 16, 128, "Default"), + 69: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 32, 256, 256, 1, 4, 1, 16, 16, 128, "Default"), + 70: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 16, 512, 256, 1, 4, 1, 16, 16, 128, "Default"), + 71: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 256, 64, 128, 1, 4, 1, 16, 16, 128, "Default"), + 72: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 16, 256, 512, 1, 4, 1, 16, 16, 128, "Default"), + 73: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 16, 256, 256, 1, 4, 1, 16, 16, 128, "Default"), + 74: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 160, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), + 75: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 48, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), + 76: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), + 77: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 48, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), + 78: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 48, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), + 79: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 64, 128, 1, 4, 1, 16, 16, 128, "Default"), + 80: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 128, 128, 1, 4, 1, 16, 16, 128, "Default"), + 81: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), + 82: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 256, 128, 1, 4, 1, 16, 16, 128, "Default"), + 83: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 48, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), + 84: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 80, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), + 85: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), + 86: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 112, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), + 87: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 112, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), + 88: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 160, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), + 89: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 224, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), + 90: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 256, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), + 91: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 48, 256, 256, 1, 4, 1, 16, 16, 128, "Default"), + 92: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 80, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), + 93: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 224, 64, 128, 1, 4, 1, 16, 16, 128, "Default"), + 94: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 112, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), + 95: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 128, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), + 96: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 192, 128, 128, 1, 4, 1, 16, 16, 128, "Default"), + 97: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 224, 128, 128, 1, 4, 1, 16, 16, 128, "Default"), + 98: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), + 99: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), + 100: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 256, 128, 128, 1, 4, 1, 16, 16, 128, "Default"), + 101: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 80, 256, 256, 1, 4, 1, 16, 16, 128, "Default"), + 102: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 192, 64, 128, 1, 4, 1, 16, 16, 128, "Default"), + 103: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 80, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), + 104: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 256, 64, 128, 1, 4, 1, 16, 16, 128, "Default"), + 105: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 112, 256, 256, 1, 4, 1, 16, 16, 128, "Default"), + 106: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 256, 256, 1, 4, 1, 16, 16, 128, "Default"), + 107: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 48, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), + 108: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), + 109: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 48, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), + 110: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 48, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), + 111: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 64, 128, 1, 4, 1, 16, 16, 128, "Default"), + 112: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 128, 128, 1, 4, 1, 16, 16, 128, "Default"), + 113: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 256, 256, 1, 4, 1, 16, 16, 128, "Default"), + 114: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), + 115: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 256, 128, 1, 4, 1, 16, 16, 128, "Default"), + 116: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 48, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), + 117: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 80, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), + 118: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), + 119: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 112, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), + 120: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 112, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), + 121: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 160, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), + 122: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 224, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), + 123: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 256, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), + 124: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 48, 256, 256, 1, 4, 1, 16, 16, 128, "Default"), + 125: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 80, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), + 126: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 224, 64, 128, 1, 4, 1, 16, 16, 128, "Default"), + 127: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 112, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), + 128: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 128, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), + 129: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 192, 128, 128, 1, 4, 1, 16, 16, 128, "Default"), + 130: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 224, 128, 128, 1, 4, 1, 16, 16, 128, "Default"), + 131: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), + 132: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), + + +} + +default_kernels_dict_950 = { + (-1): kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0,1, 128, 256, 256, 1, 4, 1, 16, 16, 128, "Default"), + (-2): kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0,1, 16, 64, 512, 1, 4, 1, 16, 16, 128, "Default"), + (-3): kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0,1, 32, 64, 512, 1, 4, 1, 16, 16, 128, "Default"), + (-4): kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0,1, 128, 128, 128, 1, 4, 1, 16, 16, 64, "Default"), +} + +# fmt: on + +arch = get_gfx() +if arch == "gfx942": + kernels_list = kernels_list_942 + default_kernels_dict = default_kernels_dict_942 +else: + kernels_list = kernels_list_950 + default_kernels_dict = default_kernels_dict_950 diff --git a/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile_tune.cu b/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile_tune.cu new file mode 100644 index 0000000000..b7658841ff --- /dev/null +++ b/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile_tune.cu @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_bpreshuffle_cktile_common.cuh" +#include "gemm_a8w8_bpreshuffle_cktile_lookup.h" +#include "gemm_a8w8_bpreshuffle_cktile_manifest.h" +#include "py_itfs_common.h" +#include + +using RowwiseKernel = std::function; + +// For certain high priority shapes, we directly use the best kernel rather +// than use heuristics. +using RowwiseKernelMap = std::unordered_map; + +// Helper function to return the next largest power of 2 +static constexpr int nextPow2(unsigned int num) +{ + if(num <= 1) + return 1; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} + +template +RowwiseKernel rowwise_dispatch(int id) +{ + // For a given shape, either find the best kernel via lookup or heuristic. + // For many small M shapes, we bucket them to the next largest kernel. + // This is fine since kernels are padded anyway. + + // First check if this shape is available in the direct lookup. + static const auto lookup = [] { + if constexpr(std::is_same_v) + { + return RowwiseKernelMap{GENERATE_LOOKUP_TABLE(DDataType, F16)}; + } + else if constexpr(std::is_same_v) + { + return RowwiseKernelMap{GENERATE_LOOKUP_TABLE(DDataType, B16)}; + } + else + { + static_assert(false, "rowwise_dispatch used with unsupported dtype!"); + } + }(); + + TORCH_CHECK(id < lookup.size(), + "Kernel id " + std::to_string(id) + + " is out of range! (lookup.size()=" + std::to_string(lookup.size()) + ")"); + auto it = lookup.find(id); + // If we found an optimal kernel, use it. + if(it != lookup.end()) + { + return it->second; + } + // Otherwise, use heuristics. + return lookup.find(0)->second; +} + +torch::Tensor gemm_a8w8_bpreshuffle_cktile_tune(torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y, + int kernelId, + int splitK) +{ + TORCH_CHECK(XQ.dtype() == torch_fp8 && XQ.dtype() == WQ.dtype(), + "Weights and activations should both be fp8!"); + TORCH_CHECK(x_scale.dtype() == w_scale.dtype(), "Scales should have the same dtype!"); + std::optional bias = std::nullopt; + + int M = XQ.size(0); + int N = WQ.size(0); + int K = XQ.size(1); + int KBatch = std::pow(2, splitK); + + if(Y.dtype() == at::ScalarType::BFloat16) + { + rowwise_dispatch(kernelId)(XQ, WQ, x_scale, w_scale, Y); + } + else + { + TORCH_CHECK(false, "Unsupported scales/output dtype!"); + } + return Y; +} diff --git a/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile_tune.py b/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile_tune.py new file mode 100755 index 0000000000..177c72f6f5 --- /dev/null +++ b/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile_tune.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import os +import aiter +import pandas as pd +import torch +import torch.nn.functional as F +from aiter import dtypes +from aiter.jit.core import AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_CKTILE +from aiter.utility.base_tuner import GemmCommonTuner +from aiter.ops.shuffle import shuffle_weight +from gemm_a8w8_bpreshuffle_cktile_common import kernels_list +import argparse +from aiter.utility.mp_tuner import mp_tuner + + +def run_torch(x, weight, x_scale, w_scale, bias=None, dtype=torch.bfloat16): + x = x.to(dtypes.fp32) * x_scale + weight = weight.to(dtypes.fp32) * w_scale + out = F.linear(x, weight) + if bias is not None: + out = out.to(bias) + bias + return out.to(dtype) + + +def run_gemm_a8w8_bpreshuffle_cktile( + x, weight, x_scale, w_scale, out, kernel_id, splitK=0 +): + aiter.gemm_a8w8_bpreshuffle_cktile_tune( + x, weight, x_scale, w_scale, out, kernel_id, splitK + ) + return out + + +def generate_data( + m, n, k, seed, dtype=dtypes.bf16, q_dtype_w=dtypes.fp8, device="cuda" +): + torch.manual_seed(seed) + x = torch.randn((m, k), dtype=dtype, device=device) + weight = torch.randn((n, k), dtype=dtype, device=device) + x, x_scale = aiter.pertoken_quant(x, quant_dtype=q_dtype_w) + weight, w_scale = aiter.pertoken_quant(weight, quant_dtype=q_dtype_w) + bias_f32 = None + weight_shuffle = shuffle_weight(weight, layout=(16, 16)) + out = torch.empty(m, n, dtype=dtype, device=device) + return x, weight_shuffle, x_scale, w_scale, out, weight, bias_f32 + + +class GemmA8W8BpreShuffleCktileTuner(GemmCommonTuner): + ARG_DEFAULTS = { + **GemmCommonTuner.ARG_DEFAULTS, + "tune_file": f"{AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_CKTILE}", + "untune_file": "aiter/configs/a8w8_bpreshuffle_cktile_untuned_gemm.csv", + } + + def _setup_specific_arguments(self): + pass + + def calculate(self, results, bpes=(1, 1, 2)): + ## bpes = (inbpe, w_bpe, outbpe) + return super().calculate(results, bpes=bpes) + + def getKernelName(self, kernelId): + if kernelId < 0 or kernelId > len(kernels_list): + return None + return kernels_list[kernelId].name + + def get_cktile_gemm_a8w8_bpreshuffle_tune_task( + self, + info_keys, + useSplitK, + seed, + ): + (cu_num, M, N, K, q_dtype_w) = info_keys + if eval(q_dtype_w) != dtypes.fp8: + print( + f"Warning: q_dtype_w only support {dtypes.fp8}, actual q_dtype_w is {q_dtype_w}!" + ) + return [] + kernels_num = len(kernels_list) + gemm_a8w8_idx = [0, 1, 2, 3, 4] # input index in generate_data + ref_data_idx = [0, 5, 2, 3, 6] + tasks_ck = [] + for i in range(kernels_num): + kernel = kernels_list[i] + maxsplitK = ( + aiter.compute_gemm_SplitK( + M, + N, + K, + kernel.MPerBLOCK, + kernel.NPerBLOCK, + kernel.KPerBLOCK, + ) + if useSplitK + else 0 + ) + for splitK in range(maxsplitK + 1): + info = (info_keys, i, splitK, "") + tasks_ck.append( + ( + info, + generate_data, + (M, N, K, seed, dtypes.bf16, eval(q_dtype_w)), + run_gemm_a8w8_bpreshuffle_cktile, + ( + gemm_a8w8_idx, + i, + splitK, + ), + {}, + run_torch, + ( + ref_data_idx, + dtypes.bf16, + ), + {}, + None, + 1e-2, + 0.01, + ) + ) + return tasks_ck + + def tune( + self, + untunedf, + tunedf, + args, + ): + issorted = args.sort + useSplitK = args.splitK + mp_num = args.mp + shape_grouped = False + errRatio = args.errRatio + cu_num = self.get_cu_num() + task = [] + tasks_data = [] # [(kernel_nums, datas)] + seed = 10000 + for i in range(len(untunedf)): + M = untunedf.loc[i, "M"] + N = untunedf.loc[i, "N"] + K = untunedf.loc[i, "K"] + q_dtype_w = untunedf.loc[i, "q_dtype_w"] + seed = seed + 1 + total_kernel_nums = 0 + kernels_num = len(kernels_list) + info_keys = (cu_num, M, N, K, q_dtype_w) + task.extend( + self.get_cktile_gemm_a8w8_bpreshuffle_tune_task( + info_keys, + useSplitK, + seed, + ) + ) + + total_kernel_nums = len(task) + + tasks_data.append((total_kernel_nums, ())) + ret = [] + if task: + ret = mp_tuner(task, tasks_data, mp_num, False, shape_grouped, errRatio) + + return ret + + +if __name__ == "__main__": + ## use default key and resultList + key = ["cu_num", "M", "N", "K", "q_dtype_w"] + tuner = GemmA8W8BpreShuffleCktileTuner( + "GemmA8W8BpreShuffleCktileTuner", + key=key, + description="gen API for gemm a8w8 bpreshuffle cktile kernel", + ) + + args = tuner.parse_args() + tuner.run(args, False) diff --git a/csrc/cktile_gemm_a8w8_bpreshuffle/gen_instances.py b/csrc/cktile_gemm_a8w8_bpreshuffle/gen_instances.py new file mode 100755 index 0000000000..656a273d80 --- /dev/null +++ b/csrc/cktile_gemm_a8w8_bpreshuffle/gen_instances.py @@ -0,0 +1,281 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +import os +import sys +from dataclasses import dataclass +import copy +from pathlib import Path +import pandas as pd +import argparse +import shutil +import torch +from gemm_a8w8_bpreshuffle_cktile_common import ( + kernelInstance, + kernels_list, + default_kernels_dict, +) + + +""" + +gemm_a8w8_bpreshuffle_cktile instance gen + +""" + + +class gemm_a8w8_bpreshuffle_cktile_codegen: + def __init__(self, working_path, istune=False): + self.working_path = working_path + self.impl_path = os.path.join(working_path, "impl") + self.instances_path = os.path.join(working_path, "instances") + self.istune = istune + + def gen_instance(self, k: kernelInstance): + INSTANCE_IMPL = f"""// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_bpreshuffle_cktile_common.cuh" + +template +torch::Tensor +{k.name}( + torch::Tensor &XQ, + torch::Tensor &WQ, + torch::Tensor &x_scale, + torch::Tensor &w_scale, + torch::Tensor &Y + ) +{{{{ + // The smallest kernel we have available. Works well for memory bound shapes. + + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (M % {k.MTile} != 0) || (N % {k.NTile} != 0) || (K % ({k.KTile}) != 0); + if (pad) + {{{{ + // pad + {{INSTANCE_CONTENT_pad}} + // pad + }}}} + else + {{{{ + // no pad + {{INSTANCE_CONTENT_nopad}} + // no pad + }}}} +}}}} + +""" + + INSTANCE_CONTENT_nobias = f"""using FlatmmInstance = CustomConfig< + DDataType, EDataType, + {k.sTransposeC},{k.sUseStructuredSparsity}, {k.sTileParitionerGroupNum}, + {k.sTileParitionerM01}, {k.sNumWaveGroups}, {k.sDoubleSmemBuffer}, + {k.PadM}, {k.PadN}, {k.PadK}, + {k.BlockPerCu}, + {k.MTile}, {k.NTile}, {k.KTile}, + {k.MWarp}, {k.NWarp}, {k.KWarp}, + {k.MWTile}, {k.NWTile}, {k.KWTile}, + ck_tile::GemmPipelineScheduler::{k.sScheduler}>; + // Run kernel instance. + return gemm_a8w8_bpreshuffle_cktile_impl(XQ, WQ, x_scale, w_scale, Y); +""" + if self.istune: + INSTANCE_IMPL_str = INSTANCE_IMPL.format( + INSTANCE_CONTENT_pad=( + INSTANCE_CONTENT_nobias.format(GemmSpec="MNKPadding") + ), + INSTANCE_CONTENT_nopad=( + INSTANCE_CONTENT_nobias.format(GemmSpec="Default") + ), + ) + else: + INSTANCE_IMPL_str = INSTANCE_IMPL.format( + INSTANCE_CONTENT_pad=INSTANCE_CONTENT_nobias.format( + GemmSpec="MNKPadding" + ), + INSTANCE_CONTENT_nopad=INSTANCE_CONTENT_nobias.format( + GemmSpec="Default" + ), + ) + + Path(os.path.join(self.impl_path, f"{k.name}.cuh")).write_text( + INSTANCE_IMPL_str + ) + + INSTANCE_template = """// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "impl/{name}.cuh" + +template torch::Tensor +{name}<{dtypes}>( + torch::Tensor &XQ, + torch::Tensor &WQ, + torch::Tensor &x_scale, + torch::Tensor &w_scale, + torch::Tensor &Y + ); + +""" + INSTANCE_dFP32_eBF16 = INSTANCE_template.format(name=k.name, dtypes="F32, B16") + INSTANCE_dFP32_eFP16 = INSTANCE_template.format(name=k.name, dtypes="F32, F16") + # TODO: dFP8_eFP8 + + if self.istune: + Path( + os.path.join(self.instances_path, f"{k.name}_dFP32_eBF16.cpp") + ).write_text(INSTANCE_dFP32_eBF16) + else: + Path( + os.path.join(self.instances_path, f"{k.name}_dFP32_eBF16.cpp") + ).write_text(INSTANCE_dFP32_eBF16) + Path( + os.path.join(self.instances_path, f"{k.name}_dFP32_eFP16.cpp") + ).write_text(INSTANCE_dFP32_eFP16) + + def gen_lookup_dict(self, kernels_dict): + LOOKUP_head = """#pragma once +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#ifdef USE_ROCM + +#define GENERATE_LOOKUP_TABLE(DTYPE, ETYPE) \\ + { \\""" + + LOOKUP_template = """ + {{{MNK}, \\ + {kernel_name}}}, \\""" + + LOOKUP_end = """ + } + +#endif // USE_ROCM +""" + with open( + os.path.join(self.working_path, "gemm_a8w8_bpreshuffle_cktile_lookup.h"), + "w", + ) as f: + f.write(LOOKUP_head) + for mnk, k in kernels_dict.items(): + # print((", ").join(map(lambda x: str(x), list(mnk))), ":", k.name) + if not self.istune and (isinstance(mnk, tuple) and mnk[0] > 0): + f.write( + LOOKUP_template.format( + MNK="{" + + (", ").join(map(lambda x: str(x), list(mnk))) + + "}", + kernel_name=k.name, + ) + ) + elif self.istune and isinstance(mnk, int): + f.write(LOOKUP_template.format(MNK=mnk, kernel_name=k.name)) + f.write(LOOKUP_end) + + def gen_manifest_head(self, kernels_dict): + MAINFEST_head = """#pragma once +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#ifdef USE_ROCM + +#include + +#include +""" + MAINFEST_template = """ +template +torch::Tensor +{kernel_name}( + torch::Tensor &XQ, + torch::Tensor &WQ, + torch::Tensor &x_scale, + torch::Tensor &w_scale, + torch::Tensor &Y); +""" + MAINFEST_end = """ + +#endif // USE_ROCM +""" + + with open( + os.path.join(self.working_path, "gemm_a8w8_bpreshuffle_cktile_manifest.h"), + "w", + ) as f: + f.write(MAINFEST_head) + for mnk, k in kernels_dict.items(): + f.write(MAINFEST_template.format(kernel_name=k.name)) + f.write(MAINFEST_end) + + def gen_instances(self, kernels_dict): + if os.path.exists(self.impl_path): + shutil.rmtree(self.impl_path) + os.mkdir(self.impl_path) + if os.path.exists(self.instances_path): + shutil.rmtree(self.instances_path) + os.mkdir(self.instances_path) + + for mnk, k in kernels_dict.items(): + self.gen_instance(k) + + self.gen_lookup_dict(kernels_dict) + self.gen_manifest_head(kernels_dict) + + +def get_tune_dict(tune_dict_csv): + tune_dict = default_kernels_dict + if os.path.exists(tune_dict_csv): + tune_df = pd.read_csv(tune_dict_csv) + if torch.cuda.is_available(): + gpu = torch.cuda.current_device() + device_properties = torch.cuda.get_device_properties(gpu) + cu_num = device_properties.multi_processor_count + tune_df = tune_df[tune_df["cu_num"] == cu_num].reset_index() + for i in range(len(tune_df)): + M = tune_df.loc[i, "M"] + N = tune_df.loc[i, "N"] + K = tune_df.loc[i, "K"] + kid = tune_df.loc[i, "kernelId"] + if kid < 0 or kid > len(kernels_list): + continue + tune_dict[(M, N, K)] = kernels_list[kid] + return tune_dict + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="gen API for CKTILE gemm a8w8 kernel", + ) + + # the directory for list_blobs/gen_blobs to write files into + parser.add_argument( + "-w", + "--working_path", + default="./", + required=False, + help="the path where all the blobs are going to be generated", + ) + + parser.add_argument( + "-f", + "--tune_file", + default="aiter/configs/a8w8_bpreshuffle_cktile_tuned_gemm.csv", + required=False, + help="tune_file include the result after run gemm_a8w8_bpreshuffle_cktile_tune.py", + ) + + parser.add_argument( + "--tune", action="store_true", required=False, help="generated tune instanses" + ) + + args = parser.parse_args() + codegen = gemm_a8w8_bpreshuffle_cktile_codegen(args.working_path, args.tune) + + if args.tune: + codegen.gen_instances(kernels_list) + else: + codegen.gen_instances(get_tune_dict(args.tune_file)) diff --git a/csrc/cktile_gemm_a8w8_bpreshuffle/include/gemm_a8w8_bpreshuffle_cktile.h b/csrc/cktile_gemm_a8w8_bpreshuffle/include/gemm_a8w8_bpreshuffle_cktile.h new file mode 100644 index 0000000000..2eb83d065f --- /dev/null +++ b/csrc/cktile_gemm_a8w8_bpreshuffle/include/gemm_a8w8_bpreshuffle_cktile.h @@ -0,0 +1,20 @@ +#pragma once +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include + +torch::Tensor gemm_a8w8_bpreshuffle_cktile( + torch::Tensor &XQ, // [M, K] + torch::Tensor &WQ, // [N, K] -> [N/128, K*128] + torch::Tensor &x_scale, // [K/128, M] + torch::Tensor &w_scale, // [K/128, N/128] + torch::Tensor &out // Out:[M, N] fp16 +); +torch::Tensor gemm_a8w8_bpreshuffle_cktile_tune( + torch::Tensor &XQ, + torch::Tensor &WQ, + torch::Tensor &x_scale, + torch::Tensor &w_scale, + torch::Tensor &out, + int kernelId, + int splitK); diff --git a/csrc/cktile_gemm_a8w8_bpreshuffle/include/gemm_a8w8_bpreshuffle_cktile_common.cuh b/csrc/cktile_gemm_a8w8_bpreshuffle/include/gemm_a8w8_bpreshuffle_cktile_common.cuh new file mode 100644 index 0000000000..4a039422c2 --- /dev/null +++ b/csrc/cktile_gemm_a8w8_bpreshuffle/include/gemm_a8w8_bpreshuffle_cktile_common.cuh @@ -0,0 +1,386 @@ +#pragma once +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#ifdef USE_ROCM + +#undef __HIP_NO_HALF_OPERATORS__ +#undef __HIP_NO_HALF_CONVERSIONS__ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include "flatmm_basic.hpp" + +using F16 = ck_tile::half_t; +using BF16 = ck_tile::bf16_t; +using FP8 = ck_tile::fp8_t; +using F32 = float; +using B16 = ck_tile::bf16_t; +using ADataType = typename GemmBasicTypeConfig::ADataType; +using BDataType = typename GemmBasicTypeConfig::BDataType; +using CDataType = ck_tile::bf16_t; +using AccDataType = typename GemmBasicTypeConfig::AccDataType; +using ALayout = ck_tile::tensor_layout::gemm::RowMajor; +using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor; +using CLayout = ck_tile::tensor_layout::gemm::RowMajor; + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, + const ck_tile::stream_config& s) +{ + using CodegenFlatmmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits; + + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1; + + const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem; + + using CodegenFlatmmPipeline = + ck_tile::FlatmmPipelineAGmemBGmemCRegV1; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::FlatmmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(kargs); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n" + << "Shape: " << CodegenFlatmmShape::GetName() << "\n" + << "problem: " << CodegenPipelineProblem::GetName() << "\n" + << "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n" + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + static constexpr ck_tile::index_t APackedSize = + std::is_same_v ? 2 : 1; + static constexpr ck_tile::index_t BPackedSize = + std::is_same_v ? 2 : 1; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; + auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; + + ck_tile::RotatingMemWrapper rotating_mem( + kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + return ave_time; +} +template +struct CreateTileConfig +{ + static constexpr bool TransposeC = sTransposeC; + static constexpr bool UseStructuredSparsity = sUseStructuredSparsity; + static constexpr int TileParitionerGroupNum = sTileParitionerGroupNum; + static constexpr int TileParitionerM01 = sTileParitionerM01; + static constexpr ck_tile::index_t NumWaveGroups = sNumWaveGroups; + static constexpr bool DoubleSmemBuffer = sDoubleSmemBuffer; + static constexpr bool kPadM = PadM; + static constexpr bool kPadN = PadN; + static constexpr bool kPadK = PadK; + static constexpr int kBlockPerCu = BlockPerCu; + static constexpr int M_Tile = MTile; + static constexpr int N_Tile = NTile; + static constexpr int K_Tile = KTile; + static constexpr int M_Warp = MWarp; + static constexpr int N_Warp = NWarp; + static constexpr int K_Warp = KWarp; + static constexpr int M_Warp_Tile = MWTile; + static constexpr int N_Warp_Tile = NWTile; + static constexpr int K_Warp_Tile = KWTile; + static constexpr auto Scheduler = sScheduler; +}; + +template +using CustomConfig = CreateTileConfig; + +template +__forceinline__ torch::Tensor +gemm_a8w8_bpreshuffle_cktile_impl(torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& out // Out:[M, N] fp16 +) +{ + TORCH_CHECK(XQ.dtype() == WQ.dtype(), "Weights and activations should have the same dtype!"); + TORCH_CHECK(x_scale.dtype() == w_scale.dtype(), "Scales should have the same dtype!"); + using ADataType = typename GemmBasicTypeConfig::ADataType; + using BDataType = typename GemmBasicTypeConfig::BDataType; + using CDataType = ck_tile::bf16_t;; + using AccDataType = typename GemmBasicTypeConfig::AccDataType; + using DsDataType = ck_tile::tuple<>; + using ALayout = ck_tile::tensor_layout::gemm::RowMajor; + using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor; + using CLayout = ck_tile::tensor_layout::gemm::RowMajor; + using DsLayout = ck_tile::tuple<>; + using CDEElementWise = ck_tile::element_wise::PassThrough; + int m = XQ.size(0); + int n = out.size(1); + int k = XQ.size(1); + + using ScaleM = typename ck_tile::FlatmmScalePointer<1>; + using ScaleN = typename ck_tile::FlatmmScalePointer<1>; + + + ck_tile::ScaleFlatmmHostArgs args; + args.a_ptr = (void*)XQ.data_ptr(); + args.b_ptr = (void*)WQ.data_ptr(); + args.scale_m = ck_tile::FlatmmScalePointer<1>{reinterpret_cast(x_scale.data_ptr()),m}; + args.scale_n = ck_tile::FlatmmScalePointer<1>{reinterpret_cast(w_scale.data_ptr()),n}; + args.e_ptr = (void*)out.data_ptr(); + + args.k_batch = 1; + args.M = m; + args.N = n; + args.K = k; + args.stride_A = k; + args.stride_B = k; + args.stride_C = n; + args.stride_E = n; + + const c10::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(XQ)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); + ck_tile::stream_config naive_config{stream}; + flatmm_calc(args, naive_config); + + return out; +} + +#endif // USE_ROCM diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 8cab4d1dfd..88e0e96dc0 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -524,6 +524,29 @@ namespace py = pybind11; py::arg("Out"), \ py::arg("kernelId") = 0, \ py::arg("splitK") = 0); + +#define GEMM_A8W8_BPRESHUFFLE_CKTILE_PYBIND \ + m.def("gemm_a8w8_bpreshuffle_cktile", \ + &gemm_a8w8_bpreshuffle_cktile, \ + "gemm_a8w8_bpreshuffle_cktile", \ + py::arg("XQ"), \ + py::arg("WQ"), \ + py::arg("x_scale"), \ + py::arg("w_scale"), \ + py::arg("Out")); + +#define GEMM_A8W8_BPRESHUFFLE_CKTILE_TUNE_PYBIND \ + m.def("gemm_a8w8_bpreshuffle_cktile_tune", \ + &gemm_a8w8_bpreshuffle_cktile_tune, \ + "gemm_a8w8_bpreshuffle_cktile_tune", \ + py::arg("XQ"), \ + py::arg("WQ"), \ + py::arg("x_scale"), \ + py::arg("w_scale"), \ + py::arg("Out"), \ + py::arg("kernelId") = 0, \ + py::arg("splitK") = 0); + #define MHA_BWD_ASM_PYBIND \ m.def("fmha_v3_bwd", \ &aiter::torch_itfs::fmha_v3_bwd, \ diff --git a/csrc/pybind/gemm_a8w8_bpreshuffle_cktile_pybind.cu b/csrc/pybind/gemm_a8w8_bpreshuffle_cktile_pybind.cu new file mode 100644 index 0000000000..b453764779 --- /dev/null +++ b/csrc/pybind/gemm_a8w8_bpreshuffle_cktile_pybind.cu @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "gemm_a8w8_bpreshuffle_cktile.h" +#include "rocm_ops.hpp" +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { GEMM_A8W8_BPRESHUFFLE_CKTILE_PYBIND; } diff --git a/csrc/pybind/gemm_a8w8_bpreshuffle_cktile_tune_pybind.cu b/csrc/pybind/gemm_a8w8_bpreshuffle_cktile_tune_pybind.cu new file mode 100644 index 0000000000..aaa0ba69f7 --- /dev/null +++ b/csrc/pybind/gemm_a8w8_bpreshuffle_cktile_tune_pybind.cu @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "gemm_a8w8_bpreshuffle_cktile.h" +#include "rocm_ops.hpp" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { GEMM_A8W8_BPRESHUFFLE_CKTILE_TUNE_PYBIND; } diff --git a/csrc/rocm_ops.cpp b/csrc/rocm_ops.cpp index 7f89db3c93..ab2c0ce3f5 100644 --- a/csrc/rocm_ops.cpp +++ b/csrc/rocm_ops.cpp @@ -23,6 +23,7 @@ #include "gemm_a8w8.h" #include "gemm_a8w8_blockscale.h" #include "gemm_a8w8_bpreshuffle.h" +#include "gemm_a8w8_bpreshuffle_cktile.h" #include "gemm_common.h" #include "hipbsolgemm.cuh" #include "mla.h"