-
Notifications
You must be signed in to change notification settings - Fork 5.2k
Use FlashInfer FP4 gemm. #8241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Use FlashInfer FP4 gemm. #8241
Changes from 7 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
d0bdcc2
Use FlashInfer FP4 gemm.
elfiegg 8ee7a70
address comment.
elfiegg 8be5b78
formatting
elfiegg 1c76736
format again
elfiegg 8adfdf6
linting conflict?
elfiegg 980d9e2
linting conflict?
elfiegg b14b03a
address incidental delete
elfiegg cfb38c7
Merge branch 'main' into gemm
merrymercy 8d1216d
Merge branch 'main' into gemm
ispobock File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,210 @@ | ||
| import argparse | ||
| import copy | ||
| import csv | ||
| import itertools | ||
|
|
||
| import pytest | ||
| import torch | ||
| import triton | ||
| from flashinfer import mm_fp4 | ||
| from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant | ||
|
|
||
| FLOAT4_E2M1_MAX = 6.0 | ||
| FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max | ||
|
|
||
|
|
||
| def get_weight_shapes(args): | ||
| models_tps = args.tp_sizes | ||
|
|
||
| if models_tps == [4]: | ||
| return [[1024, 3584], [7168, 256], [7168, 2304], [9216, 3584]] | ||
|
|
||
| if models_tps == [8]: | ||
| return [[512, 3584], [7168, 128], [7168, 1152], [4608, 3584]] | ||
| return [ | ||
| [1024, 3584], | ||
| [7168, 256], | ||
| [7168, 2304], | ||
| [9216, 3584], | ||
| [512, 3584], | ||
| [7168, 128], | ||
| [7168, 1152], | ||
| [4608, 3584], | ||
| ] | ||
|
|
||
|
|
||
| @triton.testing.perf_report( | ||
| triton.testing.Benchmark( | ||
| x_names=["batch_size"], | ||
| x_vals=[ | ||
| 1, | ||
| 2, | ||
| 4, | ||
| 8, | ||
| 16, | ||
| 32, | ||
| 64, | ||
| 128, | ||
| 256, | ||
| 512, | ||
| 1024, | ||
| 2048, | ||
| 3072, | ||
| 4096, | ||
| 8192, | ||
| 16384, | ||
| ], | ||
| # x_vals = [64], | ||
| x_log=False, | ||
| line_arg="provider", | ||
| line_vals=["cutlass", "cudnn", "trtllm"], | ||
| line_names=["baseline cutlass fp4", "cudnn fp4", "trtllm fp4"], | ||
| styles=[("red", "solid"), ("blue", "solid"), ("green", "solid")], | ||
| ylabel="latency (ms)", | ||
| plot_name="fp4_gemm_benchmark", | ||
| args={}, | ||
| ) | ||
| ) | ||
| def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file): | ||
| M = batch_size | ||
| packed_k = K | ||
| K = 2 * packed_k | ||
| a_dtype = torch.randn((M, K), dtype=dtype, device="cuda") | ||
| b_dtype = torch.randn((N, K), dtype=dtype, device="cuda") | ||
| a_global_scale = ( | ||
| (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1) | ||
| ).to(torch.float32) | ||
| b_global_scale = ( | ||
| (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1) | ||
| ).to(torch.float32) | ||
|
|
||
| alpha = 1.0 / (a_global_scale * b_global_scale) | ||
| a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale) | ||
| # print("a_fp4", a_fp4) | ||
| b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale) | ||
| res_fi = torch.empty((M, N), dtype=dtype, device="cuda") | ||
|
|
||
| quantiles = [0.5, 0.2, 0.8] | ||
| if provider == "cutlass": | ||
| ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( | ||
| lambda: cutlass_scaled_fp4_mm( | ||
| a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype | ||
| ), | ||
| quantiles=quantiles, | ||
| ) | ||
| if provider == "cudnn": | ||
| ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( | ||
| lambda: mm_fp4( | ||
| a_fp4, | ||
| b_fp4.T, | ||
| a_scale_interleaved, | ||
| b_scale_interleaved.T, | ||
| alpha, | ||
| dtype, | ||
| res_fi, | ||
| ), | ||
| quantiles=quantiles, | ||
| ) | ||
| if provider == "trtllm": | ||
| a_scale_interleaved = a_scale_interleaved.to(torch.uint8) | ||
| b_scale_interleaved = b_scale_interleaved.to(torch.uint8) | ||
| ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( | ||
| lambda: mm_fp4( | ||
| a_fp4, | ||
| b_fp4.T, | ||
| a_scale_interleaved, | ||
| b_scale_interleaved.T, | ||
| alpha, | ||
| dtype, | ||
| res_fi, | ||
| backend="trtllm", | ||
| ), | ||
| quantiles=quantiles, | ||
| ) | ||
| if correctness: | ||
| res_cutlass = cutlass_scaled_fp4_mm( | ||
| a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype | ||
| ) | ||
| mm_fp4( | ||
| a_fp4, | ||
| b_fp4.T, | ||
| a_scale_interleaved, | ||
| b_scale_interleaved.T, | ||
| alpha, | ||
| dtype, | ||
| res_fi, | ||
| backend="cudnn", | ||
| ) | ||
| assert torch.allclose( | ||
| res_fi, res_cutlass, atol=1e-3, rtol=1e-3 | ||
| ), "cudnn fp4 doesn't match cutlass fp4" | ||
| mm_fp4( | ||
| a_fp4, | ||
| b_fp4.T, | ||
| a_scale_interleaved, | ||
| b_scale_interleaved.T, | ||
| alpha, | ||
| dtype, | ||
| res_fi, | ||
| backend="trtllm", | ||
| ) | ||
| assert torch.allclose( | ||
| res_fi, res_cutlass, atol=1e-3, rtol=1e-3 | ||
| ), "trtllm fp4 doesn't match cutlass fp4" | ||
|
|
||
| if csv_file: | ||
| with open(csv_file, "a", newline="") as f: | ||
| writer = csv.writer(f) | ||
| writer.writerow([provider, M, N, K, ms]) | ||
|
|
||
| return ms, min_ms, max_ms | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument( | ||
| "--tp-sizes", | ||
| nargs="+", | ||
| type=int, | ||
| default=[1], | ||
| help="List of tensor parallel sizes", | ||
| ) | ||
| parser.add_argument( | ||
| "--dtype", | ||
| type=torch.dtype, | ||
| default=torch.bfloat16, | ||
| help="Data type", | ||
| ) | ||
| parser.add_argument( | ||
| "--correctness", | ||
| action="store_true", | ||
| help="Check correctness", | ||
| ) | ||
| parser.add_argument( | ||
| "--csv", | ||
| type=str, | ||
| default="results_cutlass_cudnn.csv", | ||
| help="CSV file to save results", | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| if args.csv: | ||
| with open(args.csv, "w", newline="") as f: | ||
| writer = csv.writer(f) | ||
| writer.writerow(["provider", "m", "n", "k", "time_ms"]) | ||
|
|
||
| NKs = get_weight_shapes(args) | ||
| for N, K in NKs: | ||
| print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ") | ||
| benchmark.run( | ||
| print_data=True, | ||
| show_plots=True, | ||
| save_path="bench_fp4_res", | ||
| N=N, | ||
| K=K, | ||
| dtype=args.dtype, | ||
| correctness=args.correctness, | ||
| csv_file=args.csv, | ||
| ) | ||
|
|
||
| print("Benchmark finished!") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: would be great to have a env var to allow user to control which to use, thus if one gets some bug we have another