|
| 1 | +import torch |
| 2 | + |
| 3 | +import triton |
| 4 | +import triton.language as tl |
| 5 | + |
| 6 | +DEVICE = triton.runtime.driver.active.get_active_torch_device() |
| 7 | + |
| 8 | +def is_cuda(): |
| 9 | + return triton.runtime.driver.active.get_current_target().backend == "cuda" |
| 10 | + |
| 11 | +def get_cuda_autotune_config(): |
| 12 | + return [ |
| 13 | + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,num_warps=2), |
| 14 | + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,num_warps=2), |
| 15 | + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4), |
| 16 | + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4), |
| 17 | + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4), |
| 18 | + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4), |
| 19 | + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4), |
| 20 | + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,num_warps=8), |
| 21 | + ] |
| 22 | + |
| 23 | +def get_hip_autotune_config(): |
| 24 | + sizes = [ |
| 25 | + {'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6}, |
| 26 | + {'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6}, |
| 27 | + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4}, |
| 28 | + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6}, |
| 29 | + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4}, |
| 30 | + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4}, |
| 31 | + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4}, |
| 32 | + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6}, |
| 33 | + ] |
| 34 | + return [triton.Config(s | {'matrix_instr_nonkdim': 16}, num_warps=8, num_stages=2) for s in sizes] |
| 35 | + |
| 36 | +def get_autotune_config(): |
| 37 | + if is_cuda(): |
| 38 | + return get_cuda_autotune_config() |
| 39 | + else: |
| 40 | + return get_hip_autotune_config() |
| 41 | + |
| 42 | +@triton.autotune( |
| 43 | + configs=get_autotune_config(), |
| 44 | + key=['M', 'N', 'K'], |
| 45 | +) |
| 46 | +@triton.jit |
| 47 | +def matmul_kernel( |
| 48 | + a_ptr, b_ptr, c_ptr, |
| 49 | + M, N, K, |
| 50 | + stride_am, stride_ak, |
| 51 | + stride_bk, stride_bn, |
| 52 | + stride_cm, stride_cn, |
| 53 | + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, |
| 54 | + GROUP_SIZE_M: tl.constexpr, |
| 55 | + PRECISION: tl.constexpr |
| 56 | +): |
| 57 | + pid = tl.program_id(axis=0) |
| 58 | + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) |
| 59 | + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) |
| 60 | + num_pid_in_group = GROUP_SIZE_M * num_pid_n |
| 61 | + group_id = pid // num_pid_in_group |
| 62 | + first_pid_m = group_id * GROUP_SIZE_M |
| 63 | + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) |
| 64 | + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) |
| 65 | + pid_n = (pid % num_pid_in_group) // group_size_m |
| 66 | + |
| 67 | + tl.assume(pid_m >= 0) |
| 68 | + tl.assume(pid_n >= 0) |
| 69 | + tl.assume(stride_am > 0) |
| 70 | + tl.assume(stride_ak > 0) |
| 71 | + tl.assume(stride_bn > 0) |
| 72 | + tl.assume(stride_bk > 0) |
| 73 | + tl.assume(stride_cm > 0) |
| 74 | + tl.assume(stride_cn > 0) |
| 75 | + |
| 76 | + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M |
| 77 | + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N |
| 78 | + offs_k = tl.arange(0, BLOCK_SIZE_K) |
| 79 | + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) |
| 80 | + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) |
| 81 | + |
| 82 | + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
| 83 | + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): |
| 84 | + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) |
| 85 | + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) |
| 86 | + accumulator += tl.dot(a, b, input_precision=PRECISION) |
| 87 | + # accumulator = tl.dot(a, b, accumulator) |
| 88 | + # Advance the ptrs to the next K block. |
| 89 | + a_ptrs += BLOCK_SIZE_K * stride_ak |
| 90 | + b_ptrs += BLOCK_SIZE_K * stride_bk |
| 91 | + c = accumulator.to(tl.float32) |
| 92 | + |
| 93 | + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) |
| 94 | + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
| 95 | + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] |
| 96 | + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) |
| 97 | + tl.store(c_ptrs, c, mask=c_mask) |
| 98 | + |
| 99 | + |
| 100 | +def matmul(a, b, precision="ieee"): |
| 101 | + M, K = a.shape |
| 102 | + K, N = b.shape |
| 103 | + c = torch.empty((M, N), device=a.device, dtype=torch.float32) |
| 104 | + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) |
| 105 | + matmul_kernel[grid]( |
| 106 | + a, b, c, |
| 107 | + M, N, K, |
| 108 | + a.stride(0), a.stride(1), |
| 109 | + b.stride(0), b.stride(1), |
| 110 | + c.stride(0), c.stride(1), |
| 111 | + PRECISION=precision |
| 112 | + ) |
| 113 | + return c |
| 114 | + |
| 115 | + |
| 116 | +precisions = ["ieee", "bf16", "bf16x3", "bf16x6", "bf16x9"] |
| 117 | +torch.manual_seed(0) |
| 118 | + |
| 119 | +for precision in precisions: |
| 120 | + a = torch.rand((512, 512), device=DEVICE, dtype=torch.float32) - 0.5 |
| 121 | + b = torch.rand((512, 512), device=DEVICE, dtype=torch.float32) - 0.5 |
| 122 | + triton_output = matmul(a, b, precision=precision) |
| 123 | + torch_output = torch.matmul(a, b) |
| 124 | + #print(f"triton_output_with_fp32_inputs={triton_output}") |
| 125 | + #print(f"torch_output_with_fp32_inputs={torch_output}") |
| 126 | + |
| 127 | + if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0): |
| 128 | + print(f'✅ Triton and Torch match for input_precision={precision}') |
| 129 | + else: |
| 130 | + print(f'❌ Triton and Torch differ for input_precision={precision}') |
| 131 | + |
| 132 | +ref_lib = 'cuBLAS' if is_cuda() else 'rocBLAS' |
| 133 | + |
| 134 | +configs = [] |
| 135 | +configs.append( |
| 136 | + triton.testing.Benchmark( |
| 137 | + x_names=["M", "N", "K"], |
| 138 | + x_vals=[128 * i for i in range(2, 33)], |
| 139 | + line_arg="provider", |
| 140 | + line_vals=[ref_lib.lower(), "triton-ieee", "triton-bf16", "triton-bf16x3", "triton-bf16x6", "triton-bf16x9"], |
| 141 | + line_names=[ref_lib, "Triton-IEEE", "Triton-BF16", "Triton-BF16x3", "Triton-BF16x6", "Triton-BF16x9"], |
| 142 | + styles=[("green", "-"), ("blue", "-"), ("blue", "-"), ("blue", "-"), ("blue", "-"), ("blue", "-")], |
| 143 | + ylabel="TFLOPS", |
| 144 | + plot_name="matmul-performance-f32", |
| 145 | + args={}, |
| 146 | + )) |
| 147 | + |
| 148 | +@triton.testing.perf_report(configs) |
| 149 | +def benchmark(M, N, K, provider): |
| 150 | + a = torch.randn((M, K), device=DEVICE, dtype=torch.float32) |
| 151 | + b = torch.randn((K, N), device=DEVICE, dtype=torch.float32) |
| 152 | + quantiles = [0.5, 0.2, 0.8] |
| 153 | + if provider == ref_lib.lower(): |
| 154 | + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) |
| 155 | + if provider.startswith('triton-'): |
| 156 | + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, provider.removeprefix('triton-')), quantiles=quantiles) |
| 157 | + perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) |
| 158 | + return perf(ms), perf(max_ms), perf(min_ms) |
| 159 | + |
| 160 | +benchmark.run(show_plots=False, print_data=True) |
0 commit comments