diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py
index 3e3034a02f0f..1ceec026b319 100644
--- a/examples/offline_inference/audio_language.py
+++ b/examples/offline_inference/audio_language.py
@@ -24,25 +24,30 @@
# Unless specified, these settings have been tested to work on a single L4.
-# Ultravox 0.5-1B
-def run_ultravox(question: str, audio_count: int):
- model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
+# MiniCPM-O
+def run_minicpmo(question: str, audio_count: int):
+ model_name = "openbmb/MiniCPM-o-2_6"
+ tokenizer = AutoTokenizer.from_pretrained(model_name,
+ trust_remote_code=True)
+ llm = LLM(model=model_name,
+ trust_remote_code=True,
+ max_model_len=4096,
+ max_num_seqs=5,
+ limit_mm_per_prompt={"audio": audio_count})
- tokenizer = AutoTokenizer.from_pretrained(model_name)
+ stop_tokens = ['<|im_end|>', '<|endoftext|>']
+ stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
+
+ audio_placeholder = "()" * audio_count
+ audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}" # noqa: E501
messages = [{
'role': 'user',
- 'content': "<|audio|>\n" * audio_count + question
+ 'content': f'{audio_placeholder}\n{question}'
}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
- add_generation_prompt=True)
-
- llm = LLM(model=model_name,
- max_model_len=4096,
- max_num_seqs=5,
- trust_remote_code=True,
- limit_mm_per_prompt={"audio": audio_count})
- stop_token_ids = None
+ add_generation_prompt=True,
+ chat_template=audio_chat_template)
return llm, prompt, stop_token_ids
@@ -68,36 +73,49 @@ def run_qwen2_audio(question: str, audio_count: int):
return llm, prompt, stop_token_ids
-def run_minicpmo(question: str, audio_count: int):
- model_name = "openbmb/MiniCPM-o-2_6"
- tokenizer = AutoTokenizer.from_pretrained(model_name,
- trust_remote_code=True)
- llm = LLM(model=model_name,
- trust_remote_code=True,
- max_model_len=4096,
- max_num_seqs=5,
- limit_mm_per_prompt={"audio": audio_count})
-
- stop_tokens = ['<|im_end|>', '<|endoftext|>']
- stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
+# Ultravox 0.5-1B
+def run_ultravox(question: str, audio_count: int):
+ model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
- audio_placeholder = "()" * audio_count
- audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}" # noqa: E501
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [{
'role': 'user',
- 'content': f'{audio_placeholder}\n{question}'
+ 'content': "<|audio|>\n" * audio_count + question
}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
- add_generation_prompt=True,
- chat_template=audio_chat_template)
+ add_generation_prompt=True)
+
+ llm = LLM(model=model_name,
+ max_model_len=4096,
+ max_num_seqs=5,
+ trust_remote_code=True,
+ limit_mm_per_prompt={"audio": audio_count})
+ stop_token_ids = None
+ return llm, prompt, stop_token_ids
+
+
+# Whisper
+def run_whisper(question: str, audio_count: int):
+ assert audio_count == 1, (
+ "Whisper only support single audio input per prompt")
+ model_name = "openai/whisper-large-v3-turbo"
+
+ prompt = "<|startoftranscript|>"
+
+ llm = LLM(model=model_name,
+ max_model_len=448,
+ max_num_seqs=5,
+ limit_mm_per_prompt={"audio": audio_count})
+ stop_token_ids = None
return llm, prompt, stop_token_ids
model_example_map = {
- "ultravox": run_ultravox,
+ "minicpmo": run_minicpmo,
"qwen2_audio": run_qwen2_audio,
- "minicpmo": run_minicpmo
+ "ultravox": run_ultravox,
+ "whisper": run_whisper,
}
diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py
new file mode 100644
index 000000000000..f44bc423658e
--- /dev/null
+++ b/examples/offline_inference/encoder_decoder_multimodal.py
@@ -0,0 +1,158 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+This example shows how to use vLLM for running offline inference with
+the explicit/implicit prompt format on enc-dec LMMs for text generation.
+"""
+import time
+
+from vllm import LLM, SamplingParams
+from vllm.assets.audio import AudioAsset
+from vllm.assets.image import ImageAsset
+from vllm.utils import FlexibleArgumentParser
+
+
+def run_florence2():
+ # Create a Florence-2 encoder/decoder model instance
+ llm = LLM(
+ model="microsoft/Florence-2-large",
+ tokenizer="facebook/bart-large",
+ max_num_seqs=8,
+ trust_remote_code=True,
+ limit_mm_per_prompt={"image": 1},
+ dtype="half",
+ )
+
+ prompts = [
+ { # implicit prompt with task token
+ "prompt": "",
+ "multi_modal_data": {
+ "image": ImageAsset("stop_sign").pil_image
+ },
+ },
+ { # explicit encoder/decoder prompt
+ "encoder_prompt": {
+ "prompt": "Describe in detail what is shown in the image.",
+ "multi_modal_data": {
+ "image": ImageAsset("cherry_blossom").pil_image
+ },
+ },
+ "decoder_prompt": "",
+ },
+ ]
+ return llm, prompts
+
+
+def run_mllama():
+ # Create a Mllama encoder/decoder model instance
+ llm = LLM(
+ model="meta-llama/Llama-3.2-11B-Vision-Instruct",
+ max_model_len=4096,
+ max_num_seqs=2,
+ limit_mm_per_prompt={"image": 1},
+ dtype="half",
+ )
+
+ prompts = [
+ { # Implicit prompt
+ "prompt": "<|image|><|begin_of_text|>What is the content of this image?", # noqa: E501
+ "multi_modal_data": {
+ "image": ImageAsset("stop_sign").pil_image,
+ },
+ },
+ { # Explicit prompt
+ "encoder_prompt": {
+ "prompt": "<|image|>",
+ "multi_modal_data": {
+ "image": ImageAsset("stop_sign").pil_image,
+ },
+ },
+ "decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.", # noqa: E501
+ },
+ ]
+ return llm, prompts
+
+
+def run_whisper():
+ # Create a Whisper encoder/decoder model instance
+ llm = LLM(
+ model="openai/whisper-large-v3-turbo",
+ max_model_len=448,
+ max_num_seqs=16,
+ limit_mm_per_prompt={"audio": 1},
+ dtype="half",
+ )
+
+ prompts = [
+ { # Test implicit prompt
+ "prompt": "<|startoftranscript|>",
+ "multi_modal_data": {
+ "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
+ },
+ },
+ { # Test explicit encoder/decoder prompt
+ "encoder_prompt": {
+ "prompt": "",
+ "multi_modal_data": {
+ "audio": AudioAsset("winning_call").audio_and_sample_rate,
+ },
+ },
+ "decoder_prompt": "<|startoftranscript|>",
+ }
+ ]
+ return llm, prompts
+
+
+model_example_map = {
+ "florence2": run_florence2,
+ "mllama": run_mllama,
+ "whisper": run_whisper,
+}
+
+
+def main(args):
+ model = args.model_type
+ if model not in model_example_map:
+ raise ValueError(f"Model type {model} is not supported.")
+
+ llm, prompts = model_example_map[model]()
+
+ # Create a sampling params object.
+ sampling_params = SamplingParams(
+ temperature=0,
+ top_p=1.0,
+ max_tokens=64,
+ )
+
+ start = time.time()
+
+ # Generate output tokens from the prompts. The output is a list of
+ # RequestOutput objects that contain the prompt, generated
+ # text, and other information.
+ outputs = llm.generate(prompts, sampling_params)
+
+ # Print the outputs.
+ for output in outputs:
+ prompt = output.prompt
+ generated_text = output.outputs[0].text
+ print(f"Decoder prompt: {prompt!r}, "
+ f"Generated text: {generated_text!r}")
+
+ duration = time.time() - start
+
+ print("Duration:", duration)
+ print("RPS:", len(prompts) / duration)
+
+
+if __name__ == "__main__":
+ parser = FlexibleArgumentParser(
+ description='Demo on using vLLM for offline inference with '
+ 'vision language models for text generation')
+ parser.add_argument('--model-type',
+ '-m',
+ type=str,
+ default="mllama",
+ choices=model_example_map.keys(),
+ help='Huggingface "model_type".')
+
+ args = parser.parse_args()
+ main(args)
diff --git a/examples/offline_inference/florence2_inference.py b/examples/offline_inference/florence2_inference.py
deleted file mode 100644
index 27aceee43cbf..000000000000
--- a/examples/offline_inference/florence2_inference.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-"""
-Demonstrate prompting of text-to-text
-encoder/decoder models, specifically Florence-2
-"""
-# TODO(Isotr0py):
-# Move to offline_inference/vision_language.py
-# after porting vision backbone
-from vllm import LLM, SamplingParams
-from vllm.assets.image import ImageAsset
-
-# Create a Florence-2 encoder/decoder model instance
-llm = LLM(
- model="microsoft/Florence-2-large",
- tokenizer="facebook/bart-large",
- max_num_seqs=8,
- trust_remote_code=True,
-)
-
-prompts = [
- { # implicit prompt with task token
- "prompt": "",
- "multi_modal_data": {
- "image": ImageAsset("stop_sign").pil_image
- },
- },
- { # explicit encoder/decoder prompt
- "encoder_prompt": {
- "prompt": "Describe in detail what is shown in the image.",
- "multi_modal_data": {
- "image": ImageAsset("cherry_blossom").pil_image
- },
- },
- "decoder_prompt": "",
- },
-]
-# Create a sampling params object.
-sampling_params = SamplingParams(
- temperature=0,
- top_p=1.0,
- min_tokens=0,
- max_tokens=128,
-)
-
-# Generate output tokens from the prompts. The output is a list of
-# RequestOutput objects that contain the prompt, generated
-# text, and other information.
-outputs = llm.generate(prompts, sampling_params)
-
-# Print the outputs.
-for output in outputs:
- generated_text = output.outputs[0].text
- print(f"Generated text: {generated_text!r}")
diff --git a/examples/offline_inference/whisper.py b/examples/offline_inference/whisper.py
deleted file mode 100644
index 59c119a772da..000000000000
--- a/examples/offline_inference/whisper.py
+++ /dev/null
@@ -1,61 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-
-import time
-
-from vllm import LLM, SamplingParams
-from vllm.assets.audio import AudioAsset
-
-# Create a Whisper encoder/decoder model instance
-llm = LLM(
- model="openai/whisper-large-v3",
- max_model_len=448,
- max_num_seqs=400,
- limit_mm_per_prompt={"audio": 1},
- kv_cache_dtype="fp8",
-)
-
-prompts = [
- {
- "prompt": "<|startoftranscript|>",
- "multi_modal_data": {
- "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
- },
- },
- { # Test explicit encoder/decoder prompt
- "encoder_prompt": {
- "prompt": "",
- "multi_modal_data": {
- "audio": AudioAsset("winning_call").audio_and_sample_rate,
- },
- },
- "decoder_prompt": "<|startoftranscript|>",
- }
-] * 1024
-
-# Create a sampling params object.
-sampling_params = SamplingParams(
- temperature=0,
- top_p=1.0,
- max_tokens=200,
-)
-
-start = time.time()
-
-# Generate output tokens from the prompts. The output is a list of
-# RequestOutput objects that contain the prompt, generated
-# text, and other information.
-outputs = llm.generate(prompts, sampling_params)
-
-# Print the outputs.
-for output in outputs:
- prompt = output.prompt
- encoder_prompt = output.encoder_prompt
- generated_text = output.outputs[0].text
- print(f"Encoder prompt: {encoder_prompt!r}, "
- f"Decoder prompt: {prompt!r}, "
- f"Generated text: {generated_text!r}")
-
-duration = time.time() - start
-
-print("Duration:", duration)
-print("RPS:", len(prompts) / duration)
diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py
index 656e5fc6dcf3..c5a55e300c46 100644
--- a/vllm/model_executor/models/whisper.py
+++ b/vllm/model_executor/models/whisper.py
@@ -748,11 +748,11 @@ def _create_fake_bias_for_k_proj(
weights: Iterable[Tuple[str, torch.Tensor]]
) -> Iterable[Tuple[str, torch.Tensor]]:
"""
- Create full zeros bias for k_proj weight in self-attention layers.
+ Create full zeros bias for k_proj weight in self-attn and x-attn layers.
So that the bias for k_proj in qkv_proj can be initialized with zeros.
"""
for name, weight in weights:
- if name.endswith(".self_attn.k_proj.weight"):
+ if name.endswith(".k_proj.weight"):
bias = torch.zeros(weight.size(0))
bias_name = name.replace("weight", "bias")
yield from [(name, weight), (bias_name, bias)]