Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified assets/wechat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion examples/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@ worker:
n: 5
temperature: 1.0
top_p: 0.99
limit_images: 0
gpu_memory_utilization: 0.6
enforce_eager: false
enable_chunked_prefill: false
tensor_parallel_size: 2
limit_images: 0
disable_tqdm: false
val_override_config:
temperature: 0.5
n: 1
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class DataConfig:
image_key: str = "images"
video_key: str = "videos"
image_dir: Optional[str] = None
video_fps: float = 2.0
max_prompt_length: int = 512
max_response_length: int = 512
rollout_batch_size: int = 512
Expand Down
3 changes: 2 additions & 1 deletion verl/trainer/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def create_dataloader(config: DataConfig, tokenizer: PreTrainedTokenizer, proces
prompt_key=config.prompt_key,
answer_key=config.answer_key,
image_key=config.image_key,
image_dir=config.image_dir,
video_key=config.video_key,
image_dir=config.image_dir,
video_fps=config.video_fps,
max_prompt_length=config.max_prompt_length,
truncation="right",
format_prompt=config.format_prompt,
Expand Down
9 changes: 7 additions & 2 deletions verl/trainer/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def _validate(self) -> Dict[str, Any]:
test_gen_batch.meta_info = self.config.worker.rollout.val_override_config
test_gen_batch.meta_info["min_pixels"] = self.config.data.min_pixels
test_gen_batch.meta_info["max_pixels"] = self.config.data.max_pixels
test_gen_batch.meta_info["video_fps"] = self.config.data.video_fps

test_gen_batch, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_ref_wg.world_size)
test_output_gen_batch = self.actor_rollout_ref_wg.generate_sequences(test_gen_batch)
Expand Down Expand Up @@ -456,14 +457,18 @@ def _make_batch_data(self, metrics: Dict[str, Any]) -> DataProto:
self.data_iterator = iter(self.train_dataloader)
batch_dict = next(self.data_iterator)

meta_info = {"min_pixels": self.config.data.min_pixels, "max_pixels": self.config.data.max_pixels}
meta_info = {
"min_pixels": self.config.data.min_pixels,
"max_pixels": self.config.data.max_pixels,
"video_fps": self.config.data.video_fps,
}
new_batch: DataProto = DataProto.from_single_dict(batch_dict, meta_info=meta_info)

# pop those keys for generation
gen_batch = new_batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
meta_info_keys=["min_pixels", "max_pixels"],
meta_info_keys=["min_pixels", "max_pixels", "video_fps"],
)

# generate a batch
Expand Down
74 changes: 56 additions & 18 deletions verl/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@
import os
from collections import defaultdict
from io import BytesIO
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from datasets import load_dataset
from jinja2 import Template
from PIL import Image
from PIL.Image import Image as ImageObject
from qwen_vl_utils.vision_process import fetch_video
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin

from ..models.transformers.qwen2_vl import get_rope_index
from . import torch_functional as VF

from qwen_vl_utils.vision_process import fetch_video

