Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def input_processor_for_phi3v(ctx: InputContext,
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
image_feature_size = [item.shape[0] for item in image_data]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")

Expand Down Expand Up @@ -577,9 +577,6 @@ def _parse_and_validate_image_input(
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)

if pixel_values is None:
return None

if pixel_values is None and image_embeds is None:
return None

Expand Down Expand Up @@ -616,7 +613,7 @@ def _process_image_input(
) -> torch.Tensor:

if image_input["type"] == "image_embeds":
return image_input["data"]
return list(torch.unbind(image_input["data"], dim=0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes -- with the original code, with multiple images, it's a 3D tensor when we want a list of 2D tensors. This transforms it to a list of 2D tensors.

Copy link
Member

@DarkLight1337 DarkLight1337 Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if image_input["data"] is already a list? (which may happen when the feature size of each image is different)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair, the images I use to test all have the same feature size so it seems like they somehow got concatenated together into a tensor. I added this case to catch for that and follow the original functionality: https://github.com/vllm-project/vllm/pull/8979/files#diff-be2abd8a08916397663c17c0ecb4036d478e4cc5f770b2f270b121d11414f080R618. I suspect that in that case, some sort of error may occur here: https://github.com/vllm-project/vllm/pull/8979/files#diff-be2abd8a08916397663c17c0ecb4036d478e4cc5f770b2f270b121d11414f080L602, since it expects a torch Tensor and not a list.


assert self.vision_embed_tokens is not None
image_embeds = self.vision_embed_tokens(image_input["data"],
Expand Down
48 changes: 29 additions & 19 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of

_AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25
Expand Down Expand Up @@ -118,6 +119,10 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
if not isinstance(data, list):
data = [data]

# If the audio inputs are embeddings, no need for preprocessing
if is_list_of(data, torch.Tensor, check="all"):
return MultiModalInputs({"audio_embeds": data})

audio_features = []
for audio_input in data:
if not isinstance(audio_input, tuple):
Expand Down Expand Up @@ -164,25 +169,30 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
audios = [audios]

audio_token_counts = []
for audio_data, sample_rate in audios:
audio_length = audio_data.shape[0]
if sample_rate != feature_extractor.sampling_rate:
# Account for resampling.
adjustment = feature_extractor.sampling_rate / sample_rate
audio_length = math.ceil(adjustment * audio_length)

feature_extractor_output_length = math.ceil(
(audio_length - (feature_extractor.hop_length - 1)) /
feature_extractor.hop_length)

uv_config = ctx.get_hf_config(UltravoxConfig)
audio_num_tokens = min(
max(
1,
math.ceil(feature_extractor_output_length /
(uv_config.stack_factor * 2))),
get_ultravox_max_audio_tokens(ctx))
audio_token_counts.append(audio_num_tokens)
for audio in audios:
if isinstance(audio, torch.Tensor):
audio_num_tokens = audio.shape[1]
audio_token_counts.append(audio_num_tokens)
else:
audio_data, sample_rate = audio
audio_length = audio_data.shape[0]
if sample_rate != feature_extractor.sampling_rate:
# Account for resampling.
adjustment = feature_extractor.sampling_rate / sample_rate
audio_length = math.ceil(adjustment * audio_length)

feature_extractor_output_length = math.ceil(
(audio_length - (feature_extractor.hop_length - 1)) /
feature_extractor.hop_length)

uv_config = ctx.get_hf_config(UltravoxConfig)
audio_num_tokens = min(
max(
1,
math.ceil(feature_extractor_output_length /
(uv_config.stack_factor * 2))),
get_ultravox_max_audio_tokens(ctx))
audio_token_counts.append(audio_num_tokens)

tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)

Expand Down