Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions verl/trainer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class DataConfig:
prompt_key: str = "prompt"
answer_key: str = "answer"
image_key: str = "images"
video_key: str = "videos"
image_dir: Optional[str] = None
max_prompt_length: int = 512
max_response_length: int = 512
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def create_dataloader(config: DataConfig, tokenizer: PreTrainedTokenizer, proces
answer_key=config.answer_key,
image_key=config.image_key,
image_dir=config.image_dir,
video_key=config.video_key,
max_prompt_length=config.max_prompt_length,
truncation="right",
format_prompt=config.format_prompt,
Expand Down
28 changes: 27 additions & 1 deletion verl/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ..models.transformers.qwen2_vl import get_rope_index
from . import torch_functional as VF

from qwen_vl_utils.vision_process import fetch_video

def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]:
tensors = defaultdict(list)
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(
answer_key: str = "answer",
image_key: str = "images",
image_dir: Optional[str] = None,
video_key: str = "videos",
max_prompt_length: int = 1024,
truncation: str = "error",
format_prompt: Optional[str] = None,
Expand All @@ -104,6 +106,7 @@ def __init__(
self.answer_key = answer_key
self.image_key = image_key
self.image_dir = image_dir
self.video_key = video_key
self.max_prompt_length = max_prompt_length
self.truncation = truncation
self.min_pixels = min_pixels
Expand Down Expand Up @@ -152,6 +155,16 @@ def _build_messages(self, example: Dict[str, Any]) -> List[Dict[str, Any]]:
if content:
content_list.append({"type": "text", "text": content})

return [{"role": "user", "content": content_list}]
elif self.video_key in example:
content_list = []
for i, content in enumerate(prompt_str.split("<video>")):
if i != 0:
content_list.append({"type": "video"})

if content:
content_list.append({"type": "text", "text": content})

return [{"role": "user", "content": content_list}]
else:
return [{"role": "user", "content": prompt_str}]
Expand Down Expand Up @@ -193,6 +206,18 @@ def __getitem__(self, index):
input_ids = model_inputs.pop("input_ids")[0]
attention_mask = model_inputs.pop("attention_mask")[0]
example["multi_modal_data"] = {"images": images}
elif self.video_key in example:
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
video = example.pop(self.video_key)
vision_info = {"video": video, "min_pixels": self.min_pixels, "max_pixels": self.max_pixels, "fps": 25}
video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True)

model_inputs = self.processor(
videos=video_input, text=[prompt], add_special_tokens=False, return_tensors="pt"
)
input_ids = model_inputs.pop("input_ids")[0]
attention_mask = model_inputs.pop("attention_mask")[0]
example["multi_modal_data"] = {"video": video_input}
else:
prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
model_inputs = self.tokenizer([prompt], add_special_tokens=False, return_tensors="pt")
Expand All @@ -204,7 +229,8 @@ def __getitem__(self, index):
position_ids = get_rope_index(
self.processor,
input_ids=input_ids,
image_grid_thw=model_inputs.get("image_grid_thw"),
image_grid_thw=model_inputs.get("image_grid_thw", None),
video_grid_thw=model_inputs.get("video_grid_thw", None),
attention_mask=attention_mask,
) # (3, seq_length)
else:
Expand Down
14 changes: 9 additions & 5 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,17 +455,21 @@ def _process_multi_modal_inputs(self, data: DataProto):
max_pixels = data.meta_info["max_pixels"]
batch_multi_modal_inputs = []
for multi_modal_data in data.non_tensor_batch["multi_modal_data"]:
images = []
for image in multi_modal_data["images"]:
images.append(process_image(image, min_pixels=min_pixels, max_pixels=max_pixels))

if len(images) != 0:
if "images" in multi_modal_data:
# it's necessary to add `dict` to properly convert batch features to dict
# otherwise the batch features will be converted to dict keys
# see https://github.com/hiyouga/EasyR1/pull/339
images = []
for image in multi_modal_data["images"]:
images.append(process_image(image, min_pixels=min_pixels, max_pixels=max_pixels))
multi_modal_inputs = dict(self.processor.image_processor(images=images, return_tensors="pt"))
multi_modal_inputs = {k: v.to(torch.cuda.current_device()) for k, v in multi_modal_inputs.items()}
batch_multi_modal_inputs.append(multi_modal_inputs)
elif "video" in multi_modal_data:
video = multi_modal_data["video"]
multi_modal_inputs = dict(self.processor.image_processor(images=None, videos=video, return_tensors="pt"))
multi_modal_inputs = {k: v.to(torch.cuda.current_device()) for k, v in multi_modal_inputs.items()}
batch_multi_modal_inputs.append(multi_modal_inputs)
else:
batch_multi_modal_inputs.append({})

Expand Down
16 changes: 10 additions & 6 deletions verl/workers/rollout/vllm_rollout_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,16 @@ def _get_logit_bias(processor: Optional[ProcessorMixin]) -> Optional[Dict[int, f
def _process_multi_modal_data(multi_modal_data: Dict[str, Any], min_pixels: int, max_pixels: int) -> Dict[str, Any]:
# may convert image path to image object
# TODO: add video
images = []
for image in multi_modal_data["images"]:
images.append(process_image(image, min_pixels=min_pixels, max_pixels=max_pixels))

if len(images) != 0:
return {"image": images}

if "images" in multi_modal_data:
images = []
for image in multi_modal_data["images"]:
images.append(process_image(image, min_pixels=min_pixels, max_pixels=max_pixels))
return {"image": images}
elif "video" in multi_modal_data:
# for image in multi_modal_data["video"]:
# images.append(process_image(image, min_pixels=min_pixels, max_pixels=max_pixels))
return {"video": multi_modal_data["video"]}

return None

Expand Down
Loading