Skip to content

Conversation

@ganyi1996ppo
Copy link
Contributor

@ganyi1996ppo ganyi1996ppo commented Oct 23, 2025

Purpose

This PR introduce persistent mla kernel implementation for AiterMLABackend, so as the fp8 support.

Test Plan

Verify the accuracy on gsm8k

Test Result

# bf16 acc result on gsm8k
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9522|±  |0.0059|
|     |       |strict-match    |     5|exact_match|↑  |0.9507|±  |0.0060|
# fp8 acc result on gsm8k
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.953|±  |0.0058|
|     |       |strict-match    |     5|exact_match|↑  |0.953|±  |0.0058|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

gpu = torch.cuda.current_device()
device_properties = torch.cuda.get_device_properties(gpu)
cu_num = device_properties.multi_processor_count

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no need for the lines 100-102 the metadatas can be created using the aiter.get_mla_metadata_info_v1
which returns the sizes needed for each metadata tensors with the respective dtypes
also the persistent mode supports different head sizes which allows us to remove restriction of num_heads from 16 and 128 by using the same condition in aiter at this line
so instead of always using persistent mode we can make it conditional to num_heads.
which means we can remove these line:

assert num_heads == 16 or num_heads == 128, (
f"Aiter MLA only supports 16 or 128 number of heads.\n"
f"Provided {num_heads} number of heads.\n"
"Try adjusting tensor_parallel_size value."
)

This way we support both persistent and non persistent mode while we can run deepseek-R1 model on tp4 since persistent mode supports a flexible num_heads

so can use similar logic as following.

        self.persistent_mode = False
        num_heads = self.num_heads

        if num_heads == 16 or num_heads in range(32, 513, 16):
            self.persistent_mode = True

        if self.persistent_mode:
            import aiter

            (
                (work_meta_data_size, w_dtype),
                (work_indptr_size, w_indptr_dtype),
                (work_info_set_size, w_info_set_dtype),
                (reduce_indptr_size, r_indptr_dtype),
                (reduce_final_map_size, r_final_map_dtype),
                (reduce_partial_map_size, r_partial_map_dtype),
            ) = aiter.get_mla_metadata_info_v1(
                max_num_reqs,
                1,  # mtp=1 
                self.num_heads,
                vllm_config.model_config.dtype,
                kv_cache_spec.dtype,
                is_sparse=False,
                fast_mode=True,
            )
            self.work_meta_data = torch.empty(
                work_meta_data_size, dtype=w_dtype, device=device
            )
            self.work_indptr = torch.empty(
                work_indptr_size, dtype=w_indptr_dtype, device=device
            )
            self.work_info_set = torch.empty(
                work_info_set_size, dtype=w_info_set_dtype, device=device
            )
            self.reduce_indptr = torch.empty(
                reduce_indptr_size, dtype=r_indptr_dtype, device=device
            )
            self.reduce_final_map = torch.empty(
                reduce_final_map_size, dtype=r_final_map_dtype, device=device
            )
            self.reduce_partial_map = torch.empty(
                reduce_partial_map_size, dtype=r_partial_map_dtype, device=device
            )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion and sorry for the delay. But from what I learned from aiter, persistent kernel is mainly aim to deal the various length request case and non-persistent is mainly for the case with more similar length. So instead of judging by head, maybe we can leave it as an env variable to the user?

max_seqlen_qo = torch.max(query_lens).item()

import aiter

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as suggested above in the case we would want to use self.persistent_mode then we have the code block below. so that we still support non persistent mode.

@wuhuikx
Copy link
Contributor

wuhuikx commented Nov 18, 2025

Blocked by this issue ROCm/aiter#1420

@ganyi1996ppo ganyi1996ppo force-pushed the ganyi/rocm_fp8_mla_and_persistent_kernel branch from f23e28b to bdb40c3 Compare November 27, 2025 07:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

rocm Related to AMD ROCm v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants