Skip to content

Commit 87ec0c7

Browse files
authored
Merge pull request vllm-project#22 from ROCm/csrikris_pa_opt_shomy_1_16
Integrate PagedAttention Optimization custom kernel into vLLM
2 parents c47256c + c774517 commit 87ec0c7

File tree

7 files changed

+1313
-41
lines changed

7 files changed

+1313
-41
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ set(CUSTOM_SRC
196196
"csrc/custom/custom_kernels.cu"
197197
"csrc/custom/fused_kernels.cu"
198198
"csrc/custom/custom.cu"
199+
"csrc/custom/paged_attention/attention_ll4mi.cu"
199200
)
200201

201202
define_gpu_extension_target(

ROCm_performance.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,9 @@ The default attention function on ROCm is using triton attention kernel. To fall
1212
## Tunable ops
1313
Pytorch tunable ops are supported.
1414
Define the following environment symbol: `PYTORCH_TUNABLEOP_ENABLED=1` in order to enable both the runtime tuning and the subsequent use of tuned results. To only use the tuned results without tuning any newly encountered shapes, also define `PYTORCH_TUNABLEOP_TUNING=1`
15+
16+
## Custom PagedAttention
17+
18+
On ROCm, to have better performance, a custom paged attention is available by switching on the env variable: `VLLM_USE_ROCM_CUSTOM_PAGED_ATTN=1`.
19+
Currently, this env variable is enabled by default. To fallback to PagedAttention v2 kernel assign the env variable to 0.
20+
The custom PagedAttention kernel is enabled for dtype: fp16, block-size=16, head-size=128, and max context length <= 16k, with GQA ratio (num_heads//num_kv_heads) between 1 to 16. On all the other cases, we fallback to PagedAttention v2 kernel.

benchmarks/kernels/benchmark_paged_attention.py

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
import torch
77

88
from vllm._C import ops
9+
from vllm._custom_C import paged_attention_custom
910
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
1011

1112
NUM_BLOCKS = 1024
12-
PARTITION_SIZE = 512
13+
PARTITION_SIZE = 256
1314

1415

1516
@torch.inference_mode()
@@ -77,6 +78,9 @@ def main(
7778
# Prepare for the paged attention kernel.
7879
output = torch.empty_like(query)
7980
if version == "v2":
81+
if not args.custom_paged_attn:
82+
global PARTITION_SIZE
83+
PARTITION_SIZE = 512
8084
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
8185
PARTITION_SIZE)
8286
tmp_output = torch.empty(
@@ -118,24 +122,43 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
118122
kv_scale,
119123
)
120124
elif version == "v2":
121-
ops.paged_attention_v2(
122-
output,
123-
exp_sums,
124-
max_logits,
125-
tmp_output,
126-
query,
127-
key_cache,
128-
value_cache,
129-
num_kv_heads,
130-
scale,
131-
block_tables,
132-
context_lens,
133-
block_size,
134-
max_context_len,
135-
alibi_slopes,
136-
kv_cache_dtype,
137-
kv_scale,
138-
)
125+
if not args.custom_paged_attn:
126+
ops.paged_attention_v2(
127+
output,
128+
exp_sums,
129+
max_logits,
130+
tmp_output,
131+
query,
132+
key_cache,
133+
value_cache,
134+
num_kv_heads,
135+
scale,
136+
block_tables,
137+
context_lens,
138+
block_size,
139+
max_context_len,
140+
alibi_slopes,
141+
kv_cache_dtype,
142+
kv_scale,
143+
)
144+
else:
145+
paged_attention_custom(
146+
output,
147+
exp_sums,
148+
max_logits,
149+
tmp_output,
150+
query,
151+
key_cache,
152+
value_cache,
153+
num_kv_heads,
154+
scale,
155+
block_tables,
156+
context_lens,
157+
block_size,
158+
max_context_len,
159+
alibi_slopes,
160+
kv_cache_dtype,
161+
)
139162
else:
140163
raise ValueError(f"Invalid version: {version}")
141164
torch.cuda.synchronize()
@@ -191,6 +214,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
191214
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
192215
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
193216
'common inference criteria.')
217+
parser.add_argument("--custom-paged-attn",
218+
action="store_true",
219+
help="Use custom paged attention")
194220
args = parser.parse_args()
195221
print(args)
196222

csrc/custom/custom.cu

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,36 @@ void MMCustomGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c) {
6464
at::cuda::getCurrentCUDAStream());
6565
}
6666

67+
void paged_attention_custom(
68+
torch::Tensor& out,
69+
torch::Tensor& exp_sums,
70+
torch::Tensor& max_logits,
71+
torch::Tensor& tmp_out,
72+
torch::Tensor& query,
73+
torch::Tensor& key_cache,
74+
torch::Tensor& value_cache,
75+
int num_kv_heads,
76+
float scale,
77+
torch::Tensor& block_tables,
78+
torch::Tensor& context_lens,
79+
int block_size,
80+
int max_context_len,
81+
#if 0
82+
torch::Tensor& qk_out,
83+
torch::Tensor& softmax_out,
84+
#endif
85+
const c10::optional<torch::Tensor>& alibi_slopes,
86+
const std::string& kv_cache_dtype);
87+
6788
// declare the extension module with the AddGPU function:
6889
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
6990
m.doc() = "pybind11 example plugin";
7091
m.def("LLMM1", &LLMM1);
7192
m.def("LLMM_Silu", &LLMM_Silu);
7293
m.def("LLZZ", &LLZZ);
94+
m.def(
95+
"paged_attention_custom",
96+
&paged_attention_custom,
97+
"PagedAttention LL4Mi Custom.");
7398
//m.def("MMCustomGPU", &MMCustomGPU);
7499
}

0 commit comments

Comments
 (0)