diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index b39dce2659a5..a378bc6baa5a 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -1,36 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 + import os +import sys import zipfile -MAX_SIZE_MB = 250 +# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 400 MiB +# Note that we have 400 MiB quota, please use it wisely. +# See https://github.com/pypi/support/issues/3792 . +# Please also sync the value with the one in Dockerfile. +VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 400)) def print_top_10_largest_files(zip_file): + """Print the top 10 largest files in the given zip file.""" with zipfile.ZipFile(zip_file, 'r') as z: file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()] file_sizes.sort(key=lambda x: x[1], reverse=True) for f, size in file_sizes[:10]: - print(f"{f}: {size/(1024*1024)} MBs uncompressed.") + print(f"{f}: {size / (1024 * 1024):.2f} MBs uncompressed.") def check_wheel_size(directory): + """Check the size of .whl files in the given directory.""" for root, _, files in os.walk(directory): - for f in files: - if f.endswith(".whl"): - wheel_path = os.path.join(root, f) - wheel_size = os.path.getsize(wheel_path) - wheel_size_mb = wheel_size / (1024 * 1024) - if wheel_size_mb > MAX_SIZE_MB: - print( - f"Wheel {wheel_path} is too large ({wheel_size_mb} MB) " - f"compare to the allowed size ({MAX_SIZE_MB} MB).") + for file_name in files: + if file_name.endswith(".whl"): + wheel_path = os.path.join(root, file_name) + wheel_size_mb = os.path.getsize(wheel_path) / (1024 * 1024) + if wheel_size_mb > VLLM_MAX_SIZE_MB: + print(f"Not allowed: Wheel {wheel_path} is larger " + f"({wheel_size_mb:.2f} MB) than the limit " + f"({VLLM_MAX_SIZE_MB} MB).") print_top_10_largest_files(wheel_path) return 1 else: print(f"Wheel {wheel_path} is within the allowed size " - f"({wheel_size_mb} MB).") + f"({wheel_size_mb:.2f} MB).") return 0 if __name__ == "__main__": - import sys - sys.exit(check_wheel_size(sys.argv[1])) + if len(sys.argv) < 2: + print("Usage: python check-wheel-size.py ") + sys.exit(1) + + directory = sys.argv[1] + sys.exit(check_wheel_size(directory)) \ No newline at end of file diff --git a/.buildkite/generate_index.py b/.buildkite/generate_index.py new file mode 100644 index 000000000000..36e1b6c01326 --- /dev/null +++ b/.buildkite/generate_index.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os + +template = """ + + +

Links for vLLM

+ {wheel}
+ + +""" + +parser = argparse.ArgumentParser() +parser.add_argument("--wheel", help="The wheel path.", required=True) +args = parser.parse_args() + +filename = os.path.basename(args.wheel) + +with open("index.html", "w") as f: + print(f"Generated index.html for {args.wheel}") + # cloudfront requires escaping the '+' character + f.write( + template.format(wheel=filename, + wheel_html_escaped=filename.replace("+", "%2B"))) diff --git a/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml b/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml index 15268395ec68..d70ecb2a7e7b 100644 --- a/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml +++ b/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml @@ -9,3 +9,4 @@ tasks: value: 0.664 limit: 1000 num_fewshot: 5 +trust_remote_code: True \ No newline at end of file diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml new file mode 100644 index 000000000000..0ecfc01ef049 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test -b "auto" -l 250 -f 5 -t 1 +model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.764 + - name: "exact_match,flexible-extract" + value: 0.764 +limit: 250 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml new file mode 100644 index 000000000000..042458659839 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1 +model_name: "HandH1998/QQQ-Llama-3-8b-g128" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.419 + - name: "exact_match,flexible-extract" + value: 0.416 +limit: 1000 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml new file mode 100644 index 000000000000..78347f63fa79 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8 -b "auto" -l 1000 -f 5 -t 1 +model_name: "neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.356 + - name: "exact_match,flexible-extract" + value: 0.358 +limit: 1000 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml b/.buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml new file mode 100644 index 000000000000..4ef8b5c3709b --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m mgoin/Minitron-4B-Base-FP8 -b auto -l 1000 -f 5 -t 1 +model_name: "mgoin/Minitron-4B-Base-FP8" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.231 + - name: "exact_match,flexible-extract" + value: 0.22 +limit: 1000 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Minitron-4B-Base.yaml b/.buildkite/lm-eval-harness/configs/Minitron-4B-Base.yaml deleted file mode 100644 index a0466748ea71..000000000000 --- a/.buildkite/lm-eval-harness/configs/Minitron-4B-Base.yaml +++ /dev/null @@ -1,11 +0,0 @@ -# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nvidia/Minitron-4B-Base -b auto -l 1000 -f 5 -t 1 -model_name: "nvidia/Minitron-4B-Base" -tasks: -- name: "gsm8k" - metrics: - - name: "exact_match,strict-match" - value: 0.252 - - name: "exact_match,flexible-extract" - value: 0.252 -limit: 1000 -num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml b/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml new file mode 100644 index 000000000000..2928d75ce446 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml @@ -0,0 +1,11 @@ +# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2 +model_name: "nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.6353 + - name: "exact_match,flexible-extract" + value: 0.637 +limit: null +num_fewshot: null diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt index e4df4b547aa5..6057229ac50f 100644 --- a/.buildkite/lm-eval-harness/configs/models-small.txt +++ b/.buildkite/lm-eval-harness/configs/models-small.txt @@ -1,9 +1,10 @@ Meta-Llama-3-8B-Instruct.yaml -Meta-Llama-3-8B-Instruct-FP8.yaml Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml -Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml +Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml +Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml -Minitron-4B-Base.yaml +Minitron-4B-Base-FP8.yaml Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml Qwen2-1.5B-Instruct-FP8W8.yaml +Meta-Llama-3-8B-QQQ.yaml diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh index fdb8ec5393b3..a67fc89d54e6 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh @@ -2,7 +2,7 @@ # We can use this script to compute baseline accuracy on GSM for transformers. # # Make sure you have lm-eval-harness installed: -# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@9516087b81a61d0e220b22cc1b75be76de23bc10 +# pip install lm-eval==0.4.4 usage() { echo`` @@ -41,6 +41,6 @@ while getopts "m:b:l:f:" OPT; do done lm_eval --model hf \ - --model_args pretrained=$MODEL,parallelize=True \ - --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \ - --batch_size $BATCH_SIZE + --model_args "pretrained=$MODEL,parallelize=True" \ + --tasks gsm8k --num_fewshot "$FEWSHOT" --limit "$LIMIT" \ + --batch_size "$BATCH_SIZE" diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh index de841d959a4e..65be3c5d93b2 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh @@ -3,7 +3,7 @@ # We use this for fp8, which HF does not support. # # Make sure you have lm-eval-harness installed: -# pip install lm-eval==0.4.3 +# pip install lm-eval==0.4.4 usage() { echo`` @@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do done lm_eval --model vllm \ - --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend="ray",trust_remote_code=true,max_model_len=4096 \ - --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \ - --batch_size $BATCH_SIZE + --model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend=ray,trust_remote_code=true,max_model_len=4096" \ + --tasks gsm8k --num_fewshot "$FEWSHOT" --limit "$LIMIT" \ + --batch_size "$BATCH_SIZE" diff --git a/.buildkite/lm-eval-harness/run-tests.sh b/.buildkite/lm-eval-harness/run-tests.sh index b4fdde6dab42..26f33b744289 100644 --- a/.buildkite/lm-eval-harness/run-tests.sh +++ b/.buildkite/lm-eval-harness/run-tests.sh @@ -30,7 +30,7 @@ while getopts "c:t:" OPT; do done # Parse list of configs. -IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < $CONFIG +IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "$CONFIG" for MODEL_CONFIG in "${MODEL_CONFIGS[@]}" do diff --git a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py index 7fdce7b53bd7..4ae23eff62f3 100644 --- a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py +++ b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 """ LM eval harness on model to compare vs HF baseline computed offline. Configs are found in configs/$MODEL.yaml @@ -12,9 +13,10 @@ import lm_eval import numpy +import pytest import yaml -RTOL = 0.02 +RTOL = 0.05 TEST_DATA_FILE = os.environ.get( "LM_EVAL_TEST_DATA_FILE", ".buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml") @@ -23,9 +25,12 @@ def launch_lm_eval(eval_config): + trust_remote_code = eval_config.get('trust_remote_code', False) + model_args = f"pretrained={eval_config['model_name']}," \ f"tensor_parallel_size={TP_SIZE}," \ - f"add_bos_token=true" + f"add_bos_token=true," \ + f"trust_remote_code={trust_remote_code}" results = lm_eval.simple_evaluate( model="vllm", @@ -42,14 +47,23 @@ def test_lm_eval_correctness(): eval_config = yaml.safe_load( Path(TEST_DATA_FILE).read_text(encoding="utf-8")) + if eval_config[ + "model_name"] == "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform": #noqa: E501 + pytest.skip("FBGEMM is currently failing on main.") + # Launch eval requests. results = launch_lm_eval(eval_config) # Confirm scores match ground truth. + success = True for task in eval_config["tasks"]: for metric in task["metrics"]: ground_truth = metric["value"] measured_value = results["results"][task["name"]][metric["name"]] print(f'{task["name"]} | {metric["name"]}: ' f'ground_truth={ground_truth} | measured={measured_value}') - assert numpy.isclose(ground_truth, measured_value, rtol=RTOL) + success = success and numpy.isclose( + ground_truth, measured_value, rtol=RTOL) + + # Assert at the end, print all scores even on failure for debugging. + assert success diff --git a/.buildkite/nightly-benchmarks/README.md b/.buildkite/nightly-benchmarks/README.md index c1aebaf5b3bb..d3f5fc5cd4ce 100644 --- a/.buildkite/nightly-benchmarks/README.md +++ b/.buildkite/nightly-benchmarks/README.md @@ -1,15 +1,13 @@ # vLLM benchmark suite - ## Introduction This directory contains two sets of benchmark for vllm. + - Performance benchmark: benchmark vllm's performance under various workload, for **developers** to gain clarity on whether their PR improves/degrades vllm's performance - Nightly benchmark: compare vllm's performance against alternatives (tgi, trt-llm and lmdeploy), for **the public** to know when to choose vllm. - -See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performance benchmark results and [vLLM GitHub README](https://github.com/vllm-project/vllm/blob/main/README.md) for latest nightly benchmark results. - +See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performance benchmark results and [vLLM GitHub README](https://github.com/vllm-project/vllm/blob/main/README.md) for latest nightly benchmark results. ## Performance benchmark quick overview @@ -19,35 +17,28 @@ See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performan **For benchmarking developers**: please try your best to constraint the duration of benchmarking to about 1 hr so that it won't take forever to run. - ## Nightly benchmark quick overview -**Benchmarking Coverage**: Fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!) on Llama-3 8B, 70B and Mixtral 8x7B. +**Benchmarking Coverage**: Fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!) on Llama-3 8B, 70B and Mixtral 8x7B. **Benchmarking engines**: vllm, TGI, trt-llm and lmdeploy. **Benchmarking Duration**: about 3.5hrs. - - ## Trigger the benchmark Performance benchmark will be triggered when: - A PR being merged into vllm. -- Every commit for those PRs with `perf-benchmarks` label. +- Every commit for those PRs with `perf-benchmarks` label AND `ready` label. Nightly benchmark will be triggered when: -- Every commit for those PRs with `nightly-benchmarks` label. - - - +- Every commit for those PRs with `perf-benchmarks` label and `nightly-benchmarks` label. ## Performance benchmark details -See [descriptions.md](tests/descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases. - +See [performance-benchmarks-descriptions.md](performance-benchmarks-descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases. -#### Latency test +### Latency test Here is an example of one test inside `latency-tests.json`: @@ -67,23 +58,25 @@ Here is an example of one test inside `latency-tests.json`: ``` In this example: -- The `test_name` attributes is a unique identifier for the test. In `latency-tests.json`, it must start with `latency_`. -- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-benchmarks-suite.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15` + +- The `test_name` attributes is a unique identifier for the test. In `latency-tests.json`, it must start with `latency_`. +- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-performance-benchmarks.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15` Note that the performance numbers are highly sensitive to the value of the parameters. Please make sure the parameters are set correctly. WARNING: The benchmarking script will save json results by itself, so please do not configure `--output-json` parameter in the json file. +### Throughput test -#### Throughput test The tests are specified in `throughput-tests.json`. The syntax is similar to `latency-tests.json`, except for that the parameters will be fed forward to `benchmark_throughput.py`. The number of this test is also stable -- a slight change on the value of this number might vary the performance numbers by a lot. -#### Serving test +### Serving test + We test the throughput by using `benchmark_serving.py` with request rate = inf to cover the online serving overhead. The corresponding parameters are in `serving-tests.json`, and here is an example: -``` +```json [ { "test_name": "serving_llama8B_tp1_sharegpt", @@ -108,6 +101,7 @@ We test the throughput by using `benchmark_serving.py` with request rate = inf t ``` Inside this example: + - The `test_name` attribute is also a unique identifier for the test. It must start with `serving_`. - The `server-parameters` includes the command line arguments for vLLM server. - The `client-parameters` includes the command line arguments for `benchmark_serving.py`. @@ -117,36 +111,33 @@ The number of this test is less stable compared to the delay and latency benchma WARNING: The benchmarking script will save json results by itself, so please do not configure `--save-results` or other results-saving-related parameters in `serving-tests.json`. -#### Visualizing the results +### Visualizing the results + The `convert-results-json-to-markdown.py` helps you put the benchmarking results inside a markdown table, by formatting [descriptions.md](tests/descriptions.md) with real benchmarking results. You can find the result presented as a table inside the `buildkite/performance-benchmark` job page. If you do not see the table, please wait till the benchmark finish running. The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file. The raw benchmarking results (in the format of json files) are in the `Artifacts` tab of the benchmarking. - - ## Nightly test details See [nightly-descriptions.md](nightly-descriptions.md) for the detailed description on test workload, models and docker containers of benchmarking other llm engines. +### Workflow -#### Workflow - -- The [nightly-pipeline.yaml](nightly-pipeline.yaml) specifies the docker containers for different LLM serving engines. +- The [nightly-pipeline.yaml](nightly-pipeline.yaml) specifies the docker containers for different LLM serving engines. - Inside each container, we run [run-nightly-suite.sh](run-nightly-suite.sh), which will probe the serving engine of the current container. - The `run-nightly-suite.sh` will redirect the request to `tests/run-[llm serving engine name]-nightly.sh`, which parses the workload described in [nightly-tests.json](tests/nightly-tests.json) and performs the benchmark. - At last, we run [scripts/plot-nightly-results.py](scripts/plot-nightly-results.py) to collect and plot the final benchmarking results, and update the results to buildkite. -#### Nightly tests +### Nightly tests In [nightly-tests.json](tests/nightly-tests.json), we include the command line arguments for benchmarking commands, together with the benchmarking test cases. The format is highly similar to performance benchmark. -#### Docker containers +### Docker containers The docker containers for benchmarking are specified in `nightly-pipeline.yaml`. WARNING: the docker versions are HARD-CODED and SHOULD BE ALIGNED WITH `nightly-descriptions.md`. The docker versions need to be hard-coded as there are several version-specific bug fixes inside `tests/run-[llm serving engine name]-nightly.sh`. WARNING: populating `trt-llm` to latest version is not easy, as it requires updating several protobuf files in [tensorrt-demo](https://github.com/neuralmagic/tensorrt-demo.git). - diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml index 02c0ee534d72..4259514940d3 100644 --- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml @@ -1,5 +1,6 @@ steps: - label: "Wait for container to be ready" + key: wait-for-container-image agents: queue: A100 plugins: @@ -8,20 +9,27 @@ steps: containers: - image: badouralix/curl-jq command: - - sh - - .buildkite/nightly-benchmarks/scripts/wait-for-image.sh - - wait + - sh .buildkite/nightly-benchmarks/scripts/wait-for-image.sh + - label: "Cleanup H100" + agents: + queue: H100 + depends_on: ~ + command: docker system prune -a --volumes --force + - label: "A100" + # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" agents: queue: A100 + depends_on: wait-for-container-image + if: build.branch == "main" plugins: - kubernetes: podSpec: priorityClassName: perf-benchmark containers: - - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + - image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT command: - - bash .buildkite/nightly-benchmarks/run-benchmarks-suite.sh + - bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh resources: limits: nvidia.com/gpu: 8 @@ -42,20 +50,135 @@ steps: - name: devshm emptyDir: medium: Memory + + - label: "H200" + # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" + agents: + queue: H200 + depends_on: wait-for-container-image + if: build.branch == "main" + plugins: + - docker#v5.12.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT + command: + - bash + - .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh + mount-buildkite-agent: true + propagate-environment: true + ipc: host + gpus: 4,5,6,7 + volumes: + - /data/benchmark-hf-cache:/root/.cache/huggingface + environment: + - VLLM_USAGE_SOURCE + - HF_TOKEN + + #- block: "Run H100 Benchmark" + #key: block-h100 + #depends_on: ~ + - label: "H100" + # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" agents: queue: H100 + depends_on: wait-for-container-image + if: build.branch == "main" + plugins: + - docker#v5.12.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT + command: + - bash + - .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh + mount-buildkite-agent: true + propagate-environment: true + ipc: host + gpus: all # see CUDA_VISIBLE_DEVICES for actual GPUs used + volumes: + - /data/benchmark-hf-cache:/root/.cache/huggingface + environment: + - VLLM_USAGE_SOURCE + - HF_TOKEN + + # Premerge benchmark + - label: "A100" + # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" + agents: + queue: A100 + depends_on: wait-for-container-image + if: build.branch != "main" + plugins: + - kubernetes: + podSpec: + priorityClassName: perf-benchmark + containers: + - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + command: + - bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh + resources: + limits: + nvidia.com/gpu: 8 + volumeMounts: + - name: devshm + mountPath: /dev/shm + env: + - name: VLLM_USAGE_SOURCE + value: ci-test + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + nodeSelector: + nvidia.com/gpu.product: NVIDIA-A100-SXM4-80GB + volumes: + - name: devshm + emptyDir: + medium: Memory + + - label: "H200" + # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" + agents: + queue: H200 + depends_on: wait-for-container-image + if: build.branch != "main" plugins: - - docker#v5.11.0: + - docker#v5.12.0: image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT command: - bash - - .buildkite/nightly-benchmarks/run-benchmarks-suite.sh + - .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh mount-buildkite-agent: true propagate-environment: true ipc: host - gpus: all + gpus: 4,5,6,7 + volumes: + - /data/benchmark-hf-cache:/root/.cache/huggingface environment: - VLLM_USAGE_SOURCE - HF_TOKEN + #- block: "Run H100 Benchmark" + #key: block-h100 + #depends_on: ~ + + - label: "H100" + # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" + agents: + queue: H100 + depends_on: wait-for-container-image + if: build.branch != "main" + plugins: + - docker#v5.12.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + command: + - bash + - .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh + mount-buildkite-agent: true + propagate-environment: true + ipc: host + gpus: all # see CUDA_VISIBLE_DEVICES for actual GPUs used + volumes: + - /data/benchmark-hf-cache:/root/.cache/huggingface + environment: + - VLLM_USAGE_SOURCE + - HF_TOKEN diff --git a/.buildkite/nightly-benchmarks/nightly-annotation.md b/.buildkite/nightly-benchmarks/nightly-annotation.md new file mode 100644 index 000000000000..e43ea765f155 --- /dev/null +++ b/.buildkite/nightly-benchmarks/nightly-annotation.md @@ -0,0 +1,27 @@ + +## Description + +This file contains the downloading link for benchmarking results. + +- [benchmarking pipeline](artifact://nightly-pipeline.yaml) +- [benchmarking results](artifact://results.zip) +- [benchmarking code](artifact://nightly-benchmarks.zip) + +Please download the visualization scripts in the post + +## Results reproduction + +- Find the docker we use in `benchmarking pipeline` +- Deploy the docker, and inside the docker: + - Download `nightly-benchmarks.zip`. + - In the same folder, run the following code: + + ```console + export HF_TOKEN= + apt update + apt install -y git + unzip nightly-benchmarks.zip + VLLM_SOURCE_CODE_LOC=./ bash .buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh + ``` + +And the results will be inside `./benchmarks/results`. diff --git a/.buildkite/nightly-benchmarks/nightly-descriptions.md b/.buildkite/nightly-benchmarks/nightly-descriptions.md index c3d3cbf47396..5f003f42f07c 100644 --- a/.buildkite/nightly-benchmarks/nightly-descriptions.md +++ b/.buildkite/nightly-benchmarks/nightly-descriptions.md @@ -1,45 +1,39 @@ # Nightly benchmark -The main goal of this benchmarking is two-fold: -- Performance clarity: Provide clarity on which one (vllm, tensorrt-llm, lmdeploy and tgi) leads in performance in what workload. -- Reproducible: one can run the exact same set of benchmarking commands inside the exact same docker by following reproducing instructions in [reproduce.md](). - - -## Docker images - -We benchmark vllm, tensorrt-llm, lmdeploy and tgi using the following docker images: -- vllm/vllm-openai:v0.5.0.post1 -- nvcr.io/nvidia/tritonserver:24.04-trtllm-python-py3 -- openmmlab/lmdeploy:v0.5.0 -- ghcr.io/huggingface/text-generation-inference:2.1 - - - - -## Hardware - -One AWS node with 8x NVIDIA A100 GPUs. - - -## Workload description - -We benchmark vllm, tensorrt-llm, lmdeploy and tgi using the following workload: - -- Input length: randomly sample 500 prompts from ShareGPT dataset (with fixed random seed). -- Output length: the corresponding output length of these 500 prompts. -- Models: llama-3 8B, llama-3 70B, mixtral 8x7B. -- Average QPS (query per second): 4 for the small model (llama-3 8B) and 2 for other two models. For each QPS, the arrival time of each query is determined using a random Poisson process (with fixed random seed). -- Evaluation metrics: Throughput (higher the better), TTFT (time to the first token, lower the better), ITL (inter-token latency, lower the better). - - - -## Plots - -In the following plots, the dot shows the mean and the error bar shows the standard error of the mean. Value 0 means that the corresponding benchmark crashed. - -Benchmarking results - -## Results - -{nightly_results_benchmarking_table} +This benchmark aims to: + +- Provide performance clarity: Provide clarity on which one (vllm, tensorrt-llm, lmdeploy and SGLang) leads in performance in what workload. +- Be reproducible: one can run the exact same set of benchmarking commands inside the exact same docker by following reproducing instructions. + +Latest results: [results link](https://blog.vllm.ai/2024/09/05/perf-update.html), scroll to the end. + +Latest reproduction guilde: [github issue link](https://github.com/vllm-project/vllm/issues/8176) + +## Setup + +- Docker images: + - vLLM: `vllm/vllm-openai:v0.6.2` + - SGLang: `lmsysorg/sglang:v0.3.2-cu121` + - LMDeploy: `openmmlab/lmdeploy:v0.6.1-cu12` + - TensorRT-LLM: `nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3` + - *NOTE: we uses r24.07 as the current implementation only works for this version. We are going to bump this up.* + - Check [nightly-pipeline.yaml](nightly-pipeline.yaml) for the concrete docker images, specs and commands we use for the benchmark. +- Hardware + - 8x Nvidia A100 GPUs +- Workload: + - Dataset + - ShareGPT dataset + - Prefill-heavy dataset (in average 462 input tokens, 16 tokens as output) + - Decode-heavy dataset (in average 462 input tokens, 256 output tokens) + - Check [nightly-tests.json](tests/nightly-tests.json) for the concrete configuration of datasets we use. + - Models: llama-3 8B, llama-3 70B. + - We do not use llama 3.1 as it is incompatible with trt-llm r24.07. ([issue](https://github.com/NVIDIA/TensorRT-LLM/issues/2105)). + - Average QPS (query per second): 2, 4, 8, 16, 32 and inf. + - Queries are randomly sampled, and arrival patterns are determined via Poisson process, but all with fixed random seed. + - Evaluation metrics: Throughput (higher the better), TTFT (time to the first token, lower the better), ITL (inter-token latency, lower the better). + +## Known issues + +- TRT-LLM crashes with Llama 3.1 8B [issue](https://github.com/NVIDIA/TensorRT-LLM/issues/2105). +- TGI does not support `ignore-eos` flag. diff --git a/.buildkite/nightly-benchmarks/nightly-pipeline.yaml b/.buildkite/nightly-benchmarks/nightly-pipeline.yaml index 6e399bb936fb..199517e8b067 100644 --- a/.buildkite/nightly-benchmarks/nightly-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/nightly-pipeline.yaml @@ -13,7 +13,7 @@ common_pod_spec: &common_pod_spec common_container_settings: &common_container_settings command: - - bash .buildkite/nightly-benchmarks/run-nightly-suite.sh + - bash .buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh resources: limits: nvidia.com/gpu: 8 @@ -37,7 +37,10 @@ common_container_settings: &common_container_settings steps: - block: ":rocket: Ready for comparing vllm against alternatives? This will take 4 hours." - - label: "A100 trt benchmark" + + + + - label: "A100 vllm step 10" priority: 100 agents: queue: A100 @@ -46,7 +49,21 @@ steps: podSpec: <<: *common_pod_spec containers: - - image: nvcr.io/nvidia/tritonserver:24.04-trtllm-python-py3 + - image: vllm/vllm-openai:v0.6.2 + <<: *common_container_settings + + + + - label: "A100 sglang benchmark" + priority: 100 + agents: + queue: A100 + plugins: + - kubernetes: + podSpec: + <<: *common_pod_spec + containers: + - image: lmsysorg/sglang:v0.3.2-cu121 <<: *common_container_settings - label: "A100 lmdeploy benchmark" @@ -58,11 +75,13 @@ steps: podSpec: <<: *common_pod_spec containers: - - image: openmmlab/lmdeploy:v0.5.0 + - image: openmmlab/lmdeploy:v0.6.1-cu12 <<: *common_container_settings - - - label: "A100 vllm benchmark" + + + + - label: "A100 trt llama-8B" priority: 100 agents: queue: A100 @@ -71,10 +90,25 @@ steps: podSpec: <<: *common_pod_spec containers: - - image: vllm/vllm-openai:latest + - image: nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3 <<: *common_container_settings + env: + - name: VLLM_USAGE_SOURCE + value: ci-test + - name: HF_HOME + value: /root/.cache/huggingface + - name: VLLM_SOURCE_CODE_LOC + value: /workspace/build/buildkite/vllm/performance-benchmark + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + - name: TEST_SELECTOR + value: "llama8B" - - label: "A100 tgi benchmark" + + - label: "A100 trt llama-70B" priority: 100 agents: queue: A100 @@ -83,12 +117,54 @@ steps: podSpec: <<: *common_pod_spec containers: - - image: ghcr.io/huggingface/text-generation-inference:2.1 + - image: nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3 <<: *common_container_settings + env: + - name: VLLM_USAGE_SOURCE + value: ci-test + - name: HF_HOME + value: /root/.cache/huggingface + - name: VLLM_SOURCE_CODE_LOC + value: /workspace/build/buildkite/vllm/performance-benchmark + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + - name: TEST_SELECTOR + value: "llama70B" + + + # FIXME(Kuntai): uncomment this after NVIDIA gives us their test docker image + # - label: "A100 trt benchmark" + # priority: 100 + # agents: + # queue: A100 + # plugins: + # - kubernetes: + # podSpec: + # <<: *common_pod_spec + # containers: + # - image: nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3 + # <<: *common_container_settings + + + # FIXME(Kuntai): uncomment this after TGI supports `--ignore-eos`. + # - label: "A100 tgi benchmark" + # priority: 100 + # agents: + # queue: A100 + # plugins: + # - kubernetes: + # podSpec: + # <<: *common_pod_spec + # containers: + # - image: ghcr.io/huggingface/text-generation-inference:2.2.0 + # <<: *common_container_settings - wait - - label: "Plot" + - label: "Collect the results" priority: 100 agents: queue: A100 @@ -117,4 +193,4 @@ steps: name: hf-token-secret key: token - - wait \ No newline at end of file + - block: ":rocket: check the results!" \ No newline at end of file diff --git a/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md b/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md new file mode 100644 index 000000000000..cacaef986c65 --- /dev/null +++ b/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md @@ -0,0 +1,56 @@ + +## Latency tests + +- Input length: 32 tokens. +- Output length: 128 tokens. +- Batch size: fixed (8). +- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. +- Evaluation metrics: end-to-end latency (mean, median, p99). + +{latency_tests_markdown_table} + +## Throughput tests + +- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed). +- Output length: the corresponding output length of these 200 prompts. +- Batch size: dynamically determined by vllm to achieve maximum throughput. +- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. +- Evaluation metrics: throughput. + +{throughput_tests_markdown_table} + +## Serving tests + +- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed). +- Output length: the corresponding output length of these 200 prompts. +- Batch size: dynamically determined by vllm and the arrival pattern of the requests. +- **Average QPS (query per second)**: 1, 4, 16 and inf. QPS = inf means all requests come at once. For other QPS values, the arrival time of each query is determined using a random Poisson process (with fixed random seed). +- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. +- We also added a speculative decoding test for llama-3 70B, under QPS 2 +- Evaluation metrics: throughput, TTFT (time to the first token, with mean, median and p99), ITL (inter-token latency, with mean, median and p99). + +{serving_tests_markdown_table} + +## json version of the benchmarking tables + +This section contains the data of the markdown tables above in JSON format. +You can load the benchmarking tables into pandas dataframes as follows: + +```python +import json +import pandas as pd + +benchmarking_results_json = """The json string""" +benchmarking_results = json.loads(benchmarking_results_json) +latency_results = pd.DataFrame.from_dict(benchmarking_results["latency"]) +throughput_results = pd.DataFrame.from_dict(benchmarking_results["throughput"]) +serving_results = pd.DataFrame.from_dict(benchmarking_results["serving"]) +``` + +The json string for all benchmarking tables: + +```json +{benchmarking_results_in_json_string} +``` + +You can also check the raw experiment data in the Artifact tab of the Buildkite page. diff --git a/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh b/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh deleted file mode 100644 index 1a88d038b4b5..000000000000 --- a/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh +++ /dev/null @@ -1,380 +0,0 @@ -#!/bin/bash - -# This script should be run inside the CI process -# This script assumes that we are already inside the vllm/ directory -# Benchmarking results will be available inside vllm/benchmarks/results/ - -# Do not set -e, as the mixtral 8x22B model tends to crash occasionally -# and we still want to see other benchmarking results even when mixtral crashes. -set -o pipefail - -check_gpus() { - # check the number of GPUs and GPU type. - declare -g gpu_count=$(nvidia-smi --list-gpus | wc -l) - if [[ $gpu_count -gt 0 ]]; then - echo "GPU found." - else - echo "Need at least 1 GPU to run benchmarking." - exit 1 - fi - declare -g gpu_type=$(echo $(nvidia-smi --query-gpu=name --format=csv,noheader) | awk '{print $2}') - echo "GPU type is $gpu_type" -} - -check_hf_token() { - # check if HF_TOKEN is available and valid - if [[ -z "$HF_TOKEN" ]]; then - echo "Error: HF_TOKEN is not set." - exit 1 - elif [[ ! "$HF_TOKEN" =~ ^hf_ ]]; then - echo "Error: HF_TOKEN does not start with 'hf_'." - exit 1 - else - echo "HF_TOKEN is set and valid." - fi -} - -ensure_sharegpt_downloaded() { - local FILE=ShareGPT_V3_unfiltered_cleaned_split.json - if [ ! -f "$FILE" ]; then - wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/$FILE - else - echo "$FILE already exists." - fi -} - -json2args() { - # transforms the JSON string to command line args, and '_' is replaced to '-' - # example: - # input: { "model": "meta-llama/Llama-2-7b-chat-hf", "tensor_parallel_size": 1 } - # output: --model meta-llama/Llama-2-7b-chat-hf --tensor-parallel-size 1 - local json_string=$1 - local args=$( - echo "$json_string" | jq -r ' - to_entries | - map("--" + (.key | gsub("_"; "-")) + " " + (.value | tostring)) | - join(" ") - ' - ) - echo "$args" -} - -wait_for_server() { - # wait for vllm server to start - # return 1 if vllm server crashes - timeout 1200 bash -c ' - until curl -X POST localhost:8000/v1/completions; do - sleep 1 - done' && return 0 || return 1 -} - -kill_gpu_processes() { - # kill all processes on GPU. - pids=$(nvidia-smi --query-compute-apps=pid --format=csv,noheader) - if [ -z "$pids" ]; then - echo "No GPU processes found." - else - for pid in $pids; do - kill -9 "$pid" - echo "Killed process with PID: $pid" - done - - echo "All GPU processes have been killed." - fi - - # waiting for GPU processes to be fully killed - # loop while nvidia-smi returns any processes - while [ -n "$(nvidia-smi --query-compute-apps=pid --format=csv,noheader)" ]; do - sleep 1 - echo "Waiting for GPU processes to be killed" - done - - # remove vllm config file - rm -rf ~/.config/vllm - - # Print the GPU memory usage - # so that we know if all GPU processes are killed. - gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) - # The memory usage should be 0 MB. - echo "GPU 0 Memory Usage: $gpu_memory_usage MB" -} - -upload_to_buildkite() { - # upload the benchmarking results to buildkite - - # if the agent binary is not found, skip uploading the results, exit 0 - # Check if buildkite-agent is available in the PATH or at /workspace/buildkite-agent - if command -v buildkite-agent >/dev/null 2>&1; then - BUILDKITE_AGENT_COMMAND="buildkite-agent" - elif [ -f /workspace/buildkite-agent ]; then - BUILDKITE_AGENT_COMMAND="/workspace/buildkite-agent" - else - echo "buildkite-agent binary not found. Skip uploading the results." - return 0 - fi - - # Use the determined command to annotate and upload artifacts - $BUILDKITE_AGENT_COMMAND annotate --style "info" --context "$BUILDKITE_LABEL-benchmark-results" < $RESULTS_FOLDER/benchmark_results.md - $BUILDKITE_AGENT_COMMAND artifact upload "$RESULTS_FOLDER/*" -} - -run_latency_tests() { - # run latency tests using `benchmark_latency.py` - # $1: a json file specifying latency test cases - - local latency_test_file - latency_test_file=$1 - - # Iterate over latency tests - jq -c '.[]' "$latency_test_file" | while read -r params; do - # get the test name, and append the GPU type back to it. - test_name=$(echo "$params" | jq -r '.test_name') - if [[ ! "$test_name" =~ ^latency_ ]]; then - echo "In latency-test.json, test_name must start with \"latency_\"." - exit 1 - fi - - # if TEST_SELECTOR is set, only run the test cases that match the selector - if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then - echo "Skip test case $test_name." - continue - fi - - # get arguments - latency_params=$(echo "$params" | jq -r '.parameters') - latency_args=$(json2args "$latency_params") - - # check if there is enough GPU to run the test - tp=$(echo "$latency_params" | jq -r '.tensor_parallel_size') - if [[ $gpu_count -lt $tp ]]; then - echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $testname." - continue - fi - - latency_command="python3 benchmark_latency.py \ - --output-json $RESULTS_FOLDER/${test_name}.json \ - $latency_args" - - echo "Running test case $test_name" - echo "Latency command: $latency_command" - - # recoding benchmarking command ang GPU command - jq_output=$(jq -n \ - --arg latency "$latency_command" \ - --arg gpu "$gpu_type" \ - '{ - latency_command: $latency, - gpu_type: $gpu - }') - echo "$jq_output" > "$RESULTS_FOLDER/$test_name.commands" - - # run the benchmark - eval "$latency_command" - - kill_gpu_processes - - done -} - - -run_throughput_tests() { - # run throughput tests using `benchmark_throughput.py` - # $1: a json file specifying throughput test cases - - local throughput_test_file - throughput_test_file=$1 - - # Iterate over throughput tests - jq -c '.[]' "$throughput_test_file" | while read -r params; do - # get the test name, and append the GPU type back to it. - test_name=$(echo "$params" | jq -r '.test_name') - if [[ ! "$test_name" =~ ^throughput_ ]]; then - echo "In throughput-test.json, test_name must start with \"throughput_\"." - exit 1 - fi - - # if TEST_SELECTOR is set, only run the test cases that match the selector - if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then - echo "Skip test case $test_name." - continue - fi - - # get arguments - throughput_params=$(echo "$params" | jq -r '.parameters') - throughput_args=$(json2args "$throughput_params") - - # check if there is enough GPU to run the test - tp=$(echo $throughput_params | jq -r '.tensor_parallel_size') - if [[ $gpu_count -lt $tp ]]; then - echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $testname." - continue - fi - - throughput_command="python3 benchmark_throughput.py \ - --output-json $RESULTS_FOLDER/${test_name}.json \ - $throughput_args" - - echo "Running test case $test_name" - echo "Throughput command: $throughput_command" - # recoding benchmarking command ang GPU command - jq_output=$(jq -n \ - --arg command "$throughput_command" \ - --arg gpu "$gpu_type" \ - '{ - throughput_command: $command, - gpu_type: $gpu - }') - echo "$jq_output" > "$RESULTS_FOLDER/$test_name.commands" - - # run the benchmark - eval "$throughput_command" - - kill_gpu_processes - - done -} - -run_serving_tests() { - # run serving tests using `benchmark_serving.py` - # $1: a json file specifying serving test cases - - local serving_test_file - serving_test_file=$1 - - # Iterate over serving tests - jq -c '.[]' "$serving_test_file" | while read -r params; do - # get the test name, and append the GPU type back to it. - test_name=$(echo "$params" | jq -r '.test_name') - if [[ ! "$test_name" =~ ^serving_ ]]; then - echo "In serving-test.json, test_name must start with \"serving_\"." - exit 1 - fi - - # if TEST_SELECTOR is set, only run the test cases that match the selector - if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then - echo "Skip test case $test_name." - continue - fi - - - # get client and server arguments - server_params=$(echo "$params" | jq -r '.server_parameters') - client_params=$(echo "$params" | jq -r '.client_parameters') - server_args=$(json2args "$server_params") - client_args=$(json2args "$client_params") - qps_list=$(echo "$params" | jq -r '.qps_list') - qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') - echo "Running over qps list $qps_list" - - # check if there is enough GPU to run the test - tp=$(echo "$server_params" | jq -r '.tensor_parallel_size') - if [[ $gpu_count -lt $tp ]]; then - echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $testname." - continue - fi - - # check if server model and client model is aligned - server_model=$(echo "$server_params" | jq -r '.model') - client_model=$(echo "$client_params" | jq -r '.model') - if [[ $server_model != "$client_model" ]]; then - echo "Server model and client model must be the same. Skip testcase $testname." - continue - fi - - server_command="python3 \ - -m vllm.entrypoints.openai.api_server \ - $server_args" - - # run the server - echo "Running test case $test_name" - echo "Server command: $server_command" - eval "$server_command" & - server_pid=$! - - # wait until the server is alive - wait_for_server - if [ $? -eq 0 ]; then - echo "" - echo "vllm server is up and running." - else - echo "" - echo "vllm failed to start within the timeout period." - fi - - # iterate over different QPS - for qps in $qps_list; do - # remove the surrounding single quote from qps - if [[ "$qps" == *"inf"* ]]; then - echo "qps was $qps" - qps="inf" - echo "now qps is $qps" - fi - - new_test_name=$test_name"_qps_"$qps - - client_command="python3 benchmark_serving.py \ - --save-result \ - --result-dir $RESULTS_FOLDER \ - --result-filename ${new_test_name}.json \ - --request-rate $qps \ - $client_args" - - echo "Running test case $test_name with qps $qps" - echo "Client command: $client_command" - - eval "$client_command" - - # record the benchmarking commands - jq_output=$(jq -n \ - --arg server "$server_command" \ - --arg client "$client_command" \ - --arg gpu "$gpu_type" \ - '{ - server_command: $server, - client_command: $client, - gpu_type: $gpu - }') - echo "$jq_output" > "$RESULTS_FOLDER/${new_test_name}.commands" - - done - - # clean up - kill -9 $server_pid - kill_gpu_processes - done -} - -main() { - check_gpus - check_hf_token - - # dependencies - (which wget && which curl) || (apt-get update && apt-get install -y wget curl) - (which jq) || (apt-get update && apt-get -y install jq) - - # get the current IP address, required by benchmark_serving.py - export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') - # turn of the reporting of the status of each request, to clean up the terminal output - export VLLM_LOG_LEVEL="WARNING" - - # prepare for benchmarking - cd benchmarks || exit 1 - ensure_sharegpt_downloaded - declare -g RESULTS_FOLDER=results/ - mkdir -p $RESULTS_FOLDER - QUICK_BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/ - - # benchmarking - run_serving_tests $QUICK_BENCHMARK_ROOT/tests/serving-tests.json - run_latency_tests $QUICK_BENCHMARK_ROOT/tests/latency-tests.json - run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/throughput-tests.json - - - # postprocess benchmarking results - pip install tabulate pandas - python3 $QUICK_BENCHMARK_ROOT/scripts/convert-results-json-to-markdown.py - - upload_to_buildkite -} - -main "$@" diff --git a/.buildkite/nightly-benchmarks/run-nightly-suite.sh b/.buildkite/nightly-benchmarks/run-nightly-suite.sh deleted file mode 100644 index 627a3e697157..000000000000 --- a/.buildkite/nightly-benchmarks/run-nightly-suite.sh +++ /dev/null @@ -1,76 +0,0 @@ -#!/bin/bash - -set -o pipefail -set -x - -check_gpus() { - # check the number of GPUs and GPU type. - declare -g gpu_count=$(nvidia-smi --list-gpus | wc -l) - if [[ $gpu_count -gt 0 ]]; then - echo "GPU found." - else - echo "Need at least 1 GPU to run benchmarking." - exit 1 - fi - declare -g gpu_type=$(echo $(nvidia-smi --query-gpu=name --format=csv,noheader) | awk '{print $2}') - echo "GPU type is $gpu_type" -} - -check_hf_token() { - # check if HF_TOKEN is available and valid - if [[ -z "$HF_TOKEN" ]]; then - echo "Error: HF_TOKEN is not set." - exit 1 - elif [[ ! "$HF_TOKEN" =~ ^hf_ ]]; then - echo "Error: HF_TOKEN does not start with 'hf_'." - exit 1 - else - echo "HF_TOKEN is set and valid." - fi -} - -main() { - - check_gpus - check_hf_token - - df -h - - (which wget && which curl) || (apt-get update && apt-get install -y wget curl) - (which jq) || (apt-get update && apt-get -y install jq) - - cd $VLLM_SOURCE_CODE_LOC/benchmarks - wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json - - - # run lmdeploy - if which lmdeploy >/dev/null; then - echo "lmdeploy is available, redirect to run-lmdeploy-nightly.sh" - bash ../.buildkite/nightly-benchmarks/scripts/run-lmdeploy-nightly.sh - exit 0 - fi - - # run tgi - if [ -e /tgi-entrypoint.sh ]; then - echo "tgi is available, redirect to run-tgi-nightly.sh" - bash ../.buildkite/nightly-benchmarks/scripts/run-tgi-nightly.sh - exit 0 - fi - - # run trt - if which trtllm-build >/dev/null; then - echo "trtllm is available, redirect to run-trt-nightly.sh" - bash ../.buildkite/nightly-benchmarks/scripts/run-trt-nightly.sh - exit 0 - fi - - # run vllm - if [ -e /vllm-workspace ]; then - echo "vllm is available, redirect to run-vllm-nightly.sh" - bash ../.buildkite/nightly-benchmarks/scripts/run-vllm-nightly.sh - exit 0 - fi - -} - -main "$@" \ No newline at end of file diff --git a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py index 534ecf17930e..1030ec24e8d7 100644 --- a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py +++ b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + import json import os from pathlib import Path @@ -56,7 +58,7 @@ def read_markdown(file): if os.path.exists(file): - with open(file, "r") as f: + with open(file) as f: return f.read() + "\n" else: return f"{file} not found.\n" @@ -75,15 +77,20 @@ def results_to_json(latency, throughput, serving): # collect results for test_file in results_folder.glob("*.json"): - with open(test_file, "r") as f: + with open(test_file) as f: raw_result = json.loads(f.read()) if "serving" in str(test_file): # this result is generated via `benchmark_serving.py` # attach the benchmarking command to raw_result - with open(test_file.with_suffix(".commands"), "r") as f: - command = json.loads(f.read()) + try: + with open(test_file.with_suffix(".commands")) as f: + command = json.loads(f.read()) + except OSError as e: + print(e) + continue + raw_result.update(command) # update the test name of this result @@ -97,8 +104,13 @@ def results_to_json(latency, throughput, serving): # this result is generated via `benchmark_latency.py` # attach the benchmarking command to raw_result - with open(test_file.with_suffix(".commands"), "r") as f: - command = json.loads(f.read()) + try: + with open(test_file.with_suffix(".commands")) as f: + command = json.loads(f.read()) + except OSError as e: + print(e) + continue + raw_result.update(command) # update the test name of this result @@ -119,8 +131,13 @@ def results_to_json(latency, throughput, serving): # this result is generated via `benchmark_throughput.py` # attach the benchmarking command to raw_result - with open(test_file.with_suffix(".commands"), "r") as f: - command = json.loads(f.read()) + try: + with open(test_file.with_suffix(".commands")) as f: + command = json.loads(f.read()) + except OSError as e: + print(e) + continue + raw_result.update(command) # update the test name of this result @@ -157,6 +174,18 @@ def results_to_json(latency, throughput, serving): throughput_results, serving_results) + for df in [latency_results, serving_results, throughput_results]: + if df.empty: + continue + + # Sort all dataframes by their respective "Test name" columns + df.sort_values(by="Test name", inplace=True) + + # The GPUs sometimes come in format of "GPUTYPE\nGPUTYPE\n...", + # we want to turn it into "8xGPUTYPE" + df["GPU"] = df["GPU"].apply( + lambda x: f"{len(x.split('\n'))}x{x.split('\n')[0]}") + # get markdown tables latency_md_table = tabulate(latency_results, headers='keys', @@ -174,8 +203,8 @@ def results_to_json(latency, throughput, serving): # document the result with open(results_folder / "benchmark_results.md", "w") as f: - results = read_markdown( - "../.buildkite/nightly-benchmarks/tests/descriptions.md") + results = read_markdown("../.buildkite/nightly-benchmarks/" + + "performance-benchmarks-descriptions.md") results = results.format( latency_tests_markdown_table=latency_md_table, throughput_tests_markdown_table=throughput_md_table, diff --git a/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py b/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py index 68ac5909e595..5e17b79d26a1 100644 --- a/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py +++ b/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + import argparse from transformers import AutoTokenizer diff --git a/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py b/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py new file mode 100644 index 000000000000..0ff95a0911b1 --- /dev/null +++ b/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import json +from pathlib import Path + +import numpy as np +import pandas as pd +from tabulate import tabulate + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description= + 'Parse command line arguments for summary-nightly-results script.') + parser.add_argument('--results-folder', + type=str, + required=True, + help='The folder where the results are stored.') + parser.add_argument('--description', + type=str, + required=True, + help='Description of the results.') + + args = parser.parse_args() + return args + + +def get_perf(df, method, model, metric): + + means = [] + + for qps in [2, 4, 8, 16, "inf"]: + target = df['Test name'].str.contains(model) + target = target & df['Engine'].str.contains(method) + target = target & df['Test name'].str.contains("qps_" + str(qps)) + filtered_df = df[target] + + if filtered_df.empty: + means.append(0.) + else: + means.append(filtered_df[metric].values[0]) + + return np.array(means) + + +def get_perf_w_std(df, method, model, metric): + + if metric in ["TTFT", "ITL"]: + mean = get_perf(df, method, model, "Mean " + metric + " (ms)") + mean = mean.tolist() + std = get_perf(df, method, model, "Std " + metric + " (ms)") + if std.mean() == 0: + std = None + success = get_perf(df, method, model, "Successful req.") + if std is not None: + std = std / np.sqrt(success) + std = std.tolist() + + else: + assert metric == "Tput" + mean = get_perf(df, method, model, "Input Tput (tok/s)") + get_perf( + df, method, model, "Output Tput (tok/s)") + mean = mean.tolist() + std = None + + return mean, std + + +def main(args): + results_folder = Path(args.results_folder) + + results = [] + + # collect results + for test_file in results_folder.glob("*_nightly_results.json"): + with open(test_file) as f: + results = results + json.loads(f.read()) + + # generate markdown table + df = pd.DataFrame.from_dict(results) + + md_table = tabulate(df, headers='keys', tablefmt='pipe', showindex=False) + + with open(args.description) as f: + description = f.read() + + description = description.format( + nightly_results_benchmarking_table=md_table) + + with open("nightly_results.md", "w") as f: + f.write(description) + + +if __name__ == '__main__': + args = parse_arguments() + main(args) diff --git a/.buildkite/nightly-benchmarks/scripts/get-lmdeploy-modelname.py b/.buildkite/nightly-benchmarks/scripts/get-lmdeploy-modelname.py index 18bcc3a8714c..e5f179a0f5b6 100644 --- a/.buildkite/nightly-benchmarks/scripts/get-lmdeploy-modelname.py +++ b/.buildkite/nightly-benchmarks/scripts/get-lmdeploy-modelname.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + from lmdeploy.serve.openai.api_client import APIClient api_client = APIClient("http://localhost:8000") diff --git a/.buildkite/nightly-benchmarks/scripts/launch-server.sh b/.buildkite/nightly-benchmarks/scripts/launch-server.sh new file mode 100644 index 000000000000..fb5063db8694 --- /dev/null +++ b/.buildkite/nightly-benchmarks/scripts/launch-server.sh @@ -0,0 +1,228 @@ +#!/bin/bash + +# Currently FP8 benchmark is NOT enabled. + +set -x +server_params=$1 +common_params=$2 + +json2args() { + # transforms the JSON string to command line args, and '_' is replaced to '-' + # example: + # input: { "model": "meta-llama/Llama-2-7b-chat-hf", "tensor_parallel_size": 1 } + # output: --model meta-llama/Llama-2-7b-chat-hf --tensor-parallel-size 1 + local json_string=$1 + local args=$( + echo "$json_string" | jq -r ' + to_entries | + map("--" + (.key | gsub("_"; "-")) + " " + (.value | tostring)) | + join(" ") + ' + ) + echo "$args" +} + +launch_trt_server() { + + model_path=$(echo "$common_params" | jq -r '.model') + model_name="${model_path#*/}" + model_type=$(echo "$server_params" | jq -r '.model_type') + model_dtype=$(echo "$server_params" | jq -r '.model_dtype') + model_tp_size=$(echo "$common_params" | jq -r '.tp') + max_batch_size=$(echo "$server_params" | jq -r '.max_batch_size') + max_input_len=$(echo "$server_params" | jq -r '.max_input_len') + max_seq_len=$(echo "$server_params" | jq -r '.max_seq_len') + max_num_tokens=$(echo "$server_params" | jq -r '.max_num_tokens') + trt_llm_version=$(echo "$server_params" | jq -r '.trt_llm_version') + + # create model caching directory + cd ~ + rm -rf models + mkdir -p models + cd models + models_dir=$(pwd) + trt_model_path=${models_dir}/${model_name}-trt-ckpt + trt_engine_path=${models_dir}/${model_name}-trt-engine + + # clone tensorrt backend + cd / + rm -rf tensorrtllm_backend + git clone https://github.com/triton-inference-server/tensorrtllm_backend.git + git lfs install + cd tensorrtllm_backend + git checkout "$trt_llm_version" + git submodule update --init --recursive + + # build trtllm engine + cd /tensorrtllm_backend + cd "./tensorrt_llm/examples/${model_type}" + python3 convert_checkpoint.py \ + --model_dir "${model_path}" \ + --dtype "${model_dtype}" \ + --tp_size "${model_tp_size}" \ + --output_dir "${trt_model_path}" + trtllm-build \ + --checkpoint_dir "${trt_model_path}" \ + --use_fused_mlp \ + --reduce_fusion disable \ + --workers 8 \ + --gpt_attention_plugin "${model_dtype}" \ + --gemm_plugin "${model_dtype}" \ + --tp_size "${model_tp_size}" \ + --max_batch_size "${max_batch_size}" \ + --max_input_len "${max_input_len}" \ + --max_seq_len "${max_seq_len}" \ + --max_num_tokens "${max_num_tokens}" \ + --output_dir "${trt_engine_path}" + + # handle triton protobuf files and launch triton server + cd /tensorrtllm_backend + mkdir triton_model_repo + cp -r all_models/inflight_batcher_llm/* triton_model_repo/ + cd triton_model_repo + rm -rf ./tensorrt_llm/1/* + cp -r "${trt_engine_path}"/* ./tensorrt_llm/1 + python3 ../tools/fill_template.py -i tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,engine_dir:/tensorrtllm_backend/triton_model_repo/tensorrt_llm/1,decoupled_mode:true,batching_strategy:inflight_fused_batching,batch_scheduler_policy:guaranteed_no_evict,exclude_input_in_output:true,triton_max_batch_size:2048,max_queue_delay_microseconds:0,max_beam_width:1,max_queue_size:2048,enable_kv_cache_reuse:false + python3 ../tools/fill_template.py -i preprocessing/config.pbtxt "triton_max_batch_size:2048,tokenizer_dir:$model_path,preprocessing_instance_count:5" + python3 ../tools/fill_template.py -i postprocessing/config.pbtxt "triton_max_batch_size:2048,tokenizer_dir:$model_path,postprocessing_instance_count:5,skip_special_tokens:false" + python3 ../tools/fill_template.py -i ensemble/config.pbtxt triton_max_batch_size:"$max_batch_size" + python3 ../tools/fill_template.py -i tensorrt_llm_bls/config.pbtxt "triton_max_batch_size:$max_batch_size,decoupled_mode:true,accumulate_tokens:False,bls_instance_count:1" + cd /tensorrtllm_backend + python3 scripts/launch_triton_server.py \ + --world_size="${model_tp_size}" \ + --model_repo=/tensorrtllm_backend/triton_model_repo & + +} + +launch_tgi_server() { + model=$(echo "$common_params" | jq -r '.model') + tp=$(echo "$common_params" | jq -r '.tp') + port=$(echo "$common_params" | jq -r '.port') + server_args=$(json2args "$server_params") + + if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then + echo "Key 'fp8' exists in common params." + server_command="/tgi-entrypoint.sh \ + --model-id $model \ + --num-shard $tp \ + --port $port \ + --quantize fp8 \ + $server_args" + else + echo "Key 'fp8' does not exist in common params." + server_command="/tgi-entrypoint.sh \ + --model-id $model \ + --num-shard $tp \ + --port $port \ + $server_args" + fi + + echo "Server command: $server_command" + eval "$server_command" & + +} + +launch_lmdeploy_server() { + model=$(echo "$common_params" | jq -r '.model') + tp=$(echo "$common_params" | jq -r '.tp') + port=$(echo "$common_params" | jq -r '.port') + server_args=$(json2args "$server_params") + + server_command="lmdeploy serve api_server $model \ + --tp $tp \ + --server-port $port \ + $server_args" + + # run the server + echo "Server command: $server_command" + bash -c "$server_command" & +} + +launch_sglang_server() { + + model=$(echo "$common_params" | jq -r '.model') + tp=$(echo "$common_params" | jq -r '.tp') + port=$(echo "$common_params" | jq -r '.port') + server_args=$(json2args "$server_params") + + if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then + echo "Key 'fp8' exists in common params. Use neuralmagic fp8 model for convenience." + model=$(echo "$common_params" | jq -r '.neuralmagic_quantized_model') + server_command="python3 \ + -m sglang.launch_server \ + --tp $tp \ + --model-path $model \ + --port $port \ + $server_args" + else + echo "Key 'fp8' does not exist in common params." + server_command="python3 \ + -m sglang.launch_server \ + --tp $tp \ + --model-path $model \ + --port $port \ + $server_args" + fi + + # run the server + echo "Server command: $server_command" + eval "$server_command" & +} + +launch_vllm_server() { + + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + + model=$(echo "$common_params" | jq -r '.model') + tp=$(echo "$common_params" | jq -r '.tp') + port=$(echo "$common_params" | jq -r '.port') + server_args=$(json2args "$server_params") + + if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then + echo "Key 'fp8' exists in common params. Use neuralmagic fp8 model for convenience." + model=$(echo "$common_params" | jq -r '.neuralmagic_quantized_model') + server_command="python3 \ + -m vllm.entrypoints.openai.api_server \ + -tp $tp \ + --model $model \ + --port $port \ + $server_args" + else + echo "Key 'fp8' does not exist in common params." + server_command="python3 \ + -m vllm.entrypoints.openai.api_server \ + -tp $tp \ + --model $model \ + --port $port \ + $server_args" + fi + + # run the server + echo "Server command: $server_command" + eval "$server_command" & +} + +main() { + + if [[ "$CURRENT_LLM_SERVING_ENGINE" == "trt" ]]; then + launch_trt_server + fi + + if [[ "$CURRENT_LLM_SERVING_ENGINE" == "tgi" ]]; then + launch_tgi_server + fi + + if [[ "$CURRENT_LLM_SERVING_ENGINE" == "lmdeploy" ]]; then + launch_lmdeploy_server + fi + + if [[ "$CURRENT_LLM_SERVING_ENGINE" == "sglang" ]]; then + launch_sglang_server + fi + + if [[ "$CURRENT_LLM_SERVING_ENGINE" == *"vllm"* ]]; then + launch_vllm_server + fi +} + +main diff --git a/.buildkite/nightly-benchmarks/scripts/launch-trt-server.sh b/.buildkite/nightly-benchmarks/scripts/launch-trt-server.sh deleted file mode 100644 index f8262653a662..000000000000 --- a/.buildkite/nightly-benchmarks/scripts/launch-trt-server.sh +++ /dev/null @@ -1,102 +0,0 @@ -#!/bin/bash - - -server_params=$1 -common_params=$2 - - - -model_path=$(echo "$common_params" | jq -r '.model') -model_name="${model_path#*/}" -model_type=$(echo "$server_params" | jq -r '.model_type') -model_dtype=$(echo "$server_params" | jq -r '.model_dtype') -model_tp_size=$(echo "$common_params" | jq -r '.tp') -max_batch_size=$(echo "$server_params" | jq -r '.max_batch_size') -max_input_len=$(echo "$server_params" | jq -r '.max_input_len') -max_output_len=$(echo "$server_params" | jq -r '.max_output_len') -trt_llm_version=$(echo "$server_params" | jq -r '.trt_llm_version') - -cd ~ -rm -rf models -mkdir -p models -cd models -models_dir=$(pwd) -trt_model_path=${models_dir}/${model_name}-trt-ckpt -trt_engine_path=${models_dir}/${model_name}-trt-engine - -cd ~ -rm -rf tensorrt-demo -git clone https://github.com/neuralmagic/tensorrt-demo.git -cd tensorrt-demo -tensorrt_demo_dir=$(pwd) - -# make sure the parameter inside tensorrt_demo is consistent to envvar -sed -i.bak "/key: \"tokenizer_dir\"/,/string_value:/s|string_value: \".*\"|string_value: \"$model_path\"|" ./triton_model_repo/postprocessing/config.pbtxt -sed -i.bak "/key: \"tokenizer_dir\"/,/string_value:/s|string_value: \".*\"|string_value: \"$model_path\"|" ./triton_model_repo/preprocessing/config.pbtxt -sed -i.bak "s|\(max_batch_size:\s*\)[0-9]*|\1$max_batch_size|g" ./triton_model_repo/ensemble/config.pbtxt -sed -i.bak "s|\(max_batch_size:\s*\)[0-9]*|\1$max_batch_size|g" ./triton_model_repo/preprocessing/config.pbtxt -sed -i.bak "s|\(max_batch_size:\s*\)[0-9]*|\1$max_batch_size|g" ./triton_model_repo/postprocessing/config.pbtxt -sed -i.bak "s|\(max_batch_size:\s*\)[0-9]*|\1$max_batch_size|g" ./triton_model_repo/tensorrt_llm_bls/config.pbtxt - - -cd / -rm -rf tensorrtllm_backend -git clone https://github.com/triton-inference-server/tensorrtllm_backend.git -git lfs install -cd tensorrtllm_backend -git checkout $trt_llm_version -tensorrtllm_backend_dir=$(pwd) -git submodule update --init --recursive -cp -r ${tensorrt_demo_dir}/triton_model_repo ${tensorrtllm_backend_dir}/ - -cd /tensorrtllm_backend -cd ./tensorrt_llm/examples/${model_type} - - -if echo "$common_params" | jq -e 'has("fp8")' > /dev/null; then - - echo "Key 'fp8' exists in common params. Use quantize.py instead of convert_checkpoint.py" - echo "Reference: https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/llama/README.md" - python ../quantization/quantize.py \ - --model_dir ${model_path} \ - --dtype ${model_dtype} \ - --tp_size ${model_tp_size} \ - --output_dir ${trt_model_path} \ - --qformat fp8 \ - --kv_cache_dtype fp8 \ - --calib_size 2 - -else - - echo "Key 'fp8' does not exist in common params. Use convert_checkpoint.py" - python3 convert_checkpoint.py \ - --model_dir ${model_path} \ - --dtype ${model_dtype} \ - --tp_size ${model_tp_size} \ - --output_dir ${trt_model_path} - -fi - - - -trtllm-build \ ---checkpoint_dir=${trt_model_path} \ ---gpt_attention_plugin=${model_dtype} \ ---gemm_plugin=${model_dtype} \ ---remove_input_padding=enable \ ---paged_kv_cache=enable \ ---tp_size=${model_tp_size} \ ---max_batch_size=${max_batch_size} \ ---max_input_len=${max_input_len} \ ---max_output_len=${max_output_len} \ ---max_num_tokens=${max_output_len} \ ---opt_num_tokens=${max_output_len} \ ---output_dir=${trt_engine_path} - -cd /tensorrtllm_backend/triton_model_repo -rm -rf ./tensorrt_llm/1/* -cp -r ${trt_engine_path}/* ./tensorrt_llm/1 -cd /tensorrtllm_backend -python3 scripts/launch_triton_server.py \ ---world_size=${model_tp_size} \ ---model_repo=/tensorrtllm_backend/triton_model_repo & \ No newline at end of file diff --git a/.buildkite/nightly-benchmarks/scripts/nightly-annotate.sh b/.buildkite/nightly-benchmarks/scripts/nightly-annotate.sh index 1168912c6e22..69b6b146b354 100644 --- a/.buildkite/nightly-benchmarks/scripts/nightly-annotate.sh +++ b/.buildkite/nightly-benchmarks/scripts/nightly-annotate.sh @@ -8,6 +8,7 @@ main() { (which wget && which curl) || (apt-get update && apt-get install -y wget curl) (which jq) || (apt-get update && apt-get -y install jq) + (which zip) || (apt-get install -y zip) if [ ! -f /workspace/buildkite-agent ]; then echo "buildkite-agent binary not found. Skip plotting the results." @@ -15,26 +16,63 @@ main() { fi # initial annotation - description="$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/nightly-descriptions.md" + #description="$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/nightly-descriptions.md" # download results - cd $VLLM_SOURCE_CODE_LOC/benchmarks + cd "$VLLM_SOURCE_CODE_LOC/benchmarks" mkdir -p results/ /workspace/buildkite-agent artifact download 'results/*nightly_results.json' results/ ls ls results/ - # generate figures - python3 -m pip install tabulate pandas matplotlib - python3 $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/plot-nightly-results.py \ - --description $description \ - --results-folder results/ + # upload benchmark results + zip -r results.zip results/ + /workspace/buildkite-agent artifact upload "results.zip" + + # upload benchmarking scripts + cd "$VLLM_SOURCE_CODE_LOC/" + zip -r nightly-benchmarks.zip .buildkite/ benchmarks/ + /workspace/buildkite-agent artifact upload "nightly-benchmarks.zip" + + cd "$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/" + # upload benchmarking pipeline + /workspace/buildkite-agent artifact upload "nightly-pipeline.yaml" + + cd "$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/" + /workspace/buildkite-agent annotate --style "success" --context "nightly-benchmarks-results" --append < nightly-annotation.md + + + + # The figures should be generated by a separate process outside the CI/CD pipeline + + # # generate figures + # python3 -m pip install tabulate pandas matplotlib + + # python3 $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py \ + # --description $description \ + # --results-folder results/ + + + # python3 $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/plot-nightly-results.py \ + # --description $description \ + # --results-folder results/ \ + # --dataset sharegpt + + # python3 $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/plot-nightly-results.py \ + # --description $description \ + # --results-folder results/ \ + # --dataset sonnet_2048_128 + + # python3 $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/plot-nightly-results.py \ + # --description $description \ + # --results-folder results/ \ + # --dataset sonnet_128_2048 - # upload results and figures - /workspace/buildkite-agent artifact upload "nightly_results.png" - /workspace/buildkite-agent artifact upload $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/nightly-pipeline.yaml - /workspace/buildkite-agent artifact upload $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/tests/nightly-tests.json - /workspace/buildkite-agent annotate --style "success" --context "nightly-benchmarks-results" --append < nightly_results.md + # # upload results and figures + # /workspace/buildkite-agent artifact upload "nightly_results*.png" + # /workspace/buildkite-agent artifact upload $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/nightly-pipeline.yaml + # /workspace/buildkite-agent artifact upload $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/tests/nightly-tests.json + # /workspace/buildkite-agent annotate --style "success" --context "nightly-benchmarks-results" --append < nightly_results.md } -main "$@" \ No newline at end of file +main "$@" diff --git a/.buildkite/nightly-benchmarks/scripts/plot-nightly-results.py b/.buildkite/nightly-benchmarks/scripts/plot-nightly-results.py deleted file mode 100644 index e5cfcc64a9b2..000000000000 --- a/.buildkite/nightly-benchmarks/scripts/plot-nightly-results.py +++ /dev/null @@ -1,135 +0,0 @@ -import argparse -import json -import math -from pathlib import Path - -import matplotlib.pyplot as plt -import pandas as pd -from tabulate import tabulate - - -def parse_arguments(): - parser = argparse.ArgumentParser( - description= - 'Parse command line arguments for summary-nightly-results script.') - parser.add_argument('--results-folder', - type=str, - required=True, - help='The folder where the results are stored.') - parser.add_argument('--description', - type=str, - required=True, - help='Description of the results.') - - args = parser.parse_args() - return args - - -def main(args): - bar_colors = ['#56B4E9', '#009E73', '#D55E00', '#E69F00'] - results_folder = Path(args.results_folder) - - results = [] - - # collect results - for test_file in results_folder.glob("*_nightly_results.json"): - with open(test_file, "r") as f: - results = results + json.loads(f.read()) - - # generate markdown table - df = pd.DataFrame.from_dict(results) - - md_table = tabulate(df, headers='keys', tablefmt='pipe', showindex=False) - - with open(args.description, "r") as f: - description = f.read() - - description = description.format( - nightly_results_benchmarking_table=md_table) - - with open("nightly_results.md", "w") as f: - f.write(description) - - plt.rcParams.update({'font.size': 20}) - - # plot results - fig, axes = plt.subplots(3, 3, figsize=(16, 14)) - fig.subplots_adjust(hspace=1) - methods = ["vllm", "trt", "lmdeploy", "tgi"] - for i, model in enumerate(["llama8B", "llama70B", "mixtral8x7B"]): - for j, metric in enumerate(["TTFT", "ITL"]): - means, stds = [], [] - for method in methods: - target = df['Test name'].str.contains(model) - target = target & df['Engine'].str.contains(method) - filtered_df = df[target] - - if filtered_df.empty: - means.append(0.) - stds.append(0.) - else: - means.append(filtered_df[f"Mean {metric} (ms)"].values[0]) - std = filtered_df[f"Std {metric} (ms)"].values[0] - success = filtered_df["Successful req."].values[0] - stds.append(std / math.sqrt(success)) - - print(model, metric) - print(means, stds) - - ax = axes[i, j + 1] - - bars = ax.bar( - ["vllm", "trt", "lmdeploy", "tgi"], - means, - yerr=stds, - capsize=10, - ) - for idx, bar in enumerate(bars): - bar.set_color(bar_colors[idx]) - ax.set_ylim(bottom=0) - - ax.set_ylabel(f"{metric} (ms)") - ax.set_title(f"{model} {metric}") - ax.grid(axis='y') - - metric = "Tput" - j = 0 - if True: - tputs = [] - for method in methods: - target = df['Test name'].str.contains(model) - target = target & df['Engine'].str.contains(method) - filtered_df = df[target] - - if filtered_df.empty: - tputs.append(0.) - else: - input_tput = filtered_df["Input Tput (tok/s)"].values[0] - output_tput = filtered_df["Output Tput (tok/s)"].values[0] - tputs.append(input_tput + output_tput) - - print(model, metric) - print(tputs) - - ax = axes[i, j] - - bars = ax.bar( - ["vllm", "trt", "lmdeploy", "tgi"], - tputs, - ) - for idx, bar in enumerate(bars): - bar.set_color(bar_colors[idx]) - - ax.set_ylim(bottom=0) - - ax.set_ylabel("Tput (token/s)") - ax.set_title(f"{model} {metric}") - ax.grid(axis='y') - - fig.tight_layout() - fig.savefig("nightly_results.png", bbox_inches='tight', dpi=400) - - -if __name__ == '__main__': - args = parse_arguments() - main(args) diff --git a/.buildkite/nightly-benchmarks/scripts/run-lmdeploy-nightly.sh b/.buildkite/nightly-benchmarks/scripts/run-lmdeploy-nightly.sh deleted file mode 100644 index d6f112aaa42f..000000000000 --- a/.buildkite/nightly-benchmarks/scripts/run-lmdeploy-nightly.sh +++ /dev/null @@ -1,218 +0,0 @@ -#!/bin/bash - -set -o pipefail - -check_gpus() { - # check the number of GPUs and GPU type. - declare -g gpu_count=$(nvidia-smi --list-gpus | wc -l) - if [[ $gpu_count -gt 0 ]]; then - echo "GPU found." - else - echo "Need at least 1 GPU to run benchmarking." - exit 1 - fi - declare -g gpu_type=$(echo $(nvidia-smi --query-gpu=name --format=csv,noheader) | awk '{print $2}') - echo "GPU type is $gpu_type" -} - -kill_gpu_processes() { - pkill lmdeploy || true - # waiting for GPU processes to be fully killed - sleep 10 - # Print the GPU memory usage - # so that we know if all GPU processes are killed. - gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) - # The memory usage should be 0 MB. - echo "GPU 0 Memory Usage: $gpu_memory_usage MB" -} - -json2args() { - # transforms the JSON string to command line args, and '_' is replaced to '-' - # example: - # input: { "model": "meta-llama/Llama-2-7b-chat-hf", "tensor_parallel_size": 1 } - # output: --model meta-llama/Llama-2-7b-chat-hf --tensor-parallel-size 1 - local json_string=$1 - local args=$( - echo "$json_string" | jq -r ' - to_entries | - map("--" + (.key | gsub("_"; "-")) + " " + (.value | tostring)) | - join(" ") - ' - ) - echo "$args" -} - -wait_for_server() { - # wait for vllm server to start - # return 1 if vllm server crashes - timeout 1200 bash -c ' - until curl -s localhost:8000/v1/completions > /dev/null; do - sleep 1 - done' && return 0 || return 1 -} - -run_serving_tests() { - # run serving tests using `benchmark_serving.py` - # $1: a json file specifying serving test cases - - local serving_test_file - serving_test_file=$1 - - # Iterate over serving tests - jq -c '.[]' "$serving_test_file" | while read -r params; do - # get the test name, and append the GPU type back to it. - test_name=$(echo "$params" | jq -r '.test_name') - - # if TEST_SELECTOR is set, only run the test cases that match the selector - if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then - echo "Skip test case $test_name." - continue - fi - - # append lmdeploy to the test name - test_name=lmdeploy_$test_name - - # get common parameters - common_params=$(echo "$params" | jq -r '.common_parameters') - model=$(echo "$common_params" | jq -r '.model') - tp=$(echo "$common_params" | jq -r '.tp') - dataset_name=$(echo "$common_params" | jq -r '.dataset_name') - dataset_path=$(echo "$common_params" | jq -r '.dataset_path') - port=$(echo "$common_params" | jq -r '.port') - num_prompts=$(echo "$common_params" | jq -r '.num_prompts') - - - - # get client and server arguments - server_params=$(echo "$params" | jq -r '.lmdeploy_server_parameters') - client_params=$(echo "$params" | jq -r '.lmdeploy_client_parameters') - server_args=$(json2args "$server_params") - client_args=$(json2args "$client_params") - qps_list=$(echo "$params" | jq -r '.qps_list') - qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') - echo "Running over qps list $qps_list" - - # check if there is enough GPU to run the test - if [[ $gpu_count -lt $tp ]]; then - echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." - continue - fi - - # prepare tokenizer - rm -rf /tokenizer_cache - mkdir /tokenizer_cache - python ../.buildkite/nightly-benchmarks/scripts/download-tokenizer.py \ - --model "$model" \ - --cachedir /tokenizer_cache - - server_command="lmdeploy serve api_server $model \ - --tp $tp \ - --server-port $port \ - $server_args" - - # run the server - echo "Running test case $test_name" - echo "Server command: $server_command" - bash -c "$server_command" & - - # wait until the server is alive - wait_for_server - if [ $? -eq 0 ]; then - echo "" - echo "lmdeploy server is up and running." - else - echo "" - echo "lmdeploy failed to start within the timeout period." - break - fi - - # get model name - model_name=$(python ../.buildkite/nightly-benchmarks/scripts/get-lmdeploy-modelname.py) - - # iterate over different QPS - for qps in $qps_list; do - # remove the surrounding single quote from qps - if [[ "$qps" == *"inf"* ]]; then - echo "qps was $qps" - qps="inf" - echo "now qps is $qps" - fi - - new_test_name=$test_name"_qps_"$qps - - client_command="python3 benchmark_serving.py \ - --backend lmdeploy \ - --tokenizer /tokenizer_cache \ - --dataset-name $dataset_name \ - --dataset-path $dataset_path \ - --num-prompts $num_prompts \ - --port $port \ - --save-result \ - --result-dir $RESULTS_FOLDER \ - --result-filename ${new_test_name}.json \ - --request-rate $qps \ - --model \"$model_name\" \ - $client_args" - - echo "Running test case $test_name with qps $qps" - echo "Client command: $client_command" - - eval "$client_command" - - # record the benchmarking commands - jq_output=$(jq -n \ - --arg server "$server_command" \ - --arg client "$client_command" \ - --arg gpu "$gpu_type" \ - --arg engine "lmdeploy" \ - '{ - server_command: $server, - client_command: $client, - gpu_type: $gpu, - engine: $engine - }') - echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands" - - done - - # clean up - kill_gpu_processes - rm -rf /root/.cache/huggingface/* - done -} - - -upload_to_buildkite() { - # upload the benchmarking results to buildkite - - # if the agent binary is not found, skip uploading the results, exit 0 - if [ ! -f /workspace/buildkite-agent ]; then - echo "buildkite-agent binary not found. Skip uploading the results." - return 0 - fi - # /workspace/buildkite-agent annotate --style "success" --context "benchmark-results" --append < $RESULTS_FOLDER/${CURRENT_LLM_SERVING_ENGINE}_nightly_results.md - /workspace/buildkite-agent artifact upload "$RESULTS_FOLDER/*" -} - - -main() { - - check_gpus - # enter vllm directory - cd $VLLM_SOURCE_CODE_LOC/benchmarks - - declare -g RESULTS_FOLDER=results/ - mkdir -p $RESULTS_FOLDER - BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/ - - python -m pip install transformers==4.41.2 - - export CURRENT_LLM_SERVING_ENGINE=lmdeploy - run_serving_tests $BENCHMARK_ROOT/tests/nightly-tests.json - python -m pip install tabulate pandas - python $BENCHMARK_ROOT/scripts/summary-nightly-results.py - upload_to_buildkite - -} - -main "$@" diff --git a/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh b/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh new file mode 100644 index 000000000000..4d01a314adc4 --- /dev/null +++ b/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh @@ -0,0 +1,462 @@ +#!/bin/bash + +set -o pipefail +set -x + +check_gpus() { + # check the number of GPUs and GPU type. + declare -g gpu_count=$(nvidia-smi --list-gpus | wc -l) + if [[ $gpu_count -gt 0 ]]; then + echo "GPU found." + else + echo "Need at least 1 GPU to run benchmarking." + exit 1 + fi + declare -g gpu_type="$(nvidia-smi --query-gpu=name --format=csv,noheader | awk '{print $2}')" + echo "GPU type is $gpu_type" +} + +check_hf_token() { + # check if HF_TOKEN is available and valid + if [[ -z "$HF_TOKEN" ]]; then + echo "Error: HF_TOKEN is not set." + exit 1 + elif [[ ! "$HF_TOKEN" =~ ^hf_ ]]; then + echo "Error: HF_TOKEN does not start with 'hf_'." + exit 1 + else + echo "HF_TOKEN is set and valid." + fi +} + + +upload_to_buildkite() { + # upload the benchmarking results to buildkite + + # if the agent binary is not found, skip uploading the results, exit 0 + if [ ! -f /workspace/buildkite-agent ]; then + echo "buildkite-agent binary not found. Skip uploading the results." + return 0 + fi + # /workspace/buildkite-agent annotate --style "success" --context "benchmark-results" --append < $RESULTS_FOLDER/${CURRENT_LLM_SERVING_ENGINE}_nightly_results.md + /workspace/buildkite-agent artifact upload "$RESULTS_FOLDER/*" +} + + +get_current_llm_serving_engine() { + + if which lmdeploy >/dev/null; then + echo "Container: lmdeploy" + export CURRENT_LLM_SERVING_ENGINE=lmdeploy + return + fi + + if [ -e /tgi-entrypoint.sh ]; then + echo "Container: tgi" + export CURRENT_LLM_SERVING_ENGINE=tgi + return + fi + + if which trtllm-build >/dev/null; then + echo "Container: tensorrt-llm" + export CURRENT_LLM_SERVING_ENGINE=trt + return + fi + + if [ -e /sgl-workspace ]; then + echo "Container: sglang" + export CURRENT_LLM_SERVING_ENGINE=sglang + return + fi + + if [ -e /vllm-workspace ]; then + echo "Container: vllm" + # move to a completely irrelevant directory, to avoid import vllm from current folder + export CURRENT_LLM_SERVING_ENGINE=vllm + + return + fi +} + +json2args() { + # transforms the JSON string to command line args, and '_' is replaced to '-' + # example: + # input: { "model": "meta-llama/Llama-2-7b-chat-hf", "tensor_parallel_size": 1 } + # output: --model meta-llama/Llama-2-7b-chat-hf --tensor-parallel-size 1 + local json_string=$1 + local args=$( + echo "$json_string" | jq -r ' + to_entries | + map("--" + (.key | gsub("_"; "-")) + " " + (.value | tostring)) | + join(" ") + ' + ) + echo "$args" +} + +kill_gpu_processes() { + pkill -f python + pkill -f python3 + pkill -f tritonserver + pkill -f pt_main_thread + pkill -f text-generation + pkill -f lmdeploy + + while [ "$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1)" -ge 1000 ]; do + sleep 1 + done +} + +wait_for_server() { + # wait for vllm server to start + # return 1 if vllm server crashes + timeout 1200 bash -c ' + until curl -s localhost:8000/v1/completions > /dev/null; do + sleep 1 + done' && return 0 || return 1 +} + +ensure_installed() { + # Ensure that the given command is installed by apt-get + local cmd=$1 + if ! which "$cmd" >/dev/null; then + apt-get update && apt-get install -y "$cmd" + fi +} + +run_serving_tests() { + # run serving tests using `benchmark_serving.py` + # $1: a json file specifying serving test cases + + local serving_test_file + serving_test_file=$1 + + # Iterate over serving tests + jq -c '.[]' "$serving_test_file" | while read -r params; do + # get the test name, and append the GPU type back to it. + test_name=$(echo "$params" | jq -r '.test_name') + + # if TEST_SELECTOR is set, only run the test cases that match the selector + if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then + echo "Skip test case $test_name." + continue + fi + + # prepend the current serving engine to the test name + test_name=${CURRENT_LLM_SERVING_ENGINE}_${test_name} + + # get common parameters + common_params=$(echo "$params" | jq -r '.common_parameters') + model=$(echo "$common_params" | jq -r '.model') + tp=$(echo "$common_params" | jq -r '.tp') + dataset_name=$(echo "$common_params" | jq -r '.dataset_name') + dataset_path=$(echo "$common_params" | jq -r '.dataset_path') + port=$(echo "$common_params" | jq -r '.port') + num_prompts=$(echo "$common_params" | jq -r '.num_prompts') + reuse_server=$(echo "$common_params" | jq -r '.reuse_server') + + # get client and server arguments + server_params=$(echo "$params" | jq -r ".${CURRENT_LLM_SERVING_ENGINE}_server_parameters") + client_params=$(echo "$params" | jq -r ".${CURRENT_LLM_SERVING_ENGINE}_client_parameters") + client_args=$(json2args "$client_params") + qps_list=$(echo "$params" | jq -r '.qps_list') + qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') + echo "Running over qps list $qps_list" + + # check if there is enough GPU to run the test + if [[ $gpu_count -lt $tp ]]; then + echo "Required num-shard $tp but only $gpu_count GPU found. Skip testcase $test_name." + continue + fi + + if [[ $reuse_server == "true" ]]; then + echo "Reuse previous server for test case $test_name" + else + kill_gpu_processes + bash "$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/launch-server.sh" \ + "$server_params" "$common_params" + fi + + if wait_for_server; then + echo "" + echo "$CURRENT_LLM_SERVING_ENGINE server is up and running." + else + echo "" + echo "$CURRENT_LLM_SERVING_ENGINE failed to start within the timeout period." + break + fi + + # prepare tokenizer + # this is required for lmdeploy. + cd "$VLLM_SOURCE_CODE_LOC/benchmarks" + rm -rf /tokenizer_cache + mkdir /tokenizer_cache + python3 ../.buildkite/nightly-benchmarks/scripts/download-tokenizer.py \ + --model "$model" \ + --cachedir /tokenizer_cache + cd "$VLLM_SOURCE_CODE_LOC/benchmarks" + + + # change model name for lmdeploy (it will not follow standard hf name) + if [[ "$CURRENT_LLM_SERVING_ENGINE" == "lmdeploy" ]]; then + model=$(python ../.buildkite/nightly-benchmarks/scripts/get-lmdeploy-modelname.py) + fi + + # iterate over different QPS + for qps in $qps_list; do + # remove the surrounding single quote from qps + if [[ "$qps" == *"inf"* ]]; then + echo "qps was $qps" + qps="inf" + echo "now qps is $qps" + fi + + new_test_name=$test_name"_qps_"$qps + + backend=$CURRENT_LLM_SERVING_ENGINE + + if [[ $backend = "trt" ]]; then + backend="tensorrt-llm" + fi + + if [[ "$backend" == *"vllm"* ]]; then + backend="vllm" + fi + + if [[ "$dataset_name" = "sharegpt" ]]; then + + client_command="python3 benchmark_serving.py \ + --backend $backend \ + --tokenizer /tokenizer_cache \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --num-prompts $num_prompts \ + --port $port \ + --save-result \ + --result-dir $RESULTS_FOLDER \ + --result-filename ${new_test_name}.json \ + --request-rate $qps \ + --ignore-eos \ + $client_args" + + elif [[ "$dataset_name" = "sonnet" ]]; then + + sonnet_input_len=$(echo "$common_params" | jq -r '.sonnet_input_len') + sonnet_output_len=$(echo "$common_params" | jq -r '.sonnet_output_len') + sonnet_prefix_len=$(echo "$common_params" | jq -r '.sonnet_prefix_len') + + client_command="python3 benchmark_serving.py \ + --backend $backend \ + --tokenizer /tokenizer_cache \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --num-prompts $num_prompts \ + --sonnet-input-len $sonnet_input_len \ + --sonnet-output-len $sonnet_output_len \ + --sonnet-prefix-len $sonnet_prefix_len \ + --port $port \ + --save-result \ + --result-dir $RESULTS_FOLDER \ + --result-filename ${new_test_name}.json \ + --request-rate $qps \ + --ignore-eos \ + $client_args" + + else + + echo "The dataset name must be either 'sharegpt' or 'sonnet'. Got $dataset_name." + exit 1 + + fi + + + + echo "Running test case $test_name with qps $qps" + echo "Client command: $client_command" + + eval "$client_command" + + server_command="None" + + # record the benchmarking commands + jq_output=$(jq -n \ + --arg server "$server_command" \ + --arg client "$client_command" \ + --arg gpu "$gpu_type" \ + --arg engine "$CURRENT_LLM_SERVING_ENGINE" \ + '{ + server_command: $server, + client_command: $client, + gpu_type: $gpu, + engine: $engine + }') + echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands" + + done + + done + + kill_gpu_processes +} + +run_genai_perf_tests() { + # run genai-perf tests + + # $1: a json file specifying genai-perf test cases + local genai_perf_test_file + genai_perf_test_file=$1 + + # Iterate over genai-perf tests + jq -c '.[]' "$genai_perf_test_file" | while read -r params; do + # get the test name, and append the GPU type back to it. + test_name=$(echo "$params" | jq -r '.test_name') + + # if TEST_SELECTOR is set, only run the test cases that match the selector + if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then + echo "Skip test case $test_name." + continue + fi + + # prepend the current serving engine to the test name + test_name=${CURRENT_LLM_SERVING_ENGINE}_${test_name} + + # get common parameters + common_params=$(echo "$params" | jq -r '.common_parameters') + model=$(echo "$common_params" | jq -r '.model') + tp=$(echo "$common_params" | jq -r '.tp') + dataset_name=$(echo "$common_params" | jq -r '.dataset_name') + dataset_path=$(echo "$common_params" | jq -r '.dataset_path') + port=$(echo "$common_params" | jq -r '.port') + num_prompts=$(echo "$common_params" | jq -r '.num_prompts') + reuse_server=$(echo "$common_params" | jq -r '.reuse_server') + + # get client and server arguments + server_params=$(echo "$params" | jq -r ".${CURRENT_LLM_SERVING_ENGINE}_server_parameters") + qps_list=$(echo "$params" | jq -r '.qps_list') + qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') + echo "Running over qps list $qps_list" + + # check if there is enough GPU to run the test + if [[ $gpu_count -lt $tp ]]; then + echo "Required num-shard $tp but only $gpu_count GPU found. Skip testcase $test_name." + continue + fi + + if [[ $reuse_server == "true" ]]; then + echo "Reuse previous server for test case $test_name" + else + kill_gpu_processes + bash "$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/launch-server.sh" \ + "$server_params" "$common_params" + fi + + if wait_for_server; then + echo "" + echo "$CURRENT_LLM_SERVING_ENGINE server is up and running." + else + echo "" + echo "$CURRENT_LLM_SERVING_ENGINE failed to start within the timeout period." + break + fi + + # iterate over different QPS + for qps in $qps_list; do + # remove the surrounding single quote from qps + if [[ "$qps" == *"inf"* ]]; then + echo "qps was $qps" + qps=$num_prompts + echo "now qps is $qps" + fi + + new_test_name=$test_name"_qps_"$qps + backend=$CURRENT_LLM_SERVING_ENGINE + + if [[ "$backend" == *"vllm"* ]]; then + backend="vllm" + fi + #TODO: add output dir. + client_command="genai-perf profile \ + -m $model \ + --service-kind openai \ + --backend vllm \ + --endpoint-type chat \ + --streaming \ + --url localhost:$port \ + --request-rate $qps \ + --num-prompts $num_prompts \ + " + + echo "Client command: $client_command" + + eval "$client_command" + + #TODO: process/record outputs + done + done + + kill_gpu_processes + +} + +prepare_dataset() { + + # download sharegpt dataset + cd "$VLLM_SOURCE_CODE_LOC/benchmarks" + wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + + # duplicate sonnet by 4x, to allow benchmarking with input length 2048 + cd "$VLLM_SOURCE_CODE_LOC/benchmarks" + echo "" > sonnet_4x.txt + for _ in {1..4} + do + cat sonnet.txt >> sonnet_4x.txt + done + +} + +main() { + + # check if the environment variable is successfully injected from yaml + + check_gpus + check_hf_token + get_current_llm_serving_engine + + pip install -U transformers + + pip install -r requirements/dev.txt + which genai-perf + + # check storage + df -h + + ensure_installed wget + ensure_installed curl + ensure_installed jq + # genai-perf dependency + ensure_installed libb64-0d + + prepare_dataset + + cd "$VLLM_SOURCE_CODE_LOC/benchmarks" + declare -g RESULTS_FOLDER=results/ + mkdir -p $RESULTS_FOLDER + BENCHMARK_ROOT="$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/" + + # run the test + run_serving_tests "$BENCHMARK_ROOT/tests/nightly-tests.json" + + # run genai-perf tests + run_genai_perf_tests "$BENCHMARK_ROOT/tests/genai-perf-tests.json" + mv artifacts/ $RESULTS_FOLDER/ + + # upload benchmark results to buildkite + python3 -m pip install tabulate pandas + python3 "$BENCHMARK_ROOT/scripts/summary-nightly-results.py" + upload_to_buildkite + +} + +main "$@" diff --git a/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh new file mode 100644 index 000000000000..4cd449b141ec --- /dev/null +++ b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh @@ -0,0 +1,385 @@ +#!/bin/bash + +# This script should be run inside the CI process +# This script assumes that we are already inside the vllm/ directory +# Benchmarking results will be available inside vllm/benchmarks/results/ + +# Do not set -e, as the mixtral 8x22B model tends to crash occasionally +# and we still want to see other benchmarking results even when mixtral crashes. +set -x +set -o pipefail + +check_gpus() { + # check the number of GPUs and GPU type. + declare -g gpu_count=$(nvidia-smi --list-gpus | wc -l) + if [[ $gpu_count -gt 0 ]]; then + echo "GPU found." + else + echo "Need at least 1 GPU to run benchmarking." + exit 1 + fi + declare -g gpu_type=$(nvidia-smi --query-gpu=name --format=csv,noheader | awk '{print $2}') + echo "GPU type is $gpu_type" +} + +check_hf_token() { + # check if HF_TOKEN is available and valid + if [[ -z "$HF_TOKEN" ]]; then + echo "Error: HF_TOKEN is not set." + exit 1 + elif [[ ! "$HF_TOKEN" =~ ^hf_ ]]; then + echo "Error: HF_TOKEN does not start with 'hf_'." + exit 1 + else + echo "HF_TOKEN is set and valid." + fi +} + +ensure_sharegpt_downloaded() { + local FILE=ShareGPT_V3_unfiltered_cleaned_split.json + if [ ! -f "$FILE" ]; then + wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/$FILE + else + echo "$FILE already exists." + fi +} + +json2args() { + # transforms the JSON string to command line args, and '_' is replaced to '-' + # example: + # input: { "model": "meta-llama/Llama-2-7b-chat-hf", "tensor_parallel_size": 1 } + # output: --model meta-llama/Llama-2-7b-chat-hf --tensor-parallel-size 1 + local json_string=$1 + local args=$( + echo "$json_string" | jq -r ' + to_entries | + map("--" + (.key | gsub("_"; "-")) + " " + (.value | tostring)) | + join(" ") + ' + ) + echo "$args" +} + +wait_for_server() { + # wait for vllm server to start + # return 1 if vllm server crashes + timeout 1200 bash -c ' + until curl -X POST localhost:8000/v1/completions; do + sleep 1 + done' && return 0 || return 1 +} + +kill_processes_launched_by_current_bash() { + # Kill all python processes launched from current bash script + current_shell_pid=$$ + processes=$(ps -eo pid,ppid,command | awk -v ppid="$current_shell_pid" -v proc="$1" '$2 == ppid && $3 ~ proc {print $1}') + if [ -n "$processes" ]; then + echo "Killing the following processes matching '$1':" + echo "$processes" + echo "$processes" | xargs kill -9 + else + echo "No processes found matching '$1'." + fi +} + +kill_gpu_processes() { + + ps -aux + lsof -t -i:8000 | xargs -r kill -9 + pgrep python3 | xargs -r kill -9 + + + # wait until GPU memory usage smaller than 1GB + while [ "$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1)" -ge 1000 ]; do + sleep 1 + done + + # remove vllm config file + rm -rf ~/.config/vllm + +} + +upload_to_buildkite() { + # upload the benchmarking results to buildkite + + # if the agent binary is not found, skip uploading the results, exit 0 + # Check if buildkite-agent is available in the PATH or at /workspace/buildkite-agent + if command -v buildkite-agent >/dev/null 2>&1; then + BUILDKITE_AGENT_COMMAND="buildkite-agent" + elif [ -f /workspace/buildkite-agent ]; then + BUILDKITE_AGENT_COMMAND="/workspace/buildkite-agent" + else + echo "buildkite-agent binary not found. Skip uploading the results." + return 0 + fi + + # Use the determined command to annotate and upload artifacts + $BUILDKITE_AGENT_COMMAND annotate --style "info" --context "$BUILDKITE_LABEL-benchmark-results" < "$RESULTS_FOLDER/benchmark_results.md" + $BUILDKITE_AGENT_COMMAND artifact upload "$RESULTS_FOLDER/*" +} + +run_latency_tests() { + # run latency tests using `benchmark_latency.py` + # $1: a json file specifying latency test cases + + local latency_test_file + latency_test_file=$1 + + # Iterate over latency tests + jq -c '.[]' "$latency_test_file" | while read -r params; do + # get the test name, and append the GPU type back to it. + test_name=$(echo "$params" | jq -r '.test_name') + if [[ ! "$test_name" =~ ^latency_ ]]; then + echo "In latency-test.json, test_name must start with \"latency_\"." + exit 1 + fi + + # if TEST_SELECTOR is set, only run the test cases that match the selector + if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then + echo "Skip test case $test_name." + continue + fi + + # get arguments + latency_params=$(echo "$params" | jq -r '.parameters') + latency_args=$(json2args "$latency_params") + + # check if there is enough GPU to run the test + tp=$(echo "$latency_params" | jq -r '.tensor_parallel_size') + if [[ $gpu_count -lt $tp ]]; then + echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." + continue + fi + + latency_command="python3 benchmark_latency.py \ + --output-json $RESULTS_FOLDER/${test_name}.json \ + $latency_args" + + echo "Running test case $test_name" + echo "Latency command: $latency_command" + + # recoding benchmarking command ang GPU command + jq_output=$(jq -n \ + --arg latency "$latency_command" \ + --arg gpu "$gpu_type" \ + '{ + latency_command: $latency, + gpu_type: $gpu + }') + echo "$jq_output" >"$RESULTS_FOLDER/$test_name.commands" + + # run the benchmark + eval "$latency_command" + + kill_gpu_processes + + done +} + +run_throughput_tests() { + # run throughput tests using `benchmark_throughput.py` + # $1: a json file specifying throughput test cases + + local throughput_test_file + throughput_test_file=$1 + + # Iterate over throughput tests + jq -c '.[]' "$throughput_test_file" | while read -r params; do + # get the test name, and append the GPU type back to it. + test_name=$(echo "$params" | jq -r '.test_name') + if [[ ! "$test_name" =~ ^throughput_ ]]; then + echo "In throughput-test.json, test_name must start with \"throughput_\"." + exit 1 + fi + + # if TEST_SELECTOR is set, only run the test cases that match the selector + if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then + echo "Skip test case $test_name." + continue + fi + + # get arguments + throughput_params=$(echo "$params" | jq -r '.parameters') + throughput_args=$(json2args "$throughput_params") + + # check if there is enough GPU to run the test + tp=$(echo "$throughput_params" | jq -r '.tensor_parallel_size') + if [[ $gpu_count -lt $tp ]]; then + echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." + continue + fi + + throughput_command="python3 benchmark_throughput.py \ + --output-json $RESULTS_FOLDER/${test_name}.json \ + $throughput_args" + + echo "Running test case $test_name" + echo "Throughput command: $throughput_command" + # recoding benchmarking command ang GPU command + jq_output=$(jq -n \ + --arg command "$throughput_command" \ + --arg gpu "$gpu_type" \ + '{ + throughput_command: $command, + gpu_type: $gpu + }') + echo "$jq_output" >"$RESULTS_FOLDER/$test_name.commands" + + # run the benchmark + eval "$throughput_command" + + kill_gpu_processes + + done +} + +run_serving_tests() { + # run serving tests using `benchmark_serving.py` + # $1: a json file specifying serving test cases + + local serving_test_file + serving_test_file=$1 + + # Iterate over serving tests + jq -c '.[]' "$serving_test_file" | while read -r params; do + # get the test name, and append the GPU type back to it. + test_name=$(echo "$params" | jq -r '.test_name') + if [[ ! "$test_name" =~ ^serving_ ]]; then + echo "In serving-test.json, test_name must start with \"serving_\"." + exit 1 + fi + + # if TEST_SELECTOR is set, only run the test cases that match the selector + if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then + echo "Skip test case $test_name." + continue + fi + + # get client and server arguments + server_params=$(echo "$params" | jq -r '.server_parameters') + client_params=$(echo "$params" | jq -r '.client_parameters') + server_args=$(json2args "$server_params") + client_args=$(json2args "$client_params") + qps_list=$(echo "$params" | jq -r '.qps_list') + qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') + echo "Running over qps list $qps_list" + + # check if there is enough GPU to run the test + tp=$(echo "$server_params" | jq -r '.tensor_parallel_size') + if [[ $gpu_count -lt $tp ]]; then + echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." + continue + fi + + # check if server model and client model is aligned + server_model=$(echo "$server_params" | jq -r '.model') + client_model=$(echo "$client_params" | jq -r '.model') + if [[ $server_model != "$client_model" ]]; then + echo "Server model and client model must be the same. Skip testcase $test_name." + continue + fi + + server_command="python3 \ + -m vllm.entrypoints.openai.api_server \ + $server_args" + + # run the server + echo "Running test case $test_name" + echo "Server command: $server_command" + bash -c "$server_command" & + server_pid=$! + + # wait until the server is alive + if wait_for_server; then + echo "" + echo "vllm server is up and running." + else + echo "" + echo "vllm failed to start within the timeout period." + fi + + # iterate over different QPS + for qps in $qps_list; do + # remove the surrounding single quote from qps + if [[ "$qps" == *"inf"* ]]; then + echo "qps was $qps" + qps="inf" + echo "now qps is $qps" + fi + + new_test_name=$test_name"_qps_"$qps + + # pass the tensor parallel size to the client so that it can be displayed + # on the benchmark dashboard + client_command="python3 benchmark_serving.py \ + --save-result \ + --result-dir $RESULTS_FOLDER \ + --result-filename ${new_test_name}.json \ + --request-rate $qps \ + --metadata "tensor_parallel_size=$tp" \ + $client_args" + + echo "Running test case $test_name with qps $qps" + echo "Client command: $client_command" + + bash -c "$client_command" + + # record the benchmarking commands + jq_output=$(jq -n \ + --arg server "$server_command" \ + --arg client "$client_command" \ + --arg gpu "$gpu_type" \ + '{ + server_command: $server, + client_command: $client, + gpu_type: $gpu + }') + echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands" + + done + + # clean up + kill -9 $server_pid + kill_gpu_processes + done +} + +main() { + check_gpus + check_hf_token + + # Set to v1 to run v1 benchmark + if [[ "${ENGINE_VERSION:-v0}" == "v1" ]]; then + export VLLM_USE_V1=1 + fi + + # dependencies + (which wget && which curl) || (apt-get update && apt-get install -y wget curl) + (which jq) || (apt-get update && apt-get -y install jq) + (which lsof) || (apt-get update && apt-get install -y lsof) + + # get the current IP address, required by benchmark_serving.py + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + # turn of the reporting of the status of each request, to clean up the terminal output + export VLLM_LOGGING_LEVEL="WARNING" + + # prepare for benchmarking + cd benchmarks || exit 1 + ensure_sharegpt_downloaded + declare -g RESULTS_FOLDER=results/ + mkdir -p $RESULTS_FOLDER + QUICK_BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/ + + # benchmarking + run_serving_tests $QUICK_BENCHMARK_ROOT/tests/serving-tests.json + run_latency_tests $QUICK_BENCHMARK_ROOT/tests/latency-tests.json + run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/throughput-tests.json + + # postprocess benchmarking results + pip install tabulate pandas + python3 $QUICK_BENCHMARK_ROOT/scripts/convert-results-json-to-markdown.py + + upload_to_buildkite +} + +main "$@" diff --git a/.buildkite/nightly-benchmarks/scripts/run-tgi-nightly.sh b/.buildkite/nightly-benchmarks/scripts/run-tgi-nightly.sh deleted file mode 100644 index fed03654f8b7..000000000000 --- a/.buildkite/nightly-benchmarks/scripts/run-tgi-nightly.sh +++ /dev/null @@ -1,216 +0,0 @@ -#!/bin/bash - -set -o pipefail - -check_gpus() { - # check the number of GPUs and GPU type. - declare -g gpu_count=$(nvidia-smi --list-gpus | wc -l) - if [[ $gpu_count -gt 0 ]]; then - echo "GPU found." - else - echo "Need at least 1 GPU to run benchmarking." - exit 1 - fi - declare -g gpu_type=$(echo $(nvidia-smi --query-gpu=name --format=csv,noheader) | awk '{print $2}') - echo "GPU type is $gpu_type" -} - -kill_gpu_processes() { - pkill text-generation || true - # waiting for GPU processes to be fully killed - sleep 10 - # Print the GPU memory usage - # so that we know if all GPU processes are killed. - gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) - # The memory usage should be 0 MB. - echo "GPU 0 Memory Usage: $gpu_memory_usage MB" -} - -json2args() { - # transforms the JSON string to command line args, and '_' is replaced to '-' - # example: - # input: { "model": "meta-llama/Llama-2-7b-chat-hf", "tensor_parallel_size": 1 } - # output: --model meta-llama/Llama-2-7b-chat-hf --tensor-parallel-size 1 - local json_string=$1 - local args=$( - echo "$json_string" | jq -r ' - to_entries | - map("--" + (.key | gsub("_"; "-")) + " " + (.value | tostring)) | - join(" ") - ' - ) - echo "$args" -} - -wait_for_server() { - timeout 1200 bash -c ' - until curl -s localhost:8000/generate_stream > /dev/null; do - sleep 1 - done' && return 0 || return 1 -} - -run_serving_tests() { - # run serving tests using `benchmark_serving.py` - # $1: a json file specifying serving test cases - - local serving_test_file - serving_test_file=$1 - - # Iterate over serving tests - jq -c '.[]' "$serving_test_file" | while read -r params; do - # get the test name, and append the GPU type back to it. - test_name=$(echo "$params" | jq -r '.test_name') - - - # if TEST_SELECTOR is set, only run the test cases that match the selector - if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then - echo "Skip test case $test_name." - continue - fi - - # append tgi to the test name - test_name=tgi_$test_name - - # get common parameters - common_params=$(echo "$params" | jq -r '.common_parameters') - model=$(echo "$common_params" | jq -r '.model') - tp=$(echo "$common_params" | jq -r '.tp') - dataset_name=$(echo "$common_params" | jq -r '.dataset_name') - dataset_path=$(echo "$common_params" | jq -r '.dataset_path') - port=$(echo "$common_params" | jq -r '.port') - num_prompts=$(echo "$common_params" | jq -r '.num_prompts') - - # get client and server arguments - server_params=$(echo "$params" | jq -r '.tgi_server_parameters') - client_params=$(echo "$params" | jq -r '.tgi_client_parameters') - server_args=$(json2args "$server_params") - client_args=$(json2args "$client_params") - qps_list=$(echo "$params" | jq -r '.qps_list') - qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') - echo "Running over qps list $qps_list" - - # check if there is enough GPU to run the test - if [[ $gpu_count -lt $tp ]]; then - echo "Required num-shard $tp but only $gpu_count GPU found. Skip testcase $test_name." - continue - fi - - if echo "$common_params" | jq -e 'has("fp8")' > /dev/null; then - echo "Key 'fp8' exists in common params." - server_command="/tgi-entrypoint.sh \ - --model-id $model \ - --num-shard $tp \ - --port $port \ - --quantize fp8 \ - $server_args" - else - echo "Key 'fp8' does not exist in common params." - server_command="/tgi-entrypoint.sh \ - --model-id $model \ - --num-shard $tp \ - --port $port \ - $server_args" - fi - - - - - # run the server - echo "Running test case $test_name" - echo "Server command: $server_command" - eval "$server_command" & - - # wait until the server is alive - wait_for_server - if [ $? -eq 0 ]; then - echo "" - echo "tgi server is up and running." - else - echo "" - echo "tgi failed to start within the timeout period." - break - fi - - # iterate over different QPS - for qps in $qps_list; do - # remove the surrounding single quote from qps - if [[ "$qps" == *"inf"* ]]; then - echo "qps was $qps" - qps="inf" - echo "now qps is $qps" - fi - - new_test_name=$test_name"_qps_"$qps - - client_command="python3 benchmark_serving.py \ - --backend tgi \ - --model $model \ - --dataset-name $dataset_name \ - --dataset-path $dataset_path \ - --num-prompts $num_prompts \ - --port $port \ - --save-result \ - --result-dir $RESULTS_FOLDER \ - --result-filename ${new_test_name}.json \ - --request-rate $qps \ - $client_args" - - echo "Running test case $test_name with qps $qps" - echo "Client command: $client_command" - - eval "$client_command" - - # record the benchmarking commands - jq_output=$(jq -n \ - --arg server "$server_command" \ - --arg client "$client_command" \ - --arg gpu "$gpu_type" \ - --arg engine "tgi" \ - '{ - server_command: $server, - client_command: $client, - gpu_type: $gpu, - engine: $engine - }') - echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands" - - done - - # clean up - kill_gpu_processes - rm -rf /root/.cache/huggingface/* - done -} - - - -upload_to_buildkite() { - # upload the benchmarking results to buildkite - - # if the agent binary is not found, skip uploading the results, exit 0 - if [ ! -f /workspace/buildkite-agent ]; then - echo "buildkite-agent binary not found. Skip uploading the results." - return 0 - fi - # /workspace/buildkite-agent annotate --style "success" --context "benchmark-results" --append < $RESULTS_FOLDER/${CURRENT_LLM_SERVING_ENGINE}_nightly_results.md - /workspace/buildkite-agent artifact upload "$RESULTS_FOLDER/*" -} - -main() { - - check_gpus - # enter vllm directory - cd $VLLM_SOURCE_CODE_LOC/benchmarks - declare -g RESULTS_FOLDER=results/ - mkdir -p $RESULTS_FOLDER - BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/ - - export CURRENT_LLM_SERVING_ENGINE=tgi - run_serving_tests $BENCHMARK_ROOT/tests/nightly-tests.json - python -m pip install tabulate pandas - python $BENCHMARK_ROOT/scripts/summary-nightly-results.py - upload_to_buildkite - -} - -main "$@" diff --git a/.buildkite/nightly-benchmarks/scripts/run-trt-nightly.sh b/.buildkite/nightly-benchmarks/scripts/run-trt-nightly.sh deleted file mode 100644 index 4a82b9ec64d7..000000000000 --- a/.buildkite/nightly-benchmarks/scripts/run-trt-nightly.sh +++ /dev/null @@ -1,214 +0,0 @@ -#!/bin/bash - -set -o pipefail - -check_gpus() { - # check the number of GPUs and GPU type. - declare -g gpu_count=$(nvidia-smi --list-gpus | wc -l) - if [[ $gpu_count -gt 0 ]]; then - echo "GPU found." - else - echo "Need at least 1 GPU to run benchmarking." - exit 1 - fi - declare -g gpu_type=$(echo $(nvidia-smi --query-gpu=name --format=csv,noheader) | awk '{print $2}') - echo "GPU type is $gpu_type" -} - -kill_gpu_processes() { - pkill tritonserver || true - # waiting for GPU processes to be fully killed - sleep 20 - # Print the GPU memory usage - # so that we know if all GPU processes are killed. - gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) - # The memory usage should be 0 MB. - echo "GPU 0 Memory Usage: $gpu_memory_usage MB" -} - -json2args() { - # transforms the JSON string to command line args, and '_' is replaced to '-' - # example: - # input: { "model": "meta-llama/Llama-2-7b-chat-hf", "tensor_parallel_size": 1 } - # output: --model meta-llama/Llama-2-7b-chat-hf --tensor-parallel-size 1 - local json_string=$1 - local args=$( - echo "$json_string" | jq -r ' - to_entries | - map("--" + (.key | gsub("_"; "-")) + " " + (.value | tostring)) | - join(" ") - ' - ) - echo "$args" -} - -wait_for_server() { - timeout 1200 bash -c ' - until curl -s localhost:8000/generate_stream > /dev/null; do - sleep 1 - done' && return 0 || return 1 -} - -run_serving_tests() { - # run serving tests using `benchmark_serving.py` - # $1: a json file specifying serving test cases - - local serving_test_file - serving_test_file=$1 - - # Iterate over serving tests - jq -c '.[]' "$serving_test_file" | while read -r params; do - # get the test name, and append the GPU type back to it. - test_name=$(echo "$params" | jq -r '.test_name') - - # if TEST_SELECTOR is set, only run the test cases that match the selector - if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then - echo "Skip test case $test_name." - continue - fi - - # append trt to the test name - test_name=trt_$test_name - - # get common parameters - common_params=$(echo "$params" | jq -r '.common_parameters') - model=$(echo "$common_params" | jq -r '.model') - tp=$(echo "$common_params" | jq -r '.tp') - dataset_name=$(echo "$common_params" | jq -r '.dataset_name') - dataset_path=$(echo "$common_params" | jq -r '.dataset_path') - port=$(echo "$common_params" | jq -r '.port') - num_prompts=$(echo "$common_params" | jq -r '.num_prompts') - - # get client and server arguments - server_params=$(echo "$params" | jq -r '.trt_server_parameters') - client_params=$(echo "$params" | jq -r '.trt_client_parameters') - client_args=$(json2args "$client_params") - qps_list=$(echo "$params" | jq -r '.qps_list') - qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') - echo "Running over qps list $qps_list" - - # check if there is enough GPU to run the test - if [[ $gpu_count -lt $tp ]]; then - echo "Required model_tp_size $tp but only $gpu_count GPU found. Skip testcase $test_name." - continue - fi - - - - cd $VLLM_SOURCE_CODE_LOC/benchmarks - - - echo "Running test case $test_name" - bash ../.buildkite/nightly-benchmarks/scripts/launch-trt-server.sh "$server_params" "$common_params" - - # wait until the server is alive - wait_for_server - if [ $? -eq 0 ]; then - echo "" - echo "trt server is up and running." - else - echo "" - echo "trt failed to start within the timeout period." - break - fi - - # prepare tokenizer - cd $VLLM_SOURCE_CODE_LOC/benchmarks - rm -rf /tokenizer_cache - mkdir /tokenizer_cache - python ../.buildkite/nightly-benchmarks/scripts/download-tokenizer.py \ - --model "$model" \ - --cachedir /tokenizer_cache - cd $VLLM_SOURCE_CODE_LOC/benchmarks - - - # iterate over different QPS - for qps in $qps_list; do - # remove the surrounding single quote from qps - if [[ "$qps" == *"inf"* ]]; then - echo "qps was $qps" - qps="inf" - echo "now qps is $qps" - fi - - new_test_name=$test_name"_qps_"$qps - - client_command="python3 benchmark_serving.py \ - --backend tensorrt-llm \ - --tokenizer /tokenizer_cache \ - --model $model \ - --dataset-name $dataset_name \ - --dataset-path $dataset_path \ - --num-prompts $num_prompts \ - --port $port \ - --save-result \ - --result-dir $RESULTS_FOLDER \ - --result-filename ${new_test_name}.json \ - --request-rate $qps \ - $client_args" - - echo "Running test case $test_name with qps $qps" - echo "Client command: $client_command" - - eval "$client_command" - - server_command="" - # record the benchmarking commands - jq_output=$(jq -n \ - --arg server "$server_command" \ - --arg client "$client_command" \ - --arg gpu "$gpu_type" \ - --arg engine "trt" \ - '{ - server_command: $server, - client_command: $client, - gpu_type: $gpu, - engine: $engine - }') - echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands" - - done - - # clean up - kill_gpu_processes - rm -rf /root/.cache/huggingface/* - done -} - -upload_to_buildkite() { - # upload the benchmarking results to buildkite - - # if the agent binary is not found, skip uploading the results, exit 0 - if [ ! -f /workspace/buildkite-agent ]; then - echo "buildkite-agent binary not found. Skip uploading the results." - return 0 - fi - # /workspace/buildkite-agent annotate --style "success" --context "benchmark-results" --append < $RESULTS_FOLDER/${CURRENT_LLM_SERVING_ENGINE}_nightly_results.md - /workspace/buildkite-agent artifact upload "$RESULTS_FOLDER/*" -} - - -main() { - - check_gpus - - - # enter vllm directory - cd $VLLM_SOURCE_CODE_LOC/benchmarks - - declare -g RESULTS_FOLDER=results/ - mkdir -p $RESULTS_FOLDER - BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/ - - # update transformers package, to make sure mixtral tokenizer is available - python -m pip install transformers -U - - export CURRENT_LLM_SERVING_ENGINE=trt - run_serving_tests $BENCHMARK_ROOT/tests/nightly-tests.json - python -m pip install tabulate pandas - python $BENCHMARK_ROOT/scripts/summary-nightly-results.py - upload_to_buildkite - -} - -main "$@" diff --git a/.buildkite/nightly-benchmarks/scripts/run-vllm-nightly.sh b/.buildkite/nightly-benchmarks/scripts/run-vllm-nightly.sh deleted file mode 100644 index 663045b8a912..000000000000 --- a/.buildkite/nightly-benchmarks/scripts/run-vllm-nightly.sh +++ /dev/null @@ -1,221 +0,0 @@ -#!/bin/bash - -set -o pipefail - -check_gpus() { - # check the number of GPUs and GPU type. - declare -g gpu_count=$(nvidia-smi --list-gpus | wc -l) - if [[ $gpu_count -gt 0 ]]; then - echo "GPU found." - else - echo "Need at least 1 GPU to run benchmarking." - exit 1 - fi - declare -g gpu_type=$(echo $(nvidia-smi --query-gpu=name --format=csv,noheader) | awk '{print $2}') - echo "GPU type is $gpu_type" -} - -kill_gpu_processes() { - # kill all processes on GPU. - pkill pt_main_thread - sleep 10 - - # remove vllm config file - rm -rf ~/.config/vllm - - # Print the GPU memory usage - # so that we know if all GPU processes are killed. - gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) - # The memory usage should be 0 MB. - echo "GPU 0 Memory Usage: $gpu_memory_usage MB" -} - -json2args() { - # transforms the JSON string to command line args, and '_' is replaced to '-' - # example: - # input: { "model": "meta-llama/Llama-2-7b-chat-hf", "tensor_parallel_size": 1 } - # output: --model meta-llama/Llama-2-7b-chat-hf --tensor-parallel-size 1 - local json_string=$1 - local args=$( - echo "$json_string" | jq -r ' - to_entries | - map("--" + (.key | gsub("_"; "-")) + " " + (.value | tostring)) | - join(" ") - ' - ) - echo "$args" -} - -wait_for_server() { - # wait for vllm server to start - # return 1 if vllm server crashes - timeout 1200 bash -c ' - until curl -s localhost:8000/v1/completions > /dev/null; do - sleep 1 - done' && return 0 || return 1 -} - -run_serving_tests() { - # run serving tests using `benchmark_serving.py` - # $1: a json file specifying serving test cases - - local serving_test_file - serving_test_file=$1 - - # Iterate over serving tests - jq -c '.[]' "$serving_test_file" | while read -r params; do - # get the test name, and append the GPU type back to it. - test_name=$(echo "$params" | jq -r '.test_name') - - # if TEST_SELECTOR is set, only run the test cases that match the selector - if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then - echo "Skip test case $test_name." - continue - fi - - # append vllm to the test name - test_name=vllm_$test_name - - - # get common parameters - common_params=$(echo "$params" | jq -r '.common_parameters') - model=$(echo "$common_params" | jq -r '.model') - tp=$(echo "$common_params" | jq -r '.tp') - dataset_name=$(echo "$common_params" | jq -r '.dataset_name') - dataset_path=$(echo "$common_params" | jq -r '.dataset_path') - port=$(echo "$common_params" | jq -r '.port') - num_prompts=$(echo "$common_params" | jq -r '.num_prompts') - - # get client and server arguments - server_params=$(echo "$params" | jq -r '.vllm_server_parameters') - client_params=$(echo "$params" | jq -r '.vllm_client_parameters') - server_args=$(json2args "$server_params") - client_args=$(json2args "$client_params") - qps_list=$(echo "$params" | jq -r '.qps_list') - qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') - echo "Running over qps list $qps_list" - - # check if there is enough GPU to run the test - if [[ $gpu_count -lt $tp ]]; then - echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." - continue - fi - - if echo "$common_params" | jq -e 'has("fp8")' > /dev/null; then - echo "Key 'fp8' exists in common params. Use neuralmagic fp8 model for convenience." - model=$(echo "$common_params" | jq -r '.neuralmagic_quantized_model') - server_command="python3 \ - -m vllm.entrypoints.openai.api_server \ - -tp $tp \ - --model $model \ - --port $port \ - $server_args" - else - echo "Key 'fp8' does not exist in common params." - server_command="python3 \ - -m vllm.entrypoints.openai.api_server \ - -tp $tp \ - --model $model \ - --port $port \ - $server_args" - fi - - # run the server - echo "Running test case $test_name" - echo "Server command: $server_command" - eval "$server_command" & - - # wait until the server is alive - wait_for_server - if [ $? -eq 0 ]; then - echo "" - echo "vllm server is up and running." - else - echo "" - echo "vllm failed to start within the timeout period." - break - fi - - # iterate over different QPS - for qps in $qps_list; do - # remove the surrounding single quote from qps - if [[ "$qps" == *"inf"* ]]; then - echo "qps was $qps" - qps="inf" - echo "now qps is $qps" - fi - - new_test_name=$test_name"_qps_"$qps - - client_command="python3 benchmark_serving.py \ - --backend vllm \ - --model $model \ - --dataset-name $dataset_name \ - --dataset-path $dataset_path \ - --num-prompts $num_prompts \ - --port $port \ - --save-result \ - --result-dir $RESULTS_FOLDER \ - --result-filename ${new_test_name}.json \ - --request-rate $qps \ - $client_args" - - echo "Running test case $test_name with qps $qps" - echo "Client command: $client_command" - - eval "$client_command" - - # record the benchmarking commands - jq_output=$(jq -n \ - --arg server "$server_command" \ - --arg client "$client_command" \ - --arg gpu "$gpu_type" \ - --arg engine "vllm" \ - '{ - server_command: $server, - client_command: $client, - gpu_type: $gpu, - engine: $engine - }') - echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands" - - done - - # clean up - kill_gpu_processes - rm -rf /root/.cache/huggingface/* - done -} - - -upload_to_buildkite() { - # upload the benchmarking results to buildkite - - # if the agent binary is not found, skip uploading the results, exit 0 - if [ ! -f /workspace/buildkite-agent ]; then - echo "buildkite-agent binary not found. Skip uploading the results." - return 0 - fi - # /workspace/buildkite-agent annotate --style "success" --context "benchmark-results" --append < $RESULTS_FOLDER/${CURRENT_LLM_SERVING_ENGINE}_nightly_results.md - /workspace/buildkite-agent artifact upload "$RESULTS_FOLDER/*" -} - -main() { - - check_gpus - # enter vllm directory - cd $VLLM_SOURCE_CODE_LOC/benchmarks - declare -g RESULTS_FOLDER=results/ - mkdir -p $RESULTS_FOLDER - BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/ - - export CURRENT_LLM_SERVING_ENGINE=vllm - run_serving_tests $BENCHMARK_ROOT/tests/nightly-tests.json - - python3 -m pip install tabulate pandas - python3 $BENCHMARK_ROOT/scripts/summary-nightly-results.py - upload_to_buildkite - -} - -main "$@" diff --git a/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py b/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py index 782d1ef9aab9..62ee5e10b509 100644 --- a/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py +++ b/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + import datetime import json import os @@ -17,10 +19,17 @@ "request_throughput": "Tput (req/s)", "mean_ttft_ms": "Mean TTFT (ms)", "std_ttft_ms": "Std TTFT (ms)", + "median_ttft_ms": "Median TTFT (ms)", "mean_itl_ms": "Mean ITL (ms)", "std_itl_ms": "Std ITL (ms)", - "input_throughput": "Input Tput (tok/s)", + "median_itl_ms": "Median ITL (ms)", + "mean_tpot_ms": "Mean TPOT (ms)", + "std_tpot_ms": "Std TPOT (ms)", + "median_tpot_ms": "Median TPOT (ms)", + "total_token_throughput": "Total Token Tput (tok/s)", "output_throughput": "Output Tput (tok/s)", + "total_input_tokens": "Total input tokens", + "total_output_tokens": "Total output tokens", "engine": "Engine", } @@ -29,11 +38,11 @@ # collect results for test_file in results_folder.glob("*.json"): - with open(test_file, "r") as f: + with open(test_file) as f: raw_result = json.loads(f.read()) # attach the benchmarking command to raw_result - with open(test_file.with_suffix(".commands"), "r") as f: + with open(test_file.with_suffix(".commands")) as f: command = json.loads(f.read()) raw_result.update(command) diff --git a/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh b/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh index c785e6a0da62..50e1ab024220 100644 --- a/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh +++ b/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh @@ -1,10 +1,16 @@ #!/bin/sh -TOKEN=$(curl -s -L "https://public.ecr.aws/token?service=public.ecr.aws&scope=repository:q9t5s3a7/vllm-ci-test-repo:pull" | jq -r .token) -URL="https://public.ecr.aws/v2/q9t5s3a7/vllm-ci-test-repo/manifests/$BUILDKITE_COMMIT" +TOKEN=$(curl -s -L "https://public.ecr.aws/token?service=public.ecr.aws&scope=repository:q9t5s3a7/vllm-ci-postmerge-repo:pull" | jq -r .token) +if [[ "$BUILDKITE_BRANCH" == "main" ]]; then + URL="https://public.ecr.aws/v2/q9t5s3a7/vllm-ci-postmerge-repo/manifests/$BUILDKITE_COMMIT" +else + URL="https://public.ecr.aws/v2/q9t5s3a7/vllm-ci-test-repo/manifests/$BUILDKITE_COMMIT" +fi + +TIMEOUT_SECONDS=10 retries=0 while [ $retries -lt 1000 ]; do - if [ $(curl -s -L -H "Authorization: Bearer $TOKEN" -o /dev/null -w "%{http_code}" $URL) -eq 200 ]; then + if [ "$(curl -s --max-time "$TIMEOUT_SECONDS" -L -H "Authorization: Bearer $TOKEN" -o /dev/null -w "%{http_code}" "$URL")" -eq 200 ]; then exit 0 fi @@ -14,4 +20,4 @@ while [ $retries -lt 1000 ]; do sleep 5 done -exit 1 \ No newline at end of file +exit 1 diff --git a/.buildkite/nightly-benchmarks/tests/descriptions.md b/.buildkite/nightly-benchmarks/tests/descriptions.md deleted file mode 100644 index 891e4917070d..000000000000 --- a/.buildkite/nightly-benchmarks/tests/descriptions.md +++ /dev/null @@ -1,67 +0,0 @@ - -## Latency tests - -This test suite aims to test vllm's end-to-end latency under a controlled setup. - -- Input length: 32 tokens. -- Output length: 128 tokens. -- Batch size: fixed (8). -- Models: llama-3 8B, llama-3 70B, mixtral 8x7B. -- Evaluation metrics: end-to-end latency (mean, median, p99). - -### Latency benchmarking results - -{latency_tests_markdown_table} - -## Throughput tests - -This test suite aims to test vllm's throughput. - -- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed). -- Output length: the corresponding output length of these 200 prompts. -- Batch size: dynamically determined by vllm to achieve maximum throughput. -- Models: llama-3 8B, llama-3 70B, mixtral 8x7B. -- Evaluation metrics: throughput. - -### Throughput benchmarking results - -{throughput_tests_markdown_table} - -## Serving tests - -This test suite aims to test vllm's real serving metrics. - -- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed). -- Output length: the corresponding output length of these 200 prompts. -- Batch size: dynamically determined by vllm and the arrival pattern of the requests. -- **Average QPS (query per second)**: 1, 4, 16 and inf. QPS = inf means all requests come at once. For other QPS values, the arrival time of each query is determined using a random Poisson process (with fixed random seed). -- Models: llama-3 8B, llama-3 70B, mixtral 8x7B. -- Evaluation metrics: throughput, TTFT (time to the first token, with mean, median and p99), ITL (inter-token latency, with mean, median and p99). - -### Serving benchmarking results - -{serving_tests_markdown_table} - -## json version of the benchmarking tables - -This section contains the data of the markdown tables above in JSON format. -You can load the benchmarking tables into pandas dataframes as follows: - -```python -import json -import pandas as pd - -benchmarking_results_json = """The json string""" -benchmarking_results = json.loads(benchmarking_results_json) -latency_results = pd.DataFrame.from_dict(benchmarking_results["latency"]) -throughput_results = pd.DataFrame.from_dict(benchmarking_results["throughput"]) -serving_results = pd.DataFrame.from_dict(benchmarking_results["serving"]) -``` - -The json string for all benchmarking tables: -```json -{benchmarking_results_in_json_string} -``` - -You can also check the raw experiment data in the Artifact tab of the Buildkite page. - diff --git a/.buildkite/nightly-benchmarks/tests/genai-perf-tests.json b/.buildkite/nightly-benchmarks/tests/genai-perf-tests.json new file mode 100644 index 000000000000..edbe9f2df0ce --- /dev/null +++ b/.buildkite/nightly-benchmarks/tests/genai-perf-tests.json @@ -0,0 +1,23 @@ +[ + { + "test_name": "llama8B_tp1_genai_perf", + "qps_list": [4,8,16,32], + "common_parameters": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "tp": 1, + "port": 8000, + "num_prompts": 500, + "reuse_server": false + }, + "vllm_server_parameters": { + "disable_log_stats": "", + "disable_log_requests": "", + "gpu_memory_utilization": 0.9, + "num_scheduler_steps": 10, + "max_num_seqs": 512, + "dtype": "bfloat16" + }, + "genai_perf_input_parameters": { + } + } +] \ No newline at end of file diff --git a/.buildkite/nightly-benchmarks/tests/latency-tests.json b/.buildkite/nightly-benchmarks/tests/latency-tests.json index 06488cd79110..7762a239f96a 100644 --- a/.buildkite/nightly-benchmarks/tests/latency-tests.json +++ b/.buildkite/nightly-benchmarks/tests/latency-tests.json @@ -2,7 +2,7 @@ { "test_name": "latency_llama8B_tp1", "parameters": { - "model": "meta-llama/Meta-Llama-3-8B", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "tensor_parallel_size": 1, "load_format": "dummy", "num_iters_warmup": 5, @@ -12,7 +12,7 @@ { "test_name": "latency_llama70B_tp4", "parameters": { - "model": "meta-llama/Meta-Llama-3-70B-Instruct", + "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", "tensor_parallel_size": 4, "load_format": "dummy", "num-iters-warmup": 5, @@ -29,4 +29,4 @@ "num-iters": 15 } } -] \ No newline at end of file +] diff --git a/.buildkite/nightly-benchmarks/tests/nightly-tests.json b/.buildkite/nightly-benchmarks/tests/nightly-tests.json index f250833c6271..fda1a7a3ec53 100644 --- a/.buildkite/nightly-benchmarks/tests/nightly-tests.json +++ b/.buildkite/nightly-benchmarks/tests/nightly-tests.json @@ -1,16 +1,18 @@ [ { - "test_name": "llama8B_tp1", - "qps_list": [4], + "test_name": "llama8B_tp1_sharegpt", + "qps_list": [4,8,16,32,"inf"], "common_parameters": { - "model": "meta-llama/Meta-Llama-3-8B", + "model": "meta-llama/Meta-Llama-3-8B-Instruct", "tp": 1, "dataset_name": "sharegpt", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", "num_prompts": 500, - "port": 8000 + "port": 8000, + "reuse_server": false }, "lmdeploy_server_parameters": { + "dtype": "bfloat16" }, "lmdeploy_client_parameters": { }, @@ -21,34 +23,158 @@ }, "trt_server_parameters": { "model_type": "llama", - "model_dtype": "float16", - "max_batch_size": 256, + "model_dtype": "bfloat16", + "max_batch_size": 2048, "max_input_len": 4096, - "max_output_len": 4096, - "trt_llm_version": "r24.04" + "max_seq_len": 6144, + "max_num_tokens": 16384, + "trt_llm_version": "v0.11.0" }, "trt_client_parameters": { "endpoint": "/v2/models/ensemble/generate_stream" + }, + "vllm_server_parameters": { + "disable_log_stats": "", + "disable_log_requests": "", + "gpu_memory_utilization": 0.9, + "num_scheduler_steps": 10, + "max_num_seqs": 512, + "dtype": "bfloat16" + }, + "vllm_client_parameters": { + }, + "sglang_server_parameters": { + "disable_radix_cache": "", + "enable_torch_compile": "", + "dtype": "bfloat16" + }, + "sglang_client_parameters": { + } + }, + { + "test_name": "llama8B_tp1_sonnet_512_16", + "qps_list": [4,8,16,32,"inf"], + "common_parameters": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "tp": 1, + "dataset_name": "sonnet", + "dataset_path": "./sonnet_4x.txt", + "num_prompts": 500, + "port": 8000, + "sonnet_input_len": 512, + "sonnet_output_len": 16, + "sonnet_prefix_len": 50, + "reuse_server": true + }, + "lmdeploy_server_parameters": { + "dtype": "bfloat16" + }, + "lmdeploy_client_parameters": { + }, + "tgi_server_parameters": { + }, + "tgi_client_parameters": { + "endpoint": "/generate_stream" + }, + "trt_server_parameters": { + "model_type": "llama", + "model_dtype": "bfloat16", + "max_batch_size": 2048, + "max_input_len": 4096, + "max_seq_len": 6144, + "max_num_tokens": 16384, + "trt_llm_version": "v0.11.0" + }, + "trt_client_parameters": { + "endpoint": "/v2/models/ensemble/generate_stream" + }, + "vllm_server_parameters": { + "disable_log_stats": "", + "disable_log_requests": "", + "gpu_memory_utilization": 0.9, + "num_scheduler_steps": 10, + "max_num_seqs": 512, + "dtype": "bfloat16" + }, + "vllm_client_parameters": { + }, + "sglang_server_parameters": { + "disable_radix_cache": "", + "enable_torch_compile": "", + "dtype": "bfloat16" + }, + "sglang_client_parameters": { + } + }, + { + "test_name": "llama8B_tp1_sonnet_512_256", + "qps_list": [4,8,16,32,"inf"], + "common_parameters": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "tp": 1, + "dataset_name": "sonnet", + "dataset_path": "./sonnet_4x.txt", + "num_prompts": 500, + "port": 8000, + "sonnet_input_len": 512, + "sonnet_output_len": 256, + "sonnet_prefix_len": 50, + "reuse_server": true + }, + "lmdeploy_server_parameters": { + "dtype": "bfloat16" + }, + "lmdeploy_client_parameters": { + }, + "tgi_server_parameters": { + }, + "tgi_client_parameters": { + "endpoint": "/generate_stream" + }, + "trt_server_parameters": { + "model_type": "llama", + "model_dtype": "bfloat16", + "max_batch_size": 2048, + "max_input_len": 4096, + "max_seq_len": 6144, + "max_num_tokens": 16384, + "trt_llm_version": "v0.11.0" }, + "trt_client_parameters": { + "endpoint": "/v2/models/ensemble/generate_stream" + }, "vllm_server_parameters": { "disable_log_stats": "", - "disable_log_requests": "" + "disable_log_requests": "", + "gpu_memory_utilization": 0.9, + "num_scheduler_steps": 10, + "max_num_seqs": 512, + "dtype": "bfloat16" }, "vllm_client_parameters": { + }, + "sglang_server_parameters": { + "disable_radix_cache": "", + "enable_torch_compile": "", + "dtype": "bfloat16" + }, + "sglang_client_parameters": { } }, { - "test_name": "llama70B_tp4", - "qps_list": [2], + "test_name": "llama70B_tp4_sharegpt", + "qps_list": [4,8,16,32,"inf"], "common_parameters": { "model": "meta-llama/Meta-Llama-3-70B-Instruct", "tp": 4, "dataset_name": "sharegpt", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", "num_prompts": 500, - "port": 8000 + "port": 8000, + "reuse_server": false }, "lmdeploy_server_parameters": { + "dtype": "bfloat16" }, "lmdeploy_client_parameters": { }, @@ -59,34 +185,50 @@ }, "trt_server_parameters": { "model_type": "llama", - "model_dtype": "float16", - "max_batch_size": 256, + "model_dtype": "bfloat16", + "max_batch_size": 2048, "max_input_len": 4096, - "max_output_len": 4096, - "trt_llm_version": "r24.04" + "max_seq_len": 6144, + "max_num_tokens": 16384, + "trt_llm_version": "v0.11.0" }, "trt_client_parameters": { "endpoint": "/v2/models/ensemble/generate_stream" - }, + }, "vllm_server_parameters": { "disable_log_stats": "", - "disable_log_requests": "" + "disable_log_requests": "", + "gpu_memory_utilization": 0.9, + "num_scheduler_steps": 10, + "max_num_seqs": 512, + "dtype": "bfloat16" }, "vllm_client_parameters": { + }, + "sglang_server_parameters": { + "disable_radix_cache": "", + "dtype": "bfloat16" + }, + "sglang_client_parameters": { } }, { - "test_name": "mixtral8x7B_tp2", - "qps_list": [2], + "test_name": "llama70B_tp4_sonnet_512_16", + "qps_list": [4,8,16,32,"inf"], "common_parameters": { - "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", - "tp": 2, - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "model": "meta-llama/Meta-Llama-3-70B-Instruct", + "tp": 4, + "dataset_name": "sonnet", + "dataset_path": "./sonnet_4x.txt", "num_prompts": 500, - "port": 8000 + "port": 8000, + "sonnet_input_len": 512, + "sonnet_output_len": 16, + "sonnet_prefix_len": 50, + "reuse_server": true }, "lmdeploy_server_parameters": { + "dtype": "bfloat16" }, "lmdeploy_client_parameters": { }, @@ -97,20 +239,85 @@ }, "trt_server_parameters": { "model_type": "llama", - "model_dtype": "float16", - "max_batch_size": 256, + "model_dtype": "bfloat16", + "max_batch_size": 2048, "max_input_len": 4096, - "max_output_len": 4096, - "trt_llm_version": "r24.04" + "max_seq_len": 6144, + "max_num_tokens": 16384, + "trt_llm_version": "v0.11.0" }, "trt_client_parameters": { "endpoint": "/v2/models/ensemble/generate_stream" + }, + "vllm_server_parameters": { + "disable_log_stats": "", + "disable_log_requests": "", + "gpu_memory_utilization": 0.9, + "num_scheduler_steps": 10, + "max_num_seqs": 512, + "dtype": "bfloat16" + }, + "vllm_client_parameters": { }, + "sglang_server_parameters": { + "disable_radix_cache": "", + "dtype": "bfloat16" + }, + "sglang_client_parameters": { + } + }, + { + "test_name": "llama70B_tp4_sonnet_512_256", + "qps_list": [4,8,16,32,"inf"], + "common_parameters": { + "model": "meta-llama/Meta-Llama-3-70B-Instruct", + "tp": 4, + "dataset_name": "sonnet", + "dataset_path": "./sonnet_4x.txt", + "num_prompts": 500, + "port": 8000, + "sonnet_input_len": 512, + "sonnet_output_len": 256, + "sonnet_prefix_len": 50, + "reuse_server": true + }, + "lmdeploy_server_parameters": { + "dtype": "bfloat16" + }, + "lmdeploy_client_parameters": { + }, + "tgi_server_parameters": { + }, + "tgi_client_parameters": { + "endpoint": "/generate_stream" + }, + "trt_server_parameters": { + "model_type": "llama", + "model_dtype": "bfloat16", + "max_batch_size": 2048, + "max_input_len": 4096, + "max_seq_len": 6144, + "max_num_tokens": 16384, + "trt_llm_version": "v0.11.0" + }, + "trt_client_parameters": { + "endpoint": "/v2/models/ensemble/generate_stream" + }, "vllm_server_parameters": { "disable_log_stats": "", - "disable_log_requests": "" + "disable_log_requests": "", + "gpu_memory_utilization": 0.9, + "num_scheduler_steps": 10, + "max_num_seqs": 512, + "dtype": "bfloat16" }, "vllm_client_parameters": { + }, + "sglang_server_parameters": { + "disable_radix_cache": "", + "dtype": "bfloat16" + }, + "sglang_client_parameters": { } } ] \ No newline at end of file diff --git a/.buildkite/nightly-benchmarks/tests/serving-tests.json b/.buildkite/nightly-benchmarks/tests/serving-tests.json index 300af0524d7c..415171e268b0 100644 --- a/.buildkite/nightly-benchmarks/tests/serving-tests.json +++ b/.buildkite/nightly-benchmarks/tests/serving-tests.json @@ -3,7 +3,7 @@ "test_name": "serving_llama8B_tp1_sharegpt", "qps_list": [1, 4, 16, "inf"], "server_parameters": { - "model": "meta-llama/Meta-Llama-3-8B", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "tensor_parallel_size": 1, "swap_space": 16, "disable_log_stats": "", @@ -11,7 +11,7 @@ "load_format": "dummy" }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3-8B", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "backend": "vllm", "dataset_name": "sharegpt", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", @@ -22,7 +22,7 @@ "test_name": "serving_llama70B_tp4_sharegpt", "qps_list": [1, 4, 16, "inf"], "server_parameters": { - "model": "meta-llama/Meta-Llama-3-70B-Instruct", + "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", "tensor_parallel_size": 4, "swap_space": 16, "disable_log_stats": "", @@ -30,7 +30,7 @@ "load_format": "dummy" }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3-70B-Instruct", + "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", "backend": "vllm", "dataset_name": "sharegpt", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", @@ -60,17 +60,16 @@ "test_name": "serving_llama70B_tp4_sharegpt_specdecode", "qps_list": [2], "server_parameters": { - "model": "meta-llama/Meta-Llama-3-70B-Instruct", + "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", "disable_log_requests": "", "tensor_parallel_size": 4, "swap_space": 16, "speculative_model": "turboderp/Qwama-0.5B-Instruct", "num_speculative_tokens": 4, - "speculative_draft_tensor_parallel_size": 1, - "use_v2_block_manager": "" + "speculative_draft_tensor_parallel_size": 1 }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3-70B-Instruct", + "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", "backend": "vllm", "dataset_name": "sharegpt", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", diff --git a/.buildkite/nightly-benchmarks/tests/throughput-tests.json b/.buildkite/nightly-benchmarks/tests/throughput-tests.json index 41ac13574870..9bc87cbcd2bc 100644 --- a/.buildkite/nightly-benchmarks/tests/throughput-tests.json +++ b/.buildkite/nightly-benchmarks/tests/throughput-tests.json @@ -2,7 +2,7 @@ { "test_name": "throughput_llama8B_tp1", "parameters": { - "model": "meta-llama/Meta-Llama-3-8B", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "tensor_parallel_size": 1, "load_format": "dummy", "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", @@ -13,7 +13,7 @@ { "test_name": "throughput_llama70B_tp4", "parameters": { - "model": "meta-llama/Meta-Llama-3-70B-Instruct", + "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", "tensor_parallel_size": 4, "load_format": "dummy", "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", @@ -32,4 +32,4 @@ "backend": "vllm" } } -] \ No newline at end of file +] diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 5be9a553dddd..18f582b6e4c9 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -1,19 +1,88 @@ steps: - - label: "Build wheel - CUDA {{matrix.cuda_version}}" + - label: "Build wheel - CUDA 12.4" agents: - queue: cpu_queue + queue: cpu_queue_postmerge commands: - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg buildkite_commit=$BUILDKITE_COMMIT --build-arg USE_SCCACHE=1 --build-arg CUDA_VERSION={{matrix.cuda_version}} --tag vllm-ci:build-image --target build --progress plain ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.4.0 --tag vllm-ci:build-image --target build --progress plain ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - # rename the files to change linux -> manylinux1 - - "for f in artifacts/dist/*.whl; do mv -- \"$$f\" \"$${f/linux/manylinux1}\"; done" - - "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/$BUILDKITE_COMMIT/" - - "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/nightly/" + - "bash .buildkite/upload-wheels.sh" + env: + DOCKER_BUILDKIT: "1" + + - label: "Build wheel - CUDA 12.1" + agents: + queue: cpu_queue_postmerge + commands: + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.1.0 --tag vllm-ci:build-image --target build --progress plain ." + - "mkdir artifacts" + - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" + - "bash .buildkite/upload-wheels.sh" + env: + DOCKER_BUILDKIT: "1" + + # Note(simon): We can always build CUDA 11.8 wheel to ensure the build is working. + # However, this block can be uncommented to save some compute hours. + # - block: "Build CUDA 11.8 wheel" + # key: block-build-cu118-wheel + + - label: "Build wheel - CUDA 11.8" + # depends_on: block-build-cu118-wheel + agents: + queue: cpu_queue_postmerge + commands: + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --tag vllm-ci:build-image --target build --progress plain ." + - "mkdir artifacts" + - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" + - "bash .buildkite/upload-wheels.sh" + env: + DOCKER_BUILDKIT: "1" + + - block: "Build release image" + depends_on: ~ + key: block-release-image-build + + - label: "Build release image" + depends_on: block-release-image-build + agents: + queue: cpu_queue_postmerge + commands: + - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.4.0 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain ." + - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" + + - label: "Build and publish TPU release image" + depends_on: ~ + if: build.env("NIGHTLY") == "1" + agents: + queue: tpu_queue_postmerge + commands: + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --tag vllm/vllm-tpu:nightly --tag vllm/vllm-tpu:$BUILDKITE_COMMIT --progress plain -f Dockerfile.tpu ." + - "docker push vllm/vllm-tpu:nightly" + - "docker push vllm/vllm-tpu:$BUILDKITE_COMMIT" + plugins: + - docker-login#v3.0.0: + username: vllm + password-env: DOCKERHUB_TOKEN + env: + DOCKER_BUILDKIT: "1" + + - input: "Provide Release version here" + fields: + - text: "What is the release version?" + key: "release-version" + + - block: "Build CPU release image" + key: block-cpu-release-image-build + depends_on: ~ + + - label: "Build and publish CPU release image" + depends_on: block-cpu-release-image-build + agents: + queue: cpu_queue_postmerge + commands: + - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain -f Dockerfile.cpu ." + - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)" env: DOCKER_BUILDKIT: "1" - matrix: - setup: - cuda_version: - - "11.8.0" - - "12.1.0" diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh old mode 100644 new mode 100755 index 77e451354caf..0680bae13ddb --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -1,5 +1,7 @@ +#!/bin/bash + # This script runs test inside the corresponding ROCm docker container. -set -ex +set -o pipefail # Print ROCm version echo "--- Confirming Clean Initial State" @@ -31,8 +33,8 @@ cleanup_docker() { echo "Disk usage is above $threshold%. Cleaning up Docker images and volumes..." # Remove dangling images (those that are not tagged and not used by any container) docker image prune -f - # Remove unused volumes - docker volume prune -f + # Remove unused volumes / force the system prune for old images as well. + docker volume prune -f && docker system prune --force --filter "until=72h" --all echo "Docker images and volumes cleanup completed." else echo "Disk usage is below $threshold%. No cleanup needed." @@ -55,31 +57,124 @@ while true; do done echo "--- Pulling container" -docker login registry-1.docker.io -u alexeivivanovamd -p ${DH_TOKEN} -image_name="rocmshared/vllm-ci:${BUILDKITE_COMMIT}" +image_name="rocm/vllm-ci:${BUILDKITE_COMMIT}" container_name="rocm_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)" -docker pull ${image_name} +docker pull "${image_name}" remove_docker_container() { - docker rm -f ${container_name} || docker image rm -f ${image_name} || true + docker rm -f "${container_name}" || docker image rm -f "${image_name}" || true } trap remove_docker_container EXIT echo "--- Running container" HF_CACHE="$(realpath ~)/huggingface" -mkdir -p ${HF_CACHE} +mkdir -p "${HF_CACHE}" HF_MOUNT="/root/.cache/huggingface" -docker run \ +commands=$@ +echo "Commands:$commands" +#ignore certain kernels tests +if [[ $commands == *" kernels "* ]]; then + commands="${commands} \ + --ignore=kernels/test_attention_selector.py \ + --ignore=kernels/test_blocksparse_attention.py \ + --ignore=kernels/test_causal_conv1d.py \ + --ignore=kernels/test_cutlass.py \ + --ignore=kernels/test_encoder_decoder_attn.py \ + --ignore=kernels/test_flash_attn.py \ + --ignore=kernels/test_flashinfer.py \ + --ignore=kernels/test_int8_quant.py \ + --ignore=kernels/test_machete_gemm.py \ + --ignore=kernels/test_mamba_ssm.py \ + --ignore=kernels/test_marlin_gemm.py \ + --ignore=kernels/test_moe.py \ + --ignore=kernels/test_prefix_prefill.py \ + --ignore=kernels/test_rand.py \ + --ignore=kernels/test_sampler.py \ + --ignore=kernels/test_cascade_flash_attn.py \ + --ignore=kernels/test_mamba_mixer2.py \ + --ignore=kernels/test_aqlm.py \ + --ignore=kernels/test_machete_mm.py \ + --ignore=kernels/test_mha_attn.py \ + --ignore=kernels/test_block_fp8.py \ + --ignore=kernels/test_permute_cols.py" +fi + +#ignore certain Entrypoints/openai tests +if [[ $commands == *" entrypoints/openai "* ]]; then + commands=${commands//" entrypoints/openai "/" entrypoints/openai \ + --ignore=entrypoints/openai/test_audio.py \ + --ignore=entrypoints/openai/test_chat.py \ + --ignore=entrypoints/openai/test_shutdown.py \ + --ignore=entrypoints/openai/test_completion.py \ + --ignore=entrypoints/openai/test_sleep.py \ + --ignore=entrypoints/openai/test_models.py \ + --ignore=entrypoints/openai/test_prompt_validation.py "} +fi + +#ignore certain Entrypoints/llm tests +if [[ $commands == *" && pytest -v -s entrypoints/llm/test_guided_generate.py"* ]]; then + commands=${commands//" && pytest -v -s entrypoints/llm/test_guided_generate.py"/" "} +fi + +# --ignore=entrypoints/openai/test_encoder_decoder.py \ +# --ignore=entrypoints/openai/test_embedding.py \ +# --ignore=entrypoints/openai/test_oot_registration.py +# --ignore=entrypoints/openai/test_accuracy.py \ +# --ignore=entrypoints/openai/test_models.py <= Fails on MI250 but passes on MI300 as of 2025-03-13 + + +PARALLEL_JOB_COUNT=8 +# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs. +if [[ $commands == *"--shard-id="* ]]; then + # assign job count as the number of shards used + commands=${commands//"--num-shards= "/"--num-shards=${PARALLEL_JOB_COUNT} "} + for GPU in $(seq 0 $(($PARALLEL_JOB_COUNT-1))); do + # assign shard-id for each shard + commands_gpu=${commands//"--shard-id= "/"--shard-id=${GPU} "} + echo "Shard ${GPU} commands:$commands_gpu" + docker run \ --device /dev/kfd --device /dev/dri \ --network host \ --shm-size=16gb \ --rm \ + -e HIP_VISIBLE_DEVICES="${GPU}" \ -e HF_TOKEN \ - -v ${HF_CACHE}:${HF_MOUNT} \ - -e HF_HOME=${HF_MOUNT} \ - --name ${container_name} \ - ${image_name} \ - /bin/bash -c "${@}" - + -e AWS_ACCESS_KEY_ID \ + -e AWS_SECRET_ACCESS_KEY \ + -v "${HF_CACHE}:${HF_MOUNT}" \ + -e "HF_HOME=${HF_MOUNT}" \ + --name "${container_name}_${GPU}" \ + "${image_name}" \ + /bin/bash -c "${commands_gpu}" \ + |& while read -r line; do echo ">>Shard $GPU: $line"; done & + PIDS+=($!) + done + #wait for all processes to finish and collect exit codes + for pid in "${PIDS[@]}"; do + wait "${pid}" + STATUS+=($?) + done + for st in "${STATUS[@]}"; do + if [[ ${st} -ne 0 ]]; then + echo "One of the processes failed with $st" + exit "${st}" + fi + done +else + docker run \ + --device /dev/kfd --device /dev/dri \ + --network host \ + --shm-size=16gb \ + --rm \ + -e HIP_VISIBLE_DEVICES=0 \ + -e HF_TOKEN \ + -e AWS_ACCESS_KEY_ID \ + -e AWS_SECRET_ACCESS_KEY \ + -v "${HF_CACHE}:${HF_MOUNT}" \ + -e "HF_HOME=${HF_MOUNT}" \ + --name "${container_name}" \ + "${image_name}" \ + /bin/bash -c "${commands}" +fi diff --git a/.buildkite/run-benchmarks.sh b/.buildkite/run-benchmarks.sh index cbf6dda677c5..1641c1faa9d6 100644 --- a/.buildkite/run-benchmarks.sh +++ b/.buildkite/run-benchmarks.sh @@ -1,3 +1,5 @@ +#!/bin/bash + # This script is run by buildkite to run the benchmarks and upload the results to buildkite set -ex diff --git a/.buildkite/run-cpu-test-ppc64le.sh b/.buildkite/run-cpu-test-ppc64le.sh new file mode 100755 index 000000000000..bc06838d804f --- /dev/null +++ b/.buildkite/run-cpu-test-ppc64le.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +# This script build the CPU docker image and run the offline inference inside the container. +# It serves a sanity check for compilation and basic model usage. +set -ex + +# Setup cleanup +remove_docker_container() { docker rm -f cpu-test || true; docker system prune -f; } +trap remove_docker_container EXIT +remove_docker_container + +# Try building the docker image +docker build -t cpu-test -f Dockerfile.ppc64le . + diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index 45bc8eb2f847..e45e184852f2 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -1,40 +1,90 @@ +#!/bin/bash + # This script build the CPU docker image and run the offline inference inside the container. # It serves a sanity check for compilation and basic model usage. set -ex +# allow to bind to different cores +CORE_RANGE=${CORE_RANGE:-48-95} +NUMA_NODE=${NUMA_NODE:-1} + # Try building the docker image -numactl -C 48-95 -N 1 docker build -t cpu-test -f Dockerfile.cpu . -numactl -C 48-95 -N 1 docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu . +numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build -t cpu-test-"$BUILDKITE_BUILD_NUMBER" -f Dockerfile.cpu . +numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2 -f Dockerfile.cpu . # Setup cleanup -remove_docker_container() { docker rm -f cpu-test cpu-test-avx2 || true; } +remove_docker_container() { set -e; docker rm -f cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" || true; } trap remove_docker_container EXIT remove_docker_container # Run the image, setting --shm-size=4g for tensor parallel. -docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \ - --cpuset-mems=1 --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test -docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \ - --cpuset-mems=1 --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-avx2 cpu-test-avx2 - -# offline inference -docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" - -# Run basic model test -docker exec cpu-test bash -c " - pip install pytest Pillow protobuf - pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported - -# online inference -docker exec cpu-test bash -c " - export VLLM_CPU_KVCACHE_SPACE=10 - export VLLM_CPU_OMP_THREADS_BIND=48-92 - python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m & - timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 - python3 benchmarks/benchmark_serving.py \ - --backend vllm \ - --dataset-name random \ - --model facebook/opt-125m \ - --num-prompts 20 \ - --endpoint /v1/completions \ - --tokenizer facebook/opt-125m" +docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \ + --cpuset-mems="$NUMA_NODE" --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER" +docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \ + --cpuset-mems="$NUMA_NODE" --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2 + +function cpu_tests() { + set -e + export NUMA_NODE=$2 + export BUILDKITE_BUILD_NUMBER=$3 + + # offline inference + docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" bash -c " + set -e + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" + + # Run basic model test + docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " + set -e + pip install -r vllm/requirements/test.txt + pip install -r vllm/requirements/cpu.txt + pytest -v -s tests/models/decoder_only/language -m cpu_model + pytest -v -s tests/models/embedding/language -m cpu_model + pytest -v -s tests/models/encoder_decoder/language -m cpu_model + pytest -v -s tests/models/decoder_only/audio_language -m cpu_model + pytest -v -s tests/models/decoder_only/vision_language -m cpu_model" + + # Run compressed-tensor test + docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " + set -e + pytest -s -v \ + tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \ + tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token" + + # Run AWQ test + docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " + set -e + pytest -s -v \ + tests/quantization/test_ipex_quant.py" + + # Run chunked-prefill and prefix-cache test + docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " + set -e + pytest -s -v -k cpu_model \ + tests/basic_correctness/test_chunked_prefill.py" + + # online serving + docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " + set -e + export VLLM_CPU_KVCACHE_SPACE=10 + export VLLM_CPU_OMP_THREADS_BIND=$1 + python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half & + timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 + python3 benchmarks/benchmark_serving.py \ + --backend vllm \ + --dataset-name random \ + --model facebook/opt-125m \ + --num-prompts 20 \ + --endpoint /v1/completions \ + --tokenizer facebook/opt-125m" + + # Run multi-lora tests + docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " + set -e + pytest -s -v \ + tests/lora/test_qwen2vl.py" +} + +# All of CPU tests are expected to be finished less than 40 mins. +export -f cpu_tests +timeout 40m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE $BUILDKITE_BUILD_NUMBER" diff --git a/.buildkite/run-gh200-test.sh b/.buildkite/run-gh200-test.sh new file mode 100644 index 000000000000..5c004b47778f --- /dev/null +++ b/.buildkite/run-gh200-test.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +# This script build the GH200 docker image and run the offline inference inside the container. +# It serves a sanity check for compilation and basic model usage. +set -ex + +# Skip the new torch installation during build since we are using the specified version for arm64 in the Dockerfile +python3 use_existing_torch.py + +# Try building the docker image +DOCKER_BUILDKIT=1 docker build . \ + --target vllm-openai \ + --platform "linux/arm64" \ + -t gh200-test \ + --build-arg max_jobs=66 \ + --build-arg nvcc_threads=2 \ + --build-arg RUN_WHEEL_CHECK=false \ + --build-arg torch_cuda_arch_list="9.0+PTX" \ + --build-arg vllm_fa_cmake_gpu_arches="90-real" + +# Setup cleanup +remove_docker_container() { docker rm -f gh200-test || true; } +trap remove_docker_container EXIT +remove_docker_container + +# Run the image and test offline inference +docker run -e HF_TOKEN -e VLLM_WORKER_MULTIPROC_METHOD=spawn -v /root/.cache/huggingface:/root/.cache/huggingface --name gh200-test --gpus=all --entrypoint="" gh200-test bash -c ' + python3 examples/offline_inference/basic/generate.py --model meta-llama/Llama-3.2-1B +' diff --git a/.buildkite/run-hpu-test.sh b/.buildkite/run-hpu-test.sh new file mode 100644 index 000000000000..f83eb927aae4 --- /dev/null +++ b/.buildkite/run-hpu-test.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# This script build the CPU docker image and run the offline inference inside the container. +# It serves a sanity check for compilation and basic model usage. +set -ex + +# Try building the docker image +docker build -t hpu-test-env -f Dockerfile.hpu . + +# Setup cleanup +# certain versions of HPU software stack have a bug that can +# override the exit code of the script, so we need to use +# separate remove_docker_container and remove_docker_container_and_exit +# functions, while other platforms only need one remove_docker_container +# function. +EXITCODE=1 +remove_docker_container() { docker rm -f hpu-test || true; } +remove_docker_container_and_exit() { remove_docker_container; exit $EXITCODE; } +trap remove_docker_container_and_exit EXIT +remove_docker_container + +# Run the image and launch offline inference +docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m +EXITCODE=$? diff --git a/.buildkite/run-multi-node-test.sh b/.buildkite/run-multi-node-test.sh index 7ac4dcc4c786..530bf90a855f 100755 --- a/.buildkite/run-multi-node-test.sh +++ b/.buildkite/run-multi-node-test.sh @@ -14,7 +14,7 @@ DOCKER_IMAGE=$4 shift 4 COMMANDS=("$@") -if [ ${#COMMANDS[@]} -ne $NUM_NODES ]; then +if [ ${#COMMANDS[@]} -ne "$NUM_NODES" ]; then echo "The number of commands must be equal to the number of nodes." echo "Number of nodes: $NUM_NODES" echo "Number of commands: ${#COMMANDS[@]}" @@ -23,7 +23,7 @@ fi echo "List of commands" for command in "${COMMANDS[@]}"; do - echo $command + echo "$command" done start_network() { @@ -36,7 +36,7 @@ start_nodes() { for node_gpu in $(seq 0 $(($NUM_GPUS - 1))); do DEVICE_NUM=$(($node * $NUM_GPUS + $node_gpu)) GPU_DEVICES+=$(($DEVICE_NUM)) - if [ $node_gpu -lt $(($NUM_GPUS - 1)) ]; then + if [ "$node_gpu" -lt $(($NUM_GPUS - 1)) ]; then GPU_DEVICES+=',' fi done @@ -49,17 +49,20 @@ start_nodes() { # 3. map the huggingface cache directory to the container # 3. assign ip addresses to the containers (head node: 192.168.10.10, worker nodes: # starting from 192.168.10.11) - docker run -d --gpus "$GPU_DEVICES" --shm-size=10.24gb -e HF_TOKEN -v ~/.cache/huggingface:/root/.cache/huggingface --name node$node --network docker-net --ip 192.168.10.$((10 + $node)) --rm $DOCKER_IMAGE /bin/bash -c "tail -f /dev/null" + docker run -d --gpus "$GPU_DEVICES" --shm-size=10.24gb -e HF_TOKEN \ + -v ~/.cache/huggingface:/root/.cache/huggingface --name "node$node" \ + --network docker-net --ip 192.168.10.$((10 + $node)) --rm "$DOCKER_IMAGE" \ + /bin/bash -c "tail -f /dev/null" # organize containers into a ray cluster - if [ $node -eq 0 ]; then + if [ "$node" -eq 0 ]; then # start the ray head node - docker exec -d node$node /bin/bash -c "ray start --head --port=6379 --block" + docker exec -d "node$node" /bin/bash -c "ray start --head --port=6379 --block" # wait for the head node to be ready sleep 10 else # start the ray worker nodes, and connect them to the head node - docker exec -d node$node /bin/bash -c "ray start --address=192.168.10.10:6379 --block" + docker exec -d "node$node" /bin/bash -c "ray start --address=192.168.10.10:6379 --block" fi done @@ -79,22 +82,22 @@ run_nodes() { for node_gpu in $(seq 0 $(($NUM_GPUS - 1))); do DEVICE_NUM=$(($node * $NUM_GPUS + $node_gpu)) GPU_DEVICES+=$(($DEVICE_NUM)) - if [ $node_gpu -lt $(($NUM_GPUS - 1)) ]; then + if [ "$node_gpu" -lt $(($NUM_GPUS - 1)) ]; then GPU_DEVICES+=',' fi done GPU_DEVICES+='"' echo "Running node$node with GPU devices: $GPU_DEVICES" - if [ $node -ne 0 ]; then - docker exec -d node$node /bin/bash -c "cd $WORKING_DIR ; ${COMMANDS[$node]}" + if [ "$node" -ne 0 ]; then + docker exec -d "node$node" /bin/bash -c "cd $WORKING_DIR ; ${COMMANDS[$node]}" else - docker exec node$node /bin/bash -c "cd $WORKING_DIR ; ${COMMANDS[$node]}" + docker exec "node$node" /bin/bash -c "cd $WORKING_DIR ; ${COMMANDS[$node]}" fi done } cleanup() { for node in $(seq 0 $(($NUM_NODES-1))); do - docker stop node$node + docker stop "node$node" done docker network rm docker-net } diff --git a/.buildkite/run-neuron-test.sh b/.buildkite/run-neuron-test.sh index 252c0f7fecd1..ad5ae6f41574 100644 --- a/.buildkite/run-neuron-test.sh +++ b/.buildkite/run-neuron-test.sh @@ -1,6 +1,20 @@ +#!/bin/bash + # This script build the Neuron docker image and run the API server inside the container. # It serves a sanity check for compilation and basic model usage. set -e +set -v + +image_name="neuron/vllm-ci" +container_name="neuron_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)" + +HF_CACHE="$(realpath ~)/huggingface" +mkdir -p "${HF_CACHE}" +HF_MOUNT="/root/.cache/huggingface" + +NEURON_COMPILE_CACHE_URL="$(realpath ~)/neuron_compile_cache" +mkdir -p "${NEURON_COMPILE_CACHE_URL}" +NEURON_COMPILE_CACHE_MOUNT="/root/.cache/neuron_compile_cache" # Try building the docker image aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com @@ -11,41 +25,30 @@ if [ -f /tmp/neuron-docker-build-timestamp ]; then last_build=$(cat /tmp/neuron-docker-build-timestamp) current_time=$(date +%s) if [ $((current_time - last_build)) -gt 86400 ]; then - docker system prune -f - echo $current_time > /tmp/neuron-docker-build-timestamp + # Remove dangling images (those that are not tagged and not used by any container) + docker image prune -f + # Remove unused volumes / force the system prune for old images as well. + docker volume prune -f && docker system prune -f + echo "$current_time" > /tmp/neuron-docker-build-timestamp fi else - echo $(date +%s) > /tmp/neuron-docker-build-timestamp + date "+%s" > /tmp/neuron-docker-build-timestamp fi -docker build -t neuron -f Dockerfile.neuron . +docker build -t "${image_name}" -f Dockerfile.neuron . # Setup cleanup -remove_docker_container() { docker rm -f neuron || true; } +remove_docker_container() { + docker image rm -f "${image_name}" || true; +} trap remove_docker_container EXIT -remove_docker_container # Run the image -docker run --device=/dev/neuron0 --device=/dev/neuron1 --network host --name neuron neuron python3 -m vllm.entrypoints.api_server \ - --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --max-num-seqs 8 --max-model-len 128 --block-size 128 --device neuron --tensor-parallel-size 2 & - -# Wait for the server to start -wait_for_server_to_start() { - timeout=300 - counter=0 - - while [ "$(curl -s -o /dev/null -w ''%{http_code}'' localhost:8000/health)" != "200" ]; do - sleep 1 - counter=$((counter + 1)) - if [ $counter -ge $timeout ]; then - echo "Timeout after $timeout seconds" - break - fi - done -} -wait_for_server_to_start - -# Test a simple prompt -curl -X POST -H "Content-Type: application/json" \ - localhost:8000/generate \ - -d '{"prompt": "San Francisco is a"}' +docker run --rm -it --device=/dev/neuron0 --network bridge \ + -v "${HF_CACHE}:${HF_MOUNT}" \ + -e "HF_HOME=${HF_MOUNT}" \ + -v "${NEURON_COMPILE_CACHE_URL}:${NEURON_COMPILE_CACHE_MOUNT}" \ + -e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \ + --name "${container_name}" \ + ${image_name} \ + /bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py && python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys && python3 -m pytest /workspace/vllm/tests/neuron/2_core/ -v --capture=tee-sys" diff --git a/.buildkite/run-openvino-test.sh b/.buildkite/run-openvino-test.sh deleted file mode 100755 index 70e56596c4a8..000000000000 --- a/.buildkite/run-openvino-test.sh +++ /dev/null @@ -1,14 +0,0 @@ -# This script build the OpenVINO docker image and run the offline inference inside the container. -# It serves a sanity check for compilation and basic model usage. -set -ex - -# Try building the docker image -docker build -t openvino-test -f Dockerfile.openvino . - -# Setup cleanup -remove_docker_container() { docker rm -f openvino-test || true; } -trap remove_docker_container EXIT -remove_docker_container - -# Run the image and launch offline inference -docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/vllm/examples/offline_inference.py diff --git a/.buildkite/run-tpu-test.sh b/.buildkite/run-tpu-test.sh deleted file mode 100644 index 4aabd123ae23..000000000000 --- a/.buildkite/run-tpu-test.sh +++ /dev/null @@ -1,16 +0,0 @@ -set -e - -# Build the docker image. -docker build -f Dockerfile.tpu -t vllm-tpu . - -# Set up cleanup. -remove_docker_container() { docker rm -f tpu-test || true; } -trap remove_docker_container EXIT -# Remove the container that might not be cleaned up in the previous run. -remove_docker_container - -# For HF_TOKEN. -source /etc/environment -# Run a simple end-to-end example. -docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu \ - python3 /workspace/vllm/examples/offline_inference_tpu.py diff --git a/.buildkite/run-tpu-v1-test.sh b/.buildkite/run-tpu-v1-test.sh new file mode 100755 index 000000000000..f0f53d3b716d --- /dev/null +++ b/.buildkite/run-tpu-v1-test.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +set -e + +# Build the docker image. +docker build -f Dockerfile.tpu -t vllm-tpu . + +# Set up cleanup. +remove_docker_container() { docker rm -f tpu-test || true; } +trap remove_docker_container EXIT +# Remove the container that might not be cleaned up in the previous run. +remove_docker_container + +# For HF_TOKEN. +source /etc/environment +# Run a simple end-to-end example. +docker run --privileged --net host --shm-size=16G -it \ + -e "HF_TOKEN=$HF_TOKEN" --name tpu-test \ + vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \ + && python3 -m pip install pytest \ + && python3 -m pip install lm_eval[api]==0.4.4 \ + && export VLLM_USE_V1=1 \ + && export VLLM_XLA_CHECK_RECOMPILATION=1 \ + && echo TEST_1 \ + && pytest /workspace/vllm/tests/tpu/test_compilation.py \ + && echo TEST_2 \ + && pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py \ + && echo TEST_3 \ + && pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \ + && echo TEST_4 \ + && pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \ + && echo TEST_5 \ + && python3 /workspace/vllm/examples/offline_inference/tpu.py" \ + + +# TODO: This test fails because it uses RANDOM_SEED sampling +# && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \ + diff --git a/.buildkite/run-xpu-test.sh b/.buildkite/run-xpu-test.sh index 22a7e76937a7..3a0e6bdb2caa 100644 --- a/.buildkite/run-xpu-test.sh +++ b/.buildkite/run-xpu-test.sh @@ -1,14 +1,31 @@ +#!/bin/bash + # This script build the CPU docker image and run the offline inference inside the container. # It serves a sanity check for compilation and basic model usage. set -ex +image_name="xpu/vllm-ci:${BUILDKITE_COMMIT}" +container_name="xpu_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)" + # Try building the docker image -docker build -t xpu-test -f Dockerfile.xpu . +docker build -t ${image_name} -f Dockerfile.xpu . # Setup cleanup -remove_docker_container() { docker rm -f xpu-test || true; } +remove_docker_container() { + docker rm -f "${container_name}" || true; + docker image rm -f "${image_name}" || true; + docker system prune -f || true; +} trap remove_docker_container EXIT -remove_docker_container -# Run the image and launch offline inference -docker run --network host --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path xpu-test python3 examples/offline_inference.py +# Run the image and test offline inference/tensor parallel +docker run \ + --device /dev/dri \ + -v /dev/dri/by-path:/dev/dri/by-path \ + --entrypoint="" \ + --name "${container_name}" \ + "${image_name}" \ + sh -c ' + VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m + VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2 +' diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9ec9ec12bfcf..217f869f1f3c 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -2,286 +2,652 @@ # adding a new command to an existing step. See different options here for examples. # This script will be feed into Jinja template in `test-template-aws.j2` at -# https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2 +# https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2 # to generate the final pipeline yaml file. +# Documentation +# label(str): the name of the test. emoji allowed. +# fast_check(bool): whether to run this on each commit on fastcheck pipeline. +# fast_check_only(bool): run this test on fastcheck pipeline only +# optional(bool): never run this test by default (i.e. need to unblock manually) unless it's scheduled nightly run. +# command(str): the single command to run for tests. incompatible with commands. +# commands(list): the list of commands to run for test. incompatbile with command. +# mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd] +# gpu(str): override the GPU selection for the test. default is on L4 GPUs. currently only supports a100 +# num_gpus(int): override the number of GPUs for the test. default to 1 GPU. currently support 2,4. +# num_nodes(int): whether to simulate multi-node setup by launch multiple containers on one host, +# in this case, commands must be specified. the first command runs on first host, the second +# command runs on the second host. +# working_dir(str): specify the place where command should execute, default to /vllm-workspace/tests +# source_file_dependencies(list): the list of prefix to opt-in the test for, if empty, the test will always run. + +# When adding a test +# - If the test belong to an existing group, add it there +# - If the test is short, add to any existing step +# - If the test takes more than 10min, then it is okay to create a new step. +# Note that all steps execute in parallel. steps: -- label: Async Engine, Inputs, Utils, Worker Test +##### fast check tests ##### + +- label: Documentation Build # 2min + working_dir: "/vllm-workspace/test_docs/docs" fast_check: true - fast_check_only: true + no_gpu: True + commands: + - pip install -r ../../requirements/docs.txt + - SPHINXOPTS=\"-W\" make html + # Check API reference (if it fails, you may have missing mock imports) + - grep \"sig sig-object py\" build/html/api/inference_params.html + +- label: Async Engine, Inputs, Utils, Worker Test # 24min + source_file_dependencies: + - vllm/ + - tests/mq_llm_engine + - tests/async_engine + - tests/test_inputs + - tests/multimodal + - tests/test_utils + - tests/worker + - tests/standalone_tests/lazy_imports.py commands: - - pytest -v -s async_engine # Async Engine + - python3 standalone_tests/lazy_imports.py + - pytest -v -s mq_llm_engine # MQLLMEngine + - pytest -v -s async_engine # AsyncLLMEngine + - NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py - pytest -v -s test_inputs.py - pytest -v -s multimodal - pytest -v -s test_utils.py # Utils - pytest -v -s worker # Worker -- label: Metrics, Tracing Test - fast_check: true - fast_check_only: true +- label: Python-only Installation Test + source_file_dependencies: + - tests/standalone_tests/python_only_compile.sh + - setup.py commands: - - pytest -v -s metrics # Metrics - - "pip install \ - opentelemetry-sdk \ - opentelemetry-api \ - opentelemetry-exporter-otlp \ - opentelemetry-semantic-conventions-ai" # Tracing - - pytest -v -s tracing + - bash standalone_tests/python_only_compile.sh -- label: Regression Test - mirror_hardwares: [amd] - fast_check: true - command: pytest -v -s test_regression.py - working_dir: "/vllm-workspace/tests" # optional - -- label: AsyncEngine Test +- label: Basic Correctness Test # 30min #mirror_hardwares: [amd] - command: pytest -v -s async_engine - -- label: Basic Correctness Test - mirror_hardwares: [amd] fast_check: true + source_file_dependencies: + - vllm/ + - tests/basic_correctness/test_basic_correctness + - tests/basic_correctness/test_cpu_offload + - tests/basic_correctness/test_preemption + - tests/basic_correctness/test_cumem.py commands: - # This flashinfer installation will fail on AMD ROCm, so it is set as optional. - - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl || true + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s basic_correctness/test_cumem.py - pytest -v -s basic_correctness/test_basic_correctness.py - pytest -v -s basic_correctness/test_cpu_offload.py + - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py + +- label: Chunked Prefill Test + source_file_dependencies: + - vllm/ + - tests/basic_correctness/test_chunked_prefill + commands: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py -- label: Core Test +- label: Core Test # 10min mirror_hardwares: [amd] fast_check: true + source_file_dependencies: + - vllm/core + - vllm/distributed + - tests/core commands: - pytest -v -s core -- label: Distributed Comm Ops Test - #mirror_hardwares: [amd] +- label: Entrypoints Test # 40min working_dir: "/vllm-workspace/tests" - num_gpus: 2 - commands: - - pytest -v -s distributed/test_comm_ops.py - - pytest -v -s distributed/test_shm_broadcast.py - -- label: 2 Node Tests (4 GPUs in total) - working_dir: "/vllm-workspace/tests" - num_gpus: 2 - num_nodes: 2 - commands: - - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) - - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py - - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py - - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) - - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py - -- label: Distributed Tests (2 GPUs) + fast_check: true mirror_hardwares: [amd] - working_dir: "/vllm-workspace/tests" - num_gpus: 2 + source_file_dependencies: + - vllm/ + - tests/entrypoints/llm + - tests/entrypoints/openai + - tests/entrypoints/test_chat_utils + - tests/entrypoints/offline_mode commands: - - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py - - TEST_DIST_MODEL=llava-hf/llava-v1.6-mistral-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py - - TEST_DIST_MODEL=llava-hf/llava-v1.6-mistral-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py - - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py - -- label: Distributed Tests (4 GPUs) - #mirror_hardwares: [amd] + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_collective_rpc.py + - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process + - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process + - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process + - VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/ + - pytest -v -s entrypoints/test_chat_utils.py + - VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests + +- label: Distributed Tests (4 GPUs) # 10min working_dir: "/vllm-workspace/tests" num_gpus: 4 - fast_check: true + source_file_dependencies: + - vllm/distributed/ + - vllm/core/ + - tests/distributed/test_utils + - tests/distributed/test_pynccl + - tests/spec_decode/e2e/test_integration_dist_tp4 + - tests/compile/test_basic_correctness + - examples/offline_inference/rlhf.py + - examples/offline_inference/rlhf_colocate.py + - tests/examples/offline_inference/data_parallel.py commands: + # test with tp=2 and external_dp=2 + - VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py + - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py + # test with internal dp + - python3 ../examples/offline_inference/data_parallel.py + - pytest -v -s distributed/test_utils.py + - pytest -v -s compile/test_basic_correctness.py - pytest -v -s distributed/test_pynccl.py - # We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here. - # See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context. - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py + # TODO: create a dedicated test section for multi-GPU example tests + # when we have multiple distributed example tests + - pushd ../examples/offline_inference + - VLLM_ENABLE_V1_MULTIPROCESSING=0 python3 rlhf.py + - VLLM_ENABLE_V1_MULTIPROCESSING=0 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py + - popd + +- label: Metrics, Tracing Test # 10min + num_gpus: 2 + source_file_dependencies: + - vllm/ + - tests/metrics + - tests/tracing + commands: + - pytest -v -s metrics + - "pip install \ + 'opentelemetry-sdk>=1.26.0,<1.27.0' \ + 'opentelemetry-api>=1.26.0,<1.27.0' \ + 'opentelemetry-exporter-otlp>=1.26.0,<1.27.0' \ + 'opentelemetry-semantic-conventions-ai>=0.4.1,<0.5.0'" + - pytest -v -s tracing -- label: Pipeline Parallelism Test - working_dir: "/vllm-workspace/tests" - num_gpus: 4 +##### fast check tests ##### +##### 1 GPU test ##### + +- label: Regression Test # 5min + mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + - tests/test_regression commands: - - pytest -v -s distributed/test_pipeline_parallel.py + - pip install modelscope + - pytest -v -s test_regression.py + working_dir: "/vllm-workspace/tests" # optional -- label: Engine Test +- label: Engine Test # 10min mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + - tests/engine + - tests/tokenization + - tests/test_sequence + - tests/test_config + - tests/test_logger commands: - pytest -v -s engine test_sequence.py test_config.py test_logger.py # OOM in the CI unless we run this separately - pytest -v -s tokenization -- label: Entrypoints Test - fast_check: true +- label: V1 Test + #mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + - tests/v1 + commands: + # split the test to avoid interference + - pytest -v -s v1/core + - pytest -v -s v1/entrypoints + - pytest -v -s v1/engine + - pytest -v -s v1/entrypoints + - pytest -v -s v1/sample + - pytest -v -s v1/worker + - pytest -v -s v1/structured_output + - pytest -v -s v1/test_stats.py + - pytest -v -s v1/test_utils.py + - pytest -v -s v1/test_oracle.py + # TODO: accuracy does not match, whether setting + # VLLM_USE_FLASHINFER_SAMPLER or not on H100. + - pytest -v -s v1/e2e + # Integration test for streaming correctness (requires special branch). + - pip install -U git+https://github.com/robertgshaw2-neuralmagic/lm-evaluation-harness.git@streaming-api + - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine + +- label: Examples Test # 25min + working_dir: "/vllm-workspace/examples" + #mirror_hardwares: [amd] + source_file_dependencies: + - vllm/entrypoints + - examples/ + commands: + - pip install tensorizer # for tensorizer test + - python3 offline_inference/basic/generate.py --model facebook/opt-125m + - python3 offline_inference/basic/generate.py --model meta-llama/Llama-2-13b-chat-hf --cpu-offload-gb 10 + - python3 offline_inference/basic/chat.py + - python3 offline_inference/prefix_caching.py + - python3 offline_inference/llm_engine_example.py + - python3 offline_inference/audio_language.py --seed 0 + - python3 offline_inference/vision_language.py --seed 0 + - python3 offline_inference/vision_language_embedding.py --seed 0 + - python3 offline_inference/vision_language_multi_image.py --seed 0 + - VLLM_USE_V1=0 python3 other/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 other/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors + - python3 offline_inference/encoder_decoder.py + - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 + - python3 offline_inference/basic/classify.py + - python3 offline_inference/basic/embed.py + - python3 offline_inference/basic/score.py + - VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2 + +- label: Prefix Caching Test # 9min mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + - tests/prefix_caching + commands: + - pytest -v -s prefix_caching +- label: Samplers Test # 36min + source_file_dependencies: + - vllm/model_executor/layers + - vllm/sampling_metadata.py + - tests/samplers + - tests/conftest.py commands: - - pytest -v -s entrypoints/llm - - pytest -v -s entrypoints/openai + - pytest -v -s samplers + - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers -- label: Examples Test - working_dir: "/vllm-workspace/examples" +- label: LogitsProcessor Test # 5min mirror_hardwares: [amd] + source_file_dependencies: + - vllm/model_executor/layers + - vllm/model_executor/guided_decoding + - tests/test_logits_processor + - tests/model_executor/test_guided_processors commands: - # install tensorizer for tensorize_vllm_model.py - - pip install awscli tensorizer - - python3 offline_inference.py - - python3 cpu_offload.py - - python3 offline_inference_with_prefix.py - - python3 llm_engine_example.py - - python3 offline_inference_vision_language.py - - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - -- label: Inputs Test - #mirror_hardwares: [amd] + - pytest -v -s test_logits_processor.py + - pytest -v -s model_executor/test_guided_processors.py + +- label: Speculative decoding tests # 40min + source_file_dependencies: + - vllm/spec_decode + - tests/spec_decode + - vllm/model_executor/models/eagle.py commands: - - pytest -v -s test_inputs.py - - pytest -v -s multimodal + - pytest -v -s spec_decode/e2e/test_multistep_correctness.py + - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py --ignore=spec_decode/e2e/test_mtp_correctness.py + - pytest -v -s spec_decode/e2e/test_eagle_correctness.py -# - label: Kernels Test %N -# #mirror_hardwares: [amd] -# commands: -# - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl -# - pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT -# parallelism: 4 +- label: LoRA Test %N # 15min each + mirror_hardwares: [amd] + source_file_dependencies: + - vllm/lora + - tests/lora + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py --ignore=lora/test_transfomers_model.py + parallelism: 4 + +- label: PyTorch Fullgraph Smoke Test # 9min + source_file_dependencies: + - vllm/ + - tests/compile + commands: + - pytest -v -s compile/test_basic_correctness.py + # these tests need to be separated, cannot combine + - pytest -v -s compile/piecewise/test_simple.py + - pytest -v -s compile/piecewise/test_toy_llama.py + - pytest -v -s compile/test_pass_manager.py + +- label: PyTorch Fullgraph Test # 18min + source_file_dependencies: + - vllm/ + - tests/compile + commands: + - pytest -v -s compile/test_full_graph.py -- label: Models Test - #mirror_hardwares: [amd] +- label: Kernels Test %N # 1h each + mirror_hardwares: [amd] + source_file_dependencies: + - csrc/ + - vllm/attention + - tests/kernels commands: - - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl - - pytest -v -s models -m \"not vlm\" + - pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + parallelism: 4 -- label: Vision Language Models Test +- label: Tensorizer Test # 11min mirror_hardwares: [amd] + soft_fail: true + source_file_dependencies: + - vllm/model_executor/model_loader + - tests/tensorizer_loader commands: - - pytest -v -s models -m vlm + - apt-get update && apt-get install -y curl libsodium23 + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s tensorizer_loader -- label: Prefix Caching Test +- label: Benchmarks # 9min + working_dir: "/vllm-workspace/.buildkite" mirror_hardwares: [amd] + source_file_dependencies: + - benchmarks/ commands: - - pytest -v -s prefix_caching + - bash run-benchmarks.sh -- label: Samplers Test - #mirror_hardwares: [amd] - command: pytest -v -s samplers +- label: Quantization Test # 33min + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + - tests/quantization + command: VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization -- label: LogitsProcessor Test - mirror_hardwares: [amd] - command: pytest -v -s test_logits_processor.py +- label: LM Eval Small Models # 53min + working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - bash ./run-tests.sh -c configs/models-small.txt -t 1 -- label: Utils Test +- label: OpenAI API correctness + source_file_dependencies: + - csrc/ + - vllm/entrypoints/openai/ + - vllm/model_executor/models/whisper.py + commands: # LMEval+Transcription WER check + - pytest -s entrypoints/openai/correctness/ + +- label: Encoder Decoder tests # 5min + source_file_dependencies: + - vllm/ + - tests/encoder_decoder + commands: + - pytest -v -s encoder_decoder + +- label: OpenAI-Compatible Tool Use # 20 min + fast_check: false + mirror_hardwares: [ amd ] + source_file_dependencies: + - vllm/ + - tests/tool_use commands: - - pytest -v -s test_utils.py - - pytest -v -s test_embedded_commit.py + - pytest -v -s tool_use -- label: Worker Test - mirror_hardwares: [amd] - command: pytest -v -s worker +##### models test ##### + +- label: Basic Models Test # 24min + source_file_dependencies: + - vllm/ + - tests/models + commands: + - pytest -v -s models/test_transformers.py + - pytest -v -s models/test_registry.py + # V1 Test: https://github.com/vllm-project/vllm/issues/14531 + - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -- label: Speculative decoding tests +- label: Language Models Test (Standard) # 32min #mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + - tests/models/decoder_only/language + - tests/models/embedding/language + - tests/models/encoder_decoder/language + commands: + - pytest -v -s models/decoder_only/language -m 'core_model or quant_model' + - pytest -v -s models/embedding/language -m core_model + +- label: Language Models Test (Extended) # 1h10min + optional: true + source_file_dependencies: + - vllm/ + - tests/models/decoder_only/language + - tests/models/embedding/language + - tests/models/encoder_decoder/language commands: - # See https://github.com/vllm-project/vllm/issues/5152 - - export VLLM_ATTENTION_BACKEND=XFORMERS - - pytest -v -s spec_decode - -# - label: LoRA Test %N -# #mirror_hardwares: [amd] -# command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py -# parallelism: 4 - -# - label: LoRA Long Context (Distributed) -# #mirror_hardwares: [amd] -# num_gpus: 4 -# # This test runs llama 13B, so it is required to run on 4 GPUs. -# commands: -# # FIXIT: find out which code initialize cuda before running the test -# # before the fix, we need to use spawn to test it -# - export VLLM_WORKER_MULTIPROC_METHOD=spawn -# - pytest -v -s -x lora/test_long_context.py - -- label: Tensorizer Test + - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' + - pytest -v -s models/embedding/language -m 'not core_model' + +- label: Multi-Modal Models Test (Standard) # 40min #mirror_hardwares: [amd] - fast_check: true + source_file_dependencies: + - vllm/ + - tests/models/decoder_only/audio_language + - tests/models/decoder_only/vision_language + - tests/models/embedding/vision_language + - tests/models/encoder_decoder/audio_language + - tests/models/encoder_decoder/vision_language commands: - - apt-get install -y curl libsodium23 - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -v -s tensorizer_loader + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal + - pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model' + - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model' + - pytest -v -s models/embedding/vision_language -m core_model + - pytest -v -s models/encoder_decoder/audio_language -m core_model + - pytest -v -s models/encoder_decoder/language -m core_model + - pytest -v -s models/encoder_decoder/vision_language -m core_model + +- label: Multi-Modal Models Test (Extended) 1 # 48m + optional: true + source_file_dependencies: + - vllm/ + - tests/models/decoder_only/audio_language + - tests/models/decoder_only/vision_language + - tests/models/embedding/vision_language + - tests/models/encoder_decoder/vision_language + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model' + - pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=0) and not core_model and not quant_model' + # HACK - run phi3v tests separately to sidestep this transformers bug + # https://github.com/huggingface/transformers/issues/34307 + - pytest -v -s models/decoder_only/vision_language/test_phi3v.py + - pytest -v -s --ignore models/decoder_only/vision_language/test_models.py --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model' + - pytest -v -s models/embedding/vision_language -m 'not core_model' + - pytest -v -s models/encoder_decoder/language -m 'not core_model' + - pytest -v -s models/encoder_decoder/vision_language -m 'not core_model' + +- label: Multi-Modal Models Test (Extended) 2 # 38m + optional: true + source_file_dependencies: + - vllm/ + - tests/models/decoder_only/vision_language + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=1) and not core_model and not quant_model' -- label: Metrics Test - mirror_hardwares: [amd] - command: pytest -v -s metrics +# This test is used only in PR development phase to test individual models and should never run on main +- label: Custom Models Test + optional: true + commands: + - echo 'Testing custom models...' + # PR authors can temporarily add commands below to test individual models + # e.g. pytest -v -s models/encoder_decoder/vision_language/test_mllama.py + # *To avoid merge conflicts, remember to REMOVE (not just comment out) them before merging the PR* -- label: Quantization Test +##### 1 GPU test ##### +##### multi gpus test ##### + +- label: Distributed Comm Ops Test # 7min + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/distributed + - tests/distributed + commands: + - pytest -v -s distributed/test_comm_ops.py + - pytest -v -s distributed/test_shm_broadcast.py + +- label: 2 Node Tests (4 GPUs in total) # 16min + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + num_nodes: 2 + source_file_dependencies: + - vllm/distributed/ + - vllm/engine/ + - vllm/executor/ + - vllm/model_executor/models/ + - tests/distributed/ + commands: + - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) + - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' + - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py + - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py + - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) + - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' + +- label: Distributed Tests (2 GPUs) # 40min #mirror_hardwares: [amd] - command: pytest -v -s quantization - -- label: Tracing Test - commands: - - "pip install \ - opentelemetry-sdk \ - opentelemetry-api \ - opentelemetry-exporter-otlp \ - opentelemetry-semantic-conventions-ai" - - pytest -v -s tracing - -- label: Benchmarks - working_dir: "/vllm-workspace/.buildkite" - mirror_hardwares: [amd] + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/distributed/ + - vllm/engine/ + - vllm/executor/ + - vllm/model_executor/models/ + - tests/distributed/ + - vllm/compilation + - vllm/worker/worker_base.py + - vllm/worker/worker.py + - vllm/worker/model_runner.py + - entrypoints/llm/test_collective_rpc.py commands: - - pip install aiohttp - - bash run-benchmarks.sh + - VLLM_ENABLE_V1_MULTIPROCESSING=0 pytest -v -s entrypoints/llm/test_collective_rpc.py + - pytest -v -s ./compile/test_basic_correctness.py + - pytest -v -s ./compile/test_wrapper.py + - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' + - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' + # Avoid importing model tests that cause CUDA reinitialization error + - pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)' + - pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)' + - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)' + - pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)' + # this test fails consistently. + # TODO: investigate and fix + # - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py + - VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py + - VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py + +- label: Plugin Tests (2 GPUs) # 40min + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/plugins/ + - tests/plugins/ + commands: + # begin platform plugin tests, all the code in-between runs on dummy platform + - pip install -e ./plugins/vllm_add_dummy_platform + - pytest -v -s plugins_tests/test_platform_plugins.py + - pip uninstall vllm_add_dummy_platform -y + # end platform plugin tests + # other tests continue here: + - pytest -v -s plugins_tests/test_scheduler_plugins.py + - pip install -e ./plugins/vllm_add_dummy_model + - pytest -v -s distributed/test_distributed_oot.py + - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process + - pytest -v -s models/test_oot_registration.py # it needs a clean process + +- label: Multi-step Tests (4 GPUs) # 36min + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/model_executor/layers/sampler.py + - vllm/sequence.py + - vllm/worker/worker_base.py + - vllm/worker/worker.py + - vllm/worker/multi_step_worker.py + - vllm/worker/model_runner_base.py + - vllm/worker/model_runner.py + - vllm/worker/multi_step_model_runner.py + - vllm/engine + - tests/multi_step + commands: + # this test is quite flaky + # TODO: investigate and fix. + # - pytest -v -s multi_step/test_correctness_async_llm.py + - pytest -v -s multi_step/test_correctness_llm.py -- label: LM Eval Small Models - working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" +- label: Pipeline Parallelism Test # 45min + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/ + - vllm/engine/ + - vllm/executor/ + - vllm/model_executor/models/ + - tests/distributed/ commands: - - pip install lm-eval - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - bash ./run-tests.sh -c configs/models-small.txt -t 1 + - pytest -v -s distributed/test_pp_cudagraph.py + - pytest -v -s distributed/test_pipeline_parallel.py -- label: LM Eval Large Models - gpu: a100 +- label: LoRA TP Test (Distributed) num_gpus: 4 - working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + source_file_dependencies: + - vllm/lora + - tests/lora commands: - - pip install lm-eval - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - bash ./run-tests.sh -c configs/models-large.txt -t 4 + # FIXIT: find out which code initialize cuda before running the test + # before the fix, we need to use spawn to test it + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + # This test runs llama 13B, so it is required to run on 4 GPUs. + - pytest -v -s -x lora/test_long_context.py + # There is some Tensor Parallelism related processing logic in LoRA that + # requires multi-GPU testing for validation. + - pytest -v -s -x lora/test_chatglm3_tp.py + - pytest -v -s -x lora/test_llama_tp.py + - pytest -v -s -x lora/test_minicpmv_tp.py + - pytest -v -s -x lora/test_transfomers_model.py -- label: Documentation Build - working_dir: "/vllm-workspace/test_docs/docs" - fast_check: true - no_gpu: True + +- label: Weight Loading Multiple GPU Test # 33min + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/ + - tests/weight_loading commands: - - pip install -r requirements-docs.txt - - SPHINXOPTS=\"-W\" make html + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt + +- label: Weight Loading Multiple GPU Test - Large Models # optional + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + gpu: a100 + optional: true + source_file_dependencies: + - vllm/ + - tests/weight_loading + commands: + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt + -- label: Distributed Tests (A100) +##### multi gpus test ##### +##### A100 test ##### + +- label: Distributed Tests (A100) # optional gpu: a100 + optional: true num_gpus: 4 - commands: + source_file_dependencies: + - vllm/ + commands: # NOTE: don't test llama model here, it seems hf implementation is buggy # see https://github.com/vllm-project/vllm/pull/5689 for details - pytest -v -s distributed/test_custom_all_reduce.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl - - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - torchrun --nproc_per_node=2 distributed/test_ca_buffer_sharing.py + - TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' - pytest -v -s -x lora/test_mixtral.py + +- label: LM Eval Large Models # optional + gpu: a100 + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - bash ./run-tests.sh -c configs/models-large.txt -t 4 diff --git a/.buildkite/upload-wheels.sh b/.buildkite/upload-wheels.sh new file mode 100644 index 000000000000..a681f8927060 --- /dev/null +++ b/.buildkite/upload-wheels.sh @@ -0,0 +1,77 @@ +#!/usr/bin/env bash + +set -ex + +# Assume wheels are in artifacts/dist/*.whl +wheel_files=(artifacts/dist/*.whl) + +# Check that exactly one wheel is found +if [[ ${#wheel_files[@]} -ne 1 ]]; then + echo "Error: Expected exactly one wheel file in artifacts/dist/, but found ${#wheel_files[@]}" + exit 1 +fi + +# Get the single wheel file +wheel="${wheel_files[0]}" + +# Rename 'linux' to 'manylinux1' in the wheel filename +new_wheel="${wheel/linux/manylinux1}" +mv -- "$wheel" "$new_wheel" +wheel="$new_wheel" + +# Extract the version from the wheel +version=$(unzip -p "$wheel" '**/METADATA' | grep '^Version: ' | cut -d' ' -f2) +echo "Version: $version" + +normal_wheel="$wheel" # Save the original wheel filename + +# If the version contains "dev", rename it to v1.0.0.dev for consistency +if [[ $version == *dev* ]]; then + suffix="${version##*.}" + if [[ $suffix == cu* ]]; then + new_version="1.0.0.dev+${suffix}" + else + new_version="1.0.0.dev" + fi + new_wheel="${wheel/$version/$new_version}" + # use cp to keep both files in the artifacts directory + cp -- "$wheel" "$new_wheel" + wheel="$new_wheel" + version="$new_version" +fi + +# Upload the wheel to S3 +python3 .buildkite/generate_index.py --wheel "$normal_wheel" + +# generate index for this commit +aws s3 cp "$wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/" +aws s3 cp "$normal_wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/" + +if [[ $normal_wheel == *"cu118"* ]]; then + # if $normal_wheel matches cu118, do not upload the index.html + echo "Skipping index files for cu118 wheels" +elif [[ $normal_wheel == *"cu121"* ]]; then + # if $normal_wheel matches cu121, do not upload the index.html + echo "Skipping index files for cu121 wheels" +else + # only upload index.html for cu124 wheels (default wheels) + aws s3 cp index.html "s3://vllm-wheels/$BUILDKITE_COMMIT/vllm/index.html" + aws s3 cp "s3://vllm-wheels/nightly/index.html" "s3://vllm-wheels/$BUILDKITE_COMMIT/index.html" +fi + +# generate index for nightly +aws s3 cp "$wheel" "s3://vllm-wheels/nightly/" +aws s3 cp "$normal_wheel" "s3://vllm-wheels/nightly/" + +if [[ $normal_wheel == *"cu118"* ]]; then + # if $normal_wheel matches cu118, do not upload the index.html + echo "Skipping index files for cu118 wheels" +elif [[ $normal_wheel == *"cu121"* ]]; then + # if $normal_wheel matches cu121, do not upload the index.html + echo "Skipping index files for cu121 wheels" +else + # only upload index.html for cu124 wheels (default wheels) + aws s3 cp index.html "s3://vllm-wheels/nightly/vllm/index.html" +fi + +aws s3 cp "$wheel" "s3://vllm-wheels/$version/" \ No newline at end of file diff --git a/.dockerignore b/.dockerignore index 5cfe0dcb065d..3863656915d0 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1 +1,33 @@ +/.venv +/build +dist vllm/*.so + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +.mypy_cache + +# Distribution / packaging +.Python +/build/ +cmake-build-*/ +CMakeUserPresets.json +develop-eggs/ +/dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 000000000000..860c5c6cd537 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,41 @@ +# See https://help.github.com/articles/about-codeowners/ +# for more info about CODEOWNERS file + +# This lists cover the "core" components of vLLM that require careful review +/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill +/vllm/core @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill +/vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill +/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill +/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill +/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill +/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill +/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth +/vllm/model_executor/guided_decoding @mgoin @russellb +/vllm/multimodal @DarkLight1337 @ywang96 +CMakeLists.txt @tlrmchlsmth + +# vLLM V1 +/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat +/vllm/v1/structured_output @mgoin @russellb + +# Test ownership +/.buildkite/lm-eval-harness @mgoin @simon-mo +/tests/async_engine @njhill @robertgshaw2-redhat @simon-mo +/tests/basic_correctness/test_chunked_prefill @rkooo567 @comaniac +/tests/distributed/test_multi_node_assignment.py @youkaichao +/tests/distributed/test_pipeline_parallel.py @youkaichao +/tests/distributed/test_same_node.py @youkaichao +/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo +/tests/entrypoints/llm/test_guided_generate.py @mgoin @russellb +/tests/kernels @tlrmchlsmth @WoosukKwon +/tests/model_executor/test_guided_processors.py @mgoin @russellb +/tests/models @DarkLight1337 @ywang96 +/tests/multi_step @alexm-redhat @comaniac +/tests/multimodal @DarkLight1337 @ywang96 +/tests/prefix_caching @comaniac @KuntaiDu +/tests/quantization @mgoin @robertgshaw2-redhat +/tests/spec_decode @njhill @LiuXiaoxuanPKU +/tests/test_inputs.py @DarkLight1337 @ywang96 +/tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb +/tests/v1/structured_output @mgoin @russellb +/tests/weight_loading @mgoin @youkaichao diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index 71f4e520135d..d1f6105a4716 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1,2 +1,2 @@ github: [vllm-project] -open_collective: [vllm] +open_collective: vllm diff --git a/.github/ISSUE_TEMPLATE/100-documentation.yml b/.github/ISSUE_TEMPLATE/100-documentation.yml index 501c0aa48b88..74d397b231ac 100644 --- a/.github/ISSUE_TEMPLATE/100-documentation.yml +++ b/.github/ISSUE_TEMPLATE/100-documentation.yml @@ -20,3 +20,10 @@ body: attributes: value: > Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/200-installation.yml b/.github/ISSUE_TEMPLATE/200-installation.yml index df41ade8c3c0..590e56c13781 100644 --- a/.github/ISSUE_TEMPLATE/200-installation.yml +++ b/.github/ISSUE_TEMPLATE/200-installation.yml @@ -38,3 +38,10 @@ body: attributes: value: > Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/300-usage.yml b/.github/ISSUE_TEMPLATE/300-usage.yml index 54763af1058f..004798a388a6 100644 --- a/.github/ISSUE_TEMPLATE/300-usage.yml +++ b/.github/ISSUE_TEMPLATE/300-usage.yml @@ -36,3 +36,10 @@ body: attributes: value: > Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/400-bug report.yml b/.github/ISSUE_TEMPLATE/400-bug report.yml deleted file mode 100644 index ce980c3f4a01..000000000000 --- a/.github/ISSUE_TEMPLATE/400-bug report.yml +++ /dev/null @@ -1,86 +0,0 @@ -name: 🐛 Bug report -description: Raise an issue here if you find a bug. -title: "[Bug]: " -labels: ["bug"] - -body: -- type: markdown - attributes: - value: > - #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+). -- type: textarea - attributes: - label: Your current environment - description: | - Please run the following and paste the output below. - ```sh - wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py - # For security purposes, please feel free to check the contents of collect_env.py before running it. - python collect_env.py - ``` - It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues. - value: | - ```text - The output of `python collect_env.py` - ``` - validations: - required: true -- type: textarea - attributes: - label: 🐛 Describe the bug - description: | - Please provide a clear and concise description of what the bug is. - - If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as succinct (minimal) as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports, etc. For example: - - ```python - from vllm import LLM, SamplingParams - - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - - llm = LLM(model="facebook/opt-125m") - - outputs = llm.generate(prompts, sampling_params) - - # Print the outputs. - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - ``` - - If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. - - Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. - - Please set the environment variable `export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging to help debugging potential issues. - - If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs. - placeholder: | - A clear and concise description of what the bug is. - - ```python - # Sample code to reproduce the problem - ``` - - ``` - The error message you got, with the full traceback. - ``` - validations: - required: true -- type: markdown - attributes: - value: > - ⚠️ Please separate bugs of `transformers` implementation or usage from bugs of `vllm`. If you think anything is wrong with the models' output: - - - Try the counterpart of `transformers` first. If the error appears, please go to [their issues](https://github.com/huggingface/transformers/issues?q=is%3Aissue+is%3Aopen+sort%3Aupdated-desc). - - - If the error only appears in vllm, please provide the detailed script of how you run `transformers` and `vllm`, also highlight the difference and what you expect. - - Thanks for contributing 🎉! diff --git a/.github/ISSUE_TEMPLATE/400-bug-report.yml b/.github/ISSUE_TEMPLATE/400-bug-report.yml new file mode 100644 index 000000000000..d4113da8b5b8 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/400-bug-report.yml @@ -0,0 +1,98 @@ +name: 🐛 Bug report +description: Raise an issue here if you find a bug. +title: "[Bug]: " +labels: ["bug"] + +body: +- type: markdown + attributes: + value: > + #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+). +- type: textarea + attributes: + label: Your current environment + description: | + Please run the following and paste the output below. + ```sh + wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py + # For security purposes, please feel free to check the contents of collect_env.py before running it. + python collect_env.py + ``` + It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues. + value: | +
+ The output of `python collect_env.py` + + ```text + Your output of `python collect_env.py` here + ``` + +
+ validations: + required: true +- type: textarea + attributes: + label: 🐛 Describe the bug + description: | + Please provide a clear and concise description of what the bug is. + + If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as succinct (minimal) as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports, etc. For example: + + ```python + from vllm import LLM, SamplingParams + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + llm = LLM(model="facebook/opt-125m") + + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + ``` + + If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. + + Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. + + Please set the environment variable `export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging to help debugging potential issues. + + If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs. + placeholder: | + A clear and concise description of what the bug is. + + ```python + # Sample code to reproduce the problem + ``` + + ``` + The error message you got, with the full traceback. + ``` + validations: + required: true +- type: markdown + attributes: + value: > + ⚠️ Please separate bugs of `transformers` implementation or usage from bugs of `vllm`. If you think anything is wrong with the models' output: + + - Try the counterpart of `transformers` first. If the error appears, please go to [their issues](https://github.com/huggingface/transformers/issues?q=is%3Aissue+is%3Aopen+sort%3Aupdated-desc). + + - If the error only appears in vllm, please provide the detailed script of how you run `transformers` and `vllm`, also highlight the difference and what you expect. + + Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/500-feature request.yml b/.github/ISSUE_TEMPLATE/500-feature request.yml deleted file mode 100644 index 47a90628c76c..000000000000 --- a/.github/ISSUE_TEMPLATE/500-feature request.yml +++ /dev/null @@ -1,31 +0,0 @@ -name: 🚀 Feature request -description: Submit a proposal/request for a new vllm feature -title: "[Feature]: " -labels: ["feature request"] - -body: -- type: markdown - attributes: - value: > - #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+). -- type: textarea - attributes: - label: 🚀 The feature, motivation and pitch - description: > - A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too. - validations: - required: true -- type: textarea - attributes: - label: Alternatives - description: > - A description of any alternative solutions or features you've considered, if any. -- type: textarea - attributes: - label: Additional context - description: > - Add any other context or screenshots about the feature request. -- type: markdown - attributes: - value: > - Thanks for contributing 🎉! diff --git a/.github/ISSUE_TEMPLATE/500-feature-request.yml b/.github/ISSUE_TEMPLATE/500-feature-request.yml new file mode 100644 index 000000000000..097d88f50930 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/500-feature-request.yml @@ -0,0 +1,38 @@ +name: 🚀 Feature request +description: Submit a proposal/request for a new vllm feature +title: "[Feature]: " +labels: ["feature request"] + +body: +- type: markdown + attributes: + value: > + #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+). +- type: textarea + attributes: + label: 🚀 The feature, motivation and pitch + description: > + A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too. + validations: + required: true +- type: textarea + attributes: + label: Alternatives + description: > + A description of any alternative solutions or features you've considered, if any. +- type: textarea + attributes: + label: Additional context + description: > + Add any other context or screenshots about the feature request. +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/600-new model.yml b/.github/ISSUE_TEMPLATE/600-new model.yml deleted file mode 100644 index bbddbfd67138..000000000000 --- a/.github/ISSUE_TEMPLATE/600-new model.yml +++ /dev/null @@ -1,33 +0,0 @@ -name: 🤗 Support request for a new model from huggingface -description: Submit a proposal/request for a new model from huggingface -title: "[New Model]: " -labels: ["new model"] - -body: -- type: markdown - attributes: - value: > - #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+). - - #### We also highly recommend you read https://docs.vllm.ai/en/latest/models/adding_model.html first to understand how to add a new model. -- type: textarea - attributes: - label: The model to consider. - description: > - A huggingface url, pointing to the model, e.g. https://huggingface.co/openai-community/gpt2 . - validations: - required: true -- type: textarea - attributes: - label: The closest model vllm already supports. - description: > - Here is the list of models already supported by vllm: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models . Which model is the most similar to the model you want to add support for? -- type: textarea - attributes: - label: What's your difficulty of supporting the model you want? - description: > - For example, any new operators or new architecture? -- type: markdown - attributes: - value: > - Thanks for contributing 🎉! diff --git a/.github/ISSUE_TEMPLATE/600-new-model.yml b/.github/ISSUE_TEMPLATE/600-new-model.yml new file mode 100644 index 000000000000..713e76c1a5ce --- /dev/null +++ b/.github/ISSUE_TEMPLATE/600-new-model.yml @@ -0,0 +1,40 @@ +name: 🤗 Support request for a new model from huggingface +description: Submit a proposal/request for a new model from huggingface +title: "[New Model]: " +labels: ["new model"] + +body: +- type: markdown + attributes: + value: > + #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+). + + #### We also highly recommend you read https://docs.vllm.ai/en/latest/contributing/model/adding_model.html first to understand how to add a new model. +- type: textarea + attributes: + label: The model to consider. + description: > + A huggingface url, pointing to the model, e.g. https://huggingface.co/openai-community/gpt2 . + validations: + required: true +- type: textarea + attributes: + label: The closest model vllm already supports. + description: > + Here is the list of models already supported by vllm: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models . Which model is the most similar to the model you want to add support for? +- type: textarea + attributes: + label: What's your difficulty of supporting the model you want? + description: > + For example, any new operators or new architecture? +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/700-performance discussion.yml b/.github/ISSUE_TEMPLATE/700-performance discussion.yml deleted file mode 100644 index 4f8843420a94..000000000000 --- a/.github/ISSUE_TEMPLATE/700-performance discussion.yml +++ /dev/null @@ -1,52 +0,0 @@ -name: ⚡ Discussion on the performance of vllm -description: Submit a proposal/discussion about the performance of vllm -title: "[Performance]: " -labels: ["performance"] - -body: -- type: markdown - attributes: - value: > - #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+). -- type: textarea - attributes: - label: Proposal to improve performance - description: > - How do you plan to improve vllm's performance? - validations: - required: false -- type: textarea - attributes: - label: Report of performance regression - description: > - Please provide detailed description of performance comparison to confirm the regression. You may want to run the benchmark script at https://github.com/vllm-project/vllm/tree/main/benchmarks . - validations: - required: false -- type: textarea - attributes: - label: Misc discussion on performance - description: > - Anything about the performance. - validations: - required: false -- type: textarea - attributes: - label: Your current environment (if you think it is necessary) - description: | - Please run the following and paste the output below. - ```sh - wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py - # For security purposes, please feel free to check the contents of collect_env.py before running it. - python collect_env.py - ``` - It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues. - value: | - ```text - The output of `python collect_env.py` - ``` - validations: - required: false -- type: markdown - attributes: - value: > - Thanks for contributing 🎉! diff --git a/.github/ISSUE_TEMPLATE/700-performance-discussion.yml b/.github/ISSUE_TEMPLATE/700-performance-discussion.yml new file mode 100644 index 000000000000..273f50d59cf7 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/700-performance-discussion.yml @@ -0,0 +1,59 @@ +name: ⚡ Discussion on the performance of vllm +description: Submit a proposal/discussion about the performance of vllm +title: "[Performance]: " +labels: ["performance"] + +body: +- type: markdown + attributes: + value: > + #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+). +- type: textarea + attributes: + label: Proposal to improve performance + description: > + How do you plan to improve vllm's performance? + validations: + required: false +- type: textarea + attributes: + label: Report of performance regression + description: > + Please provide detailed description of performance comparison to confirm the regression. You may want to run the benchmark script at https://github.com/vllm-project/vllm/tree/main/benchmarks . + validations: + required: false +- type: textarea + attributes: + label: Misc discussion on performance + description: > + Anything about the performance. + validations: + required: false +- type: textarea + attributes: + label: Your current environment (if you think it is necessary) + description: | + Please run the following and paste the output below. + ```sh + wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py + # For security purposes, please feel free to check the contents of collect_env.py before running it. + python collect_env.py + ``` + It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues. + value: | + ```text + The output of `python collect_env.py` + ``` + validations: + required: false +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/750-RFC.yml b/.github/ISSUE_TEMPLATE/750-RFC.yml index 5382b124dcd7..e447c077473f 100644 --- a/.github/ISSUE_TEMPLATE/750-RFC.yml +++ b/.github/ISSUE_TEMPLATE/750-RFC.yml @@ -47,3 +47,10 @@ body: attributes: value: > Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/800-misc discussion.yml b/.github/ISSUE_TEMPLATE/800-misc discussion.yml deleted file mode 100644 index ddb10f72db29..000000000000 --- a/.github/ISSUE_TEMPLATE/800-misc discussion.yml +++ /dev/null @@ -1,21 +0,0 @@ -name: 🎲 Misc/random discussions that do not fit into the above categories. -description: Submit a discussion as you like. Note that developers are heavily overloaded and we mainly rely on community users to answer these issues. -title: "[Misc]: " -labels: ["misc"] - -body: -- type: markdown - attributes: - value: > - #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+). -- type: textarea - attributes: - label: Anything you want to discuss about vllm. - description: > - Anything you want to discuss about vllm. - validations: - required: true -- type: markdown - attributes: - value: > - Thanks for contributing 🎉! diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 3ba13e0cec6c..fa40268d6772 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1 +1,5 @@ blank_issues_enabled: false +contact_links: + - name: Questions + url: https://discuss.vllm.ai + about: Ask questions and discuss with other vLLM community members diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 262ce8e1530a..a20c5baf895c 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -2,63 +2,5 @@ FILL IN THE PR DESCRIPTION HERE FIX #xxxx (*link existing issues this PR will resolve*) -**BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE** - ---- - -
- - PR Checklist (Click to Expand) - -

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

- -

PR Title and Classification

-

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

-
    -
  • [Bugfix] for bug fixes.
  • -
  • [CI/Build] for build or continuous integration improvements.
  • -
  • [Doc] for documentation fixes and improvements.
  • -
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • -
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • -
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • -
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • -
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • -
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.
  • -
-

Note: If the PR spans more than one category, please include all relevant prefixes.

- -

Code Quality

- -

The PR need to meet the following code quality standards:

- -
    -
  • We adhere to Google Python style guide and Google C++ style guide.
  • -
  • Pass all linter checks. Please use format.sh to format your code.
  • -
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • -
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • -
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.
  • -
- -

Notes for Large Changes

-

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

- -

What to Expect for the Reviews

- -

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

- -
    -
  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • -
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • -
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • -
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion. -
  • -
- -

Thank You

- -

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

- - -
- - + +**BEFORE SUBMITTING, PLEASE READ ** diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000000..a017d69be991 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,31 @@ +version: 2 +updates: + # Maintain dependencies for GitHub Actions + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "weekly" + labels: ["dependencies"] + open-pull-requests-limit: 5 + reviewers: ["khluu", "simon-mo"] + allow: + - dependency-type: "all" + ignore: + - dependency-name: "*" + update-types: ["version-update:semver-patch"] + - dependency-name: "torch" + - dependency-name: "torchvision" + - dependency-name: "xformers" + - dependency-name: "lm-format-enforcer" + - dependency-name: "gguf" + - dependency-name: "compressed-tensors" + - dependency-name: "ray[cgraph]" # Ray Compiled Graph + - dependency-name: "lm-eval" + groups: + minor-update: + applies-to: version-updates + update-types: ["minor"] diff --git a/.github/mergify.yml b/.github/mergify.yml new file mode 100644 index 000000000000..54f56210b286 --- /dev/null +++ b/.github/mergify.yml @@ -0,0 +1,113 @@ +pull_request_rules: +- name: label-documentation + description: Automatically apply documentation label + conditions: + - or: + - files~=^[^/]+\.md$ + - files~=^docs/ + - files~=^examples/ + actions: + label: + add: + - documentation + +- name: label-ci-build + description: Automatically apply ci/build label + conditions: + - or: + - files~=^\.github/ + - files~=\.buildkite/ + - files~=^cmake/ + - files=CMakeLists.txt + - files~=^Dockerfile + - files~=^requirements.*\.txt + - files=setup.py + actions: + label: + add: + - ci/build + +- name: label-frontend + description: Automatically apply frontend label + conditions: + - files~=^vllm/entrypoints/ + actions: + label: + add: + - frontend + +- name: label-multi-modality + description: Automatically apply multi-modality label + conditions: + - or: + - files~=^vllm/multimodal/ + - files~=^tests/multimodal/ + - files~=^tests/models/multimodal/ + - files~=^tests/models/*/audio_language/ + - files~=^tests/models/*/vision_language/ + - files=tests/models/test_vision.py + actions: + label: + add: + - multi-modality + +- name: label-structured-output + description: Automatically apply structured-output label + conditions: + - or: + - files~=^vllm/model_executor/guided_decoding/ + - files=tests/model_executor/test_guided_processors.py + - files=tests/entrypoints/llm/test_guided_generate.py + - files=benchmarks/benchmark_serving_guided.py + - files=benchmarks/benchmark_guided.py + actions: + label: + add: + - structured-output + +- name: label-speculative-decoding + description: Automatically apply speculative-decoding label + conditions: + - or: + - files~=^vllm/spec_decode/ + - files=vllm/model_executor/layers/spec_decode_base_sampler.py + - files~=^tests/spec_decode/ + actions: + label: + add: + - speculative-decoding + +- name: label-v1 + description: Automatically apply v1 label + conditions: + - or: + - files~=^vllm/v1/ + - files~=^tests/v1/ + actions: + label: + add: + - v1 + +- name: ping author on conflicts and add 'needs-rebase' label + conditions: + - conflict + - -closed + actions: + label: + add: + - needs-rebase + comment: + message: | + This pull request has merge conflicts that must be resolved before it can be + merged. Please rebase the PR, @{{author}}. + + https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork + +- name: remove 'needs-rebase' label when conflict is resolved + conditions: + - -conflict + - -closed + actions: + label: + remove: + - needs-rebase diff --git a/.github/scripts/cleanup_pr_body.sh b/.github/scripts/cleanup_pr_body.sh new file mode 100755 index 000000000000..3246c6f9bc4b --- /dev/null +++ b/.github/scripts/cleanup_pr_body.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +set -eu + +# ensure 1 argument is passed +if [ "$#" -ne 1 ]; then + echo "Usage: $0 " + exit 1 +fi + +PR_NUMBER=$1 +OLD=/tmp/orig_pr_body.txt +NEW=/tmp/new_pr_body.txt + +gh pr view --json body --template "{{.body}}" "${PR_NUMBER}" > "${OLD}" +cp "${OLD}" "${NEW}" + +# Remove "FIX #xxxx (*link existing issues this PR will resolve*)" +sed -i '/FIX #xxxx.*$/d' "${NEW}" + +# Remove "FILL IN THE PR DESCRIPTION HERE" +sed -i '/FILL IN THE PR DESCRIPTION HERE/d' "${NEW}" + +# Remove all lines after and including "**BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE**" +sed -i '/\*\*BEFORE SUBMITTING, PLEASE READ.*\*\*/,$d' "${NEW}" + +# Remove HTML
section that includes text of "PR Checklist (Click to Expand)" +python3 - <.*?.*?PR Checklist \(Click to Expand\).*?.*?
', re.DOTALL) +content = re.sub(pattern, '', content) + +with open("${NEW}", "w") as file: + file.write(content) +EOF + +# Run this only if ${NEW} is different than ${OLD} +if ! cmp -s "${OLD}" "${NEW}"; then + gh pr edit --body-file "${NEW}" "${PR_NUMBER}" + echo + echo "Updated PR body:" + echo + cat "${NEW}" +else + echo "No changes needed" +fi diff --git a/.github/workflows/add_label_automerge.yml b/.github/workflows/add_label_automerge.yml index cd53b764c720..c9d6d4259df9 100644 --- a/.github/workflows/add_label_automerge.yml +++ b/.github/workflows/add_label_automerge.yml @@ -8,7 +8,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Add label - uses: actions/github-script@v5 + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 with: script: | github.rest.issues.addLabels({ diff --git a/.github/workflows/add_label_ready_comment.yml b/.github/workflows/add_label_ready_comment.yml deleted file mode 100644 index 729c1452af03..000000000000 --- a/.github/workflows/add_label_ready_comment.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: Add Ready Label on Ready Comment - -on: - issue_comment: - types: [created] - -jobs: - add-ready-label: - runs-on: ubuntu-latest - if: github.event.issue.pull_request && contains(github.event.comment.body, '/ready') - steps: - - name: Add label - uses: actions/github-script@v5 - with: - script: | - github.rest.issues.addLabels({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: context.issue.number, - labels: ['ready'] - }) - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/clang-format.yml b/.github/workflows/clang-format.yml deleted file mode 100644 index e9b6e28fa6bc..000000000000 --- a/.github/workflows/clang-format.yml +++ /dev/null @@ -1,42 +0,0 @@ -name: clang-format - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - pull_request: - branches: - - main - -jobs: - clang-format: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.11"] - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install clang-format==18.1.5 - - name: Running clang-format - run: | - EXCLUDES=( - 'csrc/moe/topk_softmax_kernels.cu' - 'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu' - 'csrc/punica/bgmv/bgmv_config.h' - 'csrc/punica/bgmv/bgmv_impl.cuh' - 'csrc/punica/bgmv/vec_dtypes.cuh' - 'csrc/punica/punica_ops.cu' - 'csrc/punica/type_convert.h' - ) - find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \ - | grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \ - | xargs clang-format --dry-run --Werror \ No newline at end of file diff --git a/.github/workflows/cleanup_pr_body.yml b/.github/workflows/cleanup_pr_body.yml new file mode 100644 index 000000000000..50fea0c43cb8 --- /dev/null +++ b/.github/workflows/cleanup_pr_body.yml @@ -0,0 +1,26 @@ +name: Cleanup PR Body + +on: + pull_request_target: + types: [opened, reopened, edited] + +permissions: + pull-requests: write + +jobs: + update-description: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Set up Python + uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + with: + python-version: '3.12' + + - name: Update PR description + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: .github/scripts/cleanup_pr_body.sh "${{ github.event.number }}" diff --git a/.github/workflows/lint-and-deploy.yaml b/.github/workflows/lint-and-deploy.yaml new file mode 100644 index 000000000000..b199d0867a64 --- /dev/null +++ b/.github/workflows/lint-and-deploy.yaml @@ -0,0 +1,82 @@ +name: Lint and Deploy Charts + +on: pull_request + +jobs: + lint-and-deploy: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + + - name: Set up Helm + uses: azure/setup-helm@b9e51907a09c216f16ebe8536097933489208112 # v4.3.0 + with: + version: v3.14.4 + + #Python is required because ct lint runs Yamale and yamllint which require Python. + - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + with: + python-version: '3.13' + + - name: Set up chart-testing + uses: helm/chart-testing-action@0d28d3144d3a25ea2cc349d6e59901c4ff469b3b # v2.7.0 + with: + version: v3.10.1 + + - name: Run chart-testing (lint) + run: ct lint --target-branch ${{ github.event.repository.default_branch }} --chart-dirs examples/online_serving/chart-helm --charts examples/online_serving/chart-helm + + - name: Setup minio + run: | + docker network create vllm-net + docker run -d -p 9000:9000 --name minio --net vllm-net \ + -e "MINIO_ACCESS_KEY=minioadmin" \ + -e "MINIO_SECRET_KEY=minioadmin" \ + -v /tmp/data:/data \ + -v /tmp/config:/root/.minio \ + minio/minio server /data + export AWS_ACCESS_KEY_ID=minioadmin + export AWS_SECRET_ACCESS_KEY=minioadmin + export AWS_EC2_METADATA_DISABLED=true + mkdir opt-125m + cd opt-125m && curl -O -Ls "https://huggingface.co/facebook/opt-125m/resolve/main/{pytorch_model.bin,config.json,generation_config.json,merges.txt,special_tokens_map.json,tokenizer_config.json,vocab.json}" && cd .. + aws --endpoint-url http://127.0.0.1:9000/ s3 mb s3://testbucket + aws --endpoint-url http://127.0.0.1:9000/ s3 cp opt-125m/ s3://testbucket/opt-125m --recursive + + - name: Create kind cluster + uses: helm/kind-action@a1b0e391336a6ee6713a0583f8c6240d70863de3 # v1.12.0 + + - name: Build the Docker image vllm cpu + run: docker buildx build -f Dockerfile.cpu -t vllm-cpu-env . + + - name: Configuration of docker images, network and namespace for the kind cluster + run: | + docker pull amazon/aws-cli:2.6.4 + kind load docker-image amazon/aws-cli:2.6.4 --name chart-testing + kind load docker-image vllm-cpu-env:latest --name chart-testing + docker network connect vllm-net "$(docker ps -aqf "name=chart-testing-control-plane")" + kubectl create ns ns-vllm + + - name: Run chart-testing (install) + run: | + export AWS_ACCESS_KEY_ID=minioadmin + export AWS_SECRET_ACCESS_KEY=minioadmin + sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" & + helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env" + + - name: curl test + run: | + kubectl -n ns-vllm port-forward service/test-vllm-service 8001:80 & + sleep 10 + CODE="$(curl -v -f --location http://localhost:8001/v1/completions \ + --header "Content-Type: application/json" \ + --data '{ + "model": "opt-125m", + "prompt": "San Francisco is a", + "max_tokens": 7, + "temperature": 0 + }'):$CODE" + echo "$CODE" \ No newline at end of file diff --git a/.github/workflows/matchers/actionlint.json b/.github/workflows/matchers/actionlint.json new file mode 100644 index 000000000000..4613e1617bfe --- /dev/null +++ b/.github/workflows/matchers/actionlint.json @@ -0,0 +1,17 @@ +{ + "problemMatcher": [ + { + "owner": "actionlint", + "pattern": [ + { + "regexp": "^(?:\\x1b\\[\\d+m)?(.+?)(?:\\x1b\\[\\d+m)*:(?:\\x1b\\[\\d+m)*(\\d+)(?:\\x1b\\[\\d+m)*:(?:\\x1b\\[\\d+m)*(\\d+)(?:\\x1b\\[\\d+m)*: (?:\\x1b\\[\\d+m)*(.+?)(?:\\x1b\\[\\d+m)* \\[(.+?)\\]$", + "file": 1, + "line": 2, + "column": 3, + "message": 4, + "code": 5 + } + ] + } + ] +} diff --git a/.github/workflows/matchers/mypy.json b/.github/workflows/matchers/mypy.json new file mode 100644 index 000000000000..f048fce52894 --- /dev/null +++ b/.github/workflows/matchers/mypy.json @@ -0,0 +1,16 @@ +{ + "problemMatcher": [ + { + "owner": "mypy", + "pattern": [ + { + "regexp": "^(.+):(\\d+):\\s(error|warning):\\s(.+)$", + "file": 1, + "line": 2, + "severity": 3, + "message": 4 + } + ] + } + ] +} diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml deleted file mode 100644 index 721c9c026cf1..000000000000 --- a/.github/workflows/mypy.yaml +++ /dev/null @@ -1,48 +0,0 @@ -name: mypy - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - pull_request: - branches: - - main - -jobs: - ruff: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install mypy==1.9.0 - pip install types-setuptools - pip install types-PyYAML - pip install types-requests - pip install types-setuptools - - name: Mypy - run: | - mypy tests --follow-imports skip - mypy vllm/attention --follow-imports skip - mypy vllm/core --follow-imports skip - mypy vllm/distributed --follow-imports skip - mypy vllm/engine --follow-imports skip - mypy vllm/entrypoints --follow-imports skip - mypy vllm/executor --follow-imports skip - mypy vllm/lora --follow-imports skip - mypy vllm/model_executor --follow-imports skip - mypy vllm/prompt_adapter --follow-imports skip - mypy vllm/spec_decode --follow-imports skip - mypy vllm/worker --follow-imports skip - mypy - diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 000000000000..6ab63a402770 --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,20 @@ +name: pre-commit + +on: + pull_request: + push: + branches: [main] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + with: + python-version: "3.12" + - run: echo "::add-matcher::.github/workflows/matchers/actionlint.json" + - run: echo "::add-matcher::.github/workflows/matchers/mypy.json" + - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 + with: + extra_args: --all-files --hook-stage manual diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 15c2ec05b25d..bfd02879965e 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -21,16 +21,16 @@ jobs: upload_url: ${{ steps.create_release.outputs.upload_url }} steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Extract branch info shell: bash run: | - echo "release_tag=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV + echo "release_tag=${GITHUB_REF#refs/*/}" >> "$GITHUB_ENV" - name: Create Release id: create_release - uses: "actions/github-script@v6" + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 env: RELEASE_TAG: ${{ env.release_tag }} with: @@ -39,67 +39,68 @@ jobs: const script = require('.github/workflows/scripts/create_release.js') await script(github, context, core) - wheel: - name: Build Wheel - runs-on: ${{ matrix.os }} - needs: release - - strategy: - fail-fast: false - matrix: - os: ['ubuntu-20.04'] - python-version: ['3.8', '3.9', '3.10', '3.11'] - pytorch-version: ['2.3.1'] # Must be the most recent version that meets requirements-cuda.txt. - cuda-version: ['11.8', '12.1'] - - steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Setup ccache - uses: hendrikmuhs/ccache-action@v1.2 - with: - create-symlink: true - key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }} - - - name: Set up Linux Env - if: ${{ runner.os == 'Linux' }} - run: | - bash -x .github/workflows/scripts/env.sh - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - - name: Install CUDA ${{ matrix.cuda-version }} - run: | - bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }} - - - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }} - run: | - bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }} - - - name: Build wheel - shell: bash - env: - CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size - run: | - bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }} - wheel_name=$(ls dist/*whl | xargs -n 1 basename) - asset_name=${wheel_name//"linux"/"manylinux1"} - echo "wheel_name=${wheel_name}" >> $GITHUB_ENV - echo "asset_name=${asset_name}" >> $GITHUB_ENV - - - name: Upload Release Asset - uses: actions/upload-release-asset@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ needs.release.outputs.upload_url }} - asset_path: ./dist/${{ env.wheel_name }} - asset_name: ${{ env.asset_name }} - asset_content_type: application/* + # NOTE(simon): No longer build wheel using GitHub Actions. See buildkite's release workflow. + # wheel: + # name: Build Wheel + # runs-on: ${{ matrix.os }} + # needs: release + + # strategy: + # fail-fast: false + # matrix: + # os: ['ubuntu-20.04'] + # python-version: ['3.9', '3.10', '3.11', '3.12'] + # pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements/cuda.txt. + # cuda-version: ['11.8', '12.1'] + + # steps: + # - name: Checkout + # uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + # - name: Setup ccache + # uses: hendrikmuhs/ccache-action@ed74d11c0b343532753ecead8a951bb09bb34bc9 # v1.2.14 + # with: + # create-symlink: true + # key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }} + + # - name: Set up Linux Env + # if: ${{ runner.os == 'Linux' }} + # run: | + # bash -x .github/workflows/scripts/env.sh + + # - name: Set up Python + # uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + # with: + # python-version: ${{ matrix.python-version }} + + # - name: Install CUDA ${{ matrix.cuda-version }} + # run: | + # bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }} + + # - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }} + # run: | + # bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }} + + # - name: Build wheel + # shell: bash + # env: + # CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size + # run: | + # bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }} + # wheel_name=$(find dist -name "*whl" -print0 | xargs -0 -n 1 basename) + # asset_name=${wheel_name//"linux"/"manylinux1"} + # echo "wheel_name=${wheel_name}" >> "$GITHUB_ENV" + # echo "asset_name=${asset_name}" >> "$GITHUB_ENV" + + # - name: Upload Release Asset + # uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 # v1.0.2 + # env: + # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + # with: + # upload_url: ${{ needs.release.outputs.upload_url }} + # asset_path: ./dist/${{ env.wheel_name }} + # asset_name: ${{ env.asset_name }} + # asset_content_type: application/* # (Danielkinz): This last step will publish the .whl to pypi. Warning: untested # - name: Publish package diff --git a/.github/workflows/reminder_comment.yml b/.github/workflows/reminder_comment.yml index 390c88bb6530..27318c2fdd93 100644 --- a/.github/workflows/reminder_comment.yml +++ b/.github/workflows/reminder_comment.yml @@ -2,20 +2,24 @@ name: PR Reminder Comment Bot on: pull_request_target: types: [opened] - jobs: pr_reminder: runs-on: ubuntu-latest steps: - name: Remind to run full CI on PR - uses: actions/github-script@v6 + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 with: script: | github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: context.issue.number, - body: '👋 Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your `fast-check` build on Buildkite UI. \n\nOnce the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).\n\n To run full CI, you can do one of these:\n- Comment `/ready` on the PR\n- Add `ready` label to the PR\n- Enable auto-merge.\n\n🚀' + body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' + + '💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' + + 'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org.\n\n' + + 'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' + + 'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' + + '🚀' }) env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/remove_label_not_ready_comment.yml b/.github/workflows/remove_label_not_ready_comment.yml deleted file mode 100644 index d1da7726eaee..000000000000 --- a/.github/workflows/remove_label_not_ready_comment.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: Remove ready Label on notready Comment - -on: - issue_comment: - types: [created] - -jobs: - add-ready-label: - runs-on: ubuntu-latest - if: github.event.issue.pull_request && contains(github.event.comment.body, '/notready') - steps: - - name: Remove ready label - uses: actions/github-script@v5 - with: - script: | - github.rest.issues.removeLabel({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: context.issue.number, - name: 'ready' - }) - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml deleted file mode 100644 index 773def58fd96..000000000000 --- a/.github/workflows/ruff.yml +++ /dev/null @@ -1,37 +0,0 @@ -name: ruff - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - pull_request: - branches: - - main - -jobs: - ruff: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install ruff==0.1.5 codespell==2.3.0 tomli==2.0.1 isort==5.13.2 - - name: Analysing the code with ruff - run: | - ruff . - - name: Spelling check with codespell - run: | - codespell --toml pyproject.toml - - name: Run isort - run: | - isort . --check-only diff --git a/.github/workflows/scripts/build.sh b/.github/workflows/scripts/build.sh index 60a3978f9abd..0f010832b465 100644 --- a/.github/workflows/scripts/build.sh +++ b/.github/workflows/scripts/build.sh @@ -1,4 +1,5 @@ #!/bin/bash +set -eux python_executable=python$1 cuda_home=/usr/local/cuda-$2 @@ -8,14 +9,15 @@ PATH=${cuda_home}/bin:$PATH LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH # Install requirements -$python_executable -m pip install wheel packaging -$python_executable -m pip install -r requirements-cuda.txt +$python_executable -m pip install -r requirements/build.txt -r requirements/cuda.txt # Limit the number of parallel jobs to avoid OOM export MAX_JOBS=1 -# Make sure punica is built for the release (for LoRA) -export VLLM_INSTALL_PUNICA_KERNELS=1 # Make sure release wheels are built for the following architectures export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX" +export VLLM_FA_CMAKE_GPU_ARCHES="80-real;90-real" + +bash tools/check_repo.sh + # Build $python_executable setup.py bdist_wheel --dist-dir=dist diff --git a/.github/workflows/scripts/create_release.js b/.github/workflows/scripts/create_release.js index 475742118afe..0feb5dc2cf84 100644 --- a/.github/workflows/scripts/create_release.js +++ b/.github/workflows/scripts/create_release.js @@ -1,4 +1,4 @@ -// Uses Github's API to create the release and wait for result. +// Uses GitHub's API to create the release and wait for result. // We use a JS script since github CLI doesn't provide a way to wait for the release's creation and returns immediately. module.exports = async (github, context, core) => { diff --git a/.github/workflows/scripts/cuda-install.sh b/.github/workflows/scripts/cuda-install.sh index 312c6e82f33a..3d0b7a1fe040 100644 --- a/.github/workflows/scripts/cuda-install.sh +++ b/.github/workflows/scripts/cuda-install.sh @@ -1,16 +1,16 @@ #!/bin/bash # Replace '.' with '-' ex: 11.8 -> 11-8 -cuda_version=$(echo $1 | tr "." "-") +cuda_version=$(echo "$1" | tr "." "-") # Removes '-' and '.' ex: ubuntu-20.04 -> ubuntu2004 -OS=$(echo $2 | tr -d ".\-") +OS=$(echo "$2" | tr -d ".\-") # Installs CUDA -wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb +wget -nv "https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb" sudo dpkg -i cuda-keyring_1.1-1_all.deb rm cuda-keyring_1.1-1_all.deb sudo apt -qq update -sudo apt -y install cuda-${cuda_version} cuda-nvcc-${cuda_version} cuda-libraries-dev-${cuda_version} +sudo apt -y install "cuda-${cuda_version}" "cuda-nvcc-${cuda_version}" "cuda-libraries-dev-${cuda_version}" sudo apt clean # Test nvcc diff --git a/.github/workflows/scripts/pytorch-install.sh b/.github/workflows/scripts/pytorch-install.sh index dfc1851d7692..e3cda7dad2d1 100644 --- a/.github/workflows/scripts/pytorch-install.sh +++ b/.github/workflows/scripts/pytorch-install.sh @@ -6,7 +6,7 @@ cuda_version=$3 # Install torch $python_executable -m pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses setuptools && conda clean -ya -$python_executable -m pip install torch==${pytorch_version}+cu${cuda_version//./} --extra-index-url https://download.pytorch.org/whl/cu${cuda_version//./} +$python_executable -m pip install torch=="${pytorch_version}+cu${cuda_version//./}" --extra-index-url "https://download.pytorch.org/whl/cu${cuda_version//./}" # Print version information $python_executable --version diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 000000000000..656f3d3fa7bc --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,52 @@ +name: 'Close inactive issues and PRs' + +on: + schedule: + # Daily at 1:30 AM UTC + - cron: '30 1 * * *' + +jobs: + close-issues-and-pull-requests: + permissions: + issues: write + pull-requests: write + actions: write + runs-on: ubuntu-latest + steps: + - uses: actions/stale@5bef64f19d7facfb25b37b414482c7164d639639 # v9.1.0 + with: + # Increasing this value ensures that changes to this workflow + # propagate to all issues and PRs in days rather than months + operations-per-run: 1000 + + exempt-draft-pr: true + exempt-issue-labels: 'keep-open' + exempt-pr-labels: 'keep-open' + + labels-to-add-when-unstale: 'unstale' + labels-to-remove-when-stale: 'unstale' + + days-before-issue-stale: 90 + days-before-issue-close: 30 + stale-issue-label: 'stale' + stale-issue-message: > + This issue has been automatically marked as stale because it has not + had any activity within 90 days. It will be automatically closed if no + further activity occurs within 30 days. Leave a comment if + you feel this issue should remain open. Thank you! + close-issue-message: > + This issue has been automatically closed due to inactivity. Please + feel free to reopen if you feel it is still relevant. Thank you! + + days-before-pr-stale: 90 + days-before-pr-close: 30 + stale-pr-label: 'stale' + stale-pr-message: > + This pull request has been automatically marked as stale because it + has not had any activity within 90 days. It will be automatically + closed if no further activity occurs within 30 days. Leave a comment + if you feel this pull request should remain open. Thank you! + close-pr-message: > + This pull request has been automatically closed due to inactivity. + Please feel free to reopen if you intend to continue working on it. + Thank you! diff --git a/.github/workflows/yapf.yml b/.github/workflows/yapf.yml deleted file mode 100644 index 04f307bcf8b0..000000000000 --- a/.github/workflows/yapf.yml +++ /dev/null @@ -1,31 +0,0 @@ -name: yapf - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - pull_request: - branches: - - main -jobs: - yapf: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install yapf==0.32.0 - pip install toml==0.10.2 - - name: Running yapf - run: | - yapf --diff --recursive . diff --git a/.gitignore b/.gitignore index 17184b19127c..6f5cbd0733da 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,9 @@ -# vllm commit id, generated by setup.py -vllm/commit_id.py +# version file generated by setuptools-scm +/vllm/_version.py + +# vllm-flash-attn built from source +vllm/vllm_flash_attn/* +!vllm/vllm_flash_attn/fa_utils.py # Byte-compiled / optimized / DLL files __pycache__/ @@ -12,6 +16,8 @@ __pycache__/ # Distribution / packaging .Python build/ +cmake-build-*/ +CMakeUserPresets.json develop-eggs/ dist/ downloads/ @@ -28,6 +34,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +/.deps/ # PyInstaller # Usually these files are written by a python script from a template @@ -73,8 +80,7 @@ instance/ # Sphinx documentation docs/_build/ -docs/source/getting_started/examples/*.rst -!**/*.template.rst +docs/source/getting_started/examples/ # PyBuilder .pybuilder/ @@ -87,6 +93,9 @@ target/ profile_default/ ipython_config.py +# generated files +**/generated/** + # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: @@ -189,4 +198,8 @@ _build/ hip_compat.h # Benchmark dataset -*.json +benchmarks/**/*.json + +# Linting +actionlint +shellcheck*/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000000..484cd171f5f5 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,129 @@ +default_stages: + - pre-commit # Run locally + - manual # Run in CI +exclude: 'vllm/third_party/.*' +repos: +- repo: https://github.com/google/yapf + rev: v0.43.0 + hooks: + - id: yapf + args: [--in-place, --verbose] + additional_dependencies: [toml] # TODO: Remove when yapf is upgraded +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.9.3 + hooks: + - id: ruff + args: [--output-format, github, --fix] +- repo: https://github.com/codespell-project/codespell + rev: v2.4.0 + hooks: + - id: codespell + additional_dependencies: ['tomli'] + args: ['--toml', 'pyproject.toml'] +- repo: https://github.com/PyCQA/isort + rev: 0a0b7a830386ba6a31c2ec8316849ae4d1b8240d # 6.0.0 + hooks: + - id: isort +- repo: https://github.com/pre-commit/mirrors-clang-format + rev: v19.1.7 + hooks: + - id: clang-format + exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))|vllm/third_party/.*' + types_or: [c++, cuda] + args: [--style=file, --verbose] +- repo: https://github.com/jackdewinter/pymarkdown + rev: v0.9.27 + hooks: + - id: pymarkdown + args: [fix] +- repo: https://github.com/rhysd/actionlint + rev: v1.7.7 + hooks: + - id: actionlint +- repo: https://github.com/astral-sh/uv-pre-commit + rev: 0.6.2 + hooks: + - id: pip-compile + args: [requirements/test.in, -o, requirements/test.txt] + files: ^requirements/test\.(in|txt)$ +- repo: local + hooks: + - id: mypy-local + name: Run mypy for local Python installation + entry: tools/mypy.sh 0 "local" + language: python + types: [python] + additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests] + stages: [pre-commit] # Don't run in CI + - id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.9 + entry: tools/mypy.sh 1 "3.9" + language: python + types: [python] + additional_dependencies: *mypy_deps + stages: [manual] # Only run in CI + - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.10 + entry: tools/mypy.sh 1 "3.10" + language: python + types: [python] + additional_dependencies: *mypy_deps + stages: [manual] # Only run in CI + - id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.11 + entry: tools/mypy.sh 1 "3.11" + language: python + types: [python] + additional_dependencies: *mypy_deps + stages: [manual] # Only run in CI + - id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.12 + entry: tools/mypy.sh 1 "3.12" + language: python + types: [python] + additional_dependencies: *mypy_deps + stages: [manual] # Only run in CI + - id: shellcheck + name: Lint shell scripts + entry: tools/shellcheck.sh + language: script + types: [shell] + - id: png-lint + name: Lint PNG exports from excalidraw + entry: tools/png-lint.sh + language: script + types: [png] + - id: signoff-commit + name: Sign-off Commit + entry: bash + args: + - -c + - | + if ! grep -q "^Signed-off-by: $(git config user.name) <$(git config user.email)>" .git/COMMIT_EDITMSG; then + printf "\nSigned-off-by: $(git config user.name) <$(git config user.email)>\n" >> .git/COMMIT_EDITMSG + fi + language: system + verbose: true + stages: [commit-msg] + - id: check-spdx-header + name: Check SPDX headers + entry: python tools/check_spdx_header.py + language: python + types: [python] + - id: check-filenames + name: Check for spaces in all filenames + entry: bash + args: + - -c + - 'git ls-files | grep " " && echo "Filenames should not contain spaces!" && exit 1 || exit 0' + language: system + always_run: true + pass_filenames: false + # Keep `suggestion` last + - id: suggestion + name: Suggestion + entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."' + language: system + verbose: true + pass_filenames: false + # Insert new entries above the `suggestion` entry diff --git a/.readthedocs.yaml b/.readthedocs.yaml index f1959ad2743f..2781ec223b66 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -6,17 +6,16 @@ version: 2 build: os: ubuntu-22.04 tools: - python: "3.8" + python: "3.12" sphinx: - configuration: docs/source/conf.py - fail_on_warning: true + configuration: docs/source/conf.py + fail_on_warning: true # If using Sphinx, optionally build your docs in additional formats such as PDF -formats: - - pdf +formats: [] # Optionally declare the Python requirements required to build your docs python: - install: - - requirements: docs/requirements-docs.txt + install: + - requirements: requirements/docs.txt diff --git a/.shellcheckrc b/.shellcheckrc new file mode 100644 index 000000000000..f3b6eedf8d90 --- /dev/null +++ b/.shellcheckrc @@ -0,0 +1,9 @@ +# rules currently disabled: +# +# SC1091 (info): Not following: was not specified as input (see shellcheck -x) +# SC2004 (style): $/${} is unnecessary on arithmetic variables. +# SC2129 (style): Consider using { cmd1; cmd2; } >> file instead of individual redirects. +# SC2155 (warning): Declare and assign separately to avoid masking return values. +# SC2164 (warning): Use 'cd ... || exit' or 'cd ... || return' in case cd fails. +# +disable=SC1091,SC2004,SC2129,SC2155,SC2164 diff --git a/CMakeLists.txt b/CMakeLists.txt index bf00a36edc50..65d1ddbeee0b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,16 @@ -cmake_minimum_required(VERSION 3.21) +cmake_minimum_required(VERSION 3.26) +# When building directly using CMake, make sure you run the install step +# (it places the .so files in the correct location). +# +# Example: +# mkdir build && cd build +# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_INSTALL_PREFIX=.. .. +# cmake --build . --target install +# +# If you want to only build one target, make sure to install it manually: +# cmake --build . --target _C +# cmake --install . --component _C project(vllm_extensions LANGUAGES CXX) # CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py) @@ -10,17 +21,20 @@ message(STATUS "Target device: ${VLLM_TARGET_DEVICE}") include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) +# Suppress potential warnings about unused manually-specified variables +set(ignoreMe "${VLLM_PYTHON_PATH}") + # # Supported python versions. These versions will be searched in order, the # first match will be selected. These should be kept in sync with setup.py. # -set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11") +set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12") # Supported NVIDIA architectures. -set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") +set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0") # Supported AMD GPU architectures. -set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100") +set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101") # # Supported/expected torch versions for CUDA/ROCm. @@ -32,8 +46,8 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11 # requirements.txt files and should be kept consistent. The ROCm torch # versions are derived from Dockerfile.rocm # -set(TORCH_SUPPORTED_VERSION_CUDA "2.3.1") -set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0") +set(TORCH_SUPPORTED_VERSION_CUDA "2.6.0") +set(TORCH_SUPPORTED_VERSION_ROCM "2.6.0") # # Try to find python package with an executable that exactly matches @@ -74,7 +88,7 @@ if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND if (VLLM_TARGET_DEVICE STREQUAL "cpu") include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake) else() - message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}") + return() endif() return() endif() @@ -108,14 +122,32 @@ else() message(FATAL_ERROR "Can't find CUDA or HIP installation.") endif() -# -# Override the GPU architectures detected by cmake/torch and filter them by -# the supported versions for the current language. -# The final set of arches is stored in `VLLM_GPU_ARCHES`. -# -override_gpu_arches(VLLM_GPU_ARCHES - ${VLLM_GPU_LANG} - "${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}") + +if(VLLM_GPU_LANG STREQUAL "CUDA") + # + # For cuda we want to be able to control which architectures we compile for on + # a per-file basis in order to cut down on compile time. So here we extract + # the set of architectures we want to compile for and remove the from the + # CMAKE_CUDA_FLAGS so that they are not applied globally. + # + clear_cuda_arches(CUDA_ARCH_FLAGS) + extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}") + message(STATUS "CUDA target architectures: ${CUDA_ARCHS}") + # Filter the target architectures by the supported supported archs + # since for some files we will build for all CUDA_ARCHS. + cuda_archs_loose_intersection(CUDA_ARCHS + "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}") + message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}") +else() + # + # For other GPU targets override the GPU architectures detected by cmake/torch + # and filter them by the supported versions for the current language. + # The final set of arches is stored in `VLLM_GPU_ARCHES`. + # + override_gpu_arches(VLLM_GPU_ARCHES + ${VLLM_GPU_LANG} + "${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}") +endif() # # Query torch for additional GPU compilation flags for the given @@ -131,9 +163,64 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}") endif() + # -# Define extension targets +# Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process. +# setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache. +# Each dependency that produces build artifacts should override its BINARY_DIR to avoid +# conflicts between build types. It should instead be set to ${CMAKE_BINARY_DIR}/. # +include(FetchContent) +file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists +message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}") + +# +# Set rocm version dev int. +# +if(VLLM_GPU_LANG STREQUAL "HIP") + # + # Overriding the default -O set up by cmake, adding ggdb3 for the most verbose devug info + # + set(CMAKE_${VLLM_GPU_LANG}_FLAGS_DEBUG "${CMAKE_${VLLM_GPU_LANG}_FLAGS_DEBUG} -O0 -ggdb3") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -ggdb3") + + + # + # Certain HIP functions are marked as [[nodiscard]], yet vllm ignores the result which generates + # a lot of warnings that always mask real issues. Suppressing until this is properly addressed. + # + set(CMAKE_${VLLM_GPU_LANG}_FLAGS "${CMAKE_${VLLM_GPU_LANG}_FLAGS} -Wno-unused-result") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-result") +endif() + +# +# Define other extension targets +# + +# +# cumem_allocator extension +# + +set(VLLM_CUMEM_EXT_SRC + "csrc/cumem_allocator.cpp") + +set_gencode_flags_for_srcs( + SRCS "${VLLM_CUMEM_EXT_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + +if(VLLM_GPU_LANG STREQUAL "CUDA") + message(STATUS "Enabling cumem allocator extension.") + # link against cuda driver library + list(APPEND CUMEM_LIBS CUDA::cuda_driver) + define_gpu_extension_target( + cumem_allocator + DESTINATION vllm + LANGUAGE CXX + SOURCES ${VLLM_CUMEM_EXT_SRC} + LIBRARIES ${CUMEM_LIBS} + USE_SABI 3.8 + WITH_SOABI) +endif() # # _C extension @@ -141,58 +228,312 @@ endif() set(VLLM_EXT_SRC "csrc/cache_kernels.cu" - "csrc/attention/attention_kernels.cu" + "csrc/attention/paged_attention_v1.cu" + "csrc/attention/paged_attention_v2.cu" "csrc/pos_encoding_kernels.cu" "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" - "csrc/quantization/squeezellm/quant_cuda_kernel.cu" + "csrc/layernorm_quant_kernels.cu" "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/common.cu" + "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" + "csrc/quantization/gguf/gguf_kernel.cu" "csrc/cuda_utils_kernels.cu" - "csrc/moe_align_block_size_kernels.cu" "csrc/prepare_inputs/advance_step.cu" "csrc/torch_bindings.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") - include(FetchContent) - SET(CUTLASS_ENABLE_HEADERS_ONLY=ON) - FetchContent_Declare( + SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") + + # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. + # Please keep this in sync with FetchContent_Declare line below. + set(CUTLASS_REVISION "v3.8.0" CACHE STRING "CUTLASS revision to use") + + # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided + if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) + set(VLLM_CUTLASS_SRC_DIR $ENV{VLLM_CUTLASS_SRC_DIR}) + endif() + + if(VLLM_CUTLASS_SRC_DIR) + if(NOT IS_ABSOLUTE VLLM_CUTLASS_SRC_DIR) + get_filename_component(VLLM_CUTLASS_SRC_DIR "${VLLM_CUTLASS_SRC_DIR}" ABSOLUTE) + endif() + message(STATUS "The VLLM_CUTLASS_SRC_DIR is set, using ${VLLM_CUTLASS_SRC_DIR} for compilation") + FetchContent_Declare(cutlass SOURCE_DIR ${VLLM_CUTLASS_SRC_DIR}) + else() + FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - # CUTLASS 3.5.0 - GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc - ) + # Please keep this in sync with CUTLASS_REVISION line above. + GIT_TAG v3.8.0 + GIT_PROGRESS TRUE + + # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. + # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. + # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE + GIT_SHALLOW TRUE + ) + endif() FetchContent_MakeAvailable(cutlass) list(APPEND VLLM_EXT_SRC + "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" + "csrc/mamba/causal_conv1d/causal_conv1d.cu" "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" - "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" - "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" - "csrc/quantization/gptq_marlin/gptq_marlin.cu" - "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" - "csrc/quantization/gptq_marlin/awq_marlin_repack.cu" - "csrc/quantization/fp8/fp8_marlin.cu" "csrc/custom_all_reduce.cu" + "csrc/permute_cols.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") + "csrc/quantization/fp4/nvfp4_quant_entry.cu" + "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" + "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" + "csrc/cutlass_extensions/common.cpp") + + set_gencode_flags_for_srcs( + SRCS "${VLLM_EXT_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + + # Only build Marlin kernels if we are building for at least some compatible archs. + # Keep building Marlin for 9.0 as there are some group sizes and shapes that + # are not supported by Machete yet. + cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") + if (MARLIN_ARCHS) + set(MARLIN_SRCS + "csrc/quantization/fp8/fp8_marlin.cu" + "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" + "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" + "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu" + "csrc/quantization/gptq_marlin/gptq_marlin.cu" + "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" + "csrc/quantization/gptq_marlin/awq_marlin_repack.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_SRCS}" + CUDA_ARCHS "${MARLIN_ARCHS}") + list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}") + message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}") + else() + message(STATUS "Not building Marlin kernels as no compatible archs found" + " in CUDA target architectures") + endif() + + # Only build AllSpark kernels if we are building for at least some compatible archs. + cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}") + if (ALLSPARK_ARCHS) + set(ALLSPARK_SRCS + "csrc/quantization/gptq_allspark/allspark_repack.cu" + "csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu") + set_gencode_flags_for_srcs( + SRCS "${ALLSPARK_SRCS}" + CUDA_ARCHS "${ALLSPARK_ARCHS}") + list(APPEND VLLM_EXT_SRC "${ALLSPARK_SRCS}") + message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}") + else() + message(STATUS "Not building AllSpark kernels as no compatible archs found" + " in CUDA target architectures") + endif() + + + set(SCALED_MM_3X_ARCHS) + # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require + # CUDA 12.0 or later + cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS) + set(SRCS + "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu" + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu" + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu" + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu" + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1") + # Let scaled_mm_c2x know it doesn't need to build these arches + list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") + message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS) + message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running FP8 quantized models on " + "Hopper.") + else() + message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + + # The cutlass_scaled_mm kernels for Blackwell (c3x, i.e. CUTLASS 3.x) require + # CUDA 12.8 or later + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;12.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS) + set(SRCS + "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu" + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu" + ) + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1") + # Let scaled_mm_c2x know it doesn't need to build these arches + list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") + message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS) + message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is " + "not >= 12.8, we recommend upgrading to CUDA 12.8 or " + "later if you intend on running FP8 quantized models on " + "Blackwell.") + else() + message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + + # + # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) + # kernels for the remaining archs that are not already built for 3x. + cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS + "7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") + # subtract out the archs that are already built for 3x + list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) + if (SCALED_MM_2X_ARCHS) + set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_2X_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1") + message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}") + else() + if (SCALED_MM_3X_ARCHS) + message(STATUS "Not building scaled_mm_c2x as all archs are already built" + " for and covered by scaled_mm_c3x") + else() + message(STATUS "Not building scaled_mm_c2x as no compatible archs found " + "in CUDA target architectures") + endif() + endif() # - # The CUTLASS kernels for Hopper require sm90a to be enabled. - # This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a. - # That adds an extra 17MB to compiled binary, so instead we selectively enable it. - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) - set_source_files_properties( - "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" - PROPERTIES - COMPILE_FLAGS - "-gencode arch=compute_90a,code=sm_90a") + # 2:4 Sparse Kernels + + # The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor + # require CUDA 12.2 or later (and only work on Hopper). + cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS) + set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1") + message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS) + message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is " + "not >= 12.2, we recommend upgrading to CUDA 12.2 or later " + "if you intend on running FP8 sparse quantized models on Hopper.") + else() + message(STATUS "Not building sparse_scaled_mm_c3x as no compatible archs found " + "in CUDA target architectures") + endif() endif() + # FP4 Archs and flags + cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS) + set(SRCS + "csrc/quantization/fp4/nvfp4_quant_kernels.cu" + "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${FP4_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4=1") + message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") + else() + message(STATUS "Not building NVFP4 as no compatible archs were found.") + # clear FP4_ARCHS + set(FP4_ARCHS) + endif() + + # + # Machete kernels + + # The machete kernels only work on hopper and require CUDA 12.0 or later. + # Only build Machete kernels if we are building for something compatible with sm90a + cuda_archs_loose_intersection(MACHETE_ARCHS "9.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND MACHETE_ARCHS) + # + # For the Machete kernels we automatically generate sources for various + # preselected input type pairs and schedules. + # Generate sources: + set(MACHETE_GEN_SCRIPT + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py) + file(MD5 ${MACHETE_GEN_SCRIPT} MACHETE_GEN_SCRIPT_HASH) + + message(STATUS "Machete generation script hash: ${MACHETE_GEN_SCRIPT_HASH}") + message(STATUS "Last run machete generate script hash: $CACHE{MACHETE_GEN_SCRIPT_HASH}") + + if (NOT DEFINED CACHE{MACHETE_GEN_SCRIPT_HASH} + OR NOT $CACHE{MACHETE_GEN_SCRIPT_HASH} STREQUAL ${MACHETE_GEN_SCRIPT_HASH}) + execute_process( + COMMAND ${CMAKE_COMMAND} -E env + PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH + ${Python_EXECUTABLE} ${MACHETE_GEN_SCRIPT} + RESULT_VARIABLE machete_generation_result + OUTPUT_VARIABLE machete_generation_output + OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log + ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log + ) + + if (NOT machete_generation_result EQUAL 0) + message(FATAL_ERROR "Machete generation failed." + " Result: \"${machete_generation_result}\"" + "\nCheck the log for details: " + "${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log") + else() + set(MACHETE_GEN_SCRIPT_HASH ${MACHETE_GEN_SCRIPT_HASH} + CACHE STRING "Last run machete generate script hash" FORCE) + message(STATUS "Machete generation completed successfully.") + endif() + else() + message(STATUS "Machete generation script has not changed, skipping generation.") + endif() + + # Add machete generated sources + file(GLOB MACHETE_GEN_SOURCES "csrc/quantization/machete/generated/*.cu") + list(APPEND VLLM_EXT_SRC ${MACHETE_GEN_SOURCES}) + + # forward compatible + set_gencode_flags_for_srcs( + SRCS "${MACHETE_GEN_SOURCES}" + CUDA_ARCHS "${MACHETE_ARCHS}") + + list(APPEND VLLM_EXT_SRC + csrc/quantization/machete/machete_pytorch.cu) + + message(STATUS "Building Machete kernels for archs: ${MACHETE_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 + AND MACHETE_ARCHS) + message(STATUS "Not building Machete kernels as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running w4a16 quantized models on " + "Hopper.") + else() + message(STATUS "Not building Machete kernels as no compatible archs " + "found in CUDA target architectures") + endif() + endif() +# if CUDA endif endif() +message(STATUS "Enabling C extension.") define_gpu_extension_target( _C DESTINATION vllm @@ -200,18 +541,68 @@ define_gpu_extension_target( SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} - INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) +# If CUTLASS is compiled on NVCC >= 12.5, it by default uses +# cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the +# driver API. This causes problems when linking with earlier versions of CUDA. +# Setting this variable sidesteps the issue by calling the driver directly. +target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) + # # _moe_C extension # set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" + "csrc/moe/moe_align_sum_kernels.cu" "csrc/moe/topk_softmax_kernels.cu") +if(VLLM_GPU_LANG STREQUAL "CUDA") + list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu") +endif() + +set_gencode_flags_for_srcs( + SRCS "${VLLM_MOE_EXT_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + +if(VLLM_GPU_LANG STREQUAL "CUDA") + set(VLLM_MOE_WNA16_SRC + "csrc/moe/moe_wna16.cu") + + set_gencode_flags_for_srcs( + SRCS "${VLLM_MOE_WNA16_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + + list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}") + cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") + if (MARLIN_MOE_ARCHS) + set(MARLIN_MOE_SRC + "csrc/moe/marlin_kernels/marlin_moe_kernel.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu" + "csrc/moe/marlin_moe_ops.cu") + + set_gencode_flags_for_srcs( + SRCS "${MARLIN_MOE_SRC}" + CUDA_ARCHS "${MARLIN_MOE_ARCHS}") + + list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_SRC}") + message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}") + else() + message(STATUS "Not building Marlin MOE kernels as no compatible archs found" + " in CUDA target architectures") + endif() +endif() + +message(STATUS "Enabling moe extension.") define_gpu_extension_target( _moe_C DESTINATION vllm @@ -222,90 +613,27 @@ define_gpu_extension_target( USE_SABI 3 WITH_SOABI) -# -# _punica_C extension -# - -set(VLLM_PUNICA_EXT_SRC - "csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu" - "csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu" - "csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu" - "csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu" - "csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu" - "csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu" - "csrc/punica/punica_ops.cu" - "csrc/punica/torch_bindings.cpp") - -# -# Copy GPU compilation flags+update for punica -# -set(VLLM_PUNICA_GPU_FLAGS ${VLLM_GPU_FLAGS}) -list(REMOVE_ITEM VLLM_PUNICA_GPU_FLAGS - "-D__CUDA_NO_HALF_OPERATORS__" - "-D__CUDA_NO_HALF_CONVERSIONS__" - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" - "-D__CUDA_NO_HALF2_OPERATORS__") - -# -# Filter out CUDA architectures < 8.0 for punica. -# -if (${VLLM_GPU_LANG} STREQUAL "CUDA") - set(VLLM_PUNICA_GPU_ARCHES) - foreach(ARCH ${VLLM_GPU_ARCHES}) - string_to_ver(CODE_VER ${ARCH}) - if (CODE_VER GREATER_EQUAL 8.0) - list(APPEND VLLM_PUNICA_GPU_ARCHES ${ARCH}) - endif() - endforeach() - message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") -elseif(${VLLM_GPU_LANG} STREQUAL "HIP") - set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES}) - message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") -endif() +if(VLLM_GPU_LANG STREQUAL "HIP") + # + # _rocm_C extension + # + set(VLLM_ROCM_EXT_SRC + "csrc/rocm/torch_bindings.cpp" + "csrc/rocm/attention.cu") -if (VLLM_PUNICA_GPU_ARCHES) define_gpu_extension_target( - _punica_C + _rocm_C DESTINATION vllm LANGUAGE ${VLLM_GPU_LANG} - SOURCES ${VLLM_PUNICA_EXT_SRC} - COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS} - ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES} + SOURCES ${VLLM_ROCM_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} USE_SABI 3 WITH_SOABI) -else() - message(WARNING "Unable to create _punica_C target because none of the " - "requested architectures (${VLLM_GPU_ARCHES}) are supported, i.e. >= 8.0") endif() -# -# Add the `default` target which detects which extensions should be -# built based on platform/architecture. This is the same logic that -# setup.py uses to select which extensions should be built and should -# be kept in sync. -# -# The `default` target makes direct use of cmake easier since knowledge -# of which extensions are supported has been factored in, e.g. -# -# mkdir build && cd build -# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm .. -# cmake --build . --target default -# -add_custom_target(default) - -if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") - message(STATUS "Enabling C extension.") - add_dependencies(default _C) - - message(STATUS "Enabling moe extension.") - add_dependencies(default _moe_C) - - # Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or - # VLLM_INSTALL_PUNICA_KERNELS is set in the environment and - # there are supported target arches. - if (VLLM_PUNICA_GPU_ARCHES AND - (ENV{VLLM_INSTALL_PUNICA_KERNELS} OR VLLM_INSTALL_PUNICA_KERNELS)) - message(STATUS "Enabling punica extension.") - add_dependencies(default _punica_C) - endif() -endif() +# For CUDA we also build and ship some external projects. +if (VLLM_GPU_LANG STREQUAL "CUDA") + include(cmake/external_projects/flashmla.cmake) + include(cmake/external_projects/vllm_flash_attn.cmake) +endif () diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 000000000000..5268ff135c9d --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,127 @@ + +# vLLM Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socioeconomic status, +nationality, personal appearance, race, caste, color, religion, or sexual +identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the overall + community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or advances of + any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email address, + without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official email address, +posting via an official social media account, or acting as an appointed +representative at an online or offline/IRL event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement in the #code-of-conduct +channel in the [vLLM Slack](https://slack.vllm.ai). +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of +actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or permanent +ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the +community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org/), +version 2.1, available at +[v2.1](https://www.contributor-covenant.org/version/2/1/code_of_conduct.html). + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/inclusion). + +For answers to common questions about this code of conduct, see the +[Contributor Covenant FAQ](https://www.contributor-covenant.org/faq). Translations are available at +[Contributor Covenant translations](https://www.contributor-covenant.org/translations). diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 81a8db2b268b..6d46a6dca371 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,56 +1,3 @@ # Contributing to vLLM -Thank you for your interest in contributing to vLLM! -Our community is open to everyone and welcomes all kinds of contributions, no matter how small or large. -There are several ways you can contribute to the project: - -- Identify and report any issues or bugs. -- Request or add a new model. -- Suggest or implement new features. - -However, remember that contributions aren't just about code. -We believe in the power of community support; thus, answering queries, assisting others, and enhancing the documentation are highly regarded and beneficial contributions. - -Finally, one of the most impactful ways to support us is by raising awareness about vLLM. -Talk about it in your blog posts, highlighting how it's driving your incredible projects. -Express your support on Twitter if vLLM aids you, or simply offer your appreciation by starring our repository. - - -## Setup for development - -### Build from source - -```bash -pip install -e . # This may take several minutes. -``` - -### Testing - -```bash -pip install -r requirements-dev.txt - -# linting and formatting -bash format.sh -# Static type checking -mypy -# Unit tests -pytest tests/ -``` -**Note:** Currently, the repository does not pass the mypy tests. - - -## Contributing Guidelines - -### Issue Reporting - -If you encounter a bug or have a feature request, please check our issues page first to see if someone else has already reported it. -If not, please file a new issue, providing as much relevant information as possible. - -### Pull Requests & Code Reviews - -Please check the PR checklist in the [PR template](.github/PULL_REQUEST_TEMPLATE.md) for detailed guide for contribution. - -### Thank You - -Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. -Your contributions make vLLM a great tool for everyone! +You may find information about contributing to vLLM on [docs.vllm.ai](https://docs.vllm.ai/en/latest/contributing/overview.html). diff --git a/DCO b/DCO new file mode 100644 index 000000000000..49b8cb054926 --- /dev/null +++ b/DCO @@ -0,0 +1,34 @@ +Developer Certificate of Origin +Version 1.1 + +Copyright (C) 2004, 2006 The Linux Foundation and its contributors. + +Everyone is permitted to copy and distribute verbatim copies of this +license document, but changing it is not allowed. + + +Developer's Certificate of Origin 1.1 + +By making a contribution to this project, I certify that: + +(a) The contribution was created in whole or in part by me and I + have the right to submit it under the open source license + indicated in the file; or + +(b) The contribution is based upon previous work that, to the best + of my knowledge, is covered under an appropriate open source + license and I have the right under that license to submit that + work with modifications, whether created in whole or in part + by me, under the same open source license (unless I am + permitted to submit under a different license), as indicated + in the file; or + +(c) The contribution was provided directly to me by some other + person who certified (a), (b) or (c) and I have not modified + it. + +(d) I understand and agree that this project and the contribution + are public and that a record of the contribution (including all + personal information I submit with it, including my sign-off) is + maintained indefinitely and may be redistributed consistent with + this project or the open source license(s) involved. diff --git a/Dockerfile b/Dockerfile index b9a56e67e8d7..d1ecef586d50 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,35 +2,46 @@ # to run the OpenAI compatible server. # Please update any changes made here to -# docs/source/dev/dockerfile/dockerfile.rst and -# docs/source/assets/dev/dockerfile-stages-dependency.png +# docs/source/contributing/dockerfile/dockerfile.md and +# docs/source/assets/contributing/dockerfile-stages-dependency.png ARG CUDA_VERSION=12.4.1 #################### BASE BUILD IMAGE #################### # prepare basic build environment FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base - ARG CUDA_VERSION=12.4.1 -ARG PYTHON_VERSION=3.10 - +ARG PYTHON_VERSION=3.12 +ARG TARGETPLATFORM ENV DEBIAN_FRONTEND=noninteractive +# Install Python and other dependencies RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ && apt-get update -y \ - && apt-get install -y ccache software-properties-common \ + && apt-get install -y ccache software-properties-common git curl sudo \ && add-apt-repository ppa:deadsnakes/ppa \ && apt-get update -y \ && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \ - && if [ "${PYTHON_VERSION}" != "3" ]; then update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1; fi \ - && python3 --version - -RUN apt-get update -y \ - && apt-get install -y git curl sudo - -# Install pip s.t. it will be compatible with our PYTHON_VERSION -RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} -RUN python3 -m pip --version + && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ + && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \ + && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ + && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ + && python3 --version && python3 -m pip --version +# Install uv for faster pip installs +RUN --mount=type=cache,target=/root/.cache/uv \ + python3 -m pip install uv + +# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out +# Reference: https://github.com/astral-sh/uv/pull/1694 +ENV UV_HTTP_TIMEOUT=500 + +# Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519 +# as it was causing spam when compiling the CUTLASS kernels +RUN apt-get install -y gcc-10 g++-10 +RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 110 --slave /usr/bin/g++ g++ /usr/bin/g++-10 +RUN <> /etc/environment +# Install Python and other dependencies RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ && apt-get update -y \ - && apt-get install -y ccache software-properties-common \ + && apt-get install -y ccache software-properties-common git curl wget sudo vim python3-pip \ + && apt-get install -y ffmpeg libsm6 libxext6 libgl1 \ && add-apt-repository ppa:deadsnakes/ppa \ && apt-get update -y \ - && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \ - && if [ "${PYTHON_VERSION}" != "3" ]; then update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1; fi \ - && python3 --version - -RUN apt-get update -y \ - && apt-get install -y python3-pip git vim curl libibverbs-dev - -# Install pip s.t. it will be compatible with our PYTHON_VERSION -RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} -RUN python3 -m pip --version + && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \ + && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ + && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \ + && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ + && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ + && python3 --version && python3 -m pip --version +# Install uv for faster pip installs +RUN --mount=type=cache,target=/root/.cache/uv \ + python3 -m pip install uv + +# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out +# Reference: https://github.com/astral-sh/uv/pull/1694 +ENV UV_HTTP_TIMEOUT=500 # Workaround for https://github.com/openai/triton/issues/2507 and # https://github.com/pytorch/pytorch/issues/107960 -- hopefully @@ -184,20 +207,50 @@ RUN python3 -m pip --version # or future versions of triton. RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ -# install vllm wheel first, so that torch etc will be installed -RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ - --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install dist/*.whl --verbose +# arm64 (GH200) build follows the practice of "use existing pytorch" build, +# we need to install torch and torchvision from the nightly builds first, +# pytorch will not appear as a vLLM dependency in all of the following steps +# after this step +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ + uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu128 "torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319"; \ + uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu128 --pre pytorch_triton==3.3.0+gitab727c40; \ + fi -RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \ - --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir +# Install vllm wheel first, so that torch etc will be installed. +RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ + --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system dist/*.whl --verbose + +# If we need to build FlashInfer wheel before its release: +# $ export FLASHINFER_ENABLE_AOT=1 +# $ # Note we remove 7.0 from the arch list compared to the list below, since FlashInfer only supports sm75+ +# $ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.6 8.9 9.0+PTX' +# $ git clone https://github.com/flashinfer-ai/flashinfer.git --recursive +# $ cd flashinfer +# $ git checkout 524304395bd1d8cd7d07db083859523fcaa246a4 +# $ rm -rf build +# $ python3 setup.py bdist_wheel --dist-dir=dist --verbose +# $ ls dist +# $ # upload the wheel to a public location, e.g. https://wheels.vllm.ai/flashinfer/524304395bd1d8cd7d07db083859523fcaa246a4/flashinfer_python-0.2.1.post1+cu124torch2.5-cp38-abi3-linux_x86_64.whl + +RUN --mount=type=cache,target=/root/.cache/uv \ +. /etc/environment && \ +if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \ + uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post2/flashinfer_python-0.2.1.post2+cu124torch2.6-cp38-abi3-linux_x86_64.whl ; \ +fi +COPY examples examples + +# Although we build Flashinfer with AOT mode, there's still +# some issues w.r.t. JIT compilation. Therefore we need to +# install build dependencies for JIT compilation. +# TODO: Remove this once FlashInfer AOT wheel is fixed +COPY requirements/build.txt requirements/build.txt +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system -r requirements/build.txt -RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.9/flashinfer-0.0.9+cu121torch2.3-cp310-cp310-linux_x86_64.whl #################### vLLM installation IMAGE #################### - #################### TEST IMAGE #################### # image to run unit testing suite # note that this uses vllm installed by `pip` @@ -205,9 +258,25 @@ FROM vllm-base AS test ADD . /vllm-workspace/ +# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out +# Reference: https://github.com/astral-sh/uv/pull/1694 +ENV UV_HTTP_TIMEOUT=500 + +# install development dependencies (for testing) +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system -r requirements/dev.txt + # install development dependencies (for testing) -RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install -r requirements-dev.txt +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system -e tests/vllm_test_utils + +# enable fast downloads from hf (for testing) +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system hf_transfer +ENV HF_HUB_ENABLE_HF_TRANSFER 1 + +# Copy in the v1 package for testing (it isn't distributed yet) +COPY vllm/v1 /usr/local/lib/python3.12/dist-packages/vllm/v1 # doc requires source code # we hide them inside `test_docs/` , so that this source code @@ -215,18 +284,34 @@ RUN --mount=type=cache,target=/root/.cache/pip \ RUN mkdir test_docs RUN mv docs test_docs/ RUN mv vllm test_docs/ - #################### TEST IMAGE #################### #################### OPENAI API SERVER #################### -# openai api server alternative -FROM vllm-base AS vllm-openai +# base openai image with additional requirements, for any subsequent openai-style images +FROM vllm-base AS vllm-openai-base + +# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out +# Reference: https://github.com/astral-sh/uv/pull/1694 +ENV UV_HTTP_TIMEOUT=500 # install additional dependencies for openai api server -RUN --mount=type=cache,target=/root/.cache/pip \ - pip install accelerate hf_transfer 'modelscope!=1.15.0' +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ + uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \ + else \ + uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.3' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \ + fi ENV VLLM_USAGE_SOURCE production-docker-image +# define sagemaker first, so it is not default from `docker build` +FROM vllm-openai-base AS vllm-sagemaker + +COPY examples/online_serving/sagemaker-entrypoint.sh . +RUN chmod +x sagemaker-entrypoint.sh +ENTRYPOINT ["./sagemaker-entrypoint.sh"] + +FROM vllm-openai-base AS vllm-openai + ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] #################### OPENAI API SERVER #################### diff --git a/Dockerfile.arm b/Dockerfile.arm new file mode 100644 index 000000000000..bad093684239 --- /dev/null +++ b/Dockerfile.arm @@ -0,0 +1,62 @@ +# This vLLM Dockerfile is used to construct an image that can build and run vLLM on ARM CPU platform. + +FROM ubuntu:22.04 AS cpu-test-arm + +ENV CCACHE_DIR=/root/.cache/ccache + +ENV CMAKE_CXX_COMPILER_LAUNCHER=ccache + +RUN --mount=type=cache,target=/var/cache/apt \ + apt-get update -y \ + && apt-get install -y curl ccache git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \ + && apt-get install -y ffmpeg libsm6 libxext6 libgl1 \ + && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 + +# tcmalloc provides better memory allocation efficiency, e.g., holding memory in caches to speed up access of commonly-used objects. +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install py-cpuinfo # Use this to gather CPU info and optimize based on ARM Neoverse cores + +# Set LD_PRELOAD for tcmalloc on ARM +ENV LD_PRELOAD="/usr/lib/aarch64-linux-gnu/libtcmalloc_minimal.so.4" + +RUN echo 'ulimit -c 0' >> ~/.bashrc + +WORKDIR /workspace + +ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" +ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL} +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,src=requirements/build.txt,target=requirements/build.txt \ + pip install --upgrade pip && \ + pip install -r requirements/build.txt + +FROM cpu-test-arm AS build + +WORKDIR /workspace/vllm + +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,src=requirements/common.txt,target=requirements/common.txt \ + --mount=type=bind,src=requirements/cpu.txt,target=requirements/cpu.txt \ + pip install -v -r requirements/cpu.txt + +COPY . . +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi + +# Disabling AVX512 specific optimizations for ARM +ARG VLLM_CPU_DISABLE_AVX512="true" +ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512} + +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=cache,target=/root/.cache/ccache \ + --mount=type=bind,source=.git,target=.git \ + VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel && \ + pip install dist/*.whl && \ + rm -rf dist + +WORKDIR /workspace/ + +RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks + +ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] \ No newline at end of file diff --git a/Dockerfile.cpu b/Dockerfile.cpu index c473ba431e68..a10090529d8a 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -2,40 +2,68 @@ FROM ubuntu:22.04 AS cpu-test-1 -RUN apt-get update -y \ - && apt-get install -y curl git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \ +ENV CCACHE_DIR=/root/.cache/ccache + +ENV CMAKE_CXX_COMPILER_LAUNCHER=ccache + +RUN --mount=type=cache,target=/var/cache/apt \ + apt-get update -y \ + && apt-get install -y curl ccache git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \ + && apt-get install -y ffmpeg libsm6 libxext6 libgl1 \ && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 # https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html # intel-openmp provides additional performance improvement vs. openmp # tcmalloc provides better memory allocation efficiency, e.g, holding memory in caches to speed up access of commonly-used objects. -RUN pip install intel-openmp +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install intel-openmp==2025.0.1 -ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so:$LD_PRELOAD" +ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so" RUN echo 'ulimit -c 0' >> ~/.bashrc -RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.4.0%2Bgitfbaa4bc-cp310-cp310-linux_x86_64.whl +RUN pip install intel_extension_for_pytorch==2.6.0 -RUN pip install --upgrade pip \ - && pip install wheel packaging ninja "setuptools>=49.4.0" numpy +WORKDIR /workspace -FROM cpu-test-1 AS build +ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" +ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL} +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,src=requirements/build.txt,target=requirements/build.txt \ + pip install --upgrade pip && \ + pip install -r requirements/build.txt -COPY ./ /workspace/vllm +FROM cpu-test-1 AS build WORKDIR /workspace/vllm -RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/test/cpu +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,src=requirements/common.txt,target=requirements/common.txt \ + --mount=type=bind,src=requirements/cpu.txt,target=requirements/cpu.txt \ + pip install -v -r requirements/cpu.txt + +COPY . . +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi # Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ... ARG VLLM_CPU_DISABLE_AVX512 ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512} -RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=cache,target=/root/.cache/ccache \ + --mount=type=bind,source=.git,target=.git \ + VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel && \ + pip install dist/*.whl && \ + rm -rf dist WORKDIR /workspace/ RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks +# install development dependencies (for testing) +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -e tests/vllm_test_utils + ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/Dockerfile.hpu b/Dockerfile.hpu new file mode 100644 index 000000000000..48211c88f872 --- /dev/null +++ b/Dockerfile.hpu @@ -0,0 +1,21 @@ +FROM vault.habana.ai/gaudi-docker/1.19.1/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest + +COPY ./ /workspace/vllm + +WORKDIR /workspace/vllm + +RUN pip install -v -r requirements/hpu.txt + +ENV no_proxy=localhost,127.0.0.1 +ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=true + +RUN VLLM_TARGET_DEVICE=hpu python3 setup.py install + +# install development dependencies (for testing) +RUN python3 -m pip install -e tests/vllm_test_utils + +WORKDIR /workspace/ + +RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks + +ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/Dockerfile.neuron b/Dockerfile.neuron index 010f23a14301..067645906366 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -1,36 +1,55 @@ # default base image -ARG BASE_IMAGE="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference-neuronx:2.1.1-neuronx-py310-sdk2.17.0-ubuntu20.04" +# https://gallery.ecr.aws/neuron/pytorch-inference-neuronx +ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.5.1-neuronx-py310-sdk2.21.0-ubuntu22.04" FROM $BASE_IMAGE RUN echo "Base image is $BASE_IMAGE" # Install some basic utilities -RUN apt-get update && apt-get install python3 python3-pip -y +RUN apt-get update && \ + apt-get install -y \ + git \ + python3 \ + python3-pip \ + ffmpeg libsm6 libxext6 libgl1 ### Mount Point ### -# When launching the container, mount the code directory to /app -ARG APP_MOUNT=/app +# When launching the container, mount the code directory to /workspace +ARG APP_MOUNT=/workspace VOLUME [ ${APP_MOUNT} ] -WORKDIR ${APP_MOUNT} +WORKDIR ${APP_MOUNT}/vllm RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas -RUN python3 -m pip install sentencepiece transformers==4.36.2 -U -RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U -RUN python3 -m pip install --pre neuronx-cc==2.12.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U +RUN python3 -m pip install sentencepiece transformers==4.45.2 -U +RUN python3 -m pip install neuronx-cc==2.16.345.0 --extra-index-url=https://pip.repos.neuron.amazonaws.com -U +RUN python3 -m pip install pytest -COPY ./vllm /app/vllm/vllm -COPY ./setup.py /app/vllm/setup.py -COPY ./requirements-common.txt /app/vllm/requirements-common.txt -COPY ./requirements-neuron.txt /app/vllm/requirements-neuron.txt +# uninstall transformers-neuronx package explicitly to avoid version conflict +RUN python3 -m pip uninstall -y transformers-neuronx -RUN cd /app/vllm \ - && python3 -m pip install -U -r requirements-neuron.txt +COPY . . +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi + +RUN python3 -m pip install -U \ + 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ + -r requirements/neuron.txt ENV VLLM_TARGET_DEVICE neuron -RUN cd /app/vllm \ - && pip install -e . \ - && cd .. +RUN --mount=type=bind,source=.git,target=.git \ + pip install --no-build-isolation -v -e . + +# install development dependencies (for testing) +RUN python3 -m pip install -e tests/vllm_test_utils + +# install transformers-neuronx package as an optional dependencies (for V0) +# FIXME: `--no-deps` argument is temporarily added to resolve transformers package version conflict +RUN python3 -m pip install transformers-neuronx==0.13.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U --no-deps + +# overwrite entrypoint to run bash script +RUN echo "import subprocess; import sys; subprocess.check_call(sys.argv[1:])" > /usr/local/bin/dockerd-entrypoint.py CMD ["/bin/bash"] diff --git a/Dockerfile.openvino b/Dockerfile.openvino deleted file mode 100644 index 7c62dd845aa9..000000000000 --- a/Dockerfile.openvino +++ /dev/null @@ -1,26 +0,0 @@ -# The vLLM Dockerfile is used to construct vLLM image that can be directly used -# to run the OpenAI compatible server. - -FROM ubuntu:22.04 AS dev - -RUN apt-get update -y && \ - apt-get install -y python3-pip git -WORKDIR /workspace - -# copy requirements -COPY requirements-build.txt /workspace/vllm/ -COPY requirements-common.txt /workspace/vllm/ -COPY requirements-openvino.txt /workspace/vllm/ - -COPY vllm/ /workspace/vllm/vllm -COPY setup.py /workspace/vllm/ - -# install build requirements -RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt -# build vLLM with OpenVINO backend -RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/pre-release" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/ - -COPY examples/ /workspace/vllm/examples -COPY benchmarks/ /workspace/vllm/benchmarks - -CMD ["/bin/bash"] diff --git a/Dockerfile.ppc64le b/Dockerfile.ppc64le index d4e4c483cada..c5ca20d76e3e 100644 --- a/Dockerfile.ppc64le +++ b/Dockerfile.ppc64le @@ -2,21 +2,36 @@ FROM mambaorg/micromamba ARG MAMBA_DOCKERFILE_ACTIVATE=1 USER root -RUN apt-get update -y && apt-get install -y git wget vim numactl gcc-12 g++-12 protobuf-compiler libprotobuf-dev && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 +ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/" -# Some packages in requirements-cpu are installed here +RUN apt-get update -y && apt-get install -y git wget kmod curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1 libssl-dev + +# Some packages in requirements/cpu are installed here # IBM provides optimized packages for ppc64le processors in the open-ce project for mamba # Currently these may not be available for venv or pip directly -RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 pytorch-cpu=2.1.2 torchvision-cpu=0.16.2 && micromamba clean --all --yes +RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 rust && micromamba clean --all --yes COPY ./ /workspace/vllm WORKDIR /workspace/vllm +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi + +RUN --mount=type=cache,target=/root/.cache/pip \ + RUSTFLAGS='-L /opt/conda/lib' pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \ + 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ + -r requirements/cpu.txt \ + xformers uvloop==0.20.0 + +RUN --mount=type=bind,source=.git,target=.git \ + VLLM_TARGET_DEVICE=cpu python3 setup.py install + +# install development dependencies (for testing) +RUN python3 -m pip install -e tests/vllm_test_utils -# These packages will be in rocketce eventually -RUN pip install -v -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing +WORKDIR /workspace/ -RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install +RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks -WORKDIR /vllm-workspace ENTRYPOINT ["/opt/conda/bin/python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 64bc0f3c12c7..841e7978a424 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -1,182 +1,120 @@ -# Default ROCm 6.1 base image -ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging" - -# Default ROCm ARCHes to build vLLM for. -ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100" - -# Whether to install CK-based flash-attention -# If 0, will not install flash-attention -ARG BUILD_FA="1" -# If `TRY_FA_WHEEL=1`, we will try installing flash-attention from `FA_WHEEL_URL` -# If this succeeds, we use the downloaded wheel and skip building flash-attention. -# Otherwise, ROCm flash-attention from `FA_BRANCH` will be built for the -# architectures specified in `FA_GFX_ARCHS` -ARG TRY_FA_WHEEL="1" -ARG FA_WHEEL_URL="https://github.com/ROCm/flash-attention/releases/download/v2.5.9post1-cktile-vllm/flash_attn-2.5.9.post1-cp39-cp39-linux_x86_64.whl" -ARG FA_GFX_ARCHS="gfx90a;gfx942" -ARG FA_BRANCH="23a2b1c2" - -# Whether to build triton on rocm -ARG BUILD_TRITON="1" -ARG TRITON_BRANCH="e0fc12c" - -### Base image build stage -FROM $BASE_IMAGE AS base - -# Import arg(s) defined before this build stage -ARG PYTORCH_ROCM_ARCH +# default base image +ARG REMOTE_VLLM="0" +ARG USE_CYTHON="0" +ARG BUILD_RPD="1" +ARG COMMON_WORKDIR=/app +ARG BASE_IMAGE=rocm/vllm-dev:base + +FROM ${BASE_IMAGE} AS base + +ARG ARG_PYTORCH_ROCM_ARCH +ENV PYTORCH_ROCM_ARCH=${ARG_PYTORCH_ROCM_ARCH:-${PYTORCH_ROCM_ARCH}} # Install some basic utilities -RUN apt-get update && apt-get install python3 python3-pip -y -RUN apt-get update && apt-get install -y \ - curl \ - ca-certificates \ - sudo \ - git \ - bzip2 \ - libx11-6 \ - build-essential \ - wget \ - unzip \ - tmux \ - ccache \ - && rm -rf /var/lib/apt/lists/* - -# When launching the container, mount the code directory to /vllm-workspace -ARG APP_MOUNT=/vllm-workspace -WORKDIR ${APP_MOUNT} - -RUN python3 -m pip install --upgrade pip -# Remove sccache so it doesn't interfere with ccache -# TODO: implement sccache support across components +RUN apt-get update -q -y && apt-get install -q -y \ + sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev +# Remove sccache +RUN python3 -m pip install --upgrade pip && pip install setuptools_scm RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)" -# Install torch == 2.5.0 on ROCm -RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ - *"rocm-6.1"*) \ - python3 -m pip uninstall -y torch torchvision \ - && python3 -m pip install --no-cache-dir --pre \ - torch==2.5.0.dev20240726 \ - torchvision==0.20.0.dev20240726 \ - --index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \ +ARG COMMON_WORKDIR +WORKDIR ${COMMON_WORKDIR} + + +# ----------------------- +# vLLM fetch stages +FROM base AS fetch_vllm_0 +ONBUILD COPY ./ vllm/ +FROM base AS fetch_vllm_1 +ARG VLLM_REPO="https://github.com/vllm-project/vllm.git" +ARG VLLM_BRANCH="main" +ONBUILD RUN git clone ${VLLM_REPO} \ + && cd vllm \ + && git checkout ${VLLM_BRANCH} +FROM fetch_vllm_${REMOTE_VLLM} AS fetch_vllm + +# ----------------------- +# vLLM build stages +FROM fetch_vllm AS build_vllm +ARG USE_CYTHON +# Build vLLM +RUN cd vllm \ + && python3 -m pip install -r requirements/rocm.txt \ + && python3 setup.py clean --all \ + && if [ ${USE_CYTHON} -eq "1" ]; then python3 tests/build_cython.py build_ext --inplace; fi \ + && python3 setup.py bdist_wheel --dist-dir=dist +FROM scratch AS export_vllm +ARG COMMON_WORKDIR +COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/dist/*.whl / +COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/requirements /requirements +COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/benchmarks /benchmarks +COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests +COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples +COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite + +# ----------------------- +# Test vLLM image +FROM base AS test + +RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/* + +# Install vLLM +RUN --mount=type=bind,from=export_vllm,src=/,target=/install \ + cd /install \ + && pip install -U -r requirements/rocm.txt \ + && pip install -U -r requirements/rocm-test.txt \ + && pip uninstall -y vllm \ + && pip install *.whl + +WORKDIR /vllm-workspace +ARG COMMON_WORKDIR +COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace + +# install development dependencies (for testing) +RUN cd /vllm-workspace \ + && rm -rf vllm \ + && python3 -m pip install -e tests/vllm_test_utils \ + && python3 -m pip install lm-eval[api]==0.4.4 \ + && python3 -m pip install pytest-shard + +# ----------------------- +# Final vLLM image +FROM base AS final + +RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/* +# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt. +# Manually remove it so that later steps of numpy upgrade can continue +RUN case "$(which python3)" in \ + *"/opt/conda/envs/py_3.9"*) \ + rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \ *) ;; esac -ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer -ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin: -ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib: -ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/: - -ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} -ENV CCACHE_DIR=/root/.cache/ccache - - -### AMD-SMI build stage -FROM base AS build_amdsmi -# Build amdsmi wheel always -RUN cd /opt/rocm/share/amd_smi \ - && python3 -m pip wheel . --wheel-dir=/install - - -### Flash-Attention wheel build stage -FROM base AS build_fa -ARG BUILD_FA -ARG TRY_FA_WHEEL -ARG FA_WHEEL_URL -ARG FA_GFX_ARCHS -ARG FA_BRANCH -# Build ROCm flash-attention wheel if `BUILD_FA = 1` -RUN --mount=type=cache,target=${CCACHE_DIR} \ - if [ "$BUILD_FA" = "1" ]; then \ - if [ "${TRY_FA_WHEEL}" = "1" ] && python3 -m pip install "${FA_WHEEL_URL}"; then \ - # If a suitable wheel exists, we download it instead of building FA - mkdir -p /install && wget -N "${FA_WHEEL_URL}" -P /install; \ - else \ - mkdir -p libs \ - && cd libs \ - && git clone https://github.com/ROCm/flash-attention.git \ - && cd flash-attention \ - && git checkout "${FA_BRANCH}" \ - && git submodule update --init \ - && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \ - fi; \ - # Create an empty directory otherwise as later build stages expect one - else mkdir -p /install; \ - fi - - -### Triton wheel build stage -FROM base AS build_triton -ARG BUILD_TRITON -ARG TRITON_BRANCH -# Build triton wheel if `BUILD_TRITON = 1` -RUN --mount=type=cache,target=${CCACHE_DIR} \ - if [ "$BUILD_TRITON" = "1" ]; then \ - mkdir -p libs \ - && cd libs \ - && git clone https://github.com/OpenAI/triton.git \ - && cd triton \ - && git checkout "${TRITON_BRANCH}" \ - && cd python \ - && python3 setup.py bdist_wheel --dist-dir=/install; \ - # Create an empty directory otherwise as later build stages expect one - else mkdir -p /install; \ - fi - - -### Final vLLM build stage -FROM base AS final -# Import the vLLM development directory from the build context -COPY . . +RUN python3 -m pip install --upgrade huggingface-hub[cli] +ARG BUILD_RPD +RUN if [ ${BUILD_RPD} -eq "1" ]; then \ + git clone -b nvtx_enabled https://github.com/ROCm/rocmProfileData.git \ + && cd rocmProfileData/rpd_tracer \ + && pip install -r requirements.txt && cd ../ \ + && make && make install \ + && cd hipMarker && python3 setup.py install ; fi + +# Install vLLM +RUN --mount=type=bind,from=export_vllm,src=/,target=/install \ + cd /install \ + && pip install -U -r requirements/rocm.txt \ + && pip uninstall -y vllm \ + && pip install *.whl -# Package upgrades for useful functionality or to avoid dependency issues -RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install --upgrade numba scipy huggingface-hub[cli] +ARG COMMON_WORKDIR + +# Copy over the benchmark scripts as well +COPY --from=export_vllm /benchmarks ${COMMON_WORKDIR}/vllm/benchmarks +COPY --from=export_vllm /examples ${COMMON_WORKDIR}/vllm/examples -# Make sure punica kernels are built (for LoRA) -ENV VLLM_INSTALL_PUNICA_KERNELS=1 -# Workaround for ray >= 2.10.0 ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 -# Silences the HF Tokenizers warning ENV TOKENIZERS_PARALLELISM=false -RUN --mount=type=cache,target=${CCACHE_DIR} \ - --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install -Ur requirements-rocm.txt \ - && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ - *"rocm-6.1"*) \ - # Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM - wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P /opt/rocm/lib \ - # Prevent interference if torch bundles its own HIP runtime - && rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \ - *) ;; esac \ - && python3 setup.py clean --all \ - && python3 setup.py develop - -# Copy amdsmi wheel into final image -RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \ - mkdir -p libs \ - && cp /install/*.whl libs \ - # Preemptively uninstall to avoid same-version no-installs - && python3 -m pip uninstall -y amdsmi; - -# Copy triton wheel(s) into final image if they were built -RUN --mount=type=bind,from=build_triton,src=/install,target=/install \ - mkdir -p libs \ - && if ls /install/*.whl; then \ - cp /install/*.whl libs \ - # Preemptively uninstall to avoid same-version no-installs - && python3 -m pip uninstall -y triton; fi - -# Copy flash-attn wheel(s) into final image if they were built -RUN --mount=type=bind,from=build_fa,src=/install,target=/install \ - mkdir -p libs \ - && if ls /install/*.whl; then \ - cp /install/*.whl libs \ - # Preemptively uninstall to avoid same-version no-installs - && python3 -m pip uninstall -y flash-attn; fi - -# Install wheels that were built to the final image -RUN --mount=type=cache,target=/root/.cache/pip \ - if ls libs/*.whl; then \ - python3 -m pip install libs/*.whl; fi +# Performance environment variable. +ENV HIP_FORCE_DEV_KERNARG=1 CMD ["/bin/bash"] + diff --git a/Dockerfile.rocm_base b/Dockerfile.rocm_base new file mode 100644 index 000000000000..38d6a33636eb --- /dev/null +++ b/Dockerfile.rocm_base @@ -0,0 +1,172 @@ +ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.3.1-complete +ARG HIPBLASLT_BRANCH="4d40e36" +ARG HIPBLAS_COMMON_BRANCH="7c1566b" +ARG LEGACY_HIPBLASLT_OPTION= +ARG RCCL_BRANCH="648a58d" +ARG RCCL_REPO="https://github.com/ROCm/rccl" +ARG TRITON_BRANCH="e5be006" +ARG TRITON_REPO="https://github.com/triton-lang/triton.git" +ARG PYTORCH_BRANCH="3a585126" +ARG PYTORCH_VISION_BRANCH="v0.19.1" +ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" +ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" +ARG FA_BRANCH="b7d29fb" +ARG FA_REPO="https://github.com/ROCm/flash-attention.git" +ARG AITER_BRANCH="21d47a9" +ARG AITER_REPO="https://github.com/ROCm/aiter.git" + +FROM ${BASE_IMAGE} AS base + +ENV PATH=/opt/rocm/llvm/bin:$PATH +ENV ROCM_PATH=/opt/rocm +ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib: +ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942 +ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} + +ARG PYTHON_VERSION=3.12 + +RUN mkdir -p /app +WORKDIR /app +ENV DEBIAN_FRONTEND=noninteractive + +# Install Python and other dependencies +RUN apt-get update -y \ + && apt-get install -y software-properties-common git curl sudo vim less \ + && add-apt-repository ppa:deadsnakes/ppa \ + && apt-get update -y \ + && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \ + python${PYTHON_VERSION}-lib2to3 python-is-python3 \ + && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ + && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \ + && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ + && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ + && python3 --version && python3 -m pip --version + +RUN pip install -U packaging cmake ninja wheel setuptools pybind11 Cython + +FROM base AS build_hipblaslt +ARG HIPBLASLT_BRANCH +ARG HIPBLAS_COMMON_BRANCH +# Set to "--legacy_hipblas_direct" for ROCm<=6.2 +ARG LEGACY_HIPBLASLT_OPTION +RUN git clone https://github.com/ROCm/hipBLAS-common.git +RUN cd hipBLAS-common \ + && git checkout ${HIPBLAS_COMMON_BRANCH} \ + && mkdir build \ + && cd build \ + && cmake .. \ + && make package \ + && dpkg -i ./*.deb +RUN git clone https://github.com/ROCm/hipBLASLt +RUN cd hipBLASLt \ + && git checkout ${HIPBLASLT_BRANCH} \ + && ./install.sh -d --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \ + && cd build/release \ + && make package +RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install + +FROM base AS build_rccl +ARG RCCL_BRANCH +ARG RCCL_REPO +RUN git clone ${RCCL_REPO} +RUN cd rccl \ + && git checkout ${RCCL_BRANCH} \ + && ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH} +RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install + +FROM base AS build_triton +ARG TRITON_BRANCH +ARG TRITON_REPO +RUN git clone ${TRITON_REPO} +RUN cd triton \ + && git checkout ${TRITON_BRANCH} \ + && cd python \ + && python3 setup.py bdist_wheel --dist-dir=dist +RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install + +FROM base AS build_amdsmi +RUN cd /opt/rocm/share/amd_smi \ + && pip wheel . --wheel-dir=dist +RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install + +FROM base AS build_pytorch +ARG PYTORCH_BRANCH +ARG PYTORCH_VISION_BRANCH +ARG PYTORCH_REPO +ARG PYTORCH_VISION_REPO +ARG FA_BRANCH +ARG FA_REPO +RUN git clone ${PYTORCH_REPO} pytorch +RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \ + pip install -r requirements.txt && git submodule update --init --recursive \ + && python3 tools/amd_build/build_amd.py \ + && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \ + && pip install dist/*.whl +RUN git clone ${PYTORCH_VISION_REPO} vision +RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \ + && python3 setup.py bdist_wheel --dist-dir=dist \ + && pip install dist/*.whl +RUN git clone ${FA_REPO} +RUN cd flash-attention \ + && git checkout ${FA_BRANCH} \ + && git submodule update --init \ + && MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist +RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \ + && cp /app/vision/dist/*.whl /app/install \ + && cp /app/flash-attention/dist/*.whl /app/install + +FROM base AS final +RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \ + dpkg -i /install/*deb \ + && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ + && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status +RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \ + dpkg -i /install/*deb \ + && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \ + && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status +RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \ + pip install /install/*.whl +RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \ + pip install /install/*.whl +RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ + pip install /install/*.whl + +ARG AITER_REPO +ARG AITER_BRANCH +RUN git clone --recursive ${AITER_REPO} +RUN cd aiter \ + && git checkout ${AITER_BRANCH} \ + && git submodule update --init --recursive \ + && pip install -r requirements.txt \ + && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter + +ARG BASE_IMAGE +ARG HIPBLASLT_BRANCH +ARG HIPBLAS_COMMON_BRANCH +ARG LEGACY_HIPBLASLT_OPTION +ARG RCCL_BRANCH +ARG RCCL_REPO +ARG TRITON_BRANCH +ARG TRITON_REPO +ARG PYTORCH_BRANCH +ARG PYTORCH_VISION_BRANCH +ARG PYTORCH_REPO +ARG PYTORCH_VISION_REPO +ARG FA_BRANCH +ARG FA_REPO +RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \ + && echo "HIPBLAS_COMMON_BRANCH: ${HIPBLAS_COMMON_BRANCH}" >> /app/versions.txt \ + && echo "HIPBLASLT_BRANCH: ${HIPBLASLT_BRANCH}" >> /app/versions.txt \ + && echo "LEGACY_HIPBLASLT_OPTION: ${LEGACY_HIPBLASLT_OPTION}" >> /app/versions.txt \ + && echo "RCCL_BRANCH: ${RCCL_BRANCH}" >> /app/versions.txt \ + && echo "RCCL_REPO: ${RCCL_REPO}" >> /app/versions.txt \ + && echo "TRITON_BRANCH: ${TRITON_BRANCH}" >> /app/versions.txt \ + && echo "TRITON_REPO: ${TRITON_REPO}" >> /app/versions.txt \ + && echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \ + && echo "PYTORCH_VISION_BRANCH: ${PYTORCH_VISION_BRANCH}" >> /app/versions.txt \ + && echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \ + && echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \ + && echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \ + && echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \ + && echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \ + && echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt diff --git a/Dockerfile.s390x b/Dockerfile.s390x new file mode 100644 index 000000000000..5a84dc12d8f7 --- /dev/null +++ b/Dockerfile.s390x @@ -0,0 +1,152 @@ +# Base UBI image for s390x architecture +ARG BASE_UBI_IMAGE_TAG=9.5-1736404155 +ARG PYTHON_VERSION=3.12 +FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS base + +# Install basic dependencies +ARG PYTHON_VERSION +ENV PYTHON_VERSION=${PYTHON_VERSION} + +WORKDIR /workspace + +ENV LANG=C.UTF-8 \ + LC_ALL=C.UTF-8 + +# Install development utilities +RUN microdnf install -y \ + which procps findutils tar vim git gcc gcc-gfortran g++ make patch zlib-devel \ + libjpeg-turbo-devel libtiff-devel libpng-devel libwebp-devel freetype-devel harfbuzz-devel \ + openssl-devel openblas openblas-devel autoconf automake libtool cmake && \ + microdnf clean all + +# Python Installation +FROM base AS python-install +ARG PYTHON_VERSION + +ENV VIRTUAL_ENV=/opt/vllm +ENV PATH="$VIRTUAL_ENV/bin:$PATH" +ENV PYTHON_VERSION=${PYTHON_VERSION} +RUN microdnf install -y \ + python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip python${PYTHON_VERSION}-wheel && \ + python${PYTHON_VERSION} -m venv $VIRTUAL_ENV && pip install --no-cache -U pip wheel uv && microdnf clean all + +FROM python-install AS pyarrow + +# Build Apache Arrow +WORKDIR /tmp +RUN --mount=type=cache,target=/root/.cache/uv \ + git clone https://github.com/apache/arrow.git && \ + cd arrow/cpp && \ + mkdir release && cd release && \ + cmake -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=/usr/local \ + -DARROW_PYTHON=ON \ + -DARROW_PARQUET=ON \ + -DARROW_ORC=ON \ + -DARROW_FILESYSTEM=ON \ + -DARROW_WITH_LZ4=ON \ + -DARROW_WITH_ZSTD=ON \ + -DARROW_WITH_SNAPPY=ON \ + -DARROW_JSON=ON \ + -DARROW_CSV=ON \ + -DARROW_DATASET=ON \ + -DPROTOBUF_PROTOC_EXECUTABLE=/usr/bin/protoc \ + -DARROW_DEPENDENCY_SOURCE=BUNDLED \ + .. && \ + make -j$(nproc) && \ + make install && \ + cd ../../python && \ + export PYARROW_PARALLEL=4 && \ + export ARROW_BUILD_TYPE=release && \ + uv pip install -r requirements/build.txt && \ + python setup.py build_ext --build-type=$ARROW_BUILD_TYPE --bundle-arrow-cpp bdist_wheel + +FROM python-install AS numa-build +# Install numactl (needed for numa.h dependency) +WORKDIR /tmp +RUN curl -LO https://github.com/numactl/numactl/archive/refs/tags/v2.0.16.tar.gz && \ + tar -xvzf v2.0.16.tar.gz && \ + cd numactl-2.0.16 && \ + ./autogen.sh && \ + ./configure && \ + make + +# Set include path +ENV C_INCLUDE_PATH="/usr/local/include:$C_INCLUDE_PATH" + +FROM python-install AS rust +ENV CARGO_HOME=/root/.cargo +ENV RUSTUP_HOME=/root/.rustup +ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH" + +RUN curl https://sh.rustup.rs -sSf | sh -s -- -y && \ + . "$CARGO_HOME/env" && \ + rustup default stable && \ + rustup show + +FROM python-install AS torch-vision +# Install torchvision +ARG TORCH_VERSION=2.7.0.dev20250304 +ARG TORCH_VISION_VERSION=v0.20.1 +WORKDIR /tmp +RUN --mount=type=cache,target=/root/.cache/uv \ + git clone https://github.com/pytorch/vision.git && \ + cd vision && \ + git checkout $TORCH_VISION_VERSION && \ + uv pip install -v torch==${TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/nightly/cpu && \ + python setup.py bdist_wheel + +# Final build stage +FROM python-install AS vllm-cpu +ARG PYTHON_VERSION + +# Set correct library path for torch and numactl +ENV LD_LIBRARY_PATH="/opt/vllm/lib64/python${PYTHON_VERSION}/site-packages/torch/lib:/usr/local/lib:$LD_LIBRARY_PATH" +ENV C_INCLUDE_PATH="/usr/local/include:$C_INCLUDE_PATH" +ENV UV_LINK_MODE=copy +ENV CARGO_HOME=/root/.cargo +ENV RUSTUP_HOME=/root/.rustup +ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH" + +COPY . /workspace/vllm +WORKDIR /workspace/vllm + +RUN --mount=type=bind,from=numa-build,src=/tmp/numactl-2.0.16,target=/numactl \ + make -C /numactl install + +# Install dependencies, including PyTorch and Apache Arrow +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \ + --mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \ + --mount=type=bind,from=pyarrow,source=/tmp/arrow/python/dist,target=/tmp/arrow-wheels \ + --mount=type=bind,from=torch-vision,source=/tmp/vision/dist,target=/tmp/vision-wheels/ \ + sed -i '/^torch/d' requirements/build.txt && \ + ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl | head -n 1) && \ + VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl | head -n 1) && \ + uv pip install -v \ + $ARROW_WHL_FILE \ + $VISION_WHL_FILE \ + --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + --index-strategy unsafe-best-match \ + -r requirements/build.txt \ + -r requirements/cpu.txt + +# Build and install vllm +RUN --mount=type=cache,target=/root/.cache/uv \ + VLLM_TARGET_DEVICE=cpu python setup.py bdist_wheel && \ + uv pip install "$(echo dist/*.whl)[tensorizer]" + +# setup non-root user for vllm +RUN umask 002 && \ + useradd --uid 2000 --gid 0 vllm && \ + mkdir -p /home/vllm && \ + chmod g+rwx /home/vllm + +COPY LICENSE /licenses/vllm.md +COPY examples/*.jinja /app/data/template/ + +USER 2000 +WORKDIR /home/vllm + +# Set the default entrypoint +ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"] \ No newline at end of file diff --git a/Dockerfile.tpu b/Dockerfile.tpu index adebb8ab5adc..50806d8820a3 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -1,23 +1,31 @@ -ARG NIGHTLY_DATE="20240726" +ARG NIGHTLY_DATE="20250124" ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE" FROM $BASE_IMAGE -WORKDIR /workspace +WORKDIR /workspace/vllm -# Install aiohttp separately to avoid build errors. -RUN pip install aiohttp -# Install NumPy 1 instead of NumPy 2. -RUN pip install "numpy<2" -# Install the TPU and Pallas dependencies. -RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html -RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html - -# Fix FastAPI dependence -RUN pip install "starlette<0.38.0" +# Install some basic utilities +RUN apt-get update && apt-get install -y \ + git \ + ffmpeg libsm6 libxext6 libgl1 # Build vLLM. -COPY . /workspace/vllm +COPY . . +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi + +# Remove existing versions of dependencies +RUN pip uninstall -y torch torch_xla torchvision + ENV VLLM_TARGET_DEVICE="tpu" -RUN cd /workspace/vllm && python setup.py develop +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=.git,target=.git \ + python3 -m pip install \ + -r requirements/tpu.txt +RUN python3 setup.py develop + +# install development dependencies (for testing) +RUN python3 -m pip install -e tests/vllm_test_utils CMD ["/bin/bash"] diff --git a/Dockerfile.xpu b/Dockerfile.xpu index f91baa11a375..ad4abf16b43b 100644 --- a/Dockerfile.xpu +++ b/Dockerfile.xpu @@ -1,22 +1,61 @@ -FROM intel/oneapi-basekit:2024.1.0-devel-ubuntu20.04 +# oneapi 2025.0.2 docker base image use rolling 2448 package. https://dgpu-docs.intel.com/releases/packages.html?release=Rolling+2448.13&os=Ubuntu+22.04, and we don't need install driver manually. +FROM intel/deep-learning-essentials:2025.0.2-0-devel-ubuntu22.04 AS vllm-base -RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \ - echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \ - chmod 644 /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \ - rm /etc/apt/sources.list.d/intel-graphics.list && \ - wget -O- https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null && \ - echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \ - chmod 644 /usr/share/keyrings/intel-graphics.gpg +RUN rm /etc/apt/sources.list.d/intel-graphics.list -RUN apt-get update -y \ -&& apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip - -COPY ./ /workspace/vllm +RUN apt-get update -y && \ + apt-get install -y --no-install-recommends --fix-missing \ + curl \ + ffmpeg \ + git \ + libsndfile1 \ + libsm6 \ + libxext6 \ + libgl1 \ + lsb-release \ + numactl \ + python3 \ + python3-dev \ + python3-pip \ + wget WORKDIR /workspace/vllm +COPY requirements/xpu.txt /workspace/vllm/requirements/xpu.txt +COPY requirements/common.txt /workspace/vllm/requirements/common.txt + +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install --no-cache-dir \ + -r requirements/xpu.txt + +ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/" + +COPY . . +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi -RUN pip install -v -r requirements-xpu.txt +ENV VLLM_TARGET_DEVICE=xpu -RUN VLLM_TARGET_DEVICE=xpu python3 setup.py install +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=.git,target=.git \ + python3 setup.py install + +# Please refer xpu doc, we need manually install intel-extension-for-pytorch 2.6.10+xpu due to there are some conflict dependencies with torch 2.6.0+xpu +# FIXME: This will be fix in ipex 2.7. just leave this here for awareness. +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install intel-extension-for-pytorch==2.6.10+xpu \ + --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ CMD ["/bin/bash"] + +FROM vllm-base AS vllm-openai + +# install additional dependencies for openai api server +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install accelerate hf_transfer 'modelscope!=1.15.0' + +ENV VLLM_USAGE_SOURCE production-docker-image \ + TRITON_XPU_PROFILE 1 +# install development dependencies (for testing) +RUN python3 -m pip install -e tests/vllm_test_utils +ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/MANIFEST.in b/MANIFEST.in index 82be639ef4d7..82fd22b845f0 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,9 +1,9 @@ include LICENSE -include requirements-common.txt -include requirements-cuda.txt -include requirements-rocm.txt -include requirements-neuron.txt -include requirements-cpu.txt +include requirements/common.txt +include requirements/cuda.txt +include requirements/rocm.txt +include requirements/neuron.txt +include requirements/cpu.txt include CMakeLists.txt recursive-include cmake * diff --git a/README.md b/README.md index 5f23f0813f60..f2da0467e5c3 100644 --- a/README.md +++ b/README.md @@ -10,13 +10,33 @@ Easy, fast, and cheap LLM serving for everyone

-| Documentation | Blog | Paper | Discord | - +| Documentation | Blog | Paper | Twitter/X | User Forum | Developer Slack |

--- +[2025/03] We are collaborating with Ollama to host an [Inference Night](https://lu.ma/vllm-ollama) at Y Combinator in San Francisco on Thursday, March 27, at 6 PM. Discuss all things inference local or data center! + +[2025/04] We're hosting our first-ever *vLLM Asia Developer Day* in Singapore on *April 3rd*! This is a full-day event (9 AM - 9 PM SGT) in partnership with SGInnovate, AMD, and Embedded LLM. Meet the vLLM team and learn about LLM inference for RL, MI300X, and more! [Register Now](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day) + +--- + *Latest News* 🔥 + +- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing). +- [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0). +- [2025/02] We hosted [the ninth vLLM meetup](https://lu.ma/h7g3kuj9) with Meta! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1jzC_PZVXrVNSFVCW-V4cFXb6pn7zZ2CyP_Flwo05aqg/edit?usp=sharing) and AMD [here](https://drive.google.com/file/d/1Zk5qEJIkTmlQ2eQcXQZlljAx3m9s7nwn/view?usp=sharing). The slides from Meta will not be posted. +- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html). +- [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing), and Google Cloud team [here](https://drive.google.com/file/d/1h24pHewANyRL11xy5dXUbvRC9F9Kkjix/view?usp=sharing). +- [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone! + +
+Previous News + +- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing). +- [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there! +- [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://www.youtube.com/playlist?list=PLzTswPQNepXl6AQwifuwUImLPFRVpksjR) from other vLLM contributors and users! +- [2024/09] We hosted [the sixth vLLM meetup](https://lu.ma/87q3nvnh) with NVIDIA! Please find the meetup slides [here](https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing). - [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing). - [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html). - [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing). @@ -26,20 +46,27 @@ Easy, fast, and cheap LLM serving for everyone - [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM. - [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai). +
+ --- ## About + vLLM is a fast and easy-to-use library for LLM inference and serving. +Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry. + vLLM is fast with: - State-of-the-art serving throughput -- Efficient management of attention key and value memory with **PagedAttention** +- Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html) - Continuous batching of incoming requests - Fast model execution with CUDA/HIP graph -- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629), FP8 KV Cache -- Optimized CUDA kernels +- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), INT4, INT8, and FP8. +- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer. +- Speculative decoding +- Chunked prefill -**Performance benchmark**: We include a [performance benchmark](https://buildkite.com/vllm/performance-benchmark/builds/4068) that compares the performance of vllm against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [text-generation-inference](https://github.com/huggingface/text-generation-inference) and [lmdeploy](https://github.com/InternLM/lmdeploy)). +**Performance benchmark**: We include a performance benchmark at the end of [our blog post](https://blog.vllm.ai/2024/09/05/perf-update.html). It compares the performance of vLLM against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [SGLang](https://github.com/sgl-project/sglang) and [LMDeploy](https://github.com/InternLM/lmdeploy)). The implementation is under [nightly-benchmarks folder](.buildkite/nightly-benchmarks/) and you can [reproduce](https://github.com/vllm-project/vllm/issues/8176) this benchmark using our one-click runnable script. vLLM is flexible and easy to use with: @@ -48,29 +75,30 @@ vLLM is flexible and easy to use with: - Tensor parallelism and pipeline parallelism support for distributed inference - Streaming outputs - OpenAI-compatible API server -- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs -- (Experimental) Prefix caching support -- (Experimental) Multi-lora support +- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron. +- Prefix caching support +- Multi-lora support vLLM seamlessly supports most popular open-source models on HuggingFace, including: - Transformer-like LLMs (e.g., Llama) -- Mixture-of-Expert LLMs (e.g., Mixtral) +- Mixture-of-Expert LLMs (e.g., Mixtral, Deepseek-V2 and V3) +- Embedding Models (e.g. E5-Mistral) - Multi-modal LLMs (e.g., LLaVA) Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html). ## Getting Started -Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): +Install vLLM with `pip` or [from source](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#build-wheel-from-source): ```bash pip install vllm ``` -Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to learn more. -- [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html) -- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html) -- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html) +Visit our [documentation](https://docs.vllm.ai/en/latest/) to learn more. +- [Installation](https://docs.vllm.ai/en/latest/getting_started/installation.html) +- [Quickstart](https://docs.vllm.ai/en/latest/getting_started/quickstart.html) +- [List of Supported Models](https://docs.vllm.ai/en/latest/models/supported_models.html) ## Contributing @@ -83,32 +111,40 @@ vLLM is a community project. Our compute resources for development and testing a - +Cash Donations: - a16z +- Dropbox +- Sequoia Capital +- Skywork AI +- ZhenFund + +Compute Resources: - AMD - Anyscale - AWS - Crusoe Cloud - Databricks - DeepInfra -- Dropbox - Google Cloud - Lambda Lab +- Nebius +- Novita AI - NVIDIA - Replicate - Roblox - RunPod -- Sequoia Capital - Trainy - UC Berkeley - UC San Diego -- ZhenFund + +Slack Sponsor: Anyscale We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM. ## Citation If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180): + ```bibtex @inproceedings{kwon2023efficient, title={Efficient Memory Management for Large Language Model Serving with PagedAttention}, @@ -117,3 +153,15 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs year={2023} } ``` + +## Contact Us + +- For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues) or [Discussions](https://github.com/vllm-project/vllm/discussions) +- For discussing with fellow users, please use the [vLLM Forum](https://discuss.vllm.ai) +- coordinating contributions and development, please use [Slack](https://slack.vllm.ai) +- For security disclosures, please use GitHub's [Security Advisories](https://github.com/vllm-project/vllm/security/advisories) feature +- For collaborations and partnerships, please contact us at [vllm-questions@lists.berkeley.edu](mailto:vllm-questions@lists.berkeley.edu) + +## Media Kit + +- If you wish to use vLLM's logo, please refer to [our media kit repo](https://github.com/vllm-project/media-kit). diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 000000000000..7f5270715212 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,54 @@ +# Releasing vLLM + +vLLM releases offer a reliable version of the code base, packaged into a binary format that can be conveniently accessed via PyPI. These releases also serve as key milestones for the development team to communicate with the community about newly available features, improvements, and upcoming changes that could affect users, including potential breaking changes. + +## Release Versioning + +vLLM uses a “right-shifted” versioning scheme where a new patch release is out every 2 weeks. And patch releases contain features and bug fixes (as opposed to semver where patch release contains only backwards-compatible bug fixes). When critical fixes need to be made, special release post1 is released. + +* _major_ major architectural milestone and when incompatible API changes are made, similar to PyTorch 2.0. +* _minor_ major features +* _patch_ features and backwards-compatible bug fixes +* _post1_ or _patch-1_ backwards-compatible bug fixes, either explicit or implicit post release + +## Release Cadence + +Patch release is released on bi-weekly basis. Post release 1-3 days after patch release and uses same branch as patch release. +Following is the release cadence for year 2025. All future release dates below are tentative. Please note: Post releases are optional. + +| Release Date | Patch release versions | Post Release versions | +| --- | --- | --- | +| Jan 2025 | 0.7.0 | --- | +| Feb 2025 | 0.7.1, 0.7.2, 0.7.3 | --- | +| Mar 2025 | 0.7.4, 0.7.5 | --- | +| Apr 2025 | 0.7.6, 0.7.7 | --- | +| May 2025 | 0.7.8, 0.7.9 | --- | +| Jun 2025 | 0.7.10, 0.7.11 | --- | +| Jul 2025 | 0.7.12, 0.7.13 | --- | +| Aug 2025 | 0.7.14, 0.7.15 | --- | +| Sep 2025 | 0.7.16, 0.7.17 | --- | +| Oct 2025 | 0.7.18, 0.7.19 | --- | +| Nov 2025 | 0.7.20, 0.7.21 | --- | +| Dec 2025 | 0.7.22, 0.7.23 | --- | + +## Release branch + +Each release is built from a dedicated release branch. + +* For _major_, _minor_, _patch_ releases, the release branch cut is performed 1-2 days before release is live. +* For post releases, previously cut release branch is reused +* Release builds are triggered via push to RC tag like vX.Y.Z-rc1 . This enables us to build and test multiple RCs for each release. +* Final tag : vX.Y.Z does not trigger the build but used for Release notes and assets. +* After branch cut is created we monitor the main branch for any reverts and apply these reverts to a release branch. + +## Release Cherry-Pick Criteria + +After branch cut, we approach finalizing the release branch with clear criteria on what cherry picks are allowed in. Note: a cherry pick is a process to land a PR in the release branch after branch cut. These are typically limited to ensure that the team has sufficient time to complete a thorough round of testing on a stable code base. + +* Regression fixes - that address functional/performance regression against the most recent release (e.g. 0.7.0 for 0.7.1 release) +* Critical fixes - critical fixes for severe issue such as silent incorrectness, backwards compatibility, crashes, deadlocks, (large) memory leaks +* Fixes to new features introduced in the most recent release (e.g. 0.7.0 for 0.7.1 release) +* Documentation improvements +* Release branch specific changes (e.g. change version identifiers or CI fixes) + +Please note: **No feature work allowed for cherry picks**. All PRs that are considered for cherry-picks need to be merged on trunk, the only exception are Release branch specific changes. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000000..47196a1f1221 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,11 @@ +# Security Policy + +## Reporting a Vulnerability + +If you believe you have found a security vulnerability in vLLM, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem. + +Please report security issues privately using [the vulnerability submission form](https://github.com/vllm-project/vllm/security/advisories/new). Reports will then be triaged by the [vulnerability management team](https://docs.vllm.ai/en/latest/contributing/vulnerability_management.html). + +--- + +Please see [PyTorch's Security Policy](https://github.com/pytorch/pytorch/blob/main/SECURITY.md) for more information and recommendations on how to securely interact with models. diff --git a/benchmarks/README.md b/benchmarks/README.md index 192d6c4022c8..d41de1caa04c 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -1,8 +1,268 @@ # Benchmarking vLLM -## Downloading the ShareGPT dataset +This README guides you through running benchmark tests with the extensive +datasets supported on vLLM. It’s a living document, updated as new features and datasets +become available. + +## Dataset Overview + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
DatasetOnlineOfflineData Path
ShareGPTwget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
BurstGPTwget https://github.com/HPMLL/BurstGPT/releases/download/v1.1/BurstGPT_without_fails_2.csv
SonnetLocal file: benchmarks/sonnet.txt
Randomsynthetic
HuggingFace🟡🟡Specify your dataset path on HuggingFace
VisionArenalmarena-ai/vision-arena-bench-v0.1 (a HuggingFace dataset)
+ +✅: supported + +🚧: to be supported + +🟡: Partial support. Currently, HuggingFaceDataset only supports dataset formats +similar to `lmms-lab/LLaVA-OneVision-Data` and `Aeala/ShareGPT_Vicuna_unfiltered`. +If you need support for other dataset formats, please consider contributing. + +**Note**: VisionArena’s `dataset-name` should be set to `hf` + +--- +## Example - Online Benchmark + +First start serving your model + +```bash +MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B" +vllm serve ${MODEL_NAME} --disable-log-requests +``` + +Then run the benchmarking script + +```bash +# download dataset +# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json +MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B" +NUM_PROMPTS=10 +BACKEND="vllm" +DATASET_NAME="sharegpt" +DATASET_PATH="/ShareGPT_V3_unfiltered_cleaned_split.json" +python3 vllm/benchmarks/benchmark_serving.py --backend ${BACKEND} --model ${MODEL_NAME} --endpoint /v1/completions --dataset-name ${DATASET_NAME} --dataset-path ${DATASET_PATH} --num-prompts ${NUM_PROMPTS} +``` + +If successful, you will see the following output + +``` +============ Serving Benchmark Result ============ +Successful requests: 10 +Benchmark duration (s): 5.78 +Total input tokens: 1369 +Total generated tokens: 2212 +Request throughput (req/s): 1.73 +Output token throughput (tok/s): 382.89 +Total Token throughput (tok/s): 619.85 +---------------Time to First Token---------------- +Mean TTFT (ms): 71.54 +Median TTFT (ms): 73.88 +P99 TTFT (ms): 79.49 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 7.91 +Median TPOT (ms): 7.96 +P99 TPOT (ms): 8.03 +---------------Inter-token Latency---------------- +Mean ITL (ms): 7.74 +Median ITL (ms): 7.70 +P99 ITL (ms): 8.39 +================================================== +``` + +### VisionArena Benchmark for Vision Language Models + +```bash +# need a model with vision capability here +vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests +``` + +```bash +MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct" +NUM_PROMPTS=10 +BACKEND="openai-chat" +DATASET_NAME="hf" +DATASET_PATH="lmarena-ai/vision-arena-bench-v0.1" +DATASET_SPLIT='train' + +python3 vllm/benchmarks/benchmark_serving.py \ + --backend "${BACKEND}" \ + --model "${MODEL_NAME}" \ + --endpoint "/v1/chat/completions" \ + --dataset-name "${DATASET_NAME}" \ + --dataset-path "${DATASET_PATH}" \ + --hf-split "${DATASET_SPLIT}" \ + --num-prompts "${NUM_PROMPTS}" +``` + +### HuggingFaceDataset Examples + +Currently, HuggingFaceDataset only supports dataset formats +similar to `lmms-lab/LLaVA-OneVision-Data` and `Aeala/ShareGPT_Vicuna_unfiltered`. If you need support for other dataset +formats, please consider contributing. -You can download the dataset by running: ```bash -wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json +# need a model with vision capability here +vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests ``` + +**`lmms-lab/LLaVA-OneVision-Data`** + +```bash +MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct" +NUM_PROMPTS=10 +BACKEND="openai-chat" +DATASET_NAME="hf" +DATASET_PATH="lmms-lab/LLaVA-OneVision-Data" +DATASET_SPLIT='train' +DATASET_SUBSET='chart2text(cauldron)' +python3 vllm/benchmarks/benchmark_serving.py \ + --backend "${BACKEND}" \ + --model "${MODEL_NAME}" \ + --endpoint "/v1/chat/completions" \ + --dataset-name "${DATASET_NAME}" \ + --dataset-path "${DATASET_PATH}" \ + --hf-split "${DATASET_SPLIT}" \ + --num-prompts "${NUM_PROMPTS}" \ + --hf-subset "${DATASET_SUBSET}" +``` + +**`Aeala/ShareGPT_Vicuna_unfiltered`** + +```bash +MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct" +NUM_PROMPTS=10 +BACKEND="openai-chat" +DATASET_NAME="hf" +DATASET_PATH="Aeala/ShareGPT_Vicuna_unfiltered" +DATASET_SPLIT='train' +python3 vllm/benchmarks/benchmark_serving.py \ + --backend "${BACKEND}" \ + --model "${MODEL_NAME}" \ + --endpoint "/v1/chat/completions" \ + --dataset-name "${DATASET_NAME}" \ + --dataset-path "${DATASET_PATH}" \ + --hf-split "${DATASET_SPLIT}" \ + --num-prompts "${NUM_PROMPTS}" \ +``` + +--- +## Example - Offline Throughput Benchmark + +```bash +MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B" +NUM_PROMPTS=10 +DATASET_NAME="sonnet" +DATASET_PATH="vllm/benchmarks/sonnet.txt" + +python3 vllm/benchmarks/benchmark_throughput.py \ + --model "${MODEL_NAME}" \ + --dataset-name "${DATASET_NAME}" \ + --dataset-path "${DATASET_PATH}" \ + --num-prompts "${NUM_PROMPTS}" +``` + +If successful, you will see the following output + +``` +Throughput: 7.15 requests/s, 4656.00 total tokens/s, 1072.15 output tokens/s +Total num prompt tokens: 5014 +Total num output tokens: 1500 +``` + +### VisionArena Benchmark for Vision Language Models + +``` bash +MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct" +NUM_PROMPTS=10 +DATASET_NAME="hf" +DATASET_PATH="lmarena-ai/vision-arena-bench-v0.1" +DATASET_SPLIT="train" + +python3 vllm/benchmarks/benchmark_throughput.py \ + --model "${MODEL_NAME}" \ + --backend "vllm-chat" \ + --dataset-name "${DATASET_NAME}" \ + --dataset-path "${DATASET_PATH}" \ + --num-prompts "${NUM_PROMPTS}" \ + --hf-split "${DATASET_SPLIT}" +``` + +The `num prompt tokens` now includes image token counts + +``` +Throughput: 2.55 requests/s, 4036.92 total tokens/s, 326.90 output tokens/s +Total num prompt tokens: 14527 +Total num output tokens: 1280 +``` + +### Benchmark with LoRA Adapters + +``` bash +# download dataset +# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json +MODEL_NAME="meta-llama/Llama-2-7b-hf" +BACKEND="vllm" +DATASET_NAME="sharegpt" +DATASET_PATH="/ShareGPT_V3_unfiltered_cleaned_split.json" +NUM_PROMPTS=10 +MAX_LORAS=2 +MAX_LORA_RANK=8 +ENABLE_LORA="--enable-lora" +LORA_PATH="yard1/llama-2-7b-sql-lora-test" + +python3 vllm/benchmarks/benchmark_throughput.py \ + --model "${MODEL_NAME}" \ + --backend "${BACKEND}" \ + --dataset_path "${DATASET_PATH}" \ + --dataset_name "${DATASET_NAME}" \ + --num-prompts "${NUM_PROMPTS}" \ + --max-loras "${MAX_LORAS}" \ + --max-lora-rank "${MAX_LORA_RANK}" \ + ${ENABLE_LORA} \ + --lora-path "${LORA_PATH}" + ``` diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index fbab547d094f..0f13c79ae234 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -1,10 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 + import json import os import sys import time import traceback from dataclasses import dataclass, field -from typing import List, Optional, Union +from typing import Optional, Union import aiohttp import huggingface_hub.constants @@ -12,6 +14,9 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) +# NOTE(simon): do not import vLLM here so the benchmark script +# can run without vLLM installed. + AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) @@ -22,8 +27,11 @@ class RequestFuncInput: prompt_len: int output_len: int model: str - best_of: int = 1 - use_beam_search: bool = False + model_name: Optional[str] = None + logprobs: Optional[int] = None + extra_body: Optional[dict] = None + multi_modal_content: Optional[dict] = None + ignore_eos: bool = False @dataclass @@ -31,9 +39,11 @@ class RequestFuncOutput: generated_text: str = "" success: bool = False latency: float = 0.0 + output_tokens: int = 0 ttft: float = 0.0 # Time to first token - itl: List[float] = field( - default_factory=list) # List of inter-token latencies + itl: list[float] = field( + default_factory=list) # list of inter-token latencies + tpot: float = 0.0 # avg next-token latencies prompt_len: int = 0 error: str = "" @@ -45,14 +55,15 @@ async def async_request_tgi( api_url = request_func_input.api_url assert api_url.endswith("generate_stream") - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - assert not request_func_input.use_beam_search + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: params = { - "best_of": request_func_input.best_of, "max_new_tokens": request_func_input.output_len, "do_sample": True, "temperature": 0.01, # TGI does not accept 0.0 temperature. "top_p": 0.99, # TGI does not accept 1.0 top_p. + "truncate": request_func_input.prompt_len, + "ignore_eos_token": request_func_input.ignore_eos, } payload = { "inputs": request_func_input.prompt, @@ -60,6 +71,10 @@ async def async_request_tgi( } output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len + if request_func_input.ignore_eos: + output.output_tokens = request_func_input.output_len + else: + output.output_tokens = None ttft = 0.0 st = time.perf_counter() @@ -73,11 +88,11 @@ async def async_request_tgi( continue chunk_bytes = chunk_bytes.decode("utf-8") - #NOTE: Sometimes TGI returns a ping response without + # NOTE: Sometimes TGI returns a ping response without # any data, we should skip it. if chunk_bytes.startswith(":"): continue - chunk = remove_prefix(chunk_bytes, "data:") + chunk = chunk_bytes.removeprefix("data:") data = json.loads(chunk) timestamp = time.perf_counter() @@ -116,9 +131,8 @@ async def async_request_trt_llm( api_url = request_func_input.api_url assert api_url.endswith("generate_stream") - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - assert not request_func_input.use_beam_search - assert request_func_input.best_of == 1 + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: payload = { "accumulate_tokens": True, "text_input": request_func_input.prompt, @@ -127,6 +141,8 @@ async def async_request_trt_llm( "max_tokens": request_func_input.output_len, "stream": True, } + if request_func_input.ignore_eos: + payload["min_length"] = request_func_input.output_len output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -141,15 +157,15 @@ async def async_request_trt_llm( if not chunk_bytes: continue - chunk = remove_prefix(chunk_bytes.decode("utf-8"), - "data:") + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data:") data = json.loads(chunk) output.generated_text += data["text_output"] timestamp = time.perf_counter() # First token if ttft == 0.0: - ttft = time.perf_counter() - st + ttft = timestamp - st output.ttft = ttft # Decoding phase @@ -179,9 +195,8 @@ async def async_request_deepspeed_mii( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - assert request_func_input.best_of == 1 - assert not request_func_input.use_beam_search + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: payload = { "prompt": request_func_input.prompt, @@ -225,19 +240,27 @@ async def async_request_openai_completions( ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith( - "completions" - ), "OpenAI Completions API URL must end with 'completions'." + ("completions", "profile") + ), "OpenAI Completions API URL must end with 'completions' or 'profile'." - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - assert not request_func_input.use_beam_search + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: payload = { - "model": request_func_input.model, + "model": request_func_input.model_name \ + if request_func_input.model_name else request_func_input.model, "prompt": request_func_input.prompt, "temperature": 0.0, - "best_of": request_func_input.best_of, "max_tokens": request_func_input.output_len, + "logprobs": request_func_input.logprobs, "stream": True, + "stream_options": { + "include_usage": True, + }, } + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" } @@ -246,45 +269,56 @@ async def async_request_openai_completions( output.prompt_len = request_func_input.prompt_len generated_text = "" - ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: + first_chunk_received = False async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue - chunk = remove_prefix(chunk_bytes.decode("utf-8"), - "data: ") - if chunk == "[DONE]": - latency = time.perf_counter() - st - else: + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data: ") + if chunk != "[DONE]": data = json.loads(chunk) # NOTE: Some completion API might have a last # usage summary response without a token so we # want to check a token was generated - if data["choices"][0]["text"]: + if choices := data.get("choices"): + # Note that text could be empty here + # e.g. for special tokens + text = choices[0].get("text") timestamp = time.perf_counter() # First token - if ttft == 0.0: + if not first_chunk_received: + first_chunk_received = True ttft = time.perf_counter() - st output.ttft = ttft # Decoding phase - output.itl.append(timestamp - - most_recent_timestamp) + else: + output.itl.append(timestamp - + most_recent_timestamp) most_recent_timestamp = timestamp - generated_text += data["choices"][0]["text"] - + generated_text += text or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") + if first_chunk_received: + output.success = True + else: + output.success = False + output.error = ( + "Never received a valid chunk to calculate TTFT." + "This response will be marked as failed!") output.generated_text = generated_text - output.success = True - output.latency = latency + output.latency = most_recent_timestamp - st else: output.error = response.reason or "" output.success = False @@ -304,23 +338,34 @@ async def async_request_openai_chat_completions( ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith( - "chat/completions" + ("chat/completions", "profile") ), "OpenAI Chat Completions API URL must end with 'chat/completions'." - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - assert not request_func_input.use_beam_search + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + content = [{"type": "text", "text": request_func_input.prompt}] + if request_func_input.multi_modal_content: + content.append(request_func_input.multi_modal_content) payload = { - "model": request_func_input.model, + "model": request_func_input.model_name \ + if request_func_input.model_name else request_func_input.model, "messages": [ { "role": "user", - "content": request_func_input.prompt, + "content": content }, ], "temperature": 0.0, - "max_tokens": request_func_input.output_len, + "max_completion_tokens": request_func_input.output_len, "stream": True, + "stream_options": { + "include_usage": True, + }, } + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) headers = { "Content-Type": "application/json", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", @@ -342,19 +387,17 @@ async def async_request_openai_chat_completions( if not chunk_bytes: continue - chunk = remove_prefix(chunk_bytes.decode("utf-8"), - "data: ") - if chunk == "[DONE]": - latency = time.perf_counter() - st - else: + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data: ") + if chunk != "[DONE]": timestamp = time.perf_counter() data = json.loads(chunk) - delta = data["choices"][0]["delta"] - if delta.get("content", None): + if choices := data.get("choices"): + content = choices[0]["delta"].get("content") # First token if ttft == 0.0: - ttft = time.perf_counter() - st + ttft = timestamp - st output.ttft = ttft # Decoding phase @@ -362,13 +405,16 @@ async def async_request_openai_chat_completions( output.itl.append(timestamp - most_recent_timestamp) - generated_text += delta["content"] + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") most_recent_timestamp = timestamp output.generated_text = generated_text output.success = True - output.latency = latency + output.latency = most_recent_timestamp - st else: output.error = response.reason or "" output.success = False @@ -382,36 +428,54 @@ async def async_request_openai_chat_completions( return output -# Since vllm must support Python 3.8, we can't use str.removeprefix(prefix) -# introduced in Python 3.9 -def remove_prefix(text: str, prefix: str) -> str: - if text.startswith(prefix): - return text[len(prefix):] - return text - - def get_model(pretrained_model_name_or_path: str) -> str: if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true': from modelscope import snapshot_download - model_path = snapshot_download( - model_id=pretrained_model_name_or_path, - local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, - ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) + from vllm.model_executor.model_loader.weight_utils import get_lock - return model_path + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(pretrained_model_name_or_path): + model_path = snapshot_download( + model_id=pretrained_model_name_or_path, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) + + return model_path return pretrained_model_name_or_path def get_tokenizer( - pretrained_model_name_or_path: str, trust_remote_code: bool + pretrained_model_name_or_path: str, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: if pretrained_model_name_or_path is not None and not os.path.exists( pretrained_model_name_or_path): pretrained_model_name_or_path = get_model( pretrained_model_name_or_path) - return AutoTokenizer.from_pretrained(pretrained_model_name_or_path, - trust_remote_code=trust_remote_code) + if tokenizer_mode == "slow": + if kwargs.get("use_fast", False): + raise ValueError( + "Cannot use the fast tokenizer in slow tokenizer mode.") + kwargs["use_fast"] = False + if tokenizer_mode == "mistral": + try: + from vllm.transformers_utils.tokenizer import MistralTokenizer + except ImportError as e: + raise ImportError("MistralTokenizer requires vllm package.\n" + "Please install it with `pip install vllm` " + "to use mistral tokenizer mode.") from e + return MistralTokenizer.from_pretrained( + str(pretrained_model_name_or_path)) + else: + return AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + **kwargs, + ) ASYNC_REQUEST_FUNCS = { @@ -423,4 +487,5 @@ def get_tokenizer( "openai-chat": async_request_openai_chat_completions, "tensorrt-llm": async_request_trt_llm, "scalellm": async_request_openai_completions, + "sglang": async_request_openai_completions, } diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py new file mode 100644 index 000000000000..0567875f9862 --- /dev/null +++ b/benchmarks/benchmark_dataset.py @@ -0,0 +1,717 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This module defines a framework for sampling benchmark requests from various +datasets. Each dataset subclass of BenchmarkDataset must implement sample +generation. Supported dataset types include: + - ShareGPT + - Random (synthetic) + - Sonnet + - BurstGPT + - HuggingFace + - VisionArena + +TODO: Implement CustomDataset to parse a JSON file and convert its contents into +SampleRequest instances, similar to the approach used in ShareGPT. +""" + +import base64 +import io +import json +import logging +import random +from abc import ABC, abstractmethod +from collections.abc import Mapping +from dataclasses import dataclass +from functools import cache +from typing import Any, Optional, Union + +import numpy as np +import pandas as pd +from datasets import load_dataset +from PIL import Image +from transformers import PreTrainedTokenizerBase + +from vllm.lora.request import LoRARequest +from vllm.lora.utils import get_adapter_absolute_path +from vllm.multimodal import MultiModalDataDict +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer + +logger = logging.getLogger(__name__) + +# ----------------------------------------------------------------------------- +# Data Classes +# ----------------------------------------------------------------------------- + + +@dataclass +class SampleRequest: + """ + Represents a single inference request for benchmarking. + """ + + prompt: Union[str, Any] + prompt_len: int + expected_output_len: int + multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None + lora_request: Optional[LoRARequest] = None + + +# ----------------------------------------------------------------------------- +# Benchmark Dataset Base Class +# ----------------------------------------------------------------------------- + + +class BenchmarkDataset(ABC): + DEFAULT_SEED = 0 + + def __init__( + self, + dataset_path: Optional[str] = None, + random_seed: int = DEFAULT_SEED, + ) -> None: + """ + Initialize the BenchmarkDataset with an optional dataset path and random + seed. Args: + dataset_path (Optional[str]): Path to the dataset. If None, it + indicates that a default or random dataset might be used. + random_seed (int): Seed value for reproducible shuffling or + sampling. Defaults to DEFAULT_SEED. + """ + self.dataset_path = dataset_path + # Set the random seed, ensuring that a None value is replaced with the + # default seed. + self.random_seed = (random_seed + if random_seed is not None else self.DEFAULT_SEED) + self.data = None + + def apply_multimodal_chat_transformation( + self, + prompt: str, + mm_content: Optional[MultiModalDataDict] = None) -> list[dict]: + """ + Transform a prompt and optional multimodal content into a chat format. + This method is used for chat models that expect a specific conversation + format. + """ + content = [{"text": prompt, "type": "text"}] + if mm_content is not None: + content.append(mm_content) + return [{"role": "user", "content": content}] + + def load_data(self) -> None: + """ + Load data from the dataset path into self.data. + + This method must be overridden by subclasses since the method to load + data will vary depending on the dataset format and source. + + Raises: + NotImplementedError: If a subclass does not implement this method. + """ + # TODO (jenniferzhao): add support for downloading data + raise NotImplementedError( + "load_data must be implemented in subclasses.") + + def get_random_lora_request( + self, + tokenizer: PreTrainedTokenizerBase, + max_loras: Optional[int] = None, + lora_path: Optional[str] = None, + ) -> tuple[Optional[LoRARequest], AnyTokenizer]: + """ + Optionally select a random LoRA request and return its associated + tokenizer. + + This method is used when LoRA parameters are provided. It randomly + selects a LoRA based on max_loras and retrieves a cached tokenizer for + that LoRA if available. Otherwise, it returns the base tokenizer. + + Args: + tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no + LoRA is selected. max_loras (Optional[int]): The maximum number of + LoRAs available. If None, LoRA is not used. lora_path + (Optional[str]): Path to the LoRA parameters on disk. If None, LoRA + is not used. + + Returns: + tuple[Optional[LoRARequest], AnyTokenizer]: A tuple where the first + element is a LoRARequest (or None if not applicable) and the second + element is the tokenizer associated with the LoRA request (or the + base tokenizer). + """ + if max_loras is None or lora_path is None: + return None, tokenizer + + # Generate a random LoRA ID in the range [1, max_loras]. + lora_id = random.randint(1, max_loras) + lora_request = LoRARequest( + lora_name=str(lora_id), + lora_int_id=lora_id, + lora_path=lora_path_on_disk(lora_path), + ) + if lora_id not in lora_tokenizer_cache: + lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request) + # Return lora_request and the cached tokenizer if available; otherwise, + # return the base tokenizer + return lora_request, lora_tokenizer_cache[lora_id] or tokenizer + + @abstractmethod + def sample(self, tokenizer: PreTrainedTokenizerBase, + num_requests: int) -> list[SampleRequest]: + """ + Abstract method to generate sample requests from the dataset. + + Subclasses must override this method to implement dataset-specific logic + for generating a list of SampleRequest objects. + + Args: + tokenizer (PreTrainedTokenizerBase): The tokenizer to be used + for processing the dataset's text. + num_requests (int): The number of sample requests to generate. + + Returns: + list[SampleRequest]: A list of sample requests generated from the + dataset. + """ + raise NotImplementedError("sample must be implemented in subclasses.") + + def maybe_oversample_requests(self, requests: list[SampleRequest], + num_requests: int) -> None: + """ + Oversamples the list of requests if its size is less than the desired + number. + + Args: + requests (List[SampleRequest]): The current list of sampled + requests. num_requests (int): The target number of requests. + """ + if len(requests) < num_requests: + random.seed(self.random_seed) + additional = random.choices(requests, + k=num_requests - len(requests)) + requests.extend(additional) + logger.info("Oversampled requests to reach %d total samples.", + num_requests) + + +# ----------------------------------------------------------------------------- +# Utility Functions and Global Caches +# ----------------------------------------------------------------------------- + + +def is_valid_sequence( + prompt_len: int, + output_len: int, + min_len: int = 4, + max_prompt_len: int = 1024, + max_total_len: int = 2048, + skip_min_output_len_check: bool = False, +) -> bool: + """ + Validate a sequence based on prompt and output lengths. + + Default pruning criteria are copied from the original `sample_hf_requests` + and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as + from `sample_requests` in benchmark_throughput.py. + """ + # Check for invalid conditions + prompt_too_short = prompt_len < min_len + output_too_short = (not skip_min_output_len_check) and (output_len + < min_len) + prompt_too_long = prompt_len > max_prompt_len + combined_too_long = (prompt_len + output_len) > max_total_len + + # Return True if none of the invalid conditions are met + return not (prompt_too_short or output_too_short or prompt_too_long + or combined_too_long) + + +@cache +def lora_path_on_disk(lora_path: str) -> str: + return get_adapter_absolute_path(lora_path) + + +# Global cache for LoRA tokenizers. +lora_tokenizer_cache: dict[int, AnyTokenizer] = {} + + +def process_image(image: Any) -> Mapping[str, Any]: + """ + Process a single image input and return a multimedia content dictionary. + + For a PIL.Image.Image input: + - Converts the image to RGB. + - Saves the image as a JPEG in-memory. + - Encodes the JPEG data as a base64 string. + - Returns a dictionary with the image as a base64 data URL. + + For a string input: + - Treats the string as a URL or file path. + - Prepends "file://" if the string doesn't start with "http://" or + "file://". + - Returns a dictionary with the image URL. + + Raises: + ValueError: If the input is neither a PIL.Image.Image nor a string. + """ + if isinstance(image, Image.Image): + image = image.convert("RGB") + with io.BytesIO() as image_data: + image.save(image_data, format="JPEG") + image_base64 = base64.b64encode( + image_data.getvalue()).decode("utf-8") + return { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + }, + } + + if isinstance(image, str): + image_url = (image if image.startswith( + ("http://", "file://")) else f"file://{image}") + return {"type": "image_url", "image_url": {"url": image_url}} + + raise ValueError( + f"Invalid image input {image}. Must be a PIL.Image.Image or str.") + + +# ----------------------------------------------------------------------------- +# Random Dataset Implementation (Synthetic Data) +# ----------------------------------------------------------------------------- + + +class RandomDataset(BenchmarkDataset): + # Default values copied from benchmark_serving.py for the random dataset. + DEFAULT_PREFIX_LEN = 0 + DEFAULT_RANGE_RATIO = 1.0 + DEFAULT_INPUT_LEN = 1024 + DEFAULT_OUTPUT_LEN = 128 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + range_ratio: float = DEFAULT_RANGE_RATIO, + input_len: int = DEFAULT_INPUT_LEN, + output_len: int = DEFAULT_OUTPUT_LEN, + **kwargs, + ) -> list[SampleRequest]: + vocab_size = tokenizer.vocab_size + + prefix_token_ids = (np.random.randint( + 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else []) + + input_low = int(input_len * range_ratio) + output_low = int(output_len * range_ratio) + + input_lens = np.random.randint(input_low, + input_len + 1, + size=num_requests) + output_lens = np.random.randint(output_low, + output_len + 1, + size=num_requests) + offsets = np.random.randint(0, vocab_size, size=num_requests) + + requests = [] + for i in range(num_requests): + inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) % + vocab_size).tolist() + token_sequence = prefix_token_ids + inner_seq + prompt = tokenizer.decode(token_sequence) + total_input_len = prefix_len + int(input_lens[i]) + requests.append( + SampleRequest( + prompt=prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + )) + return requests + + +# ----------------------------------------------------------------------------- +# ShareGPT Dataset Implementation +# ----------------------------------------------------------------------------- + + +class ShareGPTDataset(BenchmarkDataset): + """ + Implements the ShareGPT dataset. Loads data from a JSON file and generates + sample requests based on conversation turns. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + with open(self.dataset_path, encoding="utf-8") as f: + self.data = json.load(f) + # Filter entries with at least two conversation turns. + self.data = [ + entry for entry in self.data + if "conversations" in entry and len(entry["conversations"]) >= 2 + ] + random.seed(self.random_seed) + random.shuffle(self.data) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + lora_path: Optional[str] = None, + max_loras: Optional[int] = None, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + samples: list = [] + for entry in self.data: + if len(samples) >= num_requests: + break + prompt, completion = ( + entry["conversations"][0]["value"], + entry["conversations"][1]["value"], + ) + + lora_request, tokenizer = self.get_random_lora_request( + tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + new_output_len = (len(completion_ids) + if output_len is None else output_len) + if not is_valid_sequence(prompt_len, + new_output_len, + skip_min_output_len_check=output_len + is not None): + continue + if enable_multimodal_chat: + prompt = self.apply_multimodal_chat_transformation( + prompt, None) + samples.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=new_output_len, + lora_request=lora_request, + )) + self.maybe_oversample_requests(samples, num_requests) + return samples + + +# ----------------------------------------------------------------------------- +# Sonnet Dataset Implementation +# ----------------------------------------------------------------------------- + + +class SonnetDataset(BenchmarkDataset): + """ + Simplified implementation of the Sonnet dataset. Loads poem lines from a + text file and generates sample requests. Default values here copied from + `benchmark_serving.py` for the sonnet dataset. + """ + + DEFAULT_PREFIX_LEN = 200 + DEFAULT_INPUT_LEN = 550 + DEFAULT_OUTPUT_LEN = 150 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if not self.dataset_path: + raise ValueError("dataset_path must be provided.") + with open(self.dataset_path, encoding="utf-8") as f: + self.data = f.readlines() + + def sample( + self, + tokenizer, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + input_len: int = DEFAULT_INPUT_LEN, + output_len: int = DEFAULT_OUTPUT_LEN, + return_prompt_formatted: bool = False, + **kwargs, + ) -> list: + # Calculate average token length for a poem line. + tokenized_lines = [tokenizer(line).input_ids for line in self.data] + avg_len = sum(len(tokens) + for tokens in tokenized_lines) / len(tokenized_lines) + + # Build the base prompt. + base_prompt = "Pick as many lines as you can from these poem lines:\n" + base_msg = [{"role": "user", "content": base_prompt}] + base_fmt = tokenizer.apply_chat_template(base_msg, + add_generation_prompt=True, + tokenize=False) + base_offset = len(tokenizer(base_fmt).input_ids) + if input_len <= base_offset: + raise ValueError( + f"'input_len' must be higher than the base prompt length " + f"({base_offset}).") + + # Determine how many poem lines to use. + num_input_lines = round((input_len - base_offset) / avg_len) + num_prefix_lines = round((prefix_len - base_offset) / avg_len) + prefix_lines = self.data[:num_prefix_lines] + + samples = [] + for _ in range(num_requests): + extra_lines = random.choices(self.data, + k=num_input_lines - num_prefix_lines) + prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}" + msg = [{"role": "user", "content": prompt}] + prompt_formatted = tokenizer.apply_chat_template( + msg, add_generation_prompt=True, tokenize=False) + prompt_len = len(tokenizer(prompt_formatted).input_ids) + samples.append( + SampleRequest( + prompt=prompt_formatted + if return_prompt_formatted else prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + return samples + + +# ----------------------------------------------------------------------------- +# BurstGPT Dataset Implementation +# ----------------------------------------------------------------------------- + + +class BurstGPTDataset(BenchmarkDataset): + """ + Implements the BurstGPT dataset. Loads data from a CSV file and generates + sample requests based on synthetic prompt generation. Only rows with Model + "GPT-4" and positive response tokens are used. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self, ): + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + df = pd.read_csv(self.dataset_path) + # Filter to keep only GPT-4 rows. + gpt4_df = df[df["Model"] == "GPT-4"] + # Remove failed requests (where Response tokens is 0 or less). + gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0] + # Sample the desired number of rows. + self.data = gpt4_df + + def _sample_loaded_data(self, num_requests: int) -> list: + if num_requests <= len(self.data): + data = self.data.sample(n=num_requests, + random_state=self.random_seed) + else: + data = self.data.sample( + n=num_requests, + random_state=self.random_seed, + replace=True, + ) + # Convert the dataframe to a list of lists. + return data.values.tolist() + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + max_loras: Optional[int] = None, + lora_path: Optional[str] = None, + **kwargs, + ) -> list[SampleRequest]: + samples = [] + data = self._sample_loaded_data(num_requests=num_requests) + for i in range(num_requests): + input_len = int(data[i][2]) + output_len = int(data[i][3]) + lora_req, tokenizer = self.get_random_lora_request( + tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + vocab_size = tokenizer.vocab_size + # Generate a synthetic prompt: a list of token IDs computed as (i + + # j) modulo vocab_size. + token_ids = [(i + j) % vocab_size for j in range(input_len)] + prompt = tokenizer.decode(token_ids) + samples.append( + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=output_len, + lora_request=lora_req, + )) + return samples + + +# ----------------------------------------------------------------------------- +# HuggingFace Dataset Implementation +# ----------------------------------------------------------------------------- + + +class HuggingFaceDataset(BenchmarkDataset): + """ + Dataset class for processing a HuggingFace dataset with conversation data + and optional images. + """ + + def __init__( + self, + dataset_split: str, + dataset_subset: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dataset_split = dataset_split + self.dataset_subset = dataset_subset + + self.load_data() + + def load_data(self) -> None: + if not self.dataset_path: + raise ValueError("dataset_path must be provided for loading data.") + + self.data = load_dataset( + self.dataset_path, + name=self.dataset_subset, + split=self.dataset_split, + streaming=True, + ) + if self.data.features is None or "conversations" \ + not in self.data.features: + raise ValueError( + "HuggingFaceDataset currently only supports datasets with " + "a 'conversations' column like lmms-lab/LLaVA-OneVision-Data. " + "Please consider contributing if you would like to add " + "support for additional dataset formats.") + # Shuffle and filter examples with at least 2 conversations. + self.data = self.data.shuffle(seed=self.random_seed).filter( + lambda x: len(x["conversations"]) >= 2) + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs) -> list: + sampled_requests = [] + dynamic_output = output_len is None + + for item in self.data: + if len(sampled_requests) >= num_requests: + break + conv = item["conversations"] + prompt, completion = conv[0]["value"], conv[1]["value"] + + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + completion_len = len(completion_ids) + output_len = completion_len if dynamic_output else output_len + assert isinstance(output_len, int) and output_len > 0 + if dynamic_output and not is_valid_sequence( + prompt_len, completion_len): + continue + mm_content = process_image( + item["image"]) if "image" in item else None + if enable_multimodal_chat: + # Note: when chat is enabled the request prompt_len is no longer + # accurate and we will be using request output to count the + # actual prompt len and output len + prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Vision Arena Dataset Implementation +# ----------------------------------------------------------------------------- + + +class VisionArenaDataset(HuggingFaceDataset): + """ + Vision Arena Dataset. + """ + + DEFAULT_OUTPUT_LEN = 128 + VISION_ARENA_DATASET_PATH = "lmarena-ai/vision-arena-bench-v0.1" + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if self.dataset_path != self.VISION_ARENA_DATASET_PATH: + raise ValueError(f"Only support Vision Arena dataset.\ + This data path {self.dataset_path} is not valid.") + if self.dataset_subset is None and self.dataset_split != "train": + raise ValueError("Dataset split must be 'train'.") + + self.load_data() + + def load_data(self) -> None: + dataset = load_dataset( + self.dataset_path, + name=self.dataset_subset, + split=self.dataset_split, + streaming=True, + ) + self.data = dataset.shuffle(seed=self.random_seed) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = item["turns"][0][0]["content"] + mm_content = process_image(item["images"][0]) + prompt_len = len(tokenizer(prompt).input_ids) + if enable_multimodal_chat: + # Note: when chat is enabled the request prompt_len is no longer + # accurate and we will be using request output to count the + # actual prompt len + prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 97afd301c8f2..dfd9bb1e6a4d 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -1,70 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 """Benchmark the latency of processing a single batch of requests.""" + import argparse +import dataclasses import json +import os import time from pathlib import Path -from typing import List, Optional +from typing import Any, Optional import numpy as np import torch +from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json from tqdm import tqdm from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs -from vllm.inputs import PromptInputs -from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.inputs import PromptType +from vllm.sampling_params import BeamSearchParams from vllm.utils import FlexibleArgumentParser +def save_to_pytorch_benchmark_format(args: argparse.Namespace, + results: dict[str, Any]) -> None: + pt_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={"latency": results["latencies"]}, + extra_info={k: results[k] + for k in ["avg_latency", "percentiles"]}) + if pt_records: + pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" + write_to_json(pt_file, pt_records) + + def main(args: argparse.Namespace): print(args) + engine_args = EngineArgs.from_cli_args(args) + # NOTE(woosuk): If the request cannot be processed in a single batch, # the engine will automatically process the request in multiple batches. - llm = LLM( - model=args.model, - speculative_model=args.speculative_model, - num_speculative_tokens=args.num_speculative_tokens, - speculative_draft_tensor_parallel_size=\ - args.speculative_draft_tensor_parallel_size, - tokenizer=args.tokenizer, - quantization=args.quantization, - tensor_parallel_size=args.tensor_parallel_size, - trust_remote_code=args.trust_remote_code, - dtype=args.dtype, - max_model_len=args.max_model_len, - enforce_eager=args.enforce_eager, - kv_cache_dtype=args.kv_cache_dtype, - quantization_param_path=args.quantization_param_path, - device=args.device, - ray_workers_use_nsight=args.ray_workers_use_nsight, - use_v2_block_manager=args.use_v2_block_manager, - enable_chunked_prefill=args.enable_chunked_prefill, - download_dir=args.download_dir, - block_size=args.block_size, - gpu_memory_utilization=args.gpu_memory_utilization, - load_format=args.load_format, - distributed_executor_backend=args.distributed_executor_backend, - otlp_traces_endpoint=args.otlp_traces_endpoint, - enable_prefix_caching=args.enable_prefix_caching, - ) + llm = LLM(**dataclasses.asdict(engine_args)) + assert llm.llm_engine.model_config.max_model_len >= ( + args.input_len + + args.output_len), ("Please ensure that max_model_len is greater than" + " the sum of input_len and output_len.") sampling_params = SamplingParams( n=args.n, - temperature=0.0 if args.use_beam_search else 1.0, + temperature=1.0, top_p=1.0, - use_beam_search=args.use_beam_search, ignore_eos=True, max_tokens=args.output_len, + detokenize=not args.disable_detokenize, ) print(sampling_params) dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_inputs: List[PromptInputs] = [{ + dummy_prompts: list[PromptType] = [{ "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] + def llm_generate(): + if not args.use_beam_search: + llm.generate(dummy_prompts, + sampling_params=sampling_params, + use_tqdm=False) + else: + llm.beam_search( + dummy_prompts, + BeamSearchParams( + beam_width=args.n, + max_tokens=args.output_len, + ignore_eos=True, + ), + ) + def run_to_completion(profile_dir: Optional[str] = None): if profile_dir: with torch.profiler.profile( @@ -73,16 +85,13 @@ def run_to_completion(profile_dir: Optional[str] = None): torch.profiler.ProfilerActivity.CUDA, ], on_trace_ready=torch.profiler.tensorboard_trace_handler( - str(profile_dir))) as p: - llm.generate(dummy_inputs, - sampling_params=sampling_params, - use_tqdm=False) - print(p.key_averages()) + str(profile_dir)), + ) as p: + llm_generate() + print(p.key_averages().table(sort_by="self_cuda_time_total")) else: start_time = time.perf_counter() - llm.generate(dummy_inputs, - sampling_params=sampling_params, - use_tqdm=False) + llm_generate() end_time = time.perf_counter() latency = end_time - start_time return latency @@ -94,9 +103,8 @@ def run_to_completion(profile_dir: Optional[str] = None): if args.profile: profile_dir = args.profile_result_dir if not profile_dir: - profile_dir = Path( - "." - ) / "vllm_benchmark_result" / f"latency_result_{time.time()}" + profile_dir = (Path(".") / "vllm_benchmark_result" / + f"latency_result_{time.time()}") print(f"Profiling (results will be saved to '{profile_dir}')...") run_to_completion(profile_dir=profile_dir) return @@ -108,9 +116,9 @@ def run_to_completion(profile_dir: Optional[str] = None): latencies = np.array(latencies) percentages = [10, 25, 50, 75, 90, 99] percentiles = np.percentile(latencies, percentages) - print(f'Avg latency: {np.mean(latencies)} seconds') + print(f"Avg latency: {np.mean(latencies)} seconds") for percentage, percentile in zip(percentages, percentiles): - print(f'{percentage}% percentile latency: {percentile} seconds') + print(f"{percentage}% percentile latency: {percentile} seconds") # Output JSON results if specified if args.output_json: @@ -121,165 +129,58 @@ def run_to_completion(profile_dir: Optional[str] = None): } with open(args.output_json, "w") as f: json.dump(results, f, indent=4) + save_to_pytorch_benchmark_format(args, results) -if __name__ == '__main__': +if __name__ == "__main__": parser = FlexibleArgumentParser( - description='Benchmark the latency of processing a single batch of ' - 'requests till completion.') - parser.add_argument('--model', type=str, default='facebook/opt-125m') - parser.add_argument('--speculative-model', type=str, default=None) - parser.add_argument('--num-speculative-tokens', type=int, default=None) - parser.add_argument('--speculative-draft-tensor-parallel-size', - '-spec-draft-tp', - type=int, - default=None) - parser.add_argument('--tokenizer', type=str, default=None) - parser.add_argument('--quantization', - '-q', - choices=[*QUANTIZATION_METHODS, None], - default=None) - parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) - parser.add_argument('--input-len', type=int, default=32) - parser.add_argument('--output-len', type=int, default=128) - parser.add_argument('--batch-size', type=int, default=8) - parser.add_argument('--n', - type=int, - default=1, - help='Number of generated sequences per prompt.') - parser.add_argument('--use-beam-search', action='store_true') - parser.add_argument('--num-iters-warmup', - type=int, - default=10, - help='Number of iterations to run for warmup.') - parser.add_argument('--num-iters', - type=int, - default=30, - help='Number of iterations to run.') - parser.add_argument('--trust-remote-code', - action='store_true', - help='trust remote code from huggingface') + description="Benchmark the latency of processing a single batch of " + "requests till completion.") + parser.add_argument("--input-len", type=int, default=32) + parser.add_argument("--output-len", type=int, default=128) + parser.add_argument("--batch-size", type=int, default=8) parser.add_argument( - '--max-model-len', + "--n", type=int, - default=None, - help='Maximum length of a sequence (including prompt and output). ' - 'If None, will be derived from the model.') - parser.add_argument( - '--dtype', - type=str, - default='auto', - choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], - help='data type for model weights and activations. ' - 'The "auto" option will use FP16 precision ' - 'for FP32 and FP16 models, and BF16 precision ' - 'for BF16 models.') - parser.add_argument('--enforce-eager', - action='store_true', - help='enforce eager mode and disable CUDA graph') - parser.add_argument( - '--kv-cache-dtype', - type=str, - choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], - default="auto", - help='Data type for kv cache storage. If "auto", will use model ' - 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' - 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') - parser.add_argument( - '--quantization-param-path', - type=str, - default=None, - help='Path to the JSON file containing the KV cache scaling factors. ' - 'This should generally be supplied, when KV cache dtype is FP8. ' - 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' - 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' - 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' - 'instead supported for common inference criteria.') - parser.add_argument( - '--profile', - action='store_true', - help='profile the generation process of a single batch') - parser.add_argument( - '--profile-result-dir', - type=str, - default=None, - help=('path to save the pytorch profiler output. Can be visualized ' - 'with ui.perfetto.dev or Tensorboard.')) + default=1, + help="Number of generated sequences per prompt.", + ) + parser.add_argument("--use-beam-search", action="store_true") parser.add_argument( - "--device", - type=str, - default="auto", - choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"], - help='device type for vLLM execution, supporting CUDA, OpenVINO and ' - 'CPU.') - parser.add_argument('--block-size', + "--num-iters-warmup", + type=int, + default=10, + help="Number of iterations to run for warmup.", + ) + parser.add_argument("--num-iters", type=int, - default=16, - help='block size of key/value cache') - parser.add_argument( - '--enable-chunked-prefill', - action='store_true', - help='If True, the prefill requests can be chunked based on the ' - 'max_num_batched_tokens') - parser.add_argument("--enable-prefix-caching", - action='store_true', - help="Enable automatic prefix caching") - parser.add_argument('--use-v2-block-manager', action='store_true') + default=30, + help="Number of iterations to run.") parser.add_argument( - "--ray-workers-use-nsight", - action='store_true', - help="If specified, use nsight to profile ray workers", + "--profile", + action="store_true", + help="profile the generation process of a single batch", ) - parser.add_argument('--download-dir', - type=str, - default=None, - help='directory to download and load the weights, ' - 'default to the default cache dir of huggingface') parser.add_argument( - '--output-json', + "--profile-result-dir", type=str, default=None, - help='Path to save the latency results in JSON format.') - parser.add_argument('--gpu-memory-utilization', - type=float, - default=0.9, - help='the fraction of GPU memory to be used for ' - 'the model executor, which can range from 0 to 1.' - 'If unspecified, will use the default value of 0.9.') + help=("path to save the pytorch profiler output. Can be visualized " + "with ui.perfetto.dev or Tensorboard."), + ) parser.add_argument( - '--load-format', + "--output-json", type=str, - default=EngineArgs.load_format, - choices=[ - 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer', - 'bitsandbytes' - ], - help='The format of the model weights to load.\n\n' - '* "auto" will try to load the weights in the safetensors format ' - 'and fall back to the pytorch bin format if safetensors format ' - 'is not available.\n' - '* "pt" will load the weights in the pytorch bin format.\n' - '* "safetensors" will load the weights in the safetensors format.\n' - '* "npcache" will load the weights in pytorch format and store ' - 'a numpy cache to speed up the loading.\n' - '* "dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.\n' - '* "tensorizer" will load the weights using tensorizer from ' - 'CoreWeave. See the Tensorize vLLM Model script in the Examples' - 'section for more information.\n' - '* "bitsandbytes" will load the weights using bitsandbytes ' - 'quantization.\n') - parser.add_argument( - '--distributed-executor-backend', - choices=['ray', 'mp'], default=None, - help='Backend to use for distributed serving. When more than 1 GPU ' - 'is used, will be automatically set to "ray" if installed ' - 'or "mp" (multiprocessing) otherwise.') + help="Path to save the latency results in JSON format.", + ) parser.add_argument( - '--otlp-traces-endpoint', - type=str, - default=None, - help='Target URL to which OpenTelemetry traces will be sent.') + "--disable-detokenize", + action="store_true", + help=("Do not detokenize responses (i.e. do not include " + "detokenization time in the latency measurement)"), + ) + + parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_long_document_qa_throughput.py b/benchmarks/benchmark_long_document_qa_throughput.py new file mode 100644 index 000000000000..21480578edbd --- /dev/null +++ b/benchmarks/benchmark_long_document_qa_throughput.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Offline benchmark to test the long document QA throughput. + +Example usage: + # This workload samples 8 different prompts with a default input + # length of 20000 tokens, then replicates each prompt 2 times + # in random order. + python benchmark_long_document_qa_throughput.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-documents 8 \ + --repeat-count 2 + +Commandline arguments: + --num-documents: The number of documents to sample prompts from. + + --document-length: The length of each document in tokens. + (Optional, default: 20000) + + --output-len: The number of tokens to generate for each prompt. + (Optional, default: 10) + + --repeat-count: The number of times to repeat each prompt. + (Optional, default: 2) + + --repeat-mode: The mode to repeat prompts. The supported modes are: + - 'random': shuffle the prompts randomly. (Default) + - 'tile': the entire prompt list is repeated in sequence. (Potentially + lowest cache hit) + - 'interleave': each prompt is repeated consecutively before + moving to the next element. (Highest cache hit) + + --shuffle-seed: Random seed when the repeat mode is "random". + (Optional, default: 0) + +In the meantime, it also supports all the vLLM engine args to initialize the +LLM engine. You can refer to the `vllm.engine.arg_utils.EngineArgs` for more +details. +""" + +import dataclasses +import random +import time + +from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def test_long_document_qa(llm=None, sampling_params=None, prompts=None): + """ + Test long document QA with the given prompts and sampling parameters. + Print the time spent in processing all the prompts. + + Args: + llm: The language model used for generating responses. + sampling_params: Sampling parameter used to generate the response. + prompts: A list of prompt strings to be processed by the LLM. + """ + start_time = time.time() + llm.generate(prompts, sampling_params=sampling_params) + end_time = time.time() + print(f"Time to execute all requests: {end_time - start_time:.4f} secs") + + +def repeat_prompts(prompts, repeat_count, mode: str): + """ + Repeat each prompt in the list for a specified number of times. + The order of prompts in the output list depends on the mode. + + Args: + prompts: A list of prompts to be repeated. + repeat_count: The number of times each prompt is repeated. + mode: The mode of repetition. Supported modes are: + - 'random': Shuffle the prompts randomly after repetition. + - 'tile': Repeat the entire prompt list in sequence. + Example: [1, 2, 3] -> [1, 2, 3, 1, 2, 3]. + - 'interleave': Repeat each prompt consecutively before moving to + the next. Example: [1, 2, 3] -> [1, 1, 2, 2, 3, 3]. + + Returns: + A list of repeated prompts in the specified order. + + Raises: + ValueError: If an invalid mode is provided. + """ + print("Repeat mode: ", mode) + if mode == 'random': + repeated_prompts = prompts * repeat_count + random.shuffle(repeated_prompts) + return repeated_prompts + elif mode == 'tile': + return prompts * repeat_count + elif mode == 'interleave': + repeated_prompts = [] + for prompt in prompts: + repeated_prompts.extend([prompt] * repeat_count) + return repeated_prompts + else: + raise ValueError(f"Invalid mode: {mode}, only support " + "'random', 'tile', 'interleave'") + + +def main(args): + random.seed(args.shuffle_seed) + + # Prepare the prompts: + # we append the document id at the beginning to avoid any of the document + # being the prefix of other documents + prompts = [ + str(i) + ' '.join(['hi'] * args.document_length) + for i in range(args.num_documents) + ] + + prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode) + + warmup_prompts = [ + "This is warm up request " + str(i) + \ + ' '.join(['hi'] * args.document_length) + for i in range(args.num_documents)] + + # Create the LLM engine + engine_args = EngineArgs.from_cli_args(args) + llm = LLM(**dataclasses.asdict(engine_args)) + sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) + + print("------warm up------") + test_long_document_qa( + llm=llm, + prompts=warmup_prompts, + sampling_params=sampling_params, + ) + + print("------start generating------") + test_long_document_qa( + llm=llm, + prompts=prompts, + sampling_params=sampling_params, + ) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description= + 'Benchmark the performance with or without automatic prefix caching.') + + parser.add_argument( + '--document-length', + type=int, + # Roughly the number of tokens for a system paper, + # excluding images + default=20000, + help='Range of input lengths for sampling prompts,' + 'specified as "min:max" (e.g., "128:256").') + + parser.add_argument('--num-documents', + type=int, + default=8, + help='Range of input lengths for sampling prompts,' + 'specified as "min:max" (e.g., "128:256").') + + parser.add_argument('--output-len', type=int, default=10) + + parser.add_argument('--repeat-count', + type=int, + default=2, + help='Number of times to repeat each prompt') + + parser.add_argument("--repeat-mode", + type=str, + default='random', + help='The mode to repeat prompts. The supported ' + 'modes are "random", "tile", and "interleave". ' + 'See repeat_prompts() in the source code for details.') + + parser.add_argument("--shuffle-seed", + type=int, + default=0, + help='Random seed when the repeat mode is "random"') + + parser = EngineArgs.add_cli_args(parser) + args = parser.parse_args() + main(args) diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 395107a5ec74..4fff7a8fc8ed 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -1,8 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Benchmark the efficiency of prefix caching. + +This script allows you to benchmark the performance of +a model with and without prefix caching using either fixed prompts +or prompts sampled from the ShareGPT dataset. + +Fixed example usage: + python benchmark_prefix_caching.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-prompts 1 \ + --repeat-count 100 \ + --input-length-range 128:256 + +ShareGPT example usage: + # This command samples 20 prompts with input lengths + # between 128 and 256 tokens from the ShareGPT dataset, + # then replicates each prompt 5 times. + python benchmark_prefix_caching.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json \ + --enable-prefix-caching \ + --num-prompts 20 \ + --repeat-count 5 \ + --input-length-range 128:256 +""" + +import dataclasses +import json +import random import time +from typing import Optional + +from transformers import PreTrainedTokenizerBase from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import EngineArgs from vllm.utils import FlexibleArgumentParser +try: + from vllm.transformers_utils.tokenizer import get_tokenizer +except ImportError: + from backend_request_func import get_tokenizer + PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501 @@ -15,25 +56,152 @@ def test_prefix(llm=None, sampling_params=None, prompts=None): print(f"cost time {end_time - start_time}") +@dataclasses.dataclass +class Request: + prompt: str + prompt_len: int + output_len: int + + +def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str: + vocab = tokenizer.get_vocab() + # Remove the special tokens. + vocab = { + k: v + for k, v in vocab.items() if k not in tokenizer.all_special_ids + } + return random.choices(list(vocab.values()), k=length) + + +def sample_requests_from_dataset( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + input_length_range: tuple[int, int], + fixed_output_len: Optional[int], +) -> list[Request]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [(data["conversations"][0]["value"], + data["conversations"][1]["value"]) for data in dataset] + + # Shuffle the dataset. + random.shuffle(dataset) + + min_len, max_len = input_length_range + assert min_len >= 0 and max_len >= min_len, "input_length_range too small" + + # Filter out sequences that are too long or too short + filtered_requests: list[Request] = [] + + for i in range(len(dataset)): + if len(filtered_requests) == num_requests: + break + + # Tokenize the prompts and completions. + prompt_token_ids = tokenizer(dataset[i][0]).input_ids + prompt = tokenizer.decode(prompt_token_ids) + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = (len(completion_token_ids) + if fixed_output_len is None else fixed_output_len) + if min_len <= prompt_len <= max_len: + filtered_requests.append(Request(prompt, prompt_len, output_len)) + + return filtered_requests + + +def sample_requests_from_random( + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + input_length_range: tuple[int, int], + fixed_output_len: Optional[int], + prefix_len: int, +) -> list[Request]: + + requests = [] + prefix_token_ids = sample_tokens(tokenizer, prefix_len) + min_len, max_len = input_length_range + + for i in range(num_requests): + unique_part_token_ids = sample_tokens( + tokenizer, + random.randint(min_len - prefix_len, max_len - prefix_len)) + prompt_token_ids = prefix_token_ids + unique_part_token_ids + prompt = tokenizer.decode(prompt_token_ids) + prompt_len = len(prompt_token_ids) + assert (min_len <= prompt_len <= max_len + ), f"prompt_len {prompt_len} out of range {min_len}:{max_len}" + requests.append(Request(prompt, prompt_len, fixed_output_len)) + return requests + + +def repeat_and_sort_requests(requests: list[Request], + repeat_count: int, + sort: bool = False) -> list[str]: + repeated_requests = requests * repeat_count + if sort: + repeated_requests.sort(key=lambda x: x[1]) + else: + random.shuffle(repeated_requests) + return [req.prompt for req in repeated_requests] + + def main(args): - llm = LLM(model=args.model, - tokenizer_mode='auto', - trust_remote_code=True, - enforce_eager=True, - use_v2_block_manager=args.use_v2_block_manager, - tensor_parallel_size=args.tensor_parallel_size, - enable_prefix_caching=args.enable_prefix_caching) - - num_prompts = 100 - prompts = [PROMPT] * num_prompts - sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) - - print("------warm up------") - test_prefix( - llm=llm, - prompts=prompts, - sampling_params=sampling_params, - ) + tokenizer = get_tokenizer(args.model, trust_remote_code=True) + input_length_range = tuple(map(int, args.input_length_range.split(':'))) + random.seed(args.seed) + if args.dataset_path is not None: + if args.prefix_len > 0: + raise ValueError("prefix-len is not supported when " + "dataset-path is provided.") + print(f"Start to sample {args.num_prompts} prompts " + f"from {args.dataset_path}") + filtered_requests = sample_requests_from_dataset( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + tokenizer=tokenizer, + input_length_range=input_length_range, + fixed_output_len=args.output_len, + ) + else: + print(f"Start to sample {args.num_prompts} prompts from random") + filtered_requests = sample_requests_from_random( + num_requests=args.num_prompts, + tokenizer=tokenizer, + input_length_range=input_length_range, + fixed_output_len=args.output_len, + prefix_len=args.prefix_len, + ) + + # Print some helpful stats of the requests. + print(f"Sampled {len(filtered_requests)} requests.") + prompt_lens = [req.prompt_len for req in filtered_requests] + print(f"Average input length: {sum(prompt_lens) / len(prompt_lens)}") + print(f"P50 input length: {sorted(prompt_lens)[len(prompt_lens) // 2]}") + print(f"Min Prompt Length: {min(prompt_lens)}") + print(f"Max Prompt Length: {max(prompt_lens)}") + + engine_args = EngineArgs.from_cli_args(args) + + llm = LLM(**dataclasses.asdict(engine_args)) + + sampling_params = SamplingParams(temperature=0, + max_tokens=args.output_len, + detokenize=not args.disable_detokenize) + + print("Testing filtered requests") + prompts = repeat_and_sort_requests(filtered_requests, + repeat_count=args.repeat_count, + sort=args.sort) print("------start generating------") test_prefix( @@ -45,18 +213,45 @@ def main(args): if __name__ == "__main__": parser = FlexibleArgumentParser( - description='Benchmark the performance with or without automatic ' - 'prefix caching.') - parser.add_argument('--model', + description= + 'Benchmark the performance with or without automatic prefix caching.') + parser.add_argument("--dataset-path", type=str, - default='baichuan-inc/Baichuan2-13B-Chat') - parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) + default=None, + help="Path to the dataset.") parser.add_argument('--output-len', type=int, default=10) - parser.add_argument('--enable-prefix-caching', + parser.add_argument('--num-prompts', + type=int, + required=True, + help="Number of the prompts sampled from dataset") + parser.add_argument('--repeat-count', + type=int, + default=1, + help='Number of times to repeat each prompt') + parser.add_argument('--sort', action='store_true', - help='enable prefix caching') - parser.add_argument('--use-v2-block-manager', - action='store_true', - help='Use BlockSpaceMangerV2') + help='Sort prompts by input length') + parser.add_argument('--input-length-range', + type=str, + required=True, + help='Range of input lengths for sampling prompts,' + 'specified as "min:max" (e.g., "128:256").') + parser.add_argument( + "--prefix-len", + type=int, + default=0, + help="Specifies the length of a common prefix to be " + "added to the input prompt. The input-length-range will " + "subtract this length when filtering prompts. Only used " + "when dataset-path is not provided.", + ) + parser.add_argument( + '--disable-detokenize', + action='store_true', + help=("Do not detokenize responses (i.e. do not include " + "detokenization time in the latency measurement)"), + ) + + parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_prioritization.py b/benchmarks/benchmark_prioritization.py new file mode 100644 index 000000000000..76fe00ede249 --- /dev/null +++ b/benchmarks/benchmark_prioritization.py @@ -0,0 +1,197 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Benchmark offline prioritization.""" +import argparse +import dataclasses +import json +import random +import time +from typing import Optional + +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from vllm.engine.arg_utils import EngineArgs +from vllm.utils import FlexibleArgumentParser + + +#Select a equi-probable random priority +def get_random_flag(): + return 0 if random.random() < 0.5 else 1 + + +def sample_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int], +) -> list[tuple[str, int, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [(data["conversations"][0]["value"], + data["conversations"][1]["value"]) for data in dataset] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: list[tuple[str, int, int]] = [] + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + + priority = get_random_flag() + + filtered_dataset.append((prompt, prompt_len, output_len, priority)) + + return filtered_dataset + + +def run_vllm( + requests: list[tuple[str, int, int]], + n: int, + engine_args: EngineArgs, + disable_detokenize: bool = False, +) -> float: + from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) + + assert all( + llm.llm_engine.model_config.max_model_len >= (request[1] + request[2]) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of" + " input_len and output_len for all requests.") + + # Add the requests to the engine. + prompts = [] + sampling_params = [] + priority = [] + for prompt, _, output_len, _priority in requests: + prompts.append(prompt) + priority.append(_priority) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=output_len, + detokenize=not disable_detokenize, + )) + + start = time.perf_counter() + llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True) + end = time.perf_counter() + return end - start + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + + # Sample the requests. + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code) + if args.dataset is None: + # Synthesize a prompt with the given input length. + prompt = "hi" * (args.input_len - 1) + requests = [(prompt, args.input_len, args.output_len, + get_random_flag()) for _ in range(args.num_prompts)] + else: + requests = sample_requests(args.dataset, args.num_prompts, tokenizer, + args.output_len) + + if args.backend == "vllm": + elapsed_time = run_vllm(requests, args.n, + EngineArgs.from_cli_args(args), + args.disable_detokenize) + else: + raise ValueError(f"Unknown backend: {args.backend}") + total_num_tokens = sum(prompt_len + output_len + for _, prompt_len, output_len, priority in requests) + print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} tokens/s") + + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark the throughput.") + parser.add_argument("--backend", + type=str, + choices=["vllm", "hf", "mii"], + default="vllm") + parser.add_argument("--dataset", + type=str, + default=None, + help="Path to the dataset.") + parser.add_argument("--input-len", + type=int, + default=None, + help="Input prompt length for each request") + parser.add_argument("--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.") + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--num-prompts", + type=int, + default=200, + help="Number of prompts to process.") + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') + parser.add_argument( + '--disable-detokenize', + action='store_true', + help=("Do not detokenize responses (i.e. do not include " + "detokenization time in the latency measurement)"), + ) + + parser = EngineArgs.add_cli_args(parser) + args = parser.parse_args() + if args.tokenizer is None: + args.tokenizer = args.model + if args.dataset is None: + assert args.input_len is not None + assert args.output_len is not None + else: + assert args.input_len is None + + main(args) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index fc0dbf77f16b..47627126b668 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -1,4 +1,5 @@ -"""Benchmark online serving throughput. +# SPDX-License-Identifier: Apache-2.0 +r"""Benchmark online serving throughput. On the server side, run one of the following commands: vLLM OpenAI API server @@ -24,14 +25,16 @@ """ import argparse import asyncio +import gc import json import os import random import time import warnings +from collections.abc import AsyncGenerator, Iterable from dataclasses import dataclass from datetime import datetime -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple +from typing import Any, Optional import numpy as np from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, @@ -49,6 +52,13 @@ except ImportError: from argparse import ArgumentParser as FlexibleArgumentParser +from benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset, + RandomDataset, SampleRequest, ShareGPTDataset, + SonnetDataset, VisionArenaDataset) +from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json + +MILLISECONDS_TO_SECONDS_CONVERSION = 1000 + @dataclass class BenchmarkMetrics: @@ -56,167 +66,60 @@ class BenchmarkMetrics: total_input: int total_output: int request_throughput: float - input_throughput: float + request_goodput: float output_throughput: float + total_token_throughput: float mean_ttft_ms: float median_ttft_ms: float std_ttft_ms: float - p99_ttft_ms: float + percentiles_ttft_ms: list[tuple[float, float]] mean_tpot_ms: float median_tpot_ms: float std_tpot_ms: float - p99_tpot_ms: float + percentiles_tpot_ms: list[tuple[float, float]] mean_itl_ms: float median_itl_ms: float std_itl_ms: float - p99_itl_ms: float - - -def sample_sharegpt_requests( - dataset_path: str, - num_requests: int, - tokenizer: PreTrainedTokenizerBase, - fixed_output_len: Optional[int] = None, -) -> List[Tuple[str, int, int]]: - if fixed_output_len is not None and fixed_output_len < 4: - raise ValueError("output_len too small") - # Load the dataset. - with open(dataset_path) as f: - dataset = json.load(f) - # Filter out the conversations with less than 2 turns. - dataset = [data for data in dataset if len(data["conversations"]) >= 2] - # Only keep the first two turns of each conversation. - dataset = [(data["conversations"][0]["value"], - data["conversations"][1]["value"]) for data in dataset] - - # Shuffle the dataset. - random.shuffle(dataset) - - # Filter out sequences that are too long or too short - filtered_dataset: List[Tuple[str, int, int]] = [] - for i in range(len(dataset)): - if len(filtered_dataset) == num_requests: - break - - # Tokenize the prompts and completions. - prompt = dataset[i][0] - prompt_token_ids = tokenizer(prompt).input_ids - completion = dataset[i][1] - completion_token_ids = tokenizer(completion).input_ids - prompt_len = len(prompt_token_ids) - output_len = len(completion_token_ids - ) if fixed_output_len is None else fixed_output_len - if prompt_len < 4 or output_len < 4: - # Prune too short sequences. - continue - if prompt_len > 1024 or prompt_len + output_len > 2048: - # Prune too long sequences. - continue - filtered_dataset.append((prompt, prompt_len, output_len)) - - return filtered_dataset - - -def sample_sonnet_requests( - dataset_path: str, - num_requests: int, - input_len: int, - output_len: int, - prefix_len: int, - tokenizer: PreTrainedTokenizerBase, -) -> List[Tuple[str, str, int, int]]: - assert ( - input_len > prefix_len - ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'." - - # Load the dataset. - with open(dataset_path) as f: - poem_lines = f.readlines() - - # Tokenize the poem lines. - poem_token_ids = tokenizer(poem_lines).input_ids - average_poem_len = sum( - len(token_ids) for token_ids in poem_token_ids) / len(poem_token_ids) - - # Base prefix for all requests. - base_prompt = "Pick as many lines as you can from these poem lines:\n" - base_message = [{ - "role": "user", - "content": base_prompt, - }] - base_prompt_formatted = tokenizer.apply_chat_template( - base_message, add_generation_prompt=True, tokenize=False) - base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids) - - assert ( - input_len > base_prompt_offset - ), f"Please set 'args.sonnet-input-len' higher than {base_prompt_offset}." - num_input_lines = round( - (input_len - base_prompt_offset) / average_poem_len) - - # First approximately `prefix_len` number of tokens in the - # prompt are fixed poem lines. - assert ( - prefix_len > base_prompt_offset - ), f"Please set 'args.sonnet-prefix-len' higher than {base_prompt_offset}." - - num_prefix_lines = round( - (prefix_len - base_prompt_offset) / average_poem_len) - prefix_lines = poem_lines[:num_prefix_lines] - - # Sample the rest of lines per request. - sampled_requests: List[Tuple[str, int, int]] = [] - for _ in range(num_requests): - sampled_lines = "".join( - prefix_lines + - random.sample(poem_lines, num_input_lines - num_prefix_lines)) - - prompt = f"{base_prompt}{sampled_lines}" - message = [ - { - "role": "user", - "content": prompt, - }, - ] - prompt_formatted = tokenizer.apply_chat_template( - message, add_generation_prompt=True, tokenize=False) - prompt_len = len(tokenizer(prompt_formatted).input_ids) - sampled_requests.append( - (prompt, prompt_formatted, prompt_len, output_len)) - - return sampled_requests - - -def sample_random_requests( - input_len: int, output_len: int, num_prompts: int, range_ratio: float, - tokenizer: PreTrainedTokenizerBase) -> List[Tuple[str, int, int]]: - - input_lens = np.random.randint( - int(input_len * range_ratio), - input_len + 1, - size=num_prompts, - ) - output_lens = np.random.randint( - int(output_len * range_ratio), - output_len + 1, - size=num_prompts, - ) - offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) - input_requests = [] - for i in range(num_prompts): - prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size - for j in range(input_lens[i])]) - input_requests.append( - (prompt, int(input_lens[i]), int(output_lens[i]))) - - return input_requests + percentiles_itl_ms: list[tuple[float, float]] + # E2EL stands for end-to-end latency per request. + # It is the time taken on the client side from sending + # a request to receiving a complete response. + mean_e2el_ms: float + median_e2el_ms: float + std_e2el_ms: float + percentiles_e2el_ms: list[tuple[float, float]] async def get_request( - input_requests: List[Tuple[str, int, int]], + input_requests: list[SampleRequest], request_rate: float, -) -> AsyncGenerator[Tuple[str, int, int], None]: - input_requests = iter(input_requests) + burstiness: float = 1.0, +) -> AsyncGenerator[SampleRequest, None]: + """ + Asynchronously generates requests at a specified rate + with OPTIONAL burstiness. + + Args: + input_requests: + A list of input requests, each represented as a SampleRequest. + request_rate: + The rate at which requests are generated (requests/s). + burstiness (optional): + The burstiness factor of the request generation. + Only takes effect when request_rate is not inf. + Default value is 1, which follows a Poisson process. + Otherwise, the request intervals follow a gamma distribution. + A lower burstiness value (0 < burstiness < 1) results + in more bursty requests, while a higher burstiness value + (burstiness > 1) results in a more uniform arrival of requests. + """ + input_requests: Iterable[SampleRequest] = iter(input_requests) + + # Calculate scale parameter theta to maintain the desired request_rate. + assert burstiness > 0, ( + f"A positive burstiness factor is expected, but given {burstiness}.") + theta = 1.0 / (request_rate * burstiness) + for request in input_requests: yield request @@ -224,44 +127,82 @@ async def get_request( # If the request rate is infinity, then we don't need to wait. continue - # Sample the request interval from the exponential distribution. - interval = np.random.exponential(1.0 / request_rate) + # Sample the request interval from the gamma distribution. + # If burstiness is 1, it follows exponential distribution. + interval = np.random.gamma(shape=burstiness, scale=theta) # The next request will be sent after the interval. await asyncio.sleep(interval) def calculate_metrics( - input_requests: List[Tuple[str, int, int]], - outputs: List[RequestFuncOutput], + input_requests: list[SampleRequest], + outputs: list[RequestFuncOutput], dur_s: float, tokenizer: PreTrainedTokenizerBase, -) -> Tuple[BenchmarkMetrics, List[int]]: - actual_output_lens: List[int] = [] + selected_percentile_metrics: list[str], + selected_percentiles: list[float], + goodput_config_dict: dict[str, float], +) -> tuple[BenchmarkMetrics, list[int]]: + actual_output_lens: list[int] = [] total_input = 0 completed = 0 - itls: List[float] = [] - tpots: List[float] = [] - ttfts: List[float] = [] + good_completed = 0 + itls: list[float] = [] + tpots: list[float] = [] + all_tpots: list[float] = [] + ttfts: list[float] = [] + e2els: list[float] = [] for i in range(len(outputs)): if outputs[i].success: - # We use the tokenizer to count the number of output tokens for all - # serving backends instead of looking at len(outputs[i].itl) since - # multiple output tokens may be bundled together - # Note : this may inflate the output token count slightly - output_len = len( - tokenizer(outputs[i].generated_text, - add_special_tokens=False).input_ids) + output_len = outputs[i].output_tokens + + if output_len is None: + # We use the tokenizer to count the number of output tokens + # for some serving backends instead of looking at + # len(outputs[i].itl) since multiple output tokens may be + # bundled together + # Note : this may inflate the output token count slightly + output_len = len( + tokenizer(outputs[i].generated_text, + add_special_tokens=False).input_ids) actual_output_lens.append(output_len) - total_input += input_requests[i][1] + total_input += input_requests[i].prompt_len + tpot = 0 if output_len > 1: - tpots.append( - (outputs[i].latency - outputs[i].ttft) / (output_len - 1)) + latency_minus_ttft = outputs[i].latency - outputs[i].ttft + tpot = latency_minus_ttft / (output_len - 1) + tpots.append(tpot) + # Note: if output_len <= 1, we regard tpot as 0 for goodput + all_tpots.append(tpot) itls += outputs[i].itl ttfts.append(outputs[i].ttft) + e2els.append(outputs[i].latency) completed += 1 else: actual_output_lens.append(0) + if goodput_config_dict: + valid_metrics = [] + slo_values = [] + + if "ttft" in goodput_config_dict: + valid_metrics.append(ttfts) + slo_values.append(goodput_config_dict["ttft"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + if "tpot" in goodput_config_dict: + valid_metrics.append(all_tpots) + slo_values.append(goodput_config_dict["tpot"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + if "e2el" in goodput_config_dict: + valid_metrics.append(e2els) + slo_values.append(goodput_config_dict["e2el"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + + for req_metric in zip(*valid_metrics): + is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) + if is_good_req: + good_completed += 1 + if completed == 0: warnings.warn( "All requests failed. This is likely due to a misconfiguration " @@ -272,21 +213,30 @@ def calculate_metrics( total_input=total_input, total_output=sum(actual_output_lens), request_throughput=completed / dur_s, - input_throughput=total_input / dur_s, + request_goodput=good_completed / dur_s, output_throughput=sum(actual_output_lens) / dur_s, + total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, mean_ttft_ms=np.mean(ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend - median_ttft_ms=np.median(ttfts or 0) * 1000, std_ttft_ms=np.std(ttfts or 0) * 1000, - p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, + median_ttft_ms=np.median(ttfts or 0) * 1000, + percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) + for p in selected_percentiles], mean_tpot_ms=np.mean(tpots or 0) * 1000, - median_tpot_ms=np.median(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000, - p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) + for p in selected_percentiles], mean_itl_ms=np.mean(itls or 0) * 1000, - median_itl_ms=np.median(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000, - p99_itl_ms=np.percentile(itls or 0, 99) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) + for p in selected_percentiles], + mean_e2el_ms=np.mean(e2els or 0) * 1000, + std_e2el_ms=np.std(e2els or 0) * 1000, + median_e2el_ms=np.median(e2els or 0) * 1000, + percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) + for p in selected_percentiles], ) return metrics, actual_output_lens @@ -295,13 +245,22 @@ def calculate_metrics( async def benchmark( backend: str, api_url: str, + base_url: str, model_id: str, + model_name: str, tokenizer: PreTrainedTokenizerBase, - input_requests: List[Tuple[str, int, int]], - best_of: int, - use_beam_search: bool, + input_requests: list[SampleRequest], + logprobs: Optional[int], request_rate: float, + burstiness: float, disable_tqdm: bool, + profile: bool, + selected_percentile_metrics: list[str], + selected_percentiles: list[float], + ignore_eos: bool, + goodput_config_dict: dict[str, float], + max_concurrency: Optional[int], + lora_modules: Optional[Iterable[str]], ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -309,16 +268,28 @@ async def benchmark( raise ValueError(f"Unknown backend: {backend}") print("Starting initial single prompt test run...") - test_prompt, test_prompt_len, test_output_len = input_requests[0] + test_prompt, test_prompt_len, test_output_len, test_mm_content = \ + input_requests[0].prompt, input_requests[0].prompt_len, \ + input_requests[0].expected_output_len, \ + input_requests[0].multi_modal_data + + if backend != "openai-chat" and test_mm_content is not None: + # multi-modal benchmark is only available on OpenAI Chat backend. + raise ValueError( + "Multi-modal content is only supported on 'openai-chat' backend.") + assert test_mm_content is None or isinstance(test_mm_content, dict) test_input = RequestFuncInput( model=model_id, + model_name=model_name, prompt=test_prompt, api_url=api_url, prompt_len=test_prompt_len, output_len=test_output_len, - best_of=best_of, - use_beam_search=use_beam_search, + logprobs=logprobs, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos, ) + test_output = await request_func(request_func_input=test_input) if not test_output.success: raise ValueError( @@ -326,28 +297,93 @@ async def benchmark( f"are correctly specified. Error: {test_output.error}") else: print("Initial test run completed. Starting main benchmark run...") + + if lora_modules: + # For each input request, choose a LoRA module at random. + lora_modules = iter( + [random.choice(lora_modules) \ + for _ in range(len(input_requests))]) + + if profile: + print("Starting profiler...") + profile_input = RequestFuncInput(model=model_id, + model_name=model_name, + prompt=test_prompt, + api_url=base_url + "/start_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos) + profile_output = await request_func(request_func_input=profile_input) + if profile_output.success: + print("Profiler started") + + if burstiness == 1.0: + distribution = "Poisson process" + else: + distribution = "Gamma distribution" + print(f"Traffic request rate: {request_rate}") + print(f"Burstiness factor: {burstiness} ({distribution})") + print(f"Maximum request concurrency: {max_concurrency}") pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + # This can be used once the minimum Python version is 3.10 or higher, + # and it will simplify the code in limited_request_func. + # semaphore = (asyncio.Semaphore(max_concurrency) + # if max_concurrency else contextlib.nullcontext()) + semaphore = (asyncio.Semaphore(max_concurrency) + if max_concurrency else None) + + async def limited_request_func(request_func_input, pbar): + if semaphore is None: + return await request_func(request_func_input=request_func_input, + pbar=pbar) + async with semaphore: + return await request_func(request_func_input=request_func_input, + pbar=pbar) + benchmark_start_time = time.perf_counter() - tasks: List[asyncio.Task] = [] - async for request in get_request(input_requests, request_rate): - prompt, prompt_len, output_len = request - request_func_input = RequestFuncInput( - model=model_id, - prompt=prompt, - api_url=api_url, - prompt_len=prompt_len, - output_len=output_len, - best_of=best_of, - use_beam_search=use_beam_search, - ) + tasks: list[asyncio.Task] = [] + async for request in get_request(input_requests, request_rate, burstiness): + prompt, prompt_len, output_len, mm_content = request.prompt, \ + request.prompt_len, request.expected_output_len, \ + request.multi_modal_data + req_model_id, req_model_name = model_id, model_name + if lora_modules: + req_lora_module = next(lora_modules) + req_model_id, req_model_name = req_lora_module, req_lora_module + + request_func_input = RequestFuncInput(model=req_model_id, + model_name=req_model_name, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + logprobs=logprobs, + multi_modal_content=mm_content, + ignore_eos=ignore_eos) tasks.append( asyncio.create_task( - request_func(request_func_input=request_func_input, - pbar=pbar))) - outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + limited_request_func(request_func_input=request_func_input, + pbar=pbar))) + outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) + + if profile: + print("Stopping profiler...") + profile_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=base_url + "/stop_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + ) + profile_output = await request_func(request_func_input=profile_input) + if profile_output.success: + print("Profiler stopped") if pbar is not None: pbar.close() @@ -359,6 +395,9 @@ async def benchmark( outputs=outputs, dur_s=benchmark_duration, tokenizer=tokenizer, + selected_percentile_metrics=selected_percentile_metrics, + selected_percentiles=selected_percentiles, + goodput_config_dict=goodput_config_dict, ) print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) @@ -370,27 +409,13 @@ async def benchmark( metrics.total_output)) print("{:<40} {:<10.2f}".format("Request throughput (req/s):", metrics.request_throughput)) - print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):", - metrics.input_throughput)) + if goodput_config_dict: + print("{:<40} {:<10.2f}".format("Request goodput (req/s):", + metrics.request_goodput)) print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", metrics.output_throughput)) - print("{s:{c}^{n}}".format(s='Time to First Token', n=50, c='-')) - print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) - print("{:<40} {:<10.2f}".format("Median TTFT (ms):", - metrics.median_ttft_ms)) - print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) - print("{s:{c}^{n}}".format(s='Time per Output Token (excl. 1st token)', - n=50, - c='-')) - print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms)) - print("{:<40} {:<10.2f}".format("Median TPOT (ms):", - metrics.median_tpot_ms)) - print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) - print("{s:{c}^{n}}".format(s='Inter-token Latency', n=50, c='-')) - print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) - print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) - print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) - print("=" * 50) + print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", + metrics.total_token_throughput)) result = { "duration": benchmark_duration, @@ -398,20 +423,10 @@ async def benchmark( "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, "request_throughput": metrics.request_throughput, - "input_throughput": metrics.input_throughput, + "request_goodput:": + metrics.request_goodput if goodput_config_dict else None, "output_throughput": metrics.output_throughput, - "mean_ttft_ms": metrics.mean_ttft_ms, - "median_ttft_ms": metrics.median_ttft_ms, - "std_ttft_ms": metrics.std_ttft_ms, - "p99_ttft_ms": metrics.p99_ttft_ms, - "mean_tpot_ms": metrics.mean_tpot_ms, - "median_tpot_ms": metrics.median_tpot_ms, - "std_tpot_ms": metrics.std_tpot_ms, - "p99_tpot_ms": metrics.p99_tpot_ms, - "mean_itl_ms": metrics.mean_itl_ms, - "median_itl_ms": metrics.median_itl_ms, - "std_itl_ms": metrics.std_itl_ms, - "p99_itl_ms": metrics.p99_itl_ms, + "total_token_throughput": metrics.total_token_throughput, "input_lens": [output.prompt_len for output in outputs], "output_lens": actual_output_lens, "ttfts": [output.ttft for output in outputs], @@ -419,9 +434,110 @@ async def benchmark( "generated_texts": [output.generated_text for output in outputs], "errors": [output.error for output in outputs], } + + def process_one_metric( + # E.g., "ttft" + metric_attribute_name: str, + # E.g., "TTFT" + metric_name: str, + # E.g., "Time to First Token" + metric_header: str, + ): + # This function prints and adds statistics of the specified + # metric. + if metric_attribute_name not in selected_percentile_metrics: + return + print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) + print("{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_attribute_name}_ms"))) + print("{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_attribute_name}_ms"))) + result[f"mean_{metric_attribute_name}_ms"] = getattr( + metrics, f"mean_{metric_attribute_name}_ms") + result[f"median_{metric_attribute_name}_ms"] = getattr( + metrics, f"median_{metric_attribute_name}_ms") + result[f"std_{metric_attribute_name}_ms"] = getattr( + metrics, f"std_{metric_attribute_name}_ms") + for p, value in getattr(metrics, + f"percentiles_{metric_attribute_name}_ms"): + p_word = str(int(p)) if int(p) == p else str(p) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", + value)) + result[f"p{p_word}_{metric_attribute_name}_ms"] = value + + process_one_metric("ttft", "TTFT", "Time to First Token") + process_one_metric("tpot", "TPOT", + "Time per Output Token (excl. 1st token)") + process_one_metric("itl", "ITL", "Inter-token Latency") + process_one_metric("e2el", "E2EL", "End-to-end Latency") + + print("=" * 50) + return result +def check_goodput_args(args): + # Check and parse goodput arguments + goodput_config_dict = {} + VALID_NAMES = ["ttft", "tpot", "e2el"] + if args.goodput: + goodput_config_dict = parse_goodput(args.goodput) + for slo_name, slo_val in goodput_config_dict.items(): + if slo_name not in VALID_NAMES: + raise ValueError( + f"Invalid metric name found, {slo_name}: {slo_val}. " + "The service level objective name should be one of " + f"{str(VALID_NAMES)}. ") + if slo_val < 0: + raise ValueError( + f"Invalid value found, {slo_name}: {slo_val}. " + "The service level objective value should be " + "non-negative.") + return goodput_config_dict + + +def parse_goodput(slo_pairs): + goodput_config_dict = {} + try: + for slo_pair in slo_pairs: + slo_name, slo_val = slo_pair.split(":") + goodput_config_dict[slo_name] = float(slo_val) + except ValueError as err: + raise argparse.ArgumentTypeError( + "Invalid format found for service level objectives. " + "Specify service level objectives for goodput as \"KEY:VALUE\" " + "pairs, where the key is a metric name, and the value is a " + "number in milliseconds.") from err + return goodput_config_dict + + +def save_to_pytorch_benchmark_format(args: argparse.Namespace, + results: dict[str, Any], + file_name: str) -> None: + metrics = [ + "median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms", + "mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms", + "median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms" + ] + # These raw data might be useful, but they are rather big. They can be added + # later if needed + ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"] + pt_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={k: [results[k]] + for k in metrics}, + extra_info={ + k: results[k] + for k in results if k not in metrics and k not in ignored_metrics + }) + if pt_records: + # Don't use json suffix here as we don't want CI to pick it up + pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json" + write_to_json(pt_file, pt_records) + + def main(args: argparse.Namespace): print(args) random.seed(args.seed) @@ -429,95 +545,125 @@ def main(args: argparse.Namespace): backend = args.backend model_id = args.model + model_name = args.served_model_name tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + tokenizer_mode = args.tokenizer_mode if args.base_url is not None: api_url = f"{args.base_url}{args.endpoint}" + base_url = f"{args.base_url}" else: api_url = f"http://{args.host}:{args.port}{args.endpoint}" + base_url = f"http://{args.host}:{args.port}" tokenizer = get_tokenizer(tokenizer_id, + tokenizer_mode=tokenizer_mode, trust_remote_code=args.trust_remote_code) - if args.dataset is not None: - warnings.warn( - "The '--dataset' argument will be deprecated in the next " - "release. Please use '--dataset-name' and " - "'--dataset-path' in the future runs.", - stacklevel=2) - input_requests = sample_sharegpt_requests( - dataset_path=args.dataset, - num_requests=args.num_prompts, - tokenizer=tokenizer, - fixed_output_len=args.sharegpt_output_len, - ) + if args.dataset_name is None: + raise ValueError( + "Please specify '--dataset-name' and the corresponding " + "'--dataset-path' if required.") - elif args.dataset_name == "sharegpt": - input_requests = sample_sharegpt_requests( + if args.dataset_name == "sonnet": + dataset = SonnetDataset(dataset_path=args.dataset_path) + # For the "sonnet" dataset, formatting depends on the backend. + if args.backend == "openai-chat": + input_requests = dataset.sample(num_requests=args.num_prompts, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + return_prompt_formatted=False) + else: + assert tokenizer.chat_template or tokenizer.default_chat_template, ( + "Tokenizer/model must have chat template for sonnet dataset.") + input_requests = dataset.sample(num_requests=args.num_prompts, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + return_prompt_formatted=True) + + elif args.dataset_name == "hf": + # Choose between VisionArenaDataset + # and HuggingFaceDataset based on provided parameters. + dataset_class = (VisionArenaDataset if args.dataset_path + == VisionArenaDataset.VISION_ARENA_DATASET_PATH + and args.hf_subset is None else HuggingFaceDataset) + input_requests = dataset_class( dataset_path=args.dataset_path, + dataset_subset=args.hf_subset, + dataset_split=args.hf_split, + ).sample( num_requests=args.num_prompts, tokenizer=tokenizer, - fixed_output_len=args.sharegpt_output_len, + random_seed=args.seed, + output_len=args.hf_output_len, ) - elif args.dataset_name == "sonnet": - # Do not format the prompt, pass to message directly - if args.backend == "openai-chat": - input_requests = sample_sonnet_requests( - dataset_path=args.dataset_path, - num_requests=args.num_prompts, - input_len=args.sonnet_input_len, - output_len=args.sonnet_output_len, - prefix_len=args.sonnet_prefix_len, + else: + # For datasets that follow a similar structure, use a mapping. + dataset_mapping = { + "sharegpt": + lambda: ShareGPTDataset(random_seed=args.seed, + dataset_path=args.dataset_path).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + output_len=args.sharegpt_output_len, + ), + "burstgpt": + lambda: BurstGPTDataset(random_seed=args.seed, + dataset_path=args.dataset_path). + sample(tokenizer=tokenizer, num_requests=args.num_prompts), + "random": + lambda: RandomDataset(dataset_path=args.dataset_path).sample( tokenizer=tokenizer, - ) - input_requests = [(prompt, prompt_len, output_len) - for prompt, prompt_formatted, prompt_len, - output_len in input_requests] - else: - assert ( - tokenizer.chat_template or tokenizer.default_chat_template - ), "Tokenizer/model must have chat template for sonnet dataset." - input_requests = sample_sonnet_requests( - dataset_path=args.dataset_path, num_requests=args.num_prompts, - input_len=args.sonnet_input_len, - output_len=args.sonnet_output_len, - prefix_len=args.sonnet_prefix_len, - tokenizer=tokenizer, + prefix_len=args.random_prefix_len, + input_len=args.random_input_len, + output_len=args.random_output_len, + range_ratio=args.random_range_ratio, ) - input_requests = [(prompt_formatted, prompt_len, output_len) - for prompt, prompt_formatted, prompt_len, - output_len in input_requests] - - elif args.dataset_name == "random": - input_requests = sample_random_requests( - input_len=args.random_input_len, - output_len=args.random_output_len, - num_prompts=args.num_prompts, - range_ratio=args.random_range_ratio, - tokenizer=tokenizer, - ) + } - else: - raise ValueError(f"Unknown dataset: {args.dataset_name}") + try: + input_requests = dataset_mapping[args.dataset_name]() + except KeyError as err: + raise ValueError(f"Unknown dataset: {args.dataset_name}") from err + goodput_config_dict = check_goodput_args(args) + + # Avoid GC processing "static" data - reduce pause times. + gc.collect() + gc.freeze() benchmark_result = asyncio.run( benchmark( backend=backend, api_url=api_url, + base_url=base_url, model_id=model_id, + model_name=model_name, tokenizer=tokenizer, input_requests=input_requests, - best_of=args.best_of, - use_beam_search=args.use_beam_search, + logprobs=args.logprobs, request_rate=args.request_rate, + burstiness=args.burstiness, disable_tqdm=args.disable_tqdm, + profile=args.profile, + selected_percentile_metrics=args.percentile_metrics.split(","), + selected_percentiles=[ + float(p) for p in args.metric_percentiles.split(",") + ], + ignore_eos=args.ignore_eos, + goodput_config_dict=goodput_config_dict, + max_concurrency=args.max_concurrency, + lora_modules=args.lora_modules, )) # Save config and results to json if args.save_result: - result_json: Dict[str, Any] = {} + result_json: dict[str, Any] = {} # Setup current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") @@ -525,8 +671,6 @@ def main(args: argparse.Namespace): result_json["backend"] = backend result_json["model_id"] = model_id result_json["tokenizer_id"] = tokenizer_id - result_json["best_of"] = args.best_of - result_json["use_beam_search"] = args.use_beam_search result_json["num_prompts"] = args.num_prompts # Metadata @@ -540,22 +684,36 @@ def main(args: argparse.Namespace): "Invalid metadata format. Please use KEY=VALUE format." ) + if not args.save_detailed: + # Remove fields with too many data points + for field in [ + "input_lens", "output_lens", "ttfts", "itls", + "generated_texts", "errors" + ]: + if field in result_json: + del result_json[field] + # Traffic - result_json["request_rate"] = ( - args.request_rate if args.request_rate < float("inf") else "inf") + result_json["request_rate"] = (args.request_rate if args.request_rate + < float("inf") else "inf") + result_json["burstiness"] = args.burstiness + result_json["max_concurrency"] = args.max_concurrency # Merge with benchmark result result_json = {**result_json, **benchmark_result} # Save to file base_model_id = model_id.split("/")[-1] - file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" #noqa + max_concurrency_str = (f"-concurrency{args.max_concurrency}" + if args.max_concurrency is not None else "") + file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa if args.result_filename: file_name = args.result_filename if args.result_dir: file_name = os.path.join(args.result_dir, file_name) - with open(file_name, "w") as outfile: + with open(file_name, "w", encoding='utf-8') as outfile: json.dump(result_json, outfile) + save_to_pytorch_benchmark_format(args, result_json, file_name) if __name__ == "__main__": @@ -573,7 +731,8 @@ def main(args: argparse.Namespace): default=None, help="Server or API base url if not using http host and port.", ) - parser.add_argument("--host", type=str, default="localhost") + # Use 127.0.0.1 here instead of localhost to force the use of ipv4 + parser.add_argument("--host", type=str, default="127.0.0.1") parser.add_argument("--port", type=int, default=8000) parser.add_argument( "--endpoint", @@ -581,24 +740,31 @@ def main(args: argparse.Namespace): default="/v1/completions", help="API endpoint.", ) - parser.add_argument( - "--dataset", - type=str, - default=None, - help="Path to the ShareGPT dataset, will be deprecated in the " - "next release.", - ) parser.add_argument( "--dataset-name", type=str, default="sharegpt", - choices=["sharegpt", "sonnet", "random"], + choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"], help="Name of the dataset to benchmark on.", ) parser.add_argument("--dataset-path", type=str, default=None, - help="Path to the dataset.") + help="Path to the sharegpt/sonnet dataset. " + "Or the huggingface dataset ID if using HF dataset.") + parser.add_argument( + "--max-concurrency", + type=int, + default=None, + help="Maximum number of concurrent requests. This can be used " + "to help simulate an environment where a higher level component " + "is enforcing a maximum number of concurrent requests. While the " + "--request-rate argument controls the rate at which requests are " + "initiated, this argument will control how many are actually allowed " + "to execute at a time. This means that when used in combination, the " + "actual request rate may be lower than specified with --request-rate, " + "if the server is not processing requests fast enough to keep up.") + parser.add_argument( "--model", type=str, @@ -611,13 +777,6 @@ def main(args: argparse.Namespace): help= "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ) - parser.add_argument( - "--best-of", - type=int, - default=1, - help="Generates `best_of` sequences per prompt and " - "returns the best one.", - ) parser.add_argument("--use-beam-search", action="store_true") parser.add_argument( "--num-prompts", @@ -626,52 +785,14 @@ def main(args: argparse.Namespace): help="Number of prompts to process.", ) parser.add_argument( - "--sharegpt-output-len", + "--logprobs", type=int, default=None, - help="Output length for each request. Overrides the output length " - "from the ShareGPT dataset.") - parser.add_argument( - "--sonnet-input-len", - type=int, - default=550, - help= - "Number of input tokens per request, used only for sonnet dataset.", - ) - parser.add_argument( - "--sonnet-output-len", - type=int, - default=150, - help= - "Number of output tokens per request, used only for sonnet dataset.", - ) - parser.add_argument( - "--sonnet-prefix-len", - type=int, - default=200, - help= - "Number of prefix tokens per request, used only for sonnet dataset.", - ) - parser.add_argument( - "--random-input-len", - type=int, - default=1024, - help= - "Number of input tokens per request, used only for random sampling.", - ) - parser.add_argument( - "--random-output-len", - type=int, - default=128, - help= - "Number of output tokens per request, used only for random sampling.", - ) - parser.add_argument( - "--random-range-ratio", - type=float, - default=1.0, - help="Range of sampled ratio of input/output length, " - "used only for random sampling.", + help=("Number of logprobs-per-token to compute & return as part of " + "the request. If unspecified, then either (1) if beam search " + "is disabled, no logprobs are computed & a single dummy " + "logprob is returned for each token; or (2) if beam search " + "is enabled 1 logprob per token is computed"), ) parser.add_argument( "--request-rate", @@ -679,8 +800,20 @@ def main(args: argparse.Namespace): default=float("inf"), help="Number of requests per second. If this is inf, " "then all the requests are sent at time 0. " - "Otherwise, we use Poisson process to synthesize " - "the request arrival times.", + "Otherwise, we use Poisson process or gamma distribution " + "to synthesize the request arrival times.", + ) + parser.add_argument( + "--burstiness", + type=float, + default=1.0, + help="Burstiness factor of the request generation. " + "Only take effect when request_rate is not inf. " + "Default value is 1, which follows Poisson process. " + "Otherwise, the request intervals follow a gamma distribution. " + "A lower burstiness value (0 < burstiness < 1) results in more " + "bursty requests. A higher burstiness value (burstiness > 1) " + "results in a more uniform arrival of requests.", ) parser.add_argument("--seed", type=int, default=0) parser.add_argument( @@ -693,11 +826,23 @@ def main(args: argparse.Namespace): action="store_true", help="Specify to disable tqdm progress bar.", ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "VLLM_TORCH_PROFILER_DIR to enable profiler.", + ) parser.add_argument( "--save-result", action="store_true", help="Specify to save benchmark results to a json file", ) + parser.add_argument( + "--save-detailed", + action="store_true", + help="When saving the results, whether to include per request " + "information such as response, error, ttfs, tpots, etc.", + ) parser.add_argument( "--metadata", metavar="KEY=VALUE", @@ -722,6 +867,145 @@ def main(args: argparse.Namespace): "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" " format.", ) + parser.add_argument( + "--ignore-eos", + action="store_true", + help="Set ignore_eos flag when sending the benchmark request." + "Warning: ignore_eos is not supported in deepspeed_mii and tgi.") + parser.add_argument( + "--percentile-metrics", + type=str, + default="ttft,tpot,itl", + help="Comma-seperated list of selected metrics to report percentils. " + "This argument specifies the metrics to report percentiles. " + "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". " + "Default value is \"ttft,tpot,itl\".") + parser.add_argument( + "--metric-percentiles", + type=str, + default="99", + help="Comma-seperated list of percentiles for selected metrics. " + "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " + "Default value is \"99\". " + "Use \"--percentile-metrics\" to select metrics.", + ) + parser.add_argument( + "--goodput", + nargs="+", + required=False, + help="Specify service level objectives for goodput as \"KEY:VALUE\" " + "pairs, where the key is a metric name, and the value is in " + "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " + "separated by spaces. Allowed request level metric names are " + "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " + "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " + "and the blog: https://hao-ai-lab.github.io/blogs/distserve") + + # group for dataset specific arguments + sonnet_group = parser.add_argument_group("sonnet dataset options") + sonnet_group.add_argument( + "--sonnet-input-len", + type=int, + default=550, + help= + "Number of input tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-output-len", + type=int, + default=150, + help= + "Number of output tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-prefix-len", + type=int, + default=200, + help= + "Number of prefix tokens per request, used only for sonnet dataset.", + ) + + sharegpt_group = parser.add_argument_group("sharegpt dataset options") + sharegpt_group.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length " + "from the ShareGPT dataset.") + + random_group = parser.add_argument_group("random dataset options") + random_group.add_argument( + "--random-input-len", + type=int, + default=1024, + help= + "Number of input tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-output-len", + type=int, + default=128, + help= + "Number of output tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-range-ratio", + type=float, + default=1.0, + help="Range of sampled ratio of input/output length, " + "used only for random sampling.", + ) + random_group.add_argument( + "--random-prefix-len", + type=int, + default=0, + help="Number of fixed prefix tokens before random " + " context. The length range of context in a random " + " request is [random-prefix-len, " + " random-prefix-len + random-prefix-len * random-range-ratio).") + + hf_group = parser.add_argument_group("hf dataset options") + hf_group.add_argument("--hf-subset", + type=str, + default=None, + help="Subset of the HF dataset.") + hf_group.add_argument("--hf-split", + type=str, + default=None, + help="Split of the HF dataset.") + hf_group.add_argument( + "--hf-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output lengths " + "from the sampled HF dataset.", + ) + + parser.add_argument( + '--tokenizer-mode', + type=str, + default="auto", + choices=['auto', 'slow', 'mistral', 'custom'], + help='The tokenizer mode.\n\n* "auto" will use the ' + 'fast tokenizer if available.\n* "slow" will ' + 'always use the slow tokenizer. \n* ' + '"mistral" will always use the `mistral_common` tokenizer. \n*' + '"custom" will use --tokenizer to select the preregistered tokenizer.') + + parser.add_argument("--served-model-name", + type=str, + default=None, + help="The model name used in the API. " + "If not specified, the model name will be the " + "same as the ``--model`` argument. ") + + parser.add_argument("--lora-modules", + nargs='+', + default=None, + help="A subset of LoRA module names passed in when " + "launching the server. For each request, the " + "script chooses a LoRA module at random.") args = parser.parse_args() + main(args) diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py new file mode 100644 index 000000000000..c79a93faff19 --- /dev/null +++ b/benchmarks/benchmark_serving_structured_output.py @@ -0,0 +1,1010 @@ +# SPDX-License-Identifier: Apache-2.0 +r"""Benchmark online serving throughput with structured outputs. + +On the server side, run one of the following commands: + (vLLM OpenAI API server) + vllm serve --disable-log-requests + + (TGI backend) + ./launch_tgi_server.sh + +On the client side, run: + python benchmarks/benchmark_serving_structured_output.py \ + --backend \ + --model \ + --dataset json \ + --structured-output-ratio 1.0 \ + --structured-output-backend xgrammar \ + --request-rate 10 \ + --num-prompts 1000 + + when using tgi backend, add + --endpoint /generate_stream + to the end of the command above. +""" +import argparse +import asyncio +import copy +import dataclasses +import json +import os +import random +import time +import uuid +import warnings +from collections.abc import AsyncGenerator +from dataclasses import dataclass +from typing import Optional + +import datasets +import numpy as np +import pandas as pd +from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, + RequestFuncOutput) +from tqdm.asyncio import tqdm +from transformers import PreTrainedTokenizerBase + +try: + from vllm.transformers_utils.tokenizer import get_tokenizer +except ImportError: + from backend_request_func import get_tokenizer + +try: + from vllm.utils import FlexibleArgumentParser +except ImportError: + from argparse import ArgumentParser as FlexibleArgumentParser + +from vllm.v1.structured_output.utils import ( + has_xgrammar_unsupported_json_features) + +MILLISECONDS_TO_SECONDS_CONVERSION = 1000 + + +@dataclass +class BenchmarkMetrics: + completed: int + total_input: int + total_output: int + request_throughput: float + request_goodput: float + output_throughput: float + total_token_throughput: float + mean_ttft_ms: float + median_ttft_ms: float + std_ttft_ms: float + percentiles_ttft_ms: list[tuple[float, float]] + mean_tpot_ms: float + median_tpot_ms: float + std_tpot_ms: float + percentiles_tpot_ms: list[tuple[float, float]] + mean_itl_ms: float + median_itl_ms: float + std_itl_ms: float + percentiles_itl_ms: list[tuple[float, float]] + # E2EL stands for end-to-end latency per request. + # It is the time taken on the client side from sending + # a request to receiving a complete response. + mean_e2el_ms: float + median_e2el_ms: float + std_e2el_ms: float + percentiles_e2el_ms: list[tuple[float, float]] + + +@dataclasses.dataclass +class SampleRequest: + """A class representing a single inference request for benchmarking. + + Attributes: + prompt: The input text prompt for the model. + multi_modal_data: Optional dictionary containing multi-modal data (e.g. + images). + prompt_len: The length of the prompt in tokens. + expected_output_len: The expected length of the output in tokens. + """ + prompt: str + prompt_len: int + expected_output_len: int + schema: dict + structure_type: str + completion: str = None + + +def sample_requests(tokenizer: PreTrainedTokenizerBase, + args: argparse.Namespace) -> list[SampleRequest]: + if args.dataset == 'json' or args.dataset == 'json-unique': + if args.json_schema_path is None: + dir_path = os.path.dirname(os.path.realpath(__file__)) + args.json_schema_path = os.path.join(dir_path, + "structured_schemas", + "structured_schema_1.json") + json_schemas = [] + with open(args.json_schema_path) as f: + schema = json.load(f) + + if args.dataset == 'json-unique': + json_schemas = [ + copy.deepcopy(schema) for _ in range(args.num_prompts) + ] + for i in range(len(json_schemas)): + json_schemas[i]["properties"][ + f"__optional_field_{uuid.uuid4()}"] = { + "type": + "string", + "description": + "An unique optional field to avoid cached schemas" + } + + def gen_prompt(index: int): + schema = json_schemas[index % len(json_schemas)] + return f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501 + + def get_schema(index: int): + return json_schemas[index % len(json_schemas)] + + requests = [ + SampleRequest(prompt=gen_prompt(i), + prompt_len=len(tokenizer(gen_prompt(i)).input_ids), + expected_output_len=args.output_len, + schema=get_schema(i), + structure_type=args.structure_type) + for i in range(args.num_prompts) + ] + + elif args.dataset == "grammar": + schema = """ + ?start: select_statement + + ?select_statement: "SELECT " column_list " FROM " table_name + + ?column_list: column_name ("," column_name)* + + ?table_name: identifier + + ?column_name: identifier + + ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/ + """ + prompt = "Generate an SQL query to show the 'username' \ + and 'email' from the 'users' table." + + input_len = len(tokenizer(prompt).input_ids) + print(f"Input length of the prompt: {input_len} tokens") + requests = [ + SampleRequest(prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=schema, + structure_type=args.structure_type) + for _ in range(args.num_prompts) + ] + + elif args.dataset == "regex": + regex = r"\w+@\w+\.com\n" + args.regex = regex + prompt = "Generate an email address for Alan Turing, \ + who works in Enigma. End in .com and new line. \ + Example result: alan.turing@enigma.com\n" + + input_len = len(tokenizer(prompt).input_ids) + print(f"Input length of the prompt: {input_len} tokens") + requests = [ + SampleRequest(prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=regex, + structure_type=args.structure_type) + for _ in range(args.num_prompts) + ] + + elif args.dataset == "choice": + choice = ["Positive", "Negative"] + args.choice = choice + prompt = "Classify this sentiment: vLLM is wonderful!" + input_len = len(tokenizer(prompt).input_ids) + print(f"Input length of the prompt: {input_len} tokens") + requests = [ + SampleRequest(prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=choice, + structure_type=args.structure_type) + for _ in range(args.num_prompts) + ] + + elif args.dataset == "xgrammar_bench": + requests: list[SampleRequest] = [] + dataset = datasets.load_dataset("NousResearch/json-mode-eval", + split="train") + full_dataset_len = len(dataset) + + def _filter_func(item): + import json + schema = json.loads(item["schema"]) + return not has_xgrammar_unsupported_json_features(schema) + + dataset = dataset.filter(_filter_func) + num_filtered_out = full_dataset_len - len(dataset) + print(f"dataset has {len(dataset)} entries after filtering " + f"out {num_filtered_out} entries with unsupported features") + len_dataset = len(dataset) + for data_point_idx in range(args.num_prompts): + idx = data_point_idx + while idx >= len_dataset: + idx -= len_dataset + schema = dataset["schema"][idx] + prompt = tokenizer.apply_chat_template(dataset["prompt"][idx], + tokenize=False) + input_len = len(tokenizer(prompt).input_ids) + completion = dataset["completion"][idx] + + requests.append( + SampleRequest(prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=schema, + structure_type=args.structure_type, + completion=completion)) + + return requests + + +async def get_request( + input_requests: list[SampleRequest], + request_rate: float, + burstiness: float = 1.0, +) -> AsyncGenerator[tuple[int, SampleRequest], None]: + """ + Asynchronously generates requests at a specified rate + with OPTIONAL burstiness. + + Args: + input_requests: + A list of input requests, each represented as a tuple. + request_rate: + The rate at which requests are generated (requests/s). + burstiness (optional): + The burstiness factor of the request generation. + Only takes effect when request_rate is not inf. + Default value is 1, which follows a Poisson process. + Otherwise, the request intervals follow a gamma distribution. + A lower burstiness value (0 < burstiness < 1) results + in more bursty requests, while a higher burstiness value + (burstiness > 1) results in a more uniform arrival of requests. + """ + input_requests = iter(input_requests) + + # Calculate scale parameter theta to maintain the desired request_rate. + assert burstiness > 0, ( + f"A positive burstiness factor is expected, but given {burstiness}.") + theta = 1.0 / (request_rate * burstiness) + + for i, request in enumerate(input_requests): + yield i, request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + + # Sample the request interval from the gamma distribution. + # If burstiness is 1, it follows exponential distribution. + interval = np.random.gamma(shape=burstiness, scale=theta) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + + +def calculate_metrics( + input_requests: list[tuple[str, int, int]], + outputs: list[RequestFuncOutput], + dur_s: float, + tokenizer: PreTrainedTokenizerBase, + selected_percentile_metrics: list[str], + selected_percentiles: list[float], + goodput_config_dict: Optional[dict[str, float]] = None, +) -> tuple[BenchmarkMetrics, list[int]]: + actual_output_lens: list[int] = [] + total_input = 0 + completed = 0 + good_completed = 0 + itls: list[float] = [] + tpots: list[float] = [] + all_tpots: list[float] = [] + ttfts: list[float] = [] + e2els: list[float] = [] + for i in range(len(outputs)): + if outputs[i].success: + # We use the tokenizer to count the number of output tokens for all + # serving backends instead of looking at len(outputs[i].itl) since + # multiple output tokens may be bundled together + # Note : this may inflate the output token count slightly + output_len = len( + tokenizer(outputs[i].generated_text, + add_special_tokens=False).input_ids) + actual_output_lens.append(output_len) + total_input += input_requests[i].prompt_len + tpot = 0 + if output_len > 1: + latency_minus_ttft = outputs[i].latency - outputs[i].ttft + tpot = latency_minus_ttft / (output_len - 1) + tpots.append(tpot) + outputs[i].tpot = tpot + # Note: if output_len <= 1, we regard tpot as 0 for goodput + all_tpots.append(tpot) + itls += outputs[i].itl + ttfts.append(outputs[i].ttft) + e2els.append(outputs[i].latency) + completed += 1 + else: + actual_output_lens.append(0) + + if goodput_config_dict: + valid_metrics = [] + slo_values = [] + + if "ttft" in goodput_config_dict: + valid_metrics.append(ttfts) + slo_values.append(goodput_config_dict["ttft"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + if "tpot" in goodput_config_dict: + valid_metrics.append(all_tpots) + slo_values.append(goodput_config_dict["tpot"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + if "e2el" in goodput_config_dict: + valid_metrics.append(e2els) + slo_values.append(goodput_config_dict["e2el"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + + for req_metric in zip(*valid_metrics): + is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) + if is_good_req: + good_completed += 1 + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2) + metrics = BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_output=sum(actual_output_lens), + request_throughput=completed / dur_s, + request_goodput=good_completed / dur_s, + output_throughput=sum(actual_output_lens) / dur_s, + total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, + mean_ttft_ms=np.mean(ttfts or 0) * + 1000, # ttfts is empty if streaming is not supported by backend + std_ttft_ms=np.std(ttfts or 0) * 1000, + median_ttft_ms=np.median(ttfts or 0) * 1000, + percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) + for p in selected_percentiles], + mean_tpot_ms=np.mean(tpots or 0) * 1000, + std_tpot_ms=np.std(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) + for p in selected_percentiles], + mean_itl_ms=np.mean(itls or 0) * 1000, + std_itl_ms=np.std(itls or 0) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) + for p in selected_percentiles], + mean_e2el_ms=np.mean(e2els or 0) * 1000, + std_e2el_ms=np.std(e2els or 0) * 1000, + median_e2el_ms=np.median(e2els or 0) * 1000, + percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) + for p in selected_percentiles], + ) + + return metrics, actual_output_lens + + +async def benchmark( + backend: str, + api_url: str, + base_url: str, + model_id: str, + tokenizer: PreTrainedTokenizerBase, + input_requests: list[SampleRequest], + request_rate: float, + burstiness: float, + disable_tqdm: bool, + profile: bool, + selected_percentile_metrics: list[str], + selected_percentiles: list[str], + ignore_eos: bool, + max_concurrency: Optional[int], + structured_output_ratio: float, + structured_output_backend: str, + goodput_config_dict: Optional[dict[str, float]] = None, +): + if backend in ASYNC_REQUEST_FUNCS: + request_func = ASYNC_REQUEST_FUNCS[backend] + else: + raise ValueError(f"Unknown backend: {backend}") + + def prepare_extra_body(request) -> dict: + extra_body = {} + # Add the schema to the extra_body + extra_body[request.structure_type] = request.schema + # Add the specific structured_output_backend + extra_body["guided_decoding_backend"] = structured_output_backend + return extra_body + + print("Starting initial single prompt test run...") + structured_output_req_idx = random.sample( + range(len(input_requests)), + int(len(input_requests) * structured_output_ratio)) + + test_request = input_requests[0] + test_req_extra_body = (prepare_extra_body(test_request) + if 0 in structured_output_req_idx else None) + test_input = RequestFuncInput( + model=model_id, + prompt=test_request.prompt, + api_url=api_url, + prompt_len=test_request.prompt_len, + output_len=test_request.expected_output_len, + ignore_eos=ignore_eos, + extra_body=test_req_extra_body, + ) + test_output = await request_func(request_func_input=test_input) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark arguments " + f"are correctly specified. Error: {test_output.error}") + else: + print("Initial test run completed. Starting main benchmark run...") + + if profile: + print("Starting profiler...") + profile_input = RequestFuncInput( + model=model_id, + prompt=test_request.prompt, + api_url=base_url + "/start_profile", + prompt_len=test_request.prompt_len, + output_len=test_request.expected_output_len, + ignore_eos=ignore_eos, + extra_body=test_req_extra_body, + ) + profile_output = await request_func(request_func_input=profile_input) + if profile_output.success: + print("Profiler started") + + if burstiness == 1.0: + distribution = "Poisson process" + else: + distribution = "Gamma distribution" + + print(f"Traffic request rate: {request_rate}") + print(f"Burstiness factor: {burstiness} ({distribution})") + print(f"Maximum request concurrency: {max_concurrency}") + + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + + # This can be used once the minimum Python version is 3.10 or higher, + # and it will simplify the code in limited_request_func. + # semaphore = (asyncio.Semaphore(max_concurrency) + # if max_concurrency else contextlib.nullcontext()) + semaphore = (asyncio.Semaphore(max_concurrency) + if max_concurrency else None) + + async def limited_request_func(request_func_input, pbar): + if semaphore is None: + return await request_func(request_func_input=request_func_input, + pbar=pbar) + async with semaphore: + return await request_func(request_func_input=request_func_input, + pbar=pbar) + + benchmark_start_time = time.perf_counter() + tasks: list[asyncio.Task] = [] + expected: list[str] = [] + async for i, request in get_request(input_requests, request_rate, + burstiness): + extra_body = prepare_extra_body( + request) if i in structured_output_req_idx else None + request_func_input = RequestFuncInput( + model=model_id, + prompt=request.prompt, + api_url=api_url, + prompt_len=request.prompt_len, + output_len=request.expected_output_len, + ignore_eos=ignore_eos, + extra_body=extra_body, + ) + expected.append(request.completion) + tasks.append( + asyncio.create_task( + limited_request_func(request_func_input=request_func_input, + pbar=pbar))) + outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) + + if profile: + print("Stopping profiler...") + profile_input = RequestFuncInput( + model=model_id, + prompt=test_request.prompt, + api_url=base_url + "/stop_profile", + prompt_len=test_request.prompt_len, + output_len=test_request.expected_output_len, + extra_body={test_request.structure_type: test_request.schema}, + ) + profile_output = await request_func(request_func_input=profile_input) + if profile_output.success: + print("Profiler stopped") + + if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + metrics, actual_output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + selected_percentile_metrics=selected_percentile_metrics, + selected_percentiles=selected_percentiles, + goodput_config_dict=goodput_config_dict, + ) + + print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", + benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total generated tokens:", + metrics.total_output)) + print("{:<40} {:<10.2f}".format("Request throughput (req/s):", + metrics.request_throughput)) + if goodput_config_dict: + print("{:<40} {:<10.2f}".format("Request goodput (req/s):", + metrics.request_goodput)) + print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", + metrics.output_throughput)) + print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", + metrics.total_token_throughput)) + + result = { + "duration": + benchmark_duration, + "completed": + metrics.completed, + "total_input_tokens": + metrics.total_input, + "total_output_tokens": + metrics.total_output, + "request_throughput": + metrics.request_throughput, + "output_throughput": + metrics.output_throughput, + "total_token_throughput": + metrics.total_token_throughput, + "ttft_description": + pd.Series([output.ttft for output in outputs]).describe().to_dict(), + "tpot_description": + pd.Series([output.tpot for output in outputs]).describe().to_dict(), + "input_lens": [output.prompt_len for output in outputs], + "output_lens": + actual_output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "errors": [output.error for output in outputs], + } + + ret = [{ + 'generated': output.generated_text, + 'expected': gt + } for output, gt in zip(outputs, expected)] + + def process_one_metric( + # E.g., "ttft" + metric_attribute_name: str, + # E.g., "TTFT" + metric_name: str, + # E.g., "Time to First Token" + metric_header: str, + ): + # This function prints and adds statistics of the specified + # metric. + if metric_attribute_name not in selected_percentile_metrics: + return + print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) + print("{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_attribute_name}_ms"))) + print("{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_attribute_name}_ms"))) + result[f"mean_{metric_attribute_name}_ms"] = getattr( + metrics, f"mean_{metric_attribute_name}_ms") + result[f"median_{metric_attribute_name}_ms"] = getattr( + metrics, f"median_{metric_attribute_name}_ms") + result[f"std_{metric_attribute_name}_ms"] = getattr( + metrics, f"std_{metric_attribute_name}_ms") + for p, value in getattr(metrics, + f"percentiles_{metric_attribute_name}_ms"): + p_word = str(int(p)) if int(p) == p else str(p) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", + value)) + result[f"p{p_word}_{metric_attribute_name}_ms"] = value + + process_one_metric("ttft", "TTFT", "Time to First Token") + process_one_metric("tpot", "TPOT", + "Time per Output Token (excl. 1st token)") + process_one_metric("itl", "ITL", "Inter-token Latency") + process_one_metric("e2el", "E2EL", "End-to-end Latency") + + print("=" * 50) + + return result, ret + + +def evaluate(ret, args): + + def _eval_correctness_json(expected, actual): + # extract json string from string using regex + import re + actual = actual.replace('\n', '').replace(' ', '').strip() + try: + actual = re.search(r'\{.*\}', actual).group() + actual = json.loads(actual) + except Exception: + return False + + return True + + def _eval_correctness_choice(expected, actual): + return actual in args.choice + + def _eval_correctness_regex(expected, actual): + import re + return re.match(args.regex, actual) is not None + + def _eval_correctness(expected, actual): + if args.structure_type == 'guided_json': + return _eval_correctness_json(expected, actual) + elif args.structure_type == 'guided_regex': + return _eval_correctness_regex(expected, actual) + elif args.structure_type == 'guided_choice': + return _eval_correctness_choice(expected, actual) + else: + return None + + scores = [] + for res in ret: + score = _eval_correctness(res['expected'], res['generated']) + res['correctness'] = score + scores.append(score) + + not_none_scores = [score for score in scores if score is not None] + + return (sum(not_none_scores) / len(not_none_scores) * + 100) if len(not_none_scores) > 0 else None + + +def parse_goodput(slo_pairs): + goodput_config_dict = {} + try: + for slo_pair in slo_pairs: + slo_name, slo_val = slo_pair.split(":") + goodput_config_dict[slo_name] = float(slo_val) + except ValueError as err: + raise argparse.ArgumentTypeError( + "Invalid format found for service level objectives. " + "Specify service level objectives for goodput as \"KEY:VALUE\" " + "pairs, where the key is a metric name, and the value is a " + "number in milliseconds.") from err + return goodput_config_dict + + +def check_goodput_args(args): + goodput_config_dict = {} + VALID_NAMES = ["ttft", "tpot", "e2el"] + if args.goodput: + goodput_config_dict = parse_goodput(args.goodput) + for slo_name, slo_val in goodput_config_dict.items(): + if slo_name not in VALID_NAMES: + raise ValueError( + f"Invalid metric name found, {slo_name}: {slo_val}. " + "The service level objective name should be one of " + f"{str(VALID_NAMES)}. ") + if slo_val < 0: + raise ValueError( + f"Invalid value found, {slo_name}: {slo_val}. " + "The service level objective value should be " + "non-negative.") + return goodput_config_dict + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + np.random.seed(args.seed) + + backend = args.backend + model_id = args.model + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + + if args.base_url is not None: + api_url = f"{args.base_url}{args.endpoint}" + base_url = f"{args.base_url}" + else: + api_url = f"http://{args.host}:{args.port}{args.endpoint}" + base_url = f"http://{args.host}:{args.port}" + + tokenizer = get_tokenizer( + tokenizer_id, + trust_remote_code=args.trust_remote_code, + tokenizer_mode=args.tokenizer_mode, + ) + + if args.dataset == 'grammar': + args.structure_type = 'guided_grammar' + elif args.dataset == 'regex': + args.structure_type = 'guided_regex' + elif args.dataset == 'choice': + args.structure_type = 'guided_choice' + else: + args.structure_type = 'guided_json' + + if args.no_structured_output: + args.structured_output_ratio = 0 + if args.save_results: + result_file_name = f'{args.structured_output_ratio}guided' + result_file_name += f"_{backend}" + result_file_name += f"_{args.request_rate}qps" + result_file_name += f"_{args.model.split('/')[-1]}" + result_file_name += f"_{args.dataset}" + result_file_name += f"_{args.num_prompts}" + result_file_name += f"_out{args.output_len}" + result_file_name += ".txt" + else: + result_file_name = None + + input_requests = sample_requests(tokenizer, args) + + goodput_config_dict = check_goodput_args(args) + + benchmark_result, ret = asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + base_url=base_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=args.request_rate, + burstiness=args.burstiness, + disable_tqdm=args.disable_tqdm, + profile=args.profile, + selected_percentile_metrics=args.percentile_metrics.split(","), + selected_percentiles=[ + float(p) for p in args.metric_percentiles.split(",") + ], + ignore_eos=args.ignore_eos, + max_concurrency=args.max_concurrency, + structured_output_ratio=args.structured_output_ratio, + structured_output_backend=args.structured_output_backend, + goodput_config_dict=goodput_config_dict, + )) + + # Save config and results to json + score = evaluate(ret, args) + print("correct_rate(%)", score, '\n') + if args.save_results: + results = { + "backend": + backend, + "model_id": + model_id, + "tokenizer_id": + tokenizer_id, + "num_prompts": + args.num_prompts, + "request_rate": + args.request_rate if args.request_rate < float("inf") else "inf", + "burstiness": + args.burstiness, + "max_concurrency": + args.max_concurrency, + "correct_rate(%)": + score + } + results = {"outputs": ret, **results, **benchmark_result} + + # Save to file + if args.result_filename: + result_file_name = args.result_filename + if args.result_dir: + result_file_name = os.path.join(args.result_dir, result_file_name) + with open(result_file_name, "w", encoding='utf-8') as outfile: + json.dump(results, outfile, indent=4) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark the online serving throughput.") + parser.add_argument( + "--backend", + type=str, + default="vllm", + choices=list(ASYNC_REQUEST_FUNCS.keys()), + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + # Use 127.0.0.1 here instead of localhost to force the use of ipv4 + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--endpoint", + type=str, + default="/v1/completions", + help="API endpoint.", + ) + parser.add_argument("--dataset", + default='json', + choices=[ + 'json', 'json-unique', 'grammar', 'regex', + 'choice', 'xgrammar_bench' + ]) + parser.add_argument("--json_schema_path", + type=str, + default=None, + help="Path to json schema.") + parser.add_argument( + "--max-concurrency", + type=int, + default=None, + help="Maximum number of concurrent requests. This can be used " + "to help simulate an environment where a higher level component " + "is enforcing a maximum number of concurrent requests. While the " + "--request-rate argument controls the rate at which requests are " + "initiated, this argument will control how many are actually allowed " + "to execute at a time. This means that when used in combination, the " + "actual request rate may be lower than specified with --request-rate, " + "if the server is not processing requests fast enough to keep up.") + parser.add_argument( + "--model", + type=str, + required=True, + help="Name of the model.", + ) + parser.add_argument( + "--tokenizer", + type=str, + help= + "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + ) + parser.add_argument( + "--tokenizer-mode", + type=str, + default="auto", + help= + "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + ) + parser.add_argument( + "--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.", + ) + parser.add_argument( + "--output-len", + type=int, + default=128, + help="Number of output tokens.", + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, " + "then all the requests are sent at time 0. " + "Otherwise, we use Poisson process or gamma distribution " + "to synthesize the request arrival times.", + ) + parser.add_argument( + "--burstiness", + type=float, + default=1.0, + help="Burstiness factor of the request generation. " + "Only take effect when request_rate is not inf. " + "Default value is 1, which follows Poisson process. " + "Otherwise, the request intervals follow a gamma distribution. " + "A lower burstiness value (0 < burstiness < 1) results in more " + "bursty requests. A higher burstiness value (burstiness > 1) " + "results in a more uniform arrival of requests.", + ) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Trust remote code from huggingface", + ) + parser.add_argument( + "--disable-tqdm", + action="store_true", + help="Specify to disable tqdm progress bar.", + ) + parser.add_argument( + "--save-results", + action="store_true", + help="Specify to save benchmark results to a json file", + ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "VLLM_TORCH_PROFILER_DIR to enable profiler.", + ) + parser.add_argument( + "--result-dir", + type=str, + default=None, + help="Specify directory to save benchmark json results." + "If not specified, results are saved in the current directory.", + ) + parser.add_argument( + "--result-filename", + type=str, + default=None, + help="Specify the filename to save benchmark json results." + "If not specified, results will be saved in " + "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" + " format.", + ) + parser.add_argument( + "--ignore-eos", + action="store_true", + help="Set ignore_eos flag when sending the benchmark request." + "Warning: ignore_eos is not supported in deepspeed_mii and tgi.") + parser.add_argument( + "--percentile-metrics", + type=str, + default="ttft,tpot,itl", + help="Comma-seperated list of selected metrics to report percentils. " + "This argument specifies the metrics to report percentiles. " + "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". " + "Default value is \"ttft,tpot,itl\".") + parser.add_argument( + "--metric-percentiles", + type=str, + default="99", + help="Comma-seperated list of percentiles for selected metrics. " + "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " + "Default value is \"99\". " + "Use \"--percentile-metrics\" to select metrics.", + ) + parser.add_argument( + "--goodput", + nargs="+", + required=False, + help="Specify service level objectives for goodput as \"KEY:VALUE\" " + "pairs, where the key is a metric name, and the value is in " + "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " + "separated by spaces. Allowed request level metric names are " + "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " + "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " + "and the blog: https://hao-ai-lab.github.io/blogs/distserve") + + parser.add_argument("--no-structured-output", + action='store_true', + default=False, + help="Whether to disable JSON decoding or not.") + parser.add_argument("--structured-output-ratio", + type=float, + default=1.0, + help="Ratio of Structured Outputs requests") + parser.add_argument( + "--structured-output-backend", + type=str, + choices=["outlines", "lm-format-enforcer", "xgrammar", "guidance"], + default="xgrammar", + help="Backend to use for structured outputs") + + args = parser.parse_args() + main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index a52e67bbbe7e..53869db478c5 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -1,144 +1,204 @@ +# SPDX-License-Identifier: Apache-2.0 """Benchmark offline inference throughput.""" import argparse +import dataclasses import json +import os import random import time -from typing import List, Optional, Tuple +import warnings +from typing import Any, Optional, Union import torch +import uvloop +from benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset, + RandomDataset, SampleRequest, ShareGPTDataset, + SonnetDataset, VisionArenaDataset) +from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json from tqdm import tqdm from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) -from vllm.engine.arg_utils import EngineArgs -from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS -from vllm.utils import FlexibleArgumentParser - - -def sample_requests( - dataset_path: str, - num_requests: int, - tokenizer: PreTrainedTokenizerBase, - fixed_output_len: Optional[int], -) -> List[Tuple[str, int, int]]: - if fixed_output_len is not None and fixed_output_len < 4: - raise ValueError("output_len too small") - - # Load the dataset. - with open(dataset_path) as f: - dataset = json.load(f) - # Filter out the conversations with less than 2 turns. - dataset = [data for data in dataset if len(data["conversations"]) >= 2] - # Only keep the first two turns of each conversation. - dataset = [(data["conversations"][0]["value"], - data["conversations"][1]["value"]) for data in dataset] - - # Shuffle the dataset. - random.shuffle(dataset) - - # Filter out sequences that are too long or too short - filtered_dataset: List[Tuple[str, int, int]] = [] - for i in range(len(dataset)): - if len(filtered_dataset) == num_requests: - break - - # Tokenize the prompts and completions. - prompt = dataset[i][0] - prompt_token_ids = tokenizer(prompt).input_ids - completion = dataset[i][1] - completion_token_ids = tokenizer(completion).input_ids - prompt_len = len(prompt_token_ids) - output_len = len(completion_token_ids - ) if fixed_output_len is None else fixed_output_len - if prompt_len < 4 or output_len < 4: - # Prune too short sequences. - continue - if prompt_len > 1024 or prompt_len + output_len > 2048: - # Prune too long sequences. - continue - filtered_dataset.append((prompt, prompt_len, output_len)) - - return filtered_dataset +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args) +from vllm.inputs import TextPrompt, TokensPrompt +from vllm.lora.request import LoRARequest +from vllm.outputs import RequestOutput +from vllm.sampling_params import BeamSearchParams +from vllm.utils import FlexibleArgumentParser, merge_async_iterators def run_vllm( - requests: List[Tuple[str, int, int]], - model: str, - tokenizer: str, - quantization: Optional[str], - tensor_parallel_size: int, - seed: int, + requests: list[SampleRequest], n: int, - use_beam_search: bool, - trust_remote_code: bool, - dtype: str, - max_model_len: Optional[int], - enforce_eager: bool, - kv_cache_dtype: str, - quantization_param_path: Optional[str], - device: str, - enable_prefix_caching: bool, - enable_chunked_prefill: bool, - max_num_batched_tokens: int, - distributed_executor_backend: Optional[str], - gpu_memory_utilization: float = 0.9, - download_dir: Optional[str] = None, - load_format: str = EngineArgs.load_format, -) -> float: + engine_args: EngineArgs, + disable_detokenize: bool = False, +) -> tuple[float, Optional[list[RequestOutput]]]: from vllm import LLM, SamplingParams - llm = LLM( - model=model, - tokenizer=tokenizer, - quantization=quantization, - tensor_parallel_size=tensor_parallel_size, - seed=seed, - trust_remote_code=trust_remote_code, - dtype=dtype, - max_model_len=max_model_len, - gpu_memory_utilization=gpu_memory_utilization, - enforce_eager=enforce_eager, - kv_cache_dtype=kv_cache_dtype, - quantization_param_path=quantization_param_path, - device=device, - enable_prefix_caching=enable_prefix_caching, - download_dir=download_dir, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - distributed_executor_backend=distributed_executor_backend, - load_format=load_format, - ) - + llm = LLM(**dataclasses.asdict(engine_args)) + assert all( + llm.llm_engine.model_config.max_model_len >= ( + request.prompt_len + request.expected_output_len) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests.") # Add the requests to the engine. - prompts: List[str] = [] - sampling_params: List[SamplingParams] = [] - for prompt, _, output_len in requests: - prompts.append(prompt) + prompts: list[Union[TextPrompt, TokensPrompt]] = [] + sampling_params: list[SamplingParams] = [] + for request in requests: + prompts.append( + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data) + if "prompt_token_ids" in request.prompt else \ + TextPrompt(prompt=request.prompt, + multi_modal_data=request.multi_modal_data)) sampling_params.append( SamplingParams( n=n, - temperature=0.0 if use_beam_search else 1.0, + temperature=1.0, top_p=1.0, - use_beam_search=use_beam_search, ignore_eos=True, + max_tokens=request.expected_output_len, + detokenize=not disable_detokenize, + )) + lora_requests: Optional[list[LoRARequest]] = None + if engine_args.enable_lora: + lora_requests = [request.lora_request for request in requests] + + use_beam_search = False + + outputs = None + if not use_beam_search: + start = time.perf_counter() + outputs = llm.generate(prompts, + sampling_params, + lora_request=lora_requests, + use_tqdm=True) + end = time.perf_counter() + else: + assert lora_requests is None, "BeamSearch API does not support LoRA" + prompts = [request.prompt for request in requests] + # output_len should be the same for all requests. + output_len = requests[0][2] + for request in requests: + assert request.expected_output_len == output_len + start = time.perf_counter() + llm.beam_search( + prompts, + BeamSearchParams( + beam_width=n, max_tokens=output_len, + ignore_eos=True, )) + end = time.perf_counter() + return end - start, outputs + + +def run_vllm_chat( + requests: list[SampleRequest], + n: int, + engine_args: EngineArgs, + disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]: + """ + Run vLLM chat benchmark. This function is recommended ONLY for benchmarking + multimodal models as it properly handles multimodal inputs and chat + formatting. For non-multimodal models, use run_vllm() instead. + """ + from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) + + assert all( + llm.llm_engine.model_config.max_model_len >= ( + request.prompt_len + request.expected_output_len) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of " + "prompt_len and expected_output_len for all requests.") + prompts = [] + sampling_params: list[SamplingParams] = [] + for request in requests: + prompts.append(request.prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=request.expected_output_len, + detokenize=not disable_detokenize, + )) start = time.perf_counter() - llm.generate(prompts, sampling_params, use_tqdm=True) + outputs = llm.chat(prompts, sampling_params, use_tqdm=True) end = time.perf_counter() - return end - start + return end - start, outputs + + +async def run_vllm_async( + requests: list[SampleRequest], + n: int, + engine_args: AsyncEngineArgs, + disable_frontend_multiprocessing: bool = False, + disable_detokenize: bool = False, +) -> float: + from vllm import SamplingParams + + async with build_async_engine_client_from_engine_args( + engine_args, disable_frontend_multiprocessing) as llm: + assert all( + llm.model_config.max_model_len >= (request.prompt_len + + request.expected_output_len) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests.") + + # Add the requests to the engine. + prompts: list[Union[TextPrompt, TokensPrompt]] = [] + sampling_params: list[SamplingParams] = [] + lora_requests: list[Optional[LoRARequest]] = [] + for request in requests: + prompts.append( + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data) + if "prompt_token_ids" in request.prompt else \ + TextPrompt(prompt=request.prompt, + multi_modal_data=request.multi_modal_data)) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=request.expected_output_len, + detokenize=not disable_detokenize, + )) + lora_requests.append(request.lora_request) + + generators = [] + start = time.perf_counter() + for i, (prompt, sp, + lr) in enumerate(zip(prompts, sampling_params, lora_requests)): + generator = llm.generate(prompt, + sp, + lora_request=lr, + request_id=f"test{i}") + generators.append(generator) + all_gens = merge_async_iterators(*generators) + async for i, res in all_gens: + pass + end = time.perf_counter() + return end - start def run_hf( - requests: List[Tuple[str, int, int]], + requests: list[SampleRequest], model: str, tokenizer: PreTrainedTokenizerBase, n: int, - use_beam_search: bool, max_batch_size: int, trust_remote_code: bool, + disable_detokenize: bool = False, ) -> float: - assert not use_beam_search llm = AutoModelForCausalLM.from_pretrained( model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) if llm.config.model_type == "llama": @@ -148,7 +208,7 @@ def run_hf( pbar = tqdm(total=len(requests)) start = time.perf_counter() - batch: List[str] = [] + batch: list[str] = [] max_prompt_len = 0 max_output_len = 0 for i in range(len(requests)): @@ -170,15 +230,16 @@ def run_hf( padding=True).input_ids llm_outputs = llm.generate( input_ids=input_ids.cuda(), - do_sample=not use_beam_search, + do_sample=True, num_return_sequences=n, temperature=1.0, top_p=1.0, use_cache=True, max_new_tokens=max_output_len, ) - # Include the decoding time. - tokenizer.batch_decode(llm_outputs, skip_special_tokens=True) + if not disable_detokenize: + # Include the decoding time. + tokenizer.batch_decode(llm_outputs, skip_special_tokens=True) pbar.update(len(batch)) # Clear the batch. @@ -190,14 +251,14 @@ def run_hf( def run_mii( - requests: List[Tuple[str, int, int]], + requests: list[SampleRequest], model: str, tensor_parallel_size: int, output_len: int, ) -> float: from mii import client, serve llm = serve(model, tensor_parallel=tensor_parallel_size) - prompts = [prompt for prompt, _, _ in requests] + prompts = [request.prompt for request in requests] start = time.perf_counter() llm.generate(prompts, max_new_tokens=output_len) @@ -207,46 +268,147 @@ def run_mii( return end - start +def save_to_pytorch_benchmark_format(args: argparse.Namespace, + results: dict[str, Any]) -> None: + pt_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={ + "requests_per_second": [results["requests_per_second"]], + "tokens_per_second": [results["tokens_per_second"]], + }, + extra_info={ + k: results[k] + for k in ["elapsed_time", "num_requests", "total_num_tokens"] + }) + if pt_records: + # Don't use json suffix here as we don't want CI to pick it up + pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" + write_to_json(pt_file, pt_records) + + +def get_requests(args, tokenizer): + # Common parameters for all dataset types. + common_kwargs = { + "dataset_path": args.dataset_path, + "random_seed": args.seed, + } + sample_kwargs = { + "tokenizer": tokenizer, + "lora_path": args.lora_path, + "max_loras": args.max_loras, + "num_requests": args.num_prompts, + "input_len": args.input_len, + "output_len": args.output_len, + } + if args.dataset_path is None or args.dataset_name == "random": + sample_kwargs["range_ratio"] = args.random_range_ratio + sample_kwargs["prefix_len"] = args.prefix_len + dataset_cls = RandomDataset + elif args.dataset_name == "sharegpt": + dataset_cls = ShareGPTDataset + if args.backend == "vllm-chat": + sample_kwargs["enable_multimodal_chat"] = True + elif args.dataset_name == "sonnet": + assert tokenizer.chat_template or tokenizer.default_chat_template, ( + "Tokenizer/model must have chat template for sonnet dataset.") + dataset_cls = SonnetDataset + sample_kwargs["prefix_len"] = args.prefix_len + sample_kwargs["return_prompt_formatted"] = True + elif args.dataset_name == "burstgpt": + dataset_cls = BurstGPTDataset + elif args.dataset_name == "hf": + if args.backend != "vllm-chat": + raise ValueError( + "hf datasets only are supported by vllm-chat backend") + # Choose between VisionArenaDataset and HuggingFaceDataset based on + # provided parameters. + dataset_cls = (VisionArenaDataset if args.dataset_path + == VisionArenaDataset.VISION_ARENA_DATASET_PATH + and args.hf_subset is None else HuggingFaceDataset) + common_kwargs['dataset_subset'] = args.hf_subset + common_kwargs['dataset_split'] = args.hf_split + sample_kwargs["enable_multimodal_chat"] = True + + else: + raise ValueError(f"Unknown dataset name: {args.dataset_name}") + # Remove None values + sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None} + return dataset_cls(**common_kwargs).sample(**sample_kwargs) + + def main(args: argparse.Namespace): + if args.seed is None: + args.seed = 0 print(args) random.seed(args.seed) - # Sample the requests. tokenizer = AutoTokenizer.from_pretrained( args.tokenizer, trust_remote_code=args.trust_remote_code) - if args.dataset is None: - # Synthesize a prompt with the given input length. - prompt = "hi" * (args.input_len - 1) - requests = [(prompt, args.input_len, args.output_len) - for _ in range(args.num_prompts)] - else: - requests = sample_requests(args.dataset, args.num_prompts, tokenizer, - args.output_len) - + requests = get_requests(args, tokenizer) + is_multi_modal = any(request.multi_modal_data is not None + for request in requests) + request_outputs: Optional[list[RequestOutput]] = None if args.backend == "vllm": - elapsed_time = run_vllm( - requests, args.model, args.tokenizer, args.quantization, - args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, - args.trust_remote_code, args.dtype, args.max_model_len, - args.enforce_eager, args.kv_cache_dtype, - args.quantization_param_path, args.device, - args.enable_prefix_caching, args.enable_chunked_prefill, - args.max_num_batched_tokens, args.distributed_executor_backend, - args.gpu_memory_utilization, args.download_dir, args.load_format) + if args.async_engine: + elapsed_time = uvloop.run( + run_vllm_async( + requests, + args.n, + AsyncEngineArgs.from_cli_args(args), + args.disable_frontend_multiprocessing, + args.disable_detokenize, + )) + else: + elapsed_time, request_outputs = run_vllm( + requests, args.n, EngineArgs.from_cli_args(args), + args.disable_detokenize) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, - args.use_beam_search, args.hf_max_batch_size, - args.trust_remote_code) + args.hf_max_batch_size, args.trust_remote_code, + args.disable_detokenize) elif args.backend == "mii": elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, args.output_len) + elif args.backend == "vllm-chat": + elapsed_time, request_outputs = run_vllm_chat( + requests, args.n, EngineArgs.from_cli_args(args), + args.disable_detokenize) else: raise ValueError(f"Unknown backend: {args.backend}") - total_num_tokens = sum(prompt_len + output_len - for _, prompt_len, output_len in requests) + + if request_outputs: + # Note: with the vllm and vllm-chat backends, + # we have request_outputs, which we use to count tokens. + total_prompt_tokens = 0 + total_output_tokens = 0 + for ro in request_outputs: + if not isinstance(ro, RequestOutput): + continue + total_prompt_tokens += len( + ro.prompt_token_ids) if ro.prompt_token_ids else 0 + total_output_tokens += sum( + len(o.token_ids) for o in ro.outputs if o) + total_num_tokens = total_prompt_tokens + total_output_tokens + else: + total_num_tokens = sum(r.prompt_len + r.expected_output_len + for r in requests) + total_output_tokens = sum(r.expected_output_len for r in requests) + total_prompt_tokens = total_num_tokens - total_output_tokens + + if is_multi_modal and args.backend != "vllm-chat": + print("\033[91mWARNING\033[0m: Multi-modal request with " + f"{args.backend} backend detected. The " + "following metrics are not accurate because image tokens are not" + " counted. See vllm-project/vllm/issues/9778 for details.") + # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length. + # vllm-chat backend counts the image tokens now + print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " - f"{total_num_tokens / elapsed_time:.2f} tokens/s") + f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " + f"{total_output_tokens / elapsed_time:.2f} output tokens/s") + print(f"Total num prompt tokens: {total_prompt_tokens}") + print(f"Total num output tokens: {total_output_tokens}") # Output JSON results if specified if args.output_json: @@ -259,18 +421,115 @@ def main(args: argparse.Namespace): } with open(args.output_json, "w") as f: json.dump(results, f, indent=4) + save_to_pytorch_benchmark_format(args, results) + + +def validate_args(args): + """ + Validate command-line arguments. + """ + + # === Deprecation and Defaulting === + if args.dataset is not None: + warnings.warn( + "The '--dataset' argument will be deprecated in the next release. " + "Please use '--dataset-name' and '--dataset-path' instead.", + stacklevel=2) + args.dataset_path = args.dataset + + if not getattr(args, "tokenizer", None): + args.tokenizer = args.model + + # === Backend Validation === + valid_backends = {"vllm", "hf", "mii", "vllm-chat"} + if args.backend not in valid_backends: + raise ValueError(f"Unsupported backend: {args.backend}") + + # === Dataset Configuration === + if not args.dataset and not args.dataset_path: + print( + "When dataset path is not set, it will default to random dataset") + args.dataset_name = 'random' + if args.input_len is None: + raise ValueError("input_len must be provided for a random dataset") + + # === Dataset Name Specific Checks === + # --hf-subset and --hf-split: only used + # when dataset_name is 'hf' + if args.dataset_name != "hf" and ( + getattr(args, "hf_subset", None) is not None + or getattr(args, "hf_split", None) is not None): + warnings.warn("--hf-subset and --hf-split will be ignored \ + since --dataset-name is not 'hf'.", + stacklevel=2) + elif args.dataset_name == "hf" and args.backend != "vllm-chat": + raise ValueError( + "When --dataset-name is 'hf', backend must be 'vllm-chat'") + + # --random-range-ratio: only used when dataset_name is 'random' + if args.dataset_name != 'random' and args.random_range_ratio is not None: + warnings.warn("--random-range-ratio will be ignored since \ + --dataset-name is not 'random'.", + stacklevel=2) + + # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not + # set. + if args.dataset_name not in {"random", "sonnet", None + } and args.prefix_len is not None: + warnings.warn("--prefix-len will be ignored since --dataset-name\ + is not 'random', 'sonnet', or not set.", + stacklevel=2) + + # === LoRA Settings === + if getattr(args, "enable_lora", False) and args.backend != "vllm": + raise ValueError( + "LoRA benchmarking is only supported for vLLM backend") + if getattr(args, "enable_lora", False) and args.lora_path is None: + raise ValueError("LoRA path must be provided when enable_lora is True") + + # === Backend-specific Validations === + if args.backend == "hf" and args.hf_max_batch_size is None: + raise ValueError("HF max batch size is required for HF backend") + if args.backend != "hf" and args.hf_max_batch_size is not None: + raise ValueError("HF max batch size is only for HF backend.") + + if args.backend in {"hf", "mii"} and getattr(args, "quantization", + None) is not None: + raise ValueError("Quantization is only for vLLM backend.") + + if args.backend == "mii" and args.dtype != "auto": + raise ValueError("dtype must be auto for MII backend.") + if args.backend == "mii" and args.n != 1: + raise ValueError("n must be 1 for MII backend.") + if args.backend == "mii" and args.tokenizer != args.model: + raise ValueError( + "Tokenizer must be the same as the model for MII backend.") if __name__ == "__main__": parser = FlexibleArgumentParser(description="Benchmark the throughput.") parser.add_argument("--backend", type=str, - choices=["vllm", "hf", "mii"], + choices=["vllm", "hf", "mii", "vllm-chat"], default="vllm") - parser.add_argument("--dataset", + parser.add_argument( + "--dataset-name", + type=str, + choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"], + help="Name of the dataset to benchmark on.", + default="sharegpt") + parser.add_argument( + "--dataset", + type=str, + default=None, + help="Path to the ShareGPT dataset, will be deprecated in\ + the next release. The dataset is expected to " + "be a json in form of list[dict[..., conversations: " + "list[dict[..., value: ]]]]") + parser.add_argument("--dataset-path", type=str, default=None, - help="Path to the dataset.") + help="Path to the dataset") parser.add_argument("--input-len", type=int, default=None, @@ -280,160 +539,70 @@ def main(args: argparse.Namespace): default=None, help="Output length for each request. Overrides the " "output length from the dataset.") - parser.add_argument("--model", type=str, default="facebook/opt-125m") - parser.add_argument("--tokenizer", type=str, default=None) - parser.add_argument('--quantization', - '-q', - choices=[*QUANTIZATION_METHODS, None], - default=None) - parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--n", type=int, default=1, help="Number of generated sequences per prompt.") - parser.add_argument("--use-beam-search", action="store_true") parser.add_argument("--num-prompts", type=int, default=1000, help="Number of prompts to process.") - parser.add_argument("--seed", type=int, default=0) parser.add_argument("--hf-max-batch-size", type=int, default=None, help="Maximum batch size for HF backend.") - parser.add_argument('--trust-remote-code', - action='store_true', - help='trust remote code from huggingface') parser.add_argument( - '--max-model-len', - type=int, - default=None, - help='Maximum length of a sequence (including prompt and output). ' - 'If None, will be derived from the model.') - parser.add_argument( - '--dtype', + '--output-json', type=str, - default='auto', - choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], - help='data type for model weights and activations. ' - 'The "auto" option will use FP16 precision ' - 'for FP32 and FP16 models, and BF16 precision ' - 'for BF16 models.') - parser.add_argument('--gpu-memory-utilization', - type=float, - default=0.9, - help='the fraction of GPU memory to be used for ' - 'the model executor, which can range from 0 to 1.' - 'If unspecified, will use the default value of 0.9.') - parser.add_argument("--enforce-eager", - action="store_true", - help="enforce eager execution") + default=None, + help='Path to save the throughput results in JSON format.') + parser.add_argument("--async-engine", + action='store_true', + default=False, + help="Use vLLM async engine rather than LLM class.") + parser.add_argument("--disable-frontend-multiprocessing", + action='store_true', + default=False, + help="Disable decoupled async engine frontend.") parser.add_argument( - '--kv-cache-dtype', - type=str, - choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], - default="auto", - help='Data type for kv cache storage. If "auto", will use model ' - 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' - 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') + "--disable-detokenize", + action="store_true", + help=("Do not detokenize the response (i.e. do not include " + "detokenization time in the measurement)")) + # LoRA parser.add_argument( - '--quantization-param-path', + "--lora-path", type=str, default=None, - help='Path to the JSON file containing the KV cache scaling factors. ' - 'This should generally be supplied, when KV cache dtype is FP8. ' - 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' - 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' - 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' - 'instead supported for common inference criteria.') - parser.add_argument( - "--device", - type=str, - default="auto", - choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"], - help='device type for vLLM execution, supporting CUDA, OpenVINO and ' - 'CPU.') - parser.add_argument( - "--enable-prefix-caching", - action='store_true', - help="enable automatic prefix caching for vLLM backend.") - parser.add_argument("--enable-chunked-prefill", - action='store_true', - help="enable chunked prefill for vLLM backend.") - parser.add_argument('--max-num-batched-tokens', + help="Path to the lora adapters to use. This can be an absolute path, " + "a relative path, or a Hugging Face model identifier.") + parser.add_argument("--prefix-len", type=int, default=None, - help='maximum number of batched tokens per ' - 'iteration') - parser.add_argument('--download-dir', - type=str, - default=None, - help='directory to download and load the weights, ' - 'default to the default cache dir of huggingface') - parser.add_argument( - '--output-json', - type=str, - default=None, - help='Path to save the throughput results in JSON format.') + help="Number of prefix tokens per request." + "This is for the RandomDataset and SonnetDataset") + # random dataset parser.add_argument( - '--distributed-executor-backend', - choices=['ray', 'mp'], + "--random-range-ratio", + type=float, default=None, - help='Backend to use for distributed serving. When more than 1 GPU ' - 'is used, will be automatically set to "ray" if installed ' - 'or "mp" (multiprocessing) otherwise.') - parser.add_argument( - '--load-format', - type=str, - default=EngineArgs.load_format, - choices=[ - 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer', - 'bitsandbytes' - ], - help='The format of the model weights to load.\n\n' - '* "auto" will try to load the weights in the safetensors format ' - 'and fall back to the pytorch bin format if safetensors format ' - 'is not available.\n' - '* "pt" will load the weights in the pytorch bin format.\n' - '* "safetensors" will load the weights in the safetensors format.\n' - '* "npcache" will load the weights in pytorch format and store ' - 'a numpy cache to speed up the loading.\n' - '* "dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.\n' - '* "tensorizer" will load the weights using tensorizer from ' - 'CoreWeave. See the Tensorize vLLM Model script in the Examples' - 'section for more information.\n' - '* "bitsandbytes" will load the weights using bitsandbytes ' - 'quantization.\n') + help="Range of sampled ratio of input/output length, " + "used only for RandomDataSet.", + ) + + # hf dtaset + parser.add_argument("--hf-subset", + type=str, + default=None, + help="Subset of the HF dataset.") + parser.add_argument("--hf-split", + type=str, + default=None, + help="Split of the HF dataset.") + + parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model - if args.dataset is None: - assert args.input_len is not None - assert args.output_len is not None - else: - assert args.input_len is None - - if args.backend == "vllm": - if args.hf_max_batch_size is not None: - raise ValueError("HF max batch size is only for HF backend.") - elif args.backend == "hf": - if args.hf_max_batch_size is None: - raise ValueError("HF max batch size is required for HF backend.") - if args.quantization is not None: - raise ValueError("Quantization is only for vLLM backend.") - elif args.backend == "mii": - if args.dtype != "auto": - raise ValueError("dtype must be auto for MII backend.") - if args.n != 1: - raise ValueError("n must be 1 for MII backend.") - if args.use_beam_search: - raise ValueError("Beam search is not supported for MII backend.") - if args.quantization is not None: - raise ValueError("Quantization is only for vLLM backend.") - if args.hf_max_batch_size is not None: - raise ValueError("HF max batch size is only for HF backend.") - if args.tokenizer != args.model: - raise ValueError("Tokenizer must be the same as the model for MII " - "backend.") + validate_args(args) main(args) diff --git a/benchmarks/benchmark_utils.py b/benchmarks/benchmark_utils.py new file mode 100644 index 000000000000..45a0ddbd5d08 --- /dev/null +++ b/benchmarks/benchmark_utils.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import json +import math +import os +from typing import Any + + +def convert_to_pytorch_benchmark_format(args: argparse.Namespace, + metrics: dict[str, list], + extra_info: dict[str, Any]) -> list: + """ + Save the benchmark results in the format used by PyTorch OSS benchmark with + on metric per record + https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database + """ + records = [] + if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False): + return records + + for name, benchmark_values in metrics.items(): + record = { + "benchmark": { + "name": "vLLM benchmark", + "extra_info": { + "args": vars(args), + }, + }, + "model": { + "name": args.model, + }, + "metric": { + "name": name, + "benchmark_values": benchmark_values, + "extra_info": extra_info, + }, + } + + tp = record["benchmark"]["extra_info"]["args"].get( + "tensor_parallel_size") + # Save tensor_parallel_size parameter if it's part of the metadata + if not tp and "tensor_parallel_size" in extra_info: + record["benchmark"]["extra_info"]["args"][ + "tensor_parallel_size"] = extra_info["tensor_parallel_size"] + + records.append(record) + + return records + + +class InfEncoder(json.JSONEncoder): + + def clear_inf(self, o: Any): + if isinstance(o, dict): + return {k: self.clear_inf(v) for k, v in o.items()} + elif isinstance(o, list): + return [self.clear_inf(v) for v in o] + elif isinstance(o, float) and math.isinf(o): + return "inf" + return o + + def iterencode(self, o: Any, *args, **kwargs) -> Any: + return super().iterencode(self.clear_inf(o), *args, **kwargs) + + +def write_to_json(filename: str, records: list) -> None: + with open(filename, "w") as f: + json.dump(records, f, cls=InfEncoder) diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py new file mode 100644 index 000000000000..9e36b0a9d3bb --- /dev/null +++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py @@ -0,0 +1,387 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import copy +import itertools +import pickle as pkl +import time +from collections.abc import Iterable +from typing import Callable + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_rand_sparse_tensors +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] +DEFAULT_TP_SIZES = [1] + + +# bench +def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, + **kwargs) -> TMeasurement: + min_run_time = 1 + + globals = { + "args": args, + "kwargs": kwargs, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(*args, **kwargs)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.int8 + b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, + torch.bfloat16) + out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) + + if not torch.allclose(out, out_ref): + print("Incorrect results") + print(out) + print(out_ref) + else: + print("Correct results") + + timers = [] + # pytorch impl - bfloat16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16), + b.to(dtype=torch.bfloat16))) + + # pytorch impl - float16 + timers.append( + bench_fn(label, sub_label, + "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, + a.to(dtype=torch.float16), b.to(dtype=torch.float16))) + + # cutlass impl + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass with bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, + bias)) + + # cutlass sparse impl + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16)) + + # cutlass sparse with bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16, bias)) + + return timers + + +def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.float8_e4m3fn + b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, + k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, + torch.bfloat16) + out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) + + if not torch.allclose(out, out_ref): + print("Incorrect results") + print(out) + print(out_ref) + else: + print("Correct results") + + timers = [] + + # pytorch impl w. bf16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"))) + + # pytorch impl: bf16 output, without fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16)) + + # pytorch impl: bf16 output, with fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True)) + + # pytorch impl: fp16 output, without fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16)) + + # pytorch impl: fp16 output, with fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16, + use_fast_accum=True)) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16)) + + # cutlass impl: fp16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.float16)) + + # cutlass impl: bf16 output, with bias + timers.append( + bench_fn(label, sub_label, + "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16, bias)) + + # cutlass impl: fp16 output, with bias + timers.append( + bench_fn(label, sub_label, + "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.float16, bias.to(dtype=torch.float16))) + + return timers + + +def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + if dtype == torch.int8: + return bench_int8(dtype, m, k, n, label, sub_label) + if dtype == torch.float8_e4m3fn: + return bench_fp8(dtype, m, k, n, label, sub_label) + raise ValueError("unsupported type") + + +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run(dtype: torch.dtype, + MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]: + results = [] + for m, k, n in MKNs: + timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", + f"MKN=({m}x{k}x{n})") + print_timers(timers) + results.extend(timers) + + return results + + +# output makers +def make_output(data: Iterable[TMeasurement], + MKNs: Iterable[tuple[int, int, int]], + base_description: str, + timestamp=None): + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +# argparse runners + + +def run_square_bench(args): + dim_sizes = list( + range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) + n = len(dim_sizes) + Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes + Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes + Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes + MKNs = list(zip(Ms, Ks, Ns)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args.dtype, MKNs) + model_bench_data.append(data) + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print_timers(data) + + timestamp = int(time.time()) + + all_data = [] + for d in model_bench_data: + all_data.extend(d) + # pickle all data + with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: + pkl.dump(all_data, f) + + +if __name__ == '__main__': + + def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float8_e4m3fn + raise ValueError("unsupported dtype") + + parser = FlexibleArgumentParser( + description=""" +Benchmark Cutlass GEMM. + + To run square GEMMs: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter) + + parser.add_argument("--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8']") + subparsers = parser.add_subparsers(dest="cmd") + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--dim-start", type=int, required=True) + range_parser.add_argument("--dim-end", type=int, required=True) + range_parser.add_argument("--dim-increment", type=int, required=True) + range_parser.add_argument("--m-constant", type=int, default=None) + range_parser.add_argument("--n-constant", type=int, default=None) + range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument("--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys()) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/benchmarks/cutlass_benchmarks/utils.py b/benchmarks/cutlass_benchmarks/utils.py new file mode 100644 index 000000000000..fe4d8fdfc066 --- /dev/null +++ b/benchmarks/cutlass_benchmarks/utils.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Cutlass bench utils +from collections.abc import Iterable + +import torch + +import vllm._custom_ops as ops + + +def to_fp8(tensor: torch.Tensor) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def to_int8(tensor: torch.Tensor) -> torch.Tensor: + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def to_bf16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.bfloat16) + + +def to_fp16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.float16) + + +def make_rand_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + if dtype == torch.int8: + return to_int8(a), to_int8(b) + if dtype == torch.float8_e4m3fn: + return to_fp8(a), to_fp8(b) + + raise ValueError("unsupported dtype") + + +def prune_to_2_4(tensor): + # Reshape tensor to [N, 4] where N is number of groups of 4 + original_shape = tensor.shape + reshaped = tensor.reshape(-1, 4) + + # Get indices of top 2 absolute values in each group of 4 + _, indices = torch.topk(torch.abs(reshaped), k=2, dim=1) + + # Create binary mask + mask = torch.zeros_like(reshaped) + mask.scatter_(dim=1, + index=indices, + src=torch.ones_like(indices, dtype=mask.dtype)) + + # Apply mask and reshape back + pruned = reshaped * mask + + # Turn all -0.0 to 0.0 + pruned[pruned == -0.0] = 0.0 + + return pruned.reshape(original_shape) + + +def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + b = prune_to_2_4(b.t()).t() + + if dtype == torch.int8: + a, b = to_int8(a), to_int8(b) + elif dtype == torch.float8_e4m3fn: + a, b = to_fp8(a), to_fp8(b) + elif dtype == torch.float16: + a, b = to_fp16(a), to_fp16(b) + elif dtype == torch.bfloat16: + a, b = to_bf16(a), to_bf16(b) + else: + raise ValueError("unsupported dtype") + + b_compressed, e = ops.cutlass_sparse_compress(b.t()) + + # Compressed B, Metadata, Original A, B + return b_compressed, e, a, b + + +def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype, + m: int, n: int, k: int) -> \ + tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: + ABs = [] + for _ in range(num_tensors): + b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) + if b_comp is not None: + ABs.append(make_rand_sparse_tensors(dtype, m, n, k)) + BComps, Es, As, Bs = zip(*ABs) + return list(BComps), list(Es), list(As), list(Bs) diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 70247e94e63c..e7b742d8bec9 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -1,102 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 + import argparse import copy import itertools import pickle as pkl import time -from typing import Callable, Iterable, List, Tuple +from collections.abc import Iterable +from typing import Callable, Optional import torch import torch.utils.benchmark as TBenchmark from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_rand_tensors from weight_shapes import WEIGHT_SHAPES from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + w8a8_block_fp8_matmul) from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] DEFAULT_TP_SIZES = [1] -# helpers - - -def to_fp8(tensor: torch.Tensor) -> torch.Tensor: - finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) - - -def to_int8(tensor: torch.Tensor) -> torch.Tensor: - return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) - - -def make_rand_tensors(dtype: torch.dtype, m: int, n: int, - k: int) -> Tuple[torch.Tensor, torch.Tensor]: - - a = torch.randn((m, k), device='cuda') * 5 - b = torch.randn((n, k), device='cuda').t() * 5 - - if dtype == torch.int8: - return to_int8(a), to_int8(b) - if dtype == torch.float8_e4m3fn: - return to_fp8(a), to_fp8(b) - - raise ValueError("unsupported dtype") - - -# impl - - -def pytorch_mm_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype) -> torch.Tensor: - return torch.mm(a, b) - - -def pytorch_fp8_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype) -> torch.Tensor: - return torch._scaled_mm(a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=out_dtype) - - -def pytorch_fp8_impl_fast_accum(a: torch.Tensor, b: torch.Tensor, - scale_a: torch.Tensor, scale_b: torch.Tensor, - out_dtype: torch.dtype) -> torch.Tensor: - return torch._scaled_mm(a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=out_dtype, - use_fast_accum=True) - - -def cutlass_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype) -> torch.Tensor: - return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype) - # bench -def bench_fn(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, - scale_b: torch.Tensor, out_dtype: torch.dtype, label: str, - sub_label: str, fn: Callable, description: str) -> TMeasurement: - +def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, + **kwargs) -> TMeasurement: min_run_time = 1 globals = { - "a": a, - "b": b, - "scale_a": scale_a, - "scale_b": scale_b, - "out_dtype": out_dtype, + "args": args, + "kwargs": kwargs, "fn": fn, } return TBenchmark.Timer( - stmt="fn(a, b, scale_a, scale_b, out_dtype)", + stmt="fn(*args, **kwargs)", globals=globals, label=label, sub_label=sub_label, @@ -104,84 +43,149 @@ def bench_fn(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, ).blocked_autorange(min_run_time=min_run_time) -def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: +def bench_int8( + dtype: torch.dtype, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: + """Benchmark INT8-based kernels.""" assert dtype == torch.int8 a, b = make_rand_tensors(torch.int8, m, n, k) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + azp = torch.zeros((m, ), device="cuda", dtype=torch.int32) + azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32) + + bench_fns = { + "pytorch_bf16_bf16_bf16_matmul-no-scales": + lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) + ), + "pytorch_fp16_fp16_fp16_matmul-no-scales": + lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)), + "cutlass_i8_i8_bf16_scaled_mm": + lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16), + "cutlass_i8_i8_bf16_scaled_mm_bias": + lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16, + bias), + "cutlass_i8_i8_bf16_scaled_mm_azp": + lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. + bfloat16, azp_adj), + "cutlass_i8_i8_bf16_scaled_mm_azp_bias": + lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. + bfloat16, azp_adj, None, bias), + "cutlass_i8_i8_bf16_scaled_mm_azp_pt": + lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. + bfloat16, azp_adj, azp), + "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": + lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. + bfloat16, azp_adj, azp, bias), + } timers = [] - # pytorch impl - timers.append( - bench_fn(a.to(dtype=torch.bfloat16, device="cuda"), - b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b, - torch.bfloat16, label, sub_label, pytorch_mm_impl, - "pytorch_bf16_bf16_bf16_matmul-no-scales")) - - # cutlass impl - timers.append( - bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, - cutlass_impl, "cutlass_i8_i8_bf16_scaled_mm")) + for name, fn in bench_fns.items(): + # If bench_kernels is None, run all. Otherwise, run only exact matches. + if bench_kernels is None or name in bench_kernels: + print(f"Running {name}") + timers.append(bench_fn(label, sub_label, name, fn)) return timers -def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: +def bench_fp8( + dtype: torch.dtype, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: + """Benchmark FP8-based kernels.""" assert dtype == torch.float8_e4m3fn a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) + a_cont = a.contiguous() scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + block_scale_a = torch.rand((m, k // 128), + device="cuda", + dtype=torch.float32) + block_scale_b = torch.rand((k // 128, n // 128), + device="cuda", + dtype=torch.float32) + block_scale_a_M_major = block_scale_a.t().contiguous().t() + block_scale_b_K_major = block_scale_b.t().contiguous().t() + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + print(m, k, n) + + bench_fns = { + "pytorch_bf16_bf16_bf16_matmul-no-scales": + lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) + ), + "pytorch_fp16_fp16_fp16_matmul-no-scales": + lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)), + "pytorch_fp8_fp8_fp16_scaled_mm": + lambda: torch._scaled_mm( + a, b, scale_a, scale_b, out_dtype=torch.float16), + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": + lambda: torch._scaled_mm(a, + b, + scale_a, + scale_b, + out_dtype=torch.float16, + use_fast_accum=True), + "pytorch_fp8_fp8_bf16_scaled_mm": + lambda: torch._scaled_mm( + a, b, scale_a, scale_b, out_dtype=torch.bfloat16), + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": + lambda: torch._scaled_mm(a, + b, + scale_a, + scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True), + "cutlass_fp8_fp8_bf16_scaled_mm": + lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16), + "cutlass_fp8_fp8_fp16_scaled_mm": + lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16), + "cutlass_fp8_fp8_bf16_scaled_mm_bias": + lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16, + bias), + "cutlass_fp8_fp8_fp16_scaled_mm_bias": + lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16, + bias.to(dtype=torch.float16)), + "triton_fp8_fp8_fp16_scaled_mm_blockwise": + lambda: w8a8_block_fp8_matmul(a_cont, b.t(), block_scale_a, + block_scale_b.t(), (128, 128)), + "cutlass_fp8_fp8_fp16_scaled_mm_blockwise": + lambda: ops.cutlass_scaled_mm(a, b, block_scale_a_M_major, + block_scale_b_K_major, torch.float16), + } timers = [] + for name, fn in bench_fns.items(): + # If bench_kernels is None, run all. Otherwise, run only exact matches. + if bench_kernels is None or name in bench_kernels: + print(f"Running {name}") + timers.append(bench_fn(label, sub_label, name, fn)) - # pytorch impl w. bf16 - timers.append( - bench_fn(a.to(dtype=torch.bfloat16, device="cuda"), - b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b, - torch.bfloat16, label, sub_label, pytorch_mm_impl, - "pytorch_bf16_bf16_bf16_matmul-no-scales")) - - # pytorch impl: bf16 output, without fp8 fast accum - timers.append( - bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, - pytorch_fp8_impl, "pytorch_fp8_fp8_bf16_scaled_mm")) - - # pytorch impl: bf16 output, with fp8 fast accum - timers.append( - bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, - pytorch_fp8_impl_fast_accum, - "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum")) - - # pytorch impl: fp16 output, without fp8 fast accum - timers.append( - bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, - pytorch_fp8_impl, "pytorch_fp8_fp8_fp16_scaled_mm")) - - # pytorch impl: fp16 output, with fp8 fast accum - timers.append( - bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, - pytorch_fp8_impl_fast_accum, - "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum")) - - # cutlass impl: bf16 output - timers.append( - bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, - cutlass_impl, "cutlass_fp8_fp8_bf16_scaled_mm")) - # cutlass impl: fp16 output - timers.append( - bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, - cutlass_impl, "cutlass_fp8_fp8_fp16_scaled_mm")) return timers -def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: +def bench(dtype: torch.dtype, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: if dtype == torch.int8: - return bench_int8(dtype, m, k, n, label, sub_label) + return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels) if dtype == torch.float8_e4m3fn: - return bench_fp8(dtype, m, k, n, label, sub_label) + return bench_fp8(dtype, m, k, n, label, sub_label, bench_kernels) raise ValueError("unsupported type") @@ -192,24 +196,26 @@ def print_timers(timers: Iterable[TMeasurement]): def run(dtype: torch.dtype, - MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: - + MKNs: Iterable[tuple[int, int, int]], + bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: results = [] for m, k, n in MKNs: - timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", - f"MKN=({m}x{k}x{n})") + timers = bench(dtype, + m, + k, + n, + f"scaled-{dtype}-gemm", + f"MKN=({m}x{k}x{n})", + bench_kernels=bench_kernels) print_timers(timers) results.extend(timers) - return results -# output makers def make_output(data: Iterable[TMeasurement], - MKNs: Iterable[Tuple[int, int, int]], + MKNs: Iterable[tuple[int, int, int]], base_description: str, timestamp=None): - print(f"== All Results {base_description} ====") print_timers(data) @@ -219,15 +225,11 @@ def make_output(data: Iterable[TMeasurement], pkl.dump(data, f) -# argparse runners - - def run_square_bench(args): dim_sizes = list( range(args.dim_start, args.dim_end + 1, args.dim_increment)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) - data = run(args.dtype, MKNs) - + data = run(args.dtype, MKNs, bench_kernels=args.kernels) make_output(data, MKNs, f"square_bench-{args.dtype}") @@ -238,18 +240,16 @@ def run_range_bench(args): Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes MKNs = list(zip(Ms, Ks, Ns)) - data = run(args.dtype, MKNs) - + data = run(args.dtype, MKNs, bench_kernels=args.kernels) make_output(data, MKNs, f"range_bench-{args.dtype}") def run_model_bench(args): - print("Benchmarking models:") for i, model in enumerate(args.models): print(f"[{i}] {model}") - def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]: KNs = [] for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): KN[tp_split_dim] = KN[tp_split_dim] // tp_size @@ -266,7 +266,7 @@ def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: for k, n in KNs: MKNs.append((m, k, n)) - data = run(args.dtype, MKNs) + data = run(args.dtype, MKNs, bench_kernels=args.kernels) model_bench_data.append(data) # Print all results @@ -316,6 +316,15 @@ def to_torch_dtype(dt): type=to_torch_dtype, required=True, help="Available options are ['int8', 'fp8']") + parser.add_argument( + "--kernels", + nargs="+", + type=str, + default=None, + help= + "Exact names of the kernels to benchmark. If not set, runs all kernels." + ) + subparsers = parser.add_subparsers(dest="cmd") square_parser = subparsers.add_parser("square_bench") diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py index 25ec9d602862..3d1121df40d0 100644 --- a/benchmarks/cutlass_benchmarks/weight_shapes.py +++ b/benchmarks/cutlass_benchmarks/weight_shapes.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + # Weight Shapes are in the format # ([K, N], TP_SPLIT_DIM) # Example: @@ -40,4 +42,4 @@ ([8192, 57344], 1), ([28672, 8192], 0), ], -} +} \ No newline at end of file diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh new file mode 100644 index 000000000000..94999630bae1 --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh @@ -0,0 +1,145 @@ +#!/bin/bash + +# benchmark the overhead of disaggregated prefill. +# methodology: +# - send all request to prefill vLLM instance. It will buffer KV cache. +# - then send all request to decode instance. +# - The TTFT of decode instance is the overhead. + +set -ex + +kill_gpu_processes() { + # kill all processes on GPU. + pgrep pt_main_thread | xargs -r kill -9 + pgrep python3 | xargs -r kill -9 + sleep 10 + + # remove vllm config file + rm -rf ~/.config/vllm + + # Print the GPU memory usage + # so that we know if all GPU processes are killed. + gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) + # The memory usage should be 0 MB. + echo "GPU 0 Memory Usage: $gpu_memory_usage MB" +} + +wait_for_server() { + # wait for vllm server to start + # return 1 if vllm server crashes + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + + +benchmark() { + + export VLLM_LOGGING_LEVEL=DEBUG + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + + # compare chunked prefill with disaggregated prefill + + results_folder="./results" + model="meta-llama/Meta-Llama-3.1-8B-Instruct" + dataset_name="sonnet" + dataset_path="../sonnet_4x.txt" + num_prompts=10 + qps=$1 + prefix_len=50 + input_len=2048 + output_len=$2 + + + CUDA_VISIBLE_DEVICES=0 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8100 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.6 \ + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + + + CUDA_VISIBLE_DEVICES=1 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8200 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.6 \ + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + + wait_for_server 8100 + wait_for_server 8200 + + # let the prefill instance finish prefill + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len "$output_len" \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8100 \ + --save-result \ + --result-dir $results_folder \ + --result-filename disagg_prefill_tp1.json \ + --request-rate "inf" + + + # send the request to decode. + # The TTFT of this command will be the overhead of disagg prefill impl. + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len "$output_len" \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8200 \ + --save-result \ + --result-dir $results_folder \ + --result-filename disagg_prefill_tp1_overhead.json \ + --request-rate "$qps" + kill_gpu_processes + +} + + +main() { + + (which wget && which curl) || (apt-get update && apt-get install -y wget curl) + (which jq) || (apt-get -y install jq) + (which socat) || (apt-get -y install socat) + + pip install quart httpx datasets + + cd "$(dirname "$0")" + + cd .. + # create sonnet-4x.txt + echo "" > sonnet_4x.txt + for _ in {1..4} + do + cat sonnet.txt >> sonnet_4x.txt + done + cd disagg_benchmarks + + rm -rf results + mkdir results + + default_qps=1 + default_output_len=1 + benchmark $default_qps $default_output_len + +} + + +main "$@" diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh new file mode 100644 index 000000000000..eb5d891d0d4a --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -0,0 +1,163 @@ +#!/bin/bash + +# Requirement: 2x GPUs. + + +# Model: meta-llama/Meta-Llama-3.1-8B-Instruct +# Query: 1024 input tokens, 6 output tokens, QPS 2/4/6/8, 100 requests +# Resource: 2x GPU +# Approaches: +# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4 +# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance +# Prefilling instance: max_output_token=1 +# Decoding instance: force the input tokens be the same across requests to bypass prefilling + +set -ex + +kill_gpu_processes() { + # kill all processes on GPU. + pgrep pt_main_thread | xargs -r kill -9 + pgrep python3 | xargs -r kill -9 + for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done + sleep 1 +} + +wait_for_server() { + # wait for vllm server to start + # return 1 if vllm server crashes + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + + +launch_chunked_prefill() { + model="meta-llama/Meta-Llama-3.1-8B-Instruct" + # disagg prefill + CUDA_VISIBLE_DEVICES=0 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8100 \ + --max-model-len 10000 \ + --enable-chunked-prefill \ + --gpu-memory-utilization 0.6 & + CUDA_VISIBLE_DEVICES=1 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8200 \ + --max-model-len 10000 \ + --enable-chunked-prefill \ + --gpu-memory-utilization 0.6 & + wait_for_server 8100 + wait_for_server 8200 + python3 round_robin_proxy.py & + sleep 1 +} + + +launch_disagg_prefill() { + model="meta-llama/Meta-Llama-3.1-8B-Instruct" + # disagg prefill + CUDA_VISIBLE_DEVICES=0 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8100 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.6 \ + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + + CUDA_VISIBLE_DEVICES=1 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8200 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.6 \ + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + + wait_for_server 8100 + wait_for_server 8200 + python3 disagg_prefill_proxy_server.py & + sleep 1 +} + + +benchmark() { + results_folder="./results" + model="meta-llama/Meta-Llama-3.1-8B-Instruct" + dataset_name="sonnet" + dataset_path="../sonnet_4x.txt" + num_prompts=100 + qps=$1 + prefix_len=50 + input_len=1024 + output_len=$2 + tag=$3 + + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len "$output_len" \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8000 \ + --save-result \ + --result-dir $results_folder \ + --result-filename "$tag"-qps-"$qps".json \ + --request-rate "$qps" + + sleep 2 +} + + +main() { + + (which wget && which curl) || (apt-get update && apt-get install -y wget curl) + (which jq) || (apt-get -y install jq) + (which socat) || (apt-get -y install socat) + (which lsof) || (apt-get -y install lsof) + + pip install quart httpx matplotlib aiohttp datasets + + cd "$(dirname "$0")" + + cd .. + # create sonnet-4x.txt so that we can sample 2048 tokens for input + echo "" > sonnet_4x.txt + for _ in {1..4} + do + cat sonnet.txt >> sonnet_4x.txt + done + cd disagg_benchmarks + + rm -rf results + mkdir results + + default_output_len=6 + + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + + launch_chunked_prefill + for qps in 2 4 6 8; do + benchmark $qps $default_output_len chunked_prefill + done + kill_gpu_processes + + launch_disagg_prefill + for qps in 2 4 6 8; do + benchmark $qps $default_output_len disagg_prefill + done + kill_gpu_processes + + python3 visualize_benchmark_results.py + +} + + +main "$@" diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py new file mode 100644 index 000000000000..980e68668911 --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os + +import aiohttp +from quart import Quart, make_response, request + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + +app = Quart(__name__) + + +async def forward_request(url, data): + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" + } + async with session.post(url=url, json=data, + headers=headers) as response: + if response.status == 200: + # if response.headers.get('Transfer-Encoding') == 'chunked': + if True: + async for chunk_bytes in response.content.iter_chunked( + 1024): + yield chunk_bytes + else: + content = await response.read() + yield content + + +@app.route('/v1/completions', methods=['POST']) +async def handle_request(): + try: + original_request_data = await request.get_json() + + prefill_request = original_request_data.copy() + # change max_tokens = 1 to let it only do prefill + prefill_request['max_tokens'] = 1 + + # finish prefill + async for _ in forward_request('http://localhost:8100/v1/completions', + prefill_request): + continue + + # return decode + generator = forward_request('http://localhost:8200/v1/completions', + original_request_data) + response = await make_response(generator) + response.timeout = None + + return response + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server") + print(e) + print("".join(traceback.format_exception(*exc_info))) + + +if __name__ == '__main__': + app.run(port=8000) diff --git a/benchmarks/disagg_benchmarks/round_robin_proxy.py b/benchmarks/disagg_benchmarks/round_robin_proxy.py new file mode 100644 index 000000000000..c2ad4916bf07 --- /dev/null +++ b/benchmarks/disagg_benchmarks/round_robin_proxy.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import itertools + +import aiohttp +from aiohttp import web + + +class RoundRobinProxy: + + def __init__(self, target_ports): + self.target_ports = target_ports + self.port_cycle = itertools.cycle(self.target_ports) + + async def handle_request(self, request): + target_port = next(self.port_cycle) + target_url = f"http://localhost:{target_port}{request.path_qs}" + + async with aiohttp.ClientSession() as session: + try: + # Forward the request + async with session.request( + method=request.method, + url=target_url, + headers=request.headers, + data=request.content, + ) as response: + # Start sending the response + resp = web.StreamResponse(status=response.status, + headers=response.headers) + await resp.prepare(request) + + # Stream the response content + async for chunk in response.content.iter_any(): + await resp.write(chunk) + + await resp.write_eof() + return resp + + except Exception as e: + return web.Response(text=f"Error: {str(e)}", status=500) + + +async def main(): + proxy = RoundRobinProxy([8100, 8200]) + app = web.Application() + app.router.add_route('*', '/{path:.*}', proxy.handle_request) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, 'localhost', 8000) + await site.start() + + print("Proxy server started on http://localhost:8000") + + # Keep the server running + await asyncio.Event().wait() + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py new file mode 100644 index 000000000000..a7b4b9e8bf30 --- /dev/null +++ b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json + +import matplotlib.pyplot as plt +import pandas as pd + +if __name__ == "__main__": + + data = [] + for name in ['disagg_prefill', 'chunked_prefill']: + for qps in [2, 4, 6, 8]: + with open(f"results/{name}-qps-{qps}.json") as f: + x = json.load(f) + x['name'] = name + x['qps'] = qps + data.append(x) + + df = pd.DataFrame.from_dict(data) + dis_df = df[df['name'] == 'disagg_prefill'] + chu_df = df[df['name'] == 'chunked_prefill'] + + plt.style.use('bmh') + plt.rcParams['font.size'] = 20 + + for key in [ + 'mean_ttft_ms', 'median_ttft_ms', 'p99_ttft_ms', 'mean_itl_ms', + 'median_itl_ms', 'p99_itl_ms' + ]: + + fig, ax = plt.subplots(figsize=(11, 7)) + plt.plot(dis_df['qps'], + dis_df[key], + label='disagg_prefill', + marker='o', + linewidth=4) + plt.plot(chu_df['qps'], + chu_df[key], + label='chunked_prefill', + marker='o', + linewidth=4) + ax.legend() + + ax.set_xlabel('QPS') + ax.set_ylabel(key) + ax.set_ylim(bottom=0) + fig.savefig(f'results/{key}.png') + plt.close(fig) diff --git a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py new file mode 100644 index 000000000000..3da583a33448 --- /dev/null +++ b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py @@ -0,0 +1,176 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pickle as pkl +import time +from collections.abc import Iterable +from dataclasses import dataclass +from itertools import product +from typing import Callable, Optional + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from tqdm import tqdm + +import vllm._custom_ops as ops +from vllm.model_executor.layers.layernorm import RMSNorm + + +@dataclass +class bench_params_t: + num_tokens: int + hidden_size: int + add_residual: bool + dtype: torch.dtype + + def description(self): + return (f'N {self.num_tokens} ' + f'x D {self.hidden_size} ' + f'x R {self.add_residual} ' + f'x DT {self.dtype}') + + +def get_bench_params() -> list[bench_params_t]: + ## Test Fixtures + NUM_TOKENS = [2**x for x in range(11)] + HIDDEN_SIZES = list(range(1024, 8129, 1024)) + ADD_RESIDUAL = [True, False] + DTYPES = [torch.bfloat16, torch.float] + + combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES) + bench_params = list(map(lambda x: \ + bench_params_t(x[0], x[1], x[2], x[3]), combinations)) + return bench_params + + +# Reference impls +def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype): + # Norm + torch_out = None + if residual is None: + torch_out = rms_norm_layer.forward_cuda(x, residual) + else: + torch_out, _ = rms_norm_layer.forward_cuda(x, residual) + + # Quant + torch_out, _, _ = ops.scaled_int8_quant(torch_out) + + +def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype): + # Norm + torch_out = None + if residual is None: + torch_out = rms_norm_layer.forward_cuda(x, residual) + else: + torch_out, _ = rms_norm_layer.forward_cuda(x, residual) + + # Quant + torch_out, _ = ops.scaled_fp8_quant(torch_out) + + +def fused_impl( + rms_norm_layer: RMSNorm, # this stores the weights + x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype): + out, _ = ops.rms_norm_dynamic_per_token_quant(x, + rms_norm_layer.weight, + 1e-6, + quant_dtype, + residual=residual) + + +# Bench functions +def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor, + quant_dtype: torch.dtype, label: str, sub_label: str, + fn: Callable, description: str) -> TMeasurement: + + min_run_time = 1 + + globals = { + "rms_norm_layer": rms_norm_layer, + "x": x, + "residual": residual, + "quant_dtype": quant_dtype, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(rms_norm_layer, x, residual, quant_dtype)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + +def bench(params: bench_params_t, label: str, sub_label: str) \ + -> Iterable[TMeasurement]: + + # Make inputs + layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype) + # Make weights + layer.weight.data.normal_(mean=1.0, std=0.1) + # Make inputs + scale = 1 / params.hidden_size + x = torch.randn(params.num_tokens, + params.hidden_size, + dtype=params.dtype, + device='cuda') * scale + residual = (torch.randn_like(x) * scale).to(device='cuda') \ + if params.add_residual else None + + timers = [] + + # unfused int8 impl. + timers.append( + bench_fn(layer, x, residual, torch.int8, label, sub_label, + unfused_int8_impl, "unfused_int8_impl")) + + # unfused fp8 impl. + timers.append( + bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, + unfused_fp8_impl, "unfused_fp8_impl")) + + # fused int8 impl. + timers.append( + bench_fn(layer, x, residual, torch.int8, label, sub_label, fused_impl, + "fused_int8_impl")) + + # fused fp8 impl. + timers.append( + bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, + fused_impl, "fused_fp8_impl")) + + print_timers(timers) + + return timers + + +# launch bench +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def main(): + torch.set_default_device('cuda') + bench_params = get_bench_params() + + timers = [] + for bp in tqdm(bench_params): + timers.extend( + bench(bp, "rms-norm-dynamic-per-token-quant", bp.description())) + print_timers(timers) + + # pickle all the results + timestamp = int(time.time()) + with open(f"rms_norm_dpt_quant-{timestamp}.pkl", "wb") as f: + pkl.dump(timers, f) + + +if __name__ == '__main__': + main() diff --git a/benchmarks/kernels/benchmark_aqlm.py b/benchmarks/kernels/benchmark_aqlm.py index 601c4ea439ae..8d20b91560dd 100644 --- a/benchmarks/kernels/benchmark_aqlm.py +++ b/benchmarks/kernels/benchmark_aqlm.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + import os import sys from typing import Optional diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py new file mode 100644 index 000000000000..e12d74c01e43 --- /dev/null +++ b/benchmarks/kernels/benchmark_layernorm.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 + +import time + +import torch + +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.platforms import current_platform +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser + + +@torch.inference_mode() +def main(num_tokens: int, + hidden_size: int, + add_residual: bool, + dtype: torch.dtype, + seed: int = 0, + do_profile: bool = False, + num_warmup_iters: int = 5, + num_iters: int = 100) -> None: + current_platform.seed_everything(seed) + torch.set_default_device("cuda") + + layer = RMSNorm(hidden_size).to(dtype=dtype) + layer.weight.data.normal_(mean=1.0, std=0.1) + scale = 1 / (2 * hidden_size) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + x *= scale + residual = torch.randn_like(x) * scale if add_residual else None + + def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: + torch.cuda.synchronize() + if profile: + torch.cuda.cudart().cudaProfilerStart() + start_time = time.perf_counter() + + for _ in range(num_iters): + layer(x, residual) + torch.cuda.synchronize() + + end_time = time.perf_counter() + if profile: + torch.cuda.cudart().cudaProfilerStop() + return (end_time - start_time) / num_iters + + # Warmup. + print("Warming up...") + run_benchmark = run_cuda_benchmark + run_benchmark(num_iters=num_warmup_iters, profile=False) + + # Benchmark. + if do_profile: + latency = run_benchmark(num_iters=1, profile=True) + else: + latency = run_benchmark(num_iters=num_iters, profile=False) + print(f"Kernel running time: {latency * 1000000:.3f} us") + + +if __name__ == '__main__': + parser = FlexibleArgumentParser( + description="Benchmark the layernorm kernel.") + parser.add_argument("--num-tokens", type=int, default=4096) + parser.add_argument("--hidden-size", type=int, default=8192) + parser.add_argument("--add-residual", action="store_true") + parser.add_argument("--dtype", + type=str, + choices=["half", "bfloat16", "float"], + default="half") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--profile", action="store_true") + parser.add_argument("--num-warmup-iters", type=int, default=5) + parser.add_argument("--num-iters", + type=int, + default=100, + help="Number of benchmark iterations. " + "If --profile is set, this number is ignored") + + args = parser.parse_args() + print(args) + + main(num_tokens=args.num_tokens, + hidden_size=args.hidden_size, + add_residual=args.add_residual, + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + seed=args.seed, + do_profile=args.profile, + num_warmup_iters=args.num_warmup_iters, + num_iters=args.num_iters) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py new file mode 100644 index 000000000000..b4b91eda2844 --- /dev/null +++ b/benchmarks/kernels/benchmark_lora.py @@ -0,0 +1,955 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import copy +import json +import pickle +import time +from dataclasses import dataclass +from enum import Enum, auto +from itertools import product +from pathlib import Path +from typing import Any, Callable, Optional + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from utils import ArgPool, Bench, CudaGraphBenchParams +from weight_shapes import WEIGHT_SHAPES + +from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink +from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) +DEFAULT_TP_SIZES = [1] +DEFAULT_BATCH_SIZES = [ + 1, 16, 32, 64, 128, 192, 256, 320, 384, 448, 512, 640, 768, 896, 1024, + 2048, 3072, 4096, 5120, 6144, 7168, 8192 +] +DEFAULT_HIDDEN_SIZES = [1024, 2048, 4096, 8192, 16384] +DEFAULT_LORA_RANKS = [16] +DEFAULT_NUM_LORAS = [1, 2, 3, 4] +DEFAULT_SORT_BY_LORA_IDS = [False, True] +DEFAULT_SEQ_LENGTHS = [1] +DEFAULT_EXPAND_FN_ADD_INPUTS = [True, False] + + +# Utilities +def dtype_to_str(dtype: torch.dtype): + if dtype == torch.float16: + return "f16" + if dtype == torch.bfloat16: + return "bf16" + if dtype == torch.float32: + return "f32" + raise ValueError(f"Unsupported dtype {dtype}") + + +def make_rand_lora_weight_tensor(k: int, + n: int, + num_loras: int, + dtype: torch.dtype, + device: str = "cuda") -> torch.Tensor: + + # LoRA weights column major + return torch.rand((num_loras, n, k), dtype=dtype).to(device) + + +def make_rand_tensors( + a_shape: tuple[int], + b_shape: tuple[int], + c_shape: tuple[int], + a_dtype: torch.dtype, + b_dtype: torch.dtype, + c_dtype: torch.dtype, + num_slices: int, + device: str = "cuda", +) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]: + """ + Make LoRA input/output matrices. + """ + A = torch.rand(a_shape, dtype=a_dtype).to(device) + + # LoRA weights column major + Bs = [ + torch.rand(b_shape, dtype=b_dtype).to(device) + for _ in range(num_slices) + ] + + C = torch.zeros(c_shape, dtype=c_dtype).to(device) + return A, Bs, C + + +def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int, + sort_by_lora_id: bool, + device: str) -> torch.Tensor: + """ + All prompts are mapped to a LoRA ID in range [0, num_active_loras). + where 0 refers to first lora, 1 refers to second lora and so on. + """ + assert num_active_loras > 0 + + if not sort_by_lora_id: + return torch.randint(0, + num_active_loras, (num_prompts, ), + dtype=torch.long) + + # Divide LoRAs equally and in order. + part_size = num_prompts // num_active_loras + part_size = max(part_size, 1) + + lora_id = 0 + prompt_lora_mapping = [] + while len(prompt_lora_mapping) < num_prompts: + prompt_lora_mapping.extend([lora_id] * part_size) + lora_id = lora_id + 1 if lora_id + 1 < num_active_loras else lora_id + return torch.tensor(prompt_lora_mapping[:num_prompts], + dtype=torch.long, + device=device) + + +def make_token_lora_mapping(num_tokens: int, num_prompts: int, + prompt_lora_mapping: torch.Tensor, + seq_len_tensor: torch.Tensor, device: str): + """ + Make token_lora_mapping from prompt_lora_mapping and seq_lens_tensor + """ + assert prompt_lora_mapping.shape[0] == num_prompts + + # token to lora index mapping + token_lora_mapping = [0] * num_tokens + current_offset = 0 + for b_id in range(num_prompts): + lora_index = prompt_lora_mapping[b_id].item() + s = current_offset + e = s + seq_len_tensor[b_id].item() + token_lora_mapping[s:e] = [lora_index] * (e - s) + current_offset += seq_len_tensor[b_id].item() + + return torch.tensor(token_lora_mapping, dtype=torch.long, device=device) + + +def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor, + lora_weights: list[torch.Tensor], + seq_lens_cpu: torch.Tensor, + prompt_lora_mapping_cpu: torch.Tensor, scaling: float, + add_inputs: Optional[bool]): + """ + Torch group gemm reference implementation to test correctness of + benchmarking operations. + """ + batches = seq_lens_cpu.size(0) + out_list = [] + current_offset = 0 + for lora_index, b_length in zip(range(batches), seq_lens_cpu): + x = input[current_offset:b_length + current_offset, :] + current_offset += b_length + w = lora_weights[prompt_lora_mapping_cpu[lora_index]] + result = torch.nn.functional.linear(x, w) + result *= scaling + out_list.append(result) + + cat_result = torch.cat(out_list, dim=0) + + if add_inputs: + ref_out += cat_result + else: + ref_out.copy_(cat_result) + + +class OpType(Enum): + """ + LoRA Ops to benchmark and its properties. + """ + LORA_SHRINK = auto() + LORA_EXPAND = auto() + + @staticmethod + def from_str(s: str) -> "OpType": + if s.lower() == "lora_shrink": + return OpType.LORA_SHRINK + if s.lower() == "lora_expand": + return OpType.LORA_EXPAND + raise ValueError(f"Unrecognized str {s} to convert to OpType") + + def is_shrink_fn(self) -> bool: + return self in [OpType.LORA_SHRINK] + + def is_expand_fn(self) -> bool: + return self in [OpType.LORA_EXPAND] + + def num_slices(self) -> list[int]: + return [1, 2, 3] + + def mkn(self, batch_size: int, seq_length: int, hidden_size: int, + lora_rank: int) -> tuple[int, int, int]: + num_tokens = batch_size * seq_length + if self.is_shrink_fn(): + m = num_tokens + k = hidden_size + n = lora_rank + else: + assert self.is_expand_fn() + m = num_tokens + k = lora_rank + n = hidden_size + return m, k, n + + def matmul_dtypes( + self, op_dtype: torch.dtype + ) -> tuple[torch.dtype, torch.dtype, torch.dtype]: + """ + return a type, b type and c type for A x B = C + """ + if self.is_shrink_fn(): + return op_dtype, op_dtype, torch.float32 + else: + assert self.is_expand_fn() + return torch.float32, op_dtype, op_dtype + + def matmul_shapes( + self, batch_size: int, seq_length: int, hidden_size: int, + lora_rank: int, num_loras: int, + num_slices: int) -> tuple[tuple[int], tuple[int], tuple[int]]: + """ + Given num_slices, return the shapes of the A, B, and C matrices + in A x B = C, for the op_type + """ + m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank) + + b_shape = (num_loras, n, k) # col-major + if self in [OpType.LORA_SHRINK]: + # LoRA shrink kernels support num_slices inherently in the kernel. + return ((m, k), b_shape, (num_slices, m, n)) + if self in [OpType.LORA_EXPAND]: + # LoRA expand kernels support num_slices inherently in the kernel + return ((num_slices, m, k), b_shape, (m, n * num_slices)) + raise ValueError(f"Unrecognized op_type {self}") + + def bench_fn(self) -> Callable: + if self == OpType.LORA_SHRINK: + return lora_shrink + if self == OpType.LORA_EXPAND: + return lora_expand + + raise ValueError(f"Unrecognized optype {self}") + + def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor, + lora_weights: list[torch.Tensor], + **kwargs) -> Callable: + """Each benchmark operation expects the input, lora_weights and outputs + in a slightly different format. Refer to self.matmul_shapes(). + run_ref_group_gemm accounts for those differences in executing a + reference group gemm for correctness testing. + """ + w_dtype = lora_weights[0].dtype + num_slices = len(lora_weights) + if self in [OpType.LORA_SHRINK]: + for slice_idx in range(num_slices): + ref_group_gemm(ref_out=output[slice_idx, :], + input=input, + lora_weights=lora_weights[slice_idx], + **kwargs) + elif self in [OpType.LORA_EXPAND]: + hidden_size = lora_weights[0].shape[1] + for slice_idx in range(num_slices): + slice_offset = slice_idx * hidden_size + ref_group_gemm( + ref_out=output[:, slice_offset:slice_offset + hidden_size], + input=input[slice_idx].clone().to(dtype=w_dtype), + lora_weights=lora_weights[slice_idx], + **kwargs) + else: + raise ValueError(f"Unrecognized optype {self}") + + +@dataclass +class BenchmarkContext: + """ + LoRA benchmark context + """ + batch_size: int + hidden_size: int + num_loras: int + num_active_loras: int + lora_rank: int + sort_by_lora_id: bool + dtype: torch.dtype + seq_length: Optional[int] = None + num_slices: Optional[int] = None # num_slices for slice based ops + + def with_seq_length(self, seq_length: int) -> "BenchmarkContext": + ctx = copy.copy(self) + ctx.seq_length = seq_length + return ctx + + def with_num_slices(self, num_slices: int) -> "BenchmarkContext": + ctx = copy.copy(self) + ctx.num_slices = num_slices + return ctx + + def bench_label(self) -> str: + return f"lora-{self.dtype}" + + def bench_sublabel(self, op_type: OpType) -> str: + m, k, n = op_type.mkn(self.batch_size, self.seq_length, + self.hidden_size, self.lora_rank) + desc = { + 'bs': self.batch_size, + 'sl': self.seq_length, + 'm': m, + 'k': k, + 'n': n, + 'num_loras': self.num_loras, + 'sort_by_lora': self.sort_by_lora_id, + 'num_slices': self.num_slices, + } + return json.dumps(desc) + + +@dataclass +class BenchmarkTensors: + """ + Input/Output tensors used for benchmarks + """ + # matmul tensors + input: torch.Tensor + lora_weights_lst: list[torch.Tensor] + output: torch.Tensor + # LoRA kernel metadata + lora_kernel_meta: LoRAKernelMeta + # Metadata tensors used in testing correctness + seq_lens: torch.Tensor + prompt_lora_mapping: torch.Tensor + + def io_types(self) -> str: + return (f"{dtype_to_str(self.input.dtype)}x" + f"{dtype_to_str(self.lora_weights_lst[0].dtype)}=>" + f"{dtype_to_str(self.output.dtype)}") + + @staticmethod + def make(ctx: BenchmarkContext, + op_type: OpType, + device: str = "cuda") -> "BenchmarkTensors": + + # Make input / output matmul tensors. + a_shape, b_shape, c_shape = op_type.matmul_shapes( + ctx.batch_size, ctx.seq_length, ctx.hidden_size, ctx.lora_rank, + ctx.num_loras, ctx.num_slices) + a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype) + input_tensor, lora_weights, output_tensor = \ + make_rand_tensors(a_shape, b_shape, c_shape, a_type, b_type, c_type, + num_slices = ctx.num_slices) + + # Make metadata tensors. + # Keep the metadata tensors in the CPU for further processing if needed. + # The tensors get moved to the GPU before benchmarking. + assert ctx.num_active_loras <= ctx.num_loras + total_tokens = ctx.batch_size * ctx.seq_length + + # Make metadata tensors involved in correctness testing. + # Prepare seq lens tensor + seq_len_tensor = torch.randint(ctx.seq_length, ctx.seq_length + 1, + (ctx.batch_size, )) + assert total_tokens == seq_len_tensor.sum() + # Prepare prompt lora indices tensor + prompt_lora_indices_tensor = make_prompt_lora_mapping( + ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu") + + # Make LoRAKernelMeta + token_lora_indices_tensor = make_token_lora_mapping( + total_tokens, ctx.batch_size, prompt_lora_indices_tensor, + seq_len_tensor, "cpu") + lora_kernel_meta = LoRAKernelMeta.make( + max_loras=ctx.num_loras, + max_num_tokens=token_lora_indices_tensor.size(0), + device="cpu") + lora_kernel_meta.prepare_tensors( + token_lora_mapping=token_lora_indices_tensor) + + return BenchmarkTensors(input_tensor, lora_weights, output_tensor, + lora_kernel_meta, seq_len_tensor, + prompt_lora_indices_tensor) + + def sanity_check(self) -> None: + """ + Fails asserts when non-conformality is detected. + """ + num_tokens = self.input.shape[-2] + # check metadata tensors + assert torch.sum(self.seq_lens) == num_tokens + num_seqs = self.seq_lens.shape[0] + #assert self.seq_start_loc.shape[0] == num_seqs + assert self.prompt_lora_mapping.shape[0] == num_seqs + assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens + + def to_device(self, device: str): + """ + Transfer tensors to device if the tensors aren't already on the device + """ + + def to_device(tensor: torch.Tensor): + if tensor.device != device: + tensor = tensor.to(device=device) + return tensor + + self.input = to_device(self.input) + self.output = to_device(self.output) + self.seq_lens = to_device(self.seq_lens) + self.prompt_lora_mapping = to_device(self.prompt_lora_mapping) + for i in range(len(self.lora_weights_lst)): + self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i]) + + # LoRA meta + for field_name in LoRAKernelMeta.__dataclass_fields__: + field = getattr(self.lora_kernel_meta, field_name) + assert isinstance(field, torch.Tensor) + setattr(self.lora_kernel_meta, field_name, to_device(field)) + + def metadata(self) -> tuple[int, int, int]: + """ + Return num_seqs, num_tokens and max_seq_len + """ + num_seqs = self.seq_lens.shape[0] + num_tokens = self.lora_kernel_meta.token_lora_mapping.shape[0] + max_seq_len = torch.max(self.seq_lens).item() + num_slices = len(self.lora_weights_lst) + return num_seqs, num_tokens, max_seq_len, num_slices + + def as_lora_shrink_kwargs(self) -> dict[str, Any]: + self.sanity_check() + self.to_device(self.input.device) + + _, num_tokens, _, num_slices = self.metadata() + + # Sanity check matrix shapes. + i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ + 0].shape, self.output.shape + # Expected input shape [num_tokens, hidden_size] + assert len(i_shape) == 2 + assert i_shape[0] == num_tokens + hidden_size = i_shape[1] + # Expected lora weight shape [num_loras, lora_rank, hidden_size] + assert len(lw_shape) == 3 + assert lw_shape[2] == hidden_size + lora_rank = lw_shape[1] + # Expected output shape [num_slices, num_tokens, lora_rank] + assert len(o_shape) == 3 + assert o_shape == (num_slices, num_tokens, lora_rank) + + return { + 'inputs': self.input, + 'lora_a_weights': self.lora_weights_lst, + 'output_tensor': self.output, + 'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping, + 'token_indices_sorted_by_lora_ids': + self.lora_kernel_meta.token_indices_sorted_by_lora_ids, + 'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora, + 'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc, + 'lora_ids': self.lora_kernel_meta.active_lora_ids, + 'scaling': 1.0, + } + + def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: + self.sanity_check() + self.to_device(self.input.device) + + _, num_tokens, _, num_slices = self.metadata() + + # Sanity check matrix shapes. + i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ + 0].shape, self.output.shape + # Expected input shape : [num_slices, num_tokens, lora_rank] + assert len(i_shape) == 3 + assert i_shape[0] == num_slices + assert i_shape[1] == num_tokens + lora_rank = i_shape[2] + # Expected lora weight shape : [num_lora, hidden_size, lora_rank] + assert len(lw_shape) == 3 + assert lw_shape[2] == lora_rank + hidden_size = lw_shape[1] + # Expected output shape : [num_tokens, hidden_size * num_slices] + assert len(o_shape) == 2 + assert o_shape == (num_tokens, hidden_size * num_slices) + + return { + 'inputs': self.input, + 'lora_b_weights': self.lora_weights_lst, + 'output_tensor': self.output, + 'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping, + 'token_indices_sorted_by_lora_ids': + self.lora_kernel_meta.token_indices_sorted_by_lora_ids, + 'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora, + 'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc, + 'lora_ids': self.lora_kernel_meta.active_lora_ids, + 'offset_start': 0, + 'add_inputs': add_inputs, + } + + def bench_fn_kwargs(self, + op_type: OpType, + add_inputs: Optional[bool] = None) -> dict[str, Any]: + if op_type.is_shrink_fn(): + assert add_inputs is None + else: + assert add_inputs is not None + + if op_type == OpType.LORA_SHRINK: + return self.as_lora_shrink_kwargs() + if op_type == OpType.LORA_EXPAND: + return self.as_lora_expand_kwargs(add_inputs) + raise ValueError(f"Unrecognized optype {self}") + + def test_correctness(self, op_type: OpType, + expand_fn_add_inputs: Optional[bool]) -> bool: + """ + Test correctness of op_type implementation against a grouped gemm + reference implementation. + """ + seq_lens_cpu = self.seq_lens.to(device="cpu") + prompt_lora_mapping_cpu = self.prompt_lora_mapping.to(device="cpu") + ref_output = self.output.clone() + + self.output.zero_() + op_type.bench_fn()( + **self.bench_fn_kwargs(op_type, expand_fn_add_inputs)) + + op_type.run_ref_group_gemm( + ref_output, + self.input, + self.lora_weights_lst, + seq_lens_cpu=seq_lens_cpu, + prompt_lora_mapping_cpu=prompt_lora_mapping_cpu, + scaling=1.0, + add_inputs=expand_fn_add_inputs) + + rtol, atol = { + torch.float16: (6e-2, 6e-2), + torch.bfloat16: (6e-2, 6e-2), + torch.float32: (1e-2, 1e-2), + }[self.output.dtype] + + return torch.allclose(ref_output, self.output, rtol=rtol, atol=atol) + + +def bench_optype(ctx: BenchmarkContext, + arg_pool_size: int, + op_type: OpType, + cuda_graph_nops: Optional[int] = None, + expand_fn_add_inputs: Optional[bool] = None, + test_correctness: bool = False) -> TMeasurement: + + assert arg_pool_size >= 1 + if op_type.is_shrink_fn(): + assert expand_fn_add_inputs is None + else: + assert expand_fn_add_inputs is not None + + # BenchmarkContext -> BenchmarkTensors + bench_tensors : list[BenchmarkTensors] = \ + [BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)] + for bt in bench_tensors: + bt.sanity_check() + + # Test correctness of our implementation. + if test_correctness: + assert all([ + bt.test_correctness(op_type, expand_fn_add_inputs) + for bt in bench_tensors + ]) + + # BenchmarkTensors -> dict (kwargs) + kwargs_list = [ + bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs) + for bt in bench_tensors + ] + + # Clear LoRA optimization hash-maps. + _LORA_A_PTR_DICT.clear() + _LORA_B_PTR_DICT.clear() + # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are setup + for kwargs in kwargs_list: + op_type.bench_fn()(**kwargs) + torch.cuda.synchronize() + + # Merge into a single kwargs and qualify arguments as ArgPool + kwargs = {k: ArgPool([]) for k in kwargs_list[0]} + for _kwargs in kwargs_list: + for k, v in _kwargs.items(): + kwargs[k].values.append(v) + + describe_args = (f"add_inputs={expand_fn_add_inputs}" + if expand_fn_add_inputs is not None else "") + description = ( + f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})") + + cuda_graph_params = None + if cuda_graph_nops: + cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops) + timer = None + with Bench(cuda_graph_params, + ctx.bench_label(), ctx.bench_sublabel(op_type), description, + op_type.bench_fn(), **kwargs) as bench: + timer = bench.run() + return timer + + +def bench_torch_mm(ctx: BenchmarkContext, + arg_pool_size: int, + op_type: OpType, + cuda_graph_nops: Optional[int] = None) -> TMeasurement: + """ + Benchmark basic torch.mm as a roofline. + + When all the input tokens have the same LoRA ID, the LoRA kernels are just + a matmul. This torch.mm benchmark serves as a roofline for that case. + + input op_type is used in determining the m, k, n dimensions for the matmul. + """ + + batch_size, hidden_size, lora_rank, seq_length, dtype = (ctx.batch_size, + ctx.hidden_size, + ctx.lora_rank, + ctx.seq_length, + ctx.dtype) + + m, k, n = op_type.mkn(batch_size, seq_length, hidden_size, lora_rank) + # For a fairer comparison. + n = n * ctx.num_slices + + # Get matmul input and output tensors for A x B = C + As, Bs, Cs = [], [], [] + for _ in range(arg_pool_size): + As.append(torch.rand((m, k), dtype=dtype).to("cuda")) + Bs.append(torch.rand((n, k), dtype=dtype).to("cuda").t()) + Cs.append(torch.rand((m, n), dtype=dtype).to("cuda")) + + # Make torch.mm kwargs + mm_kwargs = {'input': ArgPool(As), 'mat2': ArgPool(Bs), 'out': ArgPool(Cs)} + + description = ( + f"single-lora roofline using torch.mm ({dtype_to_str(dtype)}" + f"x{dtype_to_str(dtype)}" + f"=>{dtype_to_str(dtype)})") + cuda_graph_params = None + if cuda_graph_nops: + cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops) + with Bench(cuda_graph_params, ctx.bench_label(), + ctx.bench_sublabel(op_type), description, torch.mm, + **mm_kwargs) as bench: + return bench.run() + + +# runner +def use_cuda_graph_recommendation() -> str: + return """ + Triton kernels have a significant launch overhead with + launched directly via python. This overhead is more noticeable + for small the problem sizes. For these cases, it is recommended + to use the script with `--cuda-graph-nops N` to benchmark N + consecutive invocations of the benchmarking operations from + inside a CUDA Graph. Note that the returned measurement is for N + invocations of the operation. + """ + + +def print_timers(timers: list[TMeasurement], + args: Optional[argparse.Namespace] = None): + compare = TBenchmark.Compare(timers) + compare.print() + + if args and args.cuda_graph_nops: + print( + f"Note : The timings reported above is for {args.cuda_graph_nops} " + "consecutive invocations of the benchmarking functions. " + f"Please divide by {args.cuda_graph_nops} for single invocation " + "timings.") + + print("Note on Comparison with torch.mm : The torch.mm numbers are " + "benchmark numbers of a simple matmul emulating the single lora " + "case. It is provided as a roofline for comparing our LoRA Kernel " + "implementations. It is expected that the LoRA kernels will be " + "slower than torch.mm in cases where num_loras is big. But for " + "small num_loras the goal should be to match the torch.mm numbers.") + + +def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]): + + if args.cuda_graph_nops is not None: + assert args.cuda_graph_nops > 0 + print(f"Benchmarking {args.cuda_graph_nops} invocations inside a CUDA " + "Graph") + else: + print(f"CUDA Graphs not enabled.\n{use_cuda_graph_recommendation()}") + + timers = [] + for bench_ctx in bench_ctxs: + for seq_len in args.seq_lengths: + bench_ops: list[OpType] = args.op_types + seq_len_timers = [] + for bench_op in bench_ops: + for num_slices in bench_op.num_slices(): + _ctx = bench_ctx.with_seq_length(seq_len).with_num_slices( + num_slices) + # Benchmark torch.mm as a roofline + seq_len_timers.append( + bench_torch_mm(_ctx, args.arg_pool_size, bench_op, + args.cuda_graph_nops)) + + # Benchmark bench_op + expand_fn_add_inputs = [ + None + ] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs + for add_input_arg in expand_fn_add_inputs: + seq_len_timers.append( + bench_optype(_ctx, args.arg_pool_size, bench_op, + args.cuda_graph_nops, add_input_arg, + args.test_correctness)) + + print_timers(seq_len_timers) + timers.extend(seq_len_timers) + + # Result stdout dump + print("== All Results ====") + print_timers(timers, args) + + if args.output_directory: + # Result file dump + od = Path(args.output_directory) + if not od.exists(): + od.mkdir() + + timestamp = int(time.time()) + pkl_file = od / f"lora_bench-{timestamp}.pkl" + print(f"Writing benchmarks to {pkl_file}") + with open(pkl_file, "wb") as f: + pickle.dump(timers, f) + + +def as_benchmark_contexts(hidden_sizes: list[int], lora_ranks: list[int], + args: argparse.Namespace) -> list[BenchmarkContext]: + + ctxs: list[BenchmarkContext] = [] + for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa + args.batch_sizes, list(hidden_sizes), lora_ranks, args.num_loras, + args.sort_by_lora_id): + ctxs.append( + BenchmarkContext( + batch_size=batch_size, + hidden_size=hidden_size, + lora_rank=lora_rank, + num_loras=num_loras, + num_active_loras=args.num_active_loras + if args.num_active_loras else num_loras, + # To be filled based on the OpType to benchmark + seq_length=None, + sort_by_lora_id=sort_by_lora_id, + dtype=args.dtype, + # To be filled based on the OpType to benchmark + num_slices=None)) + + return ctxs + + +def run_list_bench(args: argparse.Namespace): + print(args) + + print("List bench :\n" + f" Hidden Sizes {args.hidden_sizes}" + f" LoRA Ranks {args.lora_ranks}") + + # Get all benchmarking contexts + bench_contexts: list[BenchmarkContext] = as_benchmark_contexts( + hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args) + + run(args, bench_contexts) + + +def run_range_bench(args: argparse.Namespace): + print(args) + + hidden_sizes = list( + range(args.hidden_sizes_start, args.hidden_sizes_end + 1, + args.hidden_sizes_increment)) + lora_ranks = list( + range(args.lora_ranks_start, args.lora_ranks_end + 1, + args.lora_ranks_increment)) + + print("Range bench :\n" + f" Hidden Sizes {hidden_sizes}" + f" LoRA Ranks {lora_ranks}") + + # Get all benchmarking contexts + bench_contexts: list[BenchmarkContext] = as_benchmark_contexts( + hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args) + + run(args, bench_contexts) + + +def run_model_bench(args: argparse.Namespace): + print(args) + + def hidden_sizes_from_model(model: str, tp_size: int) -> set[int]: + hidden_sizes = set() + for KN, tp_split_dim in WEIGHT_SHAPES[model]: + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + hidden_sizes.add(KN[1]) + return hidden_sizes + + # Get all hidden sizes + hidden_sizes: set[int] = set() + for model_name, tp_size in product(args.models, args.tp_sizes): + hidden_sizes = hidden_sizes.union( + hidden_sizes_from_model(model_name, tp_size)) + + print("Model bench :\n" + f" Hidden Sizes {hidden_sizes}" + f" LoRA Ranks {args.lora_ranks}") + + # Get all benchmarking contexts + bench_contexts: list[BenchmarkContext] = as_benchmark_contexts( + hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args) + + run(args, bench_contexts) + + +if __name__ == '__main__': + + def to_torch_dtype(dt): + if dt == "torch.float16": + return torch.float16 + if dt == "torch.bfloat16": + return torch.bfloat16 + raise ValueError("unsupported dtype") + + def get_bool(s: str) -> bool: + return s.lower() in ['true', '1'] + + def add_common_command_args(p: argparse.ArgumentParser): + p.add_argument( + "--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['torch.float16', 'torch.bfloat16']") + + p.add_argument( + "--arg-pool-size", + type=int, + default=32, + help="Run profiles with a pool of input/output/meta tensors instead" + "of simply reusing the same tensors for all runs. A bigger arg-pool" + "mitigates hardware caching effects during benchmarking.") + + p.add_argument( + "--cuda-graph-nops", + type=int, + help=("when set profiling is done using cudagraph, " + "with the given number of operations in a graph." + "Note that the measurement returned is the time " + "taken for N consecutive executions of the benchmarking " + "functions, where N is the value of this argument.")) + p.add_argument("--num-loras", + nargs="+", + type=int, + default=DEFAULT_NUM_LORAS) + p.add_argument("--num-active-loras", + type=int, + default=None, + help="Active LoRAs. When None, all LoRAs are active") + p.add_argument("--sort-by-lora-id", + nargs="+", + type=get_bool, + default=DEFAULT_SORT_BY_LORA_IDS) + p.add_argument("--op-types", + nargs="+", + type=OpType.from_str, + default=list(OpType)) + p.add_argument('--seq-lengths', + nargs="+", + type=int, + default=DEFAULT_SEQ_LENGTHS) + p.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + p.add_argument("--expand-fn-add-inputs", + nargs="+", + type=get_bool, + default=DEFAULT_EXPAND_FN_ADD_INPUTS) + p.add_argument( + '-o', + '--output-directory', + type=str, + help=("Output directory to store a the list of benchmarking" + "TMeasurement objects as a pickle file")) + + p.add_argument( + "--test-correctness", + action='store_true', + help=("When enabled, the benchmarking functions are tested" + "for correctness before the actual benchmarking")) + + parser = FlexibleArgumentParser( + description=f""" +Benchmark LoRA kernels: + {use_cuda_graph_recommendation()} + + list_bench example: + python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 + + model_bench example: + python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 + + range_bench example: + python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8 + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter) + + subparsers = parser.add_subparsers(dest="cmd", required=True) + + list_parser = subparsers.add_parser("list_bench") + list_parser.add_argument("--hidden-sizes", + nargs="+", + type=int, + default=DEFAULT_HIDDEN_SIZES) + list_parser.add_argument("--lora-ranks", + nargs="+", + type=int, + default=DEFAULT_LORA_RANKS) + add_common_command_args(list_parser) + list_parser.set_defaults(func=run_list_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--hidden-sizes-start", type=int, required=True) + range_parser.add_argument("--hidden-sizes-end", type=int, required=True) + range_parser.add_argument("--hidden-sizes-increment", + type=int, + required=True) + range_parser.add_argument("--lora-ranks-start", type=int, required=True) + range_parser.add_argument("--lora-ranks-end", type=int, required=True) + range_parser.add_argument("--lora-ranks-increment", + type=int, + required=True) + add_common_command_args(range_parser) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument("--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys()) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--lora-ranks", + nargs="+", + type=int, + default=DEFAULT_LORA_RANKS) + add_common_command_args(model_parser) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py new file mode 100644 index 000000000000..a661ea9d7e60 --- /dev/null +++ b/benchmarks/kernels/benchmark_machete.py @@ -0,0 +1,674 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import copy +import itertools +import math +import os +import pickle as pkl +import time +from collections.abc import Iterable +from dataclasses import dataclass +from itertools import product +from typing import Callable, Optional + +import pandas as pd +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales, + marlin_zero_points) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + MarlinWorkspace) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_rows, quantize_weights) +from vllm.scalar_type import ScalarType, scalar_types +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"] +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024] +DEFAULT_TP_SIZES = [1] + +NVTX_PROFILE = os.environ.get("NVTX_PROFILE", False) + +if NVTX_PROFILE: + import nvtx + + +def terse_type_name(dt): + return { + torch.bfloat16: "bf16", + torch.float16: "fp16", + torch.int8: "int8", + torch.float8_e4m3fn: "fp8", + torch.float: "float", + torch.int: "int", + }[dt] + + +@dataclass +class BenchmarkTensors: + w_ref: torch.Tensor + a: torch.Tensor + + w_q: torch.Tensor + group_size: Optional[int] + wtype: ScalarType + w_g_s: torch.Tensor + w_g_zp: Optional[torch.Tensor] + w_ch_s: Optional[torch.Tensor] + w_tok_s: Optional[torch.Tensor] + + +@dataclass +class TypeConfig: + act_type: torch.dtype + weight_type: ScalarType + output_type: Optional[torch.dtype] + group_scale_type: Optional[torch.dtype] + group_zero_type: Optional[torch.dtype] + channel_scale_type: Optional[torch.dtype] + token_scale_type: Optional[torch.dtype] + + +def rand_data(shape, dtype=torch.float16, scale=1): + if dtype.is_floating_point: + return (scale * torch.rand(shape, device="cuda") - 0.3).to(dtype) + else: + return torch.randint(-15, 15, shape, dtype=dtype, device="cuda") + + +def quantize_and_pack(atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: Optional[torch.dtype], + group_size: Optional[int], + zero_points: bool = False): + assert wtype.is_integer(), "TODO: support floating point weights" + + w_ref, w_q, w_s, w_zp = quantize_weights( + w, + wtype, + group_size=group_size, + zero_points=zero_points, + # to match how the kernel applies zps + ref_zero_points_after_scales=True) + + w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) + return w_ref, w_q, w_s, w_zp + + +def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig, + group_size: Optional[int]) -> list[BenchmarkTensors]: + m, n, k = shape + + # we want to make sure that weights don't fit into L2 cache between runs so + # we construct enough weights to exceed L2 cache, which is 50mb on a H100 + # so we target total weight size > 2*50mb + num_weights = math.ceil(2 * 50 * 1024**2 * 8 / + (k * n * types.weight_type.size_bits)) + + a = rand_data((m, k), types.act_type, scale=5) + + benchmark_tensors: list[BenchmarkTensors] = [] + for _ in range(num_weights): + w = rand_data((k, n), types.act_type, scale=5) + + if types.group_scale_type is not None: + w = w.to(types.group_scale_type) + if w.dtype.itemsize == 1: + w = w.to(torch.float16) + + w_ref, w_q_packed, w_s, w_zp = quantize_and_pack( + a.dtype, w, types.weight_type, types.group_scale_type, group_size, + types.group_zero_type is not None) + + if not a.dtype.is_floating_point: + aiinfo = torch.iinfo(a.dtype) + w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max) + + w_ref = w_ref.to(torch.float32) + + w_ch_s = None if types.channel_scale_type is None else\ + rand_data((n,), types.channel_scale_type) + w_tok_s = None if types.token_scale_type is None else\ + rand_data((m,), types.token_scale_type) + + benchmark_tensors.append( + BenchmarkTensors(w_ref=w_ref, + a=a, + w_q=w_q_packed, + wtype=types.weight_type, + w_g_s=w_s, + w_g_zp=w_zp, + group_size=group_size, + w_ch_s=w_ch_s, + w_tok_s=w_tok_s)) + + return benchmark_tensors + + +def torch_matmul_f16_create_bench_fn(bt: BenchmarkTensors) -> Callable: + a = bt.a + w = bt.w_ref.to(bt.a.dtype) # use float reference tensor + if a.dtype not in [torch.float16, torch.bfloat16]: + a = a.to(torch.float16) + w = w.to(torch.float16) + return lambda: torch.matmul(a, w) + + +def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable: + if bt.w_ch_s is not None and bt.w_tok_s is not None: + scale_a = bt.w_tok_s.to(torch.float32) + scale_b = bt.w_ch_s.to(torch.float32) + else: + scale_a = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device) + scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device) + w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t() + return lambda: ops.cutlass_scaled_mm( + bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16) + + +def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: + device = bt.a.device + + workspace = MarlinWorkspace(bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + if bt.w_g_zp is None: + w_zp = torch.empty(0, dtype=torch.int, device=device) + else: + w_zp = marlin_zero_points(bt.w_g_zp, bt.w_ref.shape[0], + bt.w_ref.shape[1], bt.wtype.size_bits) + + if bt.group_size is None: + w_s = torch.tensor([], device="cuda", dtype=torch.half) + else: + w_s = marlin_permute_scales(bt.w_g_s, bt.w_ref.shape[0], + bt.w_ref.shape[1], bt.group_size) + + sort_indices = torch.empty(0, dtype=torch.int, device=device) + g_idx = torch.empty(0, dtype=torch.int, device=device) + w_q = ops.gptq_marlin_repack(bt.w_q, sort_indices, bt.w_ref.shape[0], + bt.w_ref.shape[1], bt.wtype.size_bits) + + if bt.a.dtype.is_floating_point: + assert bt.w_ch_s is None + assert bt.w_tok_s is None + assert bt.group_size is not None + + fn = lambda: ops.gptq_marlin_gemm(a=bt.a, + b_q_weight=w_q, + b_scales=w_s, + b_zeros=w_zp, + g_idx=g_idx, + perm=sort_indices, + workspace=workspace.scratch, + b_q_type=bt.wtype, + size_m=bt.a.shape[0], + size_n=bt.w_ref.shape[1], + size_k=bt.w_ref.shape[0], + is_k_full=True, + is_zp_float=False) + else: + assert bt.a.dtype == torch.int8 + assert bt.wtype == scalar_types.uint4b8 + + if bt.w_ch_s is not None: + s_ch = bt.w_ch_s.to(torch.float32) + else: + s_ch = torch.ones(bt.w_ref.shape[1], + dtype=torch.float32, + device=device) + + if bt.w_tok_s is not None: + s_tok = bt.w_tok_s.to(torch.float32) + else: + s_tok = torch.ones(bt.a.shape[0], + dtype=torch.float32, + device=device) + + fn = lambda: ops.marlin_qqq_gemm(a=bt.a, + b_q_weight=w_q, + s_group=w_s, + s_tok=s_tok, + s_ch=s_ch, + workspace=workspace.scratch, + size_m=bt.a.shape[0], + size_n=bt.w_ref.shape[1], + size_k=bt.w_ref.shape[0]) + + return fn + + +def machete_create_bench_fn(bt: BenchmarkTensors, + out_type=torch.dtype, + schedule=None) -> Callable: + w_q = bt.w_q.t().contiguous().t() # make col major + w_q = ops.machete_prepack_B(w_q, bt.a.dtype, bt.wtype, + None if bt.w_g_s is None else bt.w_g_s.dtype) + + w_g_zp = bt.w_g_zp + if w_g_zp is not None: + w_g_zp = -1 * bt.w_g_s * (w_g_zp.to(bt.w_g_s.dtype)) + + return lambda: ops.machete_mm( + a=bt.a, + b_q=w_q, + b_type=bt.wtype, + b_group_scales=bt.w_g_s, + b_group_zeros=w_g_zp, + b_group_size=bt.group_size, + b_channel_scales=bt.w_ch_s, + a_token_scales=bt.w_tok_s, + out_type=out_type, + schedule=schedule, + ) + + +# impl + +# bench + + +def bench_fns(label: str, sub_label: str, description: str, + fns: list[Callable]): + + min_run_time = 1 if not NVTX_PROFILE else 0.1 + res = TBenchmark.Timer( + stmt=""" + for fn in fns: + fn() + """, + globals={ + "fns": fns + }, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + if NVTX_PROFILE: + with nvtx.annotate("mm-bench"), nvtx.annotate( + f"{label}|{sub_label}|{description}"): + fns[0]() + + return res + + +_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None +_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None + + +def bench(types: TypeConfig, + group_size: int, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + sweep_schedules: bool = True) -> list[TMeasurement]: + benchmark_tensors = create_bench_tensors((m, n, k), types, group_size) + sub_label += f", L={len(benchmark_tensors)}" + + name_type_string = f"W{types.weight_type}"+\ + f"-A{terse_type_name(types.act_type)}" + if types.group_scale_type is not None: + name_type_string += f"-GS{terse_type_name(types.group_scale_type)}" + if types.group_zero_type is not None: + name_type_string += f"-GZ{terse_type_name(types.group_zero_type)}" + if group_size is not None: + name_type_string += f"-G{group_size}" + if types.channel_scale_type is not None: + name_type_string += f"-CS{terse_type_name(types.channel_scale_type)}" + if types.token_scale_type is not None: + name_type_string += f"-TS{terse_type_name(types.token_scale_type)}" + + timers = [] + # pytorch impl + timers.append( + bench_fns( + label, sub_label, "torch.matmul (fp16)", + [torch_matmul_f16_create_bench_fn(bt) + for bt in benchmark_tensors])) + + if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn: + timers.append( + bench_fns( + label, sub_label, + f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", [ + cutlass_scaled_mm_create_bench_fn(bt) + for bt in benchmark_tensors + ])) + + if types.act_type != torch.float8_e4m3fn: + timers.append( + bench_fns(label, sub_label, f"marlin ({name_type_string})", + [marlin_create_bench_fn(bt) + for bt in benchmark_tensors])) + + # machete + timers.append( + bench_fns(label, sub_label, f"machete ({name_type_string})", [ + machete_create_bench_fn(bt, out_type=types.output_type) + for bt in benchmark_tensors + ])) + + if sweep_schedules: + global _SWEEP_SCHEDULES_RESULTS + + print("Finding best schedule for machete") + best = None + best_schedule = None + schedules = ops.machete_supported_schedules( + a_type=types.act_type, + b_type=types.weight_type, + group_scales_type=types.group_scale_type, + group_zeros_type=types.group_zero_type, + token_scales_type=types.token_scale_type, + channel_scales_type=types.channel_scale_type, + out_type=types.output_type) + + if schedules is None or len(schedules) == 0: + raise ValueError("No schedules found to sweep") + + for schedule in reversed(schedules): + schedule_M = int(schedule.split("_")[0].split("x")[1]) + + # Prune known bad schedules + if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4: + continue + + res = bench_fns(label, sub_label, "machete_best", [ + machete_create_bench_fn( + bt, out_type=types.output_type, schedule=schedule) + for bt in benchmark_tensors + ]) + + results_row = { + "M": m, + "K": k, + "N": n, + "group_size": group_size, + "schedule": schedule, + "median": res.median, + } + if _SWEEP_SCHEDULES_RESULTS is None: + _SWEEP_SCHEDULES_RESULTS = pd.DataFrame( + columns=results_row.keys()) + _SWEEP_SCHEDULES_RESULTS.\ + loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row + + print(f" {res.median:5.5} ", schedule) + if not best or res.median < best.median: + best = res + best_schedule = schedule + print("Best schedule:", best_schedule) + timers.append(best) + + return timers + + +# runner +def print_timers(timers: list[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]: + types = TypeConfig( + act_type=args.act_type, + weight_type=scalar_types.uint4b8 if args.group_zero_type is None \ + else scalar_types.uint4, + output_type=args.out_type, + group_scale_type=args.group_scale_type, + group_zero_type=args.group_zero_type, + channel_scale_type=args.channel_scale_type, + token_scale_type=args.token_scale_type, + ) + + results: list[TMeasurement] = [] + for m, k, n in MKNs: + timers = bench(types, + args.group_size, + m, + k, + n, + f"{args.act_type}-gemm", + f"MKN=({m}x{k}x{n})", + sweep_schedules=args.sweep_schedules) + print_timers(timers) + results.extend(timers) + + return results + + +# output makers +def make_output( + data: list[TMeasurement], + MKNs: Iterable[tuple[int, int, int]], + base_description: str, + timestamp=None, +): + + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +# argparse runners + + +def run_square_bench(args): + dim_sizes = list( + range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, args.sweep_schedules, MKNs) + + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + m_start, k_start, n_start = (int(x) for x in args.dim_start.split(",")) + m_end, k_end, n_end = (int(x) for x in args.dim_end.split(",")) + m_increment, k_increment, n_increment = \ + (int(x) for x in args.dim_increment.split(",")) + Ms = list(range(m_start, m_end + 1, m_increment)) + Ks = list(range(k_start, k_end + 1, k_increment)) + Ns = list(range(n_start, n_end + 1, n_increment)) + MKNs = list(product(Ms, Ks, Ns)) + + data = run(args.dtype, args.sweep_schedules, MKNs) + + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args, MKNs) + model_bench_data.append(data) + + type_string = f"{args.act_type}" + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {type_string} {model}-TP{tp_size} ====") + print_timers(data) + + timestr = time.strftime("%Y%m%d-%H%M%S") + + all_results = [] + for d in model_bench_data: + all_results.extend(d) + + # pickle all data + with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f: + args_dict = vars(args) + args_dict.pop("func") + pkl.dump({ + "args": args_dict, + "results": all_results, + }, f) + + +if __name__ == "__main__": + + def to_torch_dtype(dt): + return { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "int8": torch.int8, + "float8_e4m3fn": torch.float8_e4m3fn, + "int": torch.int, + "float": torch.float, + }[dt] + + class ToTorchDtype(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, to_torch_dtype(values)) + + parser = FlexibleArgumentParser( + description=""" +Benchmark Machete GEMM. + + To run square GEMMs: + python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "--act-type", + action=ToTorchDtype, + required=True, + choices=['bfloat16', 'float16', 'int8', 'float8_e4m3fn'], + ) + parser.add_argument( + "--group-scale-type", + action=ToTorchDtype, + choices=['bfloat16', 'float16'], + ) + parser.add_argument( + "--group-zero-type", + type=to_torch_dtype, + choices=['bfloat16', 'float16'], + ) + parser.add_argument( + "--channel-scale-type", + action=ToTorchDtype, + choices=['float'], + ) + parser.add_argument( + "--token-scale-type", + action=ToTorchDtype, + choices=['float'], + ) + parser.add_argument( + "--out-type", + action=ToTorchDtype, + choices=['bfloat16', 'float16'], + ) + parser.add_argument( + "--group-size", + type=int, + help="Available options are ['None', '-1', '128'], default=128", + default=128, + ) + parser.add_argument( + "--sweep-schedules", + action="store_true", + help="Run a sweep over all supported schedules", + ) + parser.add_argument("--sweep-csv-out", + help="CSV to store sweep results", + default="sch_sweep_results.csv") + subparsers = parser.add_subparsers(dest="cmd", required=True) + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument( + "--dim-start", + type=str, + required=True, + help="Start value for M,K,N as common separated list") + range_parser.add_argument( + "--dim-end", + type=str, + required=True, + help="End value (inclusive) for M,K,N as common separated list") + range_parser.add_argument( + "--dim-increment", + type=str, + required=True, + help="Increment value for M,K,N as common separated list") + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + + _SWEEP_SCHEDULES_RESULTS_CSV = args.sweep_csv_out + args.func(args) + + if _SWEEP_SCHEDULES_RESULTS is not None: + _SWEEP_SCHEDULES_RESULTS.to_csv(_SWEEP_SCHEDULES_RESULTS_CSV) diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 684985b81f69..1e785ac8fc73 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -1,4 +1,4 @@ -from typing import List +# SPDX-License-Identifier: Apache-2.0 import torch import torch.utils.benchmark as benchmark @@ -7,33 +7,37 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) +from vllm.model_executor.layers.quantization.utils.allspark_utils import ( + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS) + MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( MarlinWorkspace, marlin_quantize) from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( marlin_24_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - gptq_pack, quantize_weights, sort_weights) + gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights) +from vllm.scalar_type import ScalarType from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] -DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] ACT_ORDER_OPTS = [False, True] K_FULL_OPTS = [False, True] -def bench_run(results: List[benchmark.Measurement], model: str, - act_order: bool, is_k_full: bool, num_bits: int, group_size: int, - size_m: int, size_k: int, size_n: int): +def bench_run(results: list[benchmark.Measurement], model: str, + act_order: bool, is_k_full: bool, quant_type: ScalarType, + group_size: int, size_m: int, size_k: int, size_n: int): label = "Quant Matmul" - sub_label = ("{}, act={} k_full={}, b={}, g={}, " - "MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits, - group_size, size_m, size_k, size_n)) + sub_label = ("{}, act={} k_full={}, q={}, g={}, " + "MKN=({}x{}x{})".format(model, act_order, is_k_full, + str(quant_type), group_size, size_m, + size_k, size_n)) print(f"Testing: {sub_label}") @@ -50,18 +54,18 @@ def bench_run(results: List[benchmark.Measurement], model: str, marlin_g_idx, marlin_sort_indices, marlin_rand_perm, - ) = marlin_quantize(b, num_bits, group_size, act_order) + ) = marlin_quantize(b, quant_type, group_size, act_order) # Marlin_24 quant (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, - marlin_24_s) = marlin_24_quantize(b, num_bits, group_size) + marlin_24_s) = marlin_24_quantize(b, quant_type, group_size) marlin_zp = torch.empty(0, dtype=torch.int, device=b.device) # GPTQ quant (w_ref, q_w, s, g_idx, - rand_perm) = quantize_weights(b, num_bits, group_size, act_order) - q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n) + rand_perm) = gptq_quantize_weights(b, quant_type, group_size, act_order) + q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) # For act_order, sort the "weights" and "g_idx" # so that group ids are increasing @@ -75,10 +79,32 @@ def bench_run(results: List[benchmark.Measurement], model: str, marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL) + marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int) + + # AllSpark W8A16 quant + as_supported_case = (quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES + and group_size == -1 and not act_order and is_k_full) + if as_supported_case: + properties = torch.cuda.get_device_properties(b.device.index) + sm_count = properties.multi_processor_count + sm_version = properties.major * 10 + properties.minor + + supported_arch = (sm_version >= 80 and sm_version < 90) + as_supported_case = as_supported_case and supported_arch + if supported_arch: + has_zp = False + w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, + has_zp) + qw = qw.to(torch.uint8) + + qw_reorder, s_reorder, zp_reorder = \ + ops.allspark_repack_weight( + qw, s, zp, has_zp) + CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD globals = { # Gen params - "num_bits": num_bits, + "quant_type": quant_type, "group_size": group_size, "size_m": size_m, "size_n": size_n, @@ -104,10 +130,19 @@ def bench_run(results: List[benchmark.Measurement], model: str, # GPTQ params "q_w_gptq": q_w_gptq, "repack_sort_indices": repack_sort_indices, + # AllSpark W8A16 params + "qw_reorder": qw_reorder if as_supported_case else None, + "s_reorder": s_reorder if as_supported_case else None, + "zp_reorder": zp_reorder if as_supported_case else None, + "sm_count": sm_count if as_supported_case else None, + "sm_version": sm_version if as_supported_case else None, + "CUBLAS_M_THRESHOLD": + CUBLAS_M_THRESHOLD if as_supported_case else None, # Kernels "gptq_marlin_gemm": ops.gptq_marlin_gemm, "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, "gptq_marlin_repack": ops.gptq_marlin_repack, + "allspark_w8a16_gemm": ops.allspark_w8a16_gemm, } min_run_time = 1 @@ -128,7 +163,7 @@ def bench_run(results: List[benchmark.Measurement], model: str, results.append( benchmark.Timer( stmt= - "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501 + "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -138,19 +173,19 @@ def bench_run(results: List[benchmark.Measurement], model: str, results.append( benchmark.Timer( stmt= - "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501 + "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="gptq_marlin_gemm_fp32", ).blocked_autorange(min_run_time=min_run_time)) - if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS + if (quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES): results.append( benchmark.Timer( stmt= - "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501 + "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -160,20 +195,31 @@ def bench_run(results: List[benchmark.Measurement], model: str, results.append( benchmark.Timer( stmt= - "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501 + "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="gptq_marlin_repack", ).blocked_autorange(min_run_time=min_run_time)) + if as_supported_case: + results.append( + benchmark.Timer( + stmt= + "output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="allspark_w8a16_gemm_fp32", + ).blocked_autorange(min_run_time=min_run_time)) + def main(args): print("Benchmarking models:") for i, model in enumerate(args.models): print(f"[{i}] {model}") - results: List[benchmark.Measurement] = [] + results: list[benchmark.Measurement] = [] for model in args.models: for layer in WEIGHT_SHAPES[model]: @@ -196,9 +242,10 @@ def main(args): ) > 0 and is_k_full not in args.limit_k_full: continue - for num_bits in MARLIN_SUPPORTED_NUM_BITS: - if len(args.limit_num_bits - ) > 0 and num_bits not in args.limit_num_bits: + for quant_type in query_marlin_supported_quant_types( + False): + if len(args.limit_num_bits) > 0 and \ + quant_type.size_bits not in args.limit_num_bits: continue for group_size in MARLIN_SUPPORTED_GROUP_SIZES: @@ -215,8 +262,8 @@ def main(args): for size_m in args.batch_sizes: bench_run(results, model, act_order, is_k_full, - num_bits, group_size, size_m, size_k, - size_n) + quant_type, group_size, size_m, + size_k, size_n) compare = benchmark.Compare(results) compare.print() diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index e00696d6d43c..491f8c3962f7 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -1,7 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 + import argparse +import json import time +from contextlib import nullcontext from datetime import datetime -from typing import Any, Dict, List, Tuple, TypedDict +from itertools import product +from typing import Any, TypedDict import ray import torch @@ -10,8 +15,11 @@ from transformers import AutoConfig from vllm.model_executor.layers.fused_moe.fused_moe import * +from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser +FP8_DTYPE = current_platform.fp8_dtype() + class BenchmarkConfig(TypedDict): BLOCK_SIZE_M: int @@ -30,19 +38,37 @@ def benchmark_config( hidden_size: int, topk: int, dtype: torch.dtype, - use_fp8: bool, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, num_iters: int = 100, + block_quant_shape: List[int] = None, ) -> float: - init_dtype = torch.float16 if use_fp8 else dtype + init_dtype = torch.float16 if use_fp8_w8a8 else dtype x = torch.randn(num_tokens, hidden_size, dtype=dtype) - w1 = torch.randn(num_experts, - shard_intermediate_size, - hidden_size, - dtype=init_dtype) - w2 = torch.randn(num_experts, - hidden_size, - shard_intermediate_size // 2, - dtype=init_dtype) + if use_int8_w8a16: + w1 = torch.randint(-127, + 127, ( + num_experts, + shard_intermediate_size, + hidden_size, + ), + dtype=torch.int8) + w2 = torch.randint(-127, + 127, ( + num_experts, + hidden_size, + shard_intermediate_size // 2, + ), + dtype=torch.int8) + else: + w1 = torch.randn(num_experts, + shard_intermediate_size, + hidden_size, + dtype=init_dtype) + w2 = torch.randn(num_experts, + hidden_size, + shard_intermediate_size // 2, + dtype=init_dtype) gating_output = torch.randn(num_iters, num_tokens, num_experts, @@ -52,14 +78,34 @@ def benchmark_config( w2_scale = None a1_scale = None a2_scale = None - if use_fp8: - w1_scale = torch.randn(num_experts, dtype=torch.float32) - w2_scale = torch.randn(num_experts, dtype=torch.float32) + if use_int8_w8a16: + w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size), + dtype=torch.float32) + w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) + if use_fp8_w8a8: + if block_quant_shape: + block_n, block_k = block_quant_shape[0], block_quant_shape[1] + E = num_experts + N = shard_intermediate_size // 2 + K = hidden_size + factor_for_scale = 1e-2 + n_tiles_w1 = (2 * N + block_n - 1) // block_n + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + k_tiles_w2 = (N + block_k - 1) // block_k + w1_scale = torch.rand((E, n_tiles_w1, k_tiles_w1), + dtype=torch.float32) * factor_for_scale + w2_scale = torch.rand((E, n_tiles_w2, k_tiles_w2), + dtype=torch.float32) * factor_for_scale + else: + w1_scale = torch.randn(num_experts, dtype=torch.float32) + w2_scale = torch.randn(num_experts, dtype=torch.float32) + a1_scale = torch.randn(1, dtype=torch.float32) a2_scale = torch.randn(1, dtype=torch.float32) - w1 = w1.to(torch.float8_e4m3fn) - w2 = w2.to(torch.float8_e4m3fn) + w1 = w1.to(FP8_DTYPE) + w2 = w2.to(FP8_DTYPE) input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) @@ -67,21 +113,24 @@ def prepare(i: int): input_gating.copy_(gating_output[i]) def run(): - fused_moe( - x, - w1, - w2, - input_gating, - topk, - renormalize=True, - inplace=True, - override_config=config, - use_fp8=use_fp8, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - ) + from vllm.model_executor.layers.fused_moe import override_config + with override_config(config): + fused_moe( + x, + w1, + w2, + input_gating, + topk, + renormalize=True, + inplace=True, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_quant_shape, + ) # JIT compilation & warmup run() @@ -102,7 +151,7 @@ def run(): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) - latencies: List[float] = [] + latencies: list[float] = [] for i in range(num_iters): prepare(i) torch.cuda.synchronize() @@ -117,35 +166,194 @@ def run(): return avg -def get_configs_compute_bound() -> List[Dict[str, int]]: - # Reduced search space for faster tuning. - # TODO(woosuk): Increase the search space and use a performance model to - # prune the search space. - configs: List[BenchmarkConfig] = [] - for num_stages in [2, 3, 4, 5]: - for block_m in [16, 32, 64, 128, 256]: - for block_k in [64, 128, 256]: - for block_n in [32, 64, 128, 256]: - for num_warps in [4, 8]: - for group_size in [1, 16, 32, 64]: - configs.append({ - "BLOCK_SIZE_M": block_m, - "BLOCK_SIZE_N": block_n, - "BLOCK_SIZE_K": block_k, - "GROUP_SIZE_M": group_size, - "num_warps": num_warps, - "num_stages": num_stages, - }) +def get_rocm_tuning_space(use_fp16): + block_mn_range = [16, 32, 64, 128, 256] + block_k_range = [16, 32, 64, 128, 256] + if not use_fp16: + block_k_range.remove(16) # BLOCK_K=16 not supported for fp8 + num_warps_range = [1, 2, 4, 8] + group_m_range = [1, 4, 8, 16, 32] + num_stage_range = [2] + waves_per_eu_range = [0] + matrix_instr_nonkdim_range = [16, 32] if use_fp16 else [] + kpack_range = [1, 2] if use_fp16 else [] + + param_ranges = { + "BLOCK_SIZE_M": block_mn_range, + "BLOCK_SIZE_N": block_mn_range, + "BLOCK_SIZE_K": block_k_range, + "GROUP_SIZE_M": group_m_range, + "num_warps": num_warps_range, + "num_stages": num_stage_range, + "waves_per_eu": waves_per_eu_range, + } + if use_fp16: + param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range + param_ranges["kpack"] = kpack_range + + return param_ranges + + +def get_configs_compute_bound(use_fp16, + block_quant_shape) -> list[dict[str, int]]: + configs: list[BenchmarkConfig] = [] + + if current_platform.is_rocm(): + param_ranges = get_rocm_tuning_space(use_fp16) + else: + # Reduced search space for faster tuning. + # TODO(woosuk): Increase the search space and use a performance model to + # prune the search space. + block_m_range = [16, 32, 64, 128, 256] + block_n_range = [32, 64, 128, 256] + block_k_range = [64, 128, 256] + num_warps_range = [4, 8] + group_m_range = [1, 16, 32, 64] + num_stage_range = [2, 3, 4, 5] + + param_ranges = { + "BLOCK_SIZE_M": block_m_range, + "BLOCK_SIZE_N": block_n_range, + "BLOCK_SIZE_K": block_k_range, + "GROUP_SIZE_M": group_m_range, + "num_warps": num_warps_range, + "num_stages": num_stage_range, + } + + keys, values = zip(*param_ranges.items()) + for config_values in product(*values): + config = dict(zip(keys, config_values)) + configs.append(config) + + # Remove configs that are not compatible with fp8 block quantization + # BLOCK_SIZE_K must be a multiple of block_k + # BLOCK_SIZE_N must be a multiple of block_n + if block_quant_shape is not None and not use_fp16: + block_n, block_k = block_quant_shape[0], block_quant_shape[1] + for config in configs[:]: + if config["BLOCK_SIZE_K"] % block_k != 0 or config[ + "BLOCK_SIZE_N"] % block_n != 0: + configs.remove(config) return configs +def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size, + search_space, is_fp16, topk): + N1, K1 = shard_intermediate_size, hidden_size + N2, K2 = hidden_size, shard_intermediate_size // 2 + pruned_space_1 = prune_rocm_configs(num_tokens * topk, N1, K1, + search_space, is_fp16) + pruned_space_2 = prune_rocm_configs(num_tokens * topk, N2, K2, + search_space, is_fp16) + search_space = merge_unique_dicts(pruned_space_1, pruned_space_2) + return search_space + + +# The following code is inspired by ROCm/Triton GEMM tuning script: +# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89 +def prune_rocm_configs(M, N, K, configs, is_fp16=True): + pruned_configs = [] + elemBytes_a = 2 if is_fp16 else 1 + elemBytes_b = 2 if is_fp16 else 1 + + mfma = 16 if M < 32 or N < 32 else 32 + + # TODO (zhanglx): figure out the boundary between large and small gemms + large_gemm = False + if M >= 2048 and N >= 2048: + large_gemm = True + + for config in configs: + BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") + BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") + BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") + num_warps = config.get("num_warps") + + if is_fp16: + matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") + if matrix_instr_nonkdim > mfma: + continue + if mfma == 4 and BLOCK_SIZE_K < 64: + continue + # some layouts could not work properly in case + # number elements per thread is less 1 + if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: + continue + SPLIT_K = config.get("SPLIT_K", 1) + GROUP_M = config.get("GROUP_SIZE_M") + if is_fp16: + if (matrix_instr_nonkdim > BLOCK_SIZE_M + or matrix_instr_nonkdim > BLOCK_SIZE_N): + continue + if (matrix_instr_nonkdim >= M + and matrix_instr_nonkdim != BLOCK_SIZE_M): + continue + if (matrix_instr_nonkdim >= N + and matrix_instr_nonkdim != BLOCK_SIZE_N): + continue + # Skip BLOCK_SIZE that is too large compare to M/N + # unless BLOCK_SIZE is already small enough + if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: + continue + if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16: + continue + # skip large split_k when not necessary + if SPLIT_K != 1 and not need_split_k(M, N, K): + continue + # skip split_k that leads to EVEN_K = false + leap = SPLIT_K * BLOCK_SIZE_K + modv = K % leap + if modv != 0: + continue + # skip large GROUP_M + if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: + continue + # out of shared memory resource + # TODO (zhanglx): This does not consider the LDS usage in the epilogue + LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) + if LDS > 65536: + continue + # Skip small block sizes and num_warps for large gemm + # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 + if large_gemm: + if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: + continue + if BLOCK_SIZE_K < 64: + continue + if num_warps < 4: + continue + + pruned_configs.append(config) + + return pruned_configs + + +def need_split_k(SIZE_M, SIZE_N, SIZE_K): + return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 + + +def merge_unique_dicts(list1, list2): + result = [] + combined_list = list1.copy() + combined_list.extend(list2) + for dictionary in combined_list: + if dictionary not in result: + result.append(dictionary) + return result + + @ray.remote(num_gpus=1) class BenchmarkWorker: def __init__(self, seed: int) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(seed) + current_platform.seed_everything(seed) self.seed = seed + # Get the device ID to allocate tensors and kernels + # on the respective GPU. This is required for Ray to work + # correctly with multi-GPU tuning on the ROCm platform. + self.device_id = int(ray.get_gpu_ids()[0]) def benchmark( self, @@ -155,25 +363,40 @@ def benchmark( hidden_size: int, topk: int, dtype: torch.dtype, - use_fp8: bool, - ) -> Tuple[Dict[str, int], float]: - torch.cuda.manual_seed_all(self.seed) - - dtype_str = "float8" if use_fp8 else None + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + block_quant_shape: List[int] = None, + ) -> tuple[dict[str, int], float]: + current_platform.seed_everything(self.seed) + dtype_str = get_config_dtype_str(dtype, + use_int8_w8a16=use_int8_w8a16, + use_fp8_w8a8=use_fp8_w8a8) # NOTE(woosuk): The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. op_config = get_moe_configs(num_experts, shard_intermediate_size // 2, dtype_str) if op_config is None: - config = get_default_config(num_tokens, num_experts, - shard_intermediate_size, hidden_size, - topk, dtype_str) + config = get_default_config(num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype_str, + is_marlin=False) else: config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] - kernel_time = benchmark_config(config, num_tokens, num_experts, - shard_intermediate_size, hidden_size, - topk, dtype, use_fp8) + kernel_time = benchmark_config(config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=100, + block_quant_shape=block_quant_shape) return config, kernel_time def tune( @@ -184,29 +407,43 @@ def tune( hidden_size: int, topk: int, dtype: torch.dtype, - use_fp8: bool, - search_space: List[BenchmarkConfig], - ) -> BenchmarkConfig: + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + search_space: list[dict[str, int]], + block_quant_shape: list[int], + ) -> dict[str, int]: best_config = None best_time = float("inf") - for config in tqdm(search_space): - try: - kernel_time = benchmark_config(config, - num_tokens, - num_experts, - shard_intermediate_size, - hidden_size, - topk, - dtype, - use_fp8, - num_iters=10) - except triton.runtime.autotuner.OutOfResources: - # Some configurations may be invalid and fail to compile. - continue - - if kernel_time < best_time: - best_time = kernel_time - best_config = config + if current_platform.is_rocm(): + is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) + search_space = prune_rocm_search_space(num_tokens, + shard_intermediate_size, + hidden_size, search_space, + is_fp16, topk) + + with torch.cuda.device(self.device_id) if current_platform.is_rocm( + ) else nullcontext(): + for config in tqdm(search_space): + try: + kernel_time = benchmark_config( + config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=20, + block_quant_shape=block_quant_shape) + except triton.runtime.autotuner.OutOfResources: + # Some configurations may be invalid and fail to compile. + continue + + if kernel_time < best_time: + best_time = kernel_time + best_config = config now = datetime.now() print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") assert best_config is not None @@ -215,44 +452,84 @@ def tune( def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: return { - "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], - "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], - "BLOCK_SIZE_K": config["BLOCK_SIZE_K"], - "GROUP_SIZE_M": config["GROUP_SIZE_M"], - "num_warps": config["num_warps"], - "num_stages": config["num_stages"], + "BLOCK_SIZE_M": + config["BLOCK_SIZE_M"], + "BLOCK_SIZE_N": + config["BLOCK_SIZE_N"], + "BLOCK_SIZE_K": + config["BLOCK_SIZE_K"], + "GROUP_SIZE_M": + config["GROUP_SIZE_M"], + "num_warps": + config["num_warps"], + "num_stages": + config["num_stages"], + **({ + "waves_per_eu": config["waves_per_eu"] + } if "waves_per_eu" in config else {}), + **({ + "matrix_instr_nonkdim": config["matrix_instr_nonkdim"] + } if "matrix_instr_nonkdim" in config else {}), + **({ + "kpack": config["kpack"] + } if "kpack" in config else {}), } -def save_configs( - configs: Dict[int, BenchmarkConfig], - num_experts: int, - shard_intermediate_size: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8: bool, -) -> None: - dtype_str = "float8" if use_fp8 else None +def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int, + shard_intermediate_size: int, hidden_size: int, topk: int, + dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + block_quant_shape: List[int]) -> None: + dtype_str = get_config_dtype_str(dtype, + use_int8_w8a16=use_int8_w8a16, + use_fp8_w8a8=use_fp8_w8a8) + # NOTE(woosuk): The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. filename = get_config_file_name(num_experts, shard_intermediate_size // 2, - dtype_str) + dtype_str, block_quant_shape) + print(f"Writing best config to {filename}...") with open(filename, "w") as f: json.dump(configs, f, indent=4) f.write("\n") +def get_weight_block_size_safety(config, default_value=None): + + quantization_config = getattr(config, 'quantization_config', {}) + if isinstance(quantization_config, dict): + return quantization_config.get('weight_block_size', default_value) + return default_value + + def main(args: argparse.Namespace): print(args) - - config = AutoConfig.from_pretrained(args.model) + block_quant_shape = None + config = AutoConfig.from_pretrained( + args.model, trust_remote_code=args.trust_remote_code) if config.architectures[0] == "DbrxForCausalLM": E = config.ffn_config.moe_num_experts topk = config.ffn_config.moe_top_k intermediate_size = config.ffn_config.ffn_hidden_size shard_intermediate_size = 2 * intermediate_size // args.tp_size + elif config.architectures[0] == "JambaForCausalLM": + E = config.num_experts + topk = config.num_experts_per_tok + intermediate_size = config.intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size + elif (config.architectures[0] == "DeepseekV3ForCausalLM" + or config.architectures[0] == "DeepseekV2ForCausalLM"): + E = config.n_routed_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size + block_quant_shape = get_weight_block_size_safety(config) + elif config.architectures[0] == "Qwen2MoeForCausalLM": + E = config.num_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size else: # Default: Mixtral. E = config.num_local_experts @@ -261,8 +538,9 @@ def main(args: argparse.Namespace): shard_intermediate_size = 2 * intermediate_size // args.tp_size hidden_size = config.hidden_size - dtype = config.torch_dtype - use_fp8 = args.dtype == "fp8" + dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype + use_fp8_w8a8 = args.dtype == "fp8_w8a8" + use_int8_w8a16 = args.dtype == "int8_w8a16" if args.batch_size is None: batch_sizes = [ @@ -276,7 +554,7 @@ def main(args: argparse.Namespace): num_gpus = int(ray.available_resources()["GPU"]) workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] - def _distribute(method: str, inputs: List[Any]) -> List[Any]: + def _distribute(method: str, inputs: list[Any]) -> list[Any]: outputs = [] worker_idx = 0 for input_args in inputs: @@ -288,27 +566,31 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]: return ray.get(outputs) if args.tune: - search_space = get_configs_compute_bound() + is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) + search_space = get_configs_compute_bound(is_fp16, block_quant_shape) print(f"Start tuning over {len(search_space)} configurations...") start = time.time() configs = _distribute( - "tune", [(batch_size, E, shard_intermediate_size, hidden_size, - topk, dtype, use_fp8, search_space) - for batch_size in batch_sizes]) + "tune", + [(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype, + use_fp8_w8a8, use_int8_w8a16, search_space, block_quant_shape) + for batch_size in batch_sizes]) best_configs = { M: sort_config(config) for M, config in zip(batch_sizes, configs) } save_configs(best_configs, E, shard_intermediate_size, hidden_size, - topk, dtype, use_fp8) + topk, dtype, use_fp8_w8a8, use_int8_w8a16, + block_quant_shape) end = time.time() print(f"Tuning took {end - start:.2f} seconds") else: - outputs = _distribute("benchmark", - [(batch_size, E, shard_intermediate_size, - hidden_size, topk, dtype, use_fp8) - for batch_size in batch_sizes]) + outputs = _distribute( + "benchmark", + [(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype, + use_fp8_w8a8, use_int8_w8a16, block_quant_shape) + for batch_size in batch_sizes]) for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): print(f"Batch size: {batch_size}, config: {config}") @@ -320,14 +602,19 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]: parser.add_argument("--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1") - parser.add_argument("--tp-size", "-tp", type=int, default=2) + parser.add_argument("--tp-size", + "-tp", + "--tensor-parallel-size", + type=int, + default=2) parser.add_argument("--dtype", type=str, - choices=["auto", "fp8"], + choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--tune", action="store_true") + parser.add_argument("--trust-remote-code", action="store_true") args = parser.parse_args() main(args) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index a04433142da4..48b351bc4814 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -1,15 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 + import random import time -from typing import List, Optional +from typing import Optional import torch from vllm import _custom_ops as ops +from vllm.platforms import current_platform from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, create_kv_caches_with_random) -NUM_BLOCKS = 1024 +NUM_BLOCKS = 128 * 1024 PARTITION_SIZE = 512 +PARTITION_SIZE_ROCM = 256 @torch.inference_mode() @@ -28,10 +32,7 @@ def main( device: str = "cuda", kv_cache_dtype: Optional[str] = None, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + current_platform.seed_everything(seed) scale = float(1.0 / (head_size**0.5)) query = torch.empty(num_seqs, @@ -54,7 +55,7 @@ def main( # Create the block tables. max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size - block_tables_lst: List[List[int]] = [] + block_tables_lst: list[list[int]] = [] for _ in range(num_seqs): block_table = [ random.randint(0, NUM_BLOCKS - 1) @@ -80,6 +81,12 @@ def main( # Prepare for the paged attention kernel. output = torch.empty_like(query) if version == "v2": + if current_platform.is_rocm(): + global PARTITION_SIZE + if not args.custom_paged_attn: + PARTITION_SIZE = 1024 + else: + PARTITION_SIZE = PARTITION_SIZE_ROCM num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) tmp_output = torch.empty( size=(num_seqs, num_query_heads, num_partitions, head_size), @@ -100,7 +107,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: start_time = time.perf_counter() # Using default kv_scale - k_scale = v_scale = 1.0 + k_scale = v_scale = torch.tensor(1.0, + dtype=torch.float32, + device=device) for _ in range(num_iters): if version == "v1": @@ -121,32 +130,53 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: v_scale, ) elif version == "v2": - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - ) + if not args.custom_paged_attn: + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + else: + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) else: raise ValueError(f"Invalid version: {version}") torch.cuda.synchronize() end_time = time.perf_counter() if profile: - torch.cuda.cudart().cudaProfilerStart() + torch.cuda.cudart().cudaProfilerStop() return (end_time - start_time) / num_iters # Warmup. @@ -193,6 +223,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: help="Data type for kv cache storage. If 'auto', will use model " "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)") + parser.add_argument("--custom-paged-attn", + action="store_true", + help="Use custom paged attention") args = parser.parse_args() print(args) diff --git a/benchmarks/kernels/benchmark_quant.py b/benchmarks/kernels/benchmark_quant.py new file mode 100644 index 000000000000..b643897a60ee --- /dev/null +++ b/benchmarks/kernels/benchmark_quant.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 + +import time + +import torch + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser + + +@torch.inference_mode() +def main(num_tokens: int, + hidden_size: int, + static_scale: bool, + quant_dtype: torch.dtype, + dtype: torch.dtype, + seed: int = 0, + do_profile: bool = False, + num_warmup_iters: int = 5, + num_iters: int = 100) -> None: + current_platform.seed_everything(seed) + torch.set_default_device("cuda") + + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + scale = torch.randn(1, 1, dtype=torch.float32) if static_scale else None + + def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: + torch.cuda.synchronize() + if profile: + torch.cuda.cudart().cudaProfilerStart() + start_time = time.perf_counter() + + for _ in range(num_iters): + if quant_dtype == torch.int8: + ops.scaled_int8_quant(x, scale) + else: + ops.scaled_fp8_quant(x, scale) + torch.cuda.synchronize() + + end_time = time.perf_counter() + if profile: + torch.cuda.cudart().cudaProfilerStop() + return (end_time - start_time) / num_iters + + # Warmup. + print("Warming up...") + run_benchmark = run_cuda_benchmark + run_benchmark(num_iters=num_warmup_iters, profile=False) + + # Benchmark. + if do_profile: + latency = run_benchmark(num_iters=1, profile=True) + else: + latency = run_benchmark(num_iters=num_iters, profile=False) + print(f"Kernel running time: {latency * 1000000:.3f} us") + + +if __name__ == '__main__': + + def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float8_e4m3fn + raise ValueError(f"Unsupported dtype: {dt}") + + parser = FlexibleArgumentParser( + description="Benchmark the quantization (fp8 or int8) kernel.") + parser.add_argument("--num-tokens", type=int, default=4096) + parser.add_argument("--hidden-size", type=int, default=8192) + parser.add_argument("--static-scale", action="store_true") + parser.add_argument("--quant-dtype", + type=str, + choices=["fp8", "int8"], + default="int8") + parser.add_argument("--dtype", + type=str, + choices=["half", "bfloat16", "float"], + default="half") + + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--profile", action="store_true") + parser.add_argument("--num-warmup-iters", type=int, default=5) + parser.add_argument("--num-iters", + type=int, + default=100, + help="Number of benchmark iterations. " + "If --profile is set, this number is ignored") + + args = parser.parse_args() + print(args) + + main(num_tokens=args.num_tokens, + hidden_size=args.hidden_size, + static_scale=args.static_scale, + quant_dtype=to_torch_dtype(args.quant_dtype), + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + seed=args.seed, + do_profile=args.profile, + num_warmup_iters=args.num_warmup_iters, + num_iters=args.num_iters) diff --git a/benchmarks/kernels/benchmark_rmsnorm.py b/benchmarks/kernels/benchmark_rmsnorm.py new file mode 100644 index 000000000000..eaf6b25e8ca4 --- /dev/null +++ b/benchmarks/kernels/benchmark_rmsnorm.py @@ -0,0 +1,264 @@ +# SPDX-License-Identifier: Apache-2.0 + +import itertools +from typing import Optional, Union + +import torch +import triton +from flashinfer.norm import fused_add_rmsnorm, rmsnorm +from torch import nn + +from vllm import _custom_ops as vllm_ops + + +class HuggingFaceRMSNorm(nn.Module): + + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + orig_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x.to(orig_dtype) * self.weight + if residual is None: + return x + else: + return x, residual + + +def rmsnorm_naive( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps) + naive_norm.weight = nn.Parameter(weight) + naive_norm = naive_norm.to(x.device) + + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + output = naive_norm(x, residual) + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def rmsnorm_flashinfer( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + if residual is not None: + fused_add_rmsnorm(x, residual, weight, eps) + output = (x, residual) + else: + output = rmsnorm(x, weight, eps) + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def rmsnorm_vllm( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + if residual is not None: + vllm_ops.fused_add_rms_norm(x, residual, weight, eps) + output = (x, residual) + else: + out = torch.empty_like(x) + vllm_ops.rms_norm(out, x, weight, eps) + output = out + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): + dtype = torch.bfloat16 + x = torch.randn(batch_size, + seq_len, + hidden_size, + dtype=dtype, + device="cuda") + weight = torch.ones(hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) if use_residual else None + + output_naive = rmsnorm_naive( + x.clone(), weight, + residual.clone() if residual is not None else None) + output_flashinfer = rmsnorm_flashinfer( + x.clone(), weight, + residual.clone() if residual is not None else None) + output_vllm = rmsnorm_vllm( + x.clone(), weight, + residual.clone() if residual is not None else None) + + if use_residual: + output_naive = output_naive[0] + output_flashinfer = output_flashinfer[0] + output_vllm = output_vllm[0] + + print(f"Naive output={output_naive}") + print(f"FlashInfer output={output_flashinfer}") + print(f"vLLM output={output_vllm}") + + if torch.allclose(output_naive, output_flashinfer, atol=1e-2, + rtol=1e-2) and torch.allclose( + output_naive, output_vllm, atol=1e-2, rtol=1e-2): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [2**i for i in range(0, 7, 2)] +seq_length_range = [2**i for i in range(6, 11, 1)] +head_num_range = [32, 48] +configs = list( + itertools.product(head_num_range, batch_size_range, seq_length_range)) + + +def get_benchmark(use_residual): + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["head_num", "batch_size", "seq_len"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["huggingface", "flashinfer", "vllm"], + line_names=["HuggingFace", "FlashInfer", "vLLM"], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], + ylabel="us", + plot_name= + f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual", + args={}, + )) + def benchmark(head_num, batch_size, seq_len, provider): + dtype = torch.bfloat16 + hidden_size = head_num * 128 # assuming head_dim = 128 + + x = torch.randn(batch_size, + seq_len, + hidden_size, + dtype=dtype, + device="cuda") + weight = torch.ones(hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) if use_residual else None + + quantiles = [0.5, 0.2, 0.8] + + if provider == "huggingface": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_naive( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + elif provider == "flashinfer": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_flashinfer( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_vllm( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--batch-size", + type=int, + default=4, + help="Batch size", + ) + parser.add_argument( + "--seq-len", + type=int, + default=128, + help="Sequence length", + ) + parser.add_argument( + "--hidden-size", + type=int, + default=4096, + help="Hidden size (2nd dimension) of the sequence", + ) + parser.add_argument("--use-residual", + action="store_true", + help="Whether to use residual connection") + parser.add_argument( + "--save-path", + type=str, + default="./configs/rmsnorm/", + help="Path to save rmsnorm benchmark results", + ) + + args = parser.parse_args() + + # Run correctness test + calculate_diff(batch_size=args.batch_size, + seq_len=args.seq_len, + hidden_size=args.hidden_size, + use_residual=args.use_residual) + + # Get the benchmark function with proper use_residual setting + benchmark = get_benchmark(args.use_residual) + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index f542684a9a2a..05d24fc4b16d 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -1,11 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 + from itertools import accumulate -from typing import List, Optional +from typing import Optional import nvtx import torch from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding, get_rope) +from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser @@ -22,9 +25,7 @@ def benchmark_rope_kernels_multi_lora( max_position: int = 8192, base: int = 10000, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + current_platform.seed_everything(seed) torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size @@ -33,17 +34,17 @@ def benchmark_rope_kernels_multi_lora( # batched RoPE can take multiple scaling factors batched_rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { - "type": "linear", + "rope_type": "linear", "factor": tuple(scaling_factors) }) # non-batched RoPE takes only one scaling factor, we create multiple # instances to simulate the same behavior - non_batched_ropes: List[RotaryEmbedding] = [] + non_batched_ropes: list[RotaryEmbedding] = [] for scaling_factor in scaling_factors: non_batched_ropes.append( get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { - "type": "linear", + "rope_type": "linear", "factor": (scaling_factor, ) })) diff --git a/benchmarks/kernels/benchmark_shapes.py b/benchmarks/kernels/benchmark_shapes.py index 4eeeca35a37c..c375e61e4187 100644 --- a/benchmarks/kernels/benchmark_shapes.py +++ b/benchmarks/kernels/benchmark_shapes.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + WEIGHT_SHAPES = { "ideal": [[4 * 256 * 32, 256 * 32]], "mistralai/Mistral-7B-v0.1/TP1": [ diff --git a/benchmarks/kernels/benchmark_w8a8_block_fp8.py b/benchmarks/kernels/benchmark_w8a8_block_fp8.py new file mode 100644 index 000000000000..8f07bc8ca52e --- /dev/null +++ b/benchmarks/kernels/benchmark_w8a8_block_fp8.py @@ -0,0 +1,420 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from sglang quantization/tuning_block_wise_kernel.py + +import argparse +import json +import multiprocessing as mp +import os +import time +from datetime import datetime +from typing import Any + +import torch +import tqdm +import triton + +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _w8a8_block_fp8_matmul) +from vllm.platforms import current_platform +from vllm.utils import FlexibleArgumentParser + +mp.set_start_method("spawn", force=True) + +assert current_platform.is_cuda( +), "Only support tune w8a8 block fp8 kernel on CUDA device." + +DTYPE_MAP = { + "float32": torch.float32, + "float16": torch.float16, + "half": torch.half, + "bfloat16": torch.bfloat16, +} + + +def w8a8_block_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + config: dict[str, Any], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + """This function performs matrix multiplication with + block-wise quantization. + + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + + Args: + A: The input tensor, e.g., activation. + B: The input tensor, e.g., weight. + As: The per-token-group quantization scale for `A`. + Bs: The per-block quantization scale for `B`. + block_size: The block size for per-block quantization. + It should be 2-dim, e.g., [128, 128]. + output_dytpe: The dtype of the returned tensor. + + Returns: + torch.Tensor: The result of matmul. + """ + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + M = A.numel() // A.shape[-1] + + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] + + C_shape = A.shape[:-1] + (N, ) + C = A.new_empty(C_shape, dtype=output_dtype) + + def grid(META): + return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * + triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + + if A.dtype == torch.float8_e4m3fn: + kernel = _w8a8_block_fp8_matmul + else: + raise RuntimeError( + "Currently, only support tune w8a8 block fp8 kernel.") + + kernel[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + **config, + ) + + return C + + +def get_configs_compute_bound(): + configs = [] + for num_stages in [2, 3, 4, 5]: + for block_m in [16, 32, 64, 128, 256]: + for block_k in [64, 128]: + for block_n in [32, 64, 128, 256]: + for num_warps in [4, 8]: + for group_size in [1, 16, 32, 64]: + configs.append({ + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + }) + return configs + + +def get_weight_shapes(tp_size): + # NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. + # Modify them, if you tune for another different model. + # cannot TP + total = [ + (512 + 64, 7168), + ((128 + 64) * 128, 7168), + (128 * (128 + 128), 512), + (7168, 16384), + (7168, 18432), + ] + # N can TP + n_tp = [ + (18432 * 2, 7168), + ((128 + 64) * 128, 7168), + (128 * (128 + 128), 512), + (24576, 1536), + (12288, 7168), + (4096, 7168), + ] + # K can TP + k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)] + + weight_shapes = [] + for t in total: + weight_shapes.append(t) + for n_t in n_tp: + new_t = (n_t[0] // tp_size, n_t[1]) + weight_shapes.append(new_t) + for k_t in k_tp: + new_t = (k_t[0], k_t[1] // tp_size) + weight_shapes.append(new_t) + return weight_shapes + + +def benchmark_config(A, + B, + As, + Bs, + block_size, + config, + out_dtype=torch.float16, + num_iters=10): + + def run(): + w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype) + + torch.cuda.synchronize() + # JIT complication & warmup + for _ in range(5): + run() + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + latencies: list[float] = [] + for i in range(num_iters): + torch.cuda.synchronize() + start_event.record() + run() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + return avg + + +def tune(M, N, K, block_size, out_dtype, search_space, input_type): + factor_for_scale = 1e-2 + + if input_type == "fp8": + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + A_fp32 = ( + (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * + fp8_max) + A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + B_fp32 = ( + (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * + fp8_max) + B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + else: + raise RuntimeError( + "Currently, only support tune w8a8 block fp8 kernel.") + + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + As = torch.rand(M, k_tiles, dtype=torch.float32, + device="cuda") * factor_for_scale + Bs = (torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda") * + factor_for_scale) + + best_config = None + best_time = float("inf") + for config in tqdm(search_space): + try: + kernel_time = benchmark_config( + A, + B, + As, + Bs, + block_size, + config, + out_dtype, + num_iters=10, + ) + except triton.runtime.autotuner.OutOfResources: + # Some configurations may be invalid and fail to compile. + continue + + if kernel_time < best_time: + best_time = kernel_time + best_config = config + now = datetime.now() + print(f"{now.ctime()}] Completed tuning for batch_size={M}") + assert best_config is not None + return best_config + + +def save_configs( + N, + K, + block_n, + block_k, + configs, + save_path, + input_type="fp8", +) -> None: + os.makedirs(save_path, exist_ok=True) + device_name = current_platform.get_device_name().replace(" ", "_") + json_file_name = ( + f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8," + f"block_shape=[{block_n},{block_k}].json") + + config_file_path = os.path.join(save_path, json_file_name) + print(f"Writing best config to {config_file_path}...") + + with open(config_file_path, "w") as f: + json.dump(configs, f, indent=4) + f.write("\n") + + +def tune_on_gpu(args_dict): + """Run tuning on a specific GPU.""" + gpu_id = args_dict["gpu_id"] + batch_sizes = args_dict["batch_sizes"] + weight_shapes = args_dict["weight_shapes"] + args = args_dict["args"] + + torch.cuda.set_device(gpu_id) + print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}") + + block_n = args.block_n + block_k = args.block_k + out_dtype = DTYPE_MAP[args.out_dtype] + save_path = args.save_path + input_type = args.input_type + + search_space = get_configs_compute_bound() + search_space = [ + config for config in search_space + if block_k % config["BLOCK_SIZE_K"] == 0 + ] + + start = time.time() + for shape in tqdm(weight_shapes, desc=f"GPU {gpu_id} - Shapes"): + N, K = shape[0], shape[1] + print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`") + benchmark_results = [ + tune( + batch_size, + N, + K, + [block_n, block_k], + out_dtype, + search_space, + input_type, + ) for batch_size in tqdm(batch_sizes, + desc=f"GPU {gpu_id} - Batch sizes") + ] + best_configs = { + M: config + for M, config in zip(batch_sizes, benchmark_results) + } + save_configs(N, K, block_n, block_k, best_configs, save_path, + input_type) + + end = time.time() + print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds") + + +def distribute_batch_sizes(batch_sizes, num_gpus): + """Distribute batch sizes across available GPUs.""" + batches_per_gpu = [] + for i in range(num_gpus): + start_idx = i * len(batch_sizes) // num_gpus + end_idx = (i + 1) * len(batch_sizes) // num_gpus + batches_per_gpu.append(batch_sizes[start_idx:end_idx]) + return batches_per_gpu + + +def main(args): + print(args) + num_gpus = torch.cuda.device_count() + if num_gpus == 0: + raise RuntimeError("No GPU available for tuning") + print(f"Found {num_gpus} GPUs for parallel tuning") + + torch.cuda.init() + + if args.batch_size is None: + batch_sizes = [ + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 48, + 64, + 96, + 128, + 256, + 512, + 1024, + 1536, + 2048, + 3072, + 4096, + ] + else: + batch_sizes = [args.batch_size] + num_gpus = 1 # If only one batch size, use only one GPU + + weight_shapes = get_weight_shapes(args.tp_size) + + batches_per_gpu = distribute_batch_sizes(batch_sizes, num_gpus) + + process_args = [] + for gpu_id in range(num_gpus): + process_args.append({ + "gpu_id": gpu_id, + "batch_sizes": batches_per_gpu[gpu_id], + "weight_shapes": + weight_shapes, # Each GPU processes all weight shapes + "args": args, + }) + + ctx = mp.get_context("spawn") + with ctx.Pool(num_gpus) as pool: + pool.map(tune_on_gpu, process_args) + + print("Multi-GPU tuning completed") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description=""" +Tune triton w8a8 block fp8 for DeepSeek-V3/DeepSeek-R1: + python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8 +Then copy to model_executor/layers/quantization/utils/configs + """, + formatter_class=argparse.RawTextHelpFormatter) + + parser.add_argument("--tp-size", "-tp", type=int, default=8) + parser.add_argument("--input-type", + type=str, + choices=["fp8"], + default="fp8") + parser.add_argument( + "--out-dtype", + type=str, + choices=["float32", "float16", "bfloat16", "half"], + default="float16", + ) + parser.add_argument("--block-n", type=int, default=128) + parser.add_argument("--block-k", type=int, default=128) + parser.add_argument("--batch-size", type=int, required=False) + parser.add_argument("--save-path", type=str, default="./") + args = parser.parse_args() + + main(args) diff --git a/benchmarks/kernels/deepgemm/README.md b/benchmarks/kernels/deepgemm/README.md new file mode 100644 index 000000000000..917e814010f8 --- /dev/null +++ b/benchmarks/kernels/deepgemm/README.md @@ -0,0 +1,129 @@ +# DeepSeek DeepGEMM Kernels Benchmark + +This directory includes benchmarks between DeepSeek's DeepGEMM block fp8 kernels against vLLM's existing triton and CUTLASS-based kernels. + +Currently this just includes dense GEMMs and only works on Hopper GPUs. + +## Setup + +You need to install vLLM in your usual fashion, then install DeepGEMM from source in its own directory: + +``` +git clone --recursive https://github.com/deepseek-ai/DeepGEMM +cd DeepGEMM +python setup.py install +uv pip install -e . +``` + +## Usage + +``` +python benchmark_fp8_block_dense_gemm.py +INFO 02-26 21:55:13 [__init__.py:207] Automatically detected platform cuda. +===== STARTING FP8 GEMM BENCHMARK ===== +PyTorch version: 2.5.1+cu124 +CUDA version: 12.4 +Triton version: 3.1.0 +Using device: NVIDIA H100 80GB HBM3 +WARNING 02-26 21:55:15 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +INFO 02-26 21:55:15 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel. +WARNING 02-26 21:55:16 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=18432,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +WARNING 02-26 21:55:17 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +INFO 02-26 21:55:17 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel. +INFO 02-26 21:55:17 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel. + +===== PERFORMANCE COMPARISON ===== + +DeepGEMM Implementation: ++------+-------+-------+-----------+--------+--------+ +| m | n | k | Time (μs) | TFLOPS | GB/s | ++------+-------+-------+-----------+--------+--------+ +| 8 | 4096 | 7168 | 102.9 | 4.6 | 286.4 | +| 8 | 7168 | 18432 | 70.8 | 29.8 | 1868.8 | +| 8 | 18432 | 7168 | 69.3 | 30.5 | 1911.8 | +| 64 | 4096 | 7168 | 69.1 | 54.4 | 439.0 | +| 64 | 7168 | 18432 | 69.4 | 243.6 | 1933.6 | +| 64 | 18432 | 7168 | 70.4 | 240.3 | 1917.2 | +| 64 | 24576 | 1536 | 70.1 | 68.9 | 584.6 | +| 64 | 32768 | 512 | 68.4 | 31.4 | 307.1 | +| 64 | 7168 | 16384 | 69.5 | 216.3 | 1718.5 | +| 128 | 4096 | 7168 | 141.1 | 53.3 | 222.1 | +| 128 | 7168 | 18432 | 71.9 | 470.5 | 1896.1 | +| 128 | 18432 | 7168 | 69.3 | 488.2 | 1988.2 | +| 1024 | 4096 | 7168 | 89.7 | 670.1 | 502.5 | +| 1024 | 18432 | 7168 | 279.0 | 969.8 | 635.2 | +| 2048 | 4096 | 7168 | 175.1 | 687.0 | 347.4 | +| 4096 | 4096 | 7168 | 335.4 | 717.0 | 275.1 | ++------+-------+-------+-----------+--------+--------+ + +vLLM Triton Implementation: ++------+-------+-------+-----------+--------+--------+--------------+ +| m | n | k | Time (μs) | TFLOPS | GB/s | vs DeepGEMM | ++------+-------+-------+-----------+--------+--------+--------------+ +| 8 | 4096 | 7168 | 74.0 | 6.3 | 398.2 | 1.39x faster | +| 8 | 7168 | 18432 | 89.6 | 23.6 | 1478.1 | 0.79x slower | +| 8 | 18432 | 7168 | 113.2 | 18.7 | 1170.4 | 0.61x slower | +| 64 | 4096 | 7168 | 79.4 | 47.3 | 382.2 | 0.87x slower | +| 64 | 7168 | 18432 | 98.5 | 171.7 | 1363.0 | 0.70x slower | +| 64 | 18432 | 7168 | 119.5 | 141.5 | 1129.4 | 0.59x slower | +| 64 | 24576 | 1536 | 37.6 | 128.4 | 1089.7 | 1.86x faster | +| 64 | 32768 | 512 | 38.7 | 55.5 | 542.6 | 1.77x faster | +| 64 | 7168 | 16384 | 86.1 | 174.5 | 1386.4 | 0.81x slower | +| 128 | 4096 | 7168 | 90.7 | 82.9 | 345.4 | 1.56x faster | +| 128 | 7168 | 18432 | 144.0 | 234.9 | 946.9 | 0.50x slower | +| 128 | 18432 | 7168 | 229.5 | 147.4 | 600.1 | 0.30x slower | +| 1024 | 4096 | 7168 | 242.3 | 248.2 | 186.1 | 0.37x slower | +| 1024 | 18432 | 7168 | 897.8 | 301.4 | 197.4 | 0.31x slower | +| 2048 | 4096 | 7168 | 463.0 | 259.7 | 131.4 | 0.38x slower | +| 4096 | 4096 | 7168 | 901.8 | 266.7 | 102.3 | 0.37x slower | ++------+-------+-------+-----------+--------+--------+--------------+ + +vLLM CUTLASS Implementation: ++------+-------+-------+-----------+--------+--------+--------------+--------------+ +| m | n | k | Time (μs) | TFLOPS | GB/s | vs DeepGEMM | vs Triton | ++------+-------+-------+-----------+--------+--------+--------------+--------------+ +| 8 | 4096 | 7168 | 34.6 | 13.6 | 852.3 | 2.98x faster | 2.14x faster | +| 8 | 7168 | 18432 | 78.9 | 26.8 | 1677.3 | 0.90x slower | 1.13x faster | +| 8 | 18432 | 7168 | 81.2 | 26.0 | 1631.1 | 0.85x slower | 1.39x faster | +| 64 | 4096 | 7168 | 36.9 | 101.9 | 822.9 | 1.87x faster | 2.15x faster | +| 64 | 7168 | 18432 | 87.4 | 193.4 | 1535.2 | 0.79x slower | 1.13x faster | +| 64 | 18432 | 7168 | 85.0 | 199.0 | 1587.6 | 0.83x slower | 1.41x faster | +| 64 | 24576 | 1536 | 28.0 | 172.8 | 1465.8 | 2.51x faster | 1.35x faster | +| 64 | 32768 | 512 | 28.8 | 74.5 | 728.5 | 2.37x faster | 1.34x faster | +| 64 | 7168 | 16384 | 77.9 | 193.0 | 1532.8 | 0.89x slower | 1.11x faster | +| 128 | 4096 | 7168 | 39.1 | 192.4 | 802.0 | 3.61x faster | 2.32x faster | +| 128 | 7168 | 18432 | 93.7 | 360.8 | 1454.2 | 0.77x slower | 1.54x faster | +| 128 | 18432 | 7168 | 85.7 | 394.8 | 1608.0 | 0.81x slower | 2.68x faster | +| 1024 | 4096 | 7168 | 99.7 | 603.1 | 452.2 | 0.90x slower | 2.43x faster | +| 1024 | 18432 | 7168 | 331.3 | 816.7 | 534.9 | 0.84x slower | 2.71x faster | +| 2048 | 4096 | 7168 | 198.3 | 606.6 | 306.7 | 0.88x slower | 2.34x faster | +| 4096 | 4096 | 7168 | 392.2 | 613.2 | 235.3 | 0.86x slower | 2.30x faster | ++------+-------+-------+-----------+--------+--------+--------------+--------------+ + +===== AVERAGE PERFORMANCE ===== ++----------------+------------+----------+---------------+ +| Implementation | Avg TFLOPS | Avg GB/s | Avg Time (ms) | ++----------------+------------+----------+---------------+ +| DeepGEMM | 310.98 | 1052.10 | 0.11 | +| vLLM Triton | 144.30 | 715.60 | 0.23 | +| vLLM CUTLASS | 286.78 | 1076.67 | 0.11 | ++----------------+------------+----------+---------------+ + +===== AVERAGE SPEEDUPS ===== ++-----------------------------+--------------+ +| Comparison | Speedup | ++-----------------------------+--------------+ +| DeepGEMM vs vLLM Triton | 1.71x faster | +| DeepGEMM vs vLLM CUTLASS | 0.94x slower | +| vLLM CUTLASS vs vLLM Triton | 1.84x faster | ++-----------------------------+--------------+ + +===== ACCURACY COMPARISON ===== ++----------------+-----------------------+ +| Implementation | Avg Diff vs Reference | ++----------------+-----------------------+ +| DeepGEMM | 0.000684 | +| vLLM Triton | 0.000684 | +| vLLM CUTLASS | 0.000684 | ++----------------+-----------------------+ +``` diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py new file mode 100644 index 000000000000..7892f126e7d6 --- /dev/null +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -0,0 +1,464 @@ +# SPDX-License-Identifier: Apache-2.0 +# fmt: off +# ruff: noqa: E501 +import time + +# Import DeepGEMM functions +import deep_gemm +import torch +import triton +from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor + +# Import vLLM functions +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, w8a8_block_fp8_matmul) + + +# Copied from +# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L9 +def per_token_cast_to_fp8( + x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Convert tensor to FP8 format with per-token scaling.""" + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to( + torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) + + +# Copied from +# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L17 +def per_block_cast_to_fp8( + x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Convert tensor to FP8 format with per-block scaling.""" + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), + dtype=x.dtype, + device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), ( + x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + + +def benchmark_shape(m: int, + n: int, + k: int, + warmup: int = 100, + repeat: int = 10000, + verbose: bool = False) -> dict: + """Benchmark all implementations for a specific (m, n, k) shape.""" + if verbose: + print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===") + + # Create test tensors + A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) + + # Reference result in BF16 + torch.cuda.synchronize() + C_ref = A @ B.t() + + # Pre-quantize B for all implementations + # (weights can be pre-quantized offline) + B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B) + B_vllm, B_scale_vllm = per_block_cast_to_fp8(B) + + # Block size configuration + block_size = [128, 128] + + # Pre-quantize A for all implementations + A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A) + A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm) + C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) + A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( + A, block_size[1], column_major_scales=True) + + # === DeepGEMM Implementation === + def deepgemm_gemm(): + # A quantization is inside the loop as it depends on activations + # A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A) + # A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8( + # A, block_size[1]) + # A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm) + # C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm), + (B_deepgemm, B_scale_deepgemm), + C_deepgemm) + return C_deepgemm + + # === vLLM Triton Implementation === + def vllm_triton_gemm(): + # A quantization is inside the loop as it depends on activations + # A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) + return w8a8_block_fp8_matmul(A_vllm, + B_vllm, + A_scale_vllm, + B_scale_vllm, + block_size, + output_dtype=torch.bfloat16) + + # === vLLM CUTLASS Implementation === + def vllm_cutlass_gemm(): + # A quantization is inside the loop as it depends on activations + # A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( + # A, block_size[1], column_major_scales=True) + return ops.cutlass_scaled_mm(A_vllm_cutlass, + B_vllm.T, + scale_a=A_scale_vllm_cutlass, + scale_b=B_scale_vllm.T, + out_dtype=torch.bfloat16) + + # Run correctness check first + if verbose: + print("Running correctness check...") + C_deepgemm = deepgemm_gemm() + C_vllm_triton = vllm_triton_gemm() + C_vllm_cutlass = vllm_cutlass_gemm() + + deepgemm_diff = calc_diff(C_deepgemm, C_ref) + vllm_triton_diff = calc_diff(C_vllm_triton, C_ref) + vllm_cutlass_diff = calc_diff(C_vllm_cutlass, C_ref) + + if verbose: + print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}") + print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}") + print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}") + print("vLLM Triton vs DeepGEMM difference: " + f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}") + print("vLLM CUTLASS vs DeepGEMM difference: " + f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}") + + # Benchmark implementations + implementations = { + "DeepGEMM": deepgemm_gemm, + "vLLM Triton": vllm_triton_gemm, + "vLLM CUTLASS": vllm_cutlass_gemm + } + + benchmark_results = { + "shape": { + "m": m, + "n": n, + "k": k + }, + "implementations": {} + } + + for name, func in implementations.items(): + # Warmup + for _ in range(warmup): + func() + torch.cuda.synchronize() + + # Timing loop + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + func() + torch.cuda.synchronize() + end = time.time() + + # Calculate timing and TFLOPS + avg_time_ms = (end - start) / repeat * 1000 + avg_time_us = avg_time_ms * 1000 + tflops = 2 * m * n * k / (avg_time_ms * 1e-3) / 1e12 + gb_s = (m * k + k * n + m * n * 2) / 1e9 / (avg_time_ms * 1e-3) + + benchmark_results["implementations"][name] = { + "time_ms": avg_time_ms, + "time_us": avg_time_us, + "tflops": tflops, + "gb_s": gb_s, + "diff": { + "DeepGEMM": + 0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm), + "Reference": + deepgemm_diff if name == "DeepGEMM" else + (vllm_triton_diff + if name == "vLLM Triton" else vllm_cutlass_diff) + } + } + + if verbose: + print( + f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s" + ) + + # Calculate speedups + baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"] + for name, data in benchmark_results["implementations"].items(): + if name != "DeepGEMM": + speedup = baseline / data["time_ms"] + benchmark_results["implementations"][name][ + "speedup_vs_deepgemm"] = speedup + if verbose: + print(f"DeepGEMM is {1/speedup:.2f}x " + f"{'faster' if 1/speedup > 1 else 'slower'} than {name}") + + vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][ + "time_ms"] + vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][ + "time_ms"] + cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time + benchmark_results["implementations"]["vLLM CUTLASS"][ + "speedup_vs_triton"] = cutlass_vs_triton + if verbose: + print( + f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x " + f"{'faster' if cutlass_vs_triton > 1 else 'slower'} than vLLM Triton" + ) + + return benchmark_results + + +def format_table_row(values, widths): + """Format a row with specified column widths.""" + return "| " + " | ".join(f"{val:{w}}" + for val, w in zip(values, widths)) + " |" + + +def print_table(headers, rows, title=None): + """Print a table with headers and rows.""" + if title: + print(f"\n{title}") + + # Calculate column widths based on headers and data + widths = [ + max(len(str(h)), max(len(str(row[i])) for row in rows)) + for i, h in enumerate(headers) + ] + + # Create separator line + separator = "+-" + "-+-".join("-" * w for w in widths) + "-+" + + # Print table + print(separator) + print(format_table_row(headers, widths)) + print(separator) + for row in rows: + print(format_table_row(row, widths)) + print(separator) + + +def format_speedup(value): + """Format speedup value with indicator if it's faster or slower.""" + return f"{value:.2f}x {'faster' if value > 1.0 else 'slower'}" + + +def run_benchmarks(verbose: bool = False): + """Run benchmarks for a set of common shapes.""" + print("===== STARTING FP8 GEMM BENCHMARK =====") + + # Make sure we're using the GPU + if not torch.cuda.is_available(): + print("CUDA not available! Tests require GPU.") + return + + # Print system information + print(f"PyTorch version: {torch.__version__}") + print(f"CUDA version: {torch.version.cuda}") + print(f"Triton version: {triton.__version__}") + print(f"Using device: {torch.cuda.get_device_name()}") + + # Enable TF32 for better performance + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + # Set seeds for reproducibility + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + # Define benchmark shapes (m, n, k) + shapes = [ + (8, 4096, 7168), + (8, 7168, 18432), + (8, 18432, 7168), + (64, 4096, 7168), + (64, 7168, 18432), + (64, 18432, 7168), + (64, 24576, 1536), + (64, 32768, 512), + (64, 7168, 16384), + (128, 4096, 7168), + (128, 7168, 18432), + (128, 18432, 7168), + (1024, 4096, 7168), + (1024, 18432, 7168), + (2048, 4096, 7168), + (4096, 4096, 7168), + ] + shapes = [ + # (64, 2112, 7168), + (64, 24576, 1536), + (64, 32768, 512), + (64, 7168, 16384), + (64, 4096, 7168), + (64, 7168, 2048), + # (128, 2112, 7168), + (128, 24576, 1536), + (128, 32768, 512), + (128, 7168, 16384), + (128, 4096, 7168), + (128, 7168, 2048), + # (4096, 2112, 7168), + (4096, 24576, 1536), + (4096, 32768, 512), + (4096, 7168, 16384), + (4096, 4096, 7168), + (4096, 7168, 2048), + ] + + all_results = [] + for m, n, k in shapes: + result = benchmark_shape(m, n, k, verbose=verbose) + all_results.append(result) + + # Print results in a nicely formatted table + print("\n===== PERFORMANCE COMPARISON =====") + + # Print DeepGEMM table + deepgemm_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s"] + deepgemm_rows = [] + for result in all_results: + shape = result["shape"] + impl_data = result["implementations"]["DeepGEMM"] + deepgemm_rows.append([ + shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}" + ]) + + print_table(deepgemm_headers, + deepgemm_rows, + title="DeepGEMM Implementation:") + + # Print vLLM Triton table + triton_headers = [ + "m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM" + ] + triton_rows = [] + for result in all_results: + shape = result["shape"] + impl_data = result["implementations"]["vLLM Triton"] + speedup = impl_data.get("speedup_vs_deepgemm", 1.0) + triton_rows.append([ + shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}", + format_speedup(speedup) + ]) + + print_table(triton_headers, + triton_rows, + title="vLLM Triton Implementation:") + + # Print vLLM CUTLASS table + cutlass_headers = [ + "m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM", + "vs Triton" + ] + cutlass_rows = [] + for result in all_results: + shape = result["shape"] + impl_data = result["implementations"]["vLLM CUTLASS"] + vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0) + vs_triton = impl_data.get("speedup_vs_triton", 1.0) + cutlass_rows.append([ + shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}", + format_speedup(vs_deepgemm), + format_speedup(vs_triton) + ]) + + print_table(cutlass_headers, + cutlass_rows, + title="vLLM CUTLASS Implementation:") + + # Calculate and print averages + print("\n===== AVERAGE PERFORMANCE =====") + + implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"] + avg_metrics = { + impl: { + "tflops": 0, + "gb_s": 0, + "time_ms": 0 + } + for impl in implementations + } + + for result in all_results: + for impl in implementations: + impl_data = result["implementations"][impl] + avg_metrics[impl]["tflops"] += impl_data["tflops"] + avg_metrics[impl]["gb_s"] += impl_data["gb_s"] + avg_metrics[impl]["time_ms"] += impl_data["time_ms"] + + num_shapes = len(all_results) + avg_headers = ["Implementation", "Avg TFLOPS", "Avg GB/s", "Avg Time (ms)"] + avg_rows = [] + + for impl in implementations: + avg_tflops = avg_metrics[impl]["tflops"] / num_shapes + avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes + avg_time = avg_metrics[impl]["time_ms"] / num_shapes + avg_rows.append([ + impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}" + ]) + + print_table(avg_headers, avg_rows) + + # Calculate average speedups + avg_speedups = { + "DeepGEMM vs vLLM Triton": 0, + "DeepGEMM vs vLLM CUTLASS": 0, + "vLLM CUTLASS vs vLLM Triton": 0 + } + + for result in all_results: + deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"] + vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"] + vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][ + "time_ms"] + + avg_speedups[ + "DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time + avg_speedups[ + "DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time + avg_speedups[ + "vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time + + print("\n===== AVERAGE SPEEDUPS =====") + speedup_headers = ["Comparison", "Speedup"] + speedup_rows = [] + for comparison, total in avg_speedups.items(): + avg_speedup = total / num_shapes + status = "faster" if avg_speedup > 1 else "slower" + speedup_rows.append([comparison, f"{avg_speedup:.2f}x {status}"]) + + print_table(speedup_headers, speedup_rows) + + # Average accuracy comparison + print("\n===== ACCURACY COMPARISON =====") + avg_diff = {impl: 0 for impl in implementations} + + for result in all_results: + for impl in implementations: + avg_diff[impl] += result["implementations"][impl]["diff"][ + "Reference"] + + diff_headers = ["Implementation", "Avg Diff vs Reference"] + diff_rows = [] + for impl in implementations: + diff_rows.append([impl, f"{avg_diff[impl] / num_shapes:.6f}"]) + + print_table(diff_headers, diff_rows) + + +if __name__ == "__main__": + run_benchmarks(verbose=False) diff --git a/benchmarks/kernels/graph_machete_bench.py b/benchmarks/kernels/graph_machete_bench.py new file mode 100644 index 000000000000..bd62173a7b3a --- /dev/null +++ b/benchmarks/kernels/graph_machete_bench.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 + +import math +import pickle +import re +from collections import defaultdict + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +from torch.utils.benchmark import Measurement as TMeasurement + +from vllm.utils import FlexibleArgumentParser + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description='Benchmark the latency of processing a single batch of ' + 'requests till completion.') + parser.add_argument('filename', type=str) + + args = parser.parse_args() + + with open(args.filename, 'rb') as f: + data = pickle.load(f) + raw_results: list[TMeasurement] = data["results"] + + results = defaultdict(lambda: list()) + for v in raw_results: + result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label) + if result is not None: + KN = result.group(1) + else: + raise Exception("MKN not found") + result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label) + if result is not None: + M = result.group(1) + else: + raise Exception("MKN not found") + + kernel = v.task_spec.description + results[KN].append({ + "kernel": kernel, + "batch_size": M, + "median": v.median + }) + + rows = int(math.ceil(len(results) / 2)) + fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows)) + axs = axs.flatten() + for axs_idx, (shape, data) in enumerate(results.items()): + plt.sca(axs[axs_idx]) + df = pd.DataFrame(data) + sns.lineplot(data=df, + x="batch_size", + y="median", + hue="kernel", + style="kernel", + markers=True, + dashes=False, + palette="Dark2") + plt.title(f"Shape: {shape}") + plt.ylabel("time (median, s)") + plt.tight_layout() + plt.savefig("graph_machete_bench.pdf") diff --git a/benchmarks/kernels/requirements.txt b/benchmarks/kernels/requirements.txt new file mode 100644 index 000000000000..1411a4a0b5ab --- /dev/null +++ b/benchmarks/kernels/requirements.txt @@ -0,0 +1 @@ +pandas \ No newline at end of file diff --git a/benchmarks/kernels/utils.py b/benchmarks/kernels/utils.py new file mode 100644 index 000000000000..ac64f786f184 --- /dev/null +++ b/benchmarks/kernels/utils.py @@ -0,0 +1,213 @@ +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from collections.abc import Iterable +from typing import Any, Callable, Optional + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement + + +@dataclasses.dataclass +class CudaGraphBenchParams: + num_ops_in_cuda_graph: int + + +@dataclasses.dataclass +class ArgPool: + """ + When some argument of the benchmarking function is annotated with this type, + the benchmarking class (BenchMM) will collapse the argument to a pick a + single value from the given list of values, during function invocation. + For every invocation during a benchmarking run, it will choose a + different value from the list. + """ + values: Iterable[Any] + + def __getitem__(self, index): + return self.values[index] + + +class Bench: + + class ArgsIterator: + + def __init__(self, args_list, kwargs_list): + assert len(args_list) == len(kwargs_list) + self.args_list = args_list + self.kwargs_list = kwargs_list + self.n = len(self.args_list) + self.idx = 0 + + def __next__(self): + while True: + yield (self.args_list[self.idx], self.kwargs_list[self.idx]) + self.idx += 1 + self.idx = self.idx % self.n + + def reset(self): + self.idx = 0 + + @property + def n_args(self): + return self.n + + def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams], + label: str, sub_label: str, description: str, fn: Callable, + *args, **kwargs): + + self.cuda_graph_params = cuda_graph_params + self.use_cuda_graph = self.cuda_graph_params is not None + self.label = label + self.sub_label = sub_label + self.description = description + self.fn = fn + + # Process args + self._args = args + self._kwargs = kwargs + self.args_list, self.kwargs_list = self.collapse_argpool( + *args, **kwargs) + self.args_iterator = self.ArgsIterator(self.args_list, + self.kwargs_list) + + # Cudagraph runner + self.g = None + if self.use_cuda_graph: + self.g = self.get_cuda_graph_runner() + + # benchmark run params + self.min_run_time = 1 + + def collapse_argpool(self, *args, **kwargs): + argpool_args = [arg for arg in args if isinstance(arg, ArgPool)] + [ + arg for arg in kwargs.values() if isinstance(arg, ArgPool) + ] + if len(argpool_args) == 0: + return [args], [kwargs] + + # Make sure all argpools are of the same size + argpool_size = len(argpool_args[0].values) + assert all([argpool_size == len(arg.values) for arg in argpool_args]) + + # create copies of the args + args_list = [] + kwargs_list = [] + for _ in range(argpool_size): + args_list.append(args) + kwargs_list.append(kwargs.copy()) + + for i in range(argpool_size): + # collapse args; Just pick the ith value + args_list[i] = tuple([ + arg[i] if isinstance(arg, ArgPool) else arg + for arg in args_list[i] + ]) + + # collapse kwargs + kwargs_i = kwargs_list[i] + arg_pool_keys = [ + k for k, v in kwargs_i.items() if isinstance(v, ArgPool) + ] + for k in arg_pool_keys: + # again just pick the ith value + kwargs_i[k] = kwargs_i[k][i] + kwargs_list[i] = kwargs_i + + return args_list, kwargs_list + + def get_cuda_graph_runner(self): + assert self.use_cuda_graph + assert self.args_iterator is not None + + num_graph_ops = self.cuda_graph_params.num_ops_in_cuda_graph + + # warmup + args_it = self.args_iterator.__next__() + for _ in range(2): + args, kwargs = next(args_it) + self.fn(*args, **kwargs) + + self.args_iterator.reset() + args_it = self.args_iterator.__next__() + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(num_graph_ops): + args, kwargs = next(args_it) + self.fn(*args, **kwargs) + return g + + def run_cudagrah(self) -> TMeasurement: + assert self.use_cuda_graph + globals = {'g': self.g} + + return TBenchmark.Timer( + stmt="g.replay()", + globals=globals, + label=( + f"{self.label}" + f" | cugraph {self.cuda_graph_params.num_ops_in_cuda_graph} ops" + ), + sub_label=self.sub_label, + description=self.description, + ).blocked_autorange(min_run_time=self.min_run_time) + + def run_eager(self) -> TMeasurement: + setup = None + stmt = None + globals = None + + has_arg_pool = self.args_iterator.n_args > 1 + if has_arg_pool: + setup = ''' + args_iterator.reset() + args_it = args_iterator.__next__() + ''' + stmt = ''' + args, kwargs = next(args_it) + fn(*args, **kwargs) + ''' + globals = {'fn': self.fn, 'args_iterator': self.args_iterator} + else: + # no arg pool. Just use the args and kwargs directly + self.args_iterator.reset() + args_it = self.args_iterator.__next__() + args, kwargs = next(args_it) + + setup = "" + stmt = ''' + fn(*args, **kwargs) + ''' + globals = {'fn': self.fn, 'args': args, 'kwargs': kwargs} + + return TBenchmark.Timer( + stmt=stmt, + setup=setup, + globals=globals, + label=self.label, + sub_label=self.sub_label, + description=self.description, + ).blocked_autorange(min_run_time=self.min_run_time) + + def run(self) -> TMeasurement: + timer = None + if self.use_cuda_graph: # noqa SIM108 + timer = self.run_cudagrah() + else: + timer = self.run_eager() + if not timer.meets_confidence() or timer.has_warnings: + print("Doesn't meet confidence - re-running bench ...") + return self.run() + return timer + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type: + print(f"exc type {exc_type}") + print(f"exc value {exc_value}") + print(f"exc traceback {traceback}") diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py new file mode 100644 index 000000000000..89b05d5882a3 --- /dev/null +++ b/benchmarks/kernels/weight_shapes.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Weight Shapes are in the format +# ([K, N], TP_SPLIT_DIM) +# Example: +# A shape of ([14336, 4096], 0) indicates the following GEMM shape, +# - TP1 : K = 14336, N = 4096 +# - TP2 : K = 7168, N = 4096 +# A shape of ([4096, 6144], 1) indicates the following GEMM shape, +# - TP1 : K = 4096, N = 6144 +# - TP4 : K = 4096, N = 1536 + +# TP1 shapes +WEIGHT_SHAPES = { + "mistralai/Mistral-7B-v0.1": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-2-7b-hf": [ + ([4096, 12288], 1), + ([4096, 4096], 0), + ([4096, 22016], 1), + ([11008, 4096], 0), + ], + "meta-llama/Llama-3-8b": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-2-13b-hf": [ + ([5120, 15360], 1), + ([5120, 5120], 0), + ([5120, 27648], 1), + ([13824, 5120], 0), + ], + "meta-llama/Llama-2-70b-hf": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], + "meta-llama/Llama-3.1-405b-hf": [ + ([16384, 18432], 1), + ([16384, 16384], 0), + ([16384, 106496], 1), + ([53248, 16384], 0), + ], +} diff --git a/benchmarks/launch_tgi_server.sh b/benchmarks/launch_tgi_server.sh index f491c90d0683..ba7383d88dc4 100755 --- a/benchmarks/launch_tgi_server.sh +++ b/benchmarks/launch_tgi_server.sh @@ -4,13 +4,13 @@ PORT=8000 MODEL=$1 TOKENS=$2 -docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \ - -v $PWD/data:/data \ - ghcr.io/huggingface/text-generation-inference:1.4.0 \ - --model-id $MODEL \ +docker run -e "HF_TOKEN=$HF_TOKEN" --gpus all --shm-size 1g -p $PORT:80 \ + -v "$PWD/data:/data" \ + ghcr.io/huggingface/text-generation-inference:2.2.0 \ + --model-id "$MODEL" \ --sharded false \ --max-input-length 1024 \ --max-total-tokens 2048 \ --max-best-of 5 \ --max-concurrent-requests 5000 \ - --max-batch-total-tokens $TOKENS + --max-batch-total-tokens "$TOKENS" diff --git a/benchmarks/overheads/benchmark_hashing.py b/benchmarks/overheads/benchmark_hashing.py index 203699e9a8d0..5f94552e9dc8 100644 --- a/benchmarks/overheads/benchmark_hashing.py +++ b/benchmarks/overheads/benchmark_hashing.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + import cProfile import pstats @@ -16,7 +18,6 @@ def main(args): enforce_eager=True, enable_prefix_caching=True, tensor_parallel_size=args.tensor_parallel_size, - use_v2_block_manager=args.use_v2_block_manager, ) sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) @@ -56,8 +57,5 @@ def main(args): parser.add_argument('--enable-prefix-caching', action='store_true', help='enable prefix caching') - parser.add_argument('--use-v2-block-manager', - action='store_true', - help='Use BlockSpaceMangerV2') args = parser.parse_args() main(args) diff --git a/benchmarks/run_structured_output_benchmark.sh b/benchmarks/run_structured_output_benchmark.sh new file mode 100755 index 000000000000..126dfbc24416 --- /dev/null +++ b/benchmarks/run_structured_output_benchmark.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +# Define the model to use +MODEL=${1:-"Qwen/Qwen2.5-7B-Instruct"} + +# Define the backend to use +BACKEND=${2:-"vllm"} + +# Define the dataset to use +DATASET=${3:-"xgrammar_bench"} + +# Define the guided decoding backend +GUIDED_BACKEND=${4:-"xgrammar"} + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +OUTPUT_DIR=${5:-"$SCRIPT_DIR/structured_output_benchmark_results"} + +GUIDED_RATIO=${6:-0.5} + +# Create output directory if it doesn't exist +mkdir -p "$OUTPUT_DIR" + +# Define QPS values to test +QPS_VALUES=(70 60 50 25 20 15 10) + +# Common parameters +COMMON_PARAMS="--backend $BACKEND \ + --model $MODEL \ + --dataset $DATASET \ + --structured-output-backend $GUIDED_BACKEND \ + --structured-output-ratio $GUIDED_RATIO \ + --save-results \ + --result-dir $OUTPUT_DIR" + +echo "Starting structured output benchmark with model: $MODEL" +echo "Backend: $BACKEND" +echo "Dataset: $DATASET" +echo "Structured output backend: $GUIDED_BACKEND" +echo "Results will be saved to: $OUTPUT_DIR" +echo "----------------------------------------" + +# Run benchmarks with different QPS values +for qps in "${QPS_VALUES[@]}"; do + echo "Running benchmark with QPS: $qps" + + # Get git hash and branch for the filename + GIT_HASH=$(git rev-parse --short HEAD 2>/dev/null || echo "unknown") + GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "unknown") + + # Construct filename for this run + FILENAME="${GUIDED_BACKEND}_${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json" + + # Run the benchmark + python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \ + --request-rate $qps \ + --result-filename "$FILENAME" \ + --tokenizer-mode ${TOKENIZER_MODE:-"auto"} \ + --port ${PORT:-8000} + + echo "Completed benchmark with QPS: $qps" + echo "----------------------------------------" +done + +echo "All benchmarks completed!" +echo "Results saved to: $OUTPUT_DIR" diff --git a/benchmarks/structured_schemas/structured_schema_1.json b/benchmarks/structured_schemas/structured_schema_1.json new file mode 100644 index 000000000000..13bd6b6d16c6 --- /dev/null +++ b/benchmarks/structured_schemas/structured_schema_1.json @@ -0,0 +1,19 @@ +{ + "type": "object", + "properties": { + "name": { "type": "string" }, + "email": { "type": "string" }, + "street": { "type": "string" }, + "city": { "type": "string" }, + "state": { "type": "string" }, + "zip": { "type": "string" }, + "phone": { "type": "string" }, + "website": { "type": "string" }, + "company": { "type": "string" }, + "age": { "type": "integer" } + }, + "required": [ + "name", + "email" + ] +} diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 118f9b28e0ae..345b75d62233 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -1,5 +1,14 @@ +include(FetchContent) + +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + set(MACOSX_FOUND TRUE) +endif() + + # # Define environment variables for special configurations # @@ -9,21 +18,40 @@ endif() include_directories("${CMAKE_SOURCE_DIR}/csrc") + +set (ENABLE_NUMA TRUE) + # # Check the compile flags # -list(APPEND CXX_COMPILE_FLAGS - "-fopenmp" - "-DVLLM_CPU_EXTENSION") -execute_process(COMMAND cat /proc/cpuinfo - RESULT_VARIABLE CPUINFO_RET - OUTPUT_VARIABLE CPUINFO) +if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64") + list(APPEND CXX_COMPILE_FLAGS + "-mf16c" + ) +endif() -if (NOT CPUINFO_RET EQUAL 0) - message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo") +if(MACOSX_FOUND) + list(APPEND CXX_COMPILE_FLAGS + "-Xpreprocessor" + "-fopenmp" + "-DVLLM_CPU_EXTENSION") +else() + list(APPEND CXX_COMPILE_FLAGS + "-fopenmp" + "-DVLLM_CPU_EXTENSION") endif() +if (NOT MACOSX_FOUND) + execute_process(COMMAND cat /proc/cpuinfo + RESULT_VARIABLE CPUINFO_RET + OUTPUT_VARIABLE CPUINFO) + if (NOT CPUINFO_RET EQUAL 0) + message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo") + endif() +endif() + + function (find_isa CPUINFO TARGET OUT) string(FIND ${CPUINFO} ${TARGET} ISA_FOUND) if(NOT ISA_FOUND EQUAL -1) @@ -44,10 +72,18 @@ endfunction() is_avx512_disabled(AVX512_DISABLED) -find_isa(${CPUINFO} "avx2" AVX2_FOUND) -find_isa(${CPUINFO} "avx512f" AVX512_FOUND) -find_isa(${CPUINFO} "POWER10" POWER10_FOUND) -find_isa(${CPUINFO} "POWER9" POWER9_FOUND) +if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + set(APPLE_SILICON_FOUND TRUE) +else() + find_isa(${CPUINFO} "avx2" AVX2_FOUND) + find_isa(${CPUINFO} "avx512f" AVX512_FOUND) + find_isa(${CPUINFO} "POWER10" POWER10_FOUND) + find_isa(${CPUINFO} "POWER9" POWER9_FOUND) + find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support + find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support + find_isa(${CPUINFO} "S390" S390_FOUND) +endif() + if (AVX512_FOUND AND NOT AVX512_DISABLED) list(APPEND CXX_COMPILE_FLAGS @@ -67,9 +103,11 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) else() message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.") endif() + elseif (AVX2_FOUND) list(APPEND CXX_COMPILE_FLAGS "-mavx2") message(WARNING "vLLM CPU backend using AVX2 ISA") + elseif (POWER9_FOUND OR POWER10_FOUND) message(STATUS "PowerPC detected") # Check for PowerPC VSX support @@ -77,18 +115,71 @@ elseif (POWER9_FOUND OR POWER10_FOUND) "-mvsx" "-mcpu=native" "-mtune=native") + +elseif (ASIMD_FOUND) + message(STATUS "ARMv8 or later architecture detected") + if(ARM_BF16_FOUND) + message(STATUS "BF16 extension detected") + set(MARCH_FLAGS "-march=armv8.2-a+bf16+dotprod+fp16") + add_compile_definitions(ARM_BF16_SUPPORT) + else() + message(WARNING "BF16 functionality is not available") + set(MARCH_FLAGS "-march=armv8.2-a+dotprod+fp16") + endif() + list(APPEND CXX_COMPILE_FLAGS ${MARCH_FLAGS}) +elseif(APPLE_SILICON_FOUND) + message(STATUS "Apple Silicon Detected") + set(ENABLE_NUMA OFF) +elseif (S390_FOUND) + message(STATUS "S390 detected") + # Check for S390 VXE support + list(APPEND CXX_COMPILE_FLAGS + "-mvx" + "-mzvector" + "-march=native" + "-mtune=native") else() - message(FATAL_ERROR "vLLM CPU backend requires AVX512 or AVX2 or Power9+ ISA support.") + message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA or ARMv8 support.") endif() -message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") - -list(APPEND LIBS "numa") - - # -# Define extension targets +# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 platforms) # +if (AVX512_FOUND AND NOT AVX512_DISABLED) + FetchContent_Declare( + oneDNN + GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git + GIT_TAG v3.7.1 + GIT_PROGRESS TRUE + GIT_SHALLOW TRUE + ) + + set(ONEDNN_LIBRARY_TYPE "STATIC") + set(ONEDNN_BUILD_DOC "OFF") + set(ONEDNN_BUILD_EXAMPLES "OFF") + set(ONEDNN_BUILD_TESTS "OFF") + set(ONEDNN_ENABLE_WORKLOAD "INFERENCE") + set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") + set(ONEDNN_BUILD_GRAPH "OFF") + set(ONEDNN_ENABLE_JIT_PROFILING "OFF") + set(ONEDNN_ENABLE_ITT_TASKS "OFF") + set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") + set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") + set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) + + FetchContent_MakeAvailable(oneDNN) + + list(APPEND LIBS dnnl) +endif() + +message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") + +if(ENABLE_NUMA) + list(APPEND LIBS numa) +else() + message(STATUS "NUMA is disabled") + add_compile_definitions(-DVLLM_NUMA_DISABLED) +endif() # # _C extension @@ -102,6 +193,16 @@ set(VLLM_EXT_SRC "csrc/cpu/pos_encoding.cpp" "csrc/cpu/torch_bindings.cpp") +if (AVX512_FOUND AND NOT AVX512_DISABLED) + set(VLLM_EXT_SRC + "csrc/cpu/quant.cpp" + ${VLLM_EXT_SRC}) +endif() + +# +# Define extension targets +# + define_gpu_extension_target( _C DESTINATION vllm @@ -113,6 +214,4 @@ define_gpu_extension_target( WITH_SOABI ) -add_custom_target(default) -message(STATUS "Enabling C extension.") -add_dependencies(default _C) +message(STATUS "Enabling C extension.") \ No newline at end of file diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake new file mode 100644 index 000000000000..6291475164ba --- /dev/null +++ b/cmake/external_projects/flashmla.cmake @@ -0,0 +1,66 @@ +include(FetchContent) + +# If FLASH_MLA_SRC_DIR is set, flash-mla is installed from that directory +# instead of downloading. +# It can be set as an environment variable or passed as a cmake argument. +# The environment variable takes precedence. +if (DEFINED ENV{FLASH_MLA_SRC_DIR}) + set(FLASH_MLA_SRC_DIR $ENV{FLASH_MLA_SRC_DIR}) +endif() + +if(FLASH_MLA_SRC_DIR) + FetchContent_Declare( + flashmla + SOURCE_DIR ${FLASH_MLA_SRC_DIR} + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + ) +else() + FetchContent_Declare( + flashmla + GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git + GIT_TAG 575f7724b9762f265bbee5889df9c7d630801845 + GIT_PROGRESS TRUE + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + ) +endif() + + +FetchContent_MakeAvailable(flashmla) +message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}") + +# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later. +# Only build FlashMLA kernels if we are building for something compatible with +# sm90a +cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}") +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) + set(FlashMLA_SOURCES + ${flashmla_SOURCE_DIR}/csrc/flash_api.cpp + ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu + ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_fp16_sm90.cu + ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_metadata.cu) + + set(FlashMLA_INCLUDES + ${flashmla_SOURCE_DIR}/csrc/cutlass/include + ${flashmla_SOURCE_DIR}/csrc/include) + + set_gencode_flags_for_srcs( + SRCS "${FlashMLA_SOURCES}" + CUDA_ARCHS "${FLASH_MLA_ARCHS}") + + define_gpu_extension_target( + _flashmla_C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${FlashMLA_SOURCES} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES} + USE_SABI 3 + WITH_SOABI) +else() + # Create an empty target for setup.py when not targeting sm90a systems + add_custom_target(_flashmla_C) +endif() + diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake new file mode 100644 index 000000000000..afd7c47e8ac0 --- /dev/null +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -0,0 +1,67 @@ +# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target +# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the +# arches in the CUDA case (and instead set the gencodes on a per file basis) +# we need to manually set VLLM_GPU_ARCHES here. +if(VLLM_GPU_LANG STREQUAL "CUDA") + foreach(_ARCH ${CUDA_ARCHS}) + string(REPLACE "." "" _ARCH "${_ARCH}") + list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real") + endforeach() +endif() + +# +# Build vLLM flash attention from source +# +# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM. +# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs. +# They should be identical but if they aren't, this is a massive footgun. +# +# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place. +# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3). +# If no component is specified, vllm-flash-attn is still installed. + +# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading. +# This is to enable local development of vllm-flash-attn within vLLM. +# It can be set as an environment variable or passed as a cmake argument. +# The environment variable takes precedence. +if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR}) + set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR}) +endif() + +if(VLLM_FLASH_ATTN_SRC_DIR) + FetchContent_Declare( + vllm-flash-attn SOURCE_DIR + ${VLLM_FLASH_ATTN_SRC_DIR} + BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn + ) +else() + FetchContent_Declare( + vllm-flash-attn + GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git + GIT_TAG dc9d410b3e2d6534a4c70724c2515f4def670a22 + GIT_PROGRESS TRUE + # Don't share the vllm-flash-attn build between build types + BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn + ) +endif() + + +# Fetch the vllm-flash-attn library +FetchContent_MakeAvailable(vllm-flash-attn) +message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}") + +# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in +# case only one is built, in the case both are built redundant work is done) +install( + DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ + DESTINATION vllm_flash_attn + COMPONENT _vllm_fa2_C + FILES_MATCHING PATTERN "*.py" +) + +install( + DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ + DESTINATION vllm_flash_attn + COMPONENT _vllm_fa3_C + FILES_MATCHING PATTERN "*.py" +) diff --git a/cmake/hipify.py b/cmake/hipify.py index 340e41c8179e..a15577125eb1 100755 --- a/cmake/hipify.py +++ b/cmake/hipify.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 # # A command line tool for running pytorch's hipify preprocessor on CUDA diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 4869cad54113..c9cd099b82a7 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -58,8 +58,8 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS) # set(SRCS ${ORIG_SRCS}) set(CXX_SRCS ${ORIG_SRCS}) - list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)$") - list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)$") + list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)|(hip)$") + list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)|(hip)$") # # Generate ROCm/HIP source file names from CUDA file names. @@ -133,10 +133,202 @@ macro(string_to_ver OUT_VER IN_STR) string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR}) endmacro() +# +# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in +# `CUDA_ARCH_FLAGS`. +# +# Example: +# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75" +# clear_cuda_arches(CUDA_ARCH_FLAGS) +# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75" +# CMAKE_CUDA_FLAGS="-Wall" +# +macro(clear_cuda_arches CUDA_ARCH_FLAGS) + # Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS` + string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS + ${CMAKE_CUDA_FLAGS}) + + # Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified + # and passed back via the `CUDA_ARCHITECTURES` property. + string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS + ${CMAKE_CUDA_FLAGS}) +endmacro() + +# +# Extract unique CUDA architectures from a list of compute capabilities codes in +# the form `[]`, convert them to the form sort +# `.`, dedupes them and then sorts them in ascending order and +# stores them in `OUT_ARCHES`. +# +# Example: +# CUDA_ARCH_FLAGS="-gencode arch=compute_75,code=sm_75;...;-gencode arch=compute_90a,code=sm_90a" +# extract_unique_cuda_archs_ascending(OUT_ARCHES CUDA_ARCH_FLAGS) +# OUT_ARCHES="7.5;...;9.0" +function(extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS) + set(_CUDA_ARCHES) + foreach(_ARCH ${CUDA_ARCH_FLAGS}) + string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH}) + if (_COMPUTE) + set(_COMPUTE ${CMAKE_MATCH_1}) + endif() + + string_to_ver(_COMPUTE_VER ${_COMPUTE}) + list(APPEND _CUDA_ARCHES ${_COMPUTE_VER}) + endforeach() + + list(REMOVE_DUPLICATES _CUDA_ARCHES) + list(SORT _CUDA_ARCHES COMPARE NATURAL ORDER ASCENDING) + set(${OUT_ARCHES} ${_CUDA_ARCHES} PARENT_SCOPE) +endfunction() + +# +# For a specific file set the `-gencode` flag in compile options conditionally +# for the CUDA language. +# +# Example: +# set_gencode_flag_for_srcs( +# SRCS "foo.cu" +# ARCH "compute_75" +# CODE "sm_75") +# adds: "-gencode arch=compute_75,code=sm_75" to the compile options for +# `foo.cu` (only for the CUDA language). +# +macro(set_gencode_flag_for_srcs) + set(options) + set(oneValueArgs ARCH CODE) + set(multiValueArgs SRCS) + cmake_parse_arguments(arg "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN} ) + set(_FLAG -gencode arch=${arg_ARCH},code=${arg_CODE}) + set_property( + SOURCE ${arg_SRCS} + APPEND PROPERTY + COMPILE_OPTIONS "$<$:${_FLAG}>" + ) + + message(DEBUG "Setting gencode flag for ${arg_SRCS}: ${_FLAG}") +endmacro(set_gencode_flag_for_srcs) + +# +# For a list of source files set the `-gencode` flags in the files specific +# compile options (specifically for the CUDA language). +# +# arguments are: +# SRCS: list of source files +# CUDA_ARCHS: list of CUDA architectures in the form `.[letter]` +# BUILD_PTX_FOR_ARCH: if set to true, then the PTX code will be built +# for architecture `BUILD_PTX_FOR_ARCH` if there is a CUDA_ARCH in CUDA_ARCHS +# that is larger than BUILD_PTX_FOR_ARCH. +# +macro(set_gencode_flags_for_srcs) + set(options) + set(oneValueArgs BUILD_PTX_FOR_ARCH) + set(multiValueArgs SRCS CUDA_ARCHS) + cmake_parse_arguments(arg "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN} ) + + foreach(_ARCH ${arg_CUDA_ARCHS}) + string(REPLACE "." "" _ARCH "${_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_ARCH}" + CODE "sm_${_ARCH}") + endforeach() + + if (${arg_BUILD_PTX_FOR_ARCH}) + list(SORT arg_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) + list(GET arg_CUDA_ARCHS -1 _HIGHEST_ARCH) + if (_HIGHEST_ARCH VERSION_GREATER_EQUAL ${arg_BUILD_PTX_FOR_ARCH}) + string(REPLACE "." "" _PTX_ARCH "${arg_BUILD_PTX_FOR_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_PTX_ARCH}" + CODE "compute_${_PTX_ARCH}") + endif() + endif() +endmacro() + +# +# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form +# `.[letter]` compute the "loose intersection" with the +# `TGT_CUDA_ARCHS` list of gencodes. +# The loose intersection is defined as: +# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} } +# where `<=` is the version comparison operator. +# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version +# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`. +# We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is +# in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add +# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS). +# The result is stored in `OUT_CUDA_ARCHS`. +# +# Example: +# SRC_CUDA_ARCHS="7.5;8.0;8.6;9.0;9.0a" +# TGT_CUDA_ARCHS="8.0;8.9;9.0" +# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) +# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a" +# +function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) + list(REMOVE_DUPLICATES SRC_CUDA_ARCHS) + set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS}) + + # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should + # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS + set(_CUDA_ARCHS) + if ("9.0a" IN_LIST SRC_CUDA_ARCHS) + list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a") + if ("9.0" IN_LIST TGT_CUDA_ARCHS_) + list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0") + set(_CUDA_ARCHS "9.0a") + endif() + endif() + + if ("10.0a" IN_LIST SRC_CUDA_ARCHS) + list(REMOVE_ITEM SRC_CUDA_ARCHS "10.0a") + if ("10.0" IN_LIST TGT_CUDA_ARCHS) + list(REMOVE_ITEM TGT_CUDA_ARCHS_ "10.0") + set(_CUDA_ARCHS "10.0a") + endif() + endif() + + list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) + + # for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that + # is less or equal to ARCH (but has the same major version since SASS binary + # compatibility is only forward compatible within the same major version). + foreach(_ARCH ${TGT_CUDA_ARCHS_}) + set(_TMP_ARCH) + # Extract the major version of the target arch + string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}") + foreach(_SRC_ARCH ${SRC_CUDA_ARCHS}) + # Extract the major version of the source arch + string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}") + # Check major-version match AND version-less-or-equal + if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH) + if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR) + set(_TMP_ARCH "${_SRC_ARCH}") + endif() + else() + # If we hit a version greater than the target, we can break + break() + endif() + endforeach() + + # If we found a matching _TMP_ARCH, append it to _CUDA_ARCHS + if (_TMP_ARCH) + list(APPEND _CUDA_ARCHS "${_TMP_ARCH}") + endif() + endforeach() + + list(REMOVE_DUPLICATES _CUDA_ARCHS) + set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE) +endfunction() + # # Override the GPU architectures detected by cmake/torch and filter them by # `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in -# `GPU_ARCHES`. +# `GPU_ARCHES`. This only applies to the HIP language since for CUDA we set +# the architectures on a per file basis. # # Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`. # @@ -174,109 +366,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES) "None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is" " supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.") endif() - - elseif(${GPU_LANG} STREQUAL "CUDA") - # - # Setup/process CUDA arch flags. - # - # The torch cmake setup hardcodes the detected architecture flags in - # `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it - # can't modified on a per-target basis, e.g. for the `punica` extension. - # So, all the `-gencode` flags need to be extracted and removed from - # `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method. - # Since it's not possible to use `target_compiler_options` for adding target - # specific `-gencode` arguments, the target's `CUDA_ARCHITECTURES` property - # must be used instead. This requires repackaging the architecture flags - # into a format that cmake expects for `CUDA_ARCHITECTURES`. - # - # This is a bit fragile in that it depends on torch using `-gencode` as opposed - # to one of the other nvcc options to specify architectures. - # - # Note: torch uses the `TORCH_CUDA_ARCH_LIST` environment variable to override - # detected architectures. - # - message(DEBUG "initial CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}") - - # Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS` - string(REGEX MATCHALL "-gencode arch=[^ ]+" _CUDA_ARCH_FLAGS - ${CMAKE_CUDA_FLAGS}) - - # Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified - # and passed back via the `CUDA_ARCHITECTURES` property. - string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS - ${CMAKE_CUDA_FLAGS}) - - # If this error is triggered, it might mean that torch has changed how it sets - # up nvcc architecture code generation flags. - if (NOT _CUDA_ARCH_FLAGS) - message(FATAL_ERROR - "Could not find any architecture related code generation flags in " - "CMAKE_CUDA_FLAGS. (${CMAKE_CUDA_FLAGS})") - endif() - - message(DEBUG "final CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}") - message(DEBUG "arch flags: ${_CUDA_ARCH_FLAGS}") - - # Initialize the architecture lists to empty. - set(${GPU_ARCHES}) - - # Process each `gencode` flag. - foreach(_ARCH ${_CUDA_ARCH_FLAGS}) - # For each flag, extract the version number and whether it refers to PTX - # or native code. - # Note: if a regex matches then `CMAKE_MATCH_1` holds the binding - # for that match. - - string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH}) - if (_COMPUTE) - set(_COMPUTE ${CMAKE_MATCH_1}) - endif() - - string(REGEX MATCH "code=sm_\([0-9]+a?\)" _SM ${_ARCH}) - if (_SM) - set(_SM ${CMAKE_MATCH_1}) - endif() - - string(REGEX MATCH "code=compute_\([0-9]+a?\)" _CODE ${_ARCH}) - if (_CODE) - set(_CODE ${CMAKE_MATCH_1}) - endif() - - # Make sure the virtual architecture can be matched. - if (NOT _COMPUTE) - message(FATAL_ERROR - "Could not determine virtual architecture from: ${_ARCH}.") - endif() - - # One of sm_ or compute_ must exist. - if ((NOT _SM) AND (NOT _CODE)) - message(FATAL_ERROR - "Could not determine a codegen architecture from: ${_ARCH}.") - endif() - - if (_SM) - # -real suffix let CMake to only generate elf code for the kernels. - # we want this, otherwise the added ptx (default) will increase binary size. - set(_VIRT "-real") - set(_CODE_ARCH ${_SM}) - else() - # -virtual suffix let CMake to generate ptx code for the kernels. - set(_VIRT "-virtual") - set(_CODE_ARCH ${_CODE}) - endif() - - # Check if the current version is in the supported arch list. - string_to_ver(_CODE_VER ${_CODE_ARCH}) - if (NOT _CODE_VER IN_LIST _GPU_SUPPORTED_ARCHES_LIST) - message(STATUS "discarding unsupported CUDA arch ${_VER}.") - continue() - endif() - - # Add it to the arch list. - list(APPEND ${GPU_ARCHES} "${_CODE_ARCH}${_VIRT}") - endforeach() endif() - message(STATUS "${GPU_LANG} target arches: ${${GPU_ARCHES}}") endmacro() # @@ -350,17 +440,15 @@ function (define_gpu_extension_target GPU_MOD_NAME) target_include_directories(${GPU_MOD_NAME} PRIVATE csrc ${GPU_INCLUDE_DIRECTORIES}) - target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${torch_python_LIBRARY} - ${GPU_LIBRARIES}) + target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES}) # Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of # dependencies that are not necessary and may not be installed. if (GPU_LANGUAGE STREQUAL "CUDA") - target_link_libraries(${GPU_MOD_NAME} PRIVATE ${CUDA_CUDA_LIB} - ${CUDA_LIBRARIES}) + target_link_libraries(${GPU_MOD_NAME} PRIVATE CUDA::cudart CUDA::cuda_driver) else() target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES}) endif() - install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION}) + install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME}) endfunction() diff --git a/collect_env.py b/collect_env.py index 083cb768f539..0ec9d4cae4ba 100644 --- a/collect_env.py +++ b/collect_env.py @@ -1,17 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 + # ruff: noqa # code borrowed from https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py -# Unlike the rest of the PyTorch this file must be python2 compliant. -# This script outputs relevant system environment info -# Run it with `python collect_env.py` or `python -m torch.utils.collect_env` import datetime import locale import os import re import subprocess import sys +# Unlike the rest of the PyTorch this file must be python2 compliant. +# This script outputs relevant system environment info +# Run it with `python collect_env.py` or `python -m torch.utils.collect_env` from collections import namedtuple +from vllm.envs import environment_variables + try: import torch TORCH_AVAILABLE = True @@ -52,6 +56,7 @@ 'vllm_version', # vllm specific field 'vllm_build_flags', # vllm specific field 'gpu_topo', # vllm specific field + 'env_vars', ]) DEFAULT_CONDA_PATTERNS = { @@ -65,6 +70,9 @@ "optree", "nccl", "transformers", + "zmq", + "nvidia", + "pynvml", } DEFAULT_PIP_PATTERNS = { @@ -77,6 +85,9 @@ "onnx", "nccl", "transformers", + "zmq", + "nvidia", + "pynvml", } @@ -261,12 +272,16 @@ def get_neuron_sdk_version(run_lambda): def get_vllm_version(): - try: - import vllm - return vllm.__version__ - except ImportError: - return 'N/A' + from vllm import __version__, __version_tuple__ + + if __version__ == "dev": + return "N/A (dev)" + + if len(__version_tuple__) == 4: # dev build + git_sha = __version_tuple__[-1][1:] # type: ignore + return f"{__version__} (git sha: {git_sha}" + return __version__ def summarize_vllm_build_flags(): # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc. @@ -278,9 +293,14 @@ def summarize_vllm_build_flags(): def get_gpu_topo(run_lambda): + output = None + if get_platform() == 'linux': - return run_and_read_all(run_lambda, 'nvidia-smi topo -m') - return None + output = run_and_read_all(run_lambda, 'nvidia-smi topo -m') + if output is None: + output = run_and_read_all(run_lambda, 'rocm-smi --showtopo') + + return output # example outputs of CPU infos @@ -497,6 +517,22 @@ def is_xnnpack_available(): else: return "N/A" +def get_env_vars(): + env_vars = '' + secret_terms=('secret', 'token', 'api', 'access', 'password') + report_prefix = ("TORCH", "NCCL", "PYTORCH", + "CUDA", "CUBLAS", "CUDNN", + "OMP_", "MKL_", + "NVIDIA") + for k, v in os.environ.items(): + if any(term in k.lower() for term in secret_terms): + continue + if k in environment_variables: + env_vars = env_vars + "{}={}".format(k, v) + "\n" + if k.startswith(report_prefix): + env_vars = env_vars + "{}={}".format(k, v) + "\n" + + return env_vars def get_env_info(): run_lambda = run @@ -568,6 +604,7 @@ def get_version_or_na(cfg, prefix): vllm_version=vllm_version, vllm_build_flags=vllm_build_flags, gpu_topo=gpu_topo, + env_vars=get_env_vars(), ) @@ -616,6 +653,8 @@ def get_version_or_na(cfg, prefix): {vllm_build_flags} GPU Topology: {gpu_topo} + +{env_vars} """.strip() diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 5ed1dc3b8f79..88275dbdd83a 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -9,8 +9,16 @@ namespace vllm { +template +__device__ __forceinline__ scalar_t compute(const scalar_t& x, + const scalar_t& y) { + return act_first ? ACT_FN(x) * y : x * ACT_FN(y); +} // Activation and gating kernel template. -template + +template __global__ void act_and_mul_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2, d] @@ -19,7 +27,7 @@ __global__ void act_and_mul_kernel( for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); - out[token_idx * d + idx] = ACT_FN(x) * y; + out[token_idx * d + idx] = compute(x, y); } } @@ -55,7 +63,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { } // namespace vllm // Launch activation and gating kernel. -#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ +// Use ACT_FIRST (bool) indicating whether to apply the activation function +// first. +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \ int d = input.size(-1) / 2; \ int64_t num_tokens = input.numel() / input.size(-1); \ dim3 grid(num_tokens); \ @@ -64,7 +74,7 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ VLLM_DISPATCH_FLOATING_TYPES( \ input.scalar_type(), "act_and_mul_kernel", [&] { \ - vllm::act_and_mul_kernel> \ + vllm::act_and_mul_kernel, ACT_FIRST> \ <<>>(out.data_ptr(), \ input.data_ptr(), d); \ }); @@ -72,21 +82,71 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { void silu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { - LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true); +} + +void mul_and_silu(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + // The difference between mul_and_silu and silu_and_mul is that mul_and_silu + // applies the silu to the latter half of the input. + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false); } void gelu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { - LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel); + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, true); } void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { - LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, true); +} + +namespace vllm { + +template +__device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) { + const float f = (float)x; + return (T)(f > threshold ? f : 0.0f); } +template +__global__ void act_and_mul_kernel_with_param( + scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d, + const float param) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); + const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); + out[token_idx * d + idx] = ACT_FN(x, param) * y; + } +} + +} // namespace vllm + +#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \ + vllm::act_and_mul_kernel_with_param> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d, \ + PARAM); \ + }); + +void fatrelu_and_mul(torch::Tensor& out, // [..., d], + torch::Tensor& input, // [..., 2 * d] + double threshold) { + LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold); +} namespace vllm { // Element-wise activation kernel template. diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu deleted file mode 100644 index bcd170411e7c..000000000000 --- a/csrc/attention/attention_kernels.cu +++ /dev/null @@ -1,1002 +0,0 @@ -/* - * Adapted from - * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * Copyright (c) 2023, The vLLM team. - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include - -#include "attention_dtypes.h" -#include "attention_utils.cuh" - -#ifdef USE_ROCM - #include - #include "../quantization/fp8/amd/quant_utils.cuh" -typedef __hip_bfloat16 __nv_bfloat16; -#else - #include "../quantization/fp8/nvidia/quant_utils.cuh" -#endif - -#ifndef USE_ROCM - #define WARP_SIZE 32 -#else - #define WARP_SIZE warpSize -#endif - -#define MAX(a, b) ((a) > (b) ? (a) : (b)) -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) - -namespace vllm { - -// Utility function for attention softmax. -template -inline __device__ float block_sum(float* red_smem, float sum) { - // Decompose the thread index into warp / lane. - int warp = threadIdx.x / WARP_SIZE; - int lane = threadIdx.x % WARP_SIZE; - - // Compute the sum per warp. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - sum += VLLM_SHFL_XOR_SYNC(sum, mask); - } - - // Warp leaders store the data to shared memory. - if (lane == 0) { - red_smem[warp] = sum; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The warps compute the final sums. - if (lane < NUM_WARPS) { - sum = red_smem[lane]; - } - - // Parallel reduction inside the warp. -#pragma unroll - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - sum += VLLM_SHFL_XOR_SYNC(sum, mask); - } - - // Broadcast to other threads. - return VLLM_SHFL_SYNC(sum, 0); -} - -// TODO(woosuk): Merge the last two dimensions of the grid. -// Grid: (num_heads, num_seqs, max_num_partitions). -template // Zero means no partitioning. -__device__ void paged_attention_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, - // max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, - // head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float k_scale, const float v_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { - const int seq_idx = blockIdx.y; - const int partition_idx = blockIdx.z; - const int max_num_partitions = gridDim.z; - constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; - const int seq_len = seq_lens[seq_idx]; - if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { - // No work to do. Terminate the thread block. - return; - } - - const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); - const int num_blocks_per_partition = - USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; - - // [start_block_idx, end_block_idx) is the range of blocks to process. - const int start_block_idx = - USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; - const int end_block_idx = - MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); - const int num_blocks = end_block_idx - start_block_idx; - - // [start_token_idx, end_token_idx) is the range of tokens to process. - const int start_token_idx = start_block_idx * BLOCK_SIZE; - const int end_token_idx = - MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); - const int num_tokens = end_token_idx - start_token_idx; - - constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); - constexpr int NUM_THREAD_GROUPS = - NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE - // divides NUM_THREADS - assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); - constexpr int NUM_TOKENS_PER_THREAD_GROUP = - DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - const int thread_idx = threadIdx.x; - const int warp_idx = thread_idx / WARP_SIZE; - const int lane = thread_idx % WARP_SIZE; - - const int head_idx = blockIdx.x; - const int num_heads = gridDim.x; - const int num_queries_per_kv = num_heads / num_kv_heads; - const int kv_head_idx = head_idx / num_queries_per_kv; - const float alibi_slope = - alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; - - // A vector type to store a part of a key or a query. - // The vector size is configured in such a way that the threads in a thread - // group fetch or compute 16 bytes at a time. For example, if the size of a - // thread group is 4 and the data type is half, then the vector size is 16 / - // (4 * sizeof(half)) == 2. - constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); - using K_vec = typename Vec::Type; - using Q_vec = typename Vec::Type; - using Quant_vec = typename Vec::Type; - - constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; - constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; - - const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; - const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; - - // Load the query to registers. - // Each thread in a thread group has a different part of the query. - // For example, if the the thread group size is 4, then the first thread in - // the group has 0, 4, 8, ... th vectors of the query, and the second thread - // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because - // q is split from a qkv tensor, it may not be contiguous. - const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; - __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; -#pragma unroll - for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; - i += NUM_THREAD_GROUPS) { - const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; - q_vecs[thread_group_offset][i] = - *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); - } - __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a - // memory wall right before we use q_vecs - - // Memory planning. - extern __shared__ char shared_mem[]; - // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. - float* logits = reinterpret_cast(shared_mem); - // Workspace for reduction. - __shared__ float red_smem[2 * NUM_WARPS]; - - // x == THREAD_GROUP_SIZE * VEC_SIZE - // Each thread group fetches x elements from the key at a time. - constexpr int x = 16 / sizeof(cache_t); - float qk_max = -FLT_MAX; - - // Iterate over the key blocks. - // Each warp fetches a block of keys for each iteration. - // Each thread group in a warp fetches a key from the block, and computes - // dot product with the query. - const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - - // blocksparse specific vars - int bs_block_offset; - int q_bs_block_id; - if constexpr (IS_BLOCK_SPARSE) { - // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len, - // blocksparse_block_size); - q_bs_block_id = (seq_len - 1) / blocksparse_block_size; - if (blocksparse_head_sliding_step >= 0) - // sliding on q heads - bs_block_offset = - (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1; - else - // sliding on kv heads - bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * - (-blocksparse_head_sliding_step) + - 1; - } - - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; - block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to - // int64 because int32 can lead to overflow when this variable is multiplied - // by large numbers (e.g., kv_block_stride). - // For blocksparse attention: skip computation on blocks that are not - // attended - if constexpr (IS_BLOCK_SPARSE) { - const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; - const bool is_remote = - ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0); - const bool is_local = - (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks); - if (!is_remote && !is_local) { - for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { - const int physical_block_offset = - (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; - const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - - if (thread_group_offset == 0) { - // NOTE(linxihui): assign very large number to skipped tokens to - // avoid contribution to the sumexp softmax normalizer. This will - // not be used at computing sum(softmax*v) as the blocks will be - // skipped. - logits[token_idx - start_token_idx] = -FLT_MAX; - } - } - continue; - } - } - const int64_t physical_block_number = - static_cast(block_table[block_idx]); - - // Load a key to registers. - // Each thread in a thread group has a different part of the key. - // For example, if the the thread group size is 4, then the first thread in - // the group has 0, 4, 8, ... th vectors of the key, and the second thread - // has 1, 5, 9, ... th vectors of the key, and so on. - for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { - const int physical_block_offset = - (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; - const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - K_vec k_vecs[NUM_VECS_PER_THREAD]; - -#pragma unroll - for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { - const cache_t* k_ptr = - k_cache + physical_block_number * kv_block_stride + - kv_head_idx * kv_head_stride + physical_block_offset * x; - const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; - const int offset1 = (vec_idx * VEC_SIZE) / x; - const int offset2 = (vec_idx * VEC_SIZE) % x; - - if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { - k_vecs[j] = *reinterpret_cast( - k_ptr + offset1 * BLOCK_SIZE * x + offset2); - } else { - // Vector conversion from Quant_vec to K_vec. - Quant_vec k_vec_quant = *reinterpret_cast( - k_ptr + offset1 * BLOCK_SIZE * x + offset2); - k_vecs[j] = fp8::scaled_convert( - k_vec_quant, k_scale); - } - } - - // Compute dot product. - // This includes a reduction across the threads in the same thread group. - float qk = scale * Qk_dot::dot( - q_vecs[thread_group_offset], k_vecs); - // Add the ALiBi bias if slopes are given. - qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; - - if (thread_group_offset == 0) { - // Store the partial reductions to shared memory. - // NOTE(woosuk): It is required to zero out the masked logits. - const bool mask = token_idx >= seq_len; - logits[token_idx - start_token_idx] = mask ? 0.f : qk; - // Update the max value. - qk_max = mask ? qk_max : fmaxf(qk_max, qk); - } - } - } - - // Perform reduction across the threads in the same warp to get the - // max qk value for each "warp" (not across the thread block yet). - // The 0-th thread of each thread group already has its max qk value. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { - qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); - } - if (lane == 0) { - red_smem[warp_idx] = qk_max; - } - __syncthreads(); - - // TODO(woosuk): Refactor this part. - // Get the max qk value for the sequence. - qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); - } - // Broadcast the max qk value to all threads. - qk_max = VLLM_SHFL_SYNC(qk_max, 0); - - // Get the sum of the exp values. - float exp_sum = 0.f; - for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { - float val = __expf(logits[i] - qk_max); - logits[i] = val; - exp_sum += val; - } - exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); - - // Compute softmax. - const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); - for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { - logits[i] *= inv_sum; - } - __syncthreads(); - - // If partitioning is enabled, store the max logit and exp_sum. - if (USE_PARTITIONING && thread_idx == 0) { - float* max_logits_ptr = max_logits + - seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions + partition_idx; - *max_logits_ptr = qk_max; - float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions + partition_idx; - *exp_sums_ptr = exp_sum; - } - - // Each thread will fetch 16 bytes from the value cache at a time. - constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); - using V_vec = typename Vec::Type; - using L_vec = typename Vec::Type; - using V_quant_vec = typename Vec::Type; - using Float_L_vec = typename FloatVec::Type; - - constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; - constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; - constexpr int NUM_ROWS_PER_THREAD = - DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); - - // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. - float accs[NUM_ROWS_PER_THREAD]; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - accs[i] = 0.f; - } - - scalar_t zero_value; - zero(zero_value); - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; - block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to - // int64 because int32 can lead to overflow when this variable is multiplied - // by large numbers (e.g., kv_block_stride). - // For blocksparse attention: skip computation on blocks that are not - // attended - if constexpr (IS_BLOCK_SPARSE) { - int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; - if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) && - !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { - continue; - } - } - const int64_t physical_block_number = - static_cast(block_table[block_idx]); - const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; - const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - L_vec logits_vec; - from_float(logits_vec, *reinterpret_cast(logits + token_idx - - start_token_idx)); - - const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + - kv_head_idx * kv_head_stride; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE) { - const int offset = row_idx * BLOCK_SIZE + physical_block_offset; - V_vec v_vec; - - if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { - v_vec = *reinterpret_cast(v_ptr + offset); - } else { - V_quant_vec v_quant_vec = - *reinterpret_cast(v_ptr + offset); - // Vector conversion from V_quant_vec to V_vec. - v_vec = fp8::scaled_convert(v_quant_vec, - v_scale); - } - if (block_idx == num_seq_blocks - 1) { - // NOTE(woosuk): When v_vec contains the tokens that are out of the - // context, we should explicitly zero out the values since they may - // contain NaNs. See - // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 - scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); -#pragma unroll - for (int j = 0; j < V_VEC_SIZE; j++) { - v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; - } - } - accs[i] += dot(logits_vec, v_vec); - } - } - } - - // Perform reduction within each warp. -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - float acc = accs[i]; -#pragma unroll - for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { - acc += VLLM_SHFL_XOR_SYNC(acc, mask); - } - accs[i] = acc; - } - - // NOTE(woosuk): A barrier is required because the shared memory space for - // logits is reused for the output. - __syncthreads(); - - // Perform reduction across warps. - float* out_smem = reinterpret_cast(shared_mem); -#pragma unroll - for (int i = NUM_WARPS; i > 1; i /= 2) { - int mid = i / 2; - // Upper warps write to shared memory. - if (warp_idx >= mid && warp_idx < i) { - float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - dst[row_idx] = accs[i]; - } - } - } - __syncthreads(); - - // Lower warps update the output. - if (warp_idx < mid) { - const float* src = &out_smem[warp_idx * HEAD_SIZE]; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - accs[i] += src[row_idx]; - } - } - } - __syncthreads(); - } - - // Write the final output. - if (warp_idx == 0) { - scalar_t* out_ptr = - out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - from_float(*(out_ptr + row_idx), accs[i]); - } - } - } -} - -// Grid: (num_heads, num_seqs, 1). -template -__global__ void paged_attention_v1_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float k_scale, const float v_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { - paged_attention_kernel( - /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, - v_cache, num_kv_heads, scale, block_tables, seq_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, - kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, - blocksparse_vert_stride, blocksparse_block_size, - blocksparse_head_sliding_step); -} - -// Grid: (num_heads, num_seqs, max_num_partitions). -template -__global__ void paged_attention_v2_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, - // max_num_partitions] - scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, - // max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float k_scale, const float v_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { - paged_attention_kernel( - exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, - block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, - kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, - blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, - blocksparse_head_sliding_step); -} - -// Grid: (num_heads, num_seqs). -template -__global__ void paged_attention_v2_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const float* __restrict__ exp_sums, // [num_seqs, num_heads, - // max_num_partitions] - const float* __restrict__ max_logits, // [num_seqs, num_heads, - // max_num_partitions] - const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, - // max_num_partitions, head_size] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_partitions) { - const int num_heads = gridDim.x; - const int head_idx = blockIdx.x; - const int seq_idx = blockIdx.y; - const int seq_len = seq_lens[seq_idx]; - const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); - if (num_partitions == 1) { - // No need to reduce. Only copy tmp_out to out. - scalar_t* out_ptr = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - const scalar_t* tmp_out_ptr = - tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - head_idx * max_num_partitions * HEAD_SIZE; - for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { - out_ptr[i] = tmp_out_ptr[i]; - } - // Terminate the thread block. - return; - } - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - const int warp_idx = threadIdx.x / WARP_SIZE; - const int lane = threadIdx.x % WARP_SIZE; - - // Size: 2 * num_partitions. - extern __shared__ char shared_mem[]; - // Workspace for reduction. - __shared__ float red_smem[2 * NUM_WARPS]; - - // Load max logits to shared memory. - float* shared_max_logits = reinterpret_cast(shared_mem); - const float* max_logits_ptr = max_logits + - seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions; - float max_logit = -FLT_MAX; - for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { - const float l = max_logits_ptr[i]; - shared_max_logits[i] = l; - max_logit = fmaxf(max_logit, l); - } - __syncthreads(); - - // Get the global max logit. - // Reduce within the warp. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); - } - if (lane == 0) { - red_smem[warp_idx] = max_logit; - } - __syncthreads(); - // Reduce across warps. - max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); - } - // Broadcast the max value to all threads. - max_logit = VLLM_SHFL_SYNC(max_logit, 0); - - // Load rescaled exp sums to shared memory. - float* shared_exp_sums = - reinterpret_cast(shared_mem + sizeof(float) * num_partitions); - const float* exp_sums_ptr = exp_sums + - seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions; - float global_exp_sum = 0.0f; - for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { - float l = shared_max_logits[i]; - float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); - global_exp_sum += rescaled_exp_sum; - shared_exp_sums[i] = rescaled_exp_sum; - } - __syncthreads(); - global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); - const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); - - // Aggregate tmp_out to out. - const scalar_t* tmp_out_ptr = - tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - head_idx * max_num_partitions * HEAD_SIZE; - scalar_t* out_ptr = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; -#pragma unroll - for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { - float acc = 0.0f; - for (int j = 0; j < num_partitions; ++j) { - acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * - inv_global_exp_sum; - } - from_float(out_ptr[i], acc); - } -} - -} // namespace vllm - -#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ - VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ - ((void*)vllm::paged_attention_v1_kernel), \ - shared_mem_size); \ - vllm::paged_attention_v1_kernel \ - <<>>( \ - out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ - scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ - alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ - k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ - blocksparse_vert_stride, blocksparse_block_size, \ - blocksparse_head_sliding_step); - -// TODO(woosuk): Tune NUM_THREADS. -template -void paged_attention_v1_launcher( - torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const c10::optional& alibi_slopes, float k_scale, - float v_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { - int num_seqs = query.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); - - [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); - assert(head_size % thread_group_size == 0); - - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = - alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; - - T* out_ptr = reinterpret_cast(out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* seq_lens_ptr = seq_lens.data_ptr(); - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_seq_len = - DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = padded_max_seq_len * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); - // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len - // Keep that in sync with the logic here! - int shared_mem_size = std::max(logits_size, outputs_size); - - dim3 grid(num_heads, num_seqs, 1); - dim3 block(NUM_THREADS); - const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - switch (head_size) { - // NOTE(woosuk): To reduce the compilation time, we only compile for the - // head sizes that we use in the model. However, we can easily extend this - // to support any head size which is a multiple of 16. - case 64: - LAUNCH_PAGED_ATTENTION_V1(64); - break; - case 80: - LAUNCH_PAGED_ATTENTION_V1(80); - break; - case 96: - LAUNCH_PAGED_ATTENTION_V1(96); - break; - case 112: - LAUNCH_PAGED_ATTENTION_V1(112); - break; - case 120: - LAUNCH_PAGED_ATTENTION_V1(120); - break; - case 128: - LAUNCH_PAGED_ATTENTION_V1(128); - break; - case 192: - LAUNCH_PAGED_ATTENTION_V1(192); - break; - case 256: - LAUNCH_PAGED_ATTENTION_V1(256); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; - } -} - -#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ - paged_attention_v1_launcher( \ - out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ - blocksparse_local_blocks, blocksparse_vert_stride, \ - blocksparse_block_size, blocksparse_head_sliding_step); - -#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - switch (is_block_sparse) { \ - case true: \ - CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ - break; \ - case false: \ - CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ - break; \ - } - -// NOTE(woosuk): To reduce the compilation time, we omitted block sizes -// 1, 2, 4, 64, 128, 256. -#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ - switch (block_size) { \ - case 8: \ - CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ - break; \ - case 16: \ - CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ - break; \ - case 32: \ - CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ - } - -void paged_attention_v1( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& - key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& - value_cache, // [num_blocks, num_heads, head_size, block_size] - int64_t num_kv_heads, // [num_heads] - double scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& seq_lens, // [num_seqs] - int64_t block_size, int64_t max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, - const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step) { - const bool is_block_sparse = (blocksparse_vert_stride > 1); - - DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, - CALL_V1_LAUNCHER_BLOCK_SIZE) -} - -#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ - vllm::paged_attention_v2_kernel \ - <<>>( \ - exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ - value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ - blocksparse_local_blocks, blocksparse_vert_stride, \ - blocksparse_block_size, blocksparse_head_sliding_step); \ - vllm::paged_attention_v2_reduce_kernel \ - <<>>( \ - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ - max_num_partitions); - -template -void paged_attention_v2_launcher( - torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, - torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const c10::optional& alibi_slopes, float k_scale, - float v_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { - int num_seqs = query.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); - - [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); - assert(head_size % thread_group_size == 0); - - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = - alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; - - T* out_ptr = reinterpret_cast(out.data_ptr()); - float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); - float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); - T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* seq_lens_ptr = seq_lens.data_ptr(); - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); - int logits_size = PARTITION_SIZE * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); - - // For paged attention v2 kernel. - dim3 grid(num_heads, num_seqs, max_num_partitions); - int shared_mem_size = std::max(logits_size, outputs_size); - // For paged attention v2 reduce kernel. - dim3 reduce_grid(num_heads, num_seqs); - int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); - - dim3 block(NUM_THREADS); - const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - switch (head_size) { - // NOTE(woosuk): To reduce the compilation time, we only compile for the - // head sizes that we use in the model. However, we can easily extend this - // to support any head size which is a multiple of 16. - case 64: - LAUNCH_PAGED_ATTENTION_V2(64); - break; - case 80: - LAUNCH_PAGED_ATTENTION_V2(80); - break; - case 96: - LAUNCH_PAGED_ATTENTION_V2(96); - break; - case 112: - LAUNCH_PAGED_ATTENTION_V2(112); - break; - case 120: - LAUNCH_PAGED_ATTENTION_V2(120); - break; - case 128: - LAUNCH_PAGED_ATTENTION_V2(128); - break; - case 192: - LAUNCH_PAGED_ATTENTION_V2(192); - break; - case 256: - LAUNCH_PAGED_ATTENTION_V2(256); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; - } -} - -#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ - paged_attention_v2_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ - k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ - blocksparse_vert_stride, blocksparse_block_size, \ - blocksparse_head_sliding_step); - -#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - switch (is_block_sparse) { \ - case true: \ - CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ - break; \ - case false: \ - CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ - break; \ - } - -// NOTE(woosuk): To reduce the compilation time, we omitted block sizes -// 1, 2, 4, 64, 128, 256. -#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ - switch (block_size) { \ - case 8: \ - CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ - break; \ - case 16: \ - CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ - break; \ - case 32: \ - CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ - } - -void paged_attention_v2( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& - tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& - key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& - value_cache, // [num_blocks, num_heads, head_size, block_size] - int64_t num_kv_heads, // [num_heads] - double scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& seq_lens, // [num_seqs] - int64_t block_size, int64_t max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, - const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step) { - const bool is_block_sparse = (blocksparse_vert_stride > 1); - DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, - CALL_V2_LAUNCHER_BLOCK_SIZE) -} - -#undef WARP_SIZE -#undef MAX -#undef MIN -#undef DIVIDE_ROUND_UP diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh new file mode 100644 index 000000000000..eb216dc8baf1 --- /dev/null +++ b/csrc/attention/attention_kernels.cuh @@ -0,0 +1,676 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "attention_dtypes.h" +#include "attention_utils.cuh" + +#ifdef USE_ROCM + #include + #include "../quantization/fp8/amd/quant_utils.cuh" +typedef __hip_bfloat16 __nv_bfloat16; +#else + #include "../quantization/fp8/nvidia/quant_utils.cuh" +#endif + +#ifndef USE_ROCM + #define WARP_SIZE 32 +#else + #define WARP_SIZE warpSize +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +namespace vllm { + +// Utility function for attention softmax. +template +inline __device__ float block_sum(float* red_smem, float sum) { + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += VLLM_SHFL_XOR_SYNC(sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < NUM_WARPS) { + sum = red_smem[lane]; + } + + // Parallel reduction inside the warp. +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + sum += VLLM_SHFL_XOR_SYNC(sum, mask); + } + + // Broadcast to other threads. + return VLLM_SHFL_SYNC(sum, 0); +} + +// TODO(woosuk): Merge the last two dimensions of the grid. +// Grid: (num_heads, num_seqs, max_num_partitions). +template // Zero means no partitioning. +__device__ void paged_attention_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float* k_scale, const float* v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const int seq_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int max_num_partitions = gridDim.z; + constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; + const int seq_len = seq_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { + // No work to do. Terminate the thread block. + return; + } + + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int num_blocks_per_partition = + USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; + + // [start_block_idx, end_block_idx) is the range of blocks to process. + const int start_block_idx = + USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = + MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); + const int num_blocks = end_block_idx - start_block_idx; + + // [start_token_idx, end_token_idx) is the range of tokens to process. + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = + MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); + const int num_tokens = end_token_idx - start_token_idx; + + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE + // divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = + DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = threadIdx.x; + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / num_kv_heads; + const int kv_head_idx = head_idx / num_queries_per_kv; + const float alibi_slope = + alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread + // group fetch or compute 16 bytes at a time. For example, if the size of a + // thread group is 4 and the data type is half, then the vector size is 16 / + // (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + using Quant_vec = typename Vec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the query, and the second thread + // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because + // q is split from a qkv tensor, it may not be contiguous. + const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; + i += NUM_THREAD_GROUPS) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset][i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a + // memory wall right before we use q_vecs + + // Memory planning. + extern __shared__ char shared_mem[]; + // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. + float* logits = reinterpret_cast(shared_mem); + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(cache_t); + float qk_max = -FLT_MAX; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + + // blocksparse specific vars + int bs_block_offset; + int q_bs_block_id; + if constexpr (IS_BLOCK_SPARSE) { + // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len, + // blocksparse_block_size); + q_bs_block_id = (seq_len - 1) / blocksparse_block_size; + if (blocksparse_head_sliding_step >= 0) + // sliding on q heads + bs_block_offset = + (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1; + else + // sliding on kv heads + bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * + (-blocksparse_head_sliding_step) + + 1; + } + + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + const bool is_remote = + ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0); + const bool is_local = + (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks); + if (!is_remote && !is_local) { + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + + if (thread_group_offset == 0) { + // NOTE(linxihui): assign very large number to skipped tokens to + // avoid contribution to the sumexp softmax normalizer. This will + // not be used at computing sum(softmax*v) as the blocks will be + // skipped. + logits[token_idx - start_token_idx] = -FLT_MAX; + } + } + continue; + } + } + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the key, and the second thread + // has 1, 5, 9, ... th vectors of the key, and so on. + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const cache_t* k_ptr = + k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } else { + // Vector conversion from Quant_vec to K_vec. + Quant_vec k_vec_quant = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = fp8::scaled_convert( + k_vec_quant, *k_scale); + } + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + float qk = scale * Qk_dot::dot( + q_vecs[thread_group_offset], k_vecs); + // Add the ALiBi bias if slopes are given. + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the masked logits. + const bool mask = token_idx >= seq_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + __syncthreads(); + + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = VLLM_SHFL_SYNC(qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + // Compute softmax. + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + + // If partitioning is enabled, store the max logit and exp_sum. + if (USE_PARTITIONING && thread_idx == 0) { + float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *max_logits_ptr = qk_max; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *exp_sums_ptr = exp_sum; + } + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using V_quant_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = + DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + + // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + scalar_t zero_value; + zero(zero_value); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) && + !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { + continue; + } + } + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + from_float(logits_vec, *reinterpret_cast(logits + token_idx - + start_token_idx)); + + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + V_vec v_vec; + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { + v_vec = *reinterpret_cast(v_ptr + offset); + } else { + V_quant_vec v_quant_vec = + *reinterpret_cast(v_ptr + offset); + // Vector conversion from V_quant_vec to V_vec. + v_vec = fp8::scaled_convert(v_quant_vec, + *v_scale); + } + if (block_idx == num_seq_blocks - 1) { + // NOTE(woosuk): When v_vec contains the tokens that are out of the + // context, we should explicitly zero out the values since they may + // contain NaNs. See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); +#pragma unroll + for (int j = 0; j < V_VEC_SIZE; j++) { + v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; + } + } + accs[i] += dot(logits_vec, v_vec); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += VLLM_SHFL_XOR_SYNC(acc, mask); + } + accs[i] = acc; + } + + // NOTE(woosuk): A barrier is required because the shared memory space for + // logits is reused for the output. + __syncthreads(); + + // Perform reduction across warps. + float* out_smem = reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + __syncthreads(); + + // Lower warps update the output. + if (warp_idx < mid) { + const float* src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + __syncthreads(); + } + + // Write the final output. + if (warp_idx == 0) { + scalar_t* out_ptr = + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + from_float(*(out_ptr + row_idx), accs[i]); + } + } + } +} + +// Grid: (num_heads, num_seqs, 1). +template +__global__ void paged_attention_v1_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float* k_scale, const float* v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + paged_attention_kernel( + /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, + v_cache, num_kv_heads, scale, block_tables, seq_lens, + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, + kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, + blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); +} + +// Grid: (num_heads, num_seqs, max_num_partitions). +template +__global__ void paged_attention_v2_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float* k_scale, const float* v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + paged_attention_kernel( + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); +} + +// Grid: (num_heads, num_seqs). +template +__global__ void paged_attention_v2_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int seq_len = seq_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); + if (num_partitions == 1) { + // No need to reduce. Only copy tmp_out to out. + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { + out_ptr[i] = tmp_out_ptr[i]; + } + // Terminate the thread block. + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warp_idx = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + + // Size: 2 * num_partitions. + extern __shared__ char shared_mem[]; + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // Load max logits to shared memory. + float* shared_max_logits = reinterpret_cast(shared_mem); + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float max_logit = -FLT_MAX; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + const float l = max_logits_ptr[i]; + shared_max_logits[i] = l; + max_logit = fmaxf(max_logit, l); + } + __syncthreads(); + + // Get the global max logit. + // Reduce within the warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = max_logit; + } + __syncthreads(); + // Reduce across warps. + max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); + } + // Broadcast the max value to all threads. + max_logit = VLLM_SHFL_SYNC(max_logit, 0); + + // Load rescaled exp sums to shared memory. + float* shared_exp_sums = + reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float global_exp_sum = 0.0f; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + float l = shared_max_logits[i]; + float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); + global_exp_sum += rescaled_exp_sum; + shared_exp_sums[i] = rescaled_exp_sum; + } + __syncthreads(); + global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); + const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); + + // Aggregate tmp_out to out. + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { + float acc = 0.0f; + for (int j = 0; j < num_partitions; ++j) { + acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * + inv_global_exp_sum; + } + from_float(out_ptr[i], acc); + } +} + +} // namespace vllm + +#undef WARP_SIZE +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index cdcee4274899..826b0edffae6 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -34,7 +34,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { A_vec qk_vec = mul(q[0], k[0]); #pragma unroll for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); + qk_vec = vllm::fma(q[ii], k[ii], qk_vec); } // Finalize the reduction across lanes. diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 3cdcb95e0809..97a25baa1fc0 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -94,6 +94,7 @@ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { #else return __bfloat1622float2(val); #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { @@ -102,6 +103,7 @@ inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { #else return __bfloat162bfloat162(val); #endif + __builtin_unreachable(); // Suppress missing return statement warning } // Vector addition. @@ -115,6 +117,7 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { return __hadd(a, b); #endif #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { @@ -123,6 +126,7 @@ inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { #else return __hadd2(a, b); #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { @@ -170,6 +174,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { #else return __hmul(a, b); #endif + __builtin_unreachable(); // Suppress missing return statement warning } template <> @@ -179,6 +184,7 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { #else return __hmul2(a, b); #endif + __builtin_unreachable(); // Suppress missing return statement warning } template <> @@ -289,6 +295,7 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, #else return __hfma2(a, b, c); #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, @@ -298,6 +305,7 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, #else return __hfma2(bf162bf162(a), b, c); #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu new file mode 100644 index 000000000000..9b3a5c4b1014 --- /dev/null +++ b/csrc/attention/paged_attention_v1.cu @@ -0,0 +1,196 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "attention_kernels.cuh" + +#ifndef USE_ROCM + #define WARP_SIZE 32 +#else + #define WARP_SIZE warpSize +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void*)vllm::paged_attention_v1_kernel), \ + shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ + scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); + +// TODO(woosuk): Tune NUM_THREADS. +template +void paged_attention_v1_launcher( + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const std::optional& alibi_slopes, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_seq_len * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len + // Keep that in sync with the logic here! + int shared_mem_size = std::max(logits_size, outputs_size); + + dim3 grid(num_heads, num_seqs, 1); + dim3 block(NUM_THREADS); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 32: + LAUNCH_PAGED_ATTENTION_V1(32); + break; + case 64: + LAUNCH_PAGED_ATTENTION_V1(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V1(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V1(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V1(112); + break; + case 120: + LAUNCH_PAGED_ATTENTION_V1(120); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V1(128); + break; + case 192: + LAUNCH_PAGED_ATTENTION_V1(192); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V1(256); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ + paged_attention_v1_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ + seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); + +#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + if (is_block_sparse) { \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + } else { \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ + } + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void paged_attention_v1( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const std::optional& alibi_slopes, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V1_LAUNCHER_BLOCK_SIZE) +} + +#undef WARP_SIZE +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP \ No newline at end of file diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu new file mode 100644 index 000000000000..9935359e02fb --- /dev/null +++ b/csrc/attention/paged_attention_v2.cu @@ -0,0 +1,206 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "attention_kernels.cuh" + +#ifndef USE_ROCM + #define WARP_SIZE 32 +#else + #define WARP_SIZE warpSize +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ + value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ + max_num_partitions); + +template +void paged_attention_v2_launcher( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const std::optional& alibi_slopes, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); + int logits_size = PARTITION_SIZE * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + + // For paged attention v2 kernel. + dim3 grid(num_heads, num_seqs, max_num_partitions); + int shared_mem_size = std::max(logits_size, outputs_size); + // For paged attention v2 reduce kernel. + dim3 reduce_grid(num_heads, num_seqs); + int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + + dim3 block(NUM_THREADS); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 32: + LAUNCH_PAGED_ATTENTION_V2(32); + break; + case 64: + LAUNCH_PAGED_ATTENTION_V2(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V2(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V2(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V2(112); + break; + case 120: + LAUNCH_PAGED_ATTENTION_V2(120); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V2(128); + break; + case 192: + LAUNCH_PAGED_ATTENTION_V2(192); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V2(256); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ + paged_attention_v2_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ + k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); + +#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + if (is_block_sparse) { \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + } else { \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ + } + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void paged_attention_v2( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& + tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const std::optional& alibi_slopes, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V2_LAUNCHER_BLOCK_SIZE) +} + +#undef WARP_SIZE +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP \ No newline at end of file diff --git a/csrc/cache.h b/csrc/cache.h index 11c4c5001daa..0970b704be3a 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -15,19 +15,34 @@ void copy_blocks(std::vector const& key_caches, std::vector const& value_caches, const torch::Tensor& block_mapping); +void copy_blocks_mla(std::vector const& kv_caches, + const torch::Tensor& block_mapping); + void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, const double k_scale, - const double v_scale); + const std::string& kv_cache_dtype, + torch::Tensor& k_scale, torch::Tensor& v_scale); void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, const std::string& kv_cache_dtype, - const double k_scale, const double v_scale); + torch::Tensor& k_scale, torch::Tensor& v_scale); + +void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, + torch::Tensor& kv_cache, torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, + torch::Tensor& scale); // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, const double scale, const std::string& kv_cache_dtype); + +void gather_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + int64_t batch_size, std::optional seq_starts = std::nullopt); \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 1be806bbfa43..0b3f6fc8c19a 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -2,6 +2,7 @@ #include #include +#include "cuda_utils.h" #include "cuda_compat.h" #include "dispatch_utils.h" @@ -46,7 +47,10 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, char* src_ptr = static_cast(src.data_ptr()); char* dst_ptr = static_cast(dst.data_ptr()); - const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); + // We use the stride instead of numel in case the cache is padded for memory + // alignment reasons, we assume the blocks data (inclusive of any padding) + // is contiguous in memory + const int64_t block_size_in_bytes = src.element_size() * src.stride(0); const at::cuda::OptionalCUDAGuard device_guard( src_device.is_cuda() ? src_device : dst_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -93,6 +97,24 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, } } +// Kernel for MLA, which works on a single joint kv_cache +// Grid: (num_layers, num_pairs) +template +__global__ void copy_blocks_mla_kernel( + int64_t* cache_ptrs, const int64_t* __restrict__ block_mapping, + const int mem_footprint_per_block) { + const int layer_idx = blockIdx.x; + const int pair_idx = blockIdx.y; + scalar_t* cache = reinterpret_cast(cache_ptrs[layer_idx]); + int64_t src_block = block_mapping[2 * pair_idx]; + int64_t dst_block = block_mapping[2 * pair_idx + 1]; + int64_t src_offset = src_block * mem_footprint_per_block; + int64_t dst_offset = dst_block * mem_footprint_per_block; + for (int i = threadIdx.x; i < mem_footprint_per_block; i += blockDim.x) { + cache[dst_offset + i] = cache[src_offset + i]; + } +} + } // namespace vllm // Note: the key_caches and value_caches vectors are constant but @@ -147,6 +169,42 @@ void copy_blocks(std::vector const& key_caches, })); } +// copy blocks kernel for MLA (assumes a joint KV-cache) +void copy_blocks_mla(std::vector const& kv_caches, + const torch::Tensor& block_mapping) { + int num_layers = kv_caches.size(); + if (num_layers == 0) { + return; + } + torch::Device cache_device = kv_caches[0].device(); + TORCH_CHECK(cache_device.is_cuda(), "kv_cache must be on CUDA"); + + std::vector cache_ptrs(num_layers); + for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { + cache_ptrs[layer_idx] = + reinterpret_cast(kv_caches[layer_idx].data_ptr()); + } + torch::Tensor cache_ptrs_tensor = + torch::from_blob(cache_ptrs.data(), {num_layers}, torch::kInt64) + .to(cache_device); + + int num_pairs = block_mapping.size(0); + // We use the stride instead of numel in case the cache is padded for memory + // alignment reasons, we assume the blocks data (inclusive of any padding) + // is contiguous in memory + int mem_footprint_per_block = kv_caches[0].stride(0); + dim3 grid(num_layers, num_pairs); + dim3 block(std::min(1024, mem_footprint_per_block)); + const at::cuda::OptionalCUDAGuard device_guard(cache_device); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( + kv_caches[0].scalar_type(), "copy_blocks_mla_kernel", ([&] { + vllm::copy_blocks_mla_kernel<<>>( + cache_ptrs_tensor.data_ptr(), + block_mapping.data_ptr(), mem_footprint_per_block); + })); +} + namespace vllm { template @@ -159,8 +217,8 @@ __global__ void reshape_and_cache_kernel( // block_size] const int64_t* __restrict__ slot_mapping, // [num_tokens] const int key_stride, const int value_stride, const int num_heads, - const int head_size, const int block_size, const int x, const float k_scale, - const float v_scale) { + const int head_size, const int block_size, const int x, + const float* k_scale, const float* v_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { @@ -196,9 +254,9 @@ __global__ void reshape_and_cache_kernel( value_cache[tgt_value_idx] = tgt_value; } else { key_cache[tgt_key_idx] = - fp8::scaled_convert(tgt_key, k_scale); + fp8::scaled_convert(tgt_key, *k_scale); value_cache[tgt_value_idx] = - fp8::scaled_convert(tgt_value, v_scale); + fp8::scaled_convert(tgt_value, *v_scale); } } } @@ -214,7 +272,7 @@ __global__ void reshape_and_cache_flash_kernel( const int64_t* __restrict__ slot_mapping, // [num_tokens] const int block_stride, const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size, - const float k_scale, const float v_scale) { + const float* k_scale, const float* v_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; // NOTE: slot_idx can be -1 if the token is padded @@ -239,16 +297,61 @@ __global__ void reshape_and_cache_flash_kernel( value_cache[tgt_key_value_idx] = tgt_value; } else { key_cache[tgt_key_value_idx] = - fp8::scaled_convert(tgt_key, k_scale); + fp8::scaled_convert(tgt_key, *k_scale); value_cache[tgt_key_value_idx] = - fp8::scaled_convert(tgt_value, v_scale); + fp8::scaled_convert(tgt_value, *v_scale); } } } + +template +__global__ void concat_and_cache_mla_kernel( + const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] + const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim] + cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank + // + pe_dim)] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, // + const int entry_stride, // + const int kv_c_stride, // + const int k_pe_stride, // + const int kv_lora_rank, // + const int pe_dim, // + const int block_size, // + const float* scale // +) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst, + int src_stride, int dst_stride, int size, int offset) { + for (int i = threadIdx.x; i < size; i += blockDim.x) { + const int64_t src_idx = token_idx * src_stride + i; + const int64_t dst_idx = + block_idx * block_stride + block_offset * entry_stride + i + offset; + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + dst[dst_idx] = src[src_idx]; + } else { + dst[dst_idx] = + fp8::scaled_convert(src[src_idx], *scale); + } + } + }; + + copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0); + copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); +} + } // namespace vllm -// KV_T is the stored data type of kv-cache. -// CACHE_T is the data type of key and value tensors. +// KV_T is the data type of key and value tensors. +// CACHE_T is the stored data type of kv-cache. // KV_DTYPE is the real data type of kv-cache. #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ vllm::reshape_and_cache_kernel \ @@ -258,7 +361,9 @@ __global__ void reshape_and_cache_flash_kernel( reinterpret_cast(key_cache.data_ptr()), \ reinterpret_cast(value_cache.data_ptr()), \ slot_mapping.data_ptr(), key_stride, value_stride, \ - num_heads, head_size, block_size, x, k_scale, v_scale); + num_heads, head_size, block_size, x, \ + reinterpret_cast(k_scale.data_ptr()), \ + reinterpret_cast(v_scale.data_ptr())); void reshape_and_cache( torch::Tensor& key, // [num_tokens, num_heads, head_size] @@ -268,9 +373,9 @@ void reshape_and_cache( torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype, const double k_scale, - const double v_scale) { - int num_tokens = key.size(0); + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale) { + int num_tokens = slot_mapping.size(0); int num_heads = key.size(1); int head_size = key.size(2); int block_size = key_cache.size(3); @@ -288,8 +393,8 @@ void reshape_and_cache( CALL_RESHAPE_AND_CACHE) } -// KV_T is the stored data type of kv-cache. -// CACHE_T is the data type of key and value tensors. +// KV_T is the data type of key and value tensors. +// CACHE_T is the stored data type of kv-cache. // KV_DTYPE is the real data type of kv-cache. #define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \ vllm::reshape_and_cache_flash_kernel \ @@ -299,7 +404,9 @@ void reshape_and_cache( reinterpret_cast(key_cache.data_ptr()), \ reinterpret_cast(value_cache.data_ptr()), \ slot_mapping.data_ptr(), block_stride, key_stride, \ - value_stride, num_heads, head_size, block_size, k_scale, v_scale); + value_stride, num_heads, head_size, block_size, \ + reinterpret_cast(k_scale.data_ptr()), \ + reinterpret_cast(v_scale.data_ptr())); void reshape_and_cache_flash( torch::Tensor& key, // [num_tokens, num_heads, head_size] @@ -307,10 +414,20 @@ void reshape_and_cache_flash( torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype, const double k_scale, - const double v_scale) { - int num_tokens = key.size(0); + torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale) { + // NOTE(woosuk): In vLLM V1, key.size(0) can be different from + // slot_mapping.size(0) because of padding for CUDA graphs. + // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because + // both include padding. + // In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0) + // since key includes padding for CUDA graphs, while slot_mapping does not. + // In this case, slot_mapping.size(0) represents the actual number of tokens + // before padding. + // For compatibility with both cases, we use slot_mapping.size(0) as the + // number of tokens. + int num_tokens = slot_mapping.size(0); int num_heads = key.size(1); int head_size = key.size(2); int block_size = key_cache.size(1); @@ -329,6 +446,57 @@ void reshape_and_cache_flash( CALL_RESHAPE_AND_CACHE_FLASH); } +// KV_T is the data type of key and value tensors. +// CACHE_T is the stored data type of kv-cache. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \ + vllm::concat_and_cache_mla_kernel \ + <<>>( \ + reinterpret_cast(kv_c.data_ptr()), \ + reinterpret_cast(k_pe.data_ptr()), \ + reinterpret_cast(kv_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, entry_stride, \ + kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ + reinterpret_cast(scale.data_ptr())); + +void concat_and_cache_mla( + torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] + torch::Tensor& k_pe, // [num_tokens, pe_dim] + torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank + + // pe_dim)] + torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] + const std::string& kv_cache_dtype, torch::Tensor& scale) { + // NOTE(woosuk): In vLLM V1, key.size(0) can be different from + // slot_mapping.size(0) because of padding for CUDA graphs. + // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because + // both include padding. + // In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0) + // since key includes padding for CUDA graphs, while slot_mapping does not. + // In this case, slot_mapping.size(0) represents the actual number of tokens + // before padding. + // For compatibility with both cases, we use slot_mapping.size(0) as the + // number of tokens. + int num_tokens = slot_mapping.size(0); + int kv_lora_rank = kv_c.size(1); + int pe_dim = k_pe.size(1); + int block_size = kv_cache.size(1); + + TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); + + int kv_c_stride = kv_c.stride(0); + int k_pe_stride = k_pe.stride(0); + int block_stride = kv_cache.stride(0); + int entry_stride = kv_cache.stride(1); + + dim3 grid(num_tokens); + dim3 block(std::min(kv_lora_rank, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, + CALL_CONCAT_AND_CACHE_MLA); +} + namespace vllm { template @@ -403,3 +571,161 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); } } + +namespace vllm { + +// grid is launched with dimensions (batch, num_splits) +template +__global__ void gather_cache( + const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, + // ENTRIES...] + scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...] + const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] + const int32_t* __restrict__ cu_seq_lens, // [BATCH+1] + const int32_t block_size, const int32_t entry_size, + const int64_t block_table_stride, const int64_t cache_block_stride, + const int64_t cache_entry_stride, const int64_t dst_entry_stride, + const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per + // batch + + const int64_t bid = blockIdx.x; // Batch ID + const int32_t num_splits = gridDim.y; + const int32_t split = blockIdx.y; + const int32_t seq_start = cu_seq_lens[bid]; + const int32_t seq_end = cu_seq_lens[bid + 1]; + const int32_t seq_len = seq_end - seq_start; + const int32_t tot_blocks = cuda_utils::ceil_div(seq_len, block_size); + const int32_t split_blocks = cuda_utils::ceil_div(tot_blocks, num_splits); + + const int32_t split_start = split * split_blocks; + const int32_t split_end = min((split + 1) * split_blocks, tot_blocks); + + const bool is_active_split = (split_start < tot_blocks); + const bool is_last_split = (split_end == tot_blocks); + + if (!is_active_split) return; + + int32_t full_blocks_end = split_end; + int32_t partial_block_size = 0; + + // Adjust the pointer for the block_table for this batch. + // If seq_starts is provided, compute an offset based on (seq_starts[bid] / + // page_size) + const int32_t batch_offset = bid * block_table_stride; + int32_t offset = 0; + if (seq_starts != nullptr) { + offset = seq_starts[bid] / block_size; + } + const int32_t* batch_block_table = block_table + batch_offset + offset; + + // Adjust dst pointer based on the cumulative sequence lengths. + dst += seq_start * dst_entry_stride; + + if (is_last_split) { + partial_block_size = seq_len % block_size; + if (partial_block_size) full_blocks_end -= 1; + } + + auto copy_entry = [&](const scalar_t* __restrict__ _src, + scalar_t* __restrict__ _dst) { + for (int i = threadIdx.x; i < entry_size; i += blockDim.x) + _dst[i] = _src[i]; + }; + + for (int pid = split_start; pid < full_blocks_end; ++pid) { + auto block_id = batch_block_table[pid]; + auto block_start_ptr = src_cache + block_id * cache_block_stride; + auto block_dst_ptr = dst + pid * block_size * dst_entry_stride; + for (int eid = 0; eid < block_size; ++eid) { + copy_entry(block_start_ptr + eid * cache_entry_stride, + block_dst_ptr + eid * dst_entry_stride); + } + } + + if (partial_block_size) { + auto block_id = batch_block_table[full_blocks_end]; + auto block_start_ptr = src_cache + block_id * cache_block_stride; + auto block_dst_ptr = dst + full_blocks_end * block_size * dst_entry_stride; + for (int eid = 0; eid < partial_block_size; ++eid) { + copy_entry(block_start_ptr + eid * cache_entry_stride, + block_dst_ptr + eid * dst_entry_stride); + } + } +} + +} // namespace vllm + +// Macro to dispatch the kernel based on the data type. +#define CALL_GATHER_CACHE(CPY_DTYPE) \ + vllm::gather_cache<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst.data_ptr()), \ + block_table.data_ptr(), cu_seq_lens.data_ptr(), \ + block_size, entry_size, block_table_stride, cache_block_stride, \ + cache_entry_stride, dst_entry_stride, seq_starts_ptr); + +// Gather sequences from the cache into the destination tensor. +// - cu_seq_lens contains the cumulative sequence lengths for each batch +// - block_table contains the cache block indices for each sequence +// - Optionally, seq_starts (if provided) offsets the starting block index by +// (seq_starts[bid] / page_size) +void gather_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + int64_t batch_size, + std::optional seq_starts = std::nullopt) { + at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int32_t block_size = src_cache.size(1); + int32_t entry_size = src_cache.flatten(2, -1).size(2); + + TORCH_CHECK(block_table.dtype() == torch::kInt32, + "block_table must be int32"); + TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32, + "cu_seq_lens must be int32"); + if (seq_starts.has_value()) { + TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32, + "seq_starts must be int32"); + } + + TORCH_CHECK(src_cache.device() == dst.device(), + "src_cache and dst must be on the same device"); + TORCH_CHECK(src_cache.device() == block_table.device(), + "src_cache and block_table must be on the same device"); + TORCH_CHECK(src_cache.device() == cu_seq_lens.device(), + "src_cache and cu_seq_lens must be on the same device"); + if (seq_starts.has_value()) { + TORCH_CHECK(src_cache.device() == seq_starts.value().device(), + "src_cache and seq_starts must be on the same device"); + } + + int64_t block_table_stride = block_table.stride(0); + int64_t cache_block_stride = src_cache.stride(0); + int64_t cache_entry_stride = src_cache.stride(1); + int64_t dst_entry_stride = dst.stride(0); + + // Decide on the number of splits based on the batch size. + int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; + dim3 grid(batch_size, num_splits); + dim3 block(1024); + + TORCH_CHECK(src_cache.dtype() == dst.dtype(), + "src_cache and dst must have the same dtype"); + + const int dtype_bits = src_cache.element_size() * 8; + const int32_t* seq_starts_ptr = + seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr; + + if (dtype_bits == 32) { + CALL_GATHER_CACHE(uint32_t); + } else if (dtype_bits == 16) { + CALL_GATHER_CACHE(uint16_t); + } else if (dtype_bits == 8) { + CALL_GATHER_CACHE(uint8_t); + } else { + TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits); + } +} diff --git a/csrc/core/exception.hpp b/csrc/core/exception.hpp new file mode 100644 index 000000000000..f3b2ffaef6cc --- /dev/null +++ b/csrc/core/exception.hpp @@ -0,0 +1,3 @@ +#pragma once + +#define VLLM_IMPLIES(p, q) (!(p) || (q)) diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp new file mode 100644 index 000000000000..b8171133f6aa --- /dev/null +++ b/csrc/core/math.hpp @@ -0,0 +1,9 @@ +#pragma once + +#include +#include + +inline constexpr uint32_t next_pow_2(uint32_t const num) { + if (num <= 1) return num; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} diff --git a/csrc/registration.h b/csrc/core/registration.h similarity index 79% rename from csrc/registration.h rename to csrc/core/registration.h index e5396e9a8b13..4d0ce1c572c1 100644 --- a/csrc/registration.h +++ b/csrc/core/registration.h @@ -12,6 +12,11 @@ // could be a macro instead of a literal token. #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) +// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. +#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \ + TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE) + // REGISTER_EXTENSION allows the shared library to be loaded and initialized // via python's import statement. #define REGISTER_EXTENSION(NAME) \ diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp new file mode 100644 index 000000000000..c2ae554c9f8e --- /dev/null +++ b/csrc/core/scalar_type.hpp @@ -0,0 +1,347 @@ +#pragma once + +// For TORCH_CHECK +#include + +namespace vllm { + +// +// ScalarType can represent a wide range of floating point and integer types, +// in particular it can be used to represent sub-byte data types (something +// that torch.dtype currently does not support). +// +// The type definitions on the Python side can be found in: vllm/scalar_type.py +// these type definitions should be kept up to date with any Python API changes +// here. +// +class ScalarType { + public: + enum NanRepr : uint8_t { + NAN_NONE = 0, // nans are not supported + NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s + NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s + + NAN_REPR_ID_MAX + }; + + constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_, + int32_t bias, bool finite_values_only = false, + NanRepr nan_repr = NAN_IEEE_754) + : exponent(exponent), + mantissa(mantissa), + signed_(signed_), + bias(bias), + finite_values_only(finite_values_only), + nan_repr(nan_repr) {}; + + static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits - 1, true, bias); + } + + static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits, false, bias); + } + + // IEEE 754 compliant floating point type + static constexpr ScalarType float_IEEE754(uint8_t exponent, + uint8_t mantissa) { + TORCH_CHECK(mantissa > 0 && exponent > 0); + return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); + } + + // IEEE 754 non-compliant floating point type + static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, + bool finite_values_only, + NanRepr nan_repr) { + TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr"); + TORCH_CHECK(mantissa > 0 && exponent > 0); + TORCH_CHECK(nan_repr != NAN_IEEE_754, + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions"); + return ScalarType(exponent, mantissa, true, 0, finite_values_only, + nan_repr); + } + + uint8_t const exponent; // size of the exponent field (0 for integer types) + uint8_t const mantissa; // size of the mantissa field (size of the integer + // excluding the sign bit for integer types) + bool const signed_; // flag if the type supports negative numbers (i.e. has a + // sign bit) + int32_t const bias; // stored values equal value + bias, + // used for quantized type + + // Extra Floating point info + bool const finite_values_only; // i.e. no +/-inf if true + NanRepr const nan_repr; // how NaNs are represented + // (not applicable for integer types) + + using Id = int64_t; + + private: + // Field size in id + template + static constexpr size_t member_id_field_width() { + using T = std::decay_t; + return std::is_same_v ? 1 : sizeof(T) * 8; + } + + template + static constexpr auto reduce_members_helper(Fn f, Init val, Member member, + Rest... rest) { + auto new_val = f(val, member); + if constexpr (sizeof...(rest) > 0) { + return reduce_members_helper(f, new_val, rest...); + } else { + return new_val; + }; + } + + template + constexpr auto reduce_members(Fn f, Init init) const { + // Should be in constructor order for `from_id` + return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, + finite_values_only, nan_repr); + }; + + template + static constexpr auto reduce_member_types(Fn f, Init init) { + constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE); + return dummy_type.reduce_members(f, init); + }; + + static constexpr auto id_size_bits() { + return reduce_member_types( + [](int acc, auto member) -> int { + return acc + member_id_field_width(); + }, + 0); + } + + public: + // unique id for this scalar type that can be computed at compile time for + // c++17 template specialization this is not needed once we migrate to + // c++20 and can pass literal classes as template parameters + constexpr Id id() const { + static_assert(id_size_bits() <= sizeof(Id) * 8, + "ScalarType id is too large to be stored"); + + auto or_and_advance = [](std::pair result, + auto member) -> std::pair { + auto [id, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) + << bit_offset, + bit_offset + bits}; + }; + return reduce_members(or_and_advance, std::pair{}).first; + } + + // create a ScalarType from an id, for c++17 template specialization, + // this is not needed once we migrate to c++20 and can pass literal + // classes as template parameters + static constexpr ScalarType from_id(Id id) { + auto extract_and_advance = [id](auto result, auto member) { + using T = decltype(member); + auto [tuple, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + auto extracted_val = static_cast((int64_t(id) >> bit_offset) & + ((uint64_t(1) << bits) - 1)); + auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val)); + return std::pair{new_tuple, bit_offset + bits}; + }; + + auto [tuple_args, _] = reduce_member_types(extract_and_advance, + std::pair, int>{}); + return std::apply([](auto... args) { return ScalarType(args...); }, + tuple_args); + } + + constexpr int64_t size_bits() const { + return mantissa + exponent + is_signed(); + } + constexpr bool is_signed() const { return signed_; } + constexpr bool is_integer() const { return exponent == 0; } + constexpr bool is_floating_point() const { return exponent > 0; } + constexpr bool is_ieee_754() const { + return is_floating_point() && finite_values_only == false && + nan_repr == NAN_IEEE_754; + } + constexpr bool has_nans() const { + return is_floating_point() && nan_repr != NAN_NONE; + } + constexpr bool has_infs() const { + return is_floating_point() && finite_values_only == false; + } + constexpr bool has_bias() const { return bias != 0; } + + private: + double _floating_point_max() const { + TORCH_CHECK(mantissa <= 52 && exponent <= 11, + "Cannot represent max/min as a double for type ", str()); + + uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { + max_mantissa -= 1; + } + + uint64_t max_exponent = (uint64_t(1) << exponent) - 2; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { + TORCH_CHECK(exponent < 11, + "Cannot represent max/min as a double for type ", str()); + max_exponent += 1; + } + + // adjust the exponent to match that of a double + // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e + // is the exponent bits), there is some precedent for non-standard biases, + // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes + // but to avoid premature over complication we are just assuming the + // standard exponent bias until there is a need to support non-standard + // biases + uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; + uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 + + uint64_t max_exponent_double = + max_exponent - exponent_bias + exponent_bias_double; + + // shift the mantissa into the position for a double and + // the exponent + uint64_t double_raw = + (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); + + return *reinterpret_cast(&double_raw); + } + + constexpr std::variant _raw_max() const { + if (is_floating_point()) { + return {_floating_point_max()}; + } else { + TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(), + "Cannot represent max as a int64_t"); + return {(int64_t(1) << mantissa) - 1}; + } + } + + constexpr std::variant _raw_min() const { + if (is_floating_point()) { + TORCH_CHECK(is_signed(), + "We currently assume all floating point types are signed"); + constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); + + double max = _floating_point_max(); + uint64_t max_raw = *reinterpret_cast(&max); + uint64_t min_raw = max_raw | sign_bit_double; + return {*reinterpret_cast(&min_raw)}; + } else { + TORCH_CHECK(!is_signed() || size_bits() <= 64, + "Cannot represent min as a int64_t"); + if (is_signed()) { + // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 + // then perform an arithmetic shift right to set all the bits above + // (size_bits() - 1) to 1 + return {INT64_MIN >> (64 - size_bits())}; + } else { + return {int64_t(0)}; + } + } + } + + public: + // Max representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant max() const { + return std::visit( + [this](auto x) -> std::variant { return {x - bias}; }, + _raw_max()); + } + + // Min representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant min() const { + return std::visit( + [this](auto x) -> std::variant { return {x - bias}; }, + _raw_min()); + } + + std::string str() const { + /* naming generally follows: https://github.com/jax-ml/ml_dtypes + * for floating point types (leading f) the scheme is: + * `float_em[flags]` + * flags: + * - no-flags: means it follows IEEE 754 conventions + * - f: means finite values only (no infinities) + * - n: means nans are supported (non-standard encoding) + * for integer types the scheme is: + * `[u]int[b]` + * - if bias is not present it means its zero + */ + if (is_floating_point()) { + auto ret = "float" + std::to_string(size_bits()) + "_e" + + std::to_string(exponent) + "m" + std::to_string(mantissa); + if (!is_ieee_754()) { + if (finite_values_only) { + ret += "f"; + } + if (nan_repr != NAN_NONE) { + ret += "n"; + } + } + return ret; + } else { + auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); + if (has_bias()) { + ret += "b" + std::to_string(bias); + } + return ret; + } + } + + constexpr bool operator==(ScalarType const& other) const { + return mantissa == other.mantissa && exponent == other.exponent && + bias == other.bias && signed_ == other.signed_ && + finite_values_only == other.finite_values_only && + nan_repr == other.nan_repr; + } +}; + +using ScalarTypeId = ScalarType::Id; + +// "rust style" names generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70 +static inline constexpr auto kS4 = ScalarType::int_(4); +static inline constexpr auto kU4 = ScalarType::uint(4); +static inline constexpr auto kU4B8 = ScalarType::uint(4, 8); +static inline constexpr auto kS8 = ScalarType::int_(8); +static inline constexpr auto kU8 = ScalarType::uint(8); +static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); + +static inline constexpr auto kFE3M2f = + ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE4M3fn = + ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); +static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); +static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); + +// Fixed width style names, generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57 +static inline constexpr auto kInt4 = kS4; +static inline constexpr auto kUint4 = kU4; +static inline constexpr auto kUint4b8 = kU4B8; +static inline constexpr auto kInt8 = kS8; +static inline constexpr auto kUint8 = kU8; +static inline constexpr auto kUint8b128 = kU8B128; + +static inline constexpr auto kFloat6_e3m2f = kFE3M2f; +static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; +static inline constexpr auto kFloat8_e5m2 = kFE5M2; +static inline constexpr auto kFloat16_e8m7 = kFE8M7; +static inline constexpr auto kFloat16_e5m10 = kFE5M10; + +// colloquial names +static inline constexpr auto kHalf = kFE5M10; +static inline constexpr auto kFloat16 = kHalf; +static inline constexpr auto kBFloat16 = kFE8M7; + +static inline constexpr auto kFloat16Id = kFloat16.id(); +}; // namespace vllm diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index abb4e3bea14b..0257d8ff16ba 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -22,6 +22,24 @@ struct KernelVecType { using v_load_vec_type = vec_op::FP32Vec16; }; +template <> +struct KernelVecType { +#if defined(__powerpc64__) || defined(__s390x__) + // Power and s390x architecture-specific vector types + using q_load_vec_type = vec_op::FP32Vec8; + using k_load_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::FP32Vec16; +#else + // Fallback for other architectures, including x86 + using q_load_vec_type = vec_op::FP16Vec8; + using k_load_vec_type = vec_op::FP16Vec16; + using v_load_vec_type = vec_op::FP16Vec16; +#endif + using q_vec_type = vec_op::FP32Vec16; + using k_vec_type = vec_op::FP32Vec16; + using qk_acc_vec_type = vec_op::FP32Vec16; +}; + #ifdef __AVX512BF16__ template <> struct KernelVecType { @@ -33,6 +51,21 @@ struct KernelVecType { using v_load_vec_type = vec_op::BF16Vec16; }; #else + #ifdef __aarch64__ + #ifndef ARM_BF16_SUPPORT + // pass + #else +template <> +struct KernelVecType { + using q_load_vec_type = vec_op::BF16Vec8; + using q_vec_type = vec_op::FP32Vec16; + using k_load_vec_type = vec_op::BF16Vec16; + using k_vec_type = vec_op::FP32Vec16; + using qk_acc_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::BF16Vec16; +}; + #endif + #else template <> struct KernelVecType { using q_load_vec_type = vec_op::BF16Vec8; @@ -42,6 +75,7 @@ struct KernelVecType { using qk_acc_vec_type = vec_op::FP32Vec16; using v_load_vec_type = vec_op::BF16Vec16; }; + #endif #endif template @@ -352,7 +386,7 @@ void paged_attention_v1_impl_launcher( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const c10::optional& alibi_slopes) { + const std::optional& alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -375,6 +409,9 @@ void paged_attention_v1_impl_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { + case 32: + LAUNCH_V1_ATTENTION_KERNEL(T, 32, BLOCK_SIZE); + break; case 64: LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); break; @@ -422,12 +459,12 @@ void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + int64_t max_seq_len, const std::optional& alibi_slopes, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", @@ -665,7 +702,7 @@ void paged_attention_v2_impl_launcher( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, - int max_seq_len, const c10::optional& alibi_slopes) { + int max_seq_len, const std::optional& alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -692,6 +729,9 @@ void paged_attention_v2_impl_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { + case 32: + LAUNCH_V2_ATTENTION_KERNEL(T, 32, BLOCK_SIZE); + break; case 64: LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); break; @@ -741,12 +781,12 @@ void paged_attention_v2( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + int64_t max_seq_len, const std::optional& alibi_slopes, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", @@ -755,4 +795,4 @@ void paged_attention_v2( CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl) }); -} +} \ No newline at end of file diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 31d454328b2c..d726ee9307fe 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -3,6 +3,12 @@ #include "cpu_types.hpp" +#if defined(__x86_64__) + #define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2 +#else + #define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES +#endif + namespace { template void copy_blocks_cpu_impl(std::vector const& key_caches, @@ -95,22 +101,19 @@ void copy_blocks(std::vector const& key_caches, } const int element_num_per_block = key_caches[0][0].numel(); - VLLM_DISPATCH_FLOATING_TYPES( - key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { - CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) - copy_blocks_cpu_impl(key_caches, value_caches, block_mapping, - element_num_per_block, num_layers); - CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) - }); + DISPATCH_MACRO(key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { + CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) + copy_blocks_cpu_impl(key_caches, value_caches, block_mapping, + element_num_per_block, num_layers); + CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) + }); } void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, double k_scale, - double v_scale) { - TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); - + const std::string& kv_cache_dtype, + torch::Tensor& k_scale, torch::Tensor& v_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); @@ -120,16 +123,15 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, int key_stride = key.stride(0); int value_stride = value.stride(0); - VLLM_DISPATCH_FLOATING_TYPES( - key.scalar_type(), "reshape_and_cache_cpu_impl", [&] { - CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl) - reshape_and_cache_cpu_impl( - key.data_ptr(), value.data_ptr(), - key_cache.data_ptr(), value_cache.data_ptr(), - slot_mapping.data_ptr(), num_tokens, key_stride, - value_stride, num_heads, head_size, block_size, x); - CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl) - }); + DISPATCH_MACRO(key.scalar_type(), "reshape_and_cache_cpu_impl", [&] { + CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl) + reshape_and_cache_cpu_impl( + key.data_ptr(), value.data_ptr(), + key_cache.data_ptr(), value_cache.data_ptr(), + slot_mapping.data_ptr(), num_tokens, key_stride, value_stride, + num_heads, head_size, block_size, x); + CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl) + }); } void swap_blocks(torch::Tensor& src, torch::Tensor& dst, diff --git a/csrc/cpu/cpu_types.hpp b/csrc/cpu/cpu_types.hpp index 0213be09105e..17bbe04eef94 100644 --- a/csrc/cpu/cpu_types.hpp +++ b/csrc/cpu/cpu_types.hpp @@ -1,15 +1,20 @@ - #ifndef CPU_TYPES_HPP #define CPU_TYPES_HPP #if defined(__x86_64__) - //x86 implementation + // x86 implementation #include "cpu_types_x86.hpp" #elif defined(__POWER9_VECTOR__) - //ppc implementation + // ppc implementation #include "cpu_types_vsx.hpp" +#elif defined(__s390x__) + // s390 implementation + #include "cpu_types_vxe.hpp" +#elif defined(__aarch64__) + // arm implementation + #include "cpu_types_arm.hpp" #else #warning "unsupported vLLM cpu implementation" #endif -#endif +#endif \ No newline at end of file diff --git a/csrc/cpu/cpu_types_arm.hpp b/csrc/cpu/cpu_types_arm.hpp new file mode 100644 index 000000000000..65ffe524af73 --- /dev/null +++ b/csrc/cpu/cpu_types_arm.hpp @@ -0,0 +1,595 @@ +#include +#include +#include + +#if defined(__APPLE__) + #include "omp.h" +#endif + +namespace vec_op { + +#ifdef ARM_BF16_SUPPORT + #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) +#else + #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) +#endif + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#ifndef CPU_OP_GUARD + #define CPU_KERNEL_GUARD_IN(NAME) + #define CPU_KERNEL_GUARD_OUT(NAME) +#else + #define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; + #define CPU_KERNEL_GUARD_OUT(NAME) \ + std::cout << #NAME << " exit." << std::endl; +#endif + +#define FORCE_INLINE __attribute__((always_inline)) inline + +namespace { +template +constexpr void unroll_loop_item(std::integer_sequence, F&& f) { + (f(std::integral_constant{}), ...); +}; +}; // namespace + +template >> +constexpr void unroll_loop(F&& f) { + unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); +} + +template +struct Vec { + constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }; +}; + +struct FP32Vec8; +struct FP32Vec16; + +struct FP16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + float16x8_t reg; + + explicit FP16Vec8(const void* ptr) + : reg(vld1q_f16(static_cast(ptr))) {}; + + explicit FP16Vec8(const FP32Vec8&); + + void save(void* ptr) const { vst1q_f16(static_cast<__fp16*>(ptr), reg); } +}; + +struct FP16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + float16x8x2_t reg; + + explicit FP16Vec16(const void* ptr) { + reg.val[0] = vld1q_f16(reinterpret_cast(ptr)); + reg.val[1] = vld1q_f16(reinterpret_cast(ptr) + 8); + } + + explicit FP16Vec16(const FP32Vec16& vec); + + void save(void* ptr) const { + vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]); + vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]); + } + + void save(void* ptr, const int elem_num) const { + int full_blocks = elem_num / 8; + int remainder = elem_num % 8; + + if (full_blocks > 0) { + vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]); + if (full_blocks > 1) { + vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]); + } + } + + // Note: below is the unrolled version of the following code: + // + // for (int i = 0; i < remainder; ++i) { + // reinterpret_cast<__fp16*>(ptr)[full_blocks * 8 + i] = + // vgetq_lane_f16(temp, i); + // } + // + // For macOS build (Clang), the arm/neon intrinsics function + // `vgetq_lane_f16` needs the parameter `i` to be constant at compile + // time. + + if (remainder > 0) { + float16x8_t temp = reg.val[full_blocks]; + __fp16* fp16_ptr = reinterpret_cast<__fp16*>(ptr); + switch (remainder) { + case 1: + fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); + break; + case 2: + fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); + fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1); + break; + case 3: + fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); + fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1); + fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2); + break; + case 4: + fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); + fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1); + fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2); + fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3); + break; + case 5: + fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); + fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1); + fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2); + fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3); + fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4); + break; + case 6: + fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); + fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1); + fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2); + fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3); + fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4); + fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5); + break; + case 7: + fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); + fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1); + fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2); + fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3); + fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4); + fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5); + fp16_ptr[full_blocks * 8 + 6] = vgetq_lane_f16(temp, 6); + break; + + default: + break; + } + } + } +}; + +#ifdef ARM_BF16_SUPPORT +struct BF16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + bfloat16x8_t reg; + + explicit BF16Vec8(const void* ptr) + : reg(*reinterpret_cast(ptr)) {}; + + explicit BF16Vec8(bfloat16x8_t data) : reg(data) {}; + + explicit BF16Vec8(const FP32Vec8&); + + explicit BF16Vec8(float32x4x2_t v) + : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1])) {}; + + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } +}; + +struct BF16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + bfloat16x8x2_t reg; + + explicit BF16Vec16(const void* ptr) + : reg(*reinterpret_cast(ptr)) {}; + + explicit BF16Vec16(bfloat16x8x2_t data) : reg(data) {}; + + explicit BF16Vec16(const FP32Vec16&); + + explicit BF16Vec16(float32x4x4_t v) + : reg({vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1]), + vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {}; + + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; }; +}; + +struct BF16Vec32 : public Vec { + constexpr static int VEC_ELEM_NUM = 32; + + bfloat16x8x4_t reg; + + explicit BF16Vec32(const void* ptr) + : reg(*reinterpret_cast(ptr)) {}; + + explicit BF16Vec32(bfloat16x8x4_t data) : reg(data) {}; + + explicit BF16Vec32(const BF16Vec8& vec8_data) + : reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {}; + + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; }; +}; +#endif + +struct FP32Vec4 : public Vec { + constexpr static int VEC_ELEM_NUM = 4; + + union AliasReg { + float32x4_t reg; + float values[VEC_ELEM_NUM]; + }; + + float32x4_t reg; + + explicit FP32Vec4(float v) : reg(vdupq_n_f32(v)) {}; + + explicit FP32Vec4() : reg(vdupq_n_f32(0.0f)) {}; + + explicit FP32Vec4(const float* ptr) : reg(vld1q_f32(ptr)) {}; + + explicit FP32Vec4(float32x4_t data) : reg(data) {}; + + explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}; +}; + +struct FP32Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + union AliasReg { + float32x4x2_t reg; + float values[VEC_ELEM_NUM]; + }; + + float32x4x2_t reg; + + explicit FP32Vec8(float v) : reg({vmovq_n_f32(v), vmovq_n_f32(v)}) {}; + + explicit FP32Vec8() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {}; + + explicit FP32Vec8(const float* ptr) + : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4)}) {}; + + explicit FP32Vec8(float32x4x2_t data) : reg(data) {}; + + explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {}; + + explicit FP32Vec8(const FP16Vec8& v) { + reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg)); + reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg)); + }; + + explicit FP32Vec8(float16x8_t v) + : reg({vcvt_f32_f16(vget_low_f16(v)), vcvt_f32_f16(vget_high_f16(v))}) {}; + +#ifdef ARM_BF16_SUPPORT + + explicit FP32Vec8(bfloat16x8_t v) + : reg({vcvtq_low_f32_bf16(v), vcvtq_high_f32_bf16(v)}) {}; + + explicit FP32Vec8(const BF16Vec8& v) + : reg({vcvtq_low_f32_bf16(v.reg), vcvtq_high_f32_bf16(v.reg)}) {}; + +#endif + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float answer = 0; + unroll_loop( + [&answer, &ar](int i) { answer += ar.values[i]; }); + + return answer; + } + + FP32Vec8 exp() const { + AliasReg ar; + ar.reg = reg; + + float32x2_t exp_vec0 = {expf(ar.values[0]), expf(ar.values[1])}; + float32x2_t exp_vec1 = {expf(ar.values[2]), expf(ar.values[3])}; + float32x2_t exp_vec2 = {expf(ar.values[4]), expf(ar.values[5])}; + float32x2_t exp_vec3 = {expf(ar.values[6]), expf(ar.values[7])}; + + float32x4_t result0 = vcombine_f32(exp_vec0, exp_vec1); + float32x4_t result1 = vcombine_f32(exp_vec2, exp_vec3); + + float32x4x2_t result; + result.val[0] = result0; + result.val[1] = result1; + + return FP32Vec8(result); + } + + FP32Vec8 tanh() const { + AliasReg ar; + ar.reg = reg; + + float32x2_t tanh_vec0 = {tanhf(ar.values[0]), tanhf(ar.values[1])}; + float32x2_t tanh_vec1 = {tanhf(ar.values[2]), tanhf(ar.values[3])}; + float32x2_t tanh_vec2 = {tanhf(ar.values[4]), tanhf(ar.values[5])}; + float32x2_t tanh_vec3 = {tanhf(ar.values[6]), tanhf(ar.values[7])}; + + float32x4_t result0 = vcombine_f32(tanh_vec0, tanh_vec1); + float32x4_t result1 = vcombine_f32(tanh_vec2, tanh_vec3); + + float32x4x2_t result; + result.val[0] = result0; + result.val[1] = result1; + + return FP32Vec8(result); + } + + FP32Vec8 er() const { + AliasReg ar; + ar.reg = reg; + + float32x2_t er_vec0 = {static_cast(erf(ar.values[0])), + static_cast(erf(ar.values[1]))}; + float32x2_t er_vec1 = {static_cast(erf(ar.values[2])), + static_cast(erf(ar.values[3]))}; + float32x2_t er_vec2 = {static_cast(erf(ar.values[4])), + static_cast(erf(ar.values[5]))}; + float32x2_t er_vec3 = {static_cast(erf(ar.values[6])), + static_cast(erf(ar.values[7]))}; + + float32x4_t result0 = vcombine_f32(er_vec0, er_vec1); + float32x4_t result1 = vcombine_f32(er_vec2, er_vec3); + + float32x4x2_t result; + result.val[0] = result0; + result.val[1] = result1; + + return FP32Vec8(result); + } + + FP32Vec8 operator*(const FP32Vec8& b) const { + return FP32Vec8(float32x4x2_t({vmulq_f32(reg.val[0], b.reg.val[0]), + vmulq_f32(reg.val[1], b.reg.val[1])})); + } + + FP32Vec8 operator+(const FP32Vec8& b) const { + return FP32Vec8(float32x4x2_t({vaddq_f32(reg.val[0], b.reg.val[0]), + vaddq_f32(reg.val[1], b.reg.val[1])})); + } + + FP32Vec8 operator-(const FP32Vec8& b) const { + return FP32Vec8(float32x4x2_t({vsubq_f32(reg.val[0], b.reg.val[0]), + vsubq_f32(reg.val[1], b.reg.val[1])})); + } + + FP32Vec8 operator/(const FP32Vec8& b) const { + return FP32Vec8(float32x4x2_t({vdivq_f32(reg.val[0], b.reg.val[0]), + vdivq_f32(reg.val[1], b.reg.val[1])})); + } + + void save(float* ptr) const { + vst1q_f32(ptr, reg.val[0]); + vst1q_f32(ptr + 4, reg.val[1]); + } +}; + +struct FP32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + float32x4x4_t reg; + float values[VEC_ELEM_NUM]; + }; + + float32x4x4_t reg; + + explicit FP32Vec16(float v) + : reg({vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v)}) {} + + explicit FP32Vec16() + : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0), + vmovq_n_f32(0.0)}) {} + + explicit FP32Vec16(const float* ptr) + : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8), + vld1q_f32(ptr + 12)}) {} + + explicit FP32Vec16(float32x4x4_t data) : reg(data) {} + + explicit FP32Vec16(const FP32Vec8& data) { + reg.val[0] = data.reg.val[0]; + reg.val[1] = data.reg.val[1]; + reg.val[2] = data.reg.val[0]; + reg.val[3] = data.reg.val[1]; + } + + explicit FP32Vec16(const FP32Vec16& data) : reg(data.reg) {} + + explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v.reg)) {} + +#ifdef ARM_BF16_SUPPORT + explicit FP32Vec16(bfloat16x8x2_t v) + : reg({vcvtq_low_f32_bf16(v.val[0]), vcvtq_high_f32_bf16(v.val[0]), + vcvtq_low_f32_bf16(v.val[1]), vcvtq_high_f32_bf16(v.val[1])}) {}; +#endif + + explicit FP32Vec16(const FP32Vec4& data) { + reg.val[0] = data.reg; + reg.val[1] = data.reg; + reg.val[2] = data.reg; + reg.val[3] = data.reg; + }; + +#ifdef ARM_BF16_SUPPORT + explicit FP32Vec16(const BF16Vec16& v) + : reg({vcvtq_low_f32_bf16(v.reg.val[0]), + vcvtq_high_f32_bf16(v.reg.val[0]), + vcvtq_low_f32_bf16(v.reg.val[1]), + vcvtq_high_f32_bf16(v.reg.val[1])}) {}; + + explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}; +#endif + + explicit FP32Vec16(const FP16Vec16& v) { + reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg.val[0])); + reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg.val[0])); + reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1])); + reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1])); + }; + + FP32Vec16 operator+(const FP32Vec16& b) const { + return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]), + vaddq_f32(reg.val[1], b.reg.val[1]), + vaddq_f32(reg.val[2], b.reg.val[2]), + vaddq_f32(reg.val[3], b.reg.val[3])})); + }; + + FP32Vec16 operator*(const FP32Vec16& b) const { + return FP32Vec16(float32x4x4_t({vmulq_f32(reg.val[0], b.reg.val[0]), + vmulq_f32(reg.val[1], b.reg.val[1]), + vmulq_f32(reg.val[2], b.reg.val[2]), + vmulq_f32(reg.val[3], b.reg.val[3])})); + }; + + FP32Vec16 operator-(const FP32Vec16& b) const { + return FP32Vec16(float32x4x4_t({vsubq_f32(reg.val[0], b.reg.val[0]), + vsubq_f32(reg.val[1], b.reg.val[1]), + vsubq_f32(reg.val[2], b.reg.val[2]), + vsubq_f32(reg.val[3], b.reg.val[3])})); + }; + + FP32Vec16 operator/(const FP32Vec16& b) const { + return FP32Vec16(float32x4x4_t({vdivq_f32(reg.val[0], b.reg.val[0]), + vdivq_f32(reg.val[1], b.reg.val[1]), + vdivq_f32(reg.val[2], b.reg.val[2]), + vdivq_f32(reg.val[3], b.reg.val[3])})); + }; + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float answer = 0; + unroll_loop( + [&answer, &ar](int i) { answer += ar.values[i]; }); + + return answer; + }; + + template + float reduce_sub_sum(int idx) { + static_assert(VEC_ELEM_NUM % group_size == 0); + + AliasReg ar; + ar.reg = reg; + float answer = 0; + const int start = idx * group_size; + unroll_loop( + [&answer, &start, ar](int i) { answer += ar.values[start + i]; }); + + return answer; + }; + + void save(float* ptr) const { + vst1q_f32(ptr, reg.val[0]); + vst1q_f32(ptr + 4, reg.val[1]); + vst1q_f32(ptr + 8, reg.val[2]); + vst1q_f32(ptr + 12, reg.val[3]); + }; +}; + +template +struct VecType { + using vec_type = void; +}; + +template +using vec_t = typename VecType::vec_type; + +template <> +struct VecType { + using vec_type = FP32Vec8; +}; + +template <> +struct VecType { + using vec_type = FP16Vec8; +}; + +#ifdef ARM_BF16_SUPPORT +template <> +struct VecType { + using vec_type = BF16Vec8; +}; +#endif + +template +void storeFP32(float v, T* ptr) { + *ptr = v; +} + +template <> +inline void storeFP32(float v, c10::Half* ptr) { + *reinterpret_cast<__fp16*>(ptr) = v; +} + +inline FP16Vec16::FP16Vec16(const FP32Vec16& v) { + float16x4_t low_0 = vcvt_f16_f32(v.reg.val[0]); + float16x4_t high_0 = vcvt_f16_f32(v.reg.val[1]); + float16x4_t low_1 = vcvt_f16_f32(v.reg.val[2]); + float16x4_t high_1 = vcvt_f16_f32(v.reg.val[3]); + + reg.val[0] = vcombine_f16(low_0, high_0); + reg.val[1] = vcombine_f16(low_1, high_1); +}; + +inline FP16Vec8 ::FP16Vec8(const FP32Vec8& v) { + float16x4_t lower_half = vcvt_f16_f32(v.reg.val[0]); + float16x4_t upper_half = vcvt_f16_f32(v.reg.val[1]); + + reg = vcombine_f16(lower_half, upper_half); +}; + +inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) { + acc.reg.val[0] = vfmaq_f32(acc.reg.val[0], a.reg.val[0], b.reg.val[0]); + acc.reg.val[1] = vfmaq_f32(acc.reg.val[1], a.reg.val[1], b.reg.val[1]); + acc.reg.val[2] = vfmaq_f32(acc.reg.val[2], a.reg.val[2], b.reg.val[2]); + acc.reg.val[3] = vfmaq_f32(acc.reg.val[3], a.reg.val[3], b.reg.val[3]); +}; + +#ifdef ARM_BF16_SUPPORT +inline void fma(FP32Vec16& acc, BF16Vec32& a, BF16Vec32& b) { + float32x4_t a0_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[0])); + float32x4_t a0_high = vcvt_f32_bf16(vget_high_bf16(a.reg.val[0])); + float32x4_t a1_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[1])); + float32x4_t a1_high = vcvt_f32_bf16(vget_high_bf16(a.reg.val[1])); + + float32x4_t b0_low = vcvt_f32_bf16(vget_low_bf16(b.reg.val[0])); + float32x4_t b0_high = vcvt_f32_bf16(vget_high_bf16(b.reg.val[0])); + float32x4_t b1_low = vcvt_f32_bf16(vget_low_bf16(b.reg.val[1])); + float32x4_t b1_high = vcvt_f32_bf16(vget_high_bf16(b.reg.val[1])); + + acc.reg.val[0] = vfmaq_f32(acc.reg.val[0], a0_low, b0_low); + acc.reg.val[1] = vfmaq_f32(acc.reg.val[1], a0_high, b0_high); + acc.reg.val[2] = vfmaq_f32(acc.reg.val[2], a1_low, b1_low); + acc.reg.val[3] = vfmaq_f32(acc.reg.val[3], a1_high, b1_high); +}; +#endif + +#ifdef ARM_BF16_SUPPORT +inline BF16Vec8::BF16Vec8(const FP32Vec8& v) + : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1])) { + }; + +inline BF16Vec16::BF16Vec16(const FP32Vec16& v) + : reg({vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1]), + vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[2]), + v.reg.val[3])}) {}; +#endif + +inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 1); }; + +#ifdef ARM_BF16_SUPPORT +template <> +inline void storeFP32(float v, c10::BFloat16* ptr) { + *reinterpret_cast<__bf16*>(ptr) = vcvth_bf16_f32(v); +}; +#endif +}; // namespace vec_op \ No newline at end of file diff --git a/csrc/cpu/cpu_types_vsx.hpp b/csrc/cpu/cpu_types_vsx.hpp index b50bdadc5713..a8e1be37eb41 100644 --- a/csrc/cpu/cpu_types_vsx.hpp +++ b/csrc/cpu/cpu_types_vsx.hpp @@ -9,38 +9,40 @@ namespace vec_op { // FIXME: FP16 is not fully supported in Torch-CPU -#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) -#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) #ifndef CPU_OP_GUARD -#define CPU_KERNEL_GUARD_IN(NAME) -#define CPU_KERNEL_GUARD_OUT(NAME) + #define CPU_KERNEL_GUARD_IN(NAME) + #define CPU_KERNEL_GUARD_OUT(NAME) #else -#define CPU_KERNEL_GUARD_IN(NAME) \ - std::cout << #NAME << " invoked." << std::endl; -#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; + #define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; + #define CPU_KERNEL_GUARD_OUT(NAME) \ + std::cout << #NAME << " exit." << std::endl; #endif #define FORCE_INLINE __attribute__((always_inline)) inline namespace { template -constexpr void unroll_loop_item(std::integer_sequence, F &&f) { +constexpr void unroll_loop_item(std::integer_sequence, F&& f) { (f(std::integral_constant{}), ...); } -}; // namespace +}; // namespace template >> -constexpr void unroll_loop(F &&f) { +constexpr void unroll_loop(F&& f) { unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); } -template struct Vec { +template +struct Vec { constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } }; @@ -68,12 +70,14 @@ struct BF16Vec8 : public Vec { __vector signed short reg; - explicit BF16Vec8(const void *ptr) - : reg((__vector signed short)vec_xl(0, (__vector signed short *)ptr)) {} + explicit BF16Vec8(const void* ptr) + : reg((__vector signed short)vec_xl(0, (__vector signed short*)ptr)) {} - explicit BF16Vec8(const FP32Vec8 &); + explicit BF16Vec8(const FP32Vec8&); - void save(void *ptr) const { *reinterpret_cast<__vector signed short *>(ptr) = reg; } + void save(void* ptr) const { + *reinterpret_cast<__vector signed short*>(ptr) = reg; + } }; struct BF16Vec16 : public Vec { @@ -81,18 +85,18 @@ struct BF16Vec16 : public Vec { ss16x8x2_t reg; - explicit BF16Vec16(const void *ptr) { + explicit BF16Vec16(const void* ptr) { // Load 256 bits in two parts - reg.val[0] = (__vector signed short)vec_xl(0, (signed short *)ptr); - reg.val[1] = (__vector signed short)vec_xl(16, (signed short *)ptr); + reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr); + reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr); } - explicit BF16Vec16(const FP32Vec16 &); + explicit BF16Vec16(const FP32Vec16&); - void save(void *ptr) const { + void save(void* ptr) const { // Save 256 bits in two parts - vec_xst(reg.val[0], 0, (signed short *)ptr); - vec_xst(reg.val[1], 16, (signed short *)ptr); + vec_xst(reg.val[0], 0, (signed short*)ptr); + vec_xst(reg.val[1], 16, (signed short*)ptr); } }; @@ -102,19 +106,15 @@ struct BF16Vec32 : public Vec { constexpr static int VEC_ELEM_NUM = 32; ss16x8x4_t reg; - explicit BF16Vec32(const void *ptr) - : reg(*reinterpret_cast(ptr)) {} + explicit BF16Vec32(const void* ptr) + : reg(*reinterpret_cast(ptr)) {} explicit BF16Vec32(ss16x8x4_t data) : reg(data) {} - explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({ - vec8_data.reg, - vec8_data.reg, - vec8_data.reg, - vec8_data.reg - }) {} + explicit BF16Vec32(const BF16Vec8& vec8_data) + : reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {} - void save(void *ptr) const { *reinterpret_cast(ptr) = reg; } + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } }; struct FP32Vec4 : public Vec { @@ -130,11 +130,11 @@ struct FP32Vec4 : public Vec { explicit FP32Vec4() : reg(vec_splats(0.0f)) {} - explicit FP32Vec4(const float *ptr) : reg(vec_xl(0, ptr)) {} + explicit FP32Vec4(const float* ptr) : reg(vec_xl(0, ptr)) {} explicit FP32Vec4(__vector float data) : reg(data) {} - explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} + explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {} }; struct FP32Vec8 : public Vec { @@ -156,19 +156,19 @@ struct FP32Vec8 : public Vec { reg.val[1] = vec_splats(0.0f); } - explicit FP32Vec8(const float *ptr) { + explicit FP32Vec8(const float* ptr) { reg.val[0] = vec_xl(0, ptr); reg.val[1] = vec_xl(16, ptr); } explicit FP32Vec8(f32x4x2_t data) : reg(data) {} - explicit FP32Vec8(const FP32Vec8 &data) { + explicit FP32Vec8(const FP32Vec8& data) { reg.val[0] = data.reg.val[0]; reg.val[1] = data.reg.val[1]; } - explicit FP32Vec8(const BF16Vec8 &v) { + explicit FP32Vec8(const BF16Vec8& v) { reg.val[0] = (__vector float)vec_mergeh(zero, v.reg); reg.val[1] = (__vector float)vec_mergel(zero, v.reg); } @@ -177,7 +177,8 @@ struct FP32Vec8 : public Vec { AliasReg ar; ar.reg = reg; float result = 0; - unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); + unroll_loop( + [&result, &ar](int i) { result += ar.values[i]; }); return result; } @@ -230,23 +231,27 @@ struct FP32Vec8 : public Vec { return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); } - FP32Vec8 operator*(const FP32Vec8 &b) const { - return FP32Vec8({vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])}); + FP32Vec8 operator*(const FP32Vec8& b) const { + return FP32Vec8( + {vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])}); } - FP32Vec8 operator+(const FP32Vec8 &b) const { - return FP32Vec8({vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])}); + FP32Vec8 operator+(const FP32Vec8& b) const { + return FP32Vec8( + {vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])}); } - FP32Vec8 operator-(const FP32Vec8 &b) const { - return FP32Vec8({vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])}); + FP32Vec8 operator-(const FP32Vec8& b) const { + return FP32Vec8( + {vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])}); } - FP32Vec8 operator/(const FP32Vec8 &b) const { - return FP32Vec8({vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])}); + FP32Vec8 operator/(const FP32Vec8& b) const { + return FP32Vec8( + {vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])}); } - void save(float *ptr) const { + void save(float* ptr) const { vec_xst(reg.val[0], 0, ptr); vec_xst(reg.val[1], 16, ptr); } @@ -275,7 +280,7 @@ struct FP32Vec16 : public Vec { reg.val[3] = vec_splats(0.0f); } - explicit FP32Vec16(const float *ptr) { + explicit FP32Vec16(const float* ptr) { reg.val[0] = vec_xl(0, ptr); reg.val[1] = vec_xl(16, ptr); reg.val[2] = vec_xl(32, ptr); @@ -284,78 +289,76 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16(f32x4x4_t data) : reg(data) {} - explicit FP32Vec16(const FP32Vec16 &data) { + explicit FP32Vec16(const FP32Vec16& data) { reg.val[0] = data.reg.val[0]; reg.val[1] = data.reg.val[1]; reg.val[2] = data.reg.val[2]; reg.val[3] = data.reg.val[3]; } - explicit FP32Vec16(const FP32Vec4 &data) { + explicit FP32Vec16(const FP32Vec4& data) { reg.val[0] = data.reg; reg.val[1] = data.reg; reg.val[2] = data.reg; reg.val[3] = data.reg; } - explicit FP32Vec16(const FP32Vec8 &data) { + explicit FP32Vec16(const FP32Vec8& data) { reg.val[0] = data.reg.val[0]; reg.val[1] = data.reg.val[1]; reg.val[2] = data.reg.val[0]; reg.val[3] = data.reg.val[1]; } - explicit FP32Vec16(const BF16Vec16 &v) { + explicit FP32Vec16(const BF16Vec16& v) { reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]); reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]); reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]); reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]); } - explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} - FP32Vec16 operator*(const FP32Vec16 &b) const { - return FP32Vec16(f32x4x4_t({ - vec_mul(reg.val[0], b.reg.val[0]), - vec_mul(reg.val[1], b.reg.val[1]), - vec_mul(reg.val[2], b.reg.val[2]), - vec_mul(reg.val[3], b.reg.val[3])})); + FP32Vec16 operator*(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]), + vec_mul(reg.val[1], b.reg.val[1]), + vec_mul(reg.val[2], b.reg.val[2]), + vec_mul(reg.val[3], b.reg.val[3])})); } - FP32Vec16 operator+(const FP32Vec16 &b) const { - return FP32Vec16(f32x4x4_t({ - vec_add(reg.val[0], b.reg.val[0]), - vec_add(reg.val[1], b.reg.val[1]), - vec_add(reg.val[2], b.reg.val[2]), - vec_add(reg.val[3], b.reg.val[3])})); + FP32Vec16 operator+(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_add(reg.val[0], b.reg.val[0]), + vec_add(reg.val[1], b.reg.val[1]), + vec_add(reg.val[2], b.reg.val[2]), + vec_add(reg.val[3], b.reg.val[3])})); } - FP32Vec16 operator-(const FP32Vec16 &b) const { - return FP32Vec16(f32x4x4_t({ - vec_sub(reg.val[0], b.reg.val[0]), - vec_sub(reg.val[1], b.reg.val[1]), - vec_sub(reg.val[2], b.reg.val[2]), - vec_sub(reg.val[3], b.reg.val[3])})); + FP32Vec16 operator-(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_sub(reg.val[0], b.reg.val[0]), + vec_sub(reg.val[1], b.reg.val[1]), + vec_sub(reg.val[2], b.reg.val[2]), + vec_sub(reg.val[3], b.reg.val[3])})); } - FP32Vec16 operator/(const FP32Vec16 &b) const { - return FP32Vec16(f32x4x4_t({ - vec_div(reg.val[0], b.reg.val[0]), - vec_div(reg.val[1], b.reg.val[1]), - vec_div(reg.val[2], b.reg.val[2]), - vec_div(reg.val[3], b.reg.val[3])})); + FP32Vec16 operator/(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_div(reg.val[0], b.reg.val[0]), + vec_div(reg.val[1], b.reg.val[1]), + vec_div(reg.val[2], b.reg.val[2]), + vec_div(reg.val[3], b.reg.val[3])})); } float reduce_sum() const { AliasReg ar; ar.reg = reg; float result = 0; - unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); + unroll_loop( + [&result, &ar](int i) { result += ar.values[i]; }); return result; } - template float reduce_sub_sum(int idx) { + template + float reduce_sub_sum(int idx) { static_assert(VEC_ELEM_NUM % group_size == 0); AliasReg ar; @@ -368,7 +371,7 @@ struct FP32Vec16 : public Vec { return result; } - void save(float *ptr) const { + void save(float* ptr) const { vec_xst(reg.val[0], 0, ptr); vec_xst(reg.val[1], 16, ptr); vec_xst(reg.val[2], 32, ptr); @@ -376,43 +379,62 @@ struct FP32Vec16 : public Vec { } }; -template struct VecType { using vec_type = void; }; +template +struct VecType { + using vec_type = void; +}; -template using vec_t = typename VecType::vec_type; +template +using vec_t = typename VecType::vec_type; -template <> struct VecType { using vec_type = FP32Vec8; }; +template <> +struct VecType { + using vec_type = FP32Vec8; +}; -template <> struct VecType { using vec_type = BF16Vec8; }; +template <> +struct VecType { + using vec_type = BF16Vec8; +}; -template void storeFP32(float v, T *ptr) { *ptr = v; } +template +void storeFP32(float v, T* ptr) { + *ptr = v; +} -inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { +inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) { acc = acc + a * b; } -template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { - c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = - reinterpret_cast(&v); +template <> +inline void storeFP32(float v, c10::BFloat16* ptr) { + c10::BFloat16 __attribute__((__may_alias__))* v_ptr = + reinterpret_cast(&v); *ptr = *(v_ptr + 1); } #ifndef __VEC_CLASS_FP_NAN -#define __VEC_CLASS_FP_NAN (1 << 6) + #define __VEC_CLASS_FP_NAN (1 << 6) #endif -const static __vector unsigned char omask = { 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 }; +const static __vector unsigned char omask = {0, 1, 4, 5, 8, 9, 12, 13, + 16, 17, 20, 21, 24, 25, 28, 29}; #ifndef _ARCH_PWR10 -const static __vector unsigned int bias = { 0x00007fff, 0x00007fff, 0x00007fff, 0x00007fff }; -const static __vector unsigned int nan = { 0x7fc00000, 0x7fc00000, 0x7fc00000, 0x7fc00000 }; -const static __vector unsigned int sh16 = { 16, 16, 16, 16 }; -const static __vector unsigned int one = { 1, 1, 1, 1 }; +const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff, + 0x00007fff}; +const static __vector unsigned int nan = {0x7fc00000, 0x7fc00000, 0x7fc00000, + 0x7fc00000}; +const static __vector unsigned int sh16 = {16, 16, 16, 16}; +const static __vector unsigned int one = {1, 1, 1, 1}; #endif -inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) { +inline BF16Vec8::BF16Vec8(const FP32Vec8& v) { #ifdef _ARCH_PWR10 __vector signed short ret[2]; - ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]); - ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]); + ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16( + (__vector unsigned char)v.reg.val[0]); + ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16( + (__vector unsigned char)v.reg.val[1]); reg = vec_perm(ret[0], ret[1], omask); #elif defined(_ARCH_PWR9) __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); @@ -425,8 +447,10 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) { __vector unsigned int rnd1 = vec_add(lsb1, bias); inp0 = vec_add(inp0, rnd0); inp1 = vec_add(inp1, rnd1); - __vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); - __vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); + __vector __bool int sel0 = + vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); + __vector __bool int sel1 = + vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); inp0 = vec_sel(inp0, nan, sel0); inp1 = vec_sel(inp1, nan, sel1); inp0 = vec_sr(inp0, sh16); @@ -435,13 +459,17 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) { #endif } -inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { +inline BF16Vec16::BF16Vec16(const FP32Vec16& v) { #ifdef _ARCH_PWR10 __vector signed short ret[4]; - ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]); - ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]); - ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[2]); - ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[3]); + ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16( + (__vector unsigned char)v.reg.val[0]); + ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16( + (__vector unsigned char)v.reg.val[1]); + ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16( + (__vector unsigned char)v.reg.val[2]); + ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16( + (__vector unsigned char)v.reg.val[3]); reg.val[0] = vec_perm(ret[0], ret[1], omask); reg.val[1] = vec_perm(ret[2], ret[3], omask); #elif defined(_ARCH_PWR9) @@ -465,10 +493,14 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { inp1 = vec_add(inp1, rnd1); inp2 = vec_add(inp2, rnd2); inp3 = vec_add(inp3, rnd3); - __vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); - __vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); - __vector __bool int sel2 = vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN); - __vector __bool int sel3 = vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN); + __vector __bool int sel0 = + vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); + __vector __bool int sel1 = + vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); + __vector __bool int sel2 = + vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN); + __vector __bool int sel3 = + vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN); inp0 = vec_sel(inp0, nan, sel0); inp1 = vec_sel(inp1, nan, sel1); inp2 = vec_sel(inp2, nan, sel2); @@ -482,10 +514,10 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { #endif } -inline void prefetch(const void *addr) { +inline void prefetch(const void* addr) { __asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory"); } -}; // namespace vec_op +}; // namespace vec_op #endif diff --git a/csrc/cpu/cpu_types_vxe.hpp b/csrc/cpu/cpu_types_vxe.hpp new file mode 100644 index 000000000000..ab8cbbbf4ec4 --- /dev/null +++ b/csrc/cpu/cpu_types_vxe.hpp @@ -0,0 +1,480 @@ + +#ifndef CPU_TYPES_VXE_HPP +#define CPU_TYPES_VXE_HPP + +#include +#include +#include +namespace vec_op { + +#define vec_neg(a) (-(a)) +#define vec_add(a, b) ((a) + (b)) +#define vec_sub(a, b) ((a) - (b)) +#define vec_mul(a, b) ((a) * (b)) +#define vec_div(a, b) ((a) / (b)) +#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebaic +#define vec_sl(a, b) ((a) << (b)) // Vector Shift Left + +// FIXME: FP16 is not fully supported in Torch-CPU +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#ifndef CPU_OP_GUARD + #define CPU_KERNEL_GUARD_IN(NAME) + #define CPU_KERNEL_GUARD_OUT(NAME) +#else + #define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; + #define CPU_KERNEL_GUARD_OUT(NAME) \ + std::cout << #NAME << " exit." << std::endl; +#endif + +#define FORCE_INLINE __attribute__((always_inline)) inline + +namespace { +template +constexpr void unroll_loop_item(std::integer_sequence, F&& f) { + (f(std::integral_constant{}), ...); +} +}; // namespace + +template >> +constexpr void unroll_loop(F&& f) { + unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); +} + +template +struct Vec { + constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } +}; + +typedef struct ss16x8x2_t { + __vector signed short val[2]; +} ss16x8x2_t; + +typedef struct ss16x8x4_t { + __vector signed short val[4]; +} ss16x8x4_t; + +typedef struct f32x4x2_t { + __vector float val[2]; +} f32x4x2_t; + +typedef struct f32x4x4_t { + __vector float val[4]; +} f32x4x4_t; + +struct FP32Vec8; +struct FP32Vec16; + +struct BF16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __vector signed short reg; + + explicit BF16Vec8(const void* ptr) : reg(*(__vector signed short*)ptr) {} + explicit BF16Vec8(const FP32Vec8&); + + void save(void* ptr) const { + *reinterpret_cast<__vector signed short*>(ptr) = reg; + } +}; + +struct BF16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + ss16x8x2_t reg; + + explicit BF16Vec16(const void* ptr) { + // Load 256 bits in two parts + reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr); + reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr); + } + + explicit BF16Vec16(const FP32Vec16&); + + void save(void* ptr) const { + // Save 256 bits in two parts + vec_xst(reg.val[0], 0, (signed short*)ptr); + vec_xst(reg.val[1], 16, (signed short*)ptr); + } +}; + +const static __vector signed short zero = vec_splats((signed short)0); + +struct BF16Vec32 : public Vec { + constexpr static int VEC_ELEM_NUM = 32; + + ss16x8x4_t reg; + explicit BF16Vec32(const void* ptr) + : reg(*reinterpret_cast(ptr)) {} + + explicit BF16Vec32(ss16x8x4_t data) : reg(data) {} + + explicit BF16Vec32(const BF16Vec8& vec8_data) + : reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {} + + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } +}; + +struct FP32Vec4 : public Vec { + constexpr static int VEC_ELEM_NUM = 4; + union AliasReg { + __vector float reg; + float values[VEC_ELEM_NUM]; + }; + + __vector float reg; + + explicit FP32Vec4(float v) : reg(vec_splats(v)) {} + + explicit FP32Vec4() : reg(vec_splats(0.0f)) {} + + explicit FP32Vec4(const float* ptr) : reg(vec_xl(0, ptr)) {} + + explicit FP32Vec4(__vector float data) : reg(data) {} + + explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {} +}; + +struct FP32Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + union AliasReg { + f32x4x2_t reg; + float values[VEC_ELEM_NUM]; + }; + + f32x4x2_t reg; + + explicit FP32Vec8(float v) { + reg.val[0] = vec_splats(v); + reg.val[1] = vec_splats(v); + } + + explicit FP32Vec8() { + reg.val[0] = vec_splats(0.0f); + reg.val[1] = vec_splats(0.0f); + } + + explicit FP32Vec8(const float* ptr) { + reg.val[0] = vec_xl(0, ptr); + reg.val[1] = vec_xl(16, ptr); + } + + explicit FP32Vec8(f32x4x2_t data) : reg(data) {} + + explicit FP32Vec8(const FP32Vec8& data) { + reg.val[0] = data.reg.val[0]; + reg.val[1] = data.reg.val[1]; + } + + explicit FP32Vec8(const BF16Vec8& v) { + reg.val[0] = (__vector float)vec_mergeh(zero, v.reg); + reg.val[1] = (__vector float)vec_mergel(zero, v.reg); + } + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float result = 0; + unroll_loop( + [&result, &ar](int i) { result += ar.values[i]; }); + + return result; + } + + FP32Vec8 exp() const { + // TODO: Vectorize this + AliasReg ar; + ar.reg = reg; + f32x4x4_t ret; + ret.val[0][0] = std::exp(ar.values[0]); + ret.val[0][1] = std::exp(ar.values[1]); + ret.val[0][2] = std::exp(ar.values[2]); + ret.val[0][3] = std::exp(ar.values[3]); + ret.val[1][0] = std::exp(ar.values[4]); + ret.val[1][1] = std::exp(ar.values[5]); + ret.val[1][2] = std::exp(ar.values[6]); + ret.val[1][3] = std::exp(ar.values[7]); + return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); + } + + FP32Vec8 tanh() const { + // TODO: Vectorize this + AliasReg ar; + ar.reg = reg; + f32x4x4_t ret; + ret.val[0][0] = std::tanh(ar.values[0]); + ret.val[0][1] = std::tanh(ar.values[1]); + ret.val[0][2] = std::tanh(ar.values[2]); + ret.val[0][3] = std::tanh(ar.values[3]); + ret.val[1][0] = std::tanh(ar.values[4]); + ret.val[1][1] = std::tanh(ar.values[5]); + ret.val[1][2] = std::tanh(ar.values[6]); + ret.val[1][3] = std::tanh(ar.values[7]); + return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); + } + + FP32Vec8 er() const { + // TODO: Vectorize this + AliasReg ar; + ar.reg = reg; + f32x4x4_t ret; + ret.val[0][0] = std::erf(ar.values[0]); + ret.val[0][1] = std::erf(ar.values[1]); + ret.val[0][2] = std::erf(ar.values[2]); + ret.val[0][3] = std::erf(ar.values[3]); + ret.val[1][0] = std::erf(ar.values[4]); + ret.val[1][1] = std::erf(ar.values[5]); + ret.val[1][2] = std::erf(ar.values[6]); + ret.val[1][3] = std::erf(ar.values[7]); + return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); + } + + FP32Vec8 operator*(const FP32Vec8& b) const { + return FP32Vec8( + {vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])}); + } + + FP32Vec8 operator+(const FP32Vec8& b) const { + return FP32Vec8( + {vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])}); + } + + FP32Vec8 operator-(const FP32Vec8& b) const { + return FP32Vec8( + {vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])}); + } + + FP32Vec8 operator/(const FP32Vec8& b) const { + return FP32Vec8( + {vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])}); + } + + void save(float* ptr) const { + vec_xst(reg.val[0], 0, ptr); + vec_xst(reg.val[1], 16, ptr); + } +}; + +struct FP32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + f32x4x4_t reg; + float values[VEC_ELEM_NUM]; + }; + + f32x4x4_t reg; + + explicit FP32Vec16(float v) { + reg.val[0] = vec_splats(v); + reg.val[1] = vec_splats(v); + reg.val[2] = vec_splats(v); + reg.val[3] = vec_splats(v); + } + + explicit FP32Vec16() { + reg.val[0] = vec_splats(0.0f); + reg.val[1] = vec_splats(0.0f); + reg.val[2] = vec_splats(0.0f); + reg.val[3] = vec_splats(0.0f); + } + + explicit FP32Vec16(const float* ptr) { + reg.val[0] = vec_xl(0, ptr); + reg.val[1] = vec_xl(16, ptr); + reg.val[2] = vec_xl(32, ptr); + reg.val[3] = vec_xl(48, ptr); + } + + explicit FP32Vec16(f32x4x4_t data) : reg(data) {} + + explicit FP32Vec16(const FP32Vec16& data) { + reg.val[0] = data.reg.val[0]; + reg.val[1] = data.reg.val[1]; + reg.val[2] = data.reg.val[2]; + reg.val[3] = data.reg.val[3]; + } + + explicit FP32Vec16(const FP32Vec4& data) { + reg.val[0] = data.reg; + reg.val[1] = data.reg; + reg.val[2] = data.reg; + reg.val[3] = data.reg; + } + + explicit FP32Vec16(const FP32Vec8& data) { + reg.val[0] = data.reg.val[0]; + reg.val[1] = data.reg.val[1]; + reg.val[2] = data.reg.val[0]; + reg.val[3] = data.reg.val[1]; + } + + explicit FP32Vec16(const BF16Vec16& v) { + reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]); + reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]); + reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]); + reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]); + } + + explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} + + FP32Vec16 operator*(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]), + vec_mul(reg.val[1], b.reg.val[1]), + vec_mul(reg.val[2], b.reg.val[2]), + vec_mul(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 operator+(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_add(reg.val[0], b.reg.val[0]), + vec_add(reg.val[1], b.reg.val[1]), + vec_add(reg.val[2], b.reg.val[2]), + vec_add(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 operator-(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_sub(reg.val[0], b.reg.val[0]), + vec_sub(reg.val[1], b.reg.val[1]), + vec_sub(reg.val[2], b.reg.val[2]), + vec_sub(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 operator/(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_div(reg.val[0], b.reg.val[0]), + vec_div(reg.val[1], b.reg.val[1]), + vec_div(reg.val[2], b.reg.val[2]), + vec_div(reg.val[3], b.reg.val[3])})); + } + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float result = 0; + unroll_loop( + [&result, &ar](int i) { result += ar.values[i]; }); + + return result; + } + + template + float reduce_sub_sum(int idx) { + static_assert(VEC_ELEM_NUM % group_size == 0); + + AliasReg ar; + ar.reg = reg; + float result = 0; + const int start = idx * group_size; + unroll_loop( + [&result, &start, ar](int i) { result += ar.values[start + i]; }); + + return result; + } + + void save(float* ptr) const { + vec_xst(reg.val[0], 0, ptr); + vec_xst(reg.val[1], 16, ptr); + vec_xst(reg.val[2], 32, ptr); + vec_xst(reg.val[3], 48, ptr); + } +}; + +template +struct VecType { + using vec_type = void; +}; + +template +using vec_t = typename VecType::vec_type; + +template <> +struct VecType { + using vec_type = FP32Vec8; +}; + +template <> +struct VecType { + using vec_type = BF16Vec8; +}; + +template +void storeFP32(float v, T* ptr) { + *ptr = v; +} + +inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) { + acc = acc + a * b; +} + +namespace c10 { +struct BFloat16 { + uint16_t value; // Assume BFloat16 is defined as a struct containing a 16-bit + // value. +}; +} // namespace c10 + +template <> +inline void storeFP32(float v, c10::BFloat16* ptr) { + c10::BFloat16 __attribute__((__may_alias__))* v_ptr = + reinterpret_cast(&v); + *ptr = *(v_ptr + 1); +} + +#ifndef __VEC_CLASS_FP_NAN + #define __VEC_CLASS_FP_NAN (1 << 6) +#endif + +const static __vector unsigned char omask = {2, 3, 6, 7, 10, 11, 14, 15, + 18, 19, 22, 23, 26, 27, 30, 31}; +const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff, + 0x00007fff}; +const static __vector unsigned int nan = {0x7fc00000, 0x7fc00000, 0x7fc00000, + 0x7fc00000}; +const static __vector unsigned int sh16 = {16, 16, 16, 16}; +const static __vector unsigned int one = {1, 1, 1, 1}; + +inline BF16Vec8::BF16Vec8(const FP32Vec8& v) { + __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); + __vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]); + int cc; + __vector __bool int sel0 = + vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc); + __vector __bool int sel1 = + vec_fp_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN, &cc); + inp0 = vec_sel(inp0, nan, sel0) >> sh16; + inp1 = vec_sel(inp1, nan, sel1) >> sh16; + reg = (__vector signed short)vec_perm(inp0, inp1, omask); +} + +inline BF16Vec16::BF16Vec16(const FP32Vec16& v) { + __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); + __vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]); + __vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]); + __vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]); + int cc; + __vector __bool int sel0 = + vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc); + __vector __bool int sel1 = + vec_fp_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN, &cc); + __vector __bool int sel2 = + vec_fp_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN, &cc); + __vector __bool int sel3 = + vec_fp_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN, &cc); + inp0 = vec_sel(inp0, nan, sel0) >> sh16; + inp1 = vec_sel(inp1, nan, sel1) >> sh16; + inp2 = vec_sel(inp2, nan, sel2) >> sh16; + inp3 = vec_sel(inp3, nan, sel3) >> sh16; + reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask); + reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask); +} + +inline void prefetch(const void* addr) { void __dcbt(const void* addr); } + +}; // namespace vec_op + +#endif \ No newline at end of file diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index f50620a5287d..a9369e1fd101 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -11,88 +11,98 @@ static_assert(false, "AVX2 must be supported for the current implementation."); namespace vec_op { -// FIXME: FP16 is not fully supported in Torch-CPU -#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) -#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ +#define VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, \ + VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(__VA_ARGS__)) + #ifndef CPU_OP_GUARD -#define CPU_KERNEL_GUARD_IN(NAME) -#define CPU_KERNEL_GUARD_OUT(NAME) + #define CPU_KERNEL_GUARD_IN(NAME) + #define CPU_KERNEL_GUARD_OUT(NAME) #else -#define CPU_KERNEL_GUARD_IN(NAME) \ - std::cout << #NAME << " invoked." << std::endl; -#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; + #define CPU_KERNEL_GUARD_IN(NAME) \ + RECORD_FUNCTION(#NAME, c10::ArrayRef({})); + #define CPU_KERNEL_GUARD_OUT(NAME) #endif #define FORCE_INLINE __attribute__((always_inline)) inline namespace { template -constexpr void unroll_loop_item(std::integer_sequence, F &&f) { +constexpr void unroll_loop_item(std::integer_sequence, F&& f) { (f(std::integral_constant{}), ...); } -}; // namespace +}; // namespace template >> -constexpr void unroll_loop(F &&f) { +constexpr void unroll_loop(F&& f) { unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); } -template struct Vec { +template +struct Vec { constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } }; struct FP32Vec8; struct FP32Vec16; -#ifdef __AVX512FP16__ struct FP16Vec8 : public Vec { constexpr static int VEC_ELEM_NUM = 8; - __m128h reg; + __m128i reg; - explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {} + explicit FP16Vec8(const void* ptr) + : reg((__m128i)_mm_loadu_si128((__m128i*)ptr)) {} - explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {} + explicit FP16Vec8(const FP32Vec8&); - explicit FP16Vec8(__m128h data) : reg(data) {} + void save(void* ptr) const { *reinterpret_cast<__m128i*>(ptr) = reg; } +}; - FP16Vec8 operator*(const FP16Vec8 &b) const { - return FP16Vec8(_mm_mul_ph(reg, b.reg)); - } +struct FP16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; - FP16Vec8 operator+(const FP16Vec8 &b) const { - return FP16Vec8(_mm_add_ph(reg, b.reg)); - } + __m256i reg; - FP16Vec8 operator-(const FP16Vec8 &b) const { - return FP16Vec8(_mm_sub_ph(reg, b.reg)); - } + explicit FP16Vec16(const void* ptr) + : reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {} - FP16Vec8 operator/(const FP16Vec8 &b) const { - return FP16Vec8(_mm_div_ph(reg, b.reg)); - } + explicit FP16Vec16(const FP32Vec16&); + + void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; } - void save(void *ptr) const { _mm_storeu_ph(ptr, reg); } + void save(void* ptr, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm256_mask_storeu_epi16(ptr, mask, reg); + } }; -#endif struct BF16Vec8 : public Vec { constexpr static int VEC_ELEM_NUM = 8; __m128i reg; - explicit BF16Vec8(const void *ptr) - : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} + explicit BF16Vec8(const void* ptr) + : reg((__m128i)_mm_loadu_si128((__m128i*)ptr)) {} - explicit BF16Vec8(const FP32Vec8 &); + explicit BF16Vec8(const FP32Vec8&); - void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } + void save(void* ptr) const { *reinterpret_cast<__m128i*>(ptr) = reg; } }; struct BF16Vec16 : public Vec { @@ -100,12 +110,18 @@ struct BF16Vec16 : public Vec { __m256i reg; - explicit BF16Vec16(const void *ptr) - : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} + explicit BF16Vec16(const void* ptr) + : reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {} + + explicit BF16Vec16(const FP32Vec16&); - explicit BF16Vec16(const FP32Vec16 &); + void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; } - void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } + void save(void* ptr, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm256_mask_storeu_epi16(ptr, mask, reg); + } }; #ifdef __AVX512F__ @@ -114,11 +130,11 @@ struct BF16Vec32 : public Vec { __m512i reg; - explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {} + explicit BF16Vec32(const void* ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {} explicit BF16Vec32(__m512i data) : reg(data) {} - explicit BF16Vec32(BF16Vec8 &vec8_data) + explicit BF16Vec32(BF16Vec8& vec8_data) : reg((__m512i)_mm512_inserti32x4( _mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512( (__m128i)vec8_data.reg), @@ -126,7 +142,7 @@ struct BF16Vec32 : public Vec { (__m128i)vec8_data.reg, 2), (__m128i)vec8_data.reg, 3)) {} - void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; } + void save(void* ptr) const { *reinterpret_cast<__m512i*>(ptr) = reg; } }; #else struct BF16Vec32 : public Vec { @@ -135,24 +151,24 @@ struct BF16Vec32 : public Vec { __m256i reg_low; __m256i reg_high; - explicit BF16Vec32(const void *ptr) - : reg_low(_mm256_loadu_si256((__m256i const *)ptr)), - reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {} + explicit BF16Vec32(const void* ptr) + : reg_low(_mm256_loadu_si256((__m256i const*)ptr)), + reg_high(_mm256_loadu_si256((__m256i const*)ptr + 1)) {} - explicit BF16Vec32(__m256i low, __m256i high) : reg_low(low), - reg_high(high) {} + explicit BF16Vec32(__m256i low, __m256i high) + : reg_low(low), reg_high(high) {} - explicit BF16Vec32(BF16Vec8 &vec8_data) + explicit BF16Vec32(BF16Vec8& vec8_data) : reg_low((__m256i)_mm256_inserti32x4( - _mm256_castsi128_si256((__m128i)vec8_data.reg), - (__m128i)vec8_data.reg, 1)), + _mm256_castsi128_si256((__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1)), reg_high((__m256i)_mm256_inserti32x4( - _mm256_castsi128_si256((__m128i)vec8_data.reg), - (__m128i)vec8_data.reg, 1)) {} + _mm256_castsi128_si256((__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1)) {} - void save(void *ptr) const { - *reinterpret_cast<__m256i *>(ptr) = reg_low; - *reinterpret_cast<__m256i *>((__m256i *)ptr + 1) = reg_high; + void save(void* ptr) const { + *reinterpret_cast<__m256i*>(ptr) = reg_low; + *reinterpret_cast<__m256i*>((__m256i*)ptr + 1) = reg_high; } }; #endif @@ -170,11 +186,11 @@ struct FP32Vec4 : public Vec { explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {} - explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {} + explicit FP32Vec4(const float* ptr) : reg(_mm_loadu_ps(ptr)) {} explicit FP32Vec4(__m128 data) : reg(data) {} - explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} + explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {} }; struct FP32Vec8 : public Vec { @@ -190,17 +206,15 @@ struct FP32Vec8 : public Vec { explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {} - explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {} + explicit FP32Vec8(const float* ptr) : reg(_mm256_loadu_ps(ptr)) {} explicit FP32Vec8(__m256 data) : reg(data) {} - explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {} + explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {} -#ifdef __AVX512FP16__ - explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {} -#endif + explicit FP32Vec8(const FP16Vec8& v) : reg(_mm256_cvtph_ps(v.reg)) {} - explicit FP32Vec8(const BF16Vec8 &v) + explicit FP32Vec8(const BF16Vec8& v) : reg(_mm256_castsi256_ps( _mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {} @@ -208,7 +222,8 @@ struct FP32Vec8 : public Vec { AliasReg ar; ar.reg = reg; float result = 0; - unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); + unroll_loop( + [&result, &ar](int i) { result += ar.values[i]; }); return result; } @@ -240,25 +255,48 @@ struct FP32Vec8 : public Vec { erf(ar.values[1]), erf(ar.values[0]))); } - FP32Vec8 operator*(const FP32Vec8 &b) const { + FP32Vec8 operator*(const FP32Vec8& b) const { return FP32Vec8(_mm256_mul_ps(reg, b.reg)); } - FP32Vec8 operator+(const FP32Vec8 &b) const { + FP32Vec8 operator+(const FP32Vec8& b) const { return FP32Vec8(_mm256_add_ps(reg, b.reg)); } - FP32Vec8 operator-(const FP32Vec8 &b) const { + FP32Vec8 operator-(const FP32Vec8& b) const { return FP32Vec8(_mm256_sub_ps(reg, b.reg)); } - FP32Vec8 operator/(const FP32Vec8 &b) const { + FP32Vec8 operator/(const FP32Vec8& b) const { return FP32Vec8(_mm256_div_ps(reg, b.reg)); } - void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } + void save(float* ptr) const { _mm256_storeu_ps(ptr, reg); } }; +#ifdef __AVX512F__ +struct INT32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + __m512i reg; + int32_t values[VEC_ELEM_NUM]; + }; + + __m512i reg; + + explicit INT32Vec16(const void* data_ptr) + : reg(_mm512_loadu_epi32(data_ptr)) {} + + void save(int32_t* ptr) const { _mm512_storeu_epi32(ptr, reg); } + + void save(int32_t* ptr, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm512_mask_storeu_epi32(ptr, mask, reg); + } +}; +#endif + #ifdef __AVX512F__ struct FP32Vec16 : public Vec { constexpr static int VEC_ELEM_NUM = 16; @@ -273,13 +311,11 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {} - explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {} + explicit FP32Vec16(const float* ptr) : reg(_mm512_loadu_ps(ptr)) {} explicit FP32Vec16(__m512 data) : reg(data) {} - explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {} - - explicit FP32Vec16(const FP32Vec4 &data) + explicit FP32Vec16(const FP32Vec4& data) : reg((__m512)_mm512_inserti32x4( _mm512_inserti32x4( _mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg), @@ -287,42 +323,87 @@ struct FP32Vec16 : public Vec { (__m128i)data.reg, 2), (__m128i)data.reg, 3)) {} - explicit FP32Vec16(const FP32Vec8 &data) + explicit FP32Vec16(const FP32Vec8& data) : reg((__m512)_mm512_inserti32x8( _mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {} - explicit FP32Vec16(const BF16Vec16 &v) + explicit FP32Vec16(const BF16Vec16& v) : reg(_mm512_castsi512_ps( _mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {} - explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + explicit FP32Vec16(const FP16Vec16& v) : reg(_mm512_cvtph_ps(v.reg)) {} + + explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} - FP32Vec16 operator*(const FP32Vec16 &b) const { + explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} + + explicit FP32Vec16(const INT32Vec16& v) + : reg(_mm512_cvt_roundepi32_ps( + v.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {} + + FP32Vec16 operator*(const FP32Vec16& b) const { return FP32Vec16(_mm512_mul_ps(reg, b.reg)); } - FP32Vec16 operator+(const FP32Vec16 &b) const { + FP32Vec16 operator+(const FP32Vec16& b) const { return FP32Vec16(_mm512_add_ps(reg, b.reg)); } - FP32Vec16 operator-(const FP32Vec16 &b) const { + FP32Vec16 operator-(const FP32Vec16& b) const { return FP32Vec16(_mm512_sub_ps(reg, b.reg)); } - FP32Vec16 operator/(const FP32Vec16 &b) const { + FP32Vec16 operator/(const FP32Vec16& b) const { return FP32Vec16(_mm512_div_ps(reg, b.reg)); } + FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const { + return FP32Vec16(_mm512_min_ps(max.reg, _mm512_max_ps(min.reg, reg))); + } + + FP32Vec16 max(const FP32Vec16& b) const { + return FP32Vec16(_mm512_max_ps(reg, b.reg)); + } + + FP32Vec16 max(const FP32Vec16& b, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + return FP32Vec16(_mm512_mask_max_ps(reg, mask, reg, b.reg)); + } + + FP32Vec16 min(const FP32Vec16& b) const { + return FP32Vec16(_mm512_min_ps(reg, b.reg)); + } + + FP32Vec16 min(const FP32Vec16& b, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + return FP32Vec16(_mm512_mask_min_ps(reg, mask, reg, b.reg)); + } + + FP32Vec16 abs() const { return FP32Vec16(_mm512_abs_ps(reg)); } + float reduce_sum() const { return _mm512_reduce_add_ps(reg); } - template float reduce_sub_sum(int idx) { + float reduce_max() const { return _mm512_reduce_max_ps(reg); } + + float reduce_min() const { return _mm512_reduce_min_ps(reg); } + + template + float reduce_sub_sum(int idx) { static_assert(VEC_ELEM_NUM % group_size == 0); constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size)); return _mm512_mask_reduce_add_ps(mask, reg); } - void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); } + void save(float* ptr) const { _mm512_storeu_ps(ptr, reg); } + + void save(float* ptr, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm512_mask_storeu_ps(ptr, mask, reg); + } }; #else struct FP32Vec16 : public Vec { @@ -336,32 +417,40 @@ struct FP32Vec16 : public Vec { __m256 reg_low; __m256 reg_high; - explicit FP32Vec16(float v) : reg_low(_mm256_set1_ps(v)), - reg_high(_mm256_set1_ps(v)) {} + explicit FP32Vec16(float v) + : reg_low(_mm256_set1_ps(v)), reg_high(_mm256_set1_ps(v)) {} - explicit FP32Vec16() : reg_low(_mm256_set1_ps(0.0)), - reg_high(_mm256_set1_ps(0.0)) {} + explicit FP32Vec16() + : reg_low(_mm256_set1_ps(0.0)), reg_high(_mm256_set1_ps(0.0)) {} - explicit FP32Vec16(const float *ptr) : reg_low(_mm256_loadu_ps(ptr)), - reg_high(_mm256_loadu_ps(ptr + 8)) {} + explicit FP32Vec16(const float* ptr) + : reg_low(_mm256_loadu_ps(ptr)), reg_high(_mm256_loadu_ps(ptr + 8)) {} explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {} - explicit FP32Vec16(const FP32Vec16 &data) : reg_low(data.reg_low), - reg_high(data.reg_high) {} + explicit FP32Vec16(const FP32Vec16& data) + : reg_low(data.reg_low), reg_high(data.reg_high) {} - explicit FP32Vec16(const FP32Vec4 &data) + explicit FP32Vec16(const FP32Vec4& data) : reg_low((__m256)_mm256_inserti128_si256( - _mm256_castsi128_si256((__m128i)data.reg), - (__m128i)data.reg, 1)), + _mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)), reg_high((__m256)_mm256_inserti128_si256( - _mm256_castsi128_si256((__m128i)data.reg), - (__m128i)data.reg, 1)) {} + _mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)) {} - explicit FP32Vec16(const FP32Vec8 &data) + explicit FP32Vec16(const FP32Vec8& data) : reg_low(data.reg), reg_high(data.reg) {} - explicit FP32Vec16(const BF16Vec16 &v) { + explicit FP32Vec16(const FP16Vec16& v) { + __m128i low = _mm256_extractf128_si256(v.reg, 0); + __m128i high = _mm256_extractf128_si256(v.reg, 1); + + reg_low = _mm256_cvtph_ps(low); + reg_high = _mm256_cvtph_ps(high); + } + + explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} + + explicit FP32Vec16(const BF16Vec16& v) { __m128i low = _mm256_extractf128_si256(v.reg, 0); __m128i high = _mm256_extractf128_si256(v.reg, 1); @@ -375,24 +464,24 @@ struct FP32Vec16 : public Vec { reg_high = _mm256_castsi256_ps(v_high_shifted); } - explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} - FP32Vec16 operator*(const FP32Vec16 &b) const { + FP32Vec16 operator*(const FP32Vec16& b) const { return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low), _mm256_mul_ps(reg_high, b.reg_high)); } - FP32Vec16 operator+(const FP32Vec16 &b) const { + FP32Vec16 operator+(const FP32Vec16& b) const { return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low), _mm256_add_ps(reg_high, b.reg_high)); } - FP32Vec16 operator-(const FP32Vec16 &b) const { + FP32Vec16 operator-(const FP32Vec16& b) const { return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low), _mm256_sub_ps(reg_high, b.reg_high)); } - FP32Vec16 operator/(const FP32Vec16 &b) const { + FP32Vec16 operator/(const FP32Vec16& b) const { return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low), _mm256_div_ps(reg_high, b.reg_high)); } @@ -403,7 +492,8 @@ struct FP32Vec16 : public Vec { return low.reduce_sum() + high.reduce_sum(); } - template float reduce_sub_sum(int idx) { + template + float reduce_sub_sum(int idx) { float sum = 0.0; static_assert(VEC_ELEM_NUM % group_size == 0); constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); @@ -426,68 +516,123 @@ struct FP32Vec16 : public Vec { return sum; } - void save(float *ptr) const { + void save(float* ptr) const { _mm256_storeu_ps(ptr, reg_low); _mm256_storeu_ps(ptr + 8, reg_high); } }; #endif -template struct VecType { using vec_type = void; }; +#ifdef __AVX512F__ +struct INT8Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + __m128i reg; + int8_t values[VEC_ELEM_NUM]; + }; + + __m128i reg; -template using vec_t = typename VecType::vec_type; + explicit INT8Vec16(const FP32Vec16& vec) + : reg(_mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32( + vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))) {} -template <> struct VecType { using vec_type = FP32Vec8; }; + void save(int8_t* ptr) const { _mm_storeu_epi8(ptr, reg); } -#ifdef __AVX512FP16__ -template <> struct VecType { using vec_type = FP16Vec16; }; + void save(int8_t* ptr, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm_mask_storeu_epi8(ptr, mask, reg); + } +}; #endif -template <> struct VecType { using vec_type = BF16Vec8; }; +template +struct VecType { + using vec_type = void; +}; + +template +using vec_t = typename VecType::vec_type; + +template <> +struct VecType { + using vec_type = FP32Vec8; +}; + +template <> +struct VecType { + using vec_type = FP16Vec8; +}; -template void storeFP32(float v, T *ptr) { *ptr = v; } +template <> +struct VecType { + using vec_type = BF16Vec8; +}; -#ifdef __AVX512FP16__ -template <> inline void storeFP32(float v, c10::Half *ptr) { - *reinterpret_cast<_Float16 *>(ptr) = v; +template +void storeFP32(float v, T* ptr) { + *ptr = v; } -#endif -inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { +inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) { acc = acc + a * b; } +template <> +inline void storeFP32(float v, c10::Half* ptr) { + *reinterpret_cast(ptr) = + _cvtss_sh(v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); +} + +inline FP16Vec8::FP16Vec8(const FP32Vec8& v) + : reg(_mm256_cvtps_ph(v.reg, + _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {} + +#ifdef __AVX512F__ +inline FP16Vec16::FP16Vec16(const FP32Vec16& v) + : reg(_mm512_cvtps_ph(v.reg, + _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {} +#else +inline FP16Vec16::FP16Vec16(const FP32Vec16& v) + : reg(_mm256_insertf128_si256( + _mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg), + FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {} +#endif + #ifdef __AVX512BF16__ -template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { - *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v); +template <> +inline void storeFP32(float v, c10::BFloat16* ptr) { + *reinterpret_cast<__bfloat16*>(ptr) = _mm_cvtness_sbh(v); } -inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) +inline BF16Vec8::BF16Vec8(const FP32Vec8& v) : reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {} -inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) +inline BF16Vec16::BF16Vec16(const FP32Vec16& v) : reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {} -inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) { +inline void fma(FP32Vec16& acc, BF16Vec32& a, BF16Vec32& b) { acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg); } #else -template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { - c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = - reinterpret_cast(&v); +template <> +inline void storeFP32(float v, c10::BFloat16* ptr) { + c10::BFloat16 __attribute__((__may_alias__))* v_ptr = + reinterpret_cast(&v); *ptr = *(v_ptr + 1); } -#ifdef __AVX512F__ -inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) + #ifdef __AVX512F__ +inline BF16Vec8::BF16Vec8(const FP32Vec8& v) : reg(_mm256_cvtepi32_epi16( _mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {} -inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) +inline BF16Vec16::BF16Vec16(const FP32Vec16& v) : reg(_mm512_cvtepi32_epi16( _mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {} -#else -namespace{ + #else +namespace { __m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) { __m256i ai = _mm256_castps_si256(a); ai = _mm256_srli_epi32(ai, 16); @@ -495,21 +640,21 @@ __m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) { ai = _mm256_permute4x64_epi64(ai, 0b00111001); return _mm256_extracti128_si256(ai, 0); } -} +} // namespace -inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) +inline BF16Vec8::BF16Vec8(const FP32Vec8& v) : reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {} -inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { +inline BF16Vec16::BF16Vec16(const FP32Vec16& v) { BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low)); BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high)); reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1); } -#endif // __AVX512F__ -#endif // __AVX512BF16__ + #endif // __AVX512F__ +#endif // __AVX512BF16__ -inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); } +inline void prefetch(const void* addr) { _mm_prefetch(addr, _MM_HINT_T1); } -}; // namespace vec_op +}; // namespace vec_op #endif diff --git a/csrc/cpu/dnnl_helper.hpp b/csrc/cpu/dnnl_helper.hpp new file mode 100644 index 000000000000..8b5011dc065f --- /dev/null +++ b/csrc/cpu/dnnl_helper.hpp @@ -0,0 +1,174 @@ +#ifndef DNNL_HELPER_HPP +#define DNNL_HELPER_HPP + +#include +#include + +#include "oneapi/dnnl/dnnl.hpp" + +namespace { +template +struct DNNLType { + static constexpr dnnl::memory::data_type type = + dnnl::memory::data_type::undef; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16; +}; + +template +constexpr inline dnnl::memory::data_type get_dnnl_type() { + return DNNLType>::type; +} +}; // namespace + +template +class DNNLPrimitiveHelper { + public: + // I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias) + // A: [M, K], row-major + // B: [K, N], column-major + // C: [M, N], row-major + // bias: [N], row-major, optional + // a_scales: [MS] + // b_scales: [NS] + // Note: Due to the limitation of oneDNN + // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is + // not supported. + template + static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c, + const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N, + dnnl_dim_t K, const float* a_scales, + const float* b_scales, dnnl_dim_t MS, + dnnl_dim_t NS) { + auto&& OutputType = get_dnnl_type(); + auto&& BiasType = get_dnnl_type(); + + dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1}); + dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K}); + dnnl::memory::desc c_md({M, N}, OutputType, {N, 1}); + + dnnl::primitive_attr attr; + if constexpr (!InputNoScale) { + if (MS == 1) { + // per-tensor + attr.set_scales_mask(DNNL_ARG_SRC, 0); + } else { + // per-token + TORCH_CHECK(false, "per-token quantization is unsupported."); + } + } + + if (NS == 1) { + // per-tensor + attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); + } else { + // per-channel + attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2); + } + + dnnl::matmul::primitive_desc matmul_pd; + if (bias) { + dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); + matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, + bias_md, c_md, attr); + } else { + matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, + c_md, attr); + } + dnnl::matmul matmul(matmul_pd); + + auto& engine = default_engine(); + + dnnl::memory a_m(a_md, engine, (void*)a); + dnnl::memory b_m(b_md, engine, (void*)b); + dnnl::memory c_m(c_md, engine, (void*)c); + dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine, + (void*)a_scales); + dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine, + (void*)b_scales); + + auto& stream = default_stream(); + if constexpr (InputNoScale) { + if (bias) { + dnnl::memory::desc bias_md({N}, BiasType, {1}); + dnnl::memory bias_m(bias_md, engine, (void*)bias); + matmul.execute( + stream, { + {DNNL_ARG_SRC, a_m}, + {DNNL_ARG_WEIGHTS, b_m}, + {DNNL_ARG_BIAS, bias_m}, + {DNNL_ARG_DST, c_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, + }); + } else { + matmul.execute( + stream, { + {DNNL_ARG_SRC, a_m}, + {DNNL_ARG_WEIGHTS, b_m}, + {DNNL_ARG_DST, c_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, + }); + } + } else { + if (bias) { + dnnl::memory::desc bias_md({N}, BiasType, {1}); + dnnl::memory bias_m(bias_md, engine, (void*)bias); + matmul.execute( + stream, { + {DNNL_ARG_SRC, a_m}, + {DNNL_ARG_WEIGHTS, b_m}, + {DNNL_ARG_BIAS, bias_m}, + {DNNL_ARG_DST, c_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, + }); + } else { + matmul.execute( + stream, { + {DNNL_ARG_SRC, a_m}, + {DNNL_ARG_WEIGHTS, b_m}, + {DNNL_ARG_DST, c_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, + }); + } + } + stream.wait(); + } + + private: + static dnnl::engine& default_engine() { + static dnnl::engine engine(dnnl::engine::kind::cpu, 0); + return engine; + } + + static dnnl::stream& default_stream() { + static dnnl::stream stream(default_engine()); + return stream; + } +}; + +#endif diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp index 96bce7dda013..8a59e884d6c8 100644 --- a/csrc/cpu/pos_encoding.cpp +++ b/csrc/cpu/pos_encoding.cpp @@ -170,7 +170,7 @@ void rotary_embedding_gptj_impl( void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox) { - int num_tokens = query.numel() / query.size(-1); + int num_tokens = positions.numel(); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; int num_kv_heads = key.size(-1) / head_size; diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp new file mode 100644 index 000000000000..6751e7e55fc5 --- /dev/null +++ b/csrc/cpu/quant.cpp @@ -0,0 +1,613 @@ +#include "cpu_types.hpp" +#include "dnnl_helper.hpp" + +namespace { +template +struct KernelVecType { + using load_vec_type = void; + using azp_adj_load_vec_type = void; + using cvt_vec_type = void; +}; + +template <> +struct KernelVecType { + using load_vec_type = vec_op::FP32Vec16; + using azp_adj_load_vec_type = vec_op::INT32Vec16; + using cvt_vec_type = vec_op::FP32Vec16; +}; + +template <> +struct KernelVecType { + using load_vec_type = vec_op::BF16Vec16; + using azp_adj_load_vec_type = vec_op::INT32Vec16; + using cvt_vec_type = vec_op::FP32Vec16; +}; + +template <> +struct KernelVecType { +#if defined(__powerpc64__) || defined(__s390x__) + // Power architecture-specific vector type + using load_vec_type = vec_op::FP32Vec16; +#else + // Fallback for other architectures + using load_vec_type = vec_op::FP16Vec16; +#endif + using azp_adj_load_vec_type = vec_op::INT32Vec16; + using cvt_vec_type = vec_op::FP32Vec16; +}; + +#ifdef __AVX512F__ +template +void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + const float* scale, const int32_t* azp, + const int num_tokens, + const int hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t inv_scale(1.0 / *scale); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + + cvt_vec_t zp_vec; + if constexpr (AZP) { + zp_vec = cvt_vec_t(static_cast(*azp)); + } + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j); + } + + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j, hidden_size - j); + } +} + +template +void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + float* scale, int32_t* azp, + const int num_tokens, + const int hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + cvt_vec_t max_value(std::numeric_limits::lowest()); + cvt_vec_t min_value(std::numeric_limits::max()); + { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } + + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + + if (j + vec_elem_num == hidden_size) { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } else { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32, hidden_size - j); + min_value = min_value.min(elems_fp32, hidden_size - j); + } else { + max_value = max_value.max(elems_fp32.abs(), hidden_size - j); + } + } + } + + float scale_val, azp_val; + if constexpr (AZP) { + float max_scalar = max_value.reduce_max(); + float min_scalar = min_value.reduce_min(); + scale_val = (max_scalar - min_scalar) / 255.0f; + azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); + azp[i] = static_cast(azp_val); + scale[i] = scale_val; + } else { + scale_val = max_value.reduce_max() / 127.0f; + scale[i] = scale_val; + } + + const cvt_vec_t inv_scale(1.0 / scale_val); + const cvt_vec_t azp_vec(azp_val); + + { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j); + } + + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j, hidden_size - j); + } + } +} + +template +void static_quant_epilogue(const float* input, scalar_t* output, + const float a_scale, const float* b_scale, + const int32_t* azp_with_adj, const int num_tokens, + const int hidden_size) { + CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) + using load_vec_t = typename KernelVecType::load_vec_type; + using azp_adj_load_vec_t = + typename KernelVecType::azp_adj_load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + cvt_vec_t a_scale_vec(a_scale); + cvt_vec_t b_scale_vec(*b_scale); + cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; + + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + cvt_vec_t elems_fp32(input + i * hidden_size + j); + azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + + if constexpr (PerChannel) { + b_scale_vec = cvt_vec_t(b_scale + j); + scale_vec = b_scale_vec * a_scale_vec; + } + + elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j); + } + + cvt_vec_t elems_fp32(input + i * hidden_size + j); + azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + + if constexpr (PerChannel) { + b_scale_vec = cvt_vec_t(b_scale + j); + scale_vec = b_scale_vec * a_scale_vec; + } + + elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j, hidden_size - j); + } +} + +template +void dynamic_quant_epilogue(const float* input, scalar_t* output, + const float* a_scale, const float* b_scale, + const int32_t* azp, const int32_t* azp_adj, + const scalar_t* bias, const int num_tokens, + const int hidden_size) { + CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) + using load_vec_t = typename KernelVecType::load_vec_type; + using azp_adj_load_vec_t = + typename KernelVecType::azp_adj_load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + int j = 0; + cvt_vec_t token_scale_vec(a_scale[i]); + cvt_vec_t token_zp_scale_vec; + if constexpr (AZP) { + float zp_scale_val = a_scale[i] * static_cast(azp[i]); + if constexpr (!PerChannel) { + zp_scale_val *= *b_scale; + } + token_zp_scale_vec = cvt_vec_t(zp_scale_val); + } + + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + cvt_vec_t elems_fp32(input + i * hidden_size + j); + elems_fp32 = elems_fp32 * token_scale_vec; + + if constexpr (AZP) { + azp_adj_load_vec_t azp_adj_vec(azp_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; + + if constexpr (PerChannel) { + cvt_vec_t b_scale_vec(b_scale + j); + azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; + } + + elems_fp32 = elems_fp32 - azp_adj_fp32; + } + + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j); + } + + cvt_vec_t elems_fp32(input + i * hidden_size + j); + elems_fp32 = elems_fp32 * token_scale_vec; + + if constexpr (AZP) { + azp_adj_load_vec_t azp_adj_vec(azp_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; + + if constexpr (PerChannel) { + cvt_vec_t b_scale_vec(b_scale + j); + azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; + } + + elems_fp32 = elems_fp32 - azp_adj_fp32; + } + + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j, hidden_size - j); + } +} +#else +template +void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + const float* scale, const int32_t* azp, + const int num_tokens, + const int hidden_size) { + TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512 support.") +} + +template +void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + float* scale, int32_t* azp, + const int num_tokens, + const int hidden_size) { + TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512 support.") +} + +template +void static_quant_epilogue(const float* input, scalar_t* output, + const float a_scale, const float* b_scale, + const int32_t* azp_with_adj, const int num_tokens, + const int hidden_size) { + TORCH_CHECK(false, "static_quant_epilogue requires AVX512 support.") +} + +template +void dynamic_quant_epilogue(const float* input, scalar_t* output, + const float* a_scale, const float* b_scale, + const int32_t* azp, const int32_t* azp_with_adj, + const scalar_t* bias, const int num_tokens, + const int hidden_size) { + TORCH_CHECK(false, "dynamic_quant_epilogue requires AVX512 support.") +} +#endif +} // namespace + +void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const torch::Tensor& b, // [IC, OC], column-major + const torch::Tensor& a_scales, // [1] or [M] + const torch::Tensor& b_scales, // [1] or [OC] + const std::optional& bias // [OC] +) { + CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) + // Checks for conformality + TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, + "int8_scaled_mm only supports INT8 inputs.") + TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); + TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && + b.size(1) == c.size(1)); + TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); + TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(0) % 16 == 0 && + b.stride(1) % 16 == 0); // 16 Byte Alignment + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + + if (bias) { + TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && + bias->dim() == 1); + } + + VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm", [&] { + if (a_scales.numel() != 1) { + // per-token + // Note: oneDNN doesn't support per-token activation quantization + // Ideally we want to fuse the GEMM and the scale procedure with oneDNN + // JIT, the intermediate data is cached in registers or L1. But for now + // the oneDNN GEMM code generation only supports two quantization + // patterns: per-tensor or per-output-channel of weight. + // So we have to apply the per-token scale with a 'epilogue'. In C=s_a * + // s_b * (A@B) + bias, the C_inter = s_b * (A@B) is computed by oneDNN + // GEMM, then the per-token scale (and bias) is applied with the epilogue + // C=s_a * C_inter + bias. + torch::Tensor tmp_fp32_out = + torch::empty_like(c, ::at::ScalarType::Float); + // Compute C_inter=s_b * (A@B) + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), + tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), + a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); + if (bias.has_value()) { + // Compute C=s_a * C_inter + bias + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), nullptr, nullptr, nullptr, + bias->data_ptr(), c.size(0), c.size(1)); + } else { + // Compute C=s_a * C_inter + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr, + c.size(0), c.size(1)); + } + } else { + // per-tensor + if (bias.has_value()) { + // Compute C=s_a * s_b * (A@B) + bias + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), c.data_ptr(), + bias->data_ptr(), a.size(0), b.size(1), a.size(1), + a_scales.data_ptr(), b_scales.data_ptr(), + a_scales.numel(), b_scales.numel()); + } else { + // Compute C=s_a * s_b * (A@B) + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), c.data_ptr(), + nullptr, a.size(0), b.size(1), a.size(1), + a_scales.data_ptr(), b_scales.data_ptr(), + a_scales.numel(), b_scales.numel()); + } + } + }); +} + +void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const torch::Tensor& b, // [IC, OC], column-major + const torch::Tensor& a_scales, // [1] or [M] + const torch::Tensor& b_scales, // [1] or [OC] + const torch::Tensor& azp_adj, // [OC] + const std::optional& azp, // [1] or [M] + const std::optional& bias // [OC] +) { + CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp) + // Checks for conformality + TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, + "int8_scaled_mm_azp only supports INT8 inputs.") + TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); + TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && + b.size(1) == c.size(1)); + TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); + TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(0) % 16 == 0 && + b.stride(1) % 16 == 0); // 16 Byte Alignment + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + + if (bias) { + TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous()); + } + if (azp) { + TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous()); + } + TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous()); + + // azp & bias types + TORCH_CHECK(azp_adj.dtype() == torch::kInt32); + TORCH_CHECK(!azp || azp->dtype() == torch::kInt32); + TORCH_CHECK(!bias || bias->dtype() == c.dtype(), + "currently bias dtype must match output dtype ", c.dtype()); + + VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_azp", [&] { + torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); + if (a_scales.numel() != 1) { + // per-token + // Note: oneDNN doesn't support per-token activation quantization + // Compute C_inter=s_b * (A@B) + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), + tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), + a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); + if (bias.has_value()) { + // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + bias + if (b_scales.numel() != 1) { + // Per-Channel + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), b_scales.data_ptr(), + azp->data_ptr(), azp_adj.data_ptr(), + bias->data_ptr(), c.size(0), c.size(1)); + } else { + // Per-Tensor + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), b_scales.data_ptr(), + azp->data_ptr(), azp_adj.data_ptr(), + bias->data_ptr(), c.size(0), c.size(1)); + } + } else { + // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + if (b_scales.numel() != 1) { + // Per-Channel + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), b_scales.data_ptr(), + azp->data_ptr(), azp_adj.data_ptr(), nullptr, + c.size(0), c.size(1)); + } else { + // Per-Tensor + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), b_scales.data_ptr(), + azp->data_ptr(), azp_adj.data_ptr(), nullptr, + c.size(0), c.size(1)); + } + } + } else { + // per-tensor + if (bias.has_value()) { + // Compute C_inter=s_a * s_b * (A@B) + bias + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), + tmp_fp32_out.data_ptr(), bias->data_ptr(), + a.size(0), b.size(1), a.size(1), a_scales.data_ptr(), + b_scales.data_ptr(), a_scales.numel(), b_scales.numel()); + } else { + // Compute C_inter=s_a * s_b * (A@B) + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), + tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), + a.size(1), a_scales.data_ptr(), b_scales.data_ptr(), + a_scales.numel(), b_scales.numel()); + } + + // Compute C=C_inter - s_a * s_b * azp_adj + if (b_scales.numel() != 1) { + // Per-Channel + static_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + *a_scales.data_ptr(), b_scales.data_ptr(), + azp_adj.data_ptr(), a.size(0), b.size(1)); + } else { + // Per-Tensor + static_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + *a_scales.data_ptr(), b_scales.data_ptr(), + azp_adj.data_ptr(), a.size(0), b.size(1)); + } + } + }); +} + +// static-per-tensor quantization. +void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] + const torch::Tensor& input, // [..., hidden_size] + const torch::Tensor& scale, + std::optional const& azp) { + CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(scale.numel() == 1); + TORCH_CHECK(!azp.has_value() || azp->numel() == 1); + + const int hidden_size = input.size(-1); + const int num_tokens = input.numel() / hidden_size; + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "static_scaled_int8_quant_impl", [&] { + if (azp.has_value()) { + static_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), num_tokens, + hidden_size); + } else { + static_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), nullptr, num_tokens, hidden_size); + } + }); +} + +// dynamic-per-token quantization. +void dynamic_scaled_int8_quant( + torch::Tensor& out, // [..., hidden_size] + const torch::Tensor& input, // [..., hidden_size] + torch::Tensor& scale, // [..., 1] + std::optional const& azp) { + CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + + int const hidden_size = input.size(-1); + int const num_tokens = input.numel() / hidden_size; + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] { + if (azp.has_value()) { + dynamic_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), num_tokens, + hidden_size); + } else { + dynamic_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), nullptr, num_tokens, hidden_size); + } + }); +} diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 7d549e271a30..5d1c5f4c83d3 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -1,10 +1,22 @@ #include "cache.h" #include "ops.h" -#include "registration.h" +#include "core/registration.h" #include -void init_cpu_threads_env(const std::string& cpu_ids); +std::string init_cpu_threads_env(const std::string& cpu_ids); + +void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a, + const torch::Tensor& b, const torch::Tensor& a_scales, + const torch::Tensor& b_scales, + const std::optional& bias); + +void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a, + const torch::Tensor& b, const torch::Tensor& a_scales, + const torch::Tensor& b_scales, + const torch::Tensor& azp_adj, + const std::optional& azp, + const std::optional& bias); TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -18,7 +30,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float k_scale, float v_scale," + " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -27,12 +39,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // PagedAttention V2. ops.def( "paged_attention_v2(" - " Tensor! out, Tensor exp_sums, Tensor max_logits," - " Tensor tmp_out, Tensor query, Tensor key_cache," + " Tensor! out, Tensor! exp_sums, Tensor! max_logits," + " Tensor! tmp_out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float k_scale, float v_scale," + " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -84,6 +96,37 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! key, int head_size," " Tensor cos_sin_cache, bool is_neox) -> ()"); ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); + + // Quantization +#ifdef __AVX512F__ + // Compute int8 quantized tensor for given scaling factor. + ops.def( + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," + "Tensor? azp) -> ()"); + ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); + + // Compute int8 quantized tensor and scaling factor + ops.def( + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " + "Tensor!? azp) -> ()"); + ops.impl("dynamic_scaled_int8_quant", torch::kCPU, + &dynamic_scaled_int8_quant); + // W8A8 GEMM, supporting symmetric per-tensor or per-row/column + // quantization. + ops.def( + "cutlass_scaled_mm(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales, Tensor? bias) -> ()"); + ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); + // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column + // quantization. + ops.def( + "cutlass_scaled_mm_azp(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales, Tensor azp_adj," + " Tensor? azp, Tensor? bias) -> ()"); + ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); +#endif } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { @@ -95,8 +138,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { // Copy the cache blocks from src to dst. cache_ops.def( - "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor " - "block_mapping) -> ()"); + "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, " + "Tensor block_mapping) -> ()"); cache_ops.impl("copy_blocks", torch::kCPU, ©_blocks); // Reshape the key and value tensors and cache them. @@ -105,13 +148,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! key_cache, Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float k_scale, float v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) { // CPU utils - utils.def("init_cpu_threads_env(str cpu_ids) -> ()", &init_cpu_threads_env); + utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/cpu/utils.cpp b/csrc/cpu/utils.cpp index 5782580baa86..42a1c1d924ba 100644 --- a/csrc/cpu/utils.cpp +++ b/csrc/cpu/utils.cpp @@ -1,11 +1,23 @@ -#include -#include -#include -#include +#ifndef VLLM_NUMA_DISABLED + #include + #include + #include + #include +#endif #include "cpu_types.hpp" -void init_cpu_threads_env(const std::string& cpu_ids) { +#ifdef VLLM_NUMA_DISABLED +std::string init_cpu_threads_env(const std::string& cpu_ids) { + return std::string( + "Warning: NUMA is not enabled in this build. `init_cpu_threads_env` has " + "no effect to setup thread affinity."); +} + +#endif + +#ifndef VLLM_NUMA_DISABLED +std::string init_cpu_threads_env(const std::string& cpu_ids) { bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str()); TORCH_CHECK(omp_cpu_mask->size > 0); std::vector omp_cpu_ids; @@ -51,15 +63,41 @@ void init_cpu_threads_env(const std::string& cpu_ids) { torch::set_num_threads((int)omp_cpu_ids.size()); TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads()); TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads()); -#pragma omp parallel for schedule(static, 1) + + std::vector> thread_core_mapping; + thread_core_mapping.reserve(omp_cpu_ids.size()); + omp_lock_t writelock; + omp_init_lock(&writelock); + + #pragma omp parallel for schedule(static, 1) for (size_t i = 0; i < omp_cpu_ids.size(); ++i) { - cpu_set_t* mask = CPU_ALLOC(omp_cpu_mask->size); - size_t size = CPU_ALLOC_SIZE(omp_cpu_mask->size); - CPU_ZERO_S(size, mask); - CPU_SET_S(omp_cpu_ids[i], size, mask); - sched_setaffinity(0, sizeof(cpu_set_t), mask); - CPU_FREE(mask); + cpu_set_t mask; + CPU_ZERO(&mask); + CPU_SET(omp_cpu_ids[i], &mask); + int ret = sched_setaffinity(0, sizeof(cpu_set_t), &mask); + if (ret == -1) { + TORCH_CHECK(false, + "sched_setaffinity failed. errno: " + std::to_string(errno)); + } + + omp_set_lock(&writelock); + thread_core_mapping.emplace_back(gettid(), omp_cpu_ids[i]); + omp_unset_lock(&writelock); } + omp_destroy_lock(&writelock); + numa_free_nodemask(omp_cpu_mask); + + std::stringstream ss; + ss << "OMP threads binding of Process " << getpid() << ":\n"; + std::sort(thread_core_mapping.begin(), thread_core_mapping.end(), + [](auto&& a, auto&& b) { return a.second < b.second; }); + for (auto&& item : thread_core_mapping) { + ss << "\t" + << "OMP tid: " << item.first << ", core " << item.second << "\n"; + } + + return ss.str(); } +#endif \ No newline at end of file diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index 73944f4c1489..6e62ea208db8 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -1,5 +1,41 @@ #pragma once +#include + +#if defined(__HIPCC__) + #define HOST_DEVICE_INLINE __host__ __device__ + #define DEVICE_INLINE __device__ + #define HOST_INLINE __host__ +#elif defined(__CUDACC__) || defined(_NVHPC_CUDA) + #define HOST_DEVICE_INLINE __host__ __device__ __forceinline__ + #define DEVICE_INLINE __device__ __forceinline__ + #define HOST_INLINE __host__ __forceinline__ +#else + #define HOST_DEVICE_INLINE inline + #define DEVICE_INLINE inline + #define HOST_INLINE inline +#endif + +#define CUDA_CHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ + cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + int64_t get_device_attribute(int64_t attribute, int64_t device_id); int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); + +namespace cuda_utils { + +template +HOST_DEVICE_INLINE constexpr std::enable_if_t, T> +ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +}; // namespace cuda_utils \ No newline at end of file diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu index d6f9eb646fad..0627a42675b5 100644 --- a/csrc/cuda_utils_kernels.cu +++ b/csrc/cuda_utils_kernels.cu @@ -1,16 +1,22 @@ +#include "cuda_utils.h" #ifdef USE_ROCM #include #include #endif + int64_t get_device_attribute(int64_t attribute, int64_t device_id) { - int device, value; - if (device_id < 0) { - cudaGetDevice(&device); - } else { - device = device_id; - } - cudaDeviceGetAttribute(&value, static_cast(attribute), - device); + // Return the cached value on subsequent calls + static int value = [=]() { + int device = static_cast(device_id); + if (device < 0) { + CUDA_CHECK(cudaGetDevice(&device)); + } + int value; + CUDA_CHECK(cudaDeviceGetAttribute( + &value, static_cast(attribute), device)); + return static_cast(value); + }(); + return value; } diff --git a/csrc/cumem_allocator.cpp b/csrc/cumem_allocator.cpp new file mode 100644 index 000000000000..fab6ca36d422 --- /dev/null +++ b/csrc/cumem_allocator.cpp @@ -0,0 +1,349 @@ +// A CUDAPluggableAllocator based on cumem* APIs. +// Important: allocation size, CUdeviceptr and CUmemGenericAllocationHandle* +// need to be unsigned long long +#include + +extern "C" { + +#define PY_SSIZE_T_CLEAN +#include + +#include +#include +#include + +char error_msg[10240]; // 10KB buffer to store error messages +CUresult no_error = CUresult(0); +CUresult error_code = no_error; // store error code + +#define CUDA_CHECK(condition) \ + do { \ + CUresult error = condition; \ + if (error != 0) { \ + error_code = error; \ + char* error_string; \ + cuGetErrorString(error, (const char**)&error_string); \ + snprintf(error_msg, sizeof(error_msg), "CUDA Error: %s at %s:%d", \ + error_string, __FILE__, __LINE__); \ + std::cerr << error_msg << std::endl; \ + } \ + } while (0) + +// Global references to Python callables +// NOTE: this is borrowed reference, so we don't need to DECREF them. +// This brings the limitation that the allocator needs to be singleton. +static PyObject* g_python_malloc_callback = nullptr; +static PyObject* g_python_free_callback = nullptr; + +// --------------------------------------------------------------------------- +// Helper functions: + +void ensure_context(unsigned long long device) { + CUcontext pctx; + CUDA_CHECK(cuCtxGetCurrent(&pctx)); + if (!pctx) { + // Ensure device context. + CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); + CUDA_CHECK(cuCtxSetCurrent(pctx)); + } +} + +void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem, + CUmemGenericAllocationHandle* p_memHandle) { + ensure_context(device); + // Define memory allocation properties + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = device; + prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE; + + // Allocate memory using cuMemCreate + CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0)); + if (error_code != 0) { + return; + } + CUDA_CHECK(cuMemMap(d_mem, size, 0, *p_memHandle, 0)); + if (error_code != 0) { + return; + } + CUmemAccessDesc accessDesc = {}; + accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + accessDesc.location.id = device; + accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + + CUDA_CHECK(cuMemSetAccess(d_mem, size, &accessDesc, 1)); + if (error_code != 0) { + return; + } + // std::cout << "create_and_map: device=" << device << ", size=" << size << ", + // d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl; +} + +void unmap_and_release(unsigned long long device, ssize_t size, + CUdeviceptr d_mem, + CUmemGenericAllocationHandle* p_memHandle) { + // std::cout << "unmap_and_release: device=" << device << ", size=" << size << + // ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl; + ensure_context(device); + CUDA_CHECK(cuMemUnmap(d_mem, size)); + if (error_code != 0) { + return; + } + CUDA_CHECK(cuMemRelease(*p_memHandle)); + if (error_code != 0) { + return; + } +} + +PyObject* create_tuple_from_c_integers(unsigned long long a, + unsigned long long b, + unsigned long long c, + unsigned long long d) { + // Create a new tuple of size 4 + PyObject* tuple = PyTuple_New(4); + if (!tuple) { + return NULL; // Return NULL on failure + } + + // Convert integers to Python objects and set them in the tuple + PyTuple_SetItem( + tuple, 0, + PyLong_FromUnsignedLongLong(a)); // Steals reference to the PyLong + PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b)); + PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c)); + PyTuple_SetItem(tuple, 3, PyLong_FromUnsignedLongLong(d)); + + // Note: PyTuple_SetItem "steals" a reference to each object, + // so we do not need to Py_DECREF the PyLong objects explicitly. + + return tuple; // Return the created tuple +} + +// --------------------------------------------------------------------------- +// Our exported C functions that call Python: + +// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h +void* my_malloc(ssize_t size, int device, CUstream stream) { + ensure_context(device); + + // first allocation, align the size, and reserve an address, and also allocate + // a CUmemGenericAllocationHandle + + // Define memory allocation properties + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = device; + prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE; + + // Check if the allocation is supported + size_t granularity; + CUDA_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, + CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + if (error_code != 0) { + return nullptr; + } + size_t alignedSize = ((size + granularity - 1) / granularity) * granularity; + + CUdeviceptr d_mem; + CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, 0, 0, 0)); + if (error_code != 0) { + return nullptr; + } + // allocate the CUmemGenericAllocationHandle + CUmemGenericAllocationHandle* p_memHandle = + (CUmemGenericAllocationHandle*)malloc( + sizeof(CUmemGenericAllocationHandle)); + + if (!g_python_malloc_callback) { + std::cerr << "ERROR: g_python_malloc_callback not set.\n"; + return nullptr; + } + + // Acquire GIL (not in stable ABI officially, but often works) + PyGILState_STATE gstate = PyGILState_Ensure(); + + PyObject* arg_tuple = create_tuple_from_c_integers( + (unsigned long long)device, (unsigned long long)alignedSize, + (unsigned long long)d_mem, (unsigned long long)p_memHandle); + + // Call g_python_malloc_callback + PyObject* py_result = + PyObject_CallFunctionObjArgs(g_python_malloc_callback, arg_tuple, NULL); + Py_DECREF(arg_tuple); + + if (!py_result) { + PyErr_Print(); + PyGILState_Release(gstate); + return nullptr; + } + + PyGILState_Release(gstate); + + // do the final mapping + create_and_map(device, alignedSize, d_mem, p_memHandle); + + return (void*)d_mem; +} + +// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h +void my_free(void* ptr, ssize_t size, int device, CUstream stream) { + // get memory handle from the pointer + if (!g_python_free_callback) { + std::cerr << "ERROR: g_python_free_callback not set.\n"; + return; + } + + // Acquire GIL (not in stable ABI officially, but often works) + PyGILState_STATE gstate = PyGILState_Ensure(); + + PyObject* py_ptr = + PyLong_FromUnsignedLongLong(reinterpret_cast(ptr)); + + PyObject* py_result = + PyObject_CallFunctionObjArgs(g_python_free_callback, py_ptr, NULL); + + if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) { + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); + return; + } + + unsigned long long recv_device, recv_size; + unsigned long long recv_d_mem, recv_p_memHandle; + // Unpack the tuple into four C integers + if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size, + &recv_d_mem, &recv_p_memHandle)) { + // PyArg_ParseTuple sets an error if it fails + return; + } + + PyGILState_Release(gstate); + + // recv_size == size + // recv_device == device + + // Free memory + + CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem; + CUmemGenericAllocationHandle* p_memHandle = + (CUmemGenericAllocationHandle*)recv_p_memHandle; + unmap_and_release(device, size, d_mem, p_memHandle); + + // free address and the handle + CUDA_CHECK(cuMemAddressFree(d_mem, size)); + if (error_code != 0) { + return; + } + free(p_memHandle); +} + +// --------------------------------------------------------------------------- +// Python extension boilerplate: + +// Python-exposed function: init_module(python_malloc, python_free) +static PyObject* py_init_module(PyObject* self, PyObject* args) { + PyObject* malloc_callback = nullptr; + PyObject* free_callback = nullptr; + + if (!PyArg_ParseTuple(args, "OO", &malloc_callback, &free_callback)) { + return nullptr; + } + + if (!PyCallable_Check(malloc_callback) || !PyCallable_Check(free_callback)) { + PyErr_SetString(PyExc_TypeError, "Both arguments must be callables"); + return nullptr; + } + + // Save the Python callables + // This module does not handle GC of these objects, so they must be kept alive + // outside of this module. + g_python_malloc_callback = malloc_callback; + g_python_free_callback = free_callback; + + Py_RETURN_NONE; +} + +static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) { + if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) { + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); + return nullptr; + } + + unsigned long long recv_device, recv_size; + unsigned long long recv_d_mem, recv_p_memHandle; + // Unpack the tuple into four C integers + if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem, + &recv_p_memHandle)) { + // PyArg_ParseTuple sets an error if it fails + return nullptr; + } + + CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem; + CUmemGenericAllocationHandle* p_memHandle = + (CUmemGenericAllocationHandle*)recv_p_memHandle; + + unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle); + + if (error_code != 0) { + error_code = no_error; + PyErr_SetString(PyExc_RuntimeError, error_msg); + return nullptr; + } + + Py_RETURN_NONE; +} + +static PyObject* python_create_and_map(PyObject* self, PyObject* args) { + if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) { + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); + return nullptr; + } + + unsigned long long recv_device, recv_size; + unsigned long long recv_d_mem, recv_p_memHandle; + // Unpack the tuple into four C integers + if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem, + &recv_p_memHandle)) { + // PyArg_ParseTuple sets an error if it fails + return nullptr; + } + + CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem; + CUmemGenericAllocationHandle* p_memHandle = + (CUmemGenericAllocationHandle*)recv_p_memHandle; + + create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle); + + if (error_code != 0) { + error_code = no_error; + PyErr_SetString(PyExc_RuntimeError, error_msg); + return nullptr; + } + + Py_RETURN_NONE; +} + +static PyMethodDef module_methods[] = { + {"init_module", (PyCFunction)py_init_module, METH_VARARGS, + "Initialize module with python_malloc and python_free callables."}, + {"python_create_and_map", (PyCFunction)python_create_and_map, METH_VARARGS, + "Create and map memory on the device."}, + {"python_unmap_and_release", (PyCFunction)python_unmap_and_release, + METH_VARARGS, "Unmap and release memory on the device."}, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef cumem_allocator_module = { + PyModuleDef_HEAD_INIT, "cumem_allocator", + "cumem-based allocator for CUDAPluggableAllocator", -1, module_methods}; + +PyMODINIT_FUNC PyInit_cumem_allocator(void) { + // Initialize the module + PyObject* module = PyModule_Create(&cumem_allocator_module); + if (!module) { + return NULL; + } + return module; +} +} // extern "C" diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 82a3563979f1..123278bfed71 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -5,32 +5,29 @@ #include "custom_all_reduce.cuh" -// fake pointer type, must match fptr_t type in ops.h +// Fake pointer type, must match fptr_t type in ops.h. +// We use this type alias to indicate when pointers are passed in as int64_t. using fptr_t = int64_t; static_assert(sizeof(void*) == sizeof(fptr_t)); -fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, - const std::vector& handles, - const std::vector& offsets, int64_t rank, +fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, + torch::Tensor& rank_data, int64_t rank, bool full_nvlink) { - int world_size = offsets.size(); + int world_size = fake_ipc_ptrs.size(); if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported"); if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now"); - if (world_size != handles.size()) - throw std::invalid_argument( - "handles length should equal to offsets length"); if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in"); - cudaIpcMemHandle_t ipc_handles[8]; + vllm::Signal* ipc_ptrs[8]; for (int i = 0; i < world_size; i++) { - std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); + ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); } - return (fptr_t) new vllm::CustomAllreduce( - reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), - rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); + return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(), + rank_data.numel(), rank, world_size, + full_nvlink); } /** @@ -55,38 +52,48 @@ bool _is_weak_contiguous(torch::Tensor& t) { t.numel() * t.element_size()); } -bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, - bool full_nvlink) { - auto inp_size = inp.numel() * inp.element_size(); - // custom allreduce requires input byte size to be multiples of 16 - if (inp_size % 16 != 0) return false; - if (!_is_weak_contiguous(inp)) return false; - if (world_size == 2 || full_nvlink) return inp_size <= max_size; - // for 4 or more non NVLink-capable GPUs, custom allreduce provides little - // performance improvement over NCCL. - return false; -} - -void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, - cudaStream_t stream) { +/** + * Performs an out-of-place allreduce and stores result in out. + * + * If _reg_buffer is null, assumes inp.data_ptr() is already IPC-registered. + * Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first + * copied into _reg_buffer. + */ +void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, + fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) { auto fa = reinterpret_cast(_fa); + const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); TORCH_CHECK(_is_weak_contiguous(out)); + TORCH_CHECK(_is_weak_contiguous(inp)); + auto input_size = inp.numel() * inp.element_size(); + auto reg_buffer = reinterpret_cast(_reg_buffer); + if (reg_buffer) { + TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes); + AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size, + cudaMemcpyDeviceToDevice, stream)); + } else { + reg_buffer = inp.data_ptr(); + } switch (out.scalar_type()) { case at::ScalarType::Float: { - fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + fa->allreduce(stream, reinterpret_cast(reg_buffer), reinterpret_cast(out.data_ptr()), out.numel()); break; } case at::ScalarType::Half: { - fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + fa->allreduce(stream, reinterpret_cast(reg_buffer), reinterpret_cast(out.data_ptr()), out.numel()); break; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) case at::ScalarType::BFloat16: { fa->allreduce( - stream, reinterpret_cast(inp.data_ptr()), + stream, reinterpret_cast(reg_buffer), reinterpret_cast(out.data_ptr()), out.numel()); break; } @@ -97,57 +104,41 @@ void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, } } -void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); - auto stream = c10::cuda::getCurrentCUDAStream().stream(); - TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); - TORCH_CHECK_EQ(inp.numel(), out.numel()); - _all_reduce(_fa, inp, out, stream); -} - -void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, - torch::Tensor& out) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); - auto stream = c10::cuda::getCurrentCUDAStream().stream(); - - auto input_size = inp.numel() * inp.element_size(); - TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); - TORCH_CHECK_EQ(inp.numel(), out.numel()); - TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(), - "registered buffer is too small to contain the input"); - AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(), - input_size, cudaMemcpyDeviceToDevice, stream)); - _all_reduce(_fa, reg_buffer, out, stream); -} - void dispose(fptr_t _fa) { - auto fa = reinterpret_cast(_fa); - delete fa; + delete reinterpret_cast(_fa); } int64_t meta_size() { return sizeof(vllm::Signal); } -void register_buffer(fptr_t _fa, torch::Tensor& t, - const std::vector& handles, - const std::vector& offsets) { +void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs) { auto fa = reinterpret_cast(_fa); - fa->register_buffer(handles, offsets, t.data_ptr()); + TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_); + void* ipc_ptrs[8]; + for (int i = 0; i < fake_ipc_ptrs.size(); i++) { + ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); + } + fa->register_buffer(ipc_ptrs); } -std::tuple> get_graph_buffer_ipc_meta( - fptr_t _fa) { +// Use vector to represent byte data for python binding compatibility. +std::tuple, std::vector> +get_graph_buffer_ipc_meta(fptr_t _fa) { auto fa = reinterpret_cast(_fa); - auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta(); - auto options = - torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); - auto handles = - torch::empty({static_cast(handle_bytes.size())}, options); - std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size()); - return {handles, std::move(offsets)}; + auto [handle, offsets] = fa->get_graph_buffer_ipc_meta(); + std::vector bytes(handle.begin(), handle.end()); + return std::make_tuple(bytes, offsets); } -void register_graph_buffers(fptr_t _fa, const std::vector& handles, +// Use vector to represent byte data for python binding compatibility. +void register_graph_buffers(fptr_t _fa, + const std::vector>& handles, const std::vector>& offsets) { auto fa = reinterpret_cast(_fa); - fa->register_graph_buffers(handles, offsets); + std::vector bytes; + bytes.reserve(handles.size()); + for (int i = 0; i < handles.size(); i++) { + bytes.emplace_back(handles[i].begin(), handles[i].end()); + } + bytes.reserve(handles.size()); + fa->register_graph_buffers(bytes, offsets); } diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 1ed49b8aa9ca..b9df4ed160b0 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -23,17 +24,27 @@ namespace vllm { -constexpr int kMaxBlocks = 64; -// note: we don't want to use atomics for signals because peer atomics are no -// supported on PCIe links +constexpr int kMaxBlocks = 36; +// Counter may overflow, but it's fine since unsigned int overflow is +// well-defined behavior. +using FlagType = uint32_t; struct Signal { - alignas(128) uint32_t start[kMaxBlocks][8]; - alignas(128) uint32_t end[kMaxBlocks][8]; + alignas(128) FlagType self_counter[kMaxBlocks][8]; + // Two sets of peer counters are needed for two syncs. The reason is that + // it's possible for peer GPU block to arrive at the second sync point while + // the current GPU block haven't passed the first sync point. Thus, peer GPU + // may write counter+1 while current GPU is busy waiting for counter. We use + // alternating counter array to avoid this possibility. + alignas(128) FlagType peer_counter[2][kMaxBlocks][8]; }; -struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; +struct __align__(16) RankData { + const void* __restrict__ ptrs[8]; +}; -struct __align__(16) RankSignals { volatile Signal* signals[8]; }; +struct __align__(16) RankSignals { + Signal* signals[8]; +}; // like std::array, but aligned template @@ -123,47 +134,71 @@ DINLINE O downcast(array_t val) { } } -// This function is meant to be used as the first synchronization in the all -// reduce kernel. Thus, it doesn't need to make any visibility guarantees for -// prior memory accesses. Note: volatile writes will not be reordered against -// other volatile writes. -template -DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, - int rank) { - if (threadIdx.x < ngpus) { - // reset flag for next time - self_sg->end[blockIdx.x][threadIdx.x] = 0; - // simultaneously write to the corresponding flag of all ranks. - // Latency = 1 p2p write - sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; - // wait until we got true from all ranks - while (!self_sg->start[blockIdx.x][threadIdx.x]); - } - __syncthreads(); +static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), + "l"(flag_addr)); +#else + asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag), + "l"(flag_addr)); +#endif +} + +static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) { + FlagType flag; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile("ld.acquire.sys.global.u32 %0, [%1];" + : "=r"(flag) + : "l"(flag_addr)); +#else + asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;" + : "=r"(flag) + : "l"(flag_addr)); +#endif + return flag; +} + +static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) { + asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); } -// This function is meant to be used as the second or the final synchronization -// barrier in the all reduce kernel. If it's the final synchronization barrier, -// we don't need to make any visibility guarantees for prior memory accesses. -template -DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg, - int rank) { - __syncthreads(); - // eliminate the case that prior writes are not visible after signals become - // visible. Note that I did not managed to make this happen through a lot of - // testing. Might be the case that hardware provides stronger guarantee than - // the memory model. - if constexpr (!final_sync) __threadfence_system(); +static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) { + FlagType flag; + asm volatile("ld.volatile.global.u32 %0, [%1];" + : "=r"(flag) + : "l"(flag_addr)); + return flag; +} + +// is_start: whether this is the very first synchronization barrier. +// need_fence: whether a memory fence is needed. If true, a release-acquire +// semantic is used to enforce memory access order before and after this +// barrier. +template +DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, + int rank) { + if constexpr (!is_start) __syncthreads(); + static_assert( + !(is_start && need_fence)); // Start barrier shouldn't need fence. if (threadIdx.x < ngpus) { - // reset flag for next time - self_sg->start[blockIdx.x][threadIdx.x] = 0; - // simultaneously write to the corresponding flag of all ranks. - // Latency = 1 p2p write - sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; - // wait until we got true from all ranks - while (!self_sg->end[blockIdx.x][threadIdx.x]); + // Increment the counter. Technically we only need one counter, but we use + // multiple per block to eliminate the need to share the counter via smem. + auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1; + // Write the expected counter value to peer and wait for correct value from + // peer. + auto peer_counter_ptr = + &sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank]; + auto self_counter_ptr = + &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x]; + if constexpr (need_fence) { + st_flag_release(peer_counter_ptr, val); + while (ld_flag_acquire(self_counter_ptr) != val); + } else { + st_flag_volatile(peer_counter_ptr, val); + while (ld_flag_volatile(self_counter_ptr) != val); + } } - if constexpr (!final_sync) __syncthreads(); + if constexpr (is_start || need_fence) __syncthreads(); } template @@ -178,33 +213,31 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) { template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_1stage(RankData* _dp, RankSignals sg, - volatile Signal* self_sg, T* __restrict__ result, - int rank, int size) { + cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg, + T* __restrict__ result, int rank, int size) { using P = typename packed_t::P; using A = typename packed_t::A; // note: we don't reorder the address so the accumulation order is the same // for all ranks, ensuring bitwise identical results auto dp = *_dp; - start_sync(sg, self_sg, rank); + multi_gpu_barrier(sg, self_sg, rank); // do the actual reduction for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); } - end_sync(sg, self_sg, rank); + multi_gpu_barrier(sg, self_sg, rank); } template -DINLINE P* get_tmp_buf(volatile Signal* sg) { +DINLINE P* get_tmp_buf(Signal* sg) { return (P*)(((Signal*)sg) + 1); } template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_2stage(RankData* _dp, RankSignals sg, - volatile Signal* self_sg, T* __restrict__ result, - int rank, int size) { + cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg, + T* __restrict__ result, int rank, int size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; using P = typename packed_t::P; @@ -222,12 +255,12 @@ __global__ void __launch_bounds__(512, 1) tmps[i] = get_tmp_buf

(sg.signals[target]); } auto tmp_out = tmps[0]; - start_sync(sg, self_sg, rank); + multi_gpu_barrier(sg, self_sg, rank); // stage 1: reduce scatter for (int idx = start + tid; idx < end; idx += stride) { tmp_out[idx - start] = packed_reduce(ptrs, idx); } - end_sync(sg, self_sg, rank); + multi_gpu_barrier(sg, self_sg, rank); // stage 2: allgather. Note: it's important to match the tid between // the two stages, because visibility across devices is only guaranteed @@ -256,46 +289,52 @@ class CustomAllreduce { int world_size_; bool full_nvlink_; - // below are device pointers RankSignals sg_; + // Stores an map from a pointer to its peer pointters from all ranks. std::unordered_map buffers_; Signal* self_sg_; - // stores the registered device pointers from all ranks + // Stores rank data from all ranks. This is mainly for cuda graph purposes. + // For cuda graph to work, all kernel arguments must be fixed during graph + // capture time. However, the peer pointers are not known during graph capture + // time. Therefore, during capture, we increment the rank data pointer and use + // that as the argument to the kernel. The kernel arguments are stored in + // graph_unreg_buffers_. The actual peer pointers will be filled in at the + // memory pointed to by the pointers in graph_unreg_buffers_ when + // the IPC handles are exchanged between ranks. + // + // The overall process looks like this: + // 1. Graph capture. + // 2. Each rank obtains the IPC handles for each addresses used during cuda + // graph capture using get_graph_buffer_ipc_meta. + // 3. (In Python) all gather the IPC handles. + // 4. Obtain the peer pointers by opening the IPC handles, and store them in + // the rank data array at corresponding positions. RankData *d_rank_data_base_, *d_rank_data_end_; std::vector graph_unreg_buffers_; // a map from IPC handles to opened IPC pointers std::map ipc_handles_; /** - * meta is a pointer to device metadata and temporary buffer for allreduce. - * - * There's a total of sizeof(Signal) of prefix before the actual data, - * so meta + 1 points to actual temporary buffer. + * Signals are an array of ipc-enabled buffers from all ranks. + * For each of the buffer, the layout is as follows: + * | -- sizeof(Signal) -- | ------ a few MB ----- | + * The first section is for allreduce synchronization, and the second section + * is for storing the intermediate results required by some allreduce algos. * - * note: this class does not own any device memory. Any required buffers - * are passed in from the constructor + * Note: this class does not own any device memory. Any required buffers + * are passed in from the constructor. */ - CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, - const cudaIpcMemHandle_t* handles, - const std::vector& offsets, int rank, - bool full_nvlink = true) + CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz, + int rank, int world_size, bool full_nvlink = true) : rank_(rank), - world_size_(offsets.size()), + world_size_(world_size), full_nvlink_(full_nvlink), - self_sg_(meta), + self_sg_(signals[rank]), d_rank_data_base_(reinterpret_cast(rank_data)), d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { for (int i = 0; i < world_size_; i++) { - Signal* rank_sg; - if (i != rank_) { - char* handle = open_ipc_handle(&handles[i]); - handle += offsets[i]; - rank_sg = (Signal*)handle; - } else { - rank_sg = self_sg_; - } - sg_.signals[i] = rank_sg; + sg_.signals[i] = signals[i]; } } @@ -312,11 +351,10 @@ class CustomAllreduce { return it->second; } - std::pair, std::vector> - get_graph_buffer_ipc_meta() { + std::pair> get_graph_buffer_ipc_meta() { auto num_buffers = graph_unreg_buffers_.size(); auto handle_sz = sizeof(cudaIpcMemHandle_t); - std::vector handles(handle_sz * num_buffers, 0); + std::string handles(handle_sz * num_buffers, static_cast(0)); std::vector offsets(num_buffers); for (int i = 0; i < num_buffers; i++) { auto ptr = graph_unreg_buffers_[i]; @@ -341,26 +379,22 @@ class CustomAllreduce { std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); } - void register_buffer(const std::vector& handles, - const std::vector& offsets, void* self) { + /** + * Register already-shared IPC pointers. + */ + void register_buffer(void** ptrs) { check_rank_data_capacity(); RankData data; for (int i = 0; i < world_size_; i++) { - if (i != rank_) { - char* handle = open_ipc_handle(handles[i].data()); - handle += offsets[i]; - data.ptrs[i] = handle; - } else { - data.ptrs[i] = self; - } + data.ptrs[i] = ptrs[i]; } auto d_data = d_rank_data_base_++; CUDACHECK( cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice)); - buffers_[self] = d_data; + buffers_[ptrs[rank_]] = d_data; } - // note: when registering graph buffers, we intentionally choose to not + // Note: when registering graph buffers, we intentionally choose to not // deduplicate the addresses. That means if the allocator reuses some // addresses, they will be registered again. This is to account for the remote // possibility of different allocation patterns between ranks. For example, @@ -395,11 +429,13 @@ class CustomAllreduce { } /** - * This is the result after careful grid search. Using 36 blocks give the best - * or close to the best runtime on the devices I tried: A100, A10, A30, T4, - * V100. You'll notice that NCCL kernels also only take a small amount of SMs. - * Not quite sure the underlying reason, but my guess is that too many SMs - * will cause contention on NVLink bus. + * Performs allreduce, assuming input has already been registered. + * + * Block and grid default configs are results after careful grid search. Using + * 36 blocks give the best or close to the best runtime on the devices I + * tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only + * take a small amount of SMs. Not quite sure the underlying reason, but my + * guess is that too many SMs will cause contention on NVLink bus. */ template void allreduce(cudaStream_t stream, T* input, T* output, int size, @@ -437,6 +473,8 @@ class CustomAllreduce { #define KL(ngpus, name) \ name<<>>(ptrs, sg_, self_sg_, output, \ rank_, size); + // TODO(hanzhi713): Threshold is different for A100 and H100. + // Add per device threshold. #define REDUCE_CASE(ngpus) \ case ngpus: { \ if (world_size_ == 2) { \ diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index f7868233076c..b59ea40d980f 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -1,15 +1,15 @@ /** * This is a standalone test for custom allreduce. * To compile, make sure you have MPI and NCCL installed in your system. - * export MPI_HOME=XXX + * export MPI_HOME=xxx * nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o - * custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi + * custom_all_reduce_test -lnccl -I${MPI_HOME} -lmpi * * Warning: this C++ test is not designed to be very readable and was used * during the rapid prototyping process. * * To run: - * mpirun -np 8 ./custom_all_reduce_test + * mpirun --allow-run-as-root -np 8 ./custom_all_reduce_test */ #include #include @@ -44,7 +44,14 @@ } while (0) __global__ void dummy_kernel() { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms +#else + for (int i = 0; i < 100; i++) { + long long int start = clock64(); + while (clock64() - start < 150000000); // approximately 98.4ms on P40 + } +#endif } template @@ -128,24 +135,26 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, void* rank_data; size_t rank_data_sz = 16 * 1024 * 1024; CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); - std::vector offsets(nRanks, 0); - vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, - offsets, myRank); + vllm::Signal* ipc_ptrs[8]; + for (int i = 0; i < nRanks; i++) { + if (i == myRank) + ipc_ptrs[i] = buffer; + else + CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptrs[i], data_handles[i], + cudaIpcMemLazyEnablePeerAccess)); + } + vllm::CustomAllreduce fa(ipc_ptrs, rank_data, rank_data_sz, myRank, nRanks); auto* self_data = reinterpret_cast(reinterpret_cast(buffer) + sizeof(vllm::Signal) + data_size * sizeof(T)); // hack buffer registration { - std::vector handles; - handles.reserve(nRanks); + void* data[8]; for (int i = 0; i < nRanks; i++) { - char* begin = (char*)&data_handles[i]; - char* end = (char*)&data_handles[i + 1]; - handles.emplace_back(begin, end); + data[i] = + ((char*)ipc_ptrs[i]) + sizeof(vllm::Signal) + data_size * sizeof(T); } - std::vector offsets(nRanks, - sizeof(vllm::Signal) + data_size * sizeof(T)); - fa.register_buffer(handles, offsets, self_data); + fa.register_buffer(data); } double* ground_truth; @@ -302,15 +311,19 @@ int main(int argc, char** argv) { bool performance_test = true; cudaProfilerStart(); - // for (int threads : {256, 512}) { + // Uncomment to scan through different block size configs. + // for (int threads : {256, 512, 1024}) { // for (int block_limit = 16; block_limit < 112; block_limit += 4) { - // run(myRank, nRanks, comm, threads, block_limit, 4096 * 1024); + // run(myRank, nRanks, comm, threads, block_limit, 1024 * 1024, + // performance_test); // } // } + // Scan through different sizes to test performance. for (int sz = 512; sz <= (8 << 20); sz *= 2) { run(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test); } cudaProfilerStop(); + MPICHECK(MPI_Finalize()); return EXIT_SUCCESS; } diff --git a/csrc/cutlass_extensions/common.cpp b/csrc/cutlass_extensions/common.cpp new file mode 100644 index 000000000000..3d2093ab9429 --- /dev/null +++ b/csrc/cutlass_extensions/common.cpp @@ -0,0 +1,11 @@ +#include "cutlass_extensions/common.hpp" + +int32_t get_sm_version_num() { + int32_t major_capability, minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + 0); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + 0); + int32_t version_num = major_capability * 10 + minor_capability; + return version_num; +} \ No newline at end of file diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp new file mode 100644 index 000000000000..febc4eccd956 --- /dev/null +++ b/csrc/cutlass_extensions/common.hpp @@ -0,0 +1,51 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include +#include "cuda_runtime.h" +#include + +/** + * Helper function for checking CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + TORCH_CHECK(error == cutlass::Status::kSuccess, \ + cutlassGetStatusString(error)); \ + } + +/** + * Panic wrapper for unwinding CUDA runtime errors + */ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \ + } + +inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { + int max_shared_mem_per_block_opt_in = 0; + cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, + cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + return max_shared_mem_per_block_opt_in; +} + +int32_t get_sm_version_num(); + +/** + * A wrapper for a kernel that is used to guard against compilation on + * architectures that will never use the kernel. The purpose of this is to + * reduce the size of the compiled binary. + * __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef + * into code that will be executed on the device where it is defined. + */ +template +struct enable_sm90_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 + Kernel::operator()(std::forward(args)...); +#endif + } +}; \ No newline at end of file diff --git a/csrc/cutlass_extensions/cute_utils.cuh b/csrc/cutlass_extensions/cute_utils.cuh new file mode 100644 index 000000000000..f61fe3ceb978 --- /dev/null +++ b/csrc/cutlass_extensions/cute_utils.cuh @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +namespace cute { + +//////////////////////////////////////////////////////////////////// +// layout utils +//////////////////////////////////////////////////////////////////// + +// Permute layout based on indices, example: +// permute_layout<1, 0>(layout) will swap the two dimensions +// permute_layout<0, 2, 1>(layout) will swap the last two dimensions +template +CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) { + static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch"); + return cute::make_layout(cute::get(l)...); +} + +// is the layout f(x) = x +template +CUTE_HOST_DEVICE static constexpr bool is_identity_layout() { + if constexpr (std::is_same_v) { + return true; + } else { + constexpr auto coalesced_layout = coalesce(Layout{}); + if constexpr (rank(coalesced_layout) == 1 && + stride<0>(coalesced_layout) == 1) { + return true; + } + return false; + } +} + +//////////////////////////////////////////////////////////////////// +// Pointer utils +//////////////////////////////////////////////////////////////////// + +template +static constexpr auto get_logical_ptr(PointerType* ptr) { + if constexpr (cute::sizeof_bits_v < 8) { + return cute::subbyte_iterator(ptr); + } else { + return ptr; + } +} + +//////////////////////////////////////////////////////////////////// +// Misc utils +//////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() { + constexpr auto bits = sizeof_bits_v * Elements{}; + if constexpr (bits % 128 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<128>{}; + } else if constexpr (bits % 64 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<64>{}; + } else if constexpr (bits % 32 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<32>{}; + } else if constexpr (bits % 16 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<16>{}; + } else { + return AutoVectorizingCopyWithAssumedAlignment<8>{}; + } +} + +}; // namespace cute diff --git a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp new file mode 100644 index 000000000000..7aa87feb4cce --- /dev/null +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp @@ -0,0 +1,497 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/visitor_load.hpp from +// https://github.com/NVIDIA/cutlass v3.5.0 +// It has been modified to support either +// row/column or scalar broadcasting where the tensor being loaded from is +// always passed in via a device pointer. This lets one compiled kernel handle +// all cases of per-tensor or per-channel/per-token quantization. +// +// This interface also allows the scales to be passed in as tensors that +// consistently reside on the device, which avoids an issue with a previous +// implementation where scalars needed to be on the CPU since they +// were passed in via float values. This created a potential performance hazard +// if scales were initially on the device, and caused torch.compile graph +// breaks when moving scales to the CPU. +// +#pragma once + +// Turn off clang-format for the entire file to keep it close to upstream +// clang-format off + +#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" +#include "cute/tensor.hpp" + +namespace cutlass::epilogue::threadblock { + +using namespace cute; +using namespace detail; + +template< + class ThreadMap, + class Element, + class StrideMNL +> +struct VisitorRowOrScalarBroadcast { + + // This struct has been modified to have a bool indicating that ptr_row is a + // scalar that must be broadcast. + struct Arguments { + Element const* ptr_row = nullptr; + bool row_broadcast = true; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage {}; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gRow, + RTensor&& tC_rRow, + CTensor&& tC_cRow, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gRow(cute::forward(tC_gRow)), + tC_rRow(cute::forward(tC_rRow)), + tC_cRow(cute::forward(tC_cRow)), + n(get<1>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gRow; + RTensor tC_rRow; + CTensor tC_cRow; + Params const* params_ptr; + int n; + + // This function is modified from VisitorRowBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rRow); + auto src_v = filter(tC_gRow); + auto coord_v = filter(tC_cRow); + auto dst_v = filter(tC_rRow); + + if (params_ptr->row_broadcast) { + // In this case we are loading from a row vector and broadcasting + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + bool guard = get<1>(coord_v(i)) < n; + cutlass::arch::global_load( + dst_v(i), (void const*)&src_v(i), guard); + } + } else { + // In this case we are loading from a scalar and broadcasting + VecType filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VecLength; i++) { + reinterpret_cast(&filled_vec)[i] = *(params_ptr->ptr_row); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + if (get<1>(coord_v(i)) < n) { + dst_v(i) = filled_vec; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Tensor rRow_frg = recast>(coalesce(tC_rRow)); + return rRow_frg(column_idx); + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mRow = make_tensor( + make_gmem_ptr(params_ptr->ptr_row), + problem_shape, + params_ptr->dRow); + + // VECTOR, FRAGMENT_COLUMN + Tensor tC_gRow = recast( + ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) + )(_,_,_0{},_0{},_0{},_0{}); + Tensor tC_rRow = make_tensor_like(tC_gRow); + + // Generate the pred tensor + Tensor cRow = make_identity_tensor(mRow.shape()); + Tensor tC_cRow = outer_partition( + ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), + Shape>{}, + (_0{}) + ); + + return Callbacks< + decltype(tC_gRow), decltype(tC_rRow), + decltype(tC_cRow), ProblemShape>( + cute::move(tC_gRow), + cute::move(tC_rRow), + cute::move(tC_cRow), + problem_shape, + params_ptr + ); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null +template< + class ThreadMap, + class Element, + class StrideMNL +> +struct VisitorRowOrZeroBroadcast { + + // This struct has been modified to remove null_default (because it's always 0) + struct Arguments { + Element const* ptr_row = nullptr; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage {}; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorRowOrZeroBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gRow, + RTensor&& tC_rRow, + CTensor&& tC_cRow, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gRow(cute::forward(tC_gRow)), + tC_rRow(cute::forward(tC_rRow)), + tC_cRow(cute::forward(tC_cRow)), + n(get<1>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gRow; + RTensor tC_rRow; + CTensor tC_cRow; + Params const* params_ptr; + int n; + + // This function is modified from VisitorRowBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rRow); + auto src_v = filter(tC_gRow); + auto coord_v = filter(tC_cRow); + auto dst_v = filter(tC_rRow); + + if (params_ptr->ptr_row != nullptr) { + // In this case we are loading from a row vector and broadcasting + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + bool guard = get<1>(coord_v(i)) < n; + cutlass::arch::global_load( + dst_v(i), (void const*)&src_v(i), guard); + } + } else { + // In this case we are broadcasting 0 + VecType filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VecLength; i++) { + reinterpret_cast(&filled_vec)[i] = Element{0}; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + if (get<1>(coord_v(i)) < n) { + dst_v(i) = filled_vec; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Tensor rRow_frg = recast>(coalesce(tC_rRow)); + return rRow_frg(column_idx); + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mRow = make_tensor( + make_gmem_ptr(params_ptr->ptr_row), + problem_shape, + params_ptr->dRow); + + // VECTOR, FRAGMENT_COLUMN + Tensor tC_gRow = recast( + ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) + )(_,_,_0{},_0{},_0{},_0{}); + Tensor tC_rRow = make_tensor_like(tC_gRow); + + // Generate the pred tensor + Tensor cRow = make_identity_tensor(mRow.shape()); + Tensor tC_cRow = outer_partition( + ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), + Shape>{}, + (_0{}) + ); + + return Callbacks< + decltype(tC_gRow), decltype(tC_rRow), + decltype(tC_cRow), ProblemShape>( + cute::move(tC_gRow), + cute::move(tC_rRow), + cute::move(tC_cRow), + problem_shape, + params_ptr + ); + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + class ThreadMap, + class Element, + class StrideMNL = Stride<_1,_0,_0> +> +struct VisitorColOrScalarBroadcast { + + // This struct has been modified to have a bool indicating that ptr_col is a + // scalar that must be broadcast. + struct Arguments { + Element const* ptr_col = nullptr; + bool col_broadcast = true; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage { }; + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gCol, + RTensor&& tC_rCol, + CTensor&& tC_cCol, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gCol(cute::forward(tC_gCol)), + tC_rCol(cute::forward(tC_rCol)), + tC_cCol(cute::forward(tC_cCol)), + m(get<0>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gCol; + RTensor tC_rCol; + CTensor tC_cCol; + Params const* params_ptr; + int m; + + // This function is modified from VisitorColBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rCol); + + Tensor pred = make_tensor(shape(tC_gCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tC_cCol(i)) < m; + } + + if (params_ptr->col_broadcast) { + // In this case we are loading from a column vector and broadcasting + copy_if(pred, tC_gCol, tC_rCol); + } else { + // In this case we are loading from a scalar and broadcasting + auto dst_v = filter(tC_rCol); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(dst_v); ++i) { + if (pred(i)) { + dst_v(i) = *(params_ptr->ptr_col); + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Array frg_col; + frg_col.fill(tC_rCol(row_idx,iter_idx)); + return frg_col; + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mCol = make_tensor( + make_gmem_ptr(params_ptr->ptr_col), + problem_shape, + params_ptr->dCol); + + // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER + Tensor tC_gCol = group_modes<1,4>( + ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + Tensor tC_rCol = make_tensor_like(tC_gCol); + + // Generate the pred tensor + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tC_cCol = group_modes<1,4>( + ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + + return Callbacks< + decltype(tC_gCol), decltype(tC_rCol), + decltype(tC_cCol), ProblemShape>( + cute::move(tC_gCol), + cute::move(tC_rCol), + cute::move(tC_cCol), + problem_shape, + params_ptr + ); + } +}; + +} diff --git a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp new file mode 100644 index 000000000000..58b1e8ff159f --- /dev/null +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp @@ -0,0 +1,447 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +// from https://github.com/NVIDIA/cutlass v3.5.0 +// It has been modified to support either row/column or scalar broadcasting +// where the tensor being loaded from is always passed in via a device pointer. +// This lets one compiled kernel handle all cases of per-tensor or +// per-channel/per-token quantization. +// +// This interface also allows the scales to be passed in as tensors that +// consistently reside on the device, which avoids an issue with a previous +// implementation where scalars needed to be on the CPU since they +// were passed in via float values. This created a potential performance hazard +// if scales were initially on the device, and caused torch.compile graphs +// breaks when moving scales to the CPU. +// +#pragma once + +// Turn off clang-format for the entire file to keep it close to upstream +// clang-format off + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +// Row vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90RowOrScalarBroadcast { + static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); + + struct SharedStorage { + array_aligned(CtaTileShapeMNK{})> smem; + }; + + // This struct has been modified to have a bool indicating that ptr_row is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_row is null. + struct Arguments { + Element const* ptr_row = nullptr; + bool row_broadcast = true; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params(params) + , smem(const_cast(shared_storage.smem.data())) { } + + Params params; + Element *smem = nullptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.row_broadcast && *(params.ptr_row) == Element(0)); + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, + GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, + SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, + CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_) + : tGS_gRow(tGS_gRow_) + , tGS_sRow(tGS_sRow_) + , tGS_cRow(tGS_cRow_) + , tiled_G2S(tiled_g2s_) + , tSR_sRow(tSR_sRow_) + , tSR_rRow(tSR_rRow_) + , tCcRow(tCcRow_) + , residue_tCcRow(residue_tCcRow_) + , params(params_) {} + + GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) + GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) + GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) + Tiled_G2S tiled_G2S; + + SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcRow; // (m, n) + ThrNum thr_num; + Params const& params; + + CUTLASS_DEVICE void + begin() { + if (!params.row_broadcast) { + fill(tSR_rRow, *(params.ptr_row)); + return; + } + + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + + for (int i = 0; i < size(tGS_gRow_flt); ++i) { + if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + continue; // OOB of SMEM, + } + if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { + tGS_sRow_flt(i) = tGS_gRow_flt(i); + } + else { + tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds. + } + } + synchronize(); + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + if (epi_m == 0) { // Assumes M-major subtile loop + if (!params.row_broadcast) return; // Do not issue LDS when row is scalar + Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + copy(tSR_sRow_flt, tSR_rRow_flt); + } + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_row; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); + } + + return frg_row; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using ThreadCount = decltype(size(args.tiled_copy)); + + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem), + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) + //// G2S: Gmem to Smem + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); + Tensor tGS_gRow = thr_g2s.partition_S(gRow); + Tensor tGS_sRow = thr_g2s.partition_D(sRow); + + //// G2S: Coord + auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); + Tensor tGS_cRow = thr_g2s.partition_S(cRow); + + //// S2R: Smem to Reg + Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + + return ConsumerStoreCallbacks( + tGS_gRow, + tGS_sRow, + tGS_cRow, tiled_g2s, + tSR_sRow, + tSR_rRow, + args.tCcD, + args.residue_cD, + ThreadCount{}, + params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90ColOrScalarBroadcast { + static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias + (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias + + // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem + struct SharedStorage { }; + + // This struct has been modified to have a bool indicating that ptr_col is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_col is null. + struct Arguments { + Element const* ptr_col = nullptr; + bool col_broadcast = true; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.col_broadcast && *(params.ptr_col) == Element(0)); + } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GTensor&& tCgCol, + RTensor&& tCrCol, + CTensor&& tCcCol, + ProblemShape problem_shape, + Params const& params + ): + tCgCol(cute::forward(tCgCol)), + tCrCol(cute::forward(tCrCol)), + tCcCol(cute::forward(tCcCol)), + m(get<0>(problem_shape)), + params(params) {} + + GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensor tCrCol; + CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Params const& params; + int m; + + CUTLASS_DEVICE void + begin() { + Tensor pred = make_tensor(shape(tCgCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tCcCol(i)) < m; + } + + if (!params.col_broadcast) { + fill(tCrCol, *(params.ptr_col)); + return; + } + + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_if(pred, filter(tCgCol), filter(tCrCol)); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_col; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); + } + + return frg_col; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + // Generate an identity tensor matching the shape of the global tensor and + // partition the same way, this will be used to generate the predicate + // tensor for loading + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tCgCol), + cute::move(tCrCol), + cute::move(tCcCol), + args.problem_shape_mnkl, + params + ); + } +}; + +} diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp new file mode 100644 index 000000000000..64b7ddae3d2d --- /dev/null +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp @@ -0,0 +1,321 @@ +#pragma once + +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp" + +/* + This file defines custom epilogues for fusing channel scales, token scales, + bias, and activation zero-points onto a GEMM operation using the + CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs. + + Epilogues must contain a public type named EVTCompute of type Sm80EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +namespace vllm::c2x { + +using namespace cute; + +/* + * This class provides the common load descriptors for the + * ScaledEpilogue[...] classes + */ +template +struct ScaledEpilogueBase { + protected: + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + + template + using ColOrScalarLoad = + cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< + OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoad = + cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + template + using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast< + OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; + + template + using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + template + using RowOrZeroLoad = + cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + // This utility function constructs the arguments for the load descriptors + // from a tensor. It can handle both row and column, as well as row/column or + // scalar cases. + template + static auto args_from_tensor(torch::Tensor const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = static_cast(tensor.data_ptr()); + if constexpr (std::is_same_v> || + std::is_same_v>) { + return Arguments{data_ptr, tensor.numel() != 1}; + } else { + // it would technically work but no use case as data_ptr is never nullptr + static_assert(!std::is_same_v>); + return Arguments{data_ptr}; + } + } + + // This overload handles the case where there might not be a tensor, in which + // case a nullptr is passed and a constant (0) is used. + template + static auto args_from_tensor(std::optional const& tensor) { + static_assert(std::is_same_v>); + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; + return Arguments{data_ptr}; + } +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + torch._scaled_mm. + + A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, {}}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. + * This bias can also be used in the per-tensor azp case, where the activation + * zero point (azp) is used to compute an azp correction term, + * which is folded into the bias. + * + * The bias tensor must be per-output channel. + * ScaleA and ScaleB can be per-tensor or per-token/per-channel. + */ +template +struct ScaledEpilogueBias + : protected ScaledEpilogueBase { + protected: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, bias_args, {}}; + } +}; + +/* + * This epilogue directly supports per-tensor azp in int32 form. + * As opposed to the per-token epilogue below, this epilogue only has an azp_adj + * term, which should already be multiplied with the scalar azp. + * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzp + : protected ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowOrZeroLoad; + + // This is the full AZP term, azp * J @ B, shape (1,n) + using AzpWithAdj = typename SUPER::template RowLoad; + + // Compute float(accum - azp_adj), both operands are int32_t + using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + std::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_azp_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; + } +}; + +/* + * This epilogue supports per-token azp by computing and applying + * the correction term using a rank-1 update. If the term were materialized, + * it would require O(m*n) space, and this way it only requires O(m+n) space. + * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero + * point for each row of A. + * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzpToken + : protected ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowOrZeroLoad; + + // Per-token azp term, shape (m,1) + using Azp = typename SUPER::template ColLoad; + + // This is the AZP adjustment term, J @ B, shape (1,n) + using AzpAdj = typename SUPER::template RowLoad; + + // Compute azp * azp_adj + using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, int32_t, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::threadblock::Sm80EVT; + + // Compute float(accum - azp*azp_adj), all operands are int32_t + using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAcc = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + torch::Tensor const& azp, + std::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_args = SUPER::template args_from_tensor(azp); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_acc_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; + } +}; + +}; // namespace vllm::c2x diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp new file mode 100644 index 000000000000..0a812dc56a99 --- /dev/null +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -0,0 +1,384 @@ +#pragma once + +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" + +/* + This file defines custom epilogues for fusing channel scales, token scales, + bias, and activation zero-points onto a GEMM operation using the + CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later. + + Epilogues must contain a public type named EVTCompute of type Sm90EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +namespace vllm::c3x { + +using namespace cute; + +template +struct identity { + CUTLASS_HOST_DEVICE + T operator()(T lhs) const { return lhs; } +}; + +template +struct TrivialEpilogue { + private: + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + using Compute = cutlass::epilogue::fusion::Sm90Compute< + cutlass::epilogue::thread::Identity, ElementD, ElementAcc, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + template + static ArgumentType prepare_args(Args... args) { + return {}; + } +}; + +/* + * This class provides the common load descriptors for the + * ScaledEpilogue[...] classes + */ +template +struct ScaledEpilogueBase { + protected: + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + template + using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< + 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< + 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>>; + + // Don't want to support nullptr by default + template + using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0 /*Stages*/, TileShape, T, T, Stride, Int<0>, Int<0>>, + 128 / sizeof_bits_v, EnableNullPtr>; + + // Don't want to support nullptr by default + template + using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< + 0 /*Stages*/, TileShape, T, T, Stride, Int<1>, Int<0>>, + 128 / sizeof_bits_v, EnableNullPtr>; + + // This utility function constructs the arguments for the load descriptors + // from a tensor. It can handle both row and column, as well as row/column or + // scalar cases. + template + static auto args_from_tensor(torch::Tensor const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = static_cast(tensor.data_ptr()); + if constexpr (std::is_same_v> || + std::is_same_v>) { + return Arguments{data_ptr, tensor.numel() != 1}; + } else { + static_assert(!std::is_same_v> && + !std::is_same_v>); + return Arguments{data_ptr}; + } + } + + // This overload handles the case where there might not be a tensor, in which + // case a nullptr is passed and a constant (0) is used. + template + static auto args_from_tensor(std::optional const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; + static_assert(std::is_same_v> || + std::is_same_v>); + return Arguments{data_ptr}; + } +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + torch.scaled_mm_. + + A and B may be both either int8 or fp8_e4m3. A can be + quantized per-tensor or per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, {}}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. + * This bias can also be used in the per-tensor azp case, where the activation + * zero point (azp) is used to compute an azp correction term, + * which is folded into the bias. + * + * The bias tensor must be per-output channel. + * ScaleA and ScaleB can be per-tensor or per-token/per-channel. + */ +template +struct ScaledEpilogueBias + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, bias_args, {}}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogueBias, but the + * bias is a column vector instead of a row vector. Useful e.g. if we are + * computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels. + */ +template +struct ScaledEpilogueColumnBias + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template ColLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, bias_args, {}}; + } +}; + +/* + * This epilogue directly supports per-tensor azp in int32 form. + * As opposed to the per-token epilogue below, this epilogue only has an azp_adj + * term, which should already be multiplied with the scalar azp. + * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzp + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // This is the full AZP term, azp * J @ B, shape (1,n) + using AzpWithAdj = typename SUPER::template RowLoad; + + // Compute float(accum - azp_adj), both operands are int32_t + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + std::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_azp_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; + } +}; + +/* + * This epilogue supports per-token azp by computing and applying + * the correction term using a rank-1 update. If the term were materialized, + * it would require O(m*n) space, and this way it only requires O(m+n) space. + * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero + * point for each row of A. + * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzpToken + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // Per-token azp term, shape (m,1) + using Azp = typename SUPER::template ColLoad; + + // This is the AZP adjustment term, J @ B, shape (1,n) + using AzpAdj = typename SUPER::template RowLoad; + + // Compute azp * azp_adj + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, int32_t, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + // Compute float(accum - azp*azp_adj), all operands are int32_t + using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAcc = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + torch::Tensor const& azp, + std::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_args = SUPER::template args_from_tensor(azp); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_acc_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; + } +}; + +}; // namespace vllm::c3x diff --git a/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp b/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp new file mode 100644 index 000000000000..ec75c29e54f4 --- /dev/null +++ b/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp @@ -0,0 +1,123 @@ +// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl +// clang-format off +#pragma once + +#include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl" + +#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_SS (BlockScaled Builders) +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + int ScaleGranularityM +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum, + cute::enable_if_t< + not detail::is_use_rmem_A()> +> { + using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum; + + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); + + static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v); + static constexpr bool IsFP8Input = detail::is_input_fp8(); + static_assert((!IsFP8Input || !IsArrayOfPointersGemm), + "KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now."); + + // For fp32 types, map to tf32 MMA value type + using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; + using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + static constexpr bool IsCooperative = cute::is_any_of_v>; + using AtomLayoutMNK = cute::conditional_t>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0; + static constexpr int KernelSmemCarveout = static_cast(TensorMapStorage); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp b/csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp new file mode 100644 index 000000000000..13b90e998625 --- /dev/null +++ b/csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp @@ -0,0 +1,183 @@ +// clang-format off +// adapted from: https://github.com/soundOfDestiny/cutlass/blob/a4208aa6958864923505cade9c63eb2a6daf16e5/include/cutlass/gemm/collective/fp8_accumulation.hpp + +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cute/algorithm/clear.hpp" +#include "cute/tensor.hpp" + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////FP8 Accumulation/////////////////////////// +////////////////////////////////////////////////////////////////////////////// +/// This class provides API to promote (add) or scale (multiply_add) the results +/// from the tensor core accumulators to the main accumulators when the number +/// of MMAs reaches the max number of MMA interval specified by user, after that +/// the tensor core accumulators are zeroed. +////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +template < + class EngineAccum, + class LayoutAccum> +struct GmmaFP8AccumulationWithScale { + using TensorAccum = cute::Tensor; + using ElementAccumulator = typename EngineAccum::value_type; + + static_assert(is_static::value, "Accumulator Layout should be static"); + static_assert(is_rmem::value , "Accumulator tensor must be rmem resident."); + +private: + TensorAccum& accum_; + TensorAccum accum_temp_; + + uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted. + uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop + uint32_t mma_count_; // current executed MMAs + uint32_t reset_accum_flag_; // accum needs to be zeroed or not. + + // promote or `add` the partial accumulators to main accumulator (FADD). + CUTLASS_DEVICE + void promote_core() { + warpgroup_wait<0>(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum_); ++i) { + accum_(i) += accum_temp_(i); + } + } + + // `multiply` scale the partial accumulators and `add` to main accumulator (FFMA). + template < + class EngineScale, + class LayoutScale> + CUTLASS_DEVICE + void scale_core(const cute::Tensor &scale) { + using TensorScale = cute::Tensor; + + static_assert(is_static::value, "Scale Layout should be static"); + static_assert(is_rmem::value , "Scale tensor must be rmem resident."); + + static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape."); + + warpgroup_wait<0>(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum_); ++i) { + accum_(i) += accum_temp_(i) * scale(i); + } + } + +public: + CUTLASS_DEVICE + GmmaFP8AccumulationWithScale( + TensorAccum &accum, + uint32_t accum_promotion_interval, + uint32_t mma_count_per_mainloop_iteration) + : accum_(accum), + accum_promotion_interval_(accum_promotion_interval), + mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration), + mma_count_(0), + reset_accum_flag_(0) + { + accum_temp_ = cute::make_fragment_like(accum); + } + + // + // Methods (Common) + // + + CUTLASS_DEVICE + TensorAccum& operator()() { + return accum_temp_; + } + + /// prepare the MMA accumulators when initialization or zeroing is required. + CUTLASS_DEVICE + bool prepare_if_needed() { + return reset_accum_flag_; + } + + // + // Methods (for FADD version) + // + + /// promote (add) the results from the MMA accumulators to main accumulator if needed. + CUTLASS_DEVICE + void promote_if_needed() { + mma_count_ += mma_count_per_mainloop_iteration_; + reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); + if (reset_accum_flag_) { + promote_core(); + mma_count_ = 0; + } + } + + /// promote (add) the residue results from the MMA accumulators to main accumulator if needed. + CUTLASS_DEVICE + void promote_residue_if_needed() { + if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { + promote_core(); + } + } + + // + // Methods (for FFMA version) + // + + /// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed. + template < + class EngineScale, + class LayoutScale> + CUTLASS_DEVICE + void scale_if_needed(const cute::Tensor &scale) { + mma_count_ += mma_count_per_mainloop_iteration_; + reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); + if (reset_accum_flag_) { + scale_core(scale); + mma_count_ = 0; + } + } + + /// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed. + template < + class EngineScale, + class LayoutScale> + CUTLASS_DEVICE + void scale_residue_if_needed(const cute::Tensor &scale) { + if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { + scale_core(scale); + } + } +}; + +} // namespace cutlass::gemm::collective diff --git a/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp new file mode 100644 index 000000000000..d922a3349e1e --- /dev/null +++ b/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -0,0 +1,730 @@ +// clang-format off +// Adapted (Heavily) from: https://github.com/soundOfDestiny/cutlass/blob/9d997ce0dea4c5fa1a617db6b7ff29aa9235822c/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp + +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/trace.h" +#include "cutlass/numeric_types.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm80.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +#include "cutlass_extensions/gemm/dispatch_policy.hpp" +#include "cutlass_extensions/gemm/collective/fp8_accumulation.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + int ScaleGranularityM_, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using ElementBlockScale = ElementAccumulator; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + using PipelineParams = typename MainloopPipeline::Params; + + // Two threads per CTA are producers (1 for operand tile and 32 for scales) + static constexpr int NumProducerThreadEvents = 33; + + static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_; + static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // Block scaling gmem-to-smem copy atom + using SmemBlockScalingCopyAtomA = Copy_Atom, ElementBlockScale>; + using SmemBlockScalingCopyAtomB = Copy_Atom, ElementBlockScale>; + + // Block scaling smem layout + using SmemLayoutScaleA = Layout, Int>>; + using SmemLayoutScaleB = Layout>, Stride<_1>>; // `ScaleNsPerTile` is always 1. + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v, + "ElementAccumulator and ElementBlockScale should be same datatype"); + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> { + cute::array_aligned> smem_A; // mxk + cute::array_aligned> smem_B; // nxk + cute::array_aligned> smem_scale_A; // ScaleMsPerTile x k + cute::array_aligned> smem_scale_B; // 1xk + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + ElementBlockScale const* ptr_scale_A; + ElementBlockScale const* ptr_scale_B; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy_A_sm90( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,0), + TileShape{}, + ClusterShape{})); + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy_B_sm90( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,0), + TileShape{}, + ClusterShape{})); + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + // Block scaling factors for A and B + ElementBlockScale const* ptr_scale_A; + ElementBlockScale const* ptr_scale_B; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + uint32_t transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t transaction_bytes_nk = TmaTransactionBytesNK; + uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk; + + return { + tma_load_a, + tma_load_b, + transaction_bytes, + transaction_bytes_mk, + transaction_bytes_nk, + args.ptr_scale_A, + args.ptr_scale_B + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytesMK = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) + { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + constexpr auto scales_m = Int{}; + auto tM = get<2>(gA_mkl.shape()); + auto tN = get<2>(gB_nkl.shape()); + auto tK = get<3>(gA_mkl.shape()); + + // Make the tiled views of scale tensors + auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l) + auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{}); + auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l) + auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{}); + + // Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and + // gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl. + Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l) + Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB, + class TensorScaleA, class TensorScaleB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + // Blockscaling: Tma loads for load_input and CpAsync for load_scale + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k) + Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + + // Block scaling: load_scale has scaling tensors in global memory which are not tiled + Tensor mScaleA_mkl = get<2>(load_inputs); + Tensor mScaleB_nkl = get<3>(load_inputs); + auto scales_m = get<0>(mScaleA_mkl.shape()); + + Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape()); + + Tensor gScaleA = local_tile( + mScaleA_mkl, make_tile(Int{}), + make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1) + Tensor cScaleA = local_tile( + cScaleA_mkl, make_tile(Int{}), + make_coord(m_coord,_,l_coord)); + Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1) + + // TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128 + TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, + Layout>{}, Layout>{}); // (1,1,1) + TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, + Layout>{}, Layout>{}); // (1,1,1) + ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x); + ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x); + + Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA); + Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA); + Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA); + + Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB); + Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Allocate predicate tensors for a_scales (since we can't guarantee that + // all scales are valid, since we could have a partial tiles along M) + Tensor tApA_ScaleA = make_tensor(shape(tAsA_ScaleA(_,_,0))); + #pragma unroll + for (int i = 0; i < size(tApA_ScaleA); ++i) { + tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < scales_m; + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + int write_stage = smem_pipe_write.index(); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + // Copy operands A and B from global memory to shared memory + if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + // Copy scale tensors from global memory to shared memory + copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage)); + copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage)); + pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc); + + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail( + MainloopPipeline pipeline, + PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Block scaling + Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), + Layout< + Shape, Int>, cute::tuple_element_t<1, TileShape>, Int>, + Stride, _0, Int> + >{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k) + Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k) + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C. + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Per block scale values for operand A and B + + using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout. + using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above + + Tensor tCrScaleAViewAsC = make_tensor(RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N) + ElementBlockScale scale_b; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + GmmaFP8AccumulationWithScale accumulation(accum, size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}), size<2>(tCrA)); + warpgroup_fence_operand(accumulation()); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + int read_stage = smem_pipe_read.index(); + + // Load per block scale values from shared memory to registers. + scale_b = sScaleB[read_stage]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { + tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); + } + if constexpr (ScaleMsPerTile == 1) { + static_assert(size(RegLayoutScaleAEssential{}) == 1); + tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`. + } else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { + tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; + } + } + + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` + accumulation.scale_if_needed(tCrScaleAViewAsC); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accumulation()); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + // Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N) + scale_b = sScaleB[read_stage]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { + tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); + } + if constexpr (ScaleMsPerTile == 1) { + static_assert(size(RegLayoutScaleAEssential{}) == 1); + tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`. + } else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { + tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; + } + } + + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + warpgroup_fence_operand(accumulation()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accumulation()); + + // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` + accumulation.scale_if_needed(tCrScaleAViewAsC); + + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + accumulation.scale_residue_if_needed(tCrScaleAViewAsC); + + warpgroup_fence_operand(accumulation()); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm/dispatch_policy.hpp b/csrc/cutlass_extensions/gemm/dispatch_policy.hpp new file mode 100644 index 000000000000..df809e27a3ef --- /dev/null +++ b/csrc/cutlass_extensions/gemm/dispatch_policy.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include "cutlass/gemm/dispatch_policy.hpp" + +namespace cutlass::gemm { + +////////////////////////////////////////////////////////////////////////////// + +// FP8 related policies (including Blocked Scaled Accumulation) +// `ScaleGranularityM` specifies scaling granularity along M, while zero-value +// `ScaleGranularityM` indicates that scaling granularity is +// `size<0>(TileShape_MNK{})` along M. +template +struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum + : KernelTmaWarpSpecializedCooperative {}; + +// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp +// specialized dynamic schedule For FP8 kernels with Block Scaling +template , + class KernelSchedule = KernelTmaWarpSpecialized, + int ScaleGranularityM = + 0 // `ScaleGranularityM` specifies scaling granularity along M, + // while zero-value `ScaleGranularityM` indicates that scaling + // granularity is `size<0>(TileShape_MNK{})` along M. + > +struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8 + : MainloopSm90TmaGmmaWarpSpecialized { + static_assert( + cute::is_same_v< + KernelSchedule, + KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< + ScaleGranularityM>>, + "KernelSchedule must be one of the warp specialized policies"); +}; + +////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm \ No newline at end of file diff --git a/csrc/cutlass_extensions/torch_utils.hpp b/csrc/cutlass_extensions/torch_utils.hpp new file mode 100644 index 000000000000..a1ff933cce63 --- /dev/null +++ b/csrc/cutlass_extensions/torch_utils.hpp @@ -0,0 +1,160 @@ +#pragma once + +#include + +#include "cute/layout.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/bfloat16.h" +#include "cutlass/half.h" + +using ColumnMajor = typename cutlass::layout::ColumnMajor; +using RowMajor = typename cutlass::layout::RowMajor; + +namespace cute { + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g, + seq) { + return g(f(cute::get(static_cast(t)), I)...); +} + +template +CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq) { + return make_shape(f(I)...); +} + +}; // namespace detail + +template +CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) { + if constexpr (cute::is_tuple::value) { + return detail::tapply_with_idx( + t, f, [](auto const&... a) { return cute::make_tuple(a...); }, + tuple_seq{}); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// calls: make_shape(f(0), f(1), ..., f(N-1)) +template +CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) { + return detail::make_shape_from_idx(f, make_seq{}); +} + +}; // namespace cute + +// Make a layout from a tensor with `rank(Stride{})`, where the shape is the +// shape of the passed in tensor and the strides are of type `Stride` and +// contain the strides of the passed in tensor, checking that any static strides +// in `Stride{}` match the strides of the passed in tensor. +// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra +// strides are set to be 0 or 1. +template +static inline auto make_cute_layout(torch::Tensor const& tensor, + std::string_view name = "tensor") { + TORCH_CHECK(tensor.dim() <= rank(Stride{})); + auto stride = cute::transform_with_idx( + Stride{}, [&](auto const& stride_ele, auto const& idx) { + using StrideEle = std::decay_t; + + if (idx < tensor.dim()) { + if constexpr (cute::is_static_v) { + TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ", + name, ".stride(", idx, ") to be ", StrideEle::value); + return StrideEle{}; + } else { + if (tensor.size(idx) == 1) { + // use 0 stride for dim with size 1, this is easier for + // cute/cutlass to optimize (helps the TMA code flatten dims) + return StrideEle{0}; + } else { + return tensor.stride(idx); + } + } + } else { + // Extra strides are assumed to be 0 or 1 + if constexpr (cute::is_static_v) { + static_assert(StrideEle::value == 0 || StrideEle::value == 1); + } + return StrideEle{}; + } + }); + + auto shape = cute::make_shape_from_idx([&](auto const& idx) { + if (idx < tensor.dim()) + return tensor.size(idx); + else + return int64_t(1); + }); + + return make_layout(shape, stride); +} + +template +static inline auto maybe_make_cute_layout( + std::optional const& tensor, + std::string_view name = "tensor") { + using Layout = decltype(make_cute_layout(*tensor)); + + if (tensor) { + return std::optional{make_cute_layout(*tensor, name)}; + } else { + return std::optional{}; + } +} + +// +// Torch Type to Cutlass Type (equivalent_cutlass_type) +// + +template +struct equivalent_cutlass_type { + using type = T; +}; + +template +using equivalent_cutlass_type_t = typename equivalent_cutlass_type::type; + +template <> +struct equivalent_cutlass_type { + using type = cutlass::half_t; +}; + +template <> +struct equivalent_cutlass_type { + using type = cutlass::bfloat16_t; +}; + +// +// equivalent_scalar_t (basically inverse of equivalent_cutlass_type) +// + +// Return a `c10::CppTypeToScalarType` compatible type, i.e. get the C++ from +// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half` +template +struct equivalent_scalar_type { + using type = T; +}; + +template +using equivalent_scalar_type_t = typename equivalent_scalar_type::type; + +template <> +struct equivalent_scalar_type { + using type = c10::Half; +}; + +template <> +struct equivalent_scalar_type { + using type = c10::BFloat16; +}; + +// get equivalent c10::ScalarType tag from compile time type +template +static inline constexpr c10::ScalarType equivalent_scalar_type_v = + c10::CppTypeToScalarType>::value; \ No newline at end of file diff --git a/csrc/cutlass_extensions/vllm_collective_builder.cuh b/csrc/cutlass_extensions/vllm_collective_builder.cuh new file mode 100644 index 000000000000..e7fbba4cd4b0 --- /dev/null +++ b/csrc/cutlass_extensions/vllm_collective_builder.cuh @@ -0,0 +1,43 @@ +#pragma once + +#include "cutlass_extensions/gemm/collective/collective_builder.hpp" + +namespace cutlass::gemm::collective { +using namespace cute; + +// +// VLLMCollectiveBuilder is a wrapper around CollectiveBuilder that allows for +// for custom kernel tags, allowing you to build custom collectives. Without +// touching the cutlass library headers, using `CutlassKernelTag` will mean it +// will resort to using the standard cutlass collective builder. +// + +// Use the default Cutlass collective builder, i.e. use an unmodified cutless +// collective +struct CutlassKernelTag {}; + +template +struct VLLMCollectiveBuilder { + static_assert(sizeof(ElementA) == 0, + "Could not build a collective for given parameters."); +}; + +template +struct VLLMCollectiveBuilder< + CutlassKernelTag, ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, + ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, + ClusterShape_MNK, StageCountType, KernelScheduleType> { + using CollectiveOp = typename CollectiveBuilder< + ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, ElementB, + GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, + ClusterShape_MNK, StageCountType, KernelScheduleType>::CollectiveOp; +}; + +}; // namespace cutlass::gemm::collective \ No newline at end of file diff --git a/csrc/cutlass_extensions/vllm_custom_types.cuh b/csrc/cutlass_extensions/vllm_custom_types.cuh new file mode 100644 index 000000000000..6146bdc1f08c --- /dev/null +++ b/csrc/cutlass_extensions/vllm_custom_types.cuh @@ -0,0 +1,50 @@ +#pragma once + +#include "cutlass/integer_subbyte.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct vllm_biased_integer_subbyte : public integer_subbyte { + using Base = integer_subbyte; + + using Storage = typename Base::Storage; + using xint_t = typename Base::xint_t; + + using Base::bits_mask_; + using Base::sign_mask_; + using Base::storage; + + // + // Methods + // + + /// No operation + vllm_biased_integer_subbyte() = default; + + /// Conversion from integer type + CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(int value) + : Base(value) {} + CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(unsigned value) + : Base(value) {} + CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(double value) + : Base(value) {} +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// "GPTQ" types, i.e. symmetric quantization +using vllm_uint4b8_t = vllm_biased_integer_subbyte<4, 8>; // u4b8 +using vllm_uint8b128_t = vllm_biased_integer_subbyte<8, 128>; // u8b128 + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct sizeof_bits> { + static constexpr int value = Bits; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py new file mode 100644 index 000000000000..d64f0d0a5c2a --- /dev/null +++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 + +import enum +from typing import Union + +from cutlass_library import * + +# +# Extend cutlass library with custom types, and missing values +# + + +class VLLMDataType(enum.Enum): + u4b8 = enum_auto() + u8b128 = enum_auto() + + +class MixedInputKernelScheduleType(enum.Enum): + TmaWarpSpecialized = enum_auto() + TmaWarpSpecializedPingpong = enum_auto() + TmaWarpSpecializedCooperative = enum_auto() + + +VLLMDataTypeNames: dict[Union[VLLMDataType, DataType], str] = { + **DataTypeNames, # type: ignore + **{ + VLLMDataType.u4b8: "u4b8", + VLLMDataType.u8b128: "u8b128", + } +} + +VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = { + **DataTypeTag, # type: ignore + **{ + VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t", + VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t", + } +} + +VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = { + **DataTypeSize, # type: ignore + **{ + VLLMDataType.u4b8: 4, + VLLMDataType.u8b128: 8, + } +} + +VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = { + VLLMDataType.u4b8: "vllm::kU4B8", + VLLMDataType.u8b128: "vllm::kU8B128", + DataType.u4: "vllm::kU4", + DataType.u8: "vllm::kU8", + DataType.s4: "vllm::kS4", + DataType.s8: "vllm::kS8", + DataType.f16: "vllm::kFloat16", + DataType.bf16: "vllm::kBfloat16", +} + +VLLMDataTypeTorchDataTypeTag: dict[Union[VLLMDataType, DataType], str] = { + DataType.u8: "at::ScalarType::Byte", + DataType.s8: "at::ScalarType::Char", + DataType.e4m3: "at::ScalarType::Float8_e4m3fn", + DataType.s32: "at::ScalarType::Int", + DataType.f16: "at::ScalarType::Half", + DataType.bf16: "at::ScalarType::BFloat16", + DataType.f32: "at::ScalarType::Float", +} + +VLLMKernelScheduleTag: dict[Union[ + MixedInputKernelScheduleType, KernelScheduleType], str] = { + **KernelScheduleTag, # type: ignore + **{ + MixedInputKernelScheduleType.TmaWarpSpecialized: + "cutlass::gemm::KernelTmaWarpSpecialized", + MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: + "cutlass::gemm::KernelTmaWarpSpecializedPingpong", + MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: + "cutlass::gemm::KernelTmaWarpSpecializedCooperative", + } + } diff --git a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh new file mode 100644 index 000000000000..90f226cf64c0 --- /dev/null +++ b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh @@ -0,0 +1,992 @@ +#pragma once + +#include "cutlass/numeric_conversion.h" +#include "cutlass_extensions/vllm_custom_types.cuh" +#include "cutlass_extensions/cute_utils.cuh" +#include "cutlass_extensions/vllm_type_utils.cuh" + +// this file extends: +// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h +// with vllm specific type conversions, namely: vllm_uint4b8_t, vllm_uint8b128_t +// as well as adds interleaved numeric array converters for specific types. +// (interleaved numeric array converters can be more efficient for subbyte +// types) + +namespace cutlass { + +// InterleavedNumericArrayConverter is like NumericArrayConverter but also +// deinterleaves converted elements based on IlvBlkLayout, interleaving can +// make subbyte converts more efficient by allowing for efficient extraction +// of subbyte elements from a 32bit register. +template +struct InterleavedNumericArrayConverter { + using Converter = NumericArrayConverter; + + using result_type = typename Converter::result_type; + using source_type = typename Converter::source_type; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + if (cute::elect_one_sync()) { + if constexpr (std::is_same_v) { + printf( + "Convert %s <= %s (N = %d, IlvBlkLayout = void), not implemented\n", + nameof_v, nameof_v, N); + } else { + printf( + "Convert %s <= %s (N = %d, size(IlvBlkLayout{}) = %d), not " + "implemented\n", + nameof_v, nameof_v, N, size(IlvBlkLayout{})); + } + __brkpt(); + } + return {}; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +template +struct InterleavedNumericArrayConverter< + IlvBlkLayout, T, S, N, Round, + std::enable_if_t()>> { + using Converter = NumericArrayConverter; + + using result_type = typename Converter::result_type; + using source_type = typename Converter::source_type; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return Converter::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +template +struct ArrayConverterPacked32Bit { + using result_type = Array; + using source_type = Array; + + using result_packed_8_t = Array; + using result_packed_4_t = Array; + using result_packed_2_t = Array; + using src_packed_8_t = Array; + using src_packed_4_t = Array; + using src_packed_2_t = Array; + + static_assert(N % 2 == 0, "N must be a multiple of 2"); + static_assert(cutlass::sizeof_bits_v >= 4); // TODO: add 16 packed sources + static_assert(32 % cutlass::sizeof_bits_v == 0); + static constexpr auto src_elems_per_32bit_reg = + 32 / cutlass::sizeof_bits_v; + + // Maybe not Valid. ScalarConverter will not actually work unless + // NumericConverter is implemented. However it won't be used + // anyways since we assert N % 2 == 0, just here for compliance with + // VectorizedConverter. + using ScalarConverter = NumericConverter; + + template + CUTLASS_DEVICE static auto to_regs(PackedSrc const& src) { + if constexpr (sizeof(PackedSrc) == 1) { + return Array{reinterpret_cast(src)}; + } else if constexpr (sizeof(PackedSrc) == 2) { + return Array{reinterpret_cast(src)}; + } else if constexpr (sizeof(PackedSrc) == 4) { + return Array{reinterpret_cast(src)}; + } else { + static_assert(sizeof(PackedSrc) == 8); + return reinterpret_cast const&>(src); + } + } + + // The core converter uses bit tricks to construct a known FP16 number, then + // does a subtraction in FP16 for the final result. + template + CUTLASS_DEVICE static PackedResultType packed_convert( + PackedSrcType const& source) { + static_assert(PackedSrcType::kElements == PackedResultType::kElements); + static_assert(PackedResultType::kElements == 2 || + PackedResultType::kElements == 4 || + PackedResultType::kElements == 8, + "Invalid PackedResultType must be 2, 4 or 8."); + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + return RegConvert32bit::template convert(to_regs(source)); + } + + friend class detail::VectorizedConverter; + + public: + CUTLASS_DEVICE static result_type convert(source_type const& source) { + result_type result; + using ConverterType = + ArrayConverterPacked32Bit; + + if constexpr (src_elems_per_32bit_reg >= 8) { + detail::VectorizedConverter::convert< + ConverterType, result_packed_8_t, src_packed_8_t, result_packed_4_t, + src_packed_4_t, result_packed_2_t, src_packed_2_t>(result, source); + } else if constexpr (src_elems_per_32bit_reg >= 4) { + detail::VectorizedConverter::convert(result, source); + } else { + detail::VectorizedConverter::convert(result, source); + } + + return result; + } +}; + +// Convert 8 4bit values packed into a 32bit register to 8 8bit values packed +// into 2 32bit register. +template +CUTLASS_DEVICE cutlass::AlignedArray lut_4bit_to_8bit_convert( + uint32_t src) { + cutlass::AlignedArray r; + // Determines if the value is in the top half of the LUT if set or + // (i.e. LUT[8:15]) in the bottom half (i.e. LUT[0:7]) if not set. Then move + // into bit position 0x4 of each nibble so when or'd with final_prmt_base it + // selects the correct candidate. When elements in final_prmt_base + // are >= 0x4, the high candidate is selected (i.e. LUT[8:15]), when elements + // are < 0x4, the low candidate is selected (i.e. LUT[0:7]) + uint32_t high_bit = (src & 0x88888888) >> 1; + + // `high_bit` is OR'd with 0x31203120 to find the correct value in the LUT + // (selects correct high or low candidate) + const uint32_t final_prmt_base = 0x32103210; + + // Ignore the high bit when indexing into LUT, for each 4bit value + // we index into both the high and low candidates then use + // high_bit | final_prmt_base to select the correct candidate + uint32_t lut_idx = (src & 0x77777777); + + auto pack = [](uint8_t a, uint8_t b, uint8_t c, uint8_t d) { + return uint32_t(a) | (uint32_t(b) << 8) | (uint32_t(c) << 16) | + (uint32_t(d) << 24); + }; + + static constexpr uint32_t LOW_0 = pack(LUT0, LUT1, LUT2, LUT3); + static constexpr uint32_t LOW_1 = pack(LUT4, LUT5, LUT6, LUT7); + static constexpr uint32_t HIGH_0 = pack(LUT8, LUT9, LUT10, LUT11); + static constexpr uint32_t HIGH_1 = pack(LUT12, LUT13, LUT14, LUT15); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 2; ++ii, lut_idx >>= 16, high_bit >>= 16) { + uint32_t final_prmt_idx = final_prmt_base | high_bit; + + // This uses a look up table to convert packed int4s to packed int8s, + // using the int4 value as the index to prmt. It first select both the + // high and low candidates, then uses the high bit (i.e. `high_bit`) to + // select the correct candidate. + asm volatile( + "{\n" + " .reg .b32 low, high;\n" + " prmt.b32 low, %1, %2, %5;\n" + " prmt.b32 high, %3, %4, %5;\n" + " prmt.b32 %0, low, high, %6;\n" + "}\n" + : "=r"(r[ii]) + : "n"(LOW_0), "n"(LOW_1), "n"(HIGH_0), "n"(HIGH_1), "r"(lut_idx), + "r"(final_prmt_idx)); + } + + return r; +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + // [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as int8s + auto r = lut_4bit_to_8bit_convert<0xF8, 0xF9, 0xFA, 0xFB, // + 0xFC, 0xFD, 0xFE, 0xFF, // + 0x00, 0x01, 0x02, 0x03, // + 0x04, 0x05, 0x06, 0x07>(src_[0]); + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + // [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as fp8s + auto r = lut_4bit_to_8bit_convert<0xD0, 0xCE, 0xCC, 0xCA, // + 0xC8, 0xC4, 0xC0, 0xB8, // + 0x00, 0x38, 0x40, 0x44, // + 0x48, 0x4A, 0x4C, 0x4E>(src_[0]); + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; + using RegArray = + cutlass::AlignedArray; + RegArray r; + + // Below constructs the following temporary: + // fp16s_01 = {0x00, i4_01, 0x00, i4_01} + // fp16s_23 = {0x00, i4_23, 0x00, i4_23} + // fp16s_45 = {0x00, i4_45, 0x00, i4_45} + // fp16s_67 = {0x00, i4_67, 0x00, i4_67} + // We use inline asm instead of __byte_perm intrinsic since we don't want + // the documented (& 0x7) on the index. NVCC might be able to optimize it + // out since the index is a constexpr, but we choose to be safe about it + // here. + uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343}; + static_assert(RegArray::kElements <= 4, + "Too many inputs for F16 -> I4 vector converter"); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(r[ii]) + : "r"(src), "n"(0), "r"(prmt_indices[ii])); + } + + // Since the stored 4bit values are biased by 8 we get stored_val = (x+8) + // we are trying to construct x and a fp16 value + // The below XOR does the following: + // 1) Sets the exponent bits of the FP16 to the correct value for the + // FP16 magic_num. We will be constructing {1024+16*(x1+8), 1024+(x0+8)}, + // where x1 in the high nibble and x0 is the low nibble then using hfma + // to subtract 1032 from that + // The AND does the following: + // 1) Clear the set bits for the int4 we will ignore. + // We use lop3 so that we can use 1 instruction for AND and XOR. + static constexpr uint32_t xor_mask = 0x64006400; + static constexpr uint32_t and_mask = 0xFFF0FF0F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + // For each operand, computes: + // r[i] = (r[i] & and_mask) ^ xor_mask + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // We will issue 2 hfmas that do the following: + // {x1, x0} = {1024+16*(x1+8), 1024+(x0+8)} * {1/16, 1} - {72, 1032} + // = {x1 + 1152, x0 + 1032} * {1/16, 1} - {72, 1032} + static constexpr uint32_t hfma_bias_rep = 0xD480E408; // {72, 1032} + static constexpr uint32_t hfma_scale_rep = 0x2C003C00; // {1 / 16, 1} + + const half2& hfma_bias = reinterpret_cast(hfma_bias_rep); + const half2& hfma_scale = reinterpret_cast(hfma_scale_rep); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias); + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::half_t, vllm_uint4b8_t, N, + Round, void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t xor_mask = 0x64006400; + + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + auto src_ = src >> (4 * (ii)); + r[ii + 0] = src_; + r[ii + 1] = src_; + + static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa; + + static constexpr uint32_t low_nib_mask = 0x000F000F; + static constexpr uint32_t high_nib_mask = 0x00F000F0; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 1]) + : "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + // For low nibble: + // {x1, x0} = {1024+(x1+8), 1024+(x0+8)} * {1, 1} - {1032, 1032} + // For high nibble: + // {x1, x0} = {1024+16*(x1+8), 1024+16*(x0+8)} * {1/16, 1/16} + // - {72, 72} + static constexpr uint32_t low_nib_bias = 0x64086408; // {1032, 1032} + static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16} + static constexpr uint32_t high_nib_bias = 0xD480D480; // {-72, -72} + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); + fp16x2_val = + __hsub2(fp16x2_val, reinterpret_cast(low_nib_bias)); + } + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); + fp16x2_val = __hfma2(fp16x2_val, + reinterpret_cast(high_nib_scale), + reinterpret_cast(high_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::half_t, uint4_t, N, Round, + void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t xor_mask = 0x64006400; + + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + auto src_ = src >> (4 * (ii)); + r[ii + 0] = src_; + r[ii + 1] = src_; + + static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa; + + static constexpr uint32_t low_nib_mask = 0x000F000F; + static constexpr uint32_t high_nib_mask = 0x00F000F0; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 1]) + : "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + // For low nibble: + // {x1, x0} = {1024+x1, 1024+x0} - {1024, 1024} + // For high nibble: + // {x1, x0} = {1024+16*x1, 1024+16*x0} * {1/16, 1/16} - {64, 64} + static constexpr uint32_t low_nib_bias = 0x64006400; // {1024, 1024} + static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16} + static constexpr uint32_t high_nib_bias = 0xD400D400; // {-64, -64} + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); + fp16x2_val = + __hsub2(fp16x2_val, reinterpret_cast(low_nib_bias)); + } + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); + fp16x2_val = __hfma2(fp16x2_val, + reinterpret_cast(high_nib_scale), + reinterpret_cast(high_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; + // Hold output FP16s in reg. We need 1 reg for every 2 elements + using RegArray = + cutlass::AlignedArray; + RegArray r; + + uint32_t const prmt_indices[2] = {0x5150, 0x5352}; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(r[ii]) + : "r"(src), "n"(start_byte_for_fp16), + "r"(prmt_indices[ii])); + } + + // -128 is folded into bias subtraction, i.e. the 0x80 in the low bytes + static constexpr uint32_t bias_rep = 0x64806480; + const half2& bias = reinterpret_cast(bias_rep); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hsub2(fp16x2_val, bias); + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; + PackedResultType r; + + // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of + // u8x4 source and stores the result in r (without introducing extra + // cvt.u32.u8 instruction) + uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653}; + uint32_t* result_as_int = reinterpret_cast(&r); + for (int ii = 0; ii < PackedResultType::kElements; ++ii) { + result_as_int[ii] = __byte_perm(src, 0x4B000000, prmt_indices[ii]); + // Subtract the magic number 0x4B000000 from tmp in floating-point + // arithmetic to obtain final result + r[ii] -= (8388608.f + 128.f); // fold in -128 bias + } + + return r; + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src_reg = src_[0]; + // Hold output BF16s in reg. We need 1 reg for every 2 elements + using RegArray = + cutlass::AlignedArray; + RegArray r; + uint32_t src_reg_shifted = src_reg >> 4; + + // Below constructs the following temporary: + uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3}; + static_assert(RegArray::kElements <= 4, + "Too many inputs for uint4b8_t -> BF16 vector converter"); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(r[ii]) + : "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii])); + } + + // Since the stored 4bit values are biased by 8 we get stored_val = (x+8) + // we are trying to construct x and a BF16 value + // The below XOR does the following: + // 1) Sets the exponent bits of the BF16 to the correct value for the + // BF16 magic_num. We will be constructing {128 + (x1+8), 128 + (x0+8)} + // and subtracting 136 to get {x1, x0} + static constexpr uint32_t xor_mask = 0x43004300; + static constexpr uint32_t and_mask = 0x000F000F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + // For each operand, computes: + // r[i] = (r[i] & and_mask) ^ xor_mask + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // We will issue 2 bfmas that do the following: + // high BF16: + // hi_bf16 - 136, lo_bf16 - 136 + + // This is the BF16 {136, 136} represented as an integer. + static constexpr uint32_t bias_rep = 0x43084308; + const __nv_bfloat162& bias = + reinterpret_cast(bias_rep); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hsub2(bf16x2_val, bias); + } + + return reinterpret_cast(r); + } + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::bfloat16_t, vllm_uint4b8_t, N, + Round, void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t or_mask = 0x43004300; + + // Unlike float16 where the mantissa is large enough to contain 2 + // nibbles, bfloat16 can only fit one, so we can only convert one + // nibble at a time + for (int ii = 0; ii < RegArray::kElements; ++ii) { + r[ii] = src >> (4 * ii); + + static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t low_nib_mask = 0x000F000F; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut)); + + // For low nibble: + // {x1, x0} = {128+(x1+8), 128+(x0+8)} * {1, 1} - {136, 136} + static constexpr uint32_t low_nib_bias = 0x43084308; // {136, 136} + + { + __nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + fp16x2_val = + __hsub2(fp16x2_val, + reinterpret_cast(low_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::bfloat16_t, uint4_t, N, Round, + void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t or_mask = 0x43004300; + + // Unlike float16 where the mantissa is large enough to contain 2 + // nibbles, bfloat16 can only fit one, so we can only convert one + // nibble at a time + for (int ii = 0; ii < RegArray::kElements; ++ii) { + r[ii] = src >> (4 * ii); + + static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t low_nib_mask = 0x000F000F; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut)); + + // For low nibble: + // {x1, x0} = {128 + x1, 128 + x0} * {1, 1} - {128, 128} + static constexpr uint32_t low_nib_bias = 0x43004300; // {128, 128} + + { + __nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + fp16x2_val = + __hsub2(fp16x2_val, + reinterpret_cast(low_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + private: + using result_packed_4_t = Array; + using result_packed_2_t = Array; + using src_packed_4_t = Array; + using src_packed_2_t = Array; + + // Not Valid, not supported, only here to satisfy the interface and to avoid + // a compile error. ScalarConverter will not actually work until + // NumericConverter is + // implemented + using ScalarConverter = + NumericConverter; + + template + CUTLASS_DEVICE static PackedResultType packed_convert( + PackedSrcType const& source) { + static_assert( + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private " + "convert dispatch."); + + NumericArrayConverter + convert_uint8_to_f32; + Array tmp = + convert_uint8_to_f32(source); + NumericArrayConverter + convert_f32_to_bf16_; + return convert_f32_to_bf16_(tmp); + } + + friend class detail::VectorizedConverter; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + using ConverterType = + NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +#endif + +// for Array <= Array +// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904 +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + // FastFP16toINT8 from https://arxiv.org/pdf/2406.09904 + template + CUTLASS_DEVICE static PackedResultType convert( + Array src) { + // Hold output int8s in reg. We need 1 reg for every 4 elements + using RegArray = cutlass::AlignedArray< + uint32_t, std::max(PackedResultType::kElements / 4, size_t(1))>; + RegArray r; + + static constexpr uint32_t MAGIC_BIAS_ = 0x64806480; + auto MAGIC_BIAS = *reinterpret_cast(&MAGIC_BIAS_); + + *reinterpret_cast(&src[0]) = + __hadd2(*reinterpret_cast(&src[0]), MAGIC_BIAS); + + if constexpr (src_regs > 1) { + *reinterpret_cast(&src[1]) = + __hadd2(*reinterpret_cast(&src[1]), MAGIC_BIAS); + } + + static_assert(PackedResultType::kElements <= 4); + uint32_t uint8s; + static constexpr uint32_t MASK_0246 = 0x6420; + static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(uint8s) + : "r"(src[0]), "r"((src_regs > 1) ? src[1] : src[0]), + "n"(MASK_0246)); + + uint32_t int8s = (uint8s ^ UINT8s_TO_INT8s_MASK); + + return reinterpret_cast(int8s); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/vllm_type_utils.cuh b/csrc/cutlass_extensions/vllm_type_utils.cuh new file mode 100644 index 000000000000..500ed508c830 --- /dev/null +++ b/csrc/cutlass_extensions/vllm_type_utils.cuh @@ -0,0 +1,42 @@ +#include "cutlass/bfloat16.h" +#include "cutlass/half.h" +#include "cuda_bf16.h" + +#include "cutlass_extensions/vllm_custom_types.cuh" + +namespace cutlass { + +template +struct nameof { + static constexpr char const* value = "unknown"; +}; + +template +inline constexpr auto nameof_v = nameof::value; + +#define NAMEOF_TYPE(T) \ + template <> \ + struct nameof { \ + static constexpr char const* value = #T; \ + }; + +NAMEOF_TYPE(float_e4m3_t) +NAMEOF_TYPE(float_e5m2_t) +NAMEOF_TYPE(half_t) +NAMEOF_TYPE(nv_bfloat16) +NAMEOF_TYPE(bfloat16_t) +NAMEOF_TYPE(float) + +NAMEOF_TYPE(int4b_t) +NAMEOF_TYPE(int8_t) +NAMEOF_TYPE(int32_t) +NAMEOF_TYPE(int64_t) + +NAMEOF_TYPE(vllm_uint4b8_t) +NAMEOF_TYPE(uint4b_t) +NAMEOF_TYPE(uint8_t) +NAMEOF_TYPE(vllm_uint8b128_t) +NAMEOF_TYPE(uint32_t) +NAMEOF_TYPE(uint64_t) + +}; // namespace cutlass \ No newline at end of file diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index a634e1c3d488..dc6e0769b878 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -6,6 +6,11 @@ #include +// Need a special dispatch case macro since we will nest the FP8 dispatch. +// Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'. +#define AT_DISPATCH_FP8_CASE(enum_type, ...) \ + AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__) + #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ @@ -14,6 +19,35 @@ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +// ROCm devices might use either fn or fnuz, so set up dispatch table for both. +// A host-based check at runtime will create a preferred FP8 type for ROCm +// such that the correct kernel is dispatched. +#ifdef USE_ROCM + #define VLLM_DISPATCH_CASE_FP8_TYPES(...) \ + AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ + AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) + + #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) +#else + #define VLLM_DISPATCH_CASE_FP8_TYPES(...) \ + AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) + + #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) +#endif + +// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'. +// See AT_DISPATCH_FP8_CASE above. +#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) + #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index ca1c04bd880d..fb6882f3e7c3 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,18 +1,13 @@ -#include -#include +#include "type_convert.cuh" +#include "dispatch_utils.h" + +#include #include -#include "dispatch_utils.h" -#include "reduction_utils.cuh" #ifndef USE_ROCM - #include - #include + #include #else - #include - #include - -using __nv_bfloat16 = __hip_bfloat16; -using __nv_bfloat162 = __hip_bfloat162; + #include #endif namespace vllm { @@ -31,7 +26,11 @@ __global__ void rms_norm_kernel( const float x = (float)input[blockIdx.x * hidden_size + idx]; variance += x * x; } - variance = blockReduceSum(variance); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } @@ -44,155 +43,6 @@ __global__ void rms_norm_kernel( } } -/* Converter structs for the conversion from torch types to HIP/CUDA types, - and the associated type conversions within HIP/CUDA. These helpers need - to be implemented for now because the relevant type conversion - operators/constructors are not consistently implemented by HIP/CUDA, so - a generic conversion via type casts cannot be implemented. - - Each struct should have the member static constexpr bool `exists`: - If false, the optimized kernel is not used for the corresponding torch type. - If true, the struct should be fully defined as shown in the examples below. - */ -template -struct _typeConvert { - static constexpr bool exists = false; -}; - -#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) -// CUDA < 12.0 runs into issues with packed type conversion -template <> -struct _typeConvert { - static constexpr bool exists = true; - using hip_type = __half; - using packed_hip_type = __half2; - - __device__ static inline float convert(hip_type x) { return __half2float(x); } - __device__ static inline float2 convert(packed_hip_type x) { - return __half22float2(x); - } - __device__ static inline hip_type convert(float x) { - return __float2half_rn(x); - } - __device__ static inline packed_hip_type convert(float2 x) { - return __float22half2_rn(x); - } -}; - - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -// CUDA_ARCH < 800 does not have BF16 support -// TODO: Add in ROCm support once public headers handle bf16 maturely -template <> -struct _typeConvert { - static constexpr bool exists = true; - using hip_type = __nv_bfloat16; - using packed_hip_type = __nv_bfloat162; - - __device__ static inline float convert(hip_type x) { - return __bfloat162float(x); - } - __device__ static inline float2 convert(packed_hip_type x) { - return __bfloat1622float2(x); - } - __device__ static inline hip_type convert(float x) { - return __float2bfloat16(x); - } - __device__ static inline packed_hip_type convert(float2 x) { - return __float22bfloat162_rn(x); - } -}; - #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= - // 12000)) - -/* Vector POD struct to generate vectorized and packed FP16/BF16 ops - for appropriate specializations of fused_add_rms_norm_kernel. - Only functions that are necessary in that kernel are implemented. - Alignment to 16 bytes is required to use 128-bit global memory ops. - */ -template -struct alignas(16) _f16Vec { - /* Not theoretically necessary that width is a power of 2 but should - almost always be the case for optimization purposes */ - static_assert(width > 0 && (width & (width - 1)) == 0, - "Width is not a positive power of 2!"); - using Converter = _typeConvert; - using T1 = typename Converter::hip_type; - using T2 = typename Converter::packed_hip_type; - T1 data[width]; - - __device__ _f16Vec& operator+=(const _f16Vec& other) { - if constexpr (width % 2 == 0) { -#pragma unroll - for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i + 1]}; - temp += T2{other.data[i], other.data[i + 1]}; - data[i] = temp.x; - data[i + 1] = temp.y; - } - } else { -#pragma unroll - for (int i = 0; i < width; ++i) data[i] += other.data[i]; - } - return *this; - } - - __device__ _f16Vec& operator*=(const _f16Vec& other) { - if constexpr (width % 2 == 0) { -#pragma unroll - for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i + 1]}; - temp *= T2{other.data[i], other.data[i + 1]}; - data[i] = temp.x; - data[i + 1] = temp.y; - } - } else { -#pragma unroll - for (int i = 0; i < width; ++i) data[i] *= other.data[i]; - } - return *this; - } - - __device__ _f16Vec& operator*=(const float scale) { - if constexpr (width % 2 == 0) { -#pragma unroll - for (int i = 0; i < width; i += 2) { - float2 temp_f = Converter::convert(T2{data[i], data[i + 1]}); - temp_f.x *= scale; - temp_f.y *= scale; - T2 temp = Converter::convert(temp_f); - data[i] = temp.x; - data[i + 1] = temp.y; - } - } else { -#pragma unroll - for (int i = 0; i < width; ++i) { - float temp = Converter::convert(data[i]) * scale; - data[i] = Converter::convert(temp); - } - } - return *this; - } - - __device__ float sum_squares() const { - float result = 0.0f; - if constexpr (width % 2 == 0) { -#pragma unroll - for (int i = 0; i < width; i += 2) { - float2 z = Converter::convert(T2{data[i], data[i + 1]}); - result += z.x * z.x + z.y * z.y; - } - } else { -#pragma unroll - for (int i = 0; i < width; ++i) { - float x = Converter::convert(data[i]); - result += x * x; - } - } - return result; - } -}; - /* Function specialization in the case of FP16/BF16 tensors. Additional optimizations we can make in this case are packed and vectorized operations, which help with the @@ -228,12 +78,11 @@ fused_add_rms_norm_kernel( variance += temp.sum_squares(); residual_v[id] = temp; } - /* Keep the following if-else block in sync with the - calculation of max_block_size in fused_add_rms_norm */ - if (num_tokens < 256) { - variance = blockReduceSum(variance); - } else - variance = blockReduceSum(variance); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } @@ -268,12 +117,11 @@ fused_add_rms_norm_kernel( variance += x * x; residual[blockIdx.x * hidden_size + idx] = z; } - /* Keep the following if-else block in sync with the - calculation of max_block_size in fused_add_rms_norm */ - if (num_tokens < 256) { - variance = blockReduceSum(variance); - } else - variance = blockReduceSum(variance); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu new file mode 100644 index 000000000000..d595b9e889c8 --- /dev/null +++ b/csrc/layernorm_quant_kernels.cu @@ -0,0 +1,242 @@ +/* + * This file contains the CUDA kernels for the fused quantized layernorm. + * The kernels correspond to the kernels in layernorm_kernels.cu, except they + * also produce quantized output directly. + * Currently, only static fp8 quantization is supported. + */ + +#include "type_convert.cuh" +#include "quantization/fp8/common.cuh" +#include "dispatch_utils.h" + +#include +#include + +#ifndef USE_ROCM + #include +#else + #include +#endif + +namespace vllm { + +// TODO(woosuk): Further optimize this kernel. +template +__global__ void rms_norm_static_fp8_quant_kernel( + fp8_type* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float* __restrict__ scale, // [1] + const float epsilon, const int num_tokens, const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + const float x = (float)input[blockIdx.x * hidden_size + idx]; + variance += x * x; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + // invert scale to avoid division + float const scale_inv = 1.0f / *scale; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)input[blockIdx.x * hidden_size + idx]; + float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; + out[blockIdx.x * hidden_size + idx] = + scaled_fp8_conversion(out_norm, scale_inv); + } +} + +/* Function specialization in the case of FP16/BF16 tensors. + Additional optimizations we can make in this case are + packed and vectorized operations, which help with the + memory latency bottleneck. */ +template +__global__ std::enable_if_t<(width > 0) && _typeConvert::exists> +fused_add_rms_norm_static_fp8_quant_kernel( + fp8_type* __restrict__ out, // [..., hidden_size] + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float* __restrict__ scale, // [1] + const float epsilon, const int num_tokens, const int hidden_size) { + // Sanity checks on our vector struct and type-punned pointer arithmetic + static_assert(std::is_pod_v<_f16Vec>); + static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); + + const int vec_hidden_size = hidden_size / width; + __shared__ float s_variance; + float variance = 0.0f; + /* These and the argument pointers are all declared `restrict` as they are + not aliased in practice. Argument pointers should not be dereferenced + in this kernel as that would be undefined behavior */ + auto* __restrict__ input_v = + reinterpret_cast<_f16Vec*>(input); + auto* __restrict__ residual_v = + reinterpret_cast<_f16Vec*>(residual); + auto* __restrict__ weight_v = + reinterpret_cast*>(weight); + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16Vec temp = input_v[id]; + temp += residual_v[id]; + variance += temp.sum_squares(); + residual_v[id] = temp; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + // invert scale to avoid division + float const scale_inv = 1.0f / *scale; + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16Vec temp = residual_v[id]; + temp *= s_variance; + temp *= weight_v[idx]; +#pragma unroll + for (int i = 0; i < width; ++i) { + out[id * width + i] = + scaled_fp8_conversion(float(temp.data[i]), scale_inv); + } + } +} + +/* Generic fused_add_rms_norm_kernel + The width field is not used here but necessary for other specializations. + */ +template +__global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> +fused_add_rms_norm_static_fp8_quant_kernel( + fp8_type* __restrict__ out, // [..., hidden_size] + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float* __restrict__ scale, // [1] + const float epsilon, const int num_tokens, const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + scalar_t z = input[blockIdx.x * hidden_size + idx]; + z += residual[blockIdx.x * hidden_size + idx]; + float x = (float)z; + variance += x * x; + residual[blockIdx.x * hidden_size + idx] = z; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + // invert scale to avoid division + float const scale_inv = 1.0f / *scale; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)residual[blockIdx.x * hidden_size + idx]; + float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; + out[blockIdx.x * hidden_size + idx] = + scaled_fp8_conversion(out_norm, scale_inv); + } +} + +} // namespace vllm + +void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + torch::Tensor& scale, // [1] + double epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "rms_norm_kernel_scalar_type", [&] { + VLLM_DISPATCH_FP8_TYPES( + out.scalar_type(), "rms_norm_kernel_fp8_type", [&] { + vllm::rms_norm_static_fp8_quant_kernel + <<>>( + out.data_ptr(), input.data_ptr(), + weight.data_ptr(), scale.data_ptr(), + epsilon, num_tokens, hidden_size); + }); + }); +} + +#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \ + VLLM_DISPATCH_FP8_TYPES( \ + out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \ + vllm::fused_add_rms_norm_static_fp8_quant_kernel \ + <<>>( \ + out.data_ptr(), input.data_ptr(), \ + residual.data_ptr(), \ + weight.data_ptr(), scale.data_ptr(), \ + epsilon, num_tokens, hidden_size); \ + }); \ + }); +void fused_add_rms_norm_static_fp8_quant( + torch::Tensor& out, // [..., hidden_size], + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + torch::Tensor& scale, // [1] + double epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + /* This kernel is memory-latency bound in many scenarios. + When num_tokens is large, a smaller block size allows + for increased block occupancy on CUs and better latency + hiding on global mem ops. */ + const int max_block_size = (num_tokens < 256) ? 1024 : 256; + dim3 block(std::min(hidden_size, max_block_size)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + /*If the tensor types are FP16/BF16, try to use the optimized kernel + with packed + vectorized ops. + Max optimization is achieved with a width-8 vector of FP16/BF16s + since we can load at most 128 bits at once in a global memory op. + However, this requires each tensor's data to be aligned to 16 + bytes. + */ + auto inp_ptr = reinterpret_cast(input.data_ptr()); + auto res_ptr = reinterpret_cast(residual.data_ptr()); + auto wt_ptr = reinterpret_cast(weight.data_ptr()); + bool ptrs_are_aligned = + inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; + if (ptrs_are_aligned && hidden_size % 8 == 0) { + LAUNCH_FUSED_ADD_RMS_NORM(8); + } else { + LAUNCH_FUSED_ADD_RMS_NORM(0); + } +} diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu new file mode 100644 index 000000000000..f0e5533bcae6 --- /dev/null +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -0,0 +1,662 @@ +// clang-format off +// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu +// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu +#include +#include +#include + +#include "causal_conv1d.h" +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#include +#include + +#include "static_switch.h" + + + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + using weight_t = at::Half; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + using weight_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + + +template +void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +template +void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +void set_conv_params_fwd(ConvParamsBase ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t width, + // device pointers + const at::Tensor x, + const at::Tensor weight, + const at::Tensor out, + const std::optional& bias, + bool silu_activation, + int64_t pad_slot_id, + const std::optional& query_start_loc = std::nullopt, + const std::optional& cache_indices = std::nullopt, + const std::optional& has_initial_state = std::nullopt) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.width = width; + params.pad_slot_id = pad_slot_id; + + params.silu_activation = silu_activation; + + // Set the pointers and strides. + params.x_ptr = x.data_ptr(); + params.weight_ptr = weight.data_ptr(); + params.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr; + params.out_ptr = out.data_ptr(); + // All stride are in elements, not bytes. + params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr; + params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr; + params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr; + const bool varlen = params.query_start_loc_ptr != nullptr; + params.x_batch_stride = x.stride(varlen ? 1 : 0); + params.x_c_stride = x.stride(varlen ? 0 : 1); + params.x_l_stride = x.stride(varlen ? 1 : -1); + params.weight_c_stride = weight.stride(0); + params.weight_width_stride = weight.stride(1); + params.out_batch_stride = out.stride(varlen ? 1 : 0); + params.out_c_stride = out.stride(varlen ? 0 : 1); + params.out_l_stride = out.stride(varlen ? 1 : -1); +} + + +void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, + const std::optional &bias_, + const std::optional &conv_states, + const std::optional &query_start_loc, + const std::optional &cache_indices, + const std::optional &has_initial_state, + bool silu_activation, + // used to identify padding entries if cache_indices provided + // in case of padding, the kernel will return early + int64_t pad_slot_id) { + auto input_type = x.scalar_type(); + auto weight_type = weight.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(weight.is_cuda()); + + const bool varlen = query_start_loc.has_value() ? true : false; + const auto sizes = x.sizes(); + const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0]; + const int dim = varlen ? sizes[0] : sizes[1]; + const int seqlen = varlen ? sizes[1] : sizes[2]; + const int width = weight.size(-1); + if (varlen){ + CHECK_SHAPE(x, dim, seqlen); + } + else { + CHECK_SHAPE(x, batch_size, dim, seqlen); + } + CHECK_SHAPE(weight, dim, width); + + + + if (bias_.has_value()) { + auto bias = bias_.value(); + TORCH_CHECK(bias.scalar_type() == weight_type); + TORCH_CHECK(bias.is_cuda()); + TORCH_CHECK(bias.stride(-1) == 1); + CHECK_SHAPE(bias, dim); + } + + + if (has_initial_state.has_value()) { + auto has_initial_state_ = has_initial_state.value(); + TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool); + TORCH_CHECK(has_initial_state_.is_cuda()); + CHECK_SHAPE(has_initial_state_, batch_size); + } + + + if (query_start_loc.has_value()) { + auto query_start_loc_ = query_start_loc.value(); + TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int); + TORCH_CHECK(query_start_loc_.is_cuda()); + } + + + if (cache_indices.has_value()) { + auto cache_indices_ = cache_indices.value(); + TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int); + TORCH_CHECK(cache_indices_.is_cuda()); + CHECK_SHAPE(cache_indices_, batch_size); + } + + at::Tensor out = x; + + ConvParamsBase params; + set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, + bias_, + silu_activation, + pad_slot_id, + query_start_loc, + cache_indices, + has_initial_state + ); + + if (conv_states.has_value()) { + auto conv_states_ = conv_states.value(); + TORCH_CHECK(conv_states_.scalar_type() == input_type); + TORCH_CHECK(conv_states_.is_cuda()); + params.conv_states_ptr = conv_states_.data_ptr(); + params.conv_states_batch_stride = conv_states_.stride(0); + params.conv_states_c_stride = conv_states_.stride(1); + params.conv_states_l_stride = conv_states_.stride(2); + } else { + params.conv_states_ptr = nullptr; + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { + causal_conv1d_fwd_cuda(params, stream); + }); +} + + +void causal_conv1d_update(const at::Tensor &x, + const at::Tensor &conv_state, + const at::Tensor &weight, + const std::optional &bias_, + bool silu_activation, + const std::optional &cache_seqlens_, + const std::optional &conv_state_indices_, + // used to identify padding entries if cache_indices provided + // in case of padding, the kernel will return early + int64_t pad_slot_id) { + auto input_type = x.scalar_type(); + auto weight_type = weight.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations"); + TORCH_CHECK(conv_state.scalar_type() == input_type); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(conv_state.is_cuda()); + TORCH_CHECK(weight.is_cuda()); + + const auto sizes = x.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int width = weight.size(-1); + const int conv_state_len = conv_state.size(2); + TORCH_CHECK(conv_state_len >= width - 1); + + CHECK_SHAPE(x, batch_size, dim, seqlen); + CHECK_SHAPE(weight, dim, width); + + TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); + + if (bias_.has_value()) { + auto bias = bias_.value(); + TORCH_CHECK(bias.scalar_type() == weight_type); + TORCH_CHECK(bias.is_cuda()); + TORCH_CHECK(bias.stride(-1) == 1); + CHECK_SHAPE(bias, dim); + } + + at::Tensor out = x; + + ConvParamsBase params; + set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, + bias_, + silu_activation, + pad_slot_id); + params.conv_state_ptr = conv_state.data_ptr(); + params.conv_state_len = conv_state_len; + // All stride are in elements, not bytes. + params.conv_state_batch_stride = conv_state.stride(0); + params.conv_state_c_stride = conv_state.stride(1); + params.conv_state_l_stride = conv_state.stride(2); + + if (cache_seqlens_.has_value()) { + auto cache_seqlens = cache_seqlens_.value(); + TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32); + TORCH_CHECK(cache_seqlens.is_cuda()); + TORCH_CHECK(cache_seqlens.stride(-1) == 1); + CHECK_SHAPE(cache_seqlens, batch_size); + params.cache_seqlens = cache_seqlens.data_ptr(); + } else { + params.cache_seqlens = nullptr; + } + + if (conv_state_indices_.has_value()) { + auto conv_state_indices = conv_state_indices_.value(); + TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32) + TORCH_CHECK(conv_state_indices.is_cuda()); + TORCH_CHECK(conv_state_indices.stride(0) == 1) + CHECK_SHAPE(conv_state_indices, batch_size); + + int conv_state_entries = conv_state.size(0); + CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len); + + params.conv_state_indices_ptr = conv_state_indices.data_ptr(); + } else { + CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len); + params.conv_state_indices_ptr = nullptr; + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { + causal_conv1d_update_cuda(params, stream); + }); +} + +template +struct Causal_conv1d_fwd_kernel_traits { + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kWidth = kWidth_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : 8; + static_assert(kWidth <= kNElts); + static constexpr bool kIsVecLoad = kIsVecLoad_; + using vec_t = typename BytesToType::Type; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + static constexpr int kSmemIOSize = kIsVecLoad + ? 0 + : custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)}); + static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts; + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize; +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_fwd_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNElts = Ktraits::kNElts; + constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; + using input_t = typename Ktraits::input_t; + using vec_t = typename Ktraits::vec_t; + using weight_t = typename Ktraits::weight_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_vec = reinterpret_cast(smem_); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_store_vec = reinterpret_cast(smem_); + vec_t *smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + + const bool kVarlen = params.query_start_loc_ptr != nullptr; + const int tidx = threadIdx.x; + const int batch_id = blockIdx.x; + const int channel_id = blockIdx.y; + const int *query_start_loc = kVarlen ? reinterpret_cast(params.query_start_loc_ptr) : nullptr; + const int sequence_start_index = kVarlen ? query_start_loc[batch_id] : batch_id; + const int seqlen = kVarlen ? query_start_loc[batch_id + 1] - sequence_start_index : params.seqlen; + + input_t *x = reinterpret_cast(params.x_ptr) + sequence_start_index * params.x_batch_stride + + channel_id * params.x_c_stride; + weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + sequence_start_index * params.out_batch_stride + + channel_id * params.out_c_stride; + float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); + + bool has_initial_state = params.has_initial_state_ptr == nullptr ? false + : reinterpret_cast(params.has_initial_state_ptr)[batch_id]; + + int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr + : reinterpret_cast(params.cache_indices_ptr); + int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; + // cache_index == params.pad_slot_id is defined as padding, so we exit early + if (cache_index == params.pad_slot_id){ + return; + } + input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr + : reinterpret_cast(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride; + + // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0. + if (tidx == 0) { + input_t initial_state[kNElts] = {0}; + if (has_initial_state) { + #pragma unroll + for (int w = 0; w < kWidth - 1; ++w){ initial_state[kNElts - 1 - (kWidth - 2) + w ] = conv_states[w]; } + } + smem_exchange[kNThreads - 1] = reinterpret_cast(initial_state)[0]; + } + + float weight_vals[kWidth]; + #pragma unroll + for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } + + constexpr int kChunkSize = kNThreads * kNElts; + const int n_chunks = (seqlen + kChunkSize - 1) / kChunkSize; + for (int chunk = 0; chunk < n_chunks; ++chunk) { + input_t x_vals_load[2 * kNElts] = {0}; + if constexpr(kIsVecLoad) { + typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (seqlen - chunk * kChunkSize) / kNElts); + } else { + __syncthreads(); + typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize); + } + x += kChunkSize; + __syncthreads(); + // Thread kNThreads - 1 don't write yet, so that thread 0 can read + // the last elements of the previous chunk. + if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } + __syncthreads(); + reinterpret_cast(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1]; + __syncthreads(); + // Now thread kNThreads - 1 can write the last elements of the current chunk. + if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } + + float x_vals[2 * kNElts]; + #pragma unroll + for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } + + float out_vals[kNElts]; + #pragma unroll + for (int i = 0; i < kNElts; ++i) { + out_vals[i] = bias_val; + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; + } + } + + if (params.silu_activation) { + #pragma unroll + for (int i = 0; i < kNElts; ++i) { + out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); + } + } + + input_t out_vals_store[kNElts]; + #pragma unroll + for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; } + if constexpr(kIsVecLoad) { + typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(out), reinterpret_cast(out_vals_store), (seqlen - chunk * kChunkSize) / kNElts); + } else { + typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize); + } + out += kChunkSize; + + int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize); + // in case the final state is separated between the last "smem_exchange" and + // and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2), + // (which occurs when `final_state_position` is a non-positivie index) + // we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it + if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){ + input_t vals_load[kNElts] = {0}; + if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){ + // chunk = n_chunks - 2, a segment of the final state sits in the last index + reinterpret_cast(vals_load)[0] = smem_exchange[kNThreads - 1]; + #pragma unroll + for (int w = 0; w < -final_state_position; ++w){ + conv_states[w] = vals_load[kNElts + final_state_position + w]; + } + } + if ((chunk == n_chunks - 1) && tidx == 0){ + // chunk = n_chunks - 1, the second segment of the final state first positions + reinterpret_cast(vals_load)[0] = smem_exchange[0]; + for (int w = -final_state_position; w < kWidth - 1; ++w){ + conv_states[w] = vals_load[w + final_state_position]; + } + return; + } + } + } + // Final state is stored in the smem_exchange last token slot, + // in case seqlen < kWidth, we would need to take the final state from the + // initial state which is stored in conv_states + // in case seqlen > kWidth, we would need to load the last kWidth - 1 data + // and load it into conv_state accordingly + int last_thread = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize) / kNElts; + if (conv_states != nullptr && tidx == last_thread) { + input_t x_vals_load[kNElts * 2] = {0}; + // in case we are on the first kWidth tokens + if (last_thread == 0 && seqlen < kWidth){ + // Need to take the initial state + reinterpret_cast(x_vals_load)[0] = smem_exchange[0]; + const int offset = seqlen - (kWidth - 1); + #pragma unroll + for (int w = 0; w < kWidth - 1; ++w){ + // pad the existing state + if ((w - seqlen) >= 0 && has_initial_state) { conv_states[w - seqlen] = conv_states[w]; } + else if ((w - seqlen) >= 0 && !has_initial_state) { conv_states[w - seqlen] = input_t(0.0f); } + } + #pragma unroll + for (int w = 0; w < kWidth - 1; ++w){ + if (offset + w >= 0) + conv_states[w] = x_vals_load[offset + w ]; + } + } + else { + // in case the final state is in between the threads data + const int offset = ((seqlen - (kWidth - 1)) % (kNElts)); + if ((offset + kWidth - 2) >= kNElts && (last_thread + 1 < kNThreads)){ + // In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a + // illegal access error on H100. + // Therefore, we access last_thread + 1, only if the final state data sits there + reinterpret_cast(x_vals_load)[1] = smem_exchange[last_thread + 1]; + } + reinterpret_cast(x_vals_load)[0] = smem_exchange[last_thread]; + #pragma unroll + for (int w = 0; w < kWidth - 1; ++w){ + conv_states[w] = x_vals_load[offset + w ]; + } + } + + } +} + + +template +void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; + const bool kVarlen = params.query_start_loc_ptr != nullptr; + BOOL_SWITCH(params.seqlen % kNElts == 0 && !kVarlen, kIsVecLoad, [&] { + using Ktraits = Causal_conv1d_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize; + dim3 grid(params.batch, params.dim); + + auto kernel = &causal_conv1d_fwd_kernel; + + if (kSmemSize >= 48 * 1024) { + #ifndef USE_ROCM + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + #else + // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function. + C10_CUDA_CHECK(cudaFuncSetAttribute( + (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; + #endif + } + kernel<<>>(params); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +template +void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream); + } +} + + +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + + + + +template +struct Causal_conv1d_update_kernel_traits { + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kWidth = kWidth_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_update_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + + const int tidx = threadIdx.x; + const int batch_id = blockIdx.x; + const int channel_id = blockIdx.y * kNThreads + tidx; + if (channel_id >= params.dim) return; + + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + channel_id * params.x_c_stride; + + // If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor + // along the batch axis. Otherwise, the conv state coordinate is the same as the batch id. + const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr + ? batch_id + : params.conv_state_indices_ptr[batch_id]; + // conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early + if (conv_state_batch_coord == params.pad_slot_id){ + return; + } + input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + + conv_state_batch_coord * params.conv_state_batch_stride + + channel_id * params.conv_state_c_stride; + + weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + channel_id * params.out_c_stride; + float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); + + int state_len = params.conv_state_len; + int advance_len = params.seqlen; + int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0; + int update_idx = cache_seqlen - (kWidth - 1); + update_idx = update_idx < 0 ? update_idx + state_len : update_idx; + + float weight_vals[kWidth] = {0}; + #pragma unroll + for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } + + float x_vals[kWidth] = {0}; + if constexpr (!kIsCircularBuffer) { + #pragma unroll 2 + for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) { + conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride]; + } + #pragma unroll + for (int i = 0; i < kWidth - 1; ++i) { + input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride]; + if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) { + conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val; + } + x_vals[i] = float(state_val); + } + } else { + #pragma unroll + for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) { + input_t state_val = conv_state[update_idx * params.conv_state_l_stride]; + x_vals[i] = float(state_val); + } + } + #pragma unroll 2 + for (int i = 0; i < params.seqlen; ++i) { + input_t x_val = x[i * params.x_l_stride]; + if constexpr (!kIsCircularBuffer) { + if (i < advance_len && state_len - advance_len + i >= 0) { + conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val; + } + } else { + conv_state[update_idx * params.conv_state_l_stride] = x_val; + ++update_idx; + update_idx = update_idx >= state_len ? update_idx - state_len : update_idx; + } + x_vals[kWidth - 1] = float(x_val); + float out_val = bias_val; + #pragma unroll + for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; } + if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } + out[i * params.out_l_stride] = input_t(out_val); + // Shift the input buffer by 1 + #pragma unroll + for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; } + } +} + +template +void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + using Ktraits = Causal_conv1d_update_kernel_traits; + dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); + auto kernel = params.cache_seqlens == nullptr + ? &causal_conv1d_update_kernel + : &causal_conv1d_update_kernel; + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); + } +} + +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h new file mode 100644 index 000000000000..e26684a2b98b --- /dev/null +++ b/csrc/mamba/causal_conv1d/causal_conv1d.h @@ -0,0 +1,159 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +// clang-format off +// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h +#pragma once + +#include +#include +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ConvParamsBase { + using index_t = uint32_t; + + int batch, dim, seqlen, width; + int64_t pad_slot_id; + bool silu_activation; + + index_t x_batch_stride; + index_t x_c_stride; + index_t x_l_stride; + index_t weight_c_stride; + index_t weight_width_stride; + index_t out_batch_stride; + index_t out_c_stride; + index_t out_l_stride; + + int conv_state_len; + index_t conv_state_batch_stride; + index_t conv_state_c_stride; + index_t conv_state_l_stride; + + // Common data pointers. + void *__restrict__ x_ptr; + void *__restrict__ weight_ptr; + void *__restrict__ bias_ptr; + void *__restrict__ out_ptr; + + void *__restrict__ conv_state_ptr; + void *__restrict__ query_start_loc_ptr; + void *__restrict__ has_initial_state_ptr; + void *__restrict__ cache_indices_ptr; + int32_t *__restrict__ cache_seqlens; + + // For the continuous batching case. Makes it so that the mamba state for + // the current batch doesn't need to be a contiguous tensor. + int32_t *__restrict__ conv_state_indices_ptr; + + void *__restrict__ seq_idx_ptr; + + // No __restrict__ since initial_states could be the same as final_states. + void * initial_states_ptr; + index_t initial_states_batch_stride; + index_t initial_states_l_stride; + index_t initial_states_c_stride; + + void * final_states_ptr; + index_t final_states_batch_stride; + index_t final_states_l_stride; + index_t final_states_c_stride; + + void * conv_states_ptr; + index_t conv_states_batch_stride; + index_t conv_states_l_stride; + index_t conv_states_c_stride; +}; + + +#ifndef USE_ROCM + #include + + template + __device__ inline T shuffle_xor(T val, int offset) { + return __shfl_xor_sync(uint32_t(-1), val, offset); + } + + constexpr size_t custom_max(std::initializer_list ilist) + { + return std::max(ilist); + } + + template + constexpr T constexpr_min(T a, T b) { + return std::min(a, b); + } + +#else + #include + + template + __device__ inline T shuffle_xor(T val, int offset) { + return __shfl_xor(val, offset); + } + constexpr size_t custom_max(std::initializer_list ilist) + { + return *std::max_element(ilist.begin(), ilist.end()); + } + + template + constexpr T constexpr_min(T a, T b) { + return a < b ? a : b; + } +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct BytesToType {}; + +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ inline T operator()(T const & x, T const & y) { return x + y; } +}; + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +template<> +struct Allreduce<2> { +template +static __device__ inline T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; diff --git a/csrc/mamba/causal_conv1d/static_switch.h b/csrc/mamba/causal_conv1d/static_switch.h new file mode 100644 index 000000000000..ef74bf447f84 --- /dev/null +++ b/csrc/mamba/causal_conv1d/static_switch.h @@ -0,0 +1,28 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h +// clang-format off +// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + static constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + static constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h new file mode 100644 index 000000000000..563d2fe4ef65 --- /dev/null +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -0,0 +1,266 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +// clang-format off +// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan.h + +#pragma once + +#ifndef USE_ROCM + #include +#else + #include +#endif +#include +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SSMParamsBase { + using index_t = uint32_t; + + int batch, dim, seqlen, dstate, n_groups, n_chunks; + int dim_ngroups_ratio; + bool is_variable_B; + bool is_variable_C; + int64_t pad_slot_id; + + bool delta_softplus; + + index_t A_d_stride; + index_t A_dstate_stride; + index_t B_batch_stride; + index_t B_d_stride; + index_t B_dstate_stride; + index_t B_group_stride; + index_t C_batch_stride; + index_t C_d_stride; + index_t C_dstate_stride; + index_t C_group_stride; + index_t u_batch_stride; + index_t u_d_stride; + index_t delta_batch_stride; + index_t delta_d_stride; + index_t z_batch_stride; + index_t z_d_stride; + index_t out_batch_stride; + index_t out_d_stride; + index_t out_z_batch_stride; + index_t out_z_d_stride; + + // Common data pointers. + void *__restrict__ A_ptr; + void *__restrict__ B_ptr; + void *__restrict__ C_ptr; + void *__restrict__ D_ptr; + void *__restrict__ u_ptr; + void *__restrict__ delta_ptr; + void *__restrict__ delta_bias_ptr; + void *__restrict__ out_ptr; + void *__restrict__ ssm_states_ptr; + void *__restrict__ z_ptr; + void *__restrict__ out_z_ptr; + + void *__restrict__ query_start_loc_ptr; + void *__restrict__ cache_indices_ptr; + void *__restrict__ has_initial_state_ptr; + +}; + + + + +#ifndef USE_ROCM + + constexpr size_t custom_max(std::initializer_list ilist) + { + return std::max(ilist); + } + + template + constexpr T constexpr_min(T a, T b) { + return std::min(a, b); + } + +#else + constexpr size_t custom_max(std::initializer_list ilist) + { + return *std::max_element(ilist.begin(), ilist.end()); + } + + template + constexpr T constexpr_min(T a, T b) { + return a < b ? a : b; + } +#endif + + +#define MAX_DSTATE 256 + + +inline __device__ float2 operator+(const float2 & a, const float2 & b){ + return {a.x + b.x, a.y + b.y}; +} + +inline __device__ float3 operator+(const float3 &a, const float3 &b) { + return {a.x + b.x, a.y + b.y, a.z + b.z}; +} + +inline __device__ float4 operator+(const float4 & a, const float4 & b){ + return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct BytesToType {}; + +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Converter{ + static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { + #pragma unroll + for (int i = 0; i < N; ++i) { dst[i] = src[i]; } + } +}; + +template +struct Converter{ + static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } + } +}; + +#if __CUDA_ARCH__ >= 800 +template +struct Converter{ + static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +template struct SSMScanOp; + +template<> +struct SSMScanOp { + __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { + return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); + } +}; + +// A stateful callback functor that maintains a running prefix to be applied +// during consecutive scan operations. +template struct SSMScanPrefixCallbackOp { + using scan_t = std::conditional_t, float2, float4>; + scan_t running_prefix; + // Constructor + __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide scan. + __device__ scan_t operator()(scan_t block_aggregate) { + scan_t old_prefix = running_prefix; + running_prefix = SSMScanOp()(running_prefix, block_aggregate); + return old_prefix; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void load_input(typename Ktraits::input_t *u, + typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadT::TempStorage &smem_load, + int seqlen) { + if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) { + auto& smem_load_vec = reinterpret_cast(smem_load); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockLoadVecT(smem_load_vec).Load( + reinterpret_cast(u), + reinterpret_cast(u_vals) + #ifdef USE_ROCM + , Ktraits::kNThreads * Ktraits::kNLoads + #endif + + ); + } else { + typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); + } +} + + +template +inline __device__ void load_weight(typename Ktraits::input_t *Bvar, + typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, + int seqlen) { + constexpr int kNItems = Ktraits::kNItems; + typename Ktraits::input_t B_vals_load[kNItems]; + if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) { + auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( + reinterpret_cast(Bvar), + reinterpret_cast(B_vals_load) + ); + } else { + typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); + } + // #pragma unroll + // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } + Converter::to_float(B_vals_load, B_vals); +} + +template +inline __device__ void store_output(typename Ktraits::input_t *out, + const float (&out_vals)[Ktraits::kNItems], + typename Ktraits::BlockStoreT::TempStorage &smem_store, + int seqlen) { + typename Ktraits::input_t write_vals[Ktraits::kNItems]; + #pragma unroll + for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } + if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) { + auto& smem_store_vec = reinterpret_cast(smem_store); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockStoreVecT(smem_store_vec).Store( + reinterpret_cast(out), + reinterpret_cast(write_vals) + ); + } else { + typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); + } +} diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu new file mode 100644 index 000000000000..bd0a34119c82 --- /dev/null +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -0,0 +1,658 @@ +// clang-format off +// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh +#include +#include +#include +#include "selective_scan.h" + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#ifndef USE_ROCM + #include + #include + #include +#else + #include + namespace cub = hipcub; +#endif + +#include "selective_scan.h" +#include "static_switch.h" + +template +struct Selective_Scan_fwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. + static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; + static constexpr int kNItems = kNItems_; + static constexpr int kNRows = kNRows_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsEvenLen = kVarlen_ ? false : kIsEvenLen_; + static constexpr bool kIsVariableB = kIsVariableB_; + static constexpr bool kIsVariableC = kIsVariableC_; + static constexpr bool kHasZ = kHasZ_; + static constexpr bool kVarlen = kVarlen_; + + static constexpr bool kDirectIO = kVarlen_ ? false : kIsEvenLen && kNLoads == 1; + static constexpr int kNLoadsIndex = kNItems / 4; + using vec_t = typename BytesToType::Type; + using scan_t = float2; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_fwd_kernel(SSMParamsBase params) { + constexpr bool kIsVariableB = Ktraits::kIsVariableB; + constexpr bool kIsVariableC = Ktraits::kIsVariableC; + constexpr bool kHasZ = Ktraits::kHasZ; + constexpr bool kVarlen = Ktraits::kVarlen; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr int kNRows = Ktraits::kNRows; + constexpr bool kDirectIO = Ktraits::kDirectIO; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + // cast to lvalue reference of expected type + // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); + // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); + // auto& smem_load = reinterpret_cast(smem_loadstorescan); + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + // weight_t *smem_a = reinterpret_cast(smem_ + smem_loadstorescan_size); + // weight_t *smem_bc = reinterpret_cast(smem_a + MAX_DSTATE); + scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id / (params.dim_ngroups_ratio); + int seqlen = params.seqlen; + int sequence_start_index = batch_id; + if constexpr (kVarlen){ + int *query_start_loc = reinterpret_cast(params.query_start_loc_ptr); + sequence_start_index = query_start_loc[batch_id]; + seqlen = query_start_loc[batch_id + 1] - sequence_start_index; + } + const bool has_initial_state = params.has_initial_state_ptr == nullptr ? false + : reinterpret_cast(params.has_initial_state_ptr)[batch_id]; + + const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr + : reinterpret_cast(params.cache_indices_ptr); + const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; + // cache_index == params.pad_slot_id is defined as padding, so we exit early + if (cache_index == params.pad_slot_id){ + return; + } + input_t *u = reinterpret_cast(params.u_ptr) + sequence_start_index * params.u_batch_stride + + dim_id * kNRows * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + sequence_start_index * params.delta_batch_stride + + dim_id * kNRows * params.delta_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; + weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride; + weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride; + input_t *ssm_states = reinterpret_cast(params.ssm_states_ptr) + (cache_index * params.dim + dim_id * kNRows) * params.dstate; + + float D_val[kNRows] = {0}; + if (params.D_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + D_val[r] = reinterpret_cast(params.D_ptr)[dim_id * kNRows + r]; + } + } + float delta_bias[kNRows] = {0}; + if (params.delta_bias_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + delta_bias[r] = reinterpret_cast(params.delta_bias_ptr)[dim_id * kNRows + r]; + } + } + + + // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; + // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride]; + // } + + constexpr int kChunkSize = kNThreads * kNItems; + const int n_chunks = (seqlen + 2048 - 1) / 2048; + for (int chunk = 0; chunk < n_chunks; ++chunk) { + input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; + + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + load_input(u + r * params.u_d_stride, u_vals[r], smem_load, seqlen - chunk * kChunkSize); + if constexpr (!kDirectIO) { __syncthreads(); } + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, seqlen - chunk * kChunkSize); + } + u += kChunkSize; + delta += kChunkSize; + + float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float u_val = float(u_vals[r][i]); + delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r]; + if (params.delta_softplus) { + delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; + } + delta_u_vals[r][i] = delta_vals[r][i] * u_val; + out_vals[r][i] = D_val[r] * u_val; + } + } + + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + weight_t A_val[kNRows]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. + constexpr float kLog2e = M_LOG2E; + A_val[r] *= kLog2e; + } + // This variable holds B * C if both B and C are constant across seqlen. If only B varies + // across seqlen, this holds C. If only C varies across seqlen, this holds B. + // If both B and C vary, this is unused. + weight_t BC_val[kNRows]; + weight_t B_vals[kNItems], C_vals[kNItems]; + if constexpr (kIsVariableB) { + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (seqlen - chunk * kChunkSize) * (1)); + if constexpr (!kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } + } + if constexpr (kIsVariableC) { + auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1 )); + if constexpr (!kIsVariableB) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; + } + } + } + if constexpr (!kIsVariableB && !kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } + + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if (r > 0) { __syncthreads(); } // Scan could be using the same smem + scan_t thread_data[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), + !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); + + if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct + if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) { + thread_data[i] = make_float2(1.f, 0.f); + } + } + } + // Initialize running total + + scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx]): 0.0); + + SSMScanPrefixCallbackOp prefix_op(running_prefix); + typename Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + // There's a syncthreads in the scan op, so we don't need to sync here. + // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. + if (threadIdx.x == 0) { + smem_running_prefix[state_idx] = prefix_op.running_prefix; + if (chunk == n_chunks - 1) { + ssm_states[state_idx] = input_t(prefix_op.running_prefix.y); + } + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const weight_t C_val = !kIsVariableC + ? BC_val[r] + : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]); + out_vals[r][i] += thread_data[i].y * C_val; + } + } + } + + input_t *out = reinterpret_cast(params.out_ptr) + sequence_start_index * params.out_batch_stride + + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + store_output(out + r * params.out_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize); + } + + if constexpr (kHasZ) { + input_t *z = reinterpret_cast(params.z_ptr) + sequence_start_index * params.z_batch_stride + + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; + input_t *out_z = reinterpret_cast(params.out_z_ptr) + sequence_start_index * params.out_z_batch_stride + + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + input_t z_vals[kNItems]; + __syncthreads(); + load_input(z + r * params.z_d_stride, z_vals, smem_load, seqlen - chunk * kChunkSize); + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float z_val = z_vals[i]; + out_vals[r][i] *= z_val / (1 + expf(-z_val)); + } + __syncthreads(); + store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize); + } + } + + Bvar += kChunkSize * 1; + Cvar += kChunkSize * 1; + } +} + +template +void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { + // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block + // processing 1 row. + constexpr int kNRows = 1; + // kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size + constexpr bool kIsVariableB = true; + constexpr bool kIsVariableC = true; + constexpr bool kHasZ = true; + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); +} + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { + + #ifndef USE_ROCM + if (params.seqlen <= 128) { + selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + } + #else + if (params.seqlen <= 256) { + selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + } + #endif +} + +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); + +void set_ssm_params_fwd(SSMParamsBase ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t dstate, + const size_t n_groups, + const bool is_variable_B, + const bool is_variable_C, + // device pointers + const torch::Tensor u, + const torch::Tensor delta, + const torch::Tensor A, + const torch::Tensor B, + const torch::Tensor C, + const torch::Tensor out, + const torch::Tensor z, + const torch::Tensor out_z, + const std::optional& D, + const std::optional& delta_bias, + const torch::Tensor ssm_states, + bool has_z, + bool delta_softplus, + const std::optional& query_start_loc, + const std::optional& cache_indices, + const std::optional& has_initial_state, + bool varlen, + int64_t pad_slot_id) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.dstate = dstate; + params.n_groups = n_groups; + params.dim_ngroups_ratio = dim / n_groups; + params.pad_slot_id = pad_slot_id; + + params.delta_softplus = delta_softplus; + + params.is_variable_B = is_variable_B; + params.is_variable_C = is_variable_C; + + // Set the pointers and strides. + params.u_ptr = u.data_ptr(); + params.delta_ptr = delta.data_ptr(); + params.A_ptr = A.data_ptr(); + params.B_ptr = B.data_ptr(); + params.C_ptr = C.data_ptr(); + params.D_ptr = D.has_value() ? D.value().data_ptr() : nullptr; + params.delta_bias_ptr = delta_bias.has_value() ? delta_bias.value().data_ptr() : nullptr; + params.out_ptr = out.data_ptr(); + params.ssm_states_ptr = ssm_states.data_ptr(); + params.z_ptr = has_z ? z.data_ptr() : nullptr; + params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; + params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr; + params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr; + params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr; + + + // All stride are in elements, not bytes. + params.A_d_stride = A.stride(0); + params.A_dstate_stride = A.stride(1); + + if (varlen){ + params.B_batch_stride = B.stride(2); + params.B_group_stride = B.stride(0); + params.B_dstate_stride = B.stride(1); + params.C_batch_stride = C.stride(2); + params.C_group_stride = C.stride(0); + params.C_dstate_stride = C.stride(1); + + params.u_batch_stride = u.stride(1); + params.u_d_stride = u.stride(0); + params.delta_batch_stride = delta.stride(1); + params.delta_d_stride = delta.stride(0); + if (has_z) { + params.z_batch_stride = z.stride(1); + params.z_d_stride = z.stride(0); + params.out_z_batch_stride = out_z.stride(1); + params.out_z_d_stride = out_z.stride(0); + } + params.out_batch_stride = out.stride(1); + params.out_d_stride = out.stride(0); + + } + else{ + if (!is_variable_B) { + params.B_d_stride = B.stride(0); + } else { + params.B_batch_stride = B.stride(0); + params.B_group_stride = B.stride(1); + } + params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2); + if (!is_variable_C) { + params.C_d_stride = C.stride(0); + } else { + params.C_batch_stride = C.stride(0); + params.C_group_stride = C.stride(1); + } + params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2); + params.u_batch_stride = u.stride(0); + params.u_d_stride = u.stride(1); + params.delta_batch_stride = delta.stride(0); + params.delta_d_stride = delta.stride(1); + if (has_z) { + params.z_batch_stride = z.stride(0); + params.z_d_stride = z.stride(1); + params.out_z_batch_stride = out_z.stride(0); + params.out_z_d_stride = out_z.stride(1); + } + params.out_batch_stride = out.stride(0); + params.out_d_stride = out.stride(1); + } +} + +void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, + const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C, + const std::optional &D_, + const std::optional &z_, + const std::optional &delta_bias_, + bool delta_softplus, + const std::optional &query_start_loc, + const std::optional &cache_indices, + const std::optional &has_initial_state, + const torch::Tensor &ssm_states, + // used to identify padding entries if cache_indices provided + // in case of padding, the kernel will return early + int64_t pad_slot_id) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float); + + const bool is_variable_B = B.dim() >= 3; + const bool is_variable_C = C.dim() >= 3; + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); + TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + + const auto sizes = u.sizes(); + const bool varlen = query_start_loc.has_value(); + const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0]; + const int dim = varlen ? sizes[0] : sizes[1]; + const int seqlen = varlen ? sizes[1] : sizes[2]; + const int dstate = A.size(1); + const int n_groups = varlen ? B.size(0) : B.size(1); + + TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); + + if (varlen) { + CHECK_SHAPE(u, dim, seqlen); + CHECK_SHAPE(delta, dim, seqlen); + } else { + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + } + CHECK_SHAPE(A, dim, dstate); + TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size") + if (varlen) { + CHECK_SHAPE(B, n_groups, dstate, seqlen); + } else { + CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen); + } + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + + TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size") + if (varlen) { + CHECK_SHAPE(C, n_groups, dstate, seqlen); + } else { + CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen); + } + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + + + if (has_initial_state.has_value()) { + auto has_initial_state_ = has_initial_state.value(); + TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool); + TORCH_CHECK(has_initial_state_.is_cuda()); + CHECK_SHAPE(has_initial_state_, batch_size); + } + + + if (query_start_loc.has_value()) { + auto query_start_loc_ = query_start_loc.value(); + TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int); + TORCH_CHECK(query_start_loc_.is_cuda()); + } + + + if (cache_indices.has_value()) { + auto cache_indices_ = cache_indices.value(); + TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int); + TORCH_CHECK(cache_indices_.is_cuda()); + CHECK_SHAPE(cache_indices_, batch_size); + } + + + at::Tensor z, out_z; + const bool has_z = z_.has_value(); + TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size") + z = z_.value(); + TORCH_CHECK(z.scalar_type() == input_type); + TORCH_CHECK(z.is_cuda()); + TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); + if (varlen){ + CHECK_SHAPE(z, dim, seqlen); + } else { + CHECK_SHAPE(z, batch_size, dim, seqlen); + } + + out_z = z; + + // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout + at::Tensor out = delta; + TORCH_CHECK(ssm_states.scalar_type() == input_type); + TORCH_CHECK(ssm_states.is_cuda()); + TORCH_CHECK(ssm_states.stride(-1) == 1); + + SSMParamsBase params; + set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, is_variable_B, is_variable_C, + u, delta, A, B, C, out, z, out_z, + D_, + delta_bias_, + ssm_states, + has_z, + delta_softplus, + query_start_loc, + cache_indices, + has_initial_state, + varlen, + pad_slot_id + ); + + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { + selective_scan_fwd_cuda(params, stream); + }); +} + diff --git a/csrc/mamba/mamba_ssm/static_switch.h b/csrc/mamba/mamba_ssm/static_switch.h new file mode 100644 index 000000000000..840cb2374a2f --- /dev/null +++ b/csrc/mamba/mamba_ssm/static_switch.h @@ -0,0 +1,28 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +// clang-format off +// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/static_switch.h +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel.h b/csrc/moe/marlin_kernels/marlin_moe_kernel.h new file mode 100644 index 000000000000..47ecf109d0f5 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel.h @@ -0,0 +1,1616 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include + +#include "core/scalar_type.hpp" + +namespace marlin_moe { + +constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed +// for instance as inputs to tensor core operations. Consequently, all +// corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee +// this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is +// documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales +using FragZP = Vec; + +// Predicated asynchronous global->shared copy; used for inputs A where we apply +// predication to handle batchsizes that are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +// Asynchronous global->shared copy +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, + FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline FragB dequant(int q); + +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +template <> +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); + int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 +// Reference: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +template <> +__device__ inline FragB dequant(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +template <> +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); + int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); + + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +template <> +__device__ inline FragB dequant(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +__device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) { + half2 zp = __half2half2(reinterpret_cast<__half*>(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); +} + +// Same as above, but for act_order (each K is multiplied individually) +__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, + FragS& frag_s_3, FragS& frag_s_4, int i) { + __half2 s_val_1_2; + s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; + + __half2 s_val_3_4; + s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Given 2 floats multiply by 2 scales (halves) +__device__ inline void scale_float(float* c, FragS& s) { + __half* s_ptr = reinterpret_cast<__half*>(&s); + c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__device__ void MarlinMoESingle( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block // current m block to start kernel computation from +) { + static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + constexpr int pack_factor = 32 / w_type.size_bits(); + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + sorted_ids += 16 * thread_m_blocks; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Zero-points sizes/strides + int zp_gl_stride = (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } + } + int zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + int4* sh_s = sh_zp + (stages * zp_sh_stage); + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_sh_wr_delta * i + a_sh_wr; + int row = a_idx / a_gl_rd_delta_o; + if (row >= prob_m) { + a_sh_wr_pred[i] = false; + } else { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } + } + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; + int row = a_idx / a_gl_stride; + int sorted_row = + replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; + int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; + if (sorted_row < tot_m * (replicate_input ? 1 : topk) && + new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], + a_sh_wr_pred[i]); + } + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + + if constexpr (has_zp && group_blocks != -1) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], + &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + auto fetch_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + + #pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + + if constexpr (has_zp) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + + } else if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + if constexpr (has_zp) { + FragB frag_zp_0; + FragB frag_zp_1; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k % 2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k % 2][0]; + zp_quant_1 = frag_qzp[k % 2][1]; + } + + frag_zp_0 = dequant(zp_quant_0); + frag_zp_1 = dequant(zp_quant_1); + + frag_zp[0] = frag_zp_0[0]; + frag_zp[1] = frag_zp_0[1]; + frag_zp[2] = frag_zp_1[0]; + frag_zp[3] = frag_zp_1[1]; + } + + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant_0, b_quant_1; + if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k % 2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + FragB frag_b0 = dequant(b_quant_0); + FragB frag_b1 = dequant(b_quant_1); + // Apply zero-point to frag_b0 + if constexpr (has_zp) { + sub_zp(frag_b0, frag_zp[j], 0); + } + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); + } else { + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + } + + // Apply zero-point to frag_b1 + if constexpr (has_zp) { + sub_zp(frag_b1, frag_zp[j], 1); + } + + // Apply scale to frag_b1 + if constexpr (has_act_order) { + scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); + + } else { + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int sorted_row = sorted_ids[c_idx / c_gl_stride]; + int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], + sorted_row < tot_m * topk && + (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk))); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + __half2float(reinterpret_cast<__half*>(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<__half*>(&c)[j] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int row = sorted_ids[c_idx / c_gl_stride]; + if (row < tot_m * topk) { + int new_idx = row * c_gl_stride + c_idx % c_gl_stride; + C[new_idx] = c; + } + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4) { + res = __hmul2(res, s[0]); + } + + ((half2*)sh)[idx] = res; + }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + int row = sorted_ids[c_gl_wr / c_gl_stride]; + if (row < tot_m * topk) { + int off = row * c_gl_stride + c_gl_wr % c_gl_stride; + if (!apply_weights) { + C[off] = sh[c_sh_rd]; + } else { + __half* ctrg = reinterpret_cast<__half*>(&C[off]); + __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); + for (int j = 0; j < 8; ++j) { + ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); + } + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + + if constexpr (has_zp && group_blocks == -1) { + if (i == 0) { + fetch_zp_to_shared(); + } + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (w_type.size_bits() == 8) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } else { + // For 4-bit per-column scales, we only fetch them here in the + // final step before write-out + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (w_type.size_bits() == 8) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + + start_pipes(); + } + } + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block, // current m block to start kernel computation from + int max_par, // maximum parallelism + int cfg_max_m_blocks // upper bound on m blocks +) { + int m_block_ctr = current_m_block; + + const int* sorted_ids_expert = + sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; + int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; + if (tot_its == 0) { + return; + } + int tot_m_blocks = ceildiv(tot_its, 16); + int pad = 16 * tot_m_blocks - tot_its; + + if (m_block_ctr >= tot_m_blocks) { + return; + } + + int max_block = tot_m_blocks - m_block_ctr; + prob_m = tot_its - 16 * m_block_ctr; + + int par = 1; + if (max_block > cfg_max_m_blocks) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); + if (par > max_par) par = max_par; + prob_m = (16 * cfg_max_m_blocks) * par; + m_block_ctr += cfg_max_m_blocks * (par - 1); + max_block = cfg_max_m_blocks; + } + + if (max_block == 1) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else if (max_block == 2) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else if (max_block == 3) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } +} + +#else + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block, // current m block to start kernel computation from + int max_par, // maximum parallelism + int cfg_max_m_blocks // upper bound on m blocks +) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +#endif + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +const int USER_THREADS = + 256; // Note: This is only used with user-provided thread_k/n +const int STAGES = 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; + +#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ + HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \ + else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + MarlinMoE \ + <<>>( \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par, \ + cfg_max_m_blocks); \ + } + +#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) + +#define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu new file mode 100644 index 000000000000..77bc0dd90edd --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu @@ -0,0 +1,31 @@ +#include "marlin_moe_kernel_ku4.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku4( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks) { + bool has_zp = true; + + if (false) { + } + AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256) + AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256) + AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128) + AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128) + else { + return false; + } + return true; +} + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h new file mode 100644 index 000000000000..833fadf37721 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h @@ -0,0 +1,20 @@ +#pragma once + +#include "marlin_moe_kernel.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku4( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks); + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu new file mode 100644 index 000000000000..f7e57b037594 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu @@ -0,0 +1,31 @@ +#include "marlin_moe_kernel_ku4b8.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku4b8( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks) { + bool has_zp = false; + + if (false) { + } + GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) + else { + return false; + } + return true; +} + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h new file mode 100644 index 000000000000..494da8f10e26 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h @@ -0,0 +1,20 @@ +#pragma once + +#include "marlin_moe_kernel.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku4b8( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks); + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu new file mode 100644 index 000000000000..a901f0b11cd7 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu @@ -0,0 +1,31 @@ +#include "marlin_moe_kernel_ku8b128.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku8b128( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks) { + bool has_zp = false; + + if (false) { + } + GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) + else { + return false; + } + return true; +} + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h new file mode 100644 index 000000000000..f3018aa0c1ab --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h @@ -0,0 +1,18 @@ +#pragma once + +#include "marlin_moe_kernel.h" + +namespace marlin_moe { + +bool call_marlin_moe_kernel_ku8b128( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks); + +} diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu new file mode 100644 index 000000000000..5f12483e951e --- /dev/null +++ b/csrc/moe/marlin_moe_ops.cu @@ -0,0 +1,588 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include +#include + +#include + +#include "core/exception.hpp" +#include "core/scalar_type.hpp" +#include "core/registration.h" +#include "marlin_kernels/marlin_moe_kernel_ku4b8.h" +#include "marlin_kernels/marlin_moe_kernel_ku8b128.h" +#include "marlin_kernels/marlin_moe_kernel_ku4.h" + +template +inline std::string str(T x) { + return std::to_string(x); +} + +namespace marlin_moe { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + int start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = finish_row - start_row; + + int row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / blockDim.x; + int rest = size_k % blockDim.x; + + int offset = row * row_stride; + + half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); + half* out_half = reinterpret_cast(out_int4_ptr + offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += blockDim.x; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, + int* __restrict__ expert_offsets, + int topk_length, int block_size) { + int expert_id = threadIdx.x; + int num_experts = blockDim.x; + + int occurrences = 0; + for (int i = 0; i < topk_length; ++i) { + occurrences += (topk_ids[i] == expert_id); + } + expert_offsets[expert_id + 1] = occurrences; + __syncthreads(); + + if (threadIdx.x == 0) { + int tot_offset = 0; + expert_offsets[0] = 0; + for (int i = 0; i < num_experts; ++i) { + tot_offset += ceildiv(expert_offsets[i + 1], block_size) * block_size; + expert_offsets[i + 1] = tot_offset; + } + } + __syncthreads(); +} + +#else + +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, + int* __restrict__ expert_offsets, + int topk_length, int block_size) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +#endif + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +typedef struct { + int max_m_blocks; + thread_config_t tb_cfg; +} exec_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, // Default + {128, 64, 128}, // Reduce N 2X, same K + {64, 256, 256}, // Reduce K 2X, increase N 2X + {64, 128, 128}, // Reduce K 2X, same N + {64, 64, 128}, // Reduce both 2X +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, // Default + {128, 128, 256}, // Reduce N 2X, increase K 2X + {64, 128, 128}, // Reduce N 2X, same K + {128, 64, 128}, // Reduce N 4X, increase K 2X + {64, 64, 128}, // Reduce N 4X, same K +}; + +int get_scales_cache_size(thread_config_t const& th_config, int prob_m, + int prob_n, int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = ceildiv(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = ceildiv(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = + tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 4; + + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * STAGES; + } +} + +bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int scales_cache_size, int max_shared_mem) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + + int b_size = (tb_k * tb_n / pack_factor) * 4; + + // Get A size + int m_blocks = ceildiv(prob_m, 16); + int tb_max_m = 16; + + while (true) { + if (m_blocks >= max_m_blocks) { + tb_max_m *= max_m_blocks; + break; + } + + max_m_blocks--; + if (max_m_blocks == 0) { + TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); + } + } + + int a_size = (tb_max_m * tb_k) * 2; + + float pipe_size = (a_size + b_size) * STAGES; + + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + + return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); +} + +bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // thread_k can be only 128 or 64 (because it must be less than groupsize + // which is 128) + if (th_config.thread_k != 128 && th_config.thread_k != 64) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Determine cache for scales + int scales_cache_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full); + + // Check that pipeline fits into cache + if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, scales_cache_size, max_shared_mem)) { + return false; + } + + return true; +} + +exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, + int num_bits, int group_size, + bool has_act_order, bool is_k_full, + int max_shared_mem) { + int max_m_blocks = 4; + while (max_m_blocks > 0) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } + + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage + } + + return exec_config_t{0, {-1, -1, -1}}; +} + +#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \ + else if (KERNEL_FUNCTION( \ + q_type, thread_n_blocks, thread_k_blocks, has_act_order, \ + group_blocks, num_threads, blocks, max_shared_mem, stream, \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par, \ + exec_cfg.max_m_blocks)) { \ + } + +void marlin_mm_moe(const void* A, const void* B, void* C, + const void* sorted_ids, const void* topk_weights, + const void* topk_ids, const void* s, void* zp, + const void* g_idx, const void* perm, void* a_tmp, + void* expert_offsets, int prob_m, int prob_n, int prob_k, + void* workspace, vllm::ScalarType const& q_type, + bool has_act_order, bool is_k_full, bool has_zp, + int num_groups, int group_size, int num_experts, int topk, + int moe_block_size, int dev, cudaStream_t stream, + int thread_k, int thread_n, int sms, int max_par, + bool replicate_input, bool apply_weights) { + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + if (sms == -1) { + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + int num_bits = q_type.size_bits(); + + // Set thread config + exec_config_t exec_cfg; + if (thread_k != -1 && thread_n != -1) { + // User-defined config + exec_cfg = + exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}}; + } else { + // Auto config + exec_cfg = + determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem); + } + + TORCH_CHECK(exec_cfg.max_m_blocks > 0 && + is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, + prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem), + "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, + ", thread_k = ", exec_cfg.tb_cfg.thread_k, + ", thread_n = ", exec_cfg.tb_cfg.thread_n, + ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", + prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", max_shared_mem = ", max_shared_mem); + + int num_threads = exec_cfg.tb_cfg.num_threads; + thread_k = exec_cfg.tb_cfg.thread_k; + thread_n = exec_cfg.tb_cfg.thread_n; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + int blocks = sms; + + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(group_size != -1); + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } else { + TORCH_CHECK(group_size == 0); + group_blocks = 0; + } + + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } + } + + int tot_m = prob_m; + + const int* topk_ids_ptr = (const int*)topk_ids; + int* expert_offsets_ptr = (int*)expert_offsets; + compute_expert_offsets<<<1, num_experts, 0, stream>>>( + topk_ids_ptr, expert_offsets_ptr, tot_m * topk, moe_block_size); + + bool do_permute_a = has_act_order; + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) { + has_act_order = false; + } + + int pack_factor = 32 / q_type.size_bits(); + + for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + const int4* A_ptr = (const int4*)A; + int4* a_tmp_ptr = (int4*)a_tmp; + const int4* B_ptr = + (const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx; + int4* C_ptr = (int4*)C; + const float* topk_weights_ptr = (const float*)topk_weights; + const int* sorted_ids_ptr = (const int*)sorted_ids; + const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx; + const int4* zp_ptr = + (const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx; + const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; + const int* perm_ptr = (const int*)perm + prob_k * expert_idx; + int* locks = (int*)workspace; + + if (do_permute_a) { + // Permute A columns + int topk_rows = replicate_input ? tot_m : tot_m * topk; + int block_rows = ceildiv(topk_rows, blocks); + permute_cols_kernel<<>>( + A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows); + A_ptr = a_tmp_ptr; + } + + int tot_m_blocks = ceildiv(tot_m, 16); + for (int m_block = 0; m_block < tot_m_blocks; + m_block += 4 * exec_cfg.max_m_blocks) { + if (false) { + } + CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8) + CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) + CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4) + else { + TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + + str(prob_n) + ", " + str(prob_k) + "]" + + ", has_act_order = " + str(has_act_order) + + ", num_groups = " + str(num_groups) + + ", group_size = " + str(group_size) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } + } + } +} + +} // namespace marlin_moe + +torch::Tensor marlin_gemm_moe( + const torch::Tensor& a, const torch::Tensor& b_q_weights, + const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, + const torch::Tensor& topk_ids, const torch::Tensor& b_scales, + torch::Tensor& b_zeros, const torch::Tensor& g_idx, + const torch::Tensor& perm, torch::Tensor& workspace, + vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, + int64_t moe_block_size, bool replicate_input, bool apply_weights) { + vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); + bool has_zp = b_zeros.size(1) != 0; + if (has_zp) { + TORCH_CHECK( + b_q_type == vllm::kU4, + "b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str()); + } else { + TORCH_CHECK( + b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, + "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type.str()); + } + + int pack_factor = 32 / b_q_type.size_bits(); + + int max_par = 4; + + int dev = a.get_device(); + + auto options_dtype = + torch::TensorOptions().dtype(a.dtype()).device(a.device()); + auto options_int = + torch::TensorOptions().dtype(torch::kInt).device(a.device()); + torch::Tensor c = torch::zeros({size_m, topk, size_n}, options_dtype); + torch::Tensor a_tmp = + replicate_input ? torch::zeros({size_m, size_k}, options_dtype) + : torch::zeros({size_m, topk, size_k}, options_dtype); + torch::Tensor expert_offsets = torch::empty({num_experts + 1}, options_int); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel (can usually be left as auto -1) + int sms = -1; + + // Detect groupsize and act_order + int num_groups = -1; + int group_size = -1; + bool has_act_order = g_idx.size(1) != 0; + + int b_rank = b_scales.sizes().size(); + TORCH_CHECK(b_rank == 3, "b_scales rank = ", b_rank, " is not 3"); + TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2), + " is not size_n = ", size_n); + num_groups = b_scales.size(1); + + TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order), + "if is_k_full is false, has_act_order must be true"); + + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); + TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by num_groups = ", num_groups); + group_size = size_k / num_groups; + } else { + group_size = 0; + } + + } else { + if (num_groups > 1) { + TORCH_CHECK( + size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); + group_size = size_k / num_groups; + } else { + group_size = -1; + } + } + + // Verify b_zeros + if (has_zp) { + int rank = b_zeros.sizes().size(); + TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3"); + TORCH_CHECK(b_zeros.size(1) == num_groups, + "b_zeros dim 1 = ", b_zeros.size(1), + " is not num_groups = ", num_groups); + TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor, + "b_zeros dim 2 = ", b_zeros.size(2), + " is not size_n / pack_factor = ", size_n / pack_factor); + } + + marlin_moe::marlin_mm_moe( + a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), + topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), + b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), + expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), + b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, + num_experts, topk, moe_block_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, + replicate_input, apply_weights); + return c; +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("marlin_gemm_moe", &marlin_gemm_moe); +} diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu new file mode 100644 index 000000000000..d7be769458e3 --- /dev/null +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -0,0 +1,463 @@ +#include +#include +#include + +#include +#include + +#include "../cuda_compat.h" +#include "../dispatch_utils.h" + +#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) + +namespace vllm { +namespace moe { + +namespace { +__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, + int32_t col) { + // don't worry about overflow because num_experts is relatively small + return row * total_col + col; +} +} // namespace + +template +__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, + int32_t* sorted_token_ids, + int32_t* expert_ids, + int32_t* total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, size_t numel) { + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + extern __shared__ int32_t shared_mem[]; + int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1) + token_cnts_t* tokens_cnts = + (token_cnts_t*)(shared_mem + num_experts + + 1); // 2d tensor with shape (blockDim.x + 1, num_experts) + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; + } + + /** + * In the first step we compute token_cnts[thread_index + 1][expert_index], + * which counts how many tokens in the token shard of thread_index are + * assigned to expert expert_index. + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; + } + + __syncthreads(); + + // For each expert we accumulate the token counts from the different threads. + if (threadIdx.x < num_experts) { + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[index(num_experts, i, threadIdx.x)] += + tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + } + } + + __syncthreads(); + + // We accumulate the token counts of all experts in thread 0. + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i - 1] + + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], + block_size) * + block_size; + } + *total_tokens_post_pad = static_cast(cumsum[num_experts]); + } + + __syncthreads(); + + /** + * For each expert, each thread processes the tokens of the corresponding + * blocks and stores the corresponding expert_id for each block. + */ + if (threadIdx.x < num_experts) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; + i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } + } + + /** + * Each thread processes a token shard, calculating the index of each token + * after sorting by expert number. Given the example topk_ids = + * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, + * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a + * padding value(preset in python). + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int32_t expert_id = topk_ids[i]; + /** The cumsum[expert_id] stores the starting index of the tokens that the + * expert with expert_id needs to process, and + * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens + * processed by the expert with expert_id within the current thread's token + * shard. + */ + int32_t rank_post_pad = + tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; + } +} + +// TODO(simon): this is temporarily adapted from +// https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7 +// we did this to unblock Deepseek V3 but there should be a better +// implementation to manage shared memory. +template +__global__ void moe_align_block_size_global_mem_kernel( + scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, + int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, + int32_t block_size, size_t numel, int32_t* tokens_cnts, int32_t* cumsum) { + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; + } + + /** + * In the first step we compute token_cnts[thread_index + 1][expert_index], + * which counts how many tokens in the token shard of thread_index are + * assigned to expert expert_index. + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; + } + + __syncthreads(); + + // For each expert we accumulate the token counts from the different threads. + if (threadIdx.x < num_experts) { + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[index(num_experts, i, threadIdx.x)] += + tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + } + } + + __syncthreads(); + + // We accumulate the token counts of all experts in thread 0. + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i - 1] + + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], + block_size) * + block_size; + } + *total_tokens_post_pad = cumsum[num_experts]; + } + + __syncthreads(); + + /** + * For each expert, each thread processes the tokens of the corresponding + * blocks and stores the corresponding expert_id for each block. + */ + if (threadIdx.x < num_experts) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; + i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } + } + + /** + * Each thread processes a token shard, calculating the index of each token + * after sorting by expert number. Given the example topk_ids = + * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, + * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a + * padding value(preset in python). + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int32_t expert_id = topk_ids[i]; + /** The cumsum[expert_id] stores the starting index of the tokens that the + * expert with expert_id needs to process, and + * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens + * processed by the expert with expert_id within the current thread's token + * shard. + */ + int32_t rank_post_pad = + tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; + } +} + +// taken from +// https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957 +template +__global__ void sgl_moe_align_block_size_kernel( + scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, + int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, + int32_t block_size, size_t numel, int32_t* cumsum) { + __shared__ int32_t shared_counts[32][8]; + + const int warp_id = threadIdx.x / 32; + const int experts_per_warp = 8; + const int my_expert_start = warp_id * experts_per_warp; + + // Initialize shared_counts for this warp's experts + for (int i = 0; i < experts_per_warp; ++i) { + if (my_expert_start + i < num_experts) { + shared_counts[warp_id][i] = 0; + } + } + + __syncthreads(); + + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int expert_id = topk_ids[i]; + int warp_idx = expert_id / experts_per_warp; + int expert_offset = expert_id % experts_per_warp; + atomicAdd(&shared_counts[warp_idx][expert_offset], 1); + } + + __syncthreads(); + + // Single thread computes cumulative sum and total tokens + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + int expert_count = 0; + int warp_idx = (i - 1) / experts_per_warp; + int expert_offset = (i - 1) % experts_per_warp; + expert_count = shared_counts[warp_idx][expert_offset]; + + cumsum[i] = + cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size; + } + *total_tokens_post_pad = cumsum[num_experts]; + } + + __syncthreads(); + + // Assign expert IDs to blocks + if (threadIdx.x < num_experts) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; + i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } + } +} + +// taken from +// https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957 +template +__global__ void sgl_moe_token_sort_kernel(scalar_t* __restrict__ topk_ids, + int32_t* sorted_token_ids, + int32_t* cumsum_buffer, + size_t numel) { + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i]; + int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); + sorted_token_ids[rank_post_pad] = i; + } +} + +template +__global__ void moe_sum_kernel( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., topk, d] + const int d) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + scalar_t x = 0.0; +#pragma unroll + for (int k = 0; k < TOPK; ++k) { + x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]); + } + out[token_idx * d + idx] = x; + } +} + +} // namespace moe +} // namespace vllm + +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int device_max_shared_mem; + auto dev = topk_ids.get_device(); + cudaDeviceGetAttribute(&device_max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + + const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); + const int32_t shared_mem_i32 = + ((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); + const int32_t shared_mem_i16 = + ((num_thread + 1) * num_experts) * sizeof(uint16_t) + + (num_experts + 1) * sizeof(int32_t); + + bool use_global_memory = false; + bool use_i16 = false; // Use uint16_t for shared memory token counts + if (shared_mem_i32 < device_max_shared_mem) { + // Do nothing in this case. We're all set to use int32_t token counts + } else if (shared_mem_i16 < device_max_shared_mem && + topk_ids.numel() <= 65535) { + // when nelements of topk_ids is smaller than 65535 (max value of uint16), + // element value of token_cnts would also smaller than 65535, + // so we can use uint16 as dtype of token_cnts + use_i16 = true; + } else { + use_global_memory = true; + } + + if (use_global_memory) { + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { + // calc needed amount of shared mem for `tokens_cnts` and `cumsum` + // tensors + const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); + + auto options_int = torch::TensorOptions() + .dtype(torch::kInt) + .device(topk_ids.device()); + torch::Tensor token_cnts_buffer = + torch::empty({(num_experts + 1) * num_experts}, options_int); + torch::Tensor cumsum_buffer = + torch::empty({num_experts + 1}, options_int); + + auto kernel = + vllm::moe::moe_align_block_size_global_mem_kernel; + kernel<<<1, num_thread, 0, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, block_size, + topk_ids.numel(), token_cnts_buffer.data_ptr(), + cumsum_buffer.data_ptr()); + }); + } else if (use_i16) { + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + // set dynamic shared mem + auto kernel = + vllm::moe::moe_align_block_size_kernel; + AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( + (void*)kernel, shared_mem_i16)); + kernel<<<1, num_thread, shared_mem_i16, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, block_size, + topk_ids.numel()); + }); + } else { + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + auto kernel = + vllm::moe::moe_align_block_size_kernel; + AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( + (void*)kernel, shared_mem_i32)); + kernel<<<1, num_thread, shared_mem_i32, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, block_size, + topk_ids.numel()); + }); + } +} + +void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, + torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + TORCH_CHECK(num_experts == 256, + "sgl_moe_align_block_size kernel only supports deepseek v3."); + + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `cumsum` tensors + auto options_int = + torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); + torch::Tensor cumsum_buffer = + torch::zeros({num_experts + 1}, options_int); + + auto align_kernel = + vllm::moe::sgl_moe_align_block_size_kernel; + align_kernel<<<1, 1024, 0, stream>>>( + topk_ids.data_ptr(), sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, block_size, + topk_ids.numel(), cumsum_buffer.data_ptr()); + + const int block_threads = 256; + const int num_blocks = + (topk_ids.numel() + block_threads - 1) / block_threads; + const int max_blocks = 65535; + const int actual_blocks = std::min(num_blocks, max_blocks); + auto sort_kernel = vllm::moe::sgl_moe_token_sort_kernel; + sort_kernel<<>>( + topk_ids.data_ptr(), sorted_token_ids.data_ptr(), + cumsum_buffer.data_ptr(), topk_ids.numel()); + }); +} + +void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] + torch::Tensor& output) // [num_tokens, hidden_size] +{ + const int hidden_size = input.size(-1); + const int num_tokens = output.numel() / hidden_size; + const int topk = input.size(1); + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + switch (topk) { + case 2: + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { + vllm::moe::moe_sum_kernel<<>>( + output.data_ptr(), input.data_ptr(), + hidden_size); + }); + break; + + case 3: + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { + vllm::moe::moe_sum_kernel<<>>( + output.data_ptr(), input.data_ptr(), + hidden_size); + }); + break; + + case 4: + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { + vllm::moe::moe_sum_kernel<<>>( + output.data_ptr(), input.data_ptr(), + hidden_size); + }); + break; + + default: + at::sum_out(output, input, 1); + break; + } +} diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index a251730aa765..0bae119a7c46 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -5,3 +5,27 @@ void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& token_expert_indices, torch::Tensor& gating_output); + +void moe_sum(torch::Tensor& input, torch::Tensor& output); + +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad); + +void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, + torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad); +#ifndef USE_ROCM +torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, + torch::Tensor b_qweight, torch::Tensor b_scales, + std::optional b_qzeros, + std::optional topk_weights, + torch::Tensor sorted_token_ids, + torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad, int64_t top_k, + int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, + int64_t BLOCK_SIZE_K, int64_t bit); +#endif \ No newline at end of file diff --git a/csrc/moe/moe_wna16.cu b/csrc/moe/moe_wna16.cu new file mode 100644 index 000000000000..51ae76c1ec88 --- /dev/null +++ b/csrc/moe/moe_wna16.cu @@ -0,0 +1,346 @@ + +#include +#include +#include +#include + +#include +#include +#include "moe_wna16_utils.h" + +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) + +template +__global__ void moe_wna16_gemm_kernel( + const scalar_t* __restrict__ input, scalar_t* __restrict__ output, + + const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales, + const uint32_t* __restrict__ qzeros, + + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_token_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ num_tokens_post_pad, + + uint16_t num_experts, uint16_t group_size, uint16_t top_k, uint32_t size_m, + uint32_t size_n, uint32_t size_k, uint16_t BLOCK_SIZE_M, + uint16_t BLOCK_SIZE_N, uint16_t BLOCK_SIZE_K, bool has_zp, + bool mul_topk_weight) { +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 + if constexpr (std::is_same::value) { + return; + } else { +#endif + + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + + if (blockIdx.x * BLOCK_SIZE_M >= num_tokens_post_pad[0]) return; + + const int32_t offset_n = blockIdx.y * BLOCK_SIZE_N + threadIdx.x; + const int32_t offset_k = blockIdx.z * BLOCK_SIZE_K; + + const int32_t expert_id = expert_ids[blockIdx.x]; + + int32_t num_valid_tokens = 0; + extern __shared__ uint16_t block_input_tmp[]; + scalar_t* block_input = reinterpret_cast(block_input_tmp); + scalar_t2* block_input_half2 = reinterpret_cast(block_input); + + // load BLOCK_SIZE_M * BLOCK_SIZE_K into shared memory + for (int m = 0; m < BLOCK_SIZE_M; m++) { + const int32_t offset_m = blockIdx.x * BLOCK_SIZE_M + m; + const int32_t token_index = sorted_token_ids[offset_m]; + if (token_index / top_k >= size_m) break; + + num_valid_tokens = m + 1; + if (blockIdx.z == 0 && offset_n < size_n) + output[token_index * size_n + offset_n] = Dtype::int2num(0); + + if (expert_id != -1) { + int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N); + for (int i = 0; i < k_per_thread; i++) { + int k = BLOCK_SIZE_N * i + threadIdx.x; + if (k >= BLOCK_SIZE_K) break; + if (offset_k + k >= size_k) break; + + // load input to shared memory + // use a special layout to fit the layout of dequanted-weight + int origin_k; + if constexpr (bit == 4) { + // [0, 4, 1, 5, 2, 6, 3, 7] + int8_t order = (threadIdx.x % 2) * 4 + ((threadIdx.x % 8) / 2); + origin_k = BLOCK_SIZE_N * i + threadIdx.x / 8 * 8 + order; + } else { + // [0, 2, 1, 3] + int8_t order = (threadIdx.x % 2) * 2 + ((threadIdx.x % 4) / 2); + origin_k = BLOCK_SIZE_N * i + threadIdx.x / 4 * 4 + order; + } + + origin_k += token_index / top_k * size_k + blockIdx.z * BLOCK_SIZE_K; + block_input[m * BLOCK_SIZE_K + k] = input[origin_k]; + } + } + } + + if (expert_id == -1) return; + __syncthreads(); + if (threadIdx.x >= BLOCK_SIZE_N || offset_n >= size_n) return; + + float res[64]; // assume BLOCK_SIZE_M <= 64 + scalar_t2 res2; + scalar_t2 scale_f2; + scalar_t2 qzero_f2; + + // note that (size_n * size_k * expert_id) may greater than 2 ** 31 + constexpr int8_t pack_factor = 32 / bit; + const uint64_t expert_offset = ((uint64_t)size_n) * size_k * expert_id; + const uint32_t* expert_qweight = qweight + expert_offset / pack_factor; + const scalar_t* expert_scales = scales + expert_offset / group_size; + const uint32_t* expert_qzeros = + qzeros + expert_offset / group_size / pack_factor; + + // load 4*int32 one time: 4 int32 = 128 bit = 1 float4 + // weight would be loaded in loop + uint32_t expert_qweight_tmp[4]; + float4* expert_qweight_tmp_float4 = + reinterpret_cast(expert_qweight_tmp); + + // load all required scales one time + scalar_t expert_scales_groups[GROUPS]; + int scales_offset_tmp = + (offset_n * size_k + offset_k) / group_size / GROUPS; + if constexpr (GROUPS == 1) { + *expert_scales_groups = expert_scales[scales_offset_tmp]; + } else if constexpr (GROUPS == 2) { + float* expert_scales_groups_tmp = + reinterpret_cast(expert_scales_groups); + *expert_scales_groups_tmp = + reinterpret_cast(expert_scales)[scales_offset_tmp]; + } else if constexpr (GROUPS == 4) { + float2* expert_scales_groups_tmp = + reinterpret_cast(expert_scales_groups); + *expert_scales_groups_tmp = + reinterpret_cast(expert_scales)[scales_offset_tmp]; + } else if constexpr (GROUPS == 8) { + float4* expert_scales_groups_tmp = + reinterpret_cast(expert_scales_groups); + *expert_scales_groups_tmp = + reinterpret_cast(expert_scales)[scales_offset_tmp]; + } + + // load all required qzeros one time + uint8_t expert_qzeros_groups[GROUPS]; + if (!has_zp) { + if constexpr (bit == 4) { + qzero_f2 = Dtype::num2num2(Dtype::int2num(8)); + } else { + qzero_f2 = Dtype::num2num2(Dtype::int2num(128)); + } + } else { + int qzeros_offset_tmp = + (offset_n / (8 / bit)) * (size_k / group_size / GROUPS) + + offset_k / group_size / GROUPS; + if constexpr (GROUPS == 1) { + uint8_t* expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros_groups); + *expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros)[qzeros_offset_tmp]; + } else if constexpr (GROUPS == 2) { + uint16_t* expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros_groups); + *expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros)[qzeros_offset_tmp]; + } else if constexpr (GROUPS == 4) { + uint32_t* expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros_groups); + *expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros)[qzeros_offset_tmp]; + } else if constexpr (GROUPS == 8) { + uint64_t* expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros_groups); + *expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros)[qzeros_offset_tmp]; + } + } + + for (int tmp_k = 0; tmp_k < BLOCK_SIZE_K / pack_factor; tmp_k++) { + int k = offset_k + tmp_k * pack_factor; + if (k >= size_k) break; + const int32_t weight_offset = offset_n * size_k + k; + + if (tmp_k % 4 == 0) { + *expert_qweight_tmp_float4 = reinterpret_cast( + expert_qweight)[weight_offset / pack_factor / 4]; + } + + if (tmp_k % (group_size / pack_factor) == 0) { + scalar_t scale_f = + expert_scales_groups[tmp_k / (group_size / pack_factor)]; + scale_f2 = Dtype::num2num2(scale_f); + + if (has_zp) { + uint8_t qzero = + expert_qzeros_groups[tmp_k / (group_size / pack_factor)]; + if constexpr (bit == 4) { + qzero = (qzero >> ((threadIdx.x % 2) * 4)) & 0xF; + } + qzero_f2 = Dtype::num2num2(Dtype::int2num(qzero)); + } + } + + scalar_t2 weight_half2[16 / bit]; + dequant(expert_qweight_tmp[tmp_k % 4], weight_half2); + + for (int m = 0; m < num_valid_tokens; m++) { + res2 = {}; + +#pragma unroll + for (int i = 0; i < 16 / bit; i++) { + int32_t offset_input = m * BLOCK_SIZE_K / 2 + tmp_k * (16 / bit) + i; + res2 = __hfma2(__hmul2(__hsub2(weight_half2[i], qzero_f2), scale_f2), + block_input_half2[offset_input], res2); + } + + if (tmp_k == 0) { + res[m] = Dtype::num2float(res2.x) + Dtype::num2float(res2.y); + } else { + res[m] += Dtype::num2float(res2.x) + Dtype::num2float(res2.y); + } + } + } + + for (int m = 0; m < num_valid_tokens; ++m) { + const int32_t token_index = + sorted_token_ids[blockIdx.x * BLOCK_SIZE_M + m]; + if (mul_topk_weight) { + res[m] *= topk_weights[token_index]; + } + atomicAdd(&output[token_index * size_n + offset_n], + Dtype::float2num(res[m])); + } + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 + } +#endif +} + +template +void run_moe_wna16_gemm(const scalar_t* input, scalar_t* output, + const uint32_t* b_qweight, const scalar_t* b_scales, + const uint32_t* b_qzeros, const float* topk_weights, + const int32_t* sorted_token_ids, + const int32_t* expert_ids, + const int32_t* num_tokens_post_pad, int num_experts, + int group_size, int num_token_blocks, int top_k, + int size_m, int size_n, int size_k, int BLOCK_SIZE_M, + int BLOCK_SIZE_N, int BLOCK_SIZE_K, int bit, + bool has_zp, bool mul_topk_weight) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_SIZE_N; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = num_token_blocks; + gridDim.y = DIVIDE(size_n, BLOCK_SIZE_N); + gridDim.z = DIVIDE(size_k, BLOCK_SIZE_K); + + auto kernel = moe_wna16_gemm_kernel; + if (bit == 4) { + if (BLOCK_SIZE_K / group_size == 2) { + kernel = moe_wna16_gemm_kernel; + } else if (BLOCK_SIZE_K / group_size == 4) { + kernel = moe_wna16_gemm_kernel; + } else if (BLOCK_SIZE_K / group_size == 8) { + kernel = moe_wna16_gemm_kernel; + } + } else { + if (BLOCK_SIZE_K / group_size == 1) { + kernel = moe_wna16_gemm_kernel; + } else if (BLOCK_SIZE_K / group_size == 2) { + kernel = moe_wna16_gemm_kernel; + } else if (BLOCK_SIZE_K / group_size == 4) { + kernel = moe_wna16_gemm_kernel; + } else if (BLOCK_SIZE_K / group_size == 8) { + kernel = moe_wna16_gemm_kernel; + } + } + + const int shared_mem_size = BLOCK_SIZE_M * BLOCK_SIZE_K * 2; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>( + input, output, b_qweight, b_scales, b_qzeros, topk_weights, + sorted_token_ids, expert_ids, num_tokens_post_pad, num_experts, + group_size, top_k, size_m, size_n, size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, has_zp, mul_topk_weight); +} + +torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, + torch::Tensor b_qweight, torch::Tensor b_scales, + std::optional b_qzeros, + std::optional topk_weights, + torch::Tensor sorted_token_ids, + torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad, int64_t top_k, + int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, + int64_t BLOCK_SIZE_K, int64_t bit) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + auto options = + torch::TensorOptions().dtype(input.dtype()).device(input.device()); + + const int num_experts = b_qweight.size(0); + const int size_m = input.size(0); + const int size_n = b_qweight.size(1); + const int size_k = input.size(1); + const int group_size = size_k / b_scales.size(2); + + int64_t EM = sorted_token_ids.size(0); + if (size_m <= BLOCK_SIZE_M) { + EM = min(EM, size_m * BLOCK_SIZE_M * top_k); + } + const int num_token_blocks = (EM + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M; + + const uint32_t* b_qzeros_ptr; + if (b_qzeros.has_value()) + b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr(); + const float* topk_weights_ptr; + if (topk_weights.has_value()) + topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); + + int groups_per_block_row = BLOCK_SIZE_K / group_size; + TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); + TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, + "size_k must divisible by BLOCK_SIZE_K"); + TORCH_CHECK(BLOCK_SIZE_K % group_size == 0, + "BLOCK_SIZE_K must divisible by group_size"); + TORCH_CHECK(BLOCK_SIZE_M <= 64, "BLOCK_SIZE_M must less or equal to 64"); + TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || + groups_per_block_row == 4 || groups_per_block_row == 8, + "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); + + if (input.scalar_type() == at::ScalarType::Half) { + run_moe_wna16_gemm( + (const half*)input.data_ptr(), + (half*)output.data_ptr(), + (const uint32_t*)b_qweight.data_ptr(), + (const half*)b_scales.data_ptr(), b_qzeros_ptr, + topk_weights_ptr, sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), num_tokens_post_pad.data_ptr(), + num_experts, group_size, num_token_blocks, top_k, size_m, size_n, + size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit, + b_qzeros.has_value(), topk_weights.has_value()); + } else if (input.scalar_type() == at::ScalarType::BFloat16) { + run_moe_wna16_gemm( + (const nv_bfloat16*)input.data_ptr(), + (nv_bfloat16*)output.data_ptr(), + (const uint32_t*)b_qweight.data_ptr(), + (const nv_bfloat16*)b_scales.data_ptr(), b_qzeros_ptr, + topk_weights_ptr, sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), num_tokens_post_pad.data_ptr(), + num_experts, group_size, num_token_blocks, top_k, size_m, size_n, + size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit, + b_qzeros.has_value(), topk_weights.has_value()); + } else { + TORCH_CHECK(false, "moe_wna16_gemm only supports bfloat16 and float16"); + } + return output; +} diff --git a/csrc/moe/moe_wna16_utils.h b/csrc/moe/moe_wna16_utils.h new file mode 100644 index 000000000000..4396b80240ef --- /dev/null +++ b/csrc/moe/moe_wna16_utils.h @@ -0,0 +1,200 @@ + +#include +#include + +template +class ScalarType {}; + +template <> +class ScalarType { + public: + using scalar_t = half; + using scalar_t2 = half2; + + static __device__ float inline num2float(const half x) { + return __half2float(x); + } + + static __device__ half2 inline num2num2(const half x) { + return __half2half2(x); + } + + static __device__ half2 inline nums2num2(const half x1, const half x2) { + return __halves2half2(x1, x2); + } + + static __host__ __device__ half inline float2num(const float x) { + return __float2half(x); + } + + static __host__ __device__ half inline int2num(const float x) { + return __int2half_rn(x); + } + + static __host__ __device__ float2 inline num22float2(const half2 x) { + return __half22float2(x); + } + + static __host__ __device__ half2 inline float22num2(const float2 x) { + return __float22half2_rn(x); + } +}; + +template <> +class ScalarType { + public: + using scalar_t = nv_bfloat16; + using scalar_t2 = nv_bfloat162; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static __device__ float inline num2float(const nv_bfloat16 x) { + return __bfloat162float(x); + } + + static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { + return __bfloat162bfloat162(x); + } + + static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, + const nv_bfloat16 x2) { + return __halves2bfloat162(x1, x2); + } + + static __host__ __device__ nv_bfloat16 inline float2num(const float x) { + return __float2bfloat16(x); + } + + static __host__ __device__ nv_bfloat16 inline int2num(const float x) { + return __int2bfloat16_rn(x); + } + + static __host__ __device__ float2 inline num22float2(const nv_bfloat162 x) { + return __bfloat1622float2(x); + } + + static __host__ __device__ nv_bfloat162 inline float22num2(const float2 x) { + return __float22bfloat162_rn(x); + } +#endif +}; + +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline void dequant(int q, scalar_t2* res) {} + +template <> +__device__ inline void dequant(int q, half2* res) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + + int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); + int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); + q >>= 8; + int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); + int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); + + res[0] = __hsub2(*reinterpret_cast(&lo0), + *reinterpret_cast(&SUB)); + res[1] = __hfma2(*reinterpret_cast(&hi0), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + res[2] = __hsub2(*reinterpret_cast(&lo1), + *reinterpret_cast(&SUB)); + res[3] = __hfma2(*reinterpret_cast(&hi1), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant(int q, half2* res) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + + res[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + res[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +__device__ inline void dequant(int q, nv_bfloat162* res) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + q >>= 4; + int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + q >>= 4; + int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + q >>= 4; + int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC300C300; + + res[0] = __hfma2(*reinterpret_cast(&lo0), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + res[1] = __hfma2(*reinterpret_cast(&hi0), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + res[2] = __hfma2(*reinterpret_cast(&lo1), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + res[3] = __hfma2(*reinterpret_cast(&hi1), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* res) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388608.f; + fp32_intermediates[1] -= 8388608.f; + fp32_intermediates[2] -= 8388608.f; + fp32_intermediates[3] -= 8388608.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(res); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); +} +#endif diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 243752b9a9e8..718418e6cd49 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -1,4 +1,4 @@ -#include "registration.h" +#include "core/registration.h" #include "moe_ops.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { @@ -7,6 +7,53 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " "token_expert_indices, Tensor gating_output) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); + + // Calculate the result of moe by summing up the partial results + // from all selected experts. + m.def("moe_sum(Tensor! input, Tensor output) -> ()"); + m.impl("moe_sum", torch::kCUDA, &moe_sum); + + // Aligning the number of tokens to be processed by each expert such + // that it is divisible by the block size. + m.def( + "moe_align_block_size(Tensor topk_ids, int num_experts," + " int block_size, Tensor! sorted_token_ids," + " Tensor! experts_ids," + " Tensor! num_tokens_post_pad) -> ()"); + m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); + + // temporarily adapted from + // https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a + m.def( + "sgl_moe_align_block_size(Tensor topk_ids, int num_experts," + " int block_size, Tensor! sorted_token_ids," + " Tensor! experts_ids," + " Tensor! num_tokens_post_pad) -> ()"); + m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size); + +#ifndef USE_ROCM + m.def( + "moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, " + "Tensor b_scales, Tensor? b_qzeros, " + "Tensor? topk_weights, Tensor sorted_token_ids, " + "Tensor expert_ids, Tensor num_tokens_post_pad, " + "int top_k, int BLOCK_SIZE_M, int BLOCK_SIZE_N, int BLOCK_SIZE_K, " + "int bit) -> Tensor"); + + m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm); + + m.def( + "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " + "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " + "b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, " + "int b_q_type, SymInt size_m, " + "SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int " + "topk, " + "int moe_block_size, bool replicate_input, bool apply_weights)" + " -> Tensor"); + // conditionally compiled so impl registration is in source file + +#endif } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/moe_align_block_size_kernels.cu b/csrc/moe_align_block_size_kernels.cu deleted file mode 100644 index 1f8d75da83bb..000000000000 --- a/csrc/moe_align_block_size_kernels.cu +++ /dev/null @@ -1,134 +0,0 @@ -#include -#include - -#include -#include - -#include "cuda_compat.h" -#include "dispatch_utils.h" - -#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) - -namespace vllm { - -namespace { -__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, - int32_t col) { - // don't worry about overflow because num_experts is relatively small - return row * total_col + col; -} -} // namespace - -template -__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, - int32_t* sorted_token_ids, - int32_t* expert_ids, - int32_t* total_tokens_post_pad, - int32_t num_experts, - int32_t block_size, size_t numel) { - const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); - const size_t start_idx = threadIdx.x * tokens_per_thread; - - extern __shared__ int32_t shared_mem[]; - - int32_t* tokens_cnts = - shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) - int32_t* cumsum = - shared_mem + (num_experts + 1) * - num_experts; // 1d tensor with shape (num_experts + 1) - - for (int i = 0; i < num_experts; ++i) { - tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; - } - - /** - * In the first step we compute token_cnts[thread_index + 1][expert_index], - * which counts how many tokens in the token shard of thread_index are - * assigned to expert expert_index. - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; - } - - __syncthreads(); - - // For each expert we accumulate the token counts from the different threads. - tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; - for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[index(num_experts, i, threadIdx.x)] += - tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; - } - - __syncthreads(); - - // We accumulate the token counts of all experts in thread 0. - if (threadIdx.x == 0) { - cumsum[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = cumsum[i - 1] + - CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], - block_size) * - block_size; - } - *total_tokens_post_pad = cumsum[num_experts]; - } - - __syncthreads(); - - /** - * For each expert, each thread processes the tokens of the corresponding - * blocks and stores the corresponding expert_id for each block. - */ - for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; - i += block_size) { - expert_ids[i / block_size] = threadIdx.x; - } - - /** - * Each thread processes a token shard, calculating the index of each token - * after sorting by expert number. Given the example topk_ids = - * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, - * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a - * padding value(preset in python). - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - int32_t expert_id = topk_ids[i]; - /** The cumsum[expert_id] stores the starting index of the tokens that the - * expert with expert_id needs to process, and - * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens - * processed by the expert with expert_id within the current thread's token - * shard. - */ - int32_t rank_post_pad = - tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + - cumsum[expert_id]; - sorted_token_ids[rank_post_pad] = i; - ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; - } -} -} // namespace vllm - -void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, - int64_t block_size, torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad) { - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_INTEGRAL_TYPES( - topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - // calc needed amount of shared mem for `tokens_cnts` and `cumsum` - // tensors - const int32_t shared_mem = - ((num_experts + 1) * num_experts + (num_experts + 1)) * - sizeof(int32_t); - - // set dynamic shared mem - auto kernel = vllm::moe_align_block_size_kernel; - AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( - (void*)kernel, shared_mem)); - kernel<<<1, num_experts, shared_mem, stream>>>( - topk_ids.data_ptr(), sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), num_experts, block_size, - topk_ids.numel()); - }); -} diff --git a/csrc/ops.h b/csrc/ops.h index f075850248d1..7434aead57f0 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -3,13 +3,40 @@ #include #include +#include "core/scalar_type.hpp" + +#include + +torch::Tensor weak_ref_tensor(torch::Tensor& tensor) { + // Ensure tensor is on CUDA + if (!tensor.is_cuda()) { + throw std::runtime_error("Tensor must be on CUDA device"); + } + + // Get the raw data pointer + void* data_ptr = tensor.data_ptr(); + + // Get tensor sizes and strides + std::vector sizes = tensor.sizes().vec(); + std::vector strides = tensor.strides().vec(); + + // Get tensor options (dtype, device) + auto options = tensor.options(); + + // Create a new tensor from the raw data pointer + auto new_tensor = torch::from_blob(data_ptr, sizes, strides, options); + + return new_tensor; +} + void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + int64_t max_seq_len, const std::optional& alibi_slopes, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); @@ -18,9 +45,10 @@ void paged_attention_v2( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + int64_t max_seq_len, const std::optional& alibi_slopes, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); @@ -30,6 +58,24 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); +void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, + torch::Tensor& weight, torch::Tensor& scale, + double epsilon); + +void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& residual, + torch::Tensor& weight, + torch::Tensor& scale, double epsilon); + +void rms_norm_dynamic_per_token_quant(torch::Tensor& out, + torch::Tensor const& input, + torch::Tensor const& weight, + torch::Tensor& scales, + double const epsilon, + std::optional scale_ub, + std::optional residual); + void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); @@ -42,31 +88,47 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void silu_and_mul(torch::Tensor& out, torch::Tensor& input); +void mul_and_silu(torch::Tensor& out, torch::Tensor& input); + void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); +void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input, + double threshold); + void gelu_new(torch::Tensor& out, torch::Tensor& input); void gelu_fast(torch::Tensor& out, torch::Tensor& input); void gelu_quick(torch::Tensor& out, torch::Tensor& input); -void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, - torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, torch::Tensor& block_tables); +void advance_step_flashattn(int64_t num_seqs, int64_t num_queries, + int64_t block_size, torch::Tensor& input_tokens, + torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, + torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, + torch::Tensor& block_tables); + +void advance_step_flashinfer( + int64_t num_seqs, int64_t num_queries, int64_t block_size, + torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, torch::Tensor& block_tables, + torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, + torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds); #ifndef USE_ROCM torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, const torch::Tensor& codebooks, const torch::Tensor& scales, - const torch::Tensor& codebook_partition_sizes, + const std::vector& codebook_partition_sizes, const std::optional& bias); -torch::Tensor aqlm_dequant(const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& codebook_partition_sizes); +torch::Tensor aqlm_dequant( + const torch::Tensor& codes, const torch::Tensor& codebooks, + const std::vector& codebook_partition_sizes); torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, @@ -77,54 +139,72 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _zeros, int64_t split_k_iters, int64_t thx, int64_t thy); -torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t size_m, int64_t size_n, int64_t size_k); - -torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_meta, - torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, - int64_t size_k); - -torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& b_zeros, - torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, bool has_zp, - bool use_fp32_reduce); - -torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, - int64_t size_k, int64_t size_n, - int64_t num_bits); - -torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, - int64_t size_n, int64_t num_bits); - -torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k); +torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm); +#endif + +torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, + int64_t n); + +torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, + int64_t type, int64_t row); + +torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, + int64_t row); + +torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W, + torch::Tensor sorted_token_ids, + torch::Tensor expert_ids, + torch::Tensor num_tokens_post_padded, int64_t type, + int64_t row, int64_t top_k, int64_t tokens); + +int64_t ggml_moe_get_block_size(int64_t type); + +#ifndef USE_ROCM +bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability); bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); +bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability); + +void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A, + torch::Tensor const& B, torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha); void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, - c10::optional const& bias); + std::optional const& bias); + +void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + std::optional const& azp, + std::optional const& bias); + +bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability); + +void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& e, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias); +std::vector cutlass_sparse_compress(torch::Tensor const& a); + +void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, + torch::Tensor& output_scale, + torch::Tensor const& input_scale); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, - torch::Tensor const& scale); + torch::Tensor const& scale, + std::optional const& azp); void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, - torch::Tensor& scales); - -void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, - torch::Tensor lookup_table); + torch::Tensor& scales, + std::optional const& azp); torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, @@ -141,31 +221,48 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, void dynamic_per_token_scaled_fp8_quant( torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, - c10::optional const& scale_ub); - -void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, - int64_t block_size, torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad); + std::optional const& scale_ub); + +void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, + const torch::Tensor& A, const torch::Tensor& B, + const torch::Tensor& C, + const std::optional& D_, + const std::optional& z_, + const std::optional& delta_bias_, + bool delta_softplus, + const std::optional& query_start_loc, + const std::optional& cache_indices, + const std::optional& has_initial_state, + const torch::Tensor& ssm_states, int64_t pad_slot_id); + +void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state, + const at::Tensor& weight, + const std::optional& bias_, + bool silu_activation, + const std::optional& cache_seqlens_, + const std::optional& conv_state_indices_, + int64_t pad_slot_id); + +void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, + const std::optional& bias_, + const std::optional& conv_states, + const std::optional& query_start_loc, + const std::optional& cache_indices, + const std::optional& has_initial_state, + bool silu_activation, int64_t pad_slot_id); #ifndef USE_ROCM using fptr_t = int64_t; -fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, - const std::vector& handles, - const std::vector& offsets, int64_t rank, - bool full_nvlink); -bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, - bool full_nvlink); -void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); -void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, - torch::Tensor& out); +fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, + torch::Tensor& rank_data, int64_t rank, bool full_nvlink); +void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, + fptr_t reg_buffer, int64_t reg_buffer_sz_bytes); void dispose(fptr_t _fa); int64_t meta_size(); -void register_buffer(fptr_t _fa, torch::Tensor& t, - const std::vector& handles, - const std::vector& offsets); -std::tuple> get_graph_buffer_ipc_meta( - fptr_t _fa); -void register_graph_buffers(fptr_t _fa, const std::vector& handles, +void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs); +std::tuple, std::vector> +get_graph_buffer_ipc_meta(fptr_t _fa); +void register_graph_buffers(fptr_t _fa, + const std::vector>& handles, const std::vector>& offsets); #endif diff --git a/csrc/permute_cols.cu b/csrc/permute_cols.cu new file mode 100644 index 000000000000..f51fa73298cc --- /dev/null +++ b/csrc/permute_cols.cu @@ -0,0 +1,88 @@ +#include + +#include +#include + +#include + +static constexpr int default_threads = 256; +static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +// Currently only supports 16bit types (since we permute half types) +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + int start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = std::max(finish_row - start_row, 0); + + int row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int offset = row * row_stride; + + half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); + half* out_half = reinterpret_cast(out_int4_ptr + offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +// More efficient version of A[..., perm] +// taken from gptq_marlin.cu +torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + auto dev = A.get_device(); + auto stream = at::cuda::getCurrentCUDAStream(dev); + + TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16, + "Currently only 16bit types are supported"); + TORCH_CHECK(A.is_contiguous(), "A must be contiguous"); + TORCH_CHECK(A.size(-1) % 8 == 0, + "A columns must be a multiple of 8 (128bits)"); + auto A_2d = A.view({-1, A.size(-1)}); + + torch::Tensor D = torch::empty_like(A); + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + int block_rows = div_ceil(A_2d.size(0), sms); + permute_cols_kernel<<>>( + reinterpret_cast(A_2d.const_data_ptr()), + perm.const_data_ptr(), reinterpret_cast(D.mutable_data_ptr()), + A_2d.size(0), A_2d.size(1), block_rows); + return D; +} \ No newline at end of file diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 97184a873559..c085d31a3e9b 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -124,18 +124,54 @@ __global__ void batched_rotary_embedding_kernel( void rotary_embedding( torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or - // [num_tokens, num_heads * head_size] + // [num_tokens, num_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or - // [num_tokens, num_kv_heads * head_size] + // [num_tokens, num_kv_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] int64_t head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] bool is_neox) { - int64_t num_tokens = query.numel() / query.size(-1); + // num_tokens = batch_size * seq_len + int64_t num_tokens = positions.numel(); + int positions_ndim = positions.dim(); + + // Make sure num_tokens dim is consistent across positions, query, and key. + TORCH_CHECK( + positions_ndim == 1 || positions_ndim == 2, + "positions must have shape [num_tokens] or [batch_size, seq_len]"); + if (positions_ndim == 1) { + TORCH_CHECK( + query.size(0) == positions.size(0) && key.size(0) == positions.size(0), + "query, key and positions must have the same number of tokens"); + } + if (positions_ndim == 2) { + TORCH_CHECK( + query.size(0) == positions.size(0) && + key.size(0) == positions.size(0) && + query.size(1) == positions.size(1) && + key.size(1) == positions.size(1), + "query, key and positions must have the same batch_size and seq_len"); + } + + // Make sure head_size is valid for query and key + // hidden_size = num_heads * head_size + int query_hidden_size = query.numel() / num_tokens; + int key_hidden_size = key.numel() / num_tokens; + TORCH_CHECK(query_hidden_size % head_size == 0); + TORCH_CHECK(key_hidden_size % head_size == 0); + + // Make sure query and key have consistent number of heads + int num_heads = query_hidden_size / head_size; + int num_kv_heads = key_hidden_size / head_size; + TORCH_CHECK(num_heads % num_kv_heads == 0); + int rot_dim = cos_sin_cache.size(1); - int num_heads = query.size(-1) / head_size; - int num_kv_heads = key.size(-1) / head_size; - int64_t query_stride = query.stride(-2); - int64_t key_stride = key.stride(-2); + int seq_dim_idx = positions_ndim - 1; + int64_t query_stride = query.stride(seq_dim_idx); + int64_t key_stride = key.stride(seq_dim_idx); dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); @@ -165,19 +201,58 @@ and process in batched manner. void batched_rotary_embedding( torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or - // [num_tokens, num_heads * head_size] + // [num_tokens, num_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or - // [num_tokens, num_kv_heads * head_size] + // [num_tokens, num_kv_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] int64_t head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] bool is_neox, int64_t rot_dim, - torch::Tensor& cos_sin_cache_offsets // [num_tokens] + torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size] ) { + // num_tokens = batch_size * seq_len int64_t num_tokens = cos_sin_cache_offsets.size(0); - int num_heads = query.size(-1) / head_size; - int num_kv_heads = key.size(-1) / head_size; - int64_t query_stride = query.stride(-2); - int64_t key_stride = key.stride(-2); + TORCH_CHECK( + positions.size(0) == num_tokens || positions.numel() == num_tokens, + "positions must have the same num_tokens or batch_size as " + "cos_sin_cache_offsets"); + + int positions_ndim = positions.dim(); + // Make sure num_tokens dim is consistent across positions, query, and key. + TORCH_CHECK( + positions_ndim == 1 || positions_ndim == 2, + "positions must have shape [num_tokens] or [batch_size, seq_len]"); + if (positions_ndim == 1) { + TORCH_CHECK( + query.size(0) == positions.size(0) && key.size(0) == positions.size(0), + "query, key and positions must have the same number of tokens"); + } + if (positions_ndim == 2) { + TORCH_CHECK( + query.size(0) == positions.size(0) && + key.size(0) == positions.size(0) && + query.size(1) == positions.size(1) && + key.size(1) == positions.size(1), + "query, key and positions must have the same batch_size and seq_len"); + } + + // Make sure head_size is valid for query and key + int query_hidden_size = query.numel() / num_tokens; + int key_hidden_size = key.numel() / num_tokens; + TORCH_CHECK(query_hidden_size % head_size == 0); + TORCH_CHECK(key_hidden_size % head_size == 0); + + // Make sure query and key have concistent number of heads + int num_heads = query_hidden_size / head_size; + int num_kv_heads = key_hidden_size / head_size; + TORCH_CHECK(num_heads % num_kv_heads == 0); + + int seq_dim_idx = positions_ndim - 1; + int64_t query_stride = query.stride(seq_dim_idx); + int64_t key_stride = key.stride(seq_dim_idx); dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu index 0e537ddd6c4c..fea4bc2ca0d8 100644 --- a/csrc/prepare_inputs/advance_step.cu +++ b/csrc/prepare_inputs/advance_step.cu @@ -12,13 +12,22 @@ namespace prepare_inputs { // template -__global__ void advance_step_kernel(int num_seqs, int num_queries, - int block_size, long* input_tokens_ptr, - long const* sampled_token_ids_ptr, - long* input_positions_ptr, - int* seq_lens_ptr, long* slot_mapping_ptr, - int const* block_tables_ptr, - int64_t const block_tables_stride) { +__global__ void advance_step_flashattn_kernel( + int num_seqs, int num_queries, int block_size, long* input_tokens_ptr, + long const* sampled_token_ids_ptr, long* input_positions_ptr, + int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr, + int64_t const block_tables_stride) { + int const n_pad = num_seqs - num_queries; + if (n_pad && blockIdx.x == 0) { + // Handle cuda graph padding + int const offset = num_queries; + for (int i = threadIdx.x; i < n_pad; i += blockDim.x) { + input_tokens_ptr[offset + i] = 0; + input_positions_ptr[offset + i] = 0; + slot_mapping_ptr[offset + i] = -1; + } + } + int num_query_blocks = div_ceil(num_queries, num_threads); if (blockIdx.x >= num_query_blocks) { @@ -54,7 +63,7 @@ __global__ void advance_step_kernel(int num_seqs, int num_queries, slot_mapping_ptr[cur_query_id] = slot_num; } -inline void verify_tensor(std::string const& name, torch::Tensor& t, +inline void verify_tensor(std::string const& name, torch::Tensor const& t, int64_t const size_0, int64_t const size_1, c10::ScalarType const type) { bool size_0_cond = true; @@ -79,16 +88,117 @@ inline void verify_tensor(std::string const& name, torch::Tensor& t, } } -void advance_step(int num_seqs, int num_queries, int block_size, - torch::Tensor& input_tokens, // type: long - torch::Tensor& sampled_token_ids, // type: long - torch::Tensor& input_positions, // type: long - torch::Tensor& seq_lens, // type: int - torch::Tensor& slot_mapping, // type: long - torch::Tensor& block_tables) { // type: int +/// each thread processes a block per query +__global__ void advance_step_flashinfer_kernel( + int num_threads, int num_seqs, int num_queries, int block_size, + long* input_tokens_ptr, long const* sampled_token_ids_ptr, + long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr, + int const* block_tables_ptr, int64_t const block_tables_stride, + int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) { + int const n_pad = num_seqs - num_queries; + if (n_pad && blockIdx.x == 0) { + // Handle cuda graph padding + int const offset = num_queries; + for (int i = threadIdx.x; i < n_pad; i += blockDim.x) { + input_tokens_ptr[offset + i] = 0; + input_positions_ptr[offset + i] = 0; + slot_mapping_ptr[offset + i] = -1; + } + } + int num_query_blocks = div_ceil(num_queries, num_threads); + + if (blockIdx.x < num_query_blocks) { + int cur_query_id = blockIdx.x * num_threads + threadIdx.x; + + if (cur_query_id < num_queries) { + // Update input_tokens + input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id]; + + int seq_len = seq_lens_ptr[cur_query_id]; + int next_seq_len = seq_len + 1; + int next_input_pos = next_seq_len - 1; + + // Update seq_lens + seq_lens_ptr[cur_query_id] = next_seq_len; + // Update input_positions + input_positions_ptr[cur_query_id] = next_input_pos; + + int const* seq_block_tables_ptr = + block_tables_ptr + block_tables_stride * cur_query_id; + + int block_index = next_input_pos / block_size; + int block_offset = next_input_pos % block_size; + + // Update paged_kv_last_page_len + paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1; + + int slot_num = + seq_block_tables_ptr[block_index] * block_size + block_offset; + // Update slot_mapping + slot_mapping_ptr[cur_query_id] = slot_num; + block_table_bound_ptr[cur_query_id] = div_ceil(next_seq_len, block_size); + } + } +} + +__global__ void advance_step_flashinfer_indptr_kernel( + int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr, + int* block_table_bound_ptr) { + int idx = blockIdx.x * num_threads + threadIdx.x; + // Update paged_kv_indptr + if (idx == 0) { + paged_kv_indptr_ptr[idx] = 0; + } + if (idx < num_queries) { + int sum = 0; + for (int i = 0; i <= idx; ++i) { + sum += block_table_bound_ptr[i]; + } + paged_kv_indptr_ptr[idx + 1] = sum; + } +} + +__global__ void advance_step_flashinfer_indices_kernel( + int num_seqs, int num_queries, int const* block_tables_ptr, + int64_t const max_num_blocks_per_seq, int* paged_kv_indices_ptr, + int* paged_kv_indptr_ptr, int* block_table_bound_ptr) { + // note: max_num_blocks_per_seq = block_tables.stride(0) + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + // when cuda graphs are enabled, paged_kv_indptr tensor + // has to be updated for the padded queries + // tid represents a query# for paged_kv_indptr tensor + if (num_queries < tid && tid <= num_seqs) { + paged_kv_indptr_ptr[tid] = paged_kv_indptr_ptr[num_queries]; + } + + // each thread processes a block_ptr in block_tables + // block_tables shape: [num_queries, max_num_blocks_per_seq] + // paged_kv_indices is flattened block_tables. + for (int idx = tid; idx < (num_seqs * max_num_blocks_per_seq); + idx += (gridDim.x * blockDim.x)) { + // block_tables-row = paged_kv_indptr[queryNum] + int queryNum = idx / max_num_blocks_per_seq; + int col = idx % max_num_blocks_per_seq; + if (queryNum < num_queries && col < block_table_bound_ptr[queryNum]) { + int indices_arr_idx = paged_kv_indptr_ptr[queryNum] + col; + int block_tables_idx = queryNum * max_num_blocks_per_seq + col; + paged_kv_indices_ptr[indices_arr_idx] = + block_tables_ptr[block_tables_idx]; + } + } +} + +void advance_step_flashattn(int num_seqs, int num_queries, int block_size, + torch::Tensor& input_tokens, // type: long + torch::Tensor& sampled_token_ids, // type: long + torch::Tensor& input_positions, // type: long + torch::Tensor& seq_lens, // type: int + torch::Tensor& slot_mapping, // type: long + torch::Tensor& block_tables) { // type: int if (logging) { - printf("advance_step:\n"); + printf("advance_step_flashattn:\n"); printf(" num_seqs = %d\n", num_seqs); printf(" num_queries = %d\n", num_queries); printf(" block_size = %d\n", block_size); @@ -108,24 +218,120 @@ void advance_step(int num_seqs, int num_queries, int block_size, int blocks; cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); - advance_step_kernel<<>>( - num_seqs, num_queries, block_size, + advance_step_flashattn_kernel + <<>>( + num_seqs, num_queries, block_size, + reinterpret_cast(input_tokens.data_ptr()), + reinterpret_cast(sampled_token_ids.data_ptr()), + reinterpret_cast(input_positions.data_ptr()), + reinterpret_cast(seq_lens.data_ptr()), + reinterpret_cast(slot_mapping.data_ptr()), + reinterpret_cast(block_tables.data_ptr()), + block_tables.stride(0)); +} + +void advance_step_flashinfer( + int num_seqs, int num_queries, int block_size, + torch::Tensor& input_tokens, // type: long + torch::Tensor& sampled_token_ids, // type: long + torch::Tensor& input_positions, // type: long + torch::Tensor& seq_lens, // type: int + torch::Tensor& slot_mapping, // type: long + torch::Tensor& block_tables, // type: int + torch::Tensor& paged_kv_indices, // type: int + torch::Tensor& paged_kv_indptr, // type: int + torch::Tensor& paged_kv_last_page_len, // type: int + torch::Tensor& block_table_bound) { // type: int + + if (logging) { + printf("advance_step_flashinfer:\n"); + printf(" num_seqs = %d\n", num_seqs); + printf(" num_queries = %d\n", num_queries); + printf(" block_size = %d\n", block_size); + printf(" block_tables.stride(0) = %zu\n", block_tables.stride(0)); + } + // Verify all tensors + verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); + // verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, + // at::kLong); + verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong); + verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt); + verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong); + verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt); + + verify_tensor("paged_kv_indices", paged_kv_indices, -1, -1, at::kInt); + verify_tensor("paged_kv_indptr", paged_kv_indptr, num_seqs + 1, -1, at::kInt); + verify_tensor("paged_kv_last_page_len", paged_kv_last_page_len, num_seqs, -1, + at::kInt); + + verify_tensor("block_table_bound", block_table_bound, num_seqs, -1, at::kInt); + + int dev = sampled_token_ids.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + + int blocks; + int threads; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); + cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev); + + [[maybe_unused]] int block_tables_stride = block_tables.stride(0); + TORCH_CHECK((blocks * threads > num_queries), + "multi-step: not enough threads to map to num_queries = ", + num_queries, " block_tables.stride(0) = ", block_tables.stride(0), + " blocks = ", blocks, " max_threads = ", threads); + if (logging) { + printf("launching kernels with %d blocks and %d threads\n", blocks, + threads); + } + advance_step_flashinfer_kernel<<>>( + threads, num_seqs, num_queries, block_size, reinterpret_cast(input_tokens.data_ptr()), reinterpret_cast(sampled_token_ids.data_ptr()), reinterpret_cast(input_positions.data_ptr()), reinterpret_cast(seq_lens.data_ptr()), reinterpret_cast(slot_mapping.data_ptr()), reinterpret_cast(block_tables.data_ptr()), - block_tables.stride(0)); + block_tables.stride(0), + reinterpret_cast(paged_kv_last_page_len.data_ptr()), + reinterpret_cast(block_table_bound.data_ptr())); + + advance_step_flashinfer_indptr_kernel<<>>( + threads, num_seqs, num_queries, + reinterpret_cast(paged_kv_indptr.data_ptr()), + reinterpret_cast(block_table_bound.data_ptr())); + + advance_step_flashinfer_indices_kernel<<>>( + num_seqs, num_queries, + reinterpret_cast(block_tables.data_ptr()), + block_tables.stride(0), + reinterpret_cast(paged_kv_indices.data_ptr()), + reinterpret_cast(paged_kv_indptr.data_ptr()), + reinterpret_cast(block_table_bound.data_ptr())); } } // namespace prepare_inputs -void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, - torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, torch::Tensor& block_tables) { - prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens, - sampled_token_ids, input_positions, seq_lens, - slot_mapping, block_tables); -} \ No newline at end of file +void advance_step_flashattn(int64_t num_seqs, int64_t num_queries, + int64_t block_size, torch::Tensor& input_tokens, + torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, + torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, + torch::Tensor& block_tables) { + prepare_inputs::advance_step_flashattn( + num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, + input_positions, seq_lens, slot_mapping, block_tables); +} + +void advance_step_flashinfer( + int64_t num_seqs, int64_t num_queries, int64_t block_size, + torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, torch::Tensor& block_tables, + torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, + torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) { + prepare_inputs::advance_step_flashinfer( + num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, + input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices, + paged_kv_indptr, paged_kv_last_page_len, block_table_bound); +} diff --git a/csrc/punica/LICENSE b/csrc/punica/LICENSE deleted file mode 100644 index a46e2cdcadf7..000000000000 --- a/csrc/punica/LICENSE +++ /dev/null @@ -1,217 +0,0 @@ -Contains code from https://github.com/punica-ai/punica - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright {yyyy} {name of copyright owner} - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ------------------------------------------------------------------------------------- - -This product bundles various third-party components under other open source licenses. -This section summarizes those components and their licenses. See licenses/ -for text of these licenses. - - -Apache-2.0 -* third_party/nvbench (with LLVM exception) -* third_party/flashinfer - -BSD-3-Clause: -* third_party/cutlass \ No newline at end of file diff --git a/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu deleted file mode 100644 index 86846c274c90..000000000000 --- a/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu deleted file mode 100644 index de39c3121f5d..000000000000 --- a/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h deleted file mode 100644 index 2c8d007d8719..000000000000 --- a/csrc/punica/bgmv/bgmv_config.h +++ /dev/null @@ -1,218 +0,0 @@ -#pragma once - -template -void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, - const W_T *__restrict__ W, - const int64_t *__restrict__ indicies, int64_t y_offset, - int64_t full_y_size, int64_t batch_size, int64_t num_layers, - int64_t layer_idx, float scale); - -// clang-format off - -#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \ - f(in_T, out_T, W_T, narrow, 128) \ - f(in_T, out_T, W_T, narrow, 256) \ - f(in_T, out_T, W_T, narrow, 512) \ - f(in_T, out_T, W_T, narrow, 640) \ - f(in_T, out_T, W_T, narrow, 768) \ - f(in_T, out_T, W_T, narrow, 896) \ - f(in_T, out_T, W_T, narrow, 1024) \ - f(in_T, out_T, W_T, narrow, 1152) \ - f(in_T, out_T, W_T, narrow, 1216) \ - f(in_T, out_T, W_T, narrow, 1280) \ - f(in_T, out_T, W_T, narrow, 1536) \ - f(in_T, out_T, W_T, narrow, 1664) \ - f(in_T, out_T, W_T, narrow, 1728) \ - f(in_T, out_T, W_T, narrow, 1792) \ - f(in_T, out_T, W_T, narrow, 2048) \ - f(in_T, out_T, W_T, narrow, 2240) \ - f(in_T, out_T, W_T, narrow, 2304) \ - f(in_T, out_T, W_T, narrow, 2368) \ - f(in_T, out_T, W_T, narrow, 2432) \ - f(in_T, out_T, W_T, narrow, 2560) \ - f(in_T, out_T, W_T, narrow, 2752) \ - f(in_T, out_T, W_T, narrow, 2816) \ - f(in_T, out_T, W_T, narrow, 3072) \ - f(in_T, out_T, W_T, narrow, 3328) \ - f(in_T, out_T, W_T, narrow, 3456) \ - f(in_T, out_T, W_T, narrow, 3584) \ - f(in_T, out_T, W_T, narrow, 3712) \ - f(in_T, out_T, W_T, narrow, 4096) \ - f(in_T, out_T, W_T, narrow, 4480) \ - f(in_T, out_T, W_T, narrow, 4608) \ - f(in_T, out_T, W_T, narrow, 4736) \ - f(in_T, out_T, W_T, narrow, 4864) \ - f(in_T, out_T, W_T, narrow, 5120) \ - f(in_T, out_T, W_T, narrow, 5504) \ - f(in_T, out_T, W_T, narrow, 5632) \ - f(in_T, out_T, W_T, narrow, 5888) \ - f(in_T, out_T, W_T, narrow, 6144) \ - f(in_T, out_T, W_T, narrow, 6400) \ - f(in_T, out_T, W_T, narrow, 6848) \ - f(in_T, out_T, W_T, narrow, 6912) \ - f(in_T, out_T, W_T, narrow, 7168) \ - f(in_T, out_T, W_T, narrow, 7424) \ - f(in_T, out_T, W_T, narrow, 8192) \ - f(in_T, out_T, W_T, narrow, 8960) \ - f(in_T, out_T, W_T, narrow, 9216) \ - f(in_T, out_T, W_T, narrow, 9472) \ - f(in_T, out_T, W_T, narrow, 10240) \ - f(in_T, out_T, W_T, narrow, 11008) \ - f(in_T, out_T, W_T, narrow, 11264) \ - f(in_T, out_T, W_T, narrow, 12288) \ - f(in_T, out_T, W_T, narrow, 13696) \ - f(in_T, out_T, W_T, narrow, 13824) \ - f(in_T, out_T, W_T, narrow, 14336) \ - f(in_T, out_T, W_T, narrow, 14784) \ - f(in_T, out_T, W_T, narrow, 14848) \ - f(in_T, out_T, W_T, narrow, 15360) \ - f(in_T, out_T, W_T, narrow, 16384) \ - f(in_T, out_T, W_T, narrow, 18944) \ - f(in_T, out_T, W_T, narrow, 20480) \ - f(in_T, out_T, W_T, narrow, 22016) \ - f(in_T, out_T, W_T, narrow, 22528) \ - f(in_T, out_T, W_T, narrow, 24576) \ - f(in_T, out_T, W_T, narrow, 27392) \ - f(in_T, out_T, W_T, narrow, 27648) \ - f(in_T, out_T, W_T, narrow, 28672) \ - f(in_T, out_T, W_T, narrow, 29568) \ - f(in_T, out_T, W_T, narrow, 29696) \ - f(in_T, out_T, W_T, narrow, 32000) \ - f(in_T, out_T, W_T, narrow, 32256) \ - f(in_T, out_T, W_T, narrow, 32512) \ - f(in_T, out_T, W_T, narrow, 32768) \ - f(in_T, out_T, W_T, narrow, 33024) \ - f(in_T, out_T, W_T, narrow, 36864) \ - f(in_T, out_T, W_T, narrow, 43264) \ - f(in_T, out_T, W_T, narrow, 49152) \ - f(in_T, out_T, W_T, narrow, 49408) \ - f(in_T, out_T, W_T, narrow, 60544) \ - f(in_T, out_T, W_T, narrow, 60672) \ - f(in_T, out_T, W_T, narrow, 64000) \ - f(in_T, out_T, W_T, narrow, 64256) \ - f(in_T, out_T, W_T, narrow, 64512) \ - f(in_T, out_T, W_T, narrow, 102400) \ - f(in_T, out_T, W_T, narrow, 102656) \ - f(in_T, out_T, W_T, narrow, 102912) \ - f(in_T, out_T, W_T, narrow, 128000) \ - f(in_T, out_T, W_T, narrow, 128256) \ - f(in_T, out_T, W_T, narrow, 128512) \ - - -// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA -// and vllm/tests/lora/test_punica.py - -// Used for defining kernels going from the variety of -// dim in to the narrow dim out - // Using it for the fully sharded column - // parallel LoRA A which splits the rank dim -#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \ - f(in_T, out_T, W_T, 128, narrow) \ - f(in_T, out_T, W_T, 256, narrow) \ - f(in_T, out_T, W_T, 512, narrow) \ - f(in_T, out_T, W_T, 640, narrow) \ - f(in_T, out_T, W_T, 768, narrow) \ - f(in_T, out_T, W_T, 896, narrow) \ - f(in_T, out_T, W_T, 1024, narrow) \ - f(in_T, out_T, W_T, 1152, narrow) \ - f(in_T, out_T, W_T, 1216, narrow) \ - f(in_T, out_T, W_T, 1280, narrow) \ - f(in_T, out_T, W_T, 1536, narrow) \ - f(in_T, out_T, W_T, 1664, narrow) \ - f(in_T, out_T, W_T, 1728, narrow) \ - f(in_T, out_T, W_T, 1792, narrow) \ - f(in_T, out_T, W_T, 2048, narrow) \ - f(in_T, out_T, W_T, 2240, narrow) \ - f(in_T, out_T, W_T, 2304, narrow) \ - f(in_T, out_T, W_T, 2368, narrow) \ - f(in_T, out_T, W_T, 2432, narrow) \ - f(in_T, out_T, W_T, 2560, narrow) \ - f(in_T, out_T, W_T, 2752, narrow) \ - f(in_T, out_T, W_T, 2816, narrow) \ - f(in_T, out_T, W_T, 3072, narrow) \ - f(in_T, out_T, W_T, 3328, narrow) \ - f(in_T, out_T, W_T, 3456, narrow) \ - f(in_T, out_T, W_T, 3584, narrow) \ - f(in_T, out_T, W_T, 3712, narrow) \ - f(in_T, out_T, W_T, 4096, narrow) \ - f(in_T, out_T, W_T, 4480, narrow) \ - f(in_T, out_T, W_T, 4608, narrow) \ - f(in_T, out_T, W_T, 4736, narrow) \ - f(in_T, out_T, W_T, 4864, narrow) \ - f(in_T, out_T, W_T, 5120, narrow) \ - f(in_T, out_T, W_T, 5504, narrow) \ - f(in_T, out_T, W_T, 5632, narrow) \ - f(in_T, out_T, W_T, 5888, narrow) \ - f(in_T, out_T, W_T, 6144, narrow) \ - f(in_T, out_T, W_T, 6400, narrow) \ - f(in_T, out_T, W_T, 6848, narrow) \ - f(in_T, out_T, W_T, 6912, narrow) \ - f(in_T, out_T, W_T, 7168, narrow) \ - f(in_T, out_T, W_T, 7424, narrow) \ - f(in_T, out_T, W_T, 8192, narrow) \ - f(in_T, out_T, W_T, 8960, narrow) \ - f(in_T, out_T, W_T, 9216, narrow) \ - f(in_T, out_T, W_T, 9472, narrow) \ - f(in_T, out_T, W_T, 10240, narrow) \ - f(in_T, out_T, W_T, 11008, narrow) \ - f(in_T, out_T, W_T, 11264, narrow) \ - f(in_T, out_T, W_T, 12288, narrow) \ - f(in_T, out_T, W_T, 13696, narrow) \ - f(in_T, out_T, W_T, 13824, narrow) \ - f(in_T, out_T, W_T, 14336, narrow) \ - f(in_T, out_T, W_T, 14784, narrow) \ - f(in_T, out_T, W_T, 14848, narrow) \ - f(in_T, out_T, W_T, 15360, narrow) \ - f(in_T, out_T, W_T, 16384, narrow) \ - f(in_T, out_T, W_T, 18944, narrow) \ - f(in_T, out_T, W_T, 20480, narrow) \ - f(in_T, out_T, W_T, 22016, narrow) \ - f(in_T, out_T, W_T, 22528, narrow) \ - f(in_T, out_T, W_T, 24576, narrow) \ - f(in_T, out_T, W_T, 27392, narrow) \ - f(in_T, out_T, W_T, 27648, narrow) \ - f(in_T, out_T, W_T, 28672, narrow) \ - f(in_T, out_T, W_T, 29568, narrow) \ - f(in_T, out_T, W_T, 29696, narrow) \ - f(in_T, out_T, W_T, 32000, narrow) \ - f(in_T, out_T, W_T, 32256, narrow) \ - f(in_T, out_T, W_T, 32512, narrow) \ - f(in_T, out_T, W_T, 32768, narrow) \ - f(in_T, out_T, W_T, 33024, narrow) \ - f(in_T, out_T, W_T, 36864, narrow) \ - f(in_T, out_T, W_T, 43264, narrow) \ - f(in_T, out_T, W_T, 49152, narrow) \ - f(in_T, out_T, W_T, 49408, narrow) \ - f(in_T, out_T, W_T, 60544, narrow) \ - f(in_T, out_T, W_T, 60672, narrow) \ - f(in_T, out_T, W_T, 64000, narrow) \ - f(in_T, out_T, W_T, 64256, narrow) \ - f(in_T, out_T, W_T, 64512, narrow) \ - f(in_T, out_T, W_T, 102400, narrow) \ - f(in_T, out_T, W_T, 102656, narrow) \ - f(in_T, out_T, W_T, 102912, narrow) \ - f(in_T, out_T, W_T, 128000, narrow) \ - f(in_T, out_T, W_T, 128256, narrow) \ - f(in_T, out_T, W_T, 128512, narrow) \ -// Keep above in sync with vllm/lora/layers::SamplerWithLoRA - - -// Keep this in sync with vllm/config::LoRAConfig -#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) - - -#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ - FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \ - FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \ - FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \ - f(in_T, out_T, W_T, 8, 64) \ - f(in_T, out_T, W_T, 16, 64) \ - f(in_T, out_T, W_T, 32, 64) \ - f(in_T, out_T, W_T, 64, 64) - -// clang-format on diff --git a/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu deleted file mode 100644 index d225a1eaa82b..000000000000 --- a/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu deleted file mode 100644 index b37d288a7556..000000000000 --- a/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu deleted file mode 100644 index a1ab2deecbab..000000000000 --- a/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu deleted file mode 100644 index 0b35bf569989..000000000000 --- a/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh deleted file mode 100644 index 8a3b8403b4a6..000000000000 --- a/csrc/punica/bgmv/bgmv_impl.cuh +++ /dev/null @@ -1,451 +0,0 @@ -#pragma once - -#include -#ifndef USE_ROCM -#include -#else -#include -#endif -#ifndef USE_ROCM -#include -#endif -#include -#include -#include - -#include "vec_dtypes.cuh" - -namespace cg = cooperative_groups; - -#ifdef USE_ROCM -template -__host__ __device__ -inline void* memcpy_blocking(void *dst, const void *src) { - // Does not handle the case of long datatypes - char *d = reinterpret_cast(dst); - const char *s = reinterpret_cast(src); - size_t i = 0; -#pragma unroll - for (i = 0; i < len; ++i) { - d[i] = s[i]; - } - return dst; -} -#endif - -#ifndef USE_ROCM - -// nthrs = (32, 4) -template -__global__ void -bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, - const W_T *__restrict__ W, - const int64_t *__restrict__ indicies, int64_t y_offset, - int64_t full_y_size, int64_t num_layers, int64_t layer_idx, - float scale) { - size_t batch_idx = blockIdx.y; - int64_t idx = indicies[batch_idx] * num_layers + layer_idx; - if (idx < 0) { - return; - } - - auto block = cg::this_thread_block(); - size_t j = blockIdx.x; - constexpr size_t num_pipeline_stages = 2; - constexpr size_t tile_size = tx * ty * vec_size; - __shared__ W_T W_shared[num_pipeline_stages * tile_size]; - __shared__ in_T X_shared[num_pipeline_stages * tile_size]; - __shared__ float y_warpwise[ty]; - - size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; - size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; - auto pipe = cuda::make_pipeline(); - - // pipeline load W/X and compute WX; - pipe.producer_acquire(); - cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, - W + (idx * feat_out + j) * feat_in + - (threadIdx.y * tx + threadIdx.x) * vec_size, - cuda::aligned_size_t(W_copy_size), pipe); - cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, - X + (batch_idx * feat_in) + - (threadIdx.y * tx + threadIdx.x) * vec_size, - cuda::aligned_size_t(X_copy_size), pipe); - pipe.producer_commit(); - size_t copy_idx, compute_idx; - float y = 0.f; - vec_t x_vec; - vec_t w_vec; - size_t tile_idx; - -#pragma unroll - for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size; - ++tile_idx) { - copy_idx = tile_idx % num_pipeline_stages; - // pipeline stage: async copy W fragment - pipe.producer_acquire(); - if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) { - cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size, - W + (idx * feat_out + j) * feat_in + - tile_idx * tile_size + - (threadIdx.y * tx + threadIdx.x) * vec_size, - cuda::aligned_size_t(W_copy_size), pipe); - cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size, - X + (batch_idx * feat_in) + tile_idx * tile_size + - (threadIdx.y * tx + threadIdx.x) * vec_size, - cuda::aligned_size_t(X_copy_size), pipe); - } - pipe.producer_commit(); - - compute_idx = (tile_idx - 1) % num_pipeline_stages; - // pipeline stage: compute WX - pipe.consumer_wait(); - block.sync(); - x_vec.load(X_shared + X_shared_offset[compute_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size); - w_vec.load(W_shared + W_shared_offset[compute_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size); - float sum = 0.f; -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - sum += float(w_vec[i]) * float(x_vec[i]) * scale; - } -#pragma unroll - for (size_t offset = tx / 2; offset > 0; offset /= 2) { - sum += __shfl_down_sync(0xffffffff, sum, offset); - } - y_warpwise[threadIdx.y] = sum; - block.sync(); -#pragma unroll - for (size_t i = 0; i < ty; ++i) { - y += y_warpwise[i]; - } - - block.sync(); - pipe.consumer_release(); - } - - compute_idx = (tile_idx - 1) % num_pipeline_stages; - // final pipeline stage - pipe.consumer_wait(); - block.sync(); - x_vec.load(X_shared + X_shared_offset[compute_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size); - w_vec.load(W_shared + W_shared_offset[compute_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size); - float sum = 0.f; -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - sum += float(w_vec[i]) * float(x_vec[i]) * scale; - } -#pragma unroll - for (size_t offset = tx / 2; offset > 0; offset /= 2) { - sum += __shfl_down_sync(0xffffffff, sum, offset); - } - y_warpwise[threadIdx.y] = - ((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in) - ? sum - : 0.f; - block.sync(); -#pragma unroll - for (size_t i = 0; i < ty; ++i) { - y += y_warpwise[i]; - } - - block.sync(); - pipe.consumer_release(); - - // write Y; - if (block.thread_rank() == 0) { - Y[batch_idx * full_y_size + y_offset + j] += static_cast(y); - } -} - -#else - -template -__global__ void -bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, - const W_T *__restrict__ W, - const int64_t *__restrict__ indicies, int64_t y_offset, - int64_t full_y_size, int64_t num_layers, int64_t layer_idx, - float scale) { - size_t batch_idx = blockIdx.y; - int64_t idx = indicies[batch_idx] * num_layers + layer_idx; - if (idx < 0) { - return; - } - - size_t j = blockIdx.x; - constexpr size_t tile_size = tx * ty * vec_size; - constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size; - __shared__ float y_warpwise[ty]; - - float y = 0; - vec_t x_vec; - vec_t w_vec; - size_t tile_idx; - -#pragma unroll - for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { - if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { - x_vec.load(X + (batch_idx * feat_in) + - tile_idx * tile_size + - (threadIdx.y * tx + threadIdx.x) * vec_size); - w_vec.load(W + (idx * feat_out + j) * feat_in + - tile_idx * tile_size + - (threadIdx.y * tx + threadIdx.x) * vec_size); - } - - float sum = 0.f; -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; - } -#pragma unroll - for (size_t offset = tx / 2; offset > 0; offset /= 2) { - sum += VLLM_SHFL_DOWN_SYNC(sum, offset); - } - - __syncthreads(); - - if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { - y += sum; - } - } - - if (threadIdx.x == 0) { - y_warpwise[threadIdx.y] = y; - } - __syncthreads(); - - float y_write = 0.f; -#pragma unroll - for (size_t i = 0; i < ty; ++i) { - y_write += y_warpwise[i]; - } - - // write Y; - if (threadIdx.x == 0 && threadIdx.y == 0) { - size_t y_idx = batch_idx * full_y_size + y_offset + j; - Y[y_idx] = vllm_add(Y[y_idx], convert_type(y_write)); - } -} - -#endif - -// nthrs = (2, 16, 4) -template -__global__ void -bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, - const W_T *__restrict__ W, - const int64_t *__restrict__ indicies, int64_t y_offset, - int64_t full_y_size, int64_t num_layers, int64_t layer_idx, - float scale) { - size_t batch_idx = blockIdx.y; - int64_t idx = indicies[batch_idx] * num_layers + layer_idx; - - if (idx < 0) { - return; - } - - auto block = cg::this_thread_block(); - size_t tile_idx = blockIdx.x; - - // load X; - vec_t x_vec; - x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size); - - // load W; - vec_t w_vec; - w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in + - block.thread_rank() * vec_size); - - float sum = 0.f; -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { -#ifndef USE_ROCM - sum += float(w_vec[i]) * float(x_vec[i]) * scale; -#else - sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; -#endif - } - - cg::thread_block_tile g = cg::tiled_partition(block); -#pragma unroll - for (size_t offset = tx / 2; offset > 0; offset /= 2) { - sum += g.shfl_down(sum, offset); - } - sum = g.shfl(sum, 0); - - if (threadIdx.x == 0) { -#ifndef USE_ROCM - Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + - threadIdx.z * ty + threadIdx.y] += static_cast(sum); -#else - size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + - threadIdx.z * ty + threadIdx.y; - Y[y_idx] = vllm_add(Y[y_idx], convert_type(sum)); -#endif - } -} - -template -void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, - const W_T *__restrict__ W, - const int64_t *__restrict__ indicies, int64_t y_offset, - int64_t full_y_size, int64_t batch_size, int64_t num_layers, - int64_t layer_idx, float scale) { - constexpr size_t vec_size = 8; - constexpr int tz = 4; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if constexpr (feat_in <= feat_out) { - static_assert(feat_in % vec_size == 0); - constexpr int tx = feat_in / vec_size; - - static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) || - (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) || - (8 % tx == 0 && feat_out % (8 / tx * tz) == 0)); - - if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) { - constexpr int ty = 32 / tx; - dim3 nblks(feat_out / (ty * tz), batch_size); - dim3 nthrs(tx, ty, tz); - - bgmv_expand_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) { - constexpr int ty = 16 / tx; - dim3 nblks(feat_out / (ty * tz), batch_size); - dim3 nthrs(tx, ty, tz); - - bgmv_expand_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } else { - constexpr int ty = 8 / tx; - dim3 nblks(feat_out / (ty * tz), batch_size); - dim3 nthrs(tx, ty, tz); - - bgmv_expand_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } - } else { -#ifndef USE_ROCM - static_assert(feat_in % (vec_size * 32) == 0 || - feat_in % (vec_size * 16) == 0 || - feat_in % (vec_size * 8) == 0); - - if constexpr (feat_in % (vec_size * 32) == 0) { - constexpr int tx = 32; - constexpr int ty = 4; - - dim3 nblks(feat_out, batch_size); - dim3 nthrs(tx, ty); - - bgmv_shrink_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } else if constexpr (feat_in % (vec_size / 2 * 32) == 0) { - constexpr int tx = 32; - constexpr int ty = 4; - - dim3 nblks(feat_out, batch_size); - dim3 nthrs(tx, ty); - - bgmv_shrink_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } else if constexpr (feat_in % (vec_size / 2 * 16) == 0) { - constexpr int tx = 16; - constexpr int ty = 4; - - dim3 nblks(feat_out, batch_size); - dim3 nthrs(tx, ty); - - bgmv_shrink_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } -#else - constexpr size_t rocm_warp_size = warpSize; - -#define CHECK_INPUT_TILEABLE_BY(vec_size_) \ - feat_in % (rocm_warp_size * vec_size_) == 0 - -#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \ - if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \ - constexpr size_t vec_size_shrink = vec_size_; \ - constexpr int tx = tx_; \ - constexpr int ty = ty_; \ - dim3 nblks(feat_out, batch_size); \ - dim3 nthrs(tx, ty); \ - bgmv_shrink_kernel \ - <<>>(Y, X, W, indicies, y_offset, \ - full_y_size, num_layers, layer_idx, \ - scale); \ - } - - static_assert(CHECK_INPUT_TILEABLE_BY(32) || - CHECK_INPUT_TILEABLE_BY(16) || - CHECK_INPUT_TILEABLE_BY( 8) || - CHECK_INPUT_TILEABLE_BY( 4) || - CHECK_INPUT_TILEABLE_BY( 2) || - CHECK_INPUT_TILEABLE_BY( 1)); - - LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1) - -#undef CHECK_INPUT_TILEABLE_BY -#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM -#endif - } -} - -#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \ - template void bgmv_kernel( \ - out_T * __restrict__ Y, const in_T *__restrict__ X, \ - const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \ - int64_t y_offset, int64_t full_y_size, int64_t batch_size, \ - int64_t num_layers, int64_t layer_idx, float scale); - -#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \ - INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) - -#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \ - INST_BGMV(narrow, wide, in_T, out_T, W_T) \ - INST_BGMV(wide, narrow, in_T, out_T, W_T) diff --git a/csrc/punica/bgmv/generator.py b/csrc/punica/bgmv/generator.py deleted file mode 100644 index 972df5a7208c..000000000000 --- a/csrc/punica/bgmv/generator.py +++ /dev/null @@ -1,48 +0,0 @@ -DTYPES = ["fp16", "bf16", "fp32"] -DTYPE_MAP = { - "fp16": "nv_half", - "bf16": "nv_bfloat16", - "fp32": "float", -} - -TEMPLATE = """ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype}) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype}) -""".lstrip() # noqa: E501 - -for input_dtype in DTYPES: - for output_dtype in DTYPES: - for weight_dtype in DTYPES: - if weight_dtype == "fp32": - # FP32 weights are not supported. - continue - if output_dtype == "fp32": - # LoRA A matrix. - if input_dtype != weight_dtype: - # NOTE(woosuk): While Punica supports the case where the - # input and weight dtypes are different, we only generate - # the kernels the same dtypes to reduce the binary size. - continue - elif input_dtype == "fp32": - # LoRA B matrix. - if output_dtype != weight_dtype: - # NOTE(woosuk): While Punica supports the case where the - # output and weight dtypes are different, we only generate - # the kernels the same dtypes to reduce the binary size. - continue - elif not (input_dtype == output_dtype == weight_dtype): - # NOTE(woosuk): While Punica supports mixed data types for - # input, output, and weight, we only generate the kernels with - # the same data types to reduce the binary size. - continue - - kernel_definition = TEMPLATE.format( - input_dtype=DTYPE_MAP[input_dtype], - output_dtype=DTYPE_MAP[output_dtype], - weight_dtype=DTYPE_MAP[weight_dtype]) - filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu" - with open(filename, "w") as f: - f.write(kernel_definition) diff --git a/csrc/punica/bgmv/vec_dtypes.cuh b/csrc/punica/bgmv/vec_dtypes.cuh deleted file mode 100644 index 2738892e6dc4..000000000000 --- a/csrc/punica/bgmv/vec_dtypes.cuh +++ /dev/null @@ -1,1325 +0,0 @@ -#ifndef VEC_DTYPES_CUH_ -#define VEC_DTYPES_CUH_ - -#ifdef FLASHINFER_USE_FP8 -#include -#endif -#include - -#include - -#include "../type_convert.h" -#include "../../cuda_compat.h" - -#define FLASHINFER_INLINE \ - inline __attribute__((always_inline)) __device__ __host__ - -template -struct vec_t { - FLASHINFER_INLINE float_t &operator[](size_t i); - FLASHINFER_INLINE const float_t &operator[](size_t i) const; - FLASHINFER_INLINE void fill(float_t val); - FLASHINFER_INLINE void load(const float_t *ptr); - FLASHINFER_INLINE void store(float_t *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src); - template - FLASHINFER_INLINE void cast_load(const T *ptr); - template - FLASHINFER_INLINE void cast_store(T *ptr) const; - FLASHINFER_INLINE static void memcpy(float_t *dst, const float_t *src); -}; - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t &dst) { -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - dst[i] = tgt_float_t(src[i]); - } -} - -template -FLASHINFER_INLINE void cast_load_impl(const src_float_t *src_ptr, - vec_t &dst) { - if constexpr (std::is_same::value) { - dst.load(src_ptr); - } else { - vec_t tmp; - tmp.load(src_ptr); - dst.cast_from(tmp); - } -} - -template -FLASHINFER_INLINE void cast_store_impl(const vec_t &src, - tgt_float_t *dst_ptr) { - if constexpr (std::is_same::value) { - src.store(dst_ptr); - } else { - vec_t tmp; - tmp.cast_from(src); - tmp.store(dst_ptr); - } -} - -#ifdef FLASHINFER_USE_FP8 -/******************* vec_t<__nv_fp8_e4m3> *******************/ - -// __nv_fp8_e4m3 x 1 -template <> -struct vec_t<__nv_fp8_e4m3, 1> { - __nv_fp8_e4m3 data; - - FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { - return ((__nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { - return ((const __nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); - FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, - const __nv_fp8_e4m3 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) { - data = val; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3 *ptr) { - data = *ptr; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::store( - __nv_fp8_e4m3 *ptr) const { - *ptr = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy( - __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { - *dst = *src; -} - -// __nv_fp8_e4m3 x 2 -template <> -struct vec_t<__nv_fp8_e4m3, 2> { - __nv_fp8x2_e4m3 data; - - FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { - return ((__nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { - return ((const __nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); - FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, - const __nv_fp8_e4m3 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) { - data.__x = - (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3 *ptr) { - data = *((__nv_fp8x2_e4m3 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::store( - __nv_fp8_e4m3 *ptr) const { - *((__nv_fp8x2_e4m3 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy( - __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { - *((__nv_fp8x2_e4m3 *)dst) = *((__nv_fp8x2_e4m3 *)src); -} - -// __nv_fp8_e4m3 x 4 - -template <> -struct vec_t<__nv_fp8_e4m3, 4> { - __nv_fp8x4_e4m3 data; - - FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { - return ((__nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { - return ((const __nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); - FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, - const __nv_fp8_e4m3 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) { - data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3 *ptr) { - data = *((__nv_fp8x4_e4m3 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::store( - __nv_fp8_e4m3 *ptr) const { - *((__nv_fp8x4_e4m3 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy( - __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { - *((__nv_fp8x4_e4m3 *)dst) = *((__nv_fp8x4_e4m3 *)src); -} - -// __nv_fp8_e4m3 x 8 - -template <> -struct vec_t<__nv_fp8_e4m3, 8> { - uint2 data; - - FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { - return ((__nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { - return ((const __nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); - FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, - const __nv_fp8_e4m3 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) { - ((__nv_fp8x4_e4m3 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e4m3 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3 *ptr) { - data = *((uint2 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::store( - __nv_fp8_e4m3 *ptr) const { - *((uint2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy( - __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { - *((__nv_fp8_e4m3 *)dst) = *((__nv_fp8_e4m3 *)src); -} - -// __nv_fp8_e4m3 x 16 or more -template -struct vec_t<__nv_fp8_e4m3, vec_size> { - uint4 data[vec_size / 16]; - - FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { - return ((__nv_fp8_e4m3 *)data)[i]; - } - FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { - return ((const __nv_fp8_e4m3 *)data)[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((__nv_fp8x4_e4m3 *)(&(data[i].x)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e4m3 *)(&(data[i].y)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e4m3 *)(&(data[i].z)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e4m3 *)(&(data[i].w)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - } - } - FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } - } - FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, - const __nv_fp8_e4m3 *src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } - } -}; - -/******************* vec_t<__nv_fp8_e5m2> *******************/ - -// __nv_fp8_e5m2 x 1 -template <> -struct vec_t<__nv_fp8_e5m2, 1> { - __nv_fp8_e5m2 data; - - FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { - return ((__nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { - return ((const __nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); - FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, - const __nv_fp8_e5m2 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) { - data = val; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2 *ptr) { - data = *ptr; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::store( - __nv_fp8_e5m2 *ptr) const { - *ptr = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy( - __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { - *dst = *src; -} - -// __nv_fp8_e5m2 x 2 -template <> -struct vec_t<__nv_fp8_e5m2, 2> { - __nv_fp8x2_e5m2 data; - - FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { - return ((__nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { - return ((const __nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); - FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, - const __nv_fp8_e5m2 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) { - data.__x = - (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2 *ptr) { - data = *((__nv_fp8x2_e5m2 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::store( - __nv_fp8_e5m2 *ptr) const { - *((__nv_fp8x2_e5m2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy( - __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { - *((__nv_fp8x2_e5m2 *)dst) = *((__nv_fp8x2_e5m2 *)src); -} - -// __nv_fp8_e5m2 x 4 - -template <> -struct vec_t<__nv_fp8_e5m2, 4> { - __nv_fp8x4_e5m2 data; - - FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { - return ((__nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { - return ((const __nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); - FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, - const __nv_fp8_e5m2 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) { - data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2 *ptr) { - data = *((__nv_fp8x4_e5m2 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::store( - __nv_fp8_e5m2 *ptr) const { - *((__nv_fp8x4_e5m2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy( - __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { - *((__nv_fp8x4_e5m2 *)dst) = *((__nv_fp8x4_e5m2 *)src); -} - -// __nv_fp8_e5m2 x 8 - -template <> -struct vec_t<__nv_fp8_e5m2, 8> { - uint2 data; - - FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { - return ((__nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { - return ((const __nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); - FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, - const __nv_fp8_e5m2 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) { - ((__nv_fp8x4_e5m2 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e5m2 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2 *ptr) { - data = *((uint2 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::store( - __nv_fp8_e5m2 *ptr) const { - *((uint2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy( - __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { - *((__nv_fp8_e5m2 *)dst) = *((__nv_fp8_e5m2 *)src); -} - -// __nv_fp8_e5m2 x 16 or more - -template -struct vec_t<__nv_fp8_e5m2, vec_size> { - uint4 data[vec_size / 16]; - - FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { - return ((__nv_fp8_e5m2 *)data)[i]; - } - FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { - return ((const __nv_fp8_e5m2 *)data)[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((__nv_fp8x4_e5m2 *)(&(data[i].x)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e5m2 *)(&(data[i].y)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e5m2 *)(&(data[i].z)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e5m2 *)(&(data[i].w)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - } - } - FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } - } - FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, - const __nv_fp8_e5m2 *src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } - } -}; -#endif - -/******************* vec_t *******************/ - -// half x 1 -template <> -struct vec_t { - half data; - - FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } - FLASHINFER_INLINE const half &operator[](size_t i) const { - return ((const half *)(&data))[i]; - } - FLASHINFER_INLINE void fill(half val); - FLASHINFER_INLINE void load(const half *ptr); - FLASHINFER_INLINE void store(half *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(half *dst, const half *src); -}; - -FLASHINFER_INLINE void vec_t::fill(half val) { data = val; } - -FLASHINFER_INLINE void vec_t::load(const half *ptr) { data = *ptr; } - -FLASHINFER_INLINE void vec_t::store(half *ptr) const { *ptr = data; } - -FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { - *dst = *src; -} - -// half x 2 -template <> -struct vec_t { - half2 data; - - FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } - FLASHINFER_INLINE const half &operator[](size_t i) const { - return ((const half *)(&data))[i]; - } - FLASHINFER_INLINE void fill(half val); - FLASHINFER_INLINE void load(const half *ptr); - FLASHINFER_INLINE void store(half *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(half *dst, const half *src); -}; - -FLASHINFER_INLINE void vec_t::fill(half val) { - data = make_half2(val, val); -} - -FLASHINFER_INLINE void vec_t::load(const half *ptr) { - data = *((half2 *)ptr); -} - -FLASHINFER_INLINE void vec_t::store(half *ptr) const { - *((half2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { - *((half2 *)dst) = *((half2 *)src); -} - -// half x 4 - -template <> -struct vec_t { - uint2 data; - - FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } - FLASHINFER_INLINE const half &operator[](size_t i) const { - return ((const half *)(&data))[i]; - } - FLASHINFER_INLINE void fill(half val); - FLASHINFER_INLINE void load(const half *ptr); - FLASHINFER_INLINE void store(half *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(half *dst, const half *src); -}; - -FLASHINFER_INLINE void vec_t::fill(half val) { - *(half2 *)(&data.x) = make_half2(val, val); - *(half2 *)(&data.y) = make_half2(val, val); -} - -FLASHINFER_INLINE void vec_t::load(const half *ptr) { - data = *((uint2 *)ptr); -} - -FLASHINFER_INLINE void vec_t::store(half *ptr) const { - *((uint2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { - *((uint2 *)dst) = *((uint2 *)src); -} - -// half x 8 or more - -template -struct vec_t { - uint4 data[vec_size / 8]; - FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)data)[i]; } - FLASHINFER_INLINE const half &operator[](size_t i) const { - return ((const half *)data)[i]; - } - FLASHINFER_INLINE void fill(half val) { -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - *(half2 *)(&(data[i].x)) = make_half2(val, val); - *(half2 *)(&(data[i].y)) = make_half2(val, val); - *(half2 *)(&(data[i].z)) = make_half2(val, val); - *(half2 *)(&(data[i].w)) = make_half2(val, val); - } - } - FLASHINFER_INLINE void load(const half *ptr) { -#pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } - } - FLASHINFER_INLINE void store(half *ptr) const { -#pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(half *dst, const half *src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } - } -}; - -/******************* vec_t *******************/ - -// nv_bfloat16 x 1 -template <> -struct vec_t { - nv_bfloat16 data; - - FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { - return ((nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { - return ((const nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(nv_bfloat16 val); - FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); - FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src); -}; - -FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { - data = val; -} - -FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { - data = *ptr; -} - -FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { - *ptr = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src) { - *dst = *src; -} - -// nv_bfloat16 x 2 -template <> -struct vec_t { - nv_bfloat162 data; - - FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { - return ((nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { - return ((const nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(nv_bfloat16 val); - FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); - FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src); -}; - -FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { - data = make_bfloat162(val, val); -} - -FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { - data = *((nv_bfloat162 *)ptr); -} - -FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { - *((nv_bfloat162 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src) { - *((nv_bfloat162 *)dst) = *((nv_bfloat162 *)src); -} - -// nv_bfloat16 x 4 - -template <> -struct vec_t { - uint2 data; - - FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { - return ((nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { - return ((const nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(nv_bfloat16 val); - FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); - FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src); -}; - -FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { - *(nv_bfloat162 *)(&data.x) = make_bfloat162(val, val); - *(nv_bfloat162 *)(&data.y) = make_bfloat162(val, val); -} - -FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { - data = *((uint2 *)ptr); -} - -FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { - *((uint2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src) { - *((uint2 *)dst) = *((uint2 *)src); -} - -// nv_bfloat16 x 8 or more - -template -struct vec_t { - uint4 data[vec_size / 8]; - - FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { - return ((nv_bfloat16 *)data)[i]; - } - FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { - return ((const nv_bfloat16 *)data)[i]; - } - FLASHINFER_INLINE void fill(nv_bfloat16 val) { -#pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - *(nv_bfloat162 *)(&(data[i].x)) = make_bfloat162(val, val); - *(nv_bfloat162 *)(&(data[i].y)) = make_bfloat162(val, val); - *(nv_bfloat162 *)(&(data[i].z)) = make_bfloat162(val, val); - *(nv_bfloat162 *)(&(data[i].w)) = make_bfloat162(val, val); - } - } - FLASHINFER_INLINE void load(const nv_bfloat16 *ptr) { -#pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } - } - FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const { -#pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src) { -#pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } - } -}; - -/******************* vec_t *******************/ - -// float x 1 - -template <> -struct vec_t { - float data; - - FLASHINFER_INLINE float &operator[](size_t i) { - return ((float *)(&data))[i]; - } - FLASHINFER_INLINE const float &operator[](size_t i) const { - return ((const float *)(&data))[i]; - } - FLASHINFER_INLINE void fill(float val); - FLASHINFER_INLINE void load(const float *ptr); - FLASHINFER_INLINE void store(float *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(float *dst, const float *src); -}; - -FLASHINFER_INLINE void vec_t::fill(float val) { data = val; } - -FLASHINFER_INLINE void vec_t::load(const float *ptr) { data = *ptr; } - -FLASHINFER_INLINE void vec_t::store(float *ptr) const { *ptr = data; } - -FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) { - *dst = *src; -} - -// float x 2 - -template <> -struct vec_t { - float2 data; - - FLASHINFER_INLINE float &operator[](size_t i) { - return ((float *)(&data))[i]; - } - FLASHINFER_INLINE const float &operator[](size_t i) const { - return ((const float *)(&data))[i]; - } - FLASHINFER_INLINE void fill(float val); - FLASHINFER_INLINE void load(const float *ptr); - FLASHINFER_INLINE void store(float *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - FLASHINFER_INLINE static void memcpy(float *dst, const float *src); -}; - -FLASHINFER_INLINE void vec_t::fill(float val) { - data = make_float2(val, val); -} - -FLASHINFER_INLINE void vec_t::load(const float *ptr) { - data = *((float2 *)ptr); -} - -FLASHINFER_INLINE void vec_t::store(float *ptr) const { - *((float2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) { - *((float2 *)dst) = *((float2 *)src); -} - -// float x 4 or more -template -struct vec_t { - float4 data[vec_size / 4]; - - FLASHINFER_INLINE float &operator[](size_t i) { return ((float *)(data))[i]; } - FLASHINFER_INLINE const float &operator[](size_t i) const { - return ((const float *)(data))[i]; - } - FLASHINFER_INLINE void fill(float val) { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - data[i] = make_float4(val, val, val, val); - } - } - FLASHINFER_INLINE void load(const float *ptr) { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - data[i] = ((float4 *)ptr)[i]; - } - } - FLASHINFER_INLINE void store(float *ptr) const { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - FLASHINFER_INLINE static void memcpy(float *dst, const float *src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4 *)dst)[i] = ((float4 *)src)[i]; - } - } -}; - -/******************* vec_t type cast *******************/ - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((float2 *)(&dst.data))[i] = __half22float2(((half2 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = half(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((half2 *)(&dst.data))[i] = __float22half2_rn(((float2 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((float2 *)(&dst.data))[i] = - __bfloat1622float2(((nv_bfloat162 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = nv_bfloat16(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((nv_bfloat162 *)(&dst.data))[i] = - __float22bfloat162_rn(((float2 *)(&src.data))[i]); - } - } -} - -#ifdef FLASHINFER_USE_FP8 - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else if constexpr (vec_size == 2) { - *(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e4m3 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e4m3 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e4m3 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t<__nv_fp8_e4m3, vec_size> &dst) { - if constexpr (vec_size == 1) { - dst.data = __nv_fp8_e4m3(src.data); - } else if constexpr (vec_size == 2) { - *(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(float2 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((__nv_fp8x4_e4m3 *)(&dst.data))[i] = - __nv_fp8x4_e4m3(((float4 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t<__nv_fp8_e4m3, vec_size> &dst) { - if constexpr (vec_size == 1) { - dst.data = __nv_fp8_e4m3(src.data); - } else if constexpr (vec_size == 2) { - *(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(half2 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - // NOTE(Zihao): need to double check if we properly handle flo and fhi - ((__nv_fp8x4_e4m3 *)(&dst.data))[i] = __nv_fp8x4_e4m3( - ((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else if constexpr (vec_size == 2) { - *(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e5m2 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e5m2 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e5m2 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t<__nv_fp8_e5m2, vec_size> &dst) { - if constexpr (vec_size == 1) { - dst.data = __nv_fp8_e5m2(src.data); - } else if constexpr (vec_size == 2) { - *(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(float2 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((__nv_fp8x4_e5m2 *)(&dst.data))[i] = - __nv_fp8x4_e5m2(((float4 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t<__nv_fp8_e5m2, vec_size> &dst) { - if constexpr (vec_size == 1) { - dst.data = __nv_fp8_e4m3(src.data); - } else if constexpr (vec_size == 2) { - *(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(half2 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - // NOTE(Zihao): need to double check if we properly handle flo and fhi - ((__nv_fp8x4_e5m2 *)(&dst.data))[i] = __nv_fp8x4_e5m2( - ((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]); - } - } -} - -#endif // FLASHINFER_USE_FP8 - -#endif // VEC_DTYPES_CUH_ diff --git a/csrc/punica/punica_ops.cu b/csrc/punica/punica_ops.cu deleted file mode 100644 index dd29820144b3..000000000000 --- a/csrc/punica/punica_ops.cu +++ /dev/null @@ -1,569 +0,0 @@ -#include -#include -#include - -#include "type_convert.h" -#include "../cuda_compat.h" -#include "bgmv/bgmv_config.h" - - -//====== utils ====== - -inline void check_shape(const torch::Tensor &a, const torch::Tensor &b, - const char *a_name, const char *b_name) { - TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", - a.dim(), " vs ", b.dim()); - for (int i = 0; i < a.dim(); ++i) { - TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, - ".size(", i, ")"); - } -} - -inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) { - return (uint64_t(a) << 32) | uint64_t(b); -} - -#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") - -#define CHECK_CONTIGUOUS(x) \ - TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") - -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - -#define CHECK_DIM(d, x) \ - TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") - -#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b) - -#define CHECK_EQ(a, b) \ - TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) - -//====== bgmv ====== - -template -inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, - const int64_t *lora_indices, - uint32_t in_features, uint32_t out_features, - int64_t y_offset, int64_t full_y_size, - int64_t batch_size, int64_t num_layers, - int64_t layer_idx, float scale) { - // NOTE(woosuk): While Punica supports various combinations of input/output - // data types, we limit the supported data types to reduce the binary size. - constexpr bool is_input_float = std::is_same::value; - constexpr bool is_output_float = std::is_same::value; - if (is_input_float) { - if (!std::is_same::value) { - return false; - } - } else if (is_output_float) { - if (!std::is_same::value) { - return false; - } - } else if (!(std::is_same::value && - std::is_same::value)) { - return false; - } - - switch (pack_u32(in_features, out_features)) { -#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \ - case pack_u32(feat_in, feat_out): \ - bgmv_kernel(Y, X, W, lora_indices, y_offset, \ - full_y_size, batch_size, num_layers, \ - layer_idx, scale); \ - break; -#define CASE(_in_T, _out_T, _W_T, narrow, wide) \ - CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \ - CASE_ONESIDE(in_T, out_T, W_T, wide, narrow) - - FOR_BGMV_WIDE_NARROW(CASE, _, _, _) - FOR_INST_BGMV_WIDE_NARROW(CASE_ONESIDE, _, _, _) -#undef CASE -#undef CASE_ONESIDE - default: - return false; - } - return true; -} - -void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, double scale) { - CHECK_INPUT(y); - CHECK_INPUT(x); - CHECK_INPUT(w); - CHECK_INPUT(indicies); - - CHECK_DIM(2, y); - CHECK_DIM(2, x); - CHECK_DIM(4, w); - CHECK_DIM(1, indicies); - - int64_t B = x.size(0); - int64_t h_in = x.size(1); - int64_t h_out = y.size(1); - int64_t num_layers = w.size(1); - CHECK_EQ(w.size(3), h_in); - CHECK_EQ(w.size(2), h_out); - CHECK_EQ(indicies.size(0), x.size(0)); - CHECK_EQ(y.size(0), x.size(0)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - bool ok = false; - if (h_in <= 128512 && h_out <= 128512) { - // TODO: See if we can get rid of this massive nested switch - switch (x.scalar_type()) { - case at::ScalarType::Half: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - default: - break; - } - } - TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, - " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); -} - -void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, - double scale, int64_t h_in, int64_t h_out, - int64_t y_offset) { - CHECK_INPUT(y); - CHECK_INPUT(x); - CHECK_INPUT(w); - CHECK_INPUT(indicies); - - CHECK_DIM(2, y); - CHECK_DIM(2, x); - CHECK_DIM(4, w); - CHECK_DIM(1, indicies); - - int64_t B = x.size(0); - int64_t num_layers = w.size(1); - int64_t full_y_size = y.size(1); - CHECK_EQ(w.size(3), h_in); - CHECK_EQ(w.size(2), h_out); - CHECK_EQ(indicies.size(0), x.size(0)); - CHECK_EQ(y.size(0), x.size(0)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - bool ok = false; - if (h_in <= 128512 && h_out <= 128512) { - // TODO: See if we can get rid of this massive nested switch - switch (x.scalar_type()) { - case at::ScalarType::Half: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - default: - break; - } - } - TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, - " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); -} diff --git a/csrc/punica/punica_ops.h b/csrc/punica/punica_ops.h deleted file mode 100644 index 5d625d0564f7..000000000000 --- a/csrc/punica/punica_ops.h +++ /dev/null @@ -1,11 +0,0 @@ -#pragma once - -#include - -void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, double scale); - -void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, - double scale, int64_t h_in, int64_t h_out, - int64_t y_offset); diff --git a/csrc/punica/torch_bindings.cpp b/csrc/punica/torch_bindings.cpp deleted file mode 100644 index 894e229b6d9d..000000000000 --- a/csrc/punica/torch_bindings.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include "registration.h" -#include "punica_ops.h" - -TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { - m.def( - "dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int " - "layer_idx, float scale) -> ()"); - m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv); - - m.def( - "dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w," - "Tensor indicies, int layer_idx," - "float scale, int h_in, int h_out," - "int y_offset) -> ()"); - m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level); -} - -REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/punica/type_convert.h b/csrc/punica/type_convert.h deleted file mode 100644 index dff7ce49283d..000000000000 --- a/csrc/punica/type_convert.h +++ /dev/null @@ -1,82 +0,0 @@ -#ifndef CSRC__PUNICA__TYPE_CONVERT_H__ -#define CSRC__PUNICA__TYPE_CONVERT_H__ - -#ifndef USE_ROCM - -#include -#include - -#else - -#include -#include - -#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__ - -typedef __half nv_half; -typedef __hip_bfloat16 nv_bfloat16; -typedef __hip_bfloat162 nv_bfloat162; - -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) { - return __hip_bfloat162{val, val}; -} - -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) { - return __hip_bfloat162{vall, valr}; -} - -template -__TYPE_CONVERT__HOST_DEVICE__ -inline T_dst convert_type(T_src val) { - return static_cast(val); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline float convert_type<__half, float>(__half val) { - return __half2float(val); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline __half convert_type(float val) { - return __float2half(val); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) { - return __bfloat162float(val); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat16 convert_type(float val) { - return __float2bfloat16(val); -} - -template -__TYPE_CONVERT__HOST_DEVICE__ -inline T vllm_add(T a, T b) { - return a + b; -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline __half vllm_add<__half>(__half a, __half b) { - return __hadd(a, b); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) { - return __hadd(a, b); -} - -#undef __TYPE_CONVERT__HOST_DEVICE__ - -#endif // USE_ROCM - -#endif // CSRC__PUNICA__TYPE_CONVERT_H__ diff --git a/csrc/quantization/aqlm/gemm_kernels.cu b/csrc/quantization/aqlm/gemm_kernels.cu index 22da5e4f08a1..79cd2c610b3c 100644 --- a/csrc/quantization/aqlm/gemm_kernels.cu +++ b/csrc/quantization/aqlm/gemm_kernels.cu @@ -496,14 +496,14 @@ torch::Tensor code2x8_matmat(const torch::Tensor& input, } // Accumulate the partition sizes. -int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) { +int4 accumulate_sizes(const std::vector& codebook_partition_sizes) { int4 cumulative_sizes; auto cumulative_size = &cumulative_sizes.x; - int i = 0; + size_t i = 0; int last = 0; - assert(codebook_partition_sizes.size(0) <= 4); - for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) { - *cumulative_size = codebook_partition_sizes[i].item() + last; + assert(codebook_partition_sizes.size() <= 4); + for (; i < codebook_partition_sizes.size(); ++i, ++cumulative_size) { + *cumulative_size = codebook_partition_sizes[i] + last; last = *cumulative_size; } // fill in the rest with unreachable. @@ -519,12 +519,12 @@ int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) { torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, const torch::Tensor& codebooks, const torch::Tensor& scales, - const torch::Tensor& codebook_partition_sizes, + const std::vector& codebook_partition_sizes, const std::optional& bias) { int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); - int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); + int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(); int const entries = codebooks.size(1); if (nbooks == 1 && entries == (1 << 16)) { @@ -541,13 +541,13 @@ torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, return {}; } -torch::Tensor aqlm_dequant(const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& codebook_partition_sizes) { +torch::Tensor aqlm_dequant( + const torch::Tensor& codes, const torch::Tensor& codebooks, + const std::vector& codebook_partition_sizes) { int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); - int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); + int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(); int const entries = codebooks.size(1); const at::cuda::OptionalCUDAGuard device_guard(device_of(codes)); @@ -557,7 +557,8 @@ torch::Tensor aqlm_dequant(const torch::Tensor& codes, auto in_features = codes.size(1) * 8; auto out_features = codes.size(0); - assert(out_features = codebook_partition_sizes.sum().item()); + assert(out_features == std::accumulate(codebook_partition_sizes.begin(), + codebook_partition_sizes.end(), 0)); auto weights = torch::empty({out_features, in_features}, torch::TensorOptions() diff --git a/csrc/quantization/awq/dequantize.cuh b/csrc/quantization/awq/dequantize.cuh index 813ec6716cf5..5fa4b5f64027 100644 --- a/csrc/quantization/awq/dequantize.cuh +++ b/csrc/quantization/awq/dequantize.cuh @@ -95,6 +95,7 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) { return result; #endif + __builtin_unreachable(); // Suppress missing return statement warning } } // namespace awq diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 9da724a1b43c..53c47679cdd7 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -334,7 +334,7 @@ __global__ void __launch_bounds__(64) } // TODO: Shang: Hoist loop invariance. - for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { + for (int ax1_0_1 = 0; ax1_0_1 < (N / 32); ++ax1_0_1) { for (int local_id = 0; local_id < 8; ++local_id) { int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index aa9511daa277..e79785827189 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -3,16 +3,28 @@ #include #include "../../dispatch_utils.h" -#include "../../reduction_utils.cuh" + +#ifndef USE_ROCM + #include + #include +#else + #include + #include +#endif static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM - static const float i8_min = + static constexpr auto i8_min = static_cast(std::numeric_limits::min()); - static const float i8_max = + static constexpr auto i8_max = static_cast(std::numeric_limits::max()); - // round + + // To match the rounding mode of CUDA, we use nearbyint. + // It uses the current rounding mode, which is always FE_TONEAREST on HIP. + // If that changes in the future, we may need to set the rounding mode + // explicitly, either at runtime or compile time. float dst = std::nearbyint(x); + // saturate dst = std::clamp(dst, i8_min, i8_max); return static_cast(dst); @@ -24,6 +36,59 @@ static inline __device__ int8_t float_to_int8_rn(float x) { #endif } +static inline __device__ int32_t float_to_int32_rn(float x) { +#ifdef USE_ROCM + // int32_max is not exactly representable as float. + // Therefore, we need to be careful and manually return int32_max on overflow. + // For symmetry, we also do the same for int32_min, even though it is exactly + // representable as float and the conversion should be exact. + static constexpr auto i32_min = std::numeric_limits::min(); + static constexpr auto i32_min_f = static_cast(i32_min); + static constexpr auto i32_max = std::numeric_limits::max(); + static constexpr auto i32_max_f = static_cast(i32_max); + + // To match the rounding mode of CUDA, we use nearbyint. + // It uses the current rounding mode, which is always FE_TONEAREST on HIP. + // If that changes in the future, we may need to set the rounding mode + // explicitly, either at runtime or compile time. + float dst = std::nearbyint(x); + + // saturate on the higher end. + if (dst >= i32_max_f) { + return i32_max; + } + // saturate on the lower end. + if (dst <= i32_min_f) { + return i32_min; + } + + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +#endif +} + +static inline __device__ int8_t int32_to_int8(int32_t x) { +#ifdef USE_ROCM + static constexpr auto i8_min = + static_cast(std::numeric_limits::min()); + static constexpr auto i8_max = + static_cast(std::numeric_limits::max()); + + // saturate + int32_t dst = std::clamp(x, i8_min, i8_max); + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.sat.s8.s32 %0, %1;" : "=r"(dst) : "r"(x)); + return reinterpret_cast(dst); +#endif +} + namespace vllm { template @@ -31,12 +96,36 @@ __global__ void static_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, scale_type const* scale_ptr, const int hidden_size) { int const tid = threadIdx.x; - int const token_idx = blockIdx.x; + int64_t const token_idx = blockIdx.x; + scale_type const scale = *scale_ptr; + + // Must be performed using 64-bit math to avoid integer overflow. + out += token_idx * hidden_size; + input += token_idx * hidden_size; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[i] = float_to_int8_rn(static_cast(input[i]) / scale); + } +} + +template +__global__ void static_scaled_int8_azp_quant_kernel( + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type const* scale_ptr, azp_type const* azp_ptr, + const int hidden_size) { + int const tid = threadIdx.x; + int64_t const token_idx = blockIdx.x; scale_type const scale = *scale_ptr; + azp_type const azp = *azp_ptr; + + // Must be performed using 64-bit math to avoid integer overflow. + out += token_idx * hidden_size; + input += token_idx * hidden_size; for (int i = tid; i < hidden_size; i += blockDim.x) { - out[token_idx * hidden_size + i] = float_to_int8_rn( - static_cast(input[token_idx * hidden_size + i]) / scale); + auto const val = static_cast(input[i]); + auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); + out[i] = quant_val; } } @@ -45,17 +134,24 @@ __global__ void dynamic_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, scale_type* scale, const int hidden_size) { int const tid = threadIdx.x; - int const token_idx = blockIdx.x; + int64_t const token_idx = blockIdx.x; float absmax_val = 0.0f; float const zero = 0.0f; + // Must be performed using 64-bit math to avoid integer overflow. + out += token_idx * hidden_size; + input += token_idx * hidden_size; + for (int i = tid; i < hidden_size; i += blockDim.x) { - float val = static_cast(input[token_idx * hidden_size + i]); + float val = static_cast(input[i]); val = val > zero ? val : -val; absmax_val = val > absmax_val ? val : absmax_val; } - float const block_absmax_val_maybe = blockReduceMax(absmax_val); + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStorage; + float const block_absmax_val_maybe = + BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x); __shared__ float block_absmax_val; if (tid == 0) { block_absmax_val = block_absmax_val_maybe; @@ -65,8 +161,63 @@ __global__ void dynamic_scaled_int8_quant_kernel( float const tmp_scale = 127.0f / block_absmax_val; for (int i = tid; i < hidden_size; i += blockDim.x) { - out[token_idx * hidden_size + i] = float_to_int8_rn( - static_cast(input[token_idx * hidden_size + i]) * tmp_scale); + out[i] = float_to_int8_rn(static_cast(input[i]) * tmp_scale); + } +} + +template +__global__ void dynamic_scaled_int8_azp_quant_kernel( + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type* scale, azp_type* azp, const int hidden_size) { + int64_t const token_idx = blockIdx.x; + + // Must be performed using 64-bit math to avoid integer overflow. + out += token_idx * hidden_size; + input += token_idx * hidden_size; + + // Scan for the min and max value for this token + float max_val = std::numeric_limits::min(); + float min_val = std::numeric_limits::max(); + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + auto val = static_cast(input[i]); + max_val = std::max(max_val, val); + min_val = std::min(min_val, val); + } + + // Reduce the max and min values across the block + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStorage; + max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x); + __syncthreads(); // Make sure min doesn't mess with max shared memory + min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x); + + __shared__ scale_type scale_sh; + __shared__ azp_type azp_sh; + + // Compute the scale and zero point and store them, only on the first thread + if (threadIdx.x == 0) { + float const scale_val = (max_val - min_val) / 255.0f; + // Use rounding to even (same as torch.round) + auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val); + auto const azp_val = static_cast(azp_float); + + // Store the scale and azp into shared and global + scale[token_idx] = scale_sh = scale_val; + azp[token_idx] = azp_sh = azp_val; + } + + // Wait for the scale and azp to be computed + __syncthreads(); + + float const scale_val = scale_sh; + azp_type const azp_val = azp_sh; + + // Quantize the values + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + auto const val = static_cast(input[i]); + auto const quant_val = + int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val); + out[i] = quant_val; } } @@ -74,10 +225,12 @@ __global__ void dynamic_scaled_int8_quant_kernel( void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] - torch::Tensor const& scale) { + torch::Tensor const& scale, + std::optional const& azp) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(scale.numel() == 1); + TORCH_CHECK(!azp || azp->numel() == 1); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; @@ -86,19 +239,29 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { - vllm::static_scaled_int8_quant_kernel - <<>>(input.data_ptr(), - out.data_ptr(), - scale.data_ptr(), hidden_size); + if (!azp) { + vllm::static_scaled_int8_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), hidden_size); + } else { + vllm::static_scaled_int8_azp_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), + hidden_size); + } }); } void dynamic_scaled_int8_quant( torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] - torch::Tensor& scales) { + torch::Tensor& scales, std::optional const& azp) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(scales.is_contiguous()); + TORCH_CHECK(!azp || azp->is_contiguous()); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; @@ -107,9 +270,17 @@ void dynamic_scaled_int8_quant( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { - vllm::dynamic_scaled_int8_quant_kernel - <<>>(input.data_ptr(), - out.data_ptr(), - scales.data_ptr(), hidden_size); + if (!azp) { + vllm::dynamic_scaled_int8_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scales.data_ptr(), hidden_size); + } else { + vllm::dynamic_scaled_int8_azp_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scales.data_ptr(), azp->data_ptr(), + hidden_size); + } }); } diff --git a/csrc/quantization/cutlass_w8a8/Epilogues.md b/csrc/quantization/cutlass_w8a8/Epilogues.md new file mode 100644 index 000000000000..a30e1fdf3ac7 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/Epilogues.md @@ -0,0 +1,167 @@ +# CUTLASS Epilogues + +## Introduction + +This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs. + +Currently, we only support symmetric quantization for weights, +and symmetric and asymmetric quantization for activations. +Both can be quantized per-tensor or per-channel (weights) / per-token (activations). + +There are 4 epilogues: + +1. `ScaledEpilogue`: symmetric quantization for activations, no bias. +1. `ScaledEpilogueBias`: symmetric quantization for activations, supports bias. +1. `ScaledEpilogueAzp`: asymmetric per-tensor quantization for activations, supports bias. +1. `ScaledEpilogueAzpPerToken`: asymmetric per-token quantization for activations, supports bias. + +We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size. +Instead, if no bias is passed, the epilogue will use 0 as the bias. +That induces a redundant addition operation (and runtime check), but the performance impact is minor. + +## Underlying Linear Algebra + +More details available in the [Activation Quantization RFC](https://github.com/vllm-project/vllm/issues/3975). + +If $` \widehat X `$ is the quantized $` X `$, our matrices become the following + +```math +A = s_a (\widehat A - J_a z_a) +``` + +```math +B = s_b \widehat B +``` + +```math +D = A B + C +``` + +```math +D = s_a s_b \widehat D + C +``` + +Here, D is the output of the GEMM, and C is the bias. +A is the activations and supports asymmetric quantization, +and B is the weights and only supports symmetric quantization. +$ s_a $ and $s_b$ are the scales for activations and weights, respectively. +$ z_a $ is the zero-point for activations, and $ J_a $ is the matrix of all ones with dimensions of A. +Additional epilogues would be required to support asymmetric quantization for weights. + +Expanding further, we can calculate $` \widehat D `$ as follows: + +```math +A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B +``` + +```math +A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right) +``` + +```math +\widehat D = \widehat A \widehat B - z_a J_a \widehat B +``` + +Note that $` \widehat A \widehat B `$ is the raw output of the GEMM, +and $` J_a \widehat B `$ is known ahead of time. +Each row of it is equal to $` \mathbf 1 \widehat B `$, which is a row-vector of column sums of $` \widehat B `$. + +## Epilogues + +### `ScaledEpilogue` + +This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$. +The output of the GEMM is: + +```math +\widehat D = \widehat A \widehat B +``` + +```math +D = s_a s_b \widehat D +``` + +```math +D = s_a s_b \widehat A \widehat B +``` + +Epilogue parameters: +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). + +### `ScaledEpilogueBias` + +This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$. +The output of the GEMM is: + +```math +\widehat D = \widehat A \widehat B +``` + +```math +D = s_a s_b \widehat D + C +``` + +```math +D = s_a s_b \widehat A \widehat B + C +``` + +Epilogue parameters: + +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). +- `bias` is the bias, is always per-channel (row-vector). + +### `ScaledEpilogueAzp` + +This epilogue computes the asymmetric per-tensor quantization for activations with bias. +The output of the GEMM is: + +```math +\widehat D = \widehat A \widehat B - z_a J_a \widehat B +``` + +```math +D = s_a s_b \widehat D + C +``` + +```math +D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C +``` + +Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$. +That is precomputed and stored in `azp_with_adj` as a row-vector. + +Epilogue parameters: + +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). + - Generally this will be per-tensor as the zero-points are per-tensor. +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). +- `azp_with_adj` is the precomputed zero-point term ($` z_a J_a \widehat B `$), is per-channel (row-vector). +- `bias` is the bias, is always per-channel (row-vector). + +To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel. + +### `ScaledEpilogueAzpPerToken` + +This epilogue computes the asymmetric per-token quantization for activations with bias. + +The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector. +That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$. + +Epilogue parameters: + +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). + - Generally this will be per-token as the zero-points are per-token. +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). +- `azp_adj` is the precomputed zero-point adjustment term ($` \mathbf 1 \widehat B `$), is per-channel (row-vector). +- `azp` is the zero-point (`z_a`), is per-token (column-vector). +- `bias` is the bias, is always per-channel (row-vector). + +To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel. + +The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM): + +```math +out = scale_a * scale_b * (Dq - azp_adj * azp) + bias +``` diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp deleted file mode 100644 index c4c6b18654ee..000000000000 --- a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp +++ /dev/null @@ -1,346 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -// -// This file is a modified excerpt of -// include/cutlass/epilogue/fusion/visitor_load.hpp from -// https://github.com/NVIDIA/cutlass v3.5.0 -// It has been modified to support either -// row/column or scalar broadcasting where the tensor being loaded from is -// always passed in via a device pointer. This lets one compiled kernel handle -// all cases of per-tensor or per-channel/per-token quantization. -// -// This interface also allows the scales to be passed in as tensors that -// consistently reside on the device, which avoids an issue with a previous -// implementation where scalars needed to be on the CPU since they -// were passed in via float values. This created a potential performance hazard -// if scales were initially on the device, and caused torch.compile graph -// breaks when moving scales to the CPU. -// -#pragma once - -// Turn off clang-format for the entire file to keep it close to upstream -// clang-format off - -#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" -#include "cute/tensor.hpp" - -namespace cutlass::epilogue::threadblock { - -using namespace cute; -using namespace detail; - -template< - class ThreadMap, - class Element, - class StrideMNL -> -struct VisitorRowOrScalarBroadcast { - - // This struct has been modified to have a bool indicating that ptr_row is a - // scalar that must be broadcast. - struct Arguments { - Element const* ptr_row = nullptr; - bool row_broadcast = true; - StrideMNL dRow = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - struct SharedStorage {}; - - // Global load type - static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; - using VecType = uint_bit_t; - static int constexpr VecLength = sizeof(VecType) / sizeof(Element); - - CUTLASS_HOST_DEVICE - VisitorRowOrScalarBroadcast() { } - - CUTLASS_HOST_DEVICE - VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms) { } - - Params const* params_ptr; - - template - struct Callbacks : EmptyCallbacks { - CUTLASS_DEVICE - Callbacks( - GTensor&& tC_gRow, - RTensor&& tC_rRow, - CTensor&& tC_cRow, - ProblemShape problem_shape, - Params const* params_ptr - ): - tC_gRow(cute::forward(tC_gRow)), - tC_rRow(cute::forward(tC_rRow)), - tC_cRow(cute::forward(tC_cRow)), - n(get<1>(problem_shape)), - params_ptr(params_ptr) { } - - GTensor tC_gRow; - RTensor tC_rRow; - CTensor tC_cRow; - Params const* params_ptr; - int n; - - // This function is modified from VisitorRowBroadcast - CUTLASS_DEVICE void - begin_epilogue() { - clear(tC_rRow); - auto src_v = filter(tC_gRow); - auto coord_v = filter(tC_cRow); - auto dst_v = filter(tC_rRow); - - if (params_ptr->row_broadcast) { - // In this case we are loading from a row vector and broadcasting - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(src_v); ++i) { - bool guard = get<1>(coord_v(i)) < n; - cutlass::arch::global_load( - dst_v(i), (void const*)&src_v(i), guard); - } - } else { - // In this case we are loading from a scalar and broadcasting - VecType filled_vec; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < VecLength; i++) { - reinterpret_cast(&filled_vec)[i] = *(params_ptr->ptr_row); - } - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(src_v); ++i) { - if (get<1>(coord_v(i)) < n) { - dst_v(i) = filled_vec; - } - } - } - } - - template - CUTLASS_DEVICE auto // returns an Array - visit(int iter_idx, int row_idx, int column_idx, int frg_idx, - Array const& frg_acc) { - Tensor rRow_frg = recast>(coalesce(tC_rRow)); - return rRow_frg(column_idx); - } - }; - - template - CUTLASS_DEVICE auto - get_callbacks( - gemm::GemmCoord threadblock_tile_offset, - int thread_idx, - ProblemShape problem_shape - ) { - Tensor mRow = make_tensor( - make_gmem_ptr(params_ptr->ptr_row), - problem_shape, - params_ptr->dRow); - - // VECTOR, FRAGMENT_COLUMN - Tensor tC_gRow = recast( - ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) - )(_,_,_0{},_0{},_0{},_0{}); - Tensor tC_rRow = make_tensor_like(tC_gRow); - - // Generate the pred tensor - Tensor cRow = make_identity_tensor(mRow.shape()); - Tensor tC_cRow = outer_partition( - ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), - Shape>{}, - (_0{}) - ); - - return Callbacks< - decltype(tC_gRow), decltype(tC_rRow), - decltype(tC_cRow), ProblemShape>( - cute::move(tC_gRow), - cute::move(tC_rRow), - cute::move(tC_cRow), - problem_shape, - params_ptr - ); - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Column vector broadcast -template< - class ThreadMap, - class Element, - class StrideMNL = Stride<_1,_0,_0> -> -struct VisitorColOrScalarBroadcast { - - // This struct has been modified to have a bool indicating that ptr_col is a - // scalar that must be broadcast. - struct Arguments { - Element const* ptr_col = nullptr; - bool col_broadcast = true; - StrideMNL dCol = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - struct SharedStorage { }; - - CUTLASS_HOST_DEVICE - VisitorColOrScalarBroadcast() { } - - CUTLASS_HOST_DEVICE - VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms) { } - - Params const* params_ptr; - - template - struct Callbacks : EmptyCallbacks { - CUTLASS_DEVICE - Callbacks( - GTensor&& tC_gCol, - RTensor&& tC_rCol, - CTensor&& tC_cCol, - ProblemShape problem_shape, - Params const* params_ptr - ): - tC_gCol(cute::forward(tC_gCol)), - tC_rCol(cute::forward(tC_rCol)), - tC_cCol(cute::forward(tC_cCol)), - m(get<0>(problem_shape)), - params_ptr(params_ptr) { } - - GTensor tC_gCol; - RTensor tC_rCol; - CTensor tC_cCol; - Params const* params_ptr; - int m; - - // This function is modified from VisitorColBroadcast - CUTLASS_DEVICE void - begin_epilogue() { - clear(tC_rCol); - - Tensor pred = make_tensor(shape(tC_gCol)); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(pred); ++i) { - pred(i) = get<0>(tC_cCol(i)) < m; - } - - if (params_ptr->col_broadcast) { - // In this case we are loading from a column vector and broadcasting - copy_if(pred, tC_gCol, tC_rCol); - } else { - // In this case we are loading from a scalar and broadcasting - auto dst_v = filter(tC_rCol); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(dst_v); ++i) { - if (pred(i)) { - dst_v(i) = *(params_ptr->ptr_col); - } - } - } - } - - template - CUTLASS_DEVICE auto // returns an Array - visit(int iter_idx, int row_idx, int column_idx, int frg_idx, - Array const& frg_acc) { - Array frg_col; - frg_col.fill(tC_rCol(row_idx,iter_idx)); - return frg_col; - } - }; - - template - CUTLASS_DEVICE auto - get_callbacks( - gemm::GemmCoord threadblock_tile_offset, - int thread_idx, - ProblemShape problem_shape - ) { - Tensor mCol = make_tensor( - make_gmem_ptr(params_ptr->ptr_col), - problem_shape, - params_ptr->dCol); - - // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER - Tensor tC_gCol = group_modes<1,4>( - ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); - Tensor tC_rCol = make_tensor_like(tC_gCol); - - // Generate the pred tensor - Tensor cCol = make_identity_tensor(mCol.shape()); - Tensor tC_cCol = group_modes<1,4>( - ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); - - return Callbacks< - decltype(tC_gCol), decltype(tC_rCol), - decltype(tC_cCol), ProblemShape>( - cute::move(tC_gCol), - cute::move(tC_rCol), - cute::move(tC_cCol), - problem_shape, - params_ptr - ); - } -}; - -} diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp deleted file mode 100644 index e4bc9752ed7d..000000000000 --- a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp +++ /dev/null @@ -1,417 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -// -// This file is a modified excerpt of -// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp -// from https://github.com/NVIDIA/cutlass v3.5.0 -// It has been modified to support either row/column or scalar broadcasting -// where the tensor being loaded from is always passed in via a device pointer. -// This lets one compiled kernel handle all cases of per-tensor or -// per-channel/per-token quantization. -// -// This interface also allows the scales to be passed in as tensors that -// consistently reside on the device, which avoids an issue with a previous -// implementation where scalars needed to be on the CPU since they -// were passed in via float values. This created a potential performance hazard -// if scales were initially on the device, and caused torch.compile graphs -// breaks when moving scales to the CPU. -// -#pragma once - -// Turn off clang-format for the entire file to keep it close to upstream -// clang-format off - -#include "cutlass/cutlass.h" -#include "cutlass/arch/barrier.h" - -#include "cute/tensor.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" - -namespace cutlass::epilogue::fusion { - -using namespace cute; -using namespace detail; - -// Row vector broadcast -template< - // Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least - // ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races - int Stages, - class CtaTileShapeMNK, - class Element, - class StrideMNL = Stride<_0,_1,_0>, - int Alignment = 128 / sizeof_bits_v -> -struct Sm90RowOrScalarBroadcast { - static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - static_assert( - (cute::is_same_v>) || // row vector broadcast, e.g. per-col alpha/bias - (cute::is_same_v>)); // batched row vector broadcast - - // Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem - struct SharedStorage { - alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_row; - }; - - // This struct has been modified to have a bool indicating that ptr_row is a - // scalar that must be broadcast, instead of containing a scalar that is - // valid if ptr_row is null. - struct Arguments { - Element const* ptr_row = nullptr; - bool row_broadcast = true; - StrideMNL dRow = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm90RowOrScalarBroadcast() { } - - CUTLASS_HOST_DEVICE - Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) - : params(params), - smem_row(const_cast(shared_storage.smem_row.data())) { } - - Params params; - Element* smem_row; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return true; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_zero() const { - return (!params.row_broadcast && *(params.ptr_row) == Element(0)); - } - - template - struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { - CUTLASS_DEVICE - ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params) - : gRow(cute::forward(gRow)), - sRow(cute::forward(sRow)), - params(params) {} - - GTensor gRow; // (CTA_M,CTA_N) - STensor sRow; // (CTA_M,CTA_N,PIPE) - Params const& params; - - CUTLASS_DEVICE void - begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) { - if (!params.row_broadcast) { - return; - } - - if (issue_tma_load) { - // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size - constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v / 8; - cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); - // Issue the TMA bulk copy - auto bulk_copy = Copy_Atom{}.with(*full_mbarrier_ptr); - // Filter so we don't issue redundant copies over stride-0 modes - int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; - copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index))); - } - } - }; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); - Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) - Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - - constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; - return ProducerLoadCallbacks( - cute::move(gRow), cute::move(sRow), params); - } - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params) - : tCrRow(cute::forward(tCrRow)), - tCsRow(cute::forward(tCsRow)), - params(params) {} - - RTensor tCrRow; // (CPY,CPY_M,CPY_N) - STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) - Params const& params; - - CUTLASS_DEVICE void - previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { - if (!params.row_broadcast) { - fill(tCrRow, *(params.ptr_row)); - return; - } - - if (epi_m == 0) { // Assumes M-major subtile loop - // Filter so we don't issue redundant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; - copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow)); - } - } - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - Array frg_row; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - frg_row[i] = tCrRow(epi_v * FragmentSize + i); - } - - return frg_row; - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - Tensor tCsRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) - sRow, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N) - - constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; - return ConsumerStoreCallbacks( - cute::move(tCrRow), cute::move(tCsRow), params); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Column vector broadcast -template< - int Stages, - class CtaTileShapeMNK, - class Element, - class StrideMNL = Stride<_1,_0,_0>, - int Alignment = 128 / sizeof_bits_v -> -struct Sm90ColOrScalarBroadcast { - static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); - static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - static_assert( - (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias - (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias - - // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem - struct SharedStorage { }; - - // This struct has been modified to have a bool indicating that ptr_col is a - // scalar that must be broadcast, instead of containing a scalar that is - // valid if ptr_col is null. - struct Arguments { - Element const* ptr_col = nullptr; - bool col_broadcast = true; - StrideMNL dCol = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_zero() const { - return (!params.col_broadcast && *(params.ptr_col) == Element(0)); - } - - CUTLASS_HOST_DEVICE - Sm90ColOrScalarBroadcast() { } - - CUTLASS_HOST_DEVICE - Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) - : params(params) { } - - Params params; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks( - GTensor&& tCgCol, - RTensor&& tCrCol, - CTensor&& tCcCol, - ProblemShape problem_shape, - Params const& params - ): - tCgCol(cute::forward(tCgCol)), - tCrCol(cute::forward(tCrCol)), - tCcCol(cute::forward(tCcCol)), - m(get<0>(problem_shape)), - params(params) {} - - GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - RTensor tCrCol; - CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - Params const& params; - int m; - - CUTLASS_DEVICE void - begin() { - Tensor pred = make_tensor(shape(tCgCol)); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(pred); ++i) { - pred(i) = get<0>(tCcCol(i)) < m; - } - - if (!params.col_broadcast) { - fill(tCrCol, *(params.ptr_col)); - return; - } - - // Filter so we don't issue redundant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - copy_if(pred, filter(tCgCol), filter(tCrCol)); - } - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - Array frg_col; - Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); - } - - return frg_col; - } - - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); - Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - - // Generate an identity tensor matching the shape of the global tensor and - // partition the same way, this will be used to generate the predicate - // tensor for loading - Tensor cCol = make_identity_tensor(mCol.shape()); - Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - - return ConsumerStoreCallbacks( - cute::move(tCgCol), - cute::move(tCrCol), - cute::move(tCcCol), - args.problem_shape_mnkl, - params - ); - } -}; - -} diff --git a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh new file mode 100644 index 000000000000..26de32ce2b16 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh @@ -0,0 +1,107 @@ +#pragma once + +// clang-format will break include orders +// clang-format off +#include + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/util/packed_stride.hpp" + +#include "core/math.hpp" +#include "cutlass_extensions/common.hpp" +// clang-format on + +namespace vllm::c3x { + +static inline cute::Shape get_problem_shape( + torch::Tensor const& a, torch::Tensor const& b) { + int32_t m = a.size(0), n = b.size(1), k = a.size(1); + return {m, n, k, 1}; +} + +template +void cutlass_gemm_caller( + torch::Device device, cute::Shape prob_shape, + typename GemmKernel::MainloopArguments mainloop_args, + typename GemmKernel::EpilogueArguments epilogue_args, + typename GemmKernel::TileSchedulerArguments scheduler = {}) { + cutlass::KernelHardwareInfo hw_info; + typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, + prob_shape, + mainloop_args, + epilogue_args, + hw_info, + scheduler}; + + // Launch the CUTLASS GEMM kernel. + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(device); + auto workspace = torch::empty(workspace_size, workspace_options); + + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); + CUTLASS_CHECK(status); +} + +template +void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... epilogue_params) { + using ElementAB = typename Gemm::ElementAB; + using ElementC = typename Gemm::ElementC; + using ElementD = typename Gemm::ElementD; + using GemmKernel = typename Gemm::GemmKernel; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = StrideC; + using StrideAux = StrideC; + + typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b); + auto [M, N, K, L] = prob_shape; + + StrideA a_stride = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + StrideB b_stride = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + StrideC c_stride = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + StrideD d_stride = + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + StrideAux aux_stride = d_stride; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, + b_stride}; + + auto c_ptr = static_cast(out.data_ptr()); + // auto d_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + Gemm::Epilogue::prepare_args( + std::forward(epilogue_params)...), + c_ptr, c_stride, c_ptr, d_stride}; + + cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, + epilogue_args); +} + +} // namespace vllm::c3x \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh new file mode 100644 index 000000000000..8f4df836bcc8 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh @@ -0,0 +1,147 @@ +#pragma once + +// clang-format will break include orders +// clang-format off + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "core/math.hpp" +#include "cutlass_extensions/common.hpp" +// clang-format on + +/* + Epilogues defined in, + csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp, + must contain a public type named EVTCompute of type Sm90EVT, as well as a + static prepare_args function that constructs an EVTCompute::Arguments struct. +*/ + +using namespace cute; + +namespace vllm { + +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule> +struct cutlass_3x_gemm { + using ElementAB = ElementAB_; + using ElementD = ElementD_; + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + + using Epilogue = Epilogue_; + + using StrideD = Stride, Int<0>>; + using ElementC = void; + using StrideC = StrideD; + + using EVTCompute = typename Epilogue::EVTCompute; + + // These are the minimum alignments needed for the kernels to compile + static constexpr int AlignmentAB = + 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentCD = 4; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, float, ElementC, StrideC, AlignmentCD, ElementD, StrideD, + AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + + // clang-format off + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementAB, cutlass::layout::RowMajor, AlignmentAB, + ElementAB, cutlass::layout::ColumnMajor, AlignmentAB, + ElementAcc, TileShape, ClusterShape, + Stages, + KernelSchedule>::CollectiveOp; + // clang-format on + + using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>>; + + struct GemmKernel : public KernelType {}; +}; + +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule> +struct cutlass_3x_gemm_sm100 { + using ElementAB = ElementAB_; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = + 128 / cutlass::sizeof_bits::value; + + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = + 128 / cutlass::sizeof_bits::value; + + using ElementC = void; + using LayoutC = cutlass::layout::RowMajor; + static constexpr int AlignmentC = + 128 / cutlass::sizeof_bits::value; + + using ElementD = ElementD_; + using LayoutD = cutlass::layout::RowMajor; + static constexpr int AlignmentD = AlignmentC; + + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + using Epilogue = Epilogue_; + + // MMA type + using ElementAccumulator = float; + + // Epilogue types + using ElementBias = cutlass::half_t; + using ElementCompute = float; + using ElementAux = ElementD; + using LayoutAux = LayoutD; + using ElementAmax = float; + + using EVTCompute = typename Epilogue::EVTCompute; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, EpilogueSchedule, + EVTCompute>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB, + LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, CollectiveMainloop, CollectiveEpilogue, void>; +}; + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu new file mode 100644 index 000000000000..4cd38f4975df --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu @@ -0,0 +1,24 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_sm90_int8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + std::optional const& azp, + std::optional const& bias) { + if (azp) { + return cutlass_scaled_mm_sm90_int8_epilogue< + c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj, + *azp, bias); + } else { + return cutlass_scaled_mm_sm90_int8_epilogue( + out, a, b, a_scales, b_scales, azp_adj, bias); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu new file mode 100644 index 000000000000..0501e6da160e --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu @@ -0,0 +1,24 @@ + +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + if (out.dtype() == torch::kBFloat16) { + cutlass_gemm_blockwise_sm90_fp8_dispatch( + out, a, b, a_scales, b_scales); + + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + cutlass_gemm_blockwise_sm90_fp8_dispatch( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh new file mode 100644 index 000000000000..e089c3d4be2c --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh @@ -0,0 +1,194 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass_extensions/gemm/dispatch_policy.hpp" +#include "cutlass_extensions/gemm/collective/collective_builder.hpp" + +#include "cutlass_gemm_caller.cuh" + +namespace vllm { + +using namespace cute; + +template > +struct cutlass_3x_gemm_fp8_blockwise { + using GroupSizeM = Int; + using GroupSizeN = Int; + using GroupSizeK = Int; + using TileSizeM = Int; + + static_assert(TileSizeM_ % GroupSizeM_ == 0, + "TileSizeM must be a multiple of GroupSizeM"); + + using ElementAB = cutlass::float_e4m3_t; + + using ElementA = ElementAB; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = ElementAB; + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementD = OutType; + using StrideD = Stride, Int<0>>; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using ElementC = void; + using StrideC = StrideD; + static constexpr int AlignmentC = AlignmentD; + + using ElementAccumulator = float; + using ElementBlockScale = float; + using ElementCompute = float; + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = Shape; + + using KernelSchedule = cutlass::gemm:: + KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< + GroupSizeM_>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90AccFetch>; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, + ElementAccumulator, ElementCompute, ElementC, StrideC, AlignmentC, + ElementD, StrideD, AlignmentD, EpilogueSchedule, + StoreEpilogueCompute>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, + LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, + SchedulerType>>; + + struct GemmKernel : public KernelType {}; + + using StrideA = typename GemmKernel::StrideA; + using StrideB = typename GemmKernel::StrideB; +}; + +template +void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + using GemmKernel = typename Gemm::GemmKernel; + + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + auto prob_shape = c3x::get_problem_shape(a, b); + int32_t m = get<0>(prob_shape), n = get<1>(prob_shape), + k = get<2>(prob_shape); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideA = Stride, int64_t>; + using StrideB = Stride, int64_t>; + using StrideC = typename Gemm::StrideC; + + StrideA a_stride{lda, Int<1>{}, 0}; + StrideB b_stride{ldb, Int<1>{}, 0}; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); + + // Check is the t is contiguous and is 1D or 2D with one of the dimensions + // being 1 (i.e. a row or column vector) + auto is_contiguous_vector = [](const torch::Tensor& t) { + auto t_sizes = t.sizes(); + return t.is_contiguous() && + (t.dim() == 1 || + (t.dim() == 2 && + *std::min_element(t_sizes.begin(), t_sizes.end()) == 1)); + }; + + // TODO(lucas): lets clean-up the kernel so that we pass in Strides so + // we don't have to deal with enforcing implicit layouts + TORCH_CHECK(a_scales.size(0) == m / Gemm::GroupSizeM::value); + TORCH_CHECK(a_scales.size(1) == k / Gemm::GroupSizeK::value); + TORCH_CHECK(a_scales.stride(0) == 1 || is_contiguous_vector(a_scales), + "a_scales must be M major"); + TORCH_CHECK(b_scales.size(0) == k / Gemm::GroupSizeK::value); + TORCH_CHECK(b_scales.size(1) == n / Gemm::GroupSizeN::value); + TORCH_CHECK(b_scales.stride(0) == 1 || is_contiguous_vector(b_scales), + "b_scales must be K major"); + typename GemmKernel::MainloopArguments mainloop_args{ + a_ptr, a_stride, b_ptr, b_stride, a_scales_ptr, b_scales_ptr}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, c_ptr, c_stride, c_ptr, c_stride}; + + typename GemmKernel::TileSchedulerArguments scheduler; + + static constexpr bool UsesStreamKScheduler = + cute::is_same_v; + + if constexpr (UsesStreamKScheduler) { + using DecompositionMode = typename cutlass::gemm::kernel::detail:: + PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using ReductionMode = typename cutlass::gemm::kernel::detail:: + PersistentTileSchedulerSm90StreamKParams::ReductionMode; + + scheduler.decomposition_mode = DecompositionMode::StreamK; + scheduler.reduction_mode = ReductionMode::Nondeterministic; + } + + c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, + epilogue_args, scheduler); +} + +template +void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto k = a.size(1); + auto n = b.size(1); + + if (k > 3 * n) { + cutlass_gemm_caller_blockwise>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise>( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp new file mode 100644 index 000000000000..85272804774d --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include + +namespace vllm { + +void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias); + +void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias); + +void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + std::optional const& azp, + std::optional const& bias); + +void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); + +void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias); + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu new file mode 100644 index 000000000000..cf2cccc913f6 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu @@ -0,0 +1,24 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_sm100_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias) { + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (bias) { + TORCH_CHECK(bias->dtype() == out.dtype(), + "currently bias dtype must match output dtype ", out.dtype()); + return cutlass_scaled_mm_sm100_fp8_epilogue( + out, a, b, a_scales, b_scales, *bias); + } else { + return cutlass_scaled_mm_sm100_fp8_epilogue( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh new file mode 100644 index 000000000000..468b77d9593b --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh @@ -0,0 +1,67 @@ +#pragma once + +#include "scaled_mm.cuh" +#include "cutlass_gemm_caller.cuh" + +/** + * This file defines Gemm kernel configurations for SM100 (fp8) based on the + * Gemm shape. + */ + +namespace vllm { + +using c3x::cutlass_gemm_caller; + +template typename Epilogue> +struct sm100_fp8_config_default { + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_256, _128, _64>; + using ClusterShape = Shape<_2, _2, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm_sm100; +}; + +template typename Epilogue, + typename... EpilogueArgs> +inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + + using Cutlass3xGemmDefault = + typename sm100_fp8_config_default::Cutlass3xGemm; + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); +} + +template