diff --git a/.github/requirements-test.txt b/.github/requirements-test.txt index 50b8db80..482ec386 100644 --- a/.github/requirements-test.txt +++ b/.github/requirements-test.txt @@ -2,6 +2,7 @@ codetiming datasets pillow pytest +qwen-vl-utils ray[default] ruff tensordict diff --git a/verl/trainer/config.py b/verl/trainer/config.py index 0ded93ea..66566811 100644 --- a/verl/trainer/config.py +++ b/verl/trainer/config.py @@ -38,6 +38,7 @@ class DataConfig: prompt_key: str = "prompt" answer_key: str = "answer" image_key: str = "images" + video_key: str = "videos" image_dir: Optional[str] = None max_prompt_length: int = 512 max_response_length: int = 512 diff --git a/verl/trainer/data_loader.py b/verl/trainer/data_loader.py index 6491856d..e3a66768 100644 --- a/verl/trainer/data_loader.py +++ b/verl/trainer/data_loader.py @@ -32,6 +32,7 @@ def create_dataloader(config: DataConfig, tokenizer: PreTrainedTokenizer, proces answer_key=config.answer_key, image_key=config.image_key, image_dir=config.image_dir, + video_key=config.video_key, max_prompt_length=config.max_prompt_length, truncation="right", format_prompt=config.format_prompt, diff --git a/verl/utils/dataset.py b/verl/utils/dataset.py index e8f909b3..e1b174ad 100644 --- a/verl/utils/dataset.py +++ b/verl/utils/dataset.py @@ -30,6 +30,7 @@ from ..models.transformers.qwen2_vl import get_rope_index from . import torch_functional as VF +from qwen_vl_utils.vision_process import fetch_video def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]: tensors = defaultdict(list) @@ -91,6 +92,7 @@ def __init__( answer_key: str = "answer", image_key: str = "images", image_dir: Optional[str] = None, + video_key: str = "videos", max_prompt_length: int = 1024, truncation: str = "error", format_prompt: Optional[str] = None, @@ -104,6 +106,7 @@ def __init__( self.answer_key = answer_key self.image_key = image_key self.image_dir = image_dir + self.video_key = video_key self.max_prompt_length = max_prompt_length self.truncation = truncation self.min_pixels = min_pixels @@ -152,6 +155,16 @@ def _build_messages(self, example: Dict[str, Any]) -> List[Dict[str, Any]]: if content: content_list.append({"type": "text", "text": content}) + return [{"role": "user", "content": content_list}] + elif self.video_key in example: + content_list = [] + for i, content in enumerate(prompt_str.split("