Skip to content

Commit 7cb7eba

Browse files
committed
Merge branch 'main' of https://github.com/vllm-project/vllm into svij-ultravox-lora-dec-16
2 parents 575b5dc + df450aa commit 7cb7eba

File tree

31 files changed

+852
-234
lines changed

31 files changed

+852
-234
lines changed

benchmarks/backend_request_func.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class RequestFuncInput:
2222
prompt_len: int
2323
output_len: int
2424
model: str
25+
model_name: Optional[str] = None
2526
best_of: int = 1
2627
logprobs: Optional[int] = None
2728
extra_body: Optional[dict] = None
@@ -78,7 +79,7 @@ async def async_request_tgi(
7879
continue
7980
chunk_bytes = chunk_bytes.decode("utf-8")
8081

81-
#NOTE: Sometimes TGI returns a ping response without
82+
# NOTE: Sometimes TGI returns a ping response without
8283
# any data, we should skip it.
8384
if chunk_bytes.startswith(":"):
8485
continue
@@ -235,7 +236,8 @@ async def async_request_openai_completions(
235236

236237
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
237238
payload = {
238-
"model": request_func_input.model,
239+
"model": request_func_input.model_name \
240+
if request_func_input.model_name else request_func_input.model,
239241
"prompt": request_func_input.prompt,
240242
"temperature": 0.0,
241243
"best_of": request_func_input.best_of,
@@ -328,7 +330,8 @@ async def async_request_openai_chat_completions(
328330
if request_func_input.multi_modal_content:
329331
content.append(request_func_input.multi_modal_content)
330332
payload = {
331-
"model": request_func_input.model,
333+
"model": request_func_input.model_name \
334+
if request_func_input.model_name else request_func_input.model,
332335
"messages": [
333336
{
334337
"role": "user",

benchmarks/benchmark_serving.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,7 @@ async def benchmark(
525525
api_url: str,
526526
base_url: str,
527527
model_id: str,
528+
model_name: str,
528529
tokenizer: PreTrainedTokenizerBase,
529530
input_requests: List[Tuple[str, int, int]],
530531
logprobs: Optional[int],
@@ -553,6 +554,7 @@ async def benchmark(
553554
"Multi-modal content is only supported on 'openai-chat' backend.")
554555
test_input = RequestFuncInput(
555556
model=model_id,
557+
model_name=model_name,
556558
prompt=test_prompt,
557559
api_url=api_url,
558560
prompt_len=test_prompt_len,
@@ -573,6 +575,7 @@ async def benchmark(
573575
if profile:
574576
print("Starting profiler...")
575577
profile_input = RequestFuncInput(model=model_id,
578+
model_name=model_name,
576579
prompt=test_prompt,
577580
api_url=base_url + "/start_profile",
578581
prompt_len=test_prompt_len,
@@ -616,6 +619,7 @@ async def limited_request_func(request_func_input, pbar):
616619
async for request in get_request(input_requests, request_rate, burstiness):
617620
prompt, prompt_len, output_len, mm_content = request
618621
request_func_input = RequestFuncInput(model=model_id,
622+
model_name=model_name,
619623
prompt=prompt,
620624
api_url=api_url,
621625
prompt_len=prompt_len,
@@ -780,6 +784,7 @@ def main(args: argparse.Namespace):
780784

781785
backend = args.backend
782786
model_id = args.model
787+
model_name = args.served_model_name
783788
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
784789
tokenizer_mode = args.tokenizer_mode
785790

@@ -877,6 +882,7 @@ def main(args: argparse.Namespace):
877882
api_url=api_url,
878883
base_url=base_url,
879884
model_id=model_id,
885+
model_name=model_name,
880886
tokenizer=tokenizer,
881887
input_requests=input_requests,
882888
logprobs=args.logprobs,
@@ -1222,5 +1228,12 @@ def main(args: argparse.Namespace):
12221228
'always use the slow tokenizer. \n* '
12231229
'"mistral" will always use the `mistral_common` tokenizer.')
12241230

1231+
parser.add_argument("--served-model-name",
1232+
type=str,
1233+
default=None,
1234+
help="The model name used in the API. "
1235+
"If not specified, the model name will be the "
1236+
"same as the ``--model`` argument. ")
1237+
12251238
args = parser.parse_args()
12261239
main(args)

docs/source/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ See [this page](#generative-models) for more information on how to use generativ
754754
- `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc.
755755
- ✅︎
756756
- ✅︎
757-
-
757+
- ✅︎
758758
* - `UltravoxModel`
759759
- Ultravox
760760
- T + A<sup>E+</sup>

tests/models/decoder_only/vision_language/test_qwen2_vl.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def batch_make_image_embeddings(
105105
pixel_values = preprocess_result["pixel_values"]
106106
image_grid_thw = preprocess_result["image_grid_thw"]
107107

108-
# pixel values to embeddinds & grid_thws
108+
# pixel values to embeddings & grid_thws
109109
with torch.no_grad():
110110
visual = llm.llm_engine.model_executor.driver_worker. \
111111
model_runner.model.visual
@@ -124,11 +124,10 @@ def batch_make_image_embeddings(
124124
for image_batch in image_batches_:
125125
cur_batch_image_count = len(image_batch)
126126
merge_size = image_processor.merge_size
127-
cur_batch_embed_len = sum([
128-
grid_thw.prod() // merge_size // merge_size
127+
cur_batch_embed_len = sum(
128+
grid_thw.prod(-1) // merge_size // merge_size
129129
for grid_thw in image_grid_thw[image_counter:image_counter +
130-
cur_batch_image_count]
131-
])
130+
cur_batch_image_count])
132131

133132
result.append({
134133
"image_embeds":
@@ -187,7 +186,7 @@ def batch_make_video_embeddings(
187186
pixel_values = preprocess_result["pixel_values_videos"]
188187
video_grid_thw = preprocess_result["video_grid_thw"]
189188

190-
# pixel values to embeddinds & grid_thws
189+
# pixel values to embeddings & grid_thws
191190
with torch.no_grad():
192191
visual = llm.llm_engine.model_executor.driver_worker.\
193192
model_runner.model.visual
@@ -206,11 +205,10 @@ def batch_make_video_embeddings(
206205
for video_batch in video_batches_:
207206
cur_batch_video_count = len(video_batch)
208207
merge_size = image_processor.merge_size
209-
cur_batch_embed_len = sum([
210-
grid_thw.prod() // merge_size // merge_size
208+
cur_batch_embed_len = sum(
209+
grid_thw.prod(-1) // merge_size // merge_size
211210
for grid_thw in video_grid_thw[video_counter:video_counter +
212-
cur_batch_video_count]
213-
])
211+
cur_batch_video_count])
214212

215213
result.append({
216214
"video_embeds":

tests/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class _HfExamplesInfo:
6969
"DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501
7070
trust_remote_code=True),
7171
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
72+
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
7273
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
7374
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"),
7475
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),

tests/weight_loading/models.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,5 @@ marlin, nm-testing/zephyr-beta-7b-marlin-g128, main
3030
marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main
3131
qqq, HandH1998/QQQ-Llama-3-8b-g128, main
3232
qqq, HandH1998/QQQ-Llama-3-8b, main
33-
hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main
33+
hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main
34+
None, mgleize/fairseq2-dummy-Llama-3.2-1B, main

tests/weight_loading/test_weight_loading.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@ def test_weight_loading(vllm_runner):
2020
"""
2121
Test parameter weight loading with tp>1.
2222
"""
23-
with vllm_runner(model_name=MODEL_NAME,
24-
revision=REVISION,
25-
dtype=torch.half if QUANTIZATION == "gptq" else "auto",
26-
quantization=QUANTIZATION,
27-
max_model_len=MAX_MODEL_LEN,
28-
tensor_parallel_size=2) as model:
23+
with vllm_runner(
24+
model_name=MODEL_NAME,
25+
revision=REVISION,
26+
dtype=torch.half if QUANTIZATION == "gptq" else "auto",
27+
quantization=None if QUANTIZATION == "None" else QUANTIZATION,
28+
max_model_len=MAX_MODEL_LEN,
29+
tensor_parallel_size=2) as model:
2930

3031
output = model.generate_greedy("Hello world!", max_tokens=20)
3132
print(output)

vllm/compilation/backends.py

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,30 @@
2525
logger = init_logger(__name__)
2626

2727

28+
@dataclasses.dataclass
29+
class InductorArtifact:
30+
hash_str: str = ""
31+
file_path: str = ""
32+
33+
2834
class InductorHashCache:
2935
"""
3036
Disk format: a Python list of tuples, each tuple is
31-
(runtime_shape, graph_index, hash_str)
37+
(runtime_shape, graph_index, hash_str, file_path)
3238
We use list of tuple for readability.
3339
3440
In-memory format: a defaultdict of dict, where the key is
3541
runtime_shape, and the value is a dict of graph_index to hash_str.
3642
37-
The data is essentially `Dict[Optional[int], Dict[int, str]]`,
43+
The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact]]`,
3844
we don't use json here because json doesn't support int as key.
3945
4046
TODO: better off-the-shelf solution to serialize the data?
4147
"""
4248

4349
def __init__(self, cache_dir: str, disabled: bool = False):
44-
self.cache: defaultdict = defaultdict(dict)
50+
self.cache: Dict[Optional[int],
51+
Dict[int, InductorArtifact]] = defaultdict(dict)
4552
self.disabled = disabled
4653
self.cache_dir = cache_dir
4754
self.cache_file_path = os.path.join(cache_dir,
@@ -66,14 +73,25 @@ def deserialize(self, data: str):
6673
# because it is a safe way to parse Python literals.
6774
# do not use eval(), it is unsafe.
6875
list_data = ast.literal_eval(data)
69-
for runtime_shape, graph_index, hash_str in list_data:
70-
self.cache[runtime_shape][graph_index] = hash_str
76+
for item in list_data:
77+
runtime_shape = item[0]
78+
graph_index = item[1]
79+
hash_str = item[2]
80+
# for compatibility of old version,
81+
# where we don't have file_path.
82+
# NOTE: after running the new code, the file_path
83+
# will be updated.
84+
file_path = "" if len(item) == 3 else item[3]
85+
self.cache[runtime_shape][graph_index] = InductorArtifact(
86+
hash_str=hash_str, file_path=file_path)
7187

7288
def serialize(self) -> str:
7389
data = []
74-
for runtime_shape, graph_index_to_hash_str in self.cache.items():
75-
for graph_index, hash_str in graph_index_to_hash_str.items():
76-
data.append((runtime_shape, graph_index, hash_str))
90+
for runtime_shape, value in self.cache.items():
91+
for graph_index, inductor_artifact in value.items():
92+
data.append(
93+
(runtime_shape, graph_index, inductor_artifact.hash_str,
94+
inductor_artifact.file_path))
7795
printer = pprint.PrettyPrinter(indent=4)
7896
return printer.pformat(data)
7997

@@ -90,13 +108,14 @@ def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
90108
return runtime_shape in self.cache and graph_index in self.cache[
91109
runtime_shape]
92110

93-
def __getitem__(self, key: Tuple[Optional[int], int]) -> str:
111+
def __getitem__(self, key: Tuple[Optional[int], int]) -> InductorArtifact:
94112
if self.disabled:
95113
raise KeyError("cannot read from disabled cache")
96114
runtime_shape, graph_index = key
97115
return self.cache[runtime_shape][graph_index]
98116

99-
def __setitem__(self, key: Tuple[Optional[int], int], value: str):
117+
def __setitem__(self, key: Tuple[Optional[int], int],
118+
value: InductorArtifact):
100119
# setitem for disabled cache is fine, because we
101120
# don't actually write to the disk
102121
runtime_shape, graph_index = key
@@ -181,7 +200,8 @@ def wrap_inductor(graph: fx.GraphModule,
181200
if (runtime_shape, graph_index) in cache_data:
182201
# we compiled this graph before
183202
# so we can directly lookup the compiled graph via hash
184-
hash_str = cache_data[(runtime_shape, graph_index)]
203+
inductor_artifact = cache_data[(runtime_shape, graph_index)]
204+
hash_str = inductor_artifact.hash_str
185205
if graph_index == 0:
186206
# adds some info logging for the first graph
187207
logger.info(
@@ -199,6 +219,7 @@ def wrap_inductor(graph: fx.GraphModule,
199219
"Inductor cache lookup failed. Please remove"
200220
f"the cache file {cache_data.cache_file_path} and try again." # noqa
201221
)
222+
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
202223

203224
# Inductor calling convention (function signature):
204225
# f(list) -> tuple
@@ -224,19 +245,20 @@ def compiled_graph(*args):
224245
# the assumption is that we don't have nested Inductor compilation.
225246
# compiled_fx_graph_hash will only be called once, and we can hook
226247
# it to get the hash of the compiled graph directly.
227-
from torch._inductor.codecache import compiled_fx_graph_hash
248+
249+
inductor_artifact = InductorArtifact()
250+
from torch._inductor.codecache import (FxGraphCache,
251+
compiled_fx_graph_hash)
252+
original_load = FxGraphCache.load
253+
254+
def hijack_load(*args, **kwargs):
255+
inductor_compiled_graph = original_load(*args, **kwargs)
256+
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
257+
return inductor_compiled_graph
228258

229259
def hijack_compiled_fx_graph_hash(*args, **kwargs):
230260
out = compiled_fx_graph_hash(*args, **kwargs)
231-
# store the hash in the cache
232-
nonlocal cache_data
233-
cache_data[(runtime_shape, graph_index)] = out[0]
234-
if graph_index == 0:
235-
# adds some info logging for the first graph
236-
logger.info("Cache the graph of shape %s for later use",
237-
str(runtime_shape))
238-
logger.debug("store the %s-th graph for shape %s via hash %s",
239-
graph_index, str(runtime_shape), out[0])
261+
inductor_artifact.hash_str = out[0]
240262
return out
241263

242264
def _check_can_cache(*args, **kwargs):
@@ -255,6 +277,11 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
255277
if not cache_data.disabled:
256278
# compilation cache is enabled, patch several functions
257279

280+
# hijack to get the compiled graph itself
281+
stack.enter_context(
282+
patch("torch._inductor.codecache.FxGraphCache.load",
283+
hijack_load))
284+
258285
# for hijacking the hash of the compiled graph
259286
stack.enter_context(
260287
patch("torch._inductor.codecache.compiled_fx_graph_hash",
@@ -275,7 +302,16 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
275302
compiled_graph = compile_fx(graph,
276303
example_inputs,
277304
config_patches=current_config)
278-
305+
# store the inductor_artifact in the cache
306+
cache_data[(runtime_shape, graph_index)] = inductor_artifact
307+
if graph_index == 0:
308+
# adds some info logging for the first graph
309+
logger.info("Cache the graph of shape %s for later use",
310+
str(runtime_shape))
311+
logger.debug(
312+
"store the %s-th graph for shape %s via hash %s from file %s",
313+
graph_index, str(runtime_shape), inductor_artifact.hash_str,
314+
inductor_artifact.file_path)
279315
# after compiling the last graph, record the end time
280316
if graph_index == num_graphs - 1:
281317
now = time.time()

vllm/compilation/decorators.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
7676
During runtime, when we actually mark dimensions of tensors,
7777
it depends on the value of arguments:
7878
79-
- if it is a single integer, the corresponding dimension of the argument
80-
will be marked as dynamic.
79+
- if it is a single integer (can be negative), the corresponding dimension
80+
of the argument will be marked as dynamic.
8181
- if it is `None`, ignored.
8282
- if it is `IntermediateTensors`, all the tensors in the intermediate
8383
tensors will be marked as dynamic.
@@ -177,10 +177,20 @@ def __call__(self, *args, **kwargs):
177177
for k, dims in dynamic_arg_dims.items():
178178
arg = bound_args.arguments.get(k)
179179
if arg is not None:
180+
dims = [dims] if isinstance(dims, int) else dims
180181
if isinstance(arg, torch.Tensor):
182+
# In case dims is specified with negative indexing
183+
dims = [
184+
arg.ndim + dim if dim < 0 else dim for dim in dims
185+
]
181186
torch._dynamo.mark_dynamic(arg, dims)
182187
elif isinstance(arg, IntermediateTensors):
183188
for tensor in arg.tensors.values():
189+
# In case dims is specified with negative indexing
190+
dims = [
191+
tensor.ndim + dim if dim < 0 else dim
192+
for dim in dims
193+
]
184194
torch._dynamo.mark_dynamic(tensor, dims)
185195
else:
186196
raise ValueError(

0 commit comments

Comments
 (0)