-
-
Notifications
You must be signed in to change notification settings - Fork 13k
[CI/Build] Fix AMD CI: test_cpu_gpu.py #27388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,11 +8,29 @@ | |
|
|
||
| from vllm.platforms import current_platform | ||
| from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend | ||
| from vllm.v1.attention.backends.flashinfer import FlashInferBackend | ||
| from vllm.v1.attention.backends.mla.flashattn_mla import FlashAttnMLABackend | ||
| from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec | ||
| from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler | ||
|
|
||
| BACKENDS_TO_TEST = [FlashAttentionBackend] | ||
|
|
||
| if current_platform.is_cuda(): | ||
|
||
| from vllm.v1.attention.backends.flashinfer import FlashInferBackend | ||
|
|
||
| BACKENDS_TO_TEST.append(FlashInferBackend) | ||
|
|
||
| from vllm.v1.attention.backends.mla.flashattn_mla import FlashAttnMLABackend | ||
|
|
||
| BACKENDS_TO_TEST.append(FlashAttnMLABackend) | ||
|
|
||
| if current_platform.is_rocm(): | ||
| from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend | ||
|
|
||
| BACKENDS_TO_TEST.append(RocmAttentionBackend) | ||
|
|
||
| from vllm.v1.attention.backends.mla.triton_mla import TritonMLABackend | ||
|
|
||
| BACKENDS_TO_TEST.append(TritonMLABackend) | ||
|
|
||
| NUM_GPU_BLOCKS = [64] | ||
| NUM_CPU_BLOCKS = [256] | ||
| GPU_BLOCK_SIZES = [16] | ||
|
|
@@ -26,6 +44,10 @@ | |
| NUM_MAPPINGS = [3] | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| len(BACKENDS_TO_TEST) < 2, | ||
| reason="Need at least 2 backends to test heterogeneous KV cache layouts", | ||
| ) | ||
| @pytest.mark.parametrize("gpu_to_cpu", [True, False]) | ||
| @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) | ||
| @pytest.mark.parametrize("head_size", HEAD_SIZES) | ||
|
|
@@ -55,8 +77,7 @@ def test_transfer( | |
| ) -> None: | ||
| current_platform.seed_everything(seed) | ||
|
|
||
| # create per-layer GPU KV caches | ||
| attn_backends_list = [FlashAttentionBackend, FlashInferBackend, FlashAttnMLABackend] | ||
| attn_backends_list = BACKENDS_TO_TEST | ||
|
|
||
| gpu_caches = {} | ||
| attn_backends = {} | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.