Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 41 additions & 26 deletions verl/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,27 +49,33 @@ def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]:
return {**tensors, **non_tensors}


def process_image(image: Union[Dict[str, Any], ImageObject], max_pixels: int, min_pixels: int) -> ImageObject:
if isinstance(image, dict):
image = Image.open(BytesIO(image["bytes"]))
class ImageProcessMixin:
max_pixels: int
min_pixels: int

if (image.width * image.height) > max_pixels:
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height))
def process_image(self, image: Union[Dict[str, Any], ImageObject]) -> ImageObject:
if isinstance(image, dict):
image = Image.open(BytesIO(image["bytes"]))
elif isinstance(image, bytes):
image = Image.open(BytesIO(image))

if (image.width * image.height) < min_pixels:
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height))
if (image.width * image.height) > self.max_pixels:
resize_factor = math.sqrt(self.max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height))

if image.mode != "RGB":
image = image.convert("RGB")
if (image.width * image.height) < self.min_pixels:
resize_factor = math.sqrt(self.min_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height))

return image
if image.mode != "RGB":
image = image.convert("RGB")

return image

class RLHFDataset(Dataset):

class RLHFDataset(Dataset, ImageProcessMixin):
"""
We assume the dataset contains a column that contains prompts and other information
"""
Expand Down Expand Up @@ -116,30 +122,39 @@ def __len__(self):

def __getitem__(self, index):
row_dict: dict = self.dataset[index]
messages = [{"role": "user", "content": row_dict[self.prompt_key]}]
prompt_str: str = row_dict[self.prompt_key]
if self.system_prompt:
messages.insert(0, {"role": "system", "content": self.system_prompt})

prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
prompt_str = " ".join((self.system_prompt.strip(), prompt_str))

if self.image_key in row_dict:
prompt = prompt.replace("<image>", "<|vision_start|><|image_pad|><|vision_end|>")
row_dict["multi_modal_data"] = {
"image": [
process_image(image, self.max_pixels, self.min_pixels) for image in row_dict.pop(self.image_key)
]
}
model_inputs = self.processor(row_dict["multi_modal_data"]["image"], prompt, return_tensors="pt")
# https://huggingface.co/docs/transformers/en/tasks/image_text_to_text
content_list = []
for i, content in enumerate(prompt_str.split("<image>")):
if i != 0:
content_list.append({"type": "image"})

if content:
content_list.append({"type": "text", "text": content})

messages = [{"role": "user", "content": content_list}]
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
images = [self.process_image(image) for image in row_dict.pop(self.image_key)]
model_inputs = self.processor(images, prompt, return_tensors="pt")
input_ids = model_inputs.pop("input_ids")[0]
attention_mask = model_inputs.pop("attention_mask")[0]
row_dict["multi_modal_data"] = {"image": images}
row_dict["multi_modal_inputs"] = dict(model_inputs)

# qwen2vl mrope
position_ids = get_rope_index(
self.processor,
input_ids=input_ids,
image_grid_thw=model_inputs["image_grid_thw"],
attention_mask=attention_mask,
) # (3, seq_length)
else:
messages = [{"role": "user", "content": prompt_str}]
prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
model_inputs = self.tokenizer([prompt], add_special_tokens=False, return_tensors="pt")
input_ids = model_inputs.pop("input_ids")[0]
attention_mask = model_inputs.pop("attention_mask")[0]
Expand Down
4 changes: 3 additions & 1 deletion verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def __init__(
self._use_param_offload = self.config.ref.offload.offload_params
self._init_config(self.config.ref, "ref")

def _init_config(self, config: Union[ActorConfig, CriticConfig, RefConfig], role: Literal["actor", "critic", "ref"]):
def _init_config(
self, config: Union[ActorConfig, CriticConfig, RefConfig], role: Literal["actor", "critic", "ref"]
):
world_size = dist.get_world_size()
fsdp_size = config.fsdp.fsdp_size
if fsdp_size <= 0 or fsdp_size >= world_size:
Expand Down