22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33import contextlib
44import os
5+ import weakref
6+ from contextlib import ExitStack
57
68import pytest
79
10+ from tests .utils import wait_for_gpu_memory_to_clear
811from vllm import LLM , SamplingParams
912from vllm .config import CompilationConfig
1013from vllm .platforms import current_platform
1114
12- MODEL = "Qwen/Qwen2-1.5B-Instruct"
13-
1415
1516@contextlib .contextmanager
1617def temporary_environ (env_vars ):
@@ -31,64 +32,119 @@ def temporary_environ(env_vars):
3132 os .environ [k ] = v
3233
3334
34- @pytest .fixture (scope = "module" )
35- def full_cudagraph_llm ():
36- with temporary_environ ({
37- "VLLM_USE_V1" : "1" ,
38- "VLLM_FLASH_ATTN_VERSION" : "3"
39- }):
40- return LLM (model = MODEL ,
41- gpu_memory_utilization = 0.3 ,
42- compilation_config = CompilationConfig (full_cuda_graph = True ))
43-
35+ @pytest .fixture (scope = "class" )
36+ def llm_pair (request ):
37+ model = request .param
4438
45- @pytest .fixture (scope = "module" )
46- def piecewise_llm ():
4739 with temporary_environ ({
4840 "VLLM_USE_V1" : "1" ,
4941 "VLLM_FLASH_ATTN_VERSION" : "3"
5042 }):
51- return LLM (model = MODEL ,
52- gpu_memory_utilization = 0.6 ,
53- compilation_config = CompilationConfig ())
54-
55-
56- def generate_text (llm : LLM , batch_size : int , max_tokens : int ):
57- prompts = ["Hi my name is" ] * batch_size
58- sampling_params = SamplingParams (temperature = 0.0 ,
59- max_tokens = max_tokens ,
60- top_p = 0.95 )
61-
62- return llm .generate (prompts , sampling_params )
63-
64-
43+ full = LLM (
44+ model = model ,
45+ gpu_memory_utilization = 0.45 ,
46+ trust_remote_code = True ,
47+ max_model_len = 1024 ,
48+ compilation_config = CompilationConfig (full_cuda_graph = True ),
49+ )
50+ piecewise = LLM (
51+ model = model ,
52+ gpu_memory_utilization = 0.45 ,
53+ trust_remote_code = True ,
54+ max_model_len = 1024 ,
55+ compilation_config = CompilationConfig (),
56+ )
57+
58+ # PyTest caches the fixture values so we use weakref.proxy to enable GC
59+ yield weakref .proxy (full ), weakref .proxy (piecewise )
60+ del full
61+ del piecewise
62+
63+ wait_for_gpu_memory_to_clear (
64+ devices = [0 ],
65+ threshold_ratio = 0.1 ,
66+ )
67+
68+
69+ @pytest .mark .parametrize (
70+ "llm_pair" ,
71+ [
72+ # Model names for the llm_pair fixture
73+ "deepseek-ai/DeepSeek-V2-Lite" ,
74+ "Qwen/Qwen2-1.5B-Instruct"
75+ ],
76+ indirect = True )
6577@pytest .mark .skipif (current_platform .get_device_capability () != (9 , 0 ),
66- reason = "Only Hopper GPUs support FlashAttention 3" )
67- @pytest .mark .parametrize (("batch_size" , "max_tokens" ), [(1 , 10 ), (7 , 10 ),
68- (16 , 10 ), (25 , 10 ),
69- (32 , 10 ), (45 , 10 ),
70- (64 , 10 ), (8 , 5 ),
71- (8 , 20 ), (8 , 200 )])
72- def test_full_cudagraph (batch_size , max_tokens , full_cudagraph_llm ,
73- piecewise_llm ):
78+ reason = "Only Hopper GPUs support FA3 and FlashMLA" )
79+ class TestFullCUDAGraph :
7480 """
75- Load full cudagraph model and piecewise model once, and at the same time to
76- reuse them across various test cases .
81+ Use a class such that an llm pair is constructed once for all
82+ batch_size/max_tokens combinations and released immediately after .
7783
78- Test various batch sizes and max_tokens to ensure that the full cudagraph
79- compilation works for padded cases too .
84+ Module-scope fixtures would stick around the whole time,
85+ meaning there would be multiple LLM instances hogging memory simultaneously .
8086 """
81- piecewise_responses = generate_text (piecewise_llm ,
82- batch_size = batch_size ,
83- max_tokens = max_tokens )
84- full_cudagraph_responses = generate_text (full_cudagraph_llm ,
85- batch_size = batch_size ,
86- max_tokens = max_tokens )
8787
88- # Check that all responses are the same
89- for i in range (len (piecewise_responses )):
90- assert piecewise_responses [i ].outputs [
91- 0 ].text == full_cudagraph_responses [i ].outputs [0 ].text
88+ @pytest .mark .parametrize (("batch_size" , "max_tokens" ), [
89+ (1 , 10 ),
90+ (7 , 10 ),
91+ (16 , 10 ),
92+ (25 , 10 ),
93+ (32 , 10 ),
94+ (45 , 10 ),
95+ (64 , 10 ),
96+ (123 , 10 ),
97+ (8 , 5 ),
98+ (8 , 30 ),
99+ ])
100+ def test_full_cudagraph (self , batch_size , max_tokens ,
101+ llm_pair : tuple [LLM , LLM ]):
102+ """
103+ Test various batch sizes and max_tokens to ensure that the
104+ full cudagraph compilation works for padded cases too.
105+ """
106+
107+ piecewise_llm , full_cudagraph_llm = llm_pair
108+
109+ prompts = ["Hello, my name is" ] * batch_size
110+ sampling_params = SamplingParams (temperature = 0.0 ,
111+ max_tokens = max_tokens ,
112+ top_p = 0.95 )
113+
114+ piecewise_responses = piecewise_llm .generate (prompts , sampling_params )
115+ full_responses = full_cudagraph_llm .generate (prompts , sampling_params )
116+
117+ # Check that all responses are the same
118+ for piecewise_res , full_res in zip (piecewise_responses ,
119+ full_responses ):
120+ assert piecewise_res .outputs [0 ].text == full_res .outputs [0 ].text
121+
122+
123+ @pytest .mark .parametrize (
124+ "model, supported" ,
125+ [
126+ ("Qwen/Qwen2-1.5B-Instruct" , True ),
127+ # MLA does not support capturing CUDA Graphs with size > max_num_seqs
128+ ("deepseek-ai/DeepSeek-V2-Lite" , False ),
129+ ])
130+ @pytest .mark .skipif (current_platform .get_device_capability () != (9 , 0 ),
131+ reason = "Only Hopper GPUs support FA3 and FlashMLA" )
132+ def test_lower_max_num_seqs (model , supported ):
133+ with temporary_environ ({
134+ "VLLM_USE_V1" : "1" ,
135+ "VLLM_FLASH_ATTN_VERSION" : "3"
136+ }), ExitStack () as stack :
137+ if not supported :
138+ stack .enter_context (pytest .raises (RuntimeError ))
139+
140+ llm = LLM (model = model ,
141+ max_num_seqs = 256 ,
142+ trust_remote_code = True ,
143+ max_model_len = 1024 ,
144+ compilation_config = CompilationConfig (
145+ full_cuda_graph = True ,
146+ cudagraph_capture_sizes = [64 , 256 , 512 ]))
147+ llm .generate (["Hello, my name is" ] * 10 )
92148
93149
94150def test_full_cudagraph_with_invalid_backend ():
@@ -97,5 +153,5 @@ def test_full_cudagraph_with_invalid_backend():
97153 "VLLM_FLASH_ATTN_VERSION" :
98154 "2" #FA2 not supported with full_cuda_graph
99155 }), pytest .raises (RuntimeError ):
100- LLM (model = MODEL ,
156+ LLM (model = "Qwen/Qwen2-1.5B-Instruct" ,
101157 compilation_config = CompilationConfig (full_cuda_graph = True ))
0 commit comments