def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]:
tensors = defaultdict(list)
Expand Down Expand Up @@ -78,6 +78,13 @@ def process_image(
return image


def process_video(
video: str, min_pixels: Optional[int], max_pixels: Optional[int], video_fps: float, return_fps: bool = False
) -> Union[List[ImageObject], Tuple[List[ImageObject], List[float]]]:
vision_info = {"video": video, "min_pixels": min_pixels, "max_pixels": max_pixels, "fps": video_fps}
return fetch_video(vision_info, return_video_sample_fps=return_fps)


class RLHFDataset(Dataset):
"""
We assume the dataset contains a column that contains prompts and other information
Expand All @@ -91,8 +98,9 @@ def __init__(
prompt_key: str = "prompt",
answer_key: str = "answer",
image_key: str = "images",
image_dir: Optional[str] = None,
video_key: str = "videos",
image_dir: Optional[str] = None,
video_fps: float = 2.0,
max_prompt_length: int = 1024,
truncation: str = "error",
format_prompt: Optional[str] = None,
Expand All @@ -105,8 +113,9 @@ def __init__(
self.prompt_key = prompt_key
self.answer_key = answer_key
self.image_key = image_key
self.image_dir = image_dir
self.video_key = video_key
self.image_dir = image_dir
self.video_fps = video_fps
self.max_prompt_length = max_prompt_length
self.truncation = truncation
self.min_pixels = min_pixels
Expand Down Expand Up @@ -173,14 +182,29 @@ def _filter_overlong_prompts(self, example: Dict[str, Any]) -> bool:
messages = self._build_messages(example)
if self.image_key in example:
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
images = example[self.image_key] or []
images = example[self.image_key]
if self.image_dir is not None and len(images) != 0 and isinstance(images[0], str): # image paths
images = [os.path.join(self.image_dir, image) for image in images]

resized_images = [
process_image(image, min_pixels=self.min_pixels, max_pixels=self.max_pixels) for image in images
] or None
model_inputs = self.processor(resized_images, [prompt], add_special_tokens=False, return_tensors="pt")
processed_images = [] if len(images) != 0 else None # text-only data
for image in images:
processed_images.append(process_image(image, self.min_pixels, self.max_pixels))

model_inputs = self.processor(processed_images, [prompt], add_special_tokens=False, return_tensors="pt")
return model_inputs["input_ids"].size(-1) <= self.max_prompt_length
elif self.video_key in example:
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
videos = example[self.video_key]
if self.image_dir is not None and len(videos) != 0 and isinstance(videos[0], str): # video paths
videos = [os.path.join(self.image_dir, video) for video in videos]

processed_videos = [] if len(videos) != 0 else None # text-only data
for video in videos:
processed_videos.append(process_video(video, self.min_pixels, self.max_pixels, self.video_fps))

model_inputs = self.processor(
videos=processed_videos, text=[prompt], add_special_tokens=False, return_tensors="pt"
)
return model_inputs["input_ids"].size(-1) <= self.max_prompt_length
else:
input_ids = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True)
Expand All @@ -199,25 +223,38 @@ def __getitem__(self, index):
if self.image_dir is not None and len(images) != 0 and isinstance(images[0], str): # image paths
images = [os.path.join(self.image_dir, image) for image in images]

resized_images = [
process_image(image, min_pixels=self.min_pixels, max_pixels=self.max_pixels) for image in images
] or None
model_inputs = self.processor(resized_images, [prompt], add_special_tokens=False, return_tensors="pt")
processed_images = [] if len(images) != 0 else None # text-only data
for image in images:
processed_images.append(process_image(image, self.min_pixels, self.max_pixels))

model_inputs = self.processor(processed_images, [prompt], add_special_tokens=False, return_tensors="pt")
input_ids = model_inputs.pop("input_ids")[0]
attention_mask = model_inputs.pop("attention_mask")[0]
example["multi_modal_data"] = {"images": images}
elif self.video_key in example:
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
video = example.pop(self.video_key)
vision_info = {"video": video, "min_pixels": self.min_pixels, "max_pixels": self.max_pixels, "fps": 25}
video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True)
videos = example.pop(self.video_key)
if self.image_dir is not None and len(videos) != 0 and isinstance(videos[0], str): # video paths
videos = [os.path.join(self.image_dir, video) for video in videos]

processed_videos = [] if len(videos) != 0 else None # text-only data
video_fps_list = []
for video in videos:
processed_video, video_fps = process_video(
video, self.min_pixels, self.max_pixels, self.video_fps, return_fps=True
)
processed_videos.append(processed_video)
video_fps_list.append(video_fps)

model_inputs = self.processor(
videos=video_input, text=[prompt], add_special_tokens=False, return_tensors="pt"
videos=processed_videos, text=[prompt], add_special_tokens=False, return_tensors="pt"
)
if "second_per_grid_ts" in self.processor.model_input_names:
model_inputs["second_per_grid_ts"] = [2.0 / video_sample_fps for video_sample_fps in video_fps_list]

input_ids = model_inputs.pop("input_ids")[0]
attention_mask = model_inputs.pop("attention_mask")[0]
example["multi_modal_data"] = {"video": video_input}
example["multi_modal_data"] = {"videos": videos}
else:
prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
model_inputs = self.tokenizer([prompt], add_special_tokens=False, return_tensors="pt")
Expand All @@ -231,6 +268,7 @@ def __getitem__(self, index):
input_ids=input_ids,
image_grid_thw=model_inputs.get("image_grid_thw", None),
video_grid_thw=model_inputs.get("video_grid_thw", None),
second_per_grid_ts=model_inputs.get("second_per_grid_ts", None),
attention_mask=attention_mask,
) # (3, seq_length)
else:
Expand Down
24 changes: 16 additions & 8 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from ..single_controller.base import Worker
from ..single_controller.base.decorator import Dispatch, register
from ..utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from ..utils.dataset import process_image
from ..utils.dataset import process_image, process_video
from ..utils.flops_counter import FlopsCounter
from ..utils.fsdp_utils import (
get_fsdp_wrap_policy,
Expand Down Expand Up @@ -453,24 +453,32 @@ def _process_multi_modal_inputs(self, data: DataProto):
if "multi_modal_inputs" not in self._cache:
min_pixels = data.meta_info["min_pixels"]
max_pixels = data.meta_info["max_pixels"]
video_fps = data.meta_info["video_fps"]
batch_multi_modal_inputs = []
for multi_modal_data in data.non_tensor_batch["multi_modal_data"]:
images, videos = [], []
if "images" in multi_modal_data:
for image in multi_modal_data["images"]:
images.append(process_image(image, min_pixels, max_pixels))

if "videos" in multi_modal_data:
for video in multi_modal_data["videos"]:
videos.append(process_video(video, min_pixels, max_pixels, video_fps))

if len(images) != 0:
# it's necessary to add `dict` to properly convert batch features to dict
# otherwise the batch features will be converted to dict keys
# see https://github.com/hiyouga/EasyR1/pull/339
images = []
for image in multi_modal_data["images"]:
images.append(process_image(image, min_pixels=min_pixels, max_pixels=max_pixels))
multi_modal_inputs = dict(self.processor.image_processor(images=images, return_tensors="pt"))
multi_modal_inputs = {k: v.to(torch.cuda.current_device()) for k, v in multi_modal_inputs.items()}
batch_multi_modal_inputs.append(multi_modal_inputs)
elif "video" in multi_modal_data:
video = multi_modal_data["video"]
multi_modal_inputs = dict(self.processor.image_processor(images=None, videos=video, return_tensors="pt"))
elif len(videos) != 0:
multi_modal_inputs = dict(
self.processor.image_processor(images=None, videos=video, return_tensors="pt")
)
multi_modal_inputs = {k: v.to(torch.cuda.current_device()) for k, v in multi_modal_inputs.items()}
batch_multi_modal_inputs.append(multi_modal_inputs)
else:
else: # text-only data
batch_multi_modal_inputs.append({})

self._cache["uid"] = data.non_tensor_batch["uid"]
Expand Down
1 change: 1 addition & 0 deletions verl/workers/rollout/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class RolloutConfig:
max_model_len: Optional[int] = None
max_num_batched_tokens: int = 8192
disable_log_stats: bool = True
disable_tqdm: bool = False
val_override_config: Dict[str, Any] = field(default_factory=dict)
# below are auto keys
prompt_length: int = field(default=-1, init=False)
Expand Down
43 changes: 26 additions & 17 deletions verl/workers/rollout/vllm_rollout_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from ...protocol import DataProto
from ...utils import torch_functional as VF
from ...utils.dataset import process_image
from ...utils.dataset import process_image, process_video
from ...utils.torch_dtypes import PrecisionType
from .base import BaseRollout
from .config import RolloutConfig
Expand All @@ -49,19 +49,24 @@ def _get_logit_bias(processor: Optional[ProcessorMixin]) -> Optional[Dict[int, f
return None


def _process_multi_modal_data(multi_modal_data: Dict[str, Any], min_pixels: int, max_pixels: int) -> Dict[str, Any]:
def _process_multi_modal_data(
multi_modal_data: Dict[str, Any], min_pixels: int, max_pixels: int, video_fps: float
) -> Dict[str, Any]:
# may convert image path to image object
# TODO: add video

images, videos = [], []
if "images" in multi_modal_data:
images = []
for image in multi_modal_data["images"]:
images.append(process_image(image, min_pixels=min_pixels, max_pixels=max_pixels))
return {"image": images}
elif "video" in multi_modal_data:
# for image in multi_modal_data["video"]:
# images.append(process_image(image, min_pixels=min_pixels, max_pixels=max_pixels))
return {"video": multi_modal_data["video"]}
images.append(process_image(image, min_pixels, max_pixels))

if "videos" in multi_modal_data:
for video in multi_modal_data["videos"]:
videos.append(process_video(video, min_pixels, max_pixels, video_fps))

if len(images) != 0:
return {"image": images}

if len(videos) != 0:
return {"video": videos}

return None

Expand All @@ -85,6 +90,7 @@ def __init__(
self.rank = int(os.getenv("RANK", "0"))
self.config = config
self.pad_token_id = tokenizer.pad_token_id
self.use_tqdm = (self.rank == 0) and (not config.disable_tqdm)
if config.tensor_parallel_size > torch.distributed.get_world_size():
raise ValueError("Tensor parallelism size should be less than world size.")

Expand All @@ -94,9 +100,8 @@ def __init__(
engine_kwargs = {}
if processor is not None: # only VLMs have processor
engine_kwargs["disable_mm_preprocessor_cache"] = True

if processor is not None and config.limit_images:
engine_kwargs["limit_mm_per_prompt"] = {"image": config.limit_images}
if config.limit_images:
engine_kwargs["limit_mm_per_prompt"] = {"image": config.limit_images}

self.inference_engine = LLM(
model=model_path,
Expand Down Expand Up @@ -166,13 +171,17 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:
raise RuntimeError("vllm sharding manager is not work properly.")

if batch_multi_modal_data is not None:
min_pixels, max_pixels = prompts.meta_info["min_pixels"], prompts.meta_info["max_pixels"]
vllm_inputs = []
for raw_prompt_ids, multi_modal_data in zip(batch_raw_prompt_ids, batch_multi_modal_data):
vllm_inputs.append(
{
"prompt_token_ids": list(raw_prompt_ids),
"multi_modal_data": _process_multi_modal_data(multi_modal_data, min_pixels, max_pixels),
"multi_modal_data": _process_multi_modal_data(
multi_modal_data,
prompts.meta_info["min_pixels"],
prompts.meta_info["max_pixels"],
prompts.meta_info["video_fps"],
),
}
)
else:
Expand All @@ -181,7 +190,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:
# users can customize different sampling_params at different run
with self.update_sampling_params(**prompts.meta_info):
completions: List[RequestOutput] = self.inference_engine.generate(
prompts=vllm_inputs, sampling_params=self.sampling_params, use_tqdm=False
prompts=vllm_inputs, sampling_params=self.sampling_params, use_tqdm=self.use_tqdm
)
response_ids = [output.token_ids for completion in completions for output in completion.outputs]
response_ids = VF.pad_2d_list_to_length(
Expand Down