Skip to content

[Feature]: Need a Gather-GEMM kernel before MHA for DeepSeeek model #1354

@junhaha666

Description

@junhaha666

Suggestion Description

For the scenarios of Prefix cache and Chunk Prefill, the latent KV generated by the previous part of the tokens in the sequence is in the KV cache, it needs to be gathered according to the index. A linear layer is then required to revert back to the dimension of normal K V and concatenate it with the position encoding part. In this way, the K V obtained can participate in the subsequent MHA calculation.

Kernel input:

  1. kv cache [num_block, block_size, 576] bf16/fp8
  2. kv_indptr [batch_size+1] int32
  3. kv_indices[xxx] int32
  4. cu_seqlens_k [batch_size+1] int32
  5. kv_b_proj.weight [2*128/TP * 128, 512] fp8
  6. kv_b_proj.scale [2*128/TP, 4] fp32
  7. kv_cache_scale [1] None/fp32

Kernel output:

  1. K [toekn_num, 128/TP * 192] bf16
  2. V [toekn_num, 128/TP * 128] bf16

Performance: Normal fp8 blockscale gemm performance

Operating System

No response

GPU

MI308

ROCm Component

No response

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions