Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 85 additions & 3 deletions QEfficient/generation/embedding_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
operations, separating them from the main text generation logic.
"""

from typing import Any, Dict, Optional, Tuple
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import requests
import torch
from PIL import Image
from transformers import AutoImageProcessor
from transformers import AutoImageProcessor, AutoTokenizer

from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.utils.logging_utils import logger
Expand All @@ -37,6 +38,9 @@ def __init__(
qeff_model: Optional[QAICInferenceSession],
vision_session: Optional[QAICInferenceSession],
processor: Optional[AutoImageProcessor],
tokenizer: Optional[AutoTokenizer],
image_height: Optional[int] = None,
image_width: Optional[int] = None,
config: Optional[Dict[str, Any]] = None,
lang_session: Optional[QAICInferenceSession] = None,
):
Expand All @@ -46,12 +50,16 @@ def __init__(
Args:
vision_session: QAICInferenceSession for vision model
processor: AutoImageProcessor for image preprocessing
tokenizer: AutoTokenizer for text tokenization
config: Configuration dictionary with vision model parameters
lang_session: Optional language session for coordination (to avoid resource conflicts)
"""
self._qeff_model = qeff_model
self._vision_session = vision_session
self._processor = processor
self._tokenizer = tokenizer
self._image_height = image_height
self._image_width = image_width
self._config = config or {}
self._lang_session = lang_session # Store language session for coordination

Expand All @@ -70,6 +78,71 @@ def is_available(self) -> bool:
"""
return self._vision_session is not None and self._processor is not None

def prepare_internVL_inputs(self, img_url: str, query: str) -> Dict[str, np.ndarray]:
"""
Prepare inputs for InternVL model

Args:
image_url: URL or path to image
query: Text query to process with image
prompt = [query]
"""
if not self._tokenizer:
raise ValueError("Tokenizer is required for InternVL input preparation")
prompt = query
pixel_values = []
num_patches_list = []
questions = []
img = requests.get(img_url, stream=True)
image = Image.open(BytesIO(img.content)).convert("RGB")

if self._image_height and self._image_width:
image = image.resize((self._image_height, self._image_width))
else:
logger.warning("Height and Width not specified. Using default image size for num_patches = 13.")
image = image.resize((1000, 747))

# preprocess the resized image
pixel_value = self._processor.load_image(image, max_num=12)
num_patches_list.append(pixel_value.shape[0])
pixel_values.append(pixel_value)

question = "<image>\n" + prompt
questions.append(question)

pixel_values = torch.cat(pixel_values, dim=0)

# Chat Template information for prompt preprocessing
messages: List[List[str]] = []
roles = ("<|im_start|>user\n", "<|im_start|>assistant\n")
prompt = self._processor(pixel_values, questions, messages, roles, num_patches_list=num_patches_list)

inputs = self._tokenizer(prompt, return_tensors="pt")
inputs["pixel_values"] = pixel_values.clone()

# Convert to numpy arrays
vision_inputs = {}
for k, v in inputs.items():
if k in {
"pixel_values",
"image_masks",
"image_input_idx",
"valid_idx",
"aspect_ratio_ids",
"aspect_ratio_mask",
}:
vision_inputs[k] = np.array(v)

# Convert specific inputs to float16
vision_inputs_fp16 = {"pixel_values", "image_masks"}
for k in vision_inputs_fp16:
if k in vision_inputs:
vision_inputs[k] = vision_inputs[k].astype("float16")

lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}

return vision_inputs, lang_inputs

def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -> Dict[str, np.ndarray]:
"""
Download and preprocess image into model inputs
Expand All @@ -95,6 +168,9 @@ def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -
else:
image = Image.open(image_url)

if "mistral3" in self._qeff_model.model.config.model_type:
image = image.resize((1540, 1540))

# Prepare conversation format
conversation = [
{
Expand Down Expand Up @@ -323,7 +399,13 @@ def get_processed_inputs(

try:
## Get vlm inputs ##
vision_inputs, lang_inputs = self.prepare_vlm_inputs(image_url, query, prefill_seq_len)
if (
hasattr(self._qeff_model.model.config, "model_type")
and self._qeff_model.model.config.model_type == "internvl_chat"
):
vision_inputs, lang_inputs = self.prepare_internVL_inputs(image_url, query)
else:
vision_inputs, lang_inputs = self.prepare_vlm_inputs(image_url, query, prefill_seq_len)

# Handle padding for language model
pad_token_id = 1
Expand Down
8 changes: 8 additions & 0 deletions QEfficient/generation/vlm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def __init__(
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
full_batch_size: Optional[int] = None,
image_height: Optional[int] = None,
image_width: Optional[int] = None,
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
Expand Down Expand Up @@ -139,6 +141,9 @@ def __init__(
)
self.qeff_model = qeff_model
self.processor = processor
self.tokenizer = tokenizer
self.image_height = image_height
self.image_width = image_width
self._vision_qpc_path = vision_qpc_path
self.device_id = device_id # Store device_id for vision components
self.enable_debug_logs = enable_debug_logs # Store for vision components
Expand Down Expand Up @@ -169,6 +174,9 @@ def _init_vision_components(self):
qeff_model=self.qeff_model,
vision_session=self._vision_session,
processor=self.processor,
tokenizer=self.tokenizer,
image_height=self.image_height,
image_width=self.image_width,
config=vision_config,
lang_session=self._session, # Pass language session for coordination
)
Expand Down
93 changes: 67 additions & 26 deletions QEfficient/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,15 @@ def __init__(self, model):
self.config = self.model.config
self.lm_head = self.model.lm_head

def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values):
def forward(
self,
input_ids,
vision_embeds,
position_ids,
image_idx,
past_key_values,
batch_index: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.model.get_input_embeddings()(input_ids)
B, N, C = inputs_embeds.shape
selected = input_ids == self.model.config.image_token_index
Expand All @@ -603,7 +611,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va
image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds)
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds)
outputs = self.language_model(
inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
inputs_embeds=inputs_embeds,
position_ids=position_ids,
past_key_values=past_key_values,
batch_index=batch_index,
use_cache=True,
)
image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
Expand Down Expand Up @@ -648,6 +660,9 @@ def get_specializations(
ctx_len: int,
img_size: int,
kv_offload: bool = False,
continuous_batching: bool = False,
kv_cache_batch_size: Optional[int] = None,
full_batch_size: Optional[int] = None,
**compiler_options,
):
prefill_seq_len = prefill_seq_len if prefill_seq_len else 32
Expand All @@ -667,44 +682,63 @@ def get_specializations(
"ctx_len": ctx_len,
}
]
lang = [
{
"batch_size": batch_size,
"seq_len": prefill_seq_len,
"ctx_len": ctx_len,
"sliding_window": self.language_model.config.sliding_window,
"img_size": img_size,
"mm_tokens_per_image": mm_tokens_per_image,
},
{
"batch_size": batch_size,
"seq_len": "1",
"ctx_len": ctx_len,
"sliding_window": self.language_model.config.sliding_window,
"img_size": img_size,
"mm_tokens_per_image": mm_tokens_per_image,
},
]
lang_prefill = {
"batch_size": 1 if continuous_batching else batch_size,
"seq_len": prefill_seq_len,
"ctx_len": ctx_len,
"sliding_window": self.language_model.config.sliding_window,
"img_size": img_size,
"mm_tokens_per_image": mm_tokens_per_image,
"vision_batch_size": batch_size,
}
if continuous_batching:
lang_prefill["full_batch_size"] = kv_cache_batch_size
else:
lang_prefill["batch_size"] = kv_cache_batch_size
if full_batch_size:
lang_prefill["full_batch_exec_size"] = full_batch_size

lang_decode = {
"batch_size": full_batch_size if continuous_batching else batch_size,
"seq_len": "1",
"ctx_len": ctx_len,
"sliding_window": self.language_model.config.sliding_window,
"img_size": img_size,
"mm_tokens_per_image": mm_tokens_per_image,
"vision_batch_size": batch_size,
}
if continuous_batching:
lang_decode["full_batch_size"] = kv_cache_batch_size
else:
lang_decode["batch_size"] = kv_cache_batch_size
lang = []
lang.append(lang_prefill)
lang.append(lang_decode)

specializations = {}

if kv_offload:
specializations["vision"] = vision
specializations["lang"] = lang
return specializations, compiler_options
else:
lang[0].pop("vision_size")
lang[1].pop("vision_size")
return lang, compiler_options

def get_onnx_dynamic_axes(self, kv_offload: bool = False):
def get_onnx_dynamic_axes(self, kv_offload: bool = False, continuous_batching: bool = False):
# Define dynamic axes
vision_dynamic_axes = {}
lang_dynamic_axes = {}
lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"}
lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"}
lang_dynamic_axes["vision_embeds"] = {0: "batch_size", 1: "mm_tokens_per_image"}
lang_dynamic_axes["vision_embeds"] = {0: "vision_batch_size", 1: "mm_tokens_per_image"}
if continuous_batching:
lang_dynamic_axes["batch_index"] = {0: "batch_size"}
vision_dynamic_axes["pixel_values"] = {0: "batch_size", 2: "img_size", 3: "img_size"}

pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"}
pkv_dynamic_sliding_axes = {0: "batch_size", 2: "sliding_window"}
pkv_dynamic_axes = {0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len"}
pkv_dynamic_sliding_axes = {0: "full_batch_size" if continuous_batching else "batch_size", 2: "sliding_window"}
layer_switch = (
self.language_model.config.sliding_window_pattern
if hasattr(self.language_model.config, "sliding_window_pattern")
Expand Down Expand Up @@ -767,7 +801,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len):
past_key_values.append(pkv)
return past_key_values

def get_dummy_inputs(self, kv_offload: bool = False):
def get_dummy_inputs(self, kv_offload: bool = False, continuous_batching: bool = False):
if vis_cfg := getattr(self.config, "vision_config", None):
img_size = getattr(vis_cfg, "image_size", 896)
else:
Expand Down Expand Up @@ -806,13 +840,20 @@ def get_dummy_inputs(self, kv_offload: bool = False):
.repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1)
)
lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64)

bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS

# Add data for KV
lang_inputs["past_key_values"] = self.get_dummy_pkv_cache(
config=self.language_model.config,
batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
batch_size=fbs if continuous_batching else bs,
seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
)

if continuous_batching:
lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1)

inputs = {}
if kv_offload:
inputs["vision"] = vision_inputs
Expand Down
Loading
Loading