-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[ROCm][MLA] Enable MLA persistent kernel with fp8 and bf16 support #27380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[ROCm][MLA] Enable MLA persistent kernel with fp8 and bf16 support #27380
Conversation
| gpu = torch.cuda.current_device() | ||
| device_properties = torch.cuda.get_device_properties(gpu) | ||
| cu_num = device_properties.multi_processor_count | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is no need for the lines 100-102 the metadatas can be created using the aiter.get_mla_metadata_info_v1
which returns the sizes needed for each metadata tensors with the respective dtypes
also the persistent mode supports different head sizes which allows us to remove restriction of num_heads from 16 and 128 by using the same condition in aiter at this line
so instead of always using persistent mode we can make it conditional to num_heads.
which means we can remove these line:
vllm/vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Lines 228 to 232 in d4acf51
| assert num_heads == 16 or num_heads == 128, ( | |
| f"Aiter MLA only supports 16 or 128 number of heads.\n" | |
| f"Provided {num_heads} number of heads.\n" | |
| "Try adjusting tensor_parallel_size value." | |
| ) |
This way we support both persistent and non persistent mode while we can run deepseek-R1 model on tp4 since persistent mode supports a flexible num_heads
so can use similar logic as following.
self.persistent_mode = False
num_heads = self.num_heads
if num_heads == 16 or num_heads in range(32, 513, 16):
self.persistent_mode = True
if self.persistent_mode:
import aiter
(
(work_meta_data_size, w_dtype),
(work_indptr_size, w_indptr_dtype),
(work_info_set_size, w_info_set_dtype),
(reduce_indptr_size, r_indptr_dtype),
(reduce_final_map_size, r_final_map_dtype),
(reduce_partial_map_size, r_partial_map_dtype),
) = aiter.get_mla_metadata_info_v1(
max_num_reqs,
1, # mtp=1
self.num_heads,
vllm_config.model_config.dtype,
kv_cache_spec.dtype,
is_sparse=False,
fast_mode=True,
)
self.work_meta_data = torch.empty(
work_meta_data_size, dtype=w_dtype, device=device
)
self.work_indptr = torch.empty(
work_indptr_size, dtype=w_indptr_dtype, device=device
)
self.work_info_set = torch.empty(
work_info_set_size, dtype=w_info_set_dtype, device=device
)
self.reduce_indptr = torch.empty(
reduce_indptr_size, dtype=r_indptr_dtype, device=device
)
self.reduce_final_map = torch.empty(
reduce_final_map_size, dtype=r_final_map_dtype, device=device
)
self.reduce_partial_map = torch.empty(
reduce_partial_map_size, dtype=r_partial_map_dtype, device=device
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion and sorry for the delay. But from what I learned from aiter, persistent kernel is mainly aim to deal the various length request case and non-persistent is mainly for the case with more similar length. So instead of judging by head, maybe we can leave it as an env variable to the user?
| max_seqlen_qo = torch.max(query_lens).item() | ||
|
|
||
| import aiter | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as suggested above in the case we would want to use self.persistent_mode then we have the code block below. so that we still support non persistent mode.
|
Blocked by this issue ROCm/aiter#1420 |
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
f23e28b to
bdb40c3
Compare
Purpose
This PR introduce persistent mla kernel implementation for
AiterMLABackend, so as the fp8 support.Test Plan
Verify the accuracy on gsm8k
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.