diff --git a/verl/utils/dataset.py b/verl/utils/dataset.py index 95046c29..91a9b419 100644 --- a/verl/utils/dataset.py +++ b/verl/utils/dataset.py @@ -20,8 +20,8 @@ import numpy as np import torch -from jinja2 import Template from datasets import load_dataset +from jinja2 import Template from PIL import Image from PIL.Image import Image as ImageObject from torch.utils.data import Dataset @@ -127,7 +127,7 @@ def __init__( self.format_prompt = f.read() if self.filter_overlong_prompts: - self.dataset = self.dataset.filter(self._filter_overlong_prompts, desc=f"Filtering overlong prompts") + self.dataset = self.dataset.filter(self._filter_overlong_prompts, desc="Filtering overlong prompts") def _get_messages(self, example: Dict[str, Any]) -> List[Dict[str, Any]]: prompt_str: str = example[self.prompt_key] @@ -152,7 +152,9 @@ def _get_messages(self, example: Dict[str, Any]) -> List[Dict[str, Any]]: def _filter_overlong_prompts(self, example: Dict[str, Any]) -> bool: messages = self._get_messages(example) processing_class = self.processor if self.processor is not None else self.tokenizer - return len(processing_class.apply_chat_template(messages, add_generation_prompt=True)) <= self.max_prompt_length + return ( + len(processing_class.apply_chat_template(messages, add_generation_prompt=True)) <= self.max_prompt_length + ) def __len__(self): return len(self.dataset) @@ -198,9 +200,9 @@ def __getitem__(self, index): raw_prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) if len(raw_prompt_ids) > self.max_prompt_length: if self.truncation == "left": - raw_prompt_ids = raw_prompt_ids[..., -self.max_prompt_length:] + raw_prompt_ids = raw_prompt_ids[..., -self.max_prompt_length :] elif self.truncation == "right": - raw_prompt_ids = raw_prompt_ids[..., :self.max_prompt_length] + raw_prompt_ids = raw_prompt_ids[..., : self.max_prompt_length] elif self.truncation == "error": raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.")