diff --git a/docs/source/data_utils.md b/docs/source/data_utils.md index 26c3d368655..b6a18efb343 100644 --- a/docs/source/data_utils.md +++ b/docs/source/data_utils.md @@ -4,6 +4,10 @@ [[autodoc]] prepare_multimodal_messages +## prepare_multimodal_messages_vllm + +[[autodoc]] prepare_multimodal_messages_vllm + ## is_conversational [[autodoc]] is_conversational diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index 72a4d3e993f..ac0bd788b24 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -18,7 +18,7 @@ import pytest from datasets import Dataset, DatasetDict -from transformers import AutoProcessor, AutoTokenizer +from transformers import AutoProcessor, AutoTokenizer, is_vision_available from trl.data_utils import ( apply_chat_template, @@ -31,13 +31,19 @@ maybe_unpair_preference_dataset, pack_dataset, prepare_multimodal_messages, + prepare_multimodal_messages_vllm, truncate_dataset, unpair_preference_dataset, ) -from .testing_utils import TrlTestCase +from .testing_utils import TrlTestCase, require_vision +if is_vision_available(): + from PIL import Image + + +@require_vision class TestPrepareMultimodalMessages: def test_basic_user_assistant_conversation(self): """Test basic conversation with user and assistant messages.""" @@ -45,30 +51,46 @@ def test_basic_user_assistant_conversation(self): {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, ] - - prepare_multimodal_messages(messages, num_images=1) + image = Image.new("RGB", (10, 10), color="blue") + messages = prepare_multimodal_messages(messages, images=[image]) expected = [ - {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}, - {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}, + { + "role": "user", + "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "It is blue."}], + }, ] assert messages == expected def test_first_user_message_gets_image(self): - """Test that only the first user message gets an image placeholder.""" + """Test that only the first user message gets an image.""" messages = [ {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, {"role": "user", "content": "How about the grass?"}, ] - prepare_multimodal_messages(messages, num_images=1) + image = Image.new("RGB", (10, 10), color="blue") + messages = prepare_multimodal_messages(messages, images=[image]) expected = [ - {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}, - {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}, - {"role": "user", "content": [{"type": "text", "text": "How about the grass?"}]}, + { + "role": "user", + "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "It is blue."}], + }, + { + "role": "user", + "content": [{"type": "text", "text": "How about the grass?"}], + }, ] assert messages == expected @@ -79,20 +101,23 @@ def test_multiple_images(self): {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, ] - - prepare_multimodal_messages(messages, num_images=3) + images = [Image.new("RGB", (10, 10), color=color) for color in ["red", "green", "blue"]] + messages = prepare_multimodal_messages(messages, images=images) expected = [ { "role": "user", "content": [ - {"type": "image"}, - {"type": "image"}, - {"type": "image"}, + {"type": "image", "image": images[0]}, + {"type": "image", "image": images[1]}, + {"type": "image", "image": images[2]}, {"type": "text", "text": "What color is the sky?"}, ], }, - {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}, + { + "role": "assistant", + "content": [{"type": "text", "text": "It is blue."}], + }, ] assert messages == expected @@ -104,11 +129,18 @@ def test_system_message_transformation(self): {"role": "user", "content": "What color is the sky?"}, ] - prepare_multimodal_messages(messages, num_images=1) + image = Image.new("RGB", (10, 10), color="blue") + messages = prepare_multimodal_messages(messages, images=[image]) expected = [ - {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant"}]}, - {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}, + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant"}], + }, + { + "role": "user", + "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], + }, ] assert messages == expected @@ -121,10 +153,25 @@ def test_already_prepared_messages_unchanged(self): {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}, ] - original = copy.deepcopy(messages) - prepare_multimodal_messages(messages, num_images=1) + image = Image.new("RGB", (10, 10), color="blue") + messages = prepare_multimodal_messages(messages, images=[image]) - assert messages == original + expected = [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant"}], + }, + { + "role": "user", + "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "It is blue."}], + }, + ] + + assert messages == expected def test_mixed_prepared_and_unprepared_messages(self): """Test handling of mixed prepared and unprepared messages.""" @@ -134,17 +181,119 @@ def test_mixed_prepared_and_unprepared_messages(self): {"role": "user", "content": "What about the grass?"}, ] - prepare_multimodal_messages(messages, num_images=1) + image = Image.new("RGB", (10, 10), color="blue") + messages = prepare_multimodal_messages(messages, images=[image]) expected = [ - {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}, - {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}, - {"role": "user", "content": [{"type": "text", "text": "What about the grass?"}]}, + { + "role": "user", + "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "It is blue."}], + }, + { + "role": "user", + "content": [{"type": "text", "text": "What about the grass?"}], + }, ] assert messages == expected +@require_vision +class TestPrepareMultimodalMessagesVLLM: + def test_single_image_conversion(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": Image.new("RGB", (10, 10), color="blue")}, + {"type": "text", "text": "What color is the sky?"}, + ], + } + ] + + result = prepare_multimodal_messages_vllm(messages) + + # Original should remain unchanged (deepcopy test) + assert messages[0]["content"][0]["type"] == "image" + + # Converted version should have correct structure + assert result[0]["content"][0]["type"] == "image_pil" + assert "image_pil" in result[0]["content"][0] + assert "image" not in result[0]["content"][0] + assert isinstance(result[0]["content"][0]["image_pil"], Image.Image) + assert result[0]["content"][1]["type"] == "text" + + def test_mixed_content_conversion(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What color is the sky?"}, + {"type": "image", "image": Image.new("RGB", (10, 10), color="blue")}, + ], + } + ] + + result = prepare_multimodal_messages_vllm(messages) + + # The image part should be converted, text should be unchanged + assert result[0]["content"][0]["type"] == "text" + assert result[0]["content"][1]["type"] == "image_pil" + + def test_no_images(self): + messages = [{"role": "user", "content": [{"type": "text", "text": "What color is the sky?"}]}] + + result = prepare_multimodal_messages_vllm(messages) + + # Should be identical since there are no images + assert result == messages + # And a deepcopy — not the same object + assert result is not messages + assert result[0] is not messages[0] + + def test_multiple_messages(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What color is the sky?"}, + {"type": "image", "image": Image.new("RGB", (10, 10), color="blue")}, + ], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "It is blue."}], + }, + ] + + result = prepare_multimodal_messages_vllm(messages) + + assert result[0]["content"][1]["type"] == "image_pil" + assert result[1]["content"][0]["type"] == "text" + assert result[1]["content"][0]["text"] == "It is blue." + + def test_deepcopy_integrity(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What color is the sky?"}, + {"type": "image", "image": Image.new("RGB", (10, 10), color="blue")}, + ], + }, + ] + original = copy.deepcopy(messages) + + _ = prepare_multimodal_messages_vllm(messages) + + # Original should not be mutated + assert messages == original + + class TestIsConversational(TrlTestCase): conversational_examples = [ { # Language modeling diff --git a/trl/__init__.py b/trl/__init__.py index 8babb49039e..ee6aa770b8e 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -50,6 +50,7 @@ "maybe_unpair_preference_dataset", "pack_dataset", "prepare_multimodal_messages", + "prepare_multimodal_messages_vllm", "truncate_dataset", "unpair_preference_dataset", ], @@ -129,6 +130,7 @@ maybe_unpair_preference_dataset, pack_dataset, prepare_multimodal_messages, + prepare_multimodal_messages_vllm, truncate_dataset, unpair_preference_dataset, ) diff --git a/trl/data_utils.py b/trl/data_utils.py index e80311c2111..0209fbcb6dd 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy from collections import defaultdict, deque from collections.abc import Sequence from itertools import takewhile @@ -28,19 +29,30 @@ DatasetType = TypeVar("DatasetType", Dataset, DatasetDict) -def prepare_multimodal_messages(messages: list[dict[str, Any]], num_images: int) -> None: +def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list) -> list[dict[str, Any]]: + # docstyle-ignore # because is not parsable in the code block """ - Convert messages into a structured multimodal format if needed. - - Each message's content is transformed from a raw string into a list of typed parts. The first user message is - prefixed with an image placeholder, while all other user and assistant messages are wrapped as text entries. + Convert messages into a structured multimodal format and inject the provided images into the message contents. Args: messages (`list[dict[str, Any]]`): - Messages with `"role"` and `"content"`. Content may be a raw string before transformation. - num_images (`int`): - Number of images to include in the first user message. This is used to determine how many image - placeholders to add. + Messages with `"role"` and `"content"`. Content may be a raw string before transformation. List of messages + a `"role"` key (`"system"`, `"user"`, or `"assistant"`) and a `"content"` key containing either a string or + a list of structured blocks if already prepared. + images (`list`): + List of image objects to insert. + + Returns: + `list[dict[str, Any]]`: A deep-copied list of messages where every `"content"` value is a list of structured + content blocks, and all `"image"` placeholders are populated with the corresponding image objects. + + Notes: + - When the input `messages` isn't already in the structured format, (i.e., all `"content"` values are strings), + the function transforms them into the structured format by wrapping text in `{"type": "text", "text": ...}` + and inserting `{"type": "image"}` placeholders for the images *before* the first user message. + - When the input `messages` is already in the structured format (i.e., all `"content"` values are lists of + structured blocks), the function only fills in the actual images in the existing `{"type": "image"}` + placeholders. If the number of placeholders does not match the number of provided images, an error is raised. Example: ```python @@ -50,24 +62,28 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], num_images: int) {"role": "assistant", "content": "It looks like a cat."}, ] - # Output (num_images=1) + # Output, one image provided [ - {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What's in this image?"}]}, + {"role": "user", "content": [{"type": "image", "image": }, {"type": "text", "text": "What's in this image?"}]}, {"role": "assistant", "content": [{"type": "text", "text": "It looks like a cat."}]}, ] ``` """ - image_included = False + + messages = copy.deepcopy(messages) # avoid modifying the original messages + + # First, convert all messages to the structured format if needed, and insert image placeholders if needed + images_included = False for message in messages: if message["role"] == "system": if isinstance(message["content"], str): # if already prepared, the content will be a list message["content"] = [{"type": "text", "text": message["content"]}] elif message["role"] == "user": - if isinstance(message["content"], str) and not image_included: - placeholders = [{"type": "image"}] * num_images - message["content"] = [*placeholders, {"type": "text", "text": message["content"]}] - image_included = True - elif isinstance(message["content"], str) and image_included: + if isinstance(message["content"], str) and not images_included: + image_entries = [{"type": "image"} for _ in range(len(images))] + message["content"] = [*image_entries, {"type": "text", "text": message["content"]}] + images_included = True + elif isinstance(message["content"], str) and images_included: message["content"] = [{"type": "text", "text": message["content"]}] elif message["role"] == "assistant": if isinstance(message["content"], str): @@ -75,6 +91,56 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], num_images: int) else: raise ValueError(f"Invalid role in message: {message['role']}. Expected 'user', 'assistant', or 'system'.") + # Then, check that the number of image placeholders matches the number of images provided + num_placeholders = sum(sum(1 for part in message["content"] if part["type"] == "image") for message in messages) + if num_placeholders != len(images): + raise ValueError( + f"Number of images provided ({len(images)}) does not match number of image placeholders ({num_placeholders})." + ) + + # Then, fill in the actual images in the placeholders + img_idx = 0 + for message in messages: + for part in message["content"]: + if part["type"] == "image": + part["image"] = images[img_idx] + img_idx += 1 + + return messages + + +def prepare_multimodal_messages_vllm(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + # docstyle-ignore # because is not parsable in the code block + """ + Convert structured multimodal messages into a format compatible with vLLM. Replaces `"type": "image"` blocks with + `"type": "image_pil"` blocks, and `"image": Image` with `"image_pil": Image`. + + Args: + messages (`list[dict[str, Any]]`): + Messages with `"role"` and `"content"`. Content is expected to be a list of structured blocks. + + Returns: + `list[dict[str, Any]]`: + A deep-copied list of messages compatible with vLLM's expected input format. + + Example: + ```python + # Input + [{"role": "user", "content": [{"type": "image", "image": }, {"type": "text", "text": "What's in this image?"}]}] + + # Output + [{"role": "user", "content": [{"type": "image_pil", "image_pil": }, {"type": "text", "text": "What's in this image?"}]}] + ``` + """ + messages = copy.deepcopy(messages) # avoid modifying the original messages + for message in messages: + if isinstance(message["content"], list): + for part in message["content"]: + if part["type"] == "image": + part["type"] = "image_pil" # vLLM expects 'image_pil' key for images + part["image_pil"] = part.pop("image") + return messages + def is_conversational(example: dict[str, Any]) -> bool: r""" diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index eaeb6eb5a34..3150f947f3a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -44,9 +44,14 @@ is_wandb_available, ) from transformers.trainer_utils import seed_worker -from transformers.utils import is_datasets_available, is_flash_attn_2_available, is_peft_available, is_rich_available +from transformers.utils import is_datasets_available, is_peft_available, is_rich_available -from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template, prepare_multimodal_messages +from ..data_utils import ( + apply_chat_template, + is_conversational, + prepare_multimodal_messages, + prepare_multimodal_messages_vllm, +) from ..extras.profiling import profiling_context, profiling_decorator from ..extras.vllm_client import VLLMClient from ..import_utils import is_liger_kernel_available, is_vllm_available @@ -1069,23 +1074,9 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): rewards_per_func = gather(rewards_per_func) return rewards_per_func - def _generate_single_turn(self, prompts: list[str], images: Optional[list]): + def _generate_single_turn(self, prompts: list): device = self.accelerator.device - # If the prompts are conversational and the inputs contain images, we need to convert the prompts from - # [{"role": "user", "content": "What color is the sky?"}] to - # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] - kwargs = {} - if images is not None: - kwargs = {"images": images} - for prompt, image_list in zip(prompts, images): - if isinstance(prompt, list): # i.e., when using conversational data - prepare_multimodal_messages(prompt, num_images=len(image_list)) - - prompts_text = [ - maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts - ] - # Generate completions using either vLLM or regular generation if self.use_vllm: if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: @@ -1098,38 +1089,35 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): self._move_model_to_vllm() self._last_loaded_step = self.state.global_step + prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in prompts] + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process if self.vllm_mode == "server": - all_prompts_text = gather_object(prompts_text) - if images is not None: - all_images = gather_object(images) + all_prompts = gather_object(prompts) if self.accelerator.is_main_process: # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate # num_generations outputs for each one. This is faster than generating outputs for each duplicate # prompt individually. - ordered_set_of_prompts = all_prompts_text[:: self.num_generations] - - if images is not None: - ordered_set_of_images = all_images[:: self.num_generations] - else: - ordered_set_of_images = None - + ordered_set_of_prompts = all_prompts[:: self.num_generations] + + sampling_params = { + "n": self.num_generations, + "repetition_penalty": self.repetition_penalty, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": -1 if self.top_k is None else self.top_k, + "min_p": 0.0 if self.min_p is None else self.min_p, + "max_tokens": self.max_completion_length, + "truncate_prompt_tokens": self.max_prompt_length, + "guided_decoding_regex": self.guided_decoding_regex, + "generation_kwargs": self.args.generation_kwargs, + } with profiling_context(self, "vLLM.generate"): - output = self.vllm_client.generate( - prompts=ordered_set_of_prompts, - images=ordered_set_of_images, - n=self.num_generations, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=-1 if self.top_k is None else self.top_k, - min_p=0.0 if self.min_p is None else self.min_p, - max_tokens=self.max_completion_length, - truncate_prompt_tokens=self.max_prompt_length, - guided_decoding_regex=self.guided_decoding_regex, - generation_kwargs=self.args.generation_kwargs, - ) + if is_conversational({"prompt": ordered_set_of_prompts[0]}): + output = self.vllm_client.chat(prompts=ordered_set_of_prompts, **sampling_params) + else: + output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"]) else: payload = None @@ -1176,31 +1164,18 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): if self.vllm_tensor_parallel_size > 1: # Gather prompts from all ranks in the TP group and flatten. # Each rank starts with its own prompts; after gathering, all ranks see the full group set. - orig_size = len(prompts_text) + orig_size = len(prompts) gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) - all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - - if images is not None: - gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) - all_images = [img for sublist in gathered_images for img in sublist] - else: - all_images = None - else: - all_prompts_text = prompts_text - all_images = images - - if images is not None and all_images: - vllm_inputs = [] - for prompt, image_list in zip(all_prompts_text, all_images): - vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) - + torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group) + all_prompts = [p for sublist in gathered_prompts for p in sublist] else: - vllm_inputs = all_prompts_text + all_prompts = prompts with profiling_context(self, "vLLM.generate"): - all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False) + if is_conversational({"prompt": prompts[0]}): + all_outputs = self.llm.chat(all_prompts, sampling_params=sampling_params, use_tqdm=False) + else: + all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False) all_prompt_ids = [output.prompt_token_ids for output in all_outputs] all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] @@ -1227,15 +1202,15 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): self.llm.sleep(level=1) elif self.use_transformers_paged: - # Re-process inputs for paged generation if needed - # Note: images are already validated and preprocessed above - paged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs) - previous_attn = self.model_wrapped.config._attn_implementation - - if is_flash_attn_2_available(): - self.model_wrapped.config._attn_implementation = "paged_attention" + processor_kwargs = {"max_length": self.max_prompt_length, "truncation": True, "add_special_tokens": False} + if is_conversational({"prompt": prompts[0]}): + generate_inputs = self.processing_class.apply_chat_template( + conversation=prompts, **processor_kwargs, tokenize=True, return_dict=True + ) else: - self.model_wrapped.config._attn_implementation = "sdpa_paged" + generate_inputs = self.processing_class(text=prompts, **processor_kwargs) + generate_inputs["inputs"] = generate_inputs.pop("input_ids") + with ( profiling_context(self, "transformers.generate_batch"), unwrap_model_for_generation( @@ -1251,27 +1226,29 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): unwrapped_model.to(torch.float16) with torch.inference_mode(): all_outputs = unwrapped_model.generate_batch( - paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False + **generate_inputs, generation_config=self.generation_config, progress_bar=False ) unwrapped_model.train() # restore training mode, as generate_batch forces eval mode completion_ids = [output.generated_tokens for output in all_outputs.values()] - prompt_ids = paged_prompt_inputs.input_ids - # Restore the original attention implementation, training mode - self.model_wrapped.config._attn_implementation = previous_attn + prompt_ids = generate_inputs["inputs"] logprobs = None # not used in this case else: # Regular generation path - generate_inputs = self.processing_class( - text=prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - max_length=self.max_prompt_length, - truncation=True, - add_special_tokens=False, - **kwargs, - ) + processor_kwargs = { + "return_tensors": "pt", + "padding": True, + "padding_side": "left", + "max_length": self.max_prompt_length, + "truncation": True, + "add_special_tokens": False, + } + if is_conversational({"prompt": prompts[0]}): + generate_inputs = self.processing_class.apply_chat_template( + conversation=prompts, **processor_kwargs, tokenize=True, return_dict=True + ) + else: + generate_inputs = self.processing_class(text=prompts, **processor_kwargs) generate_inputs = super()._prepare_inputs(generate_inputs) with ( @@ -1302,11 +1279,11 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): return prompt_ids, completion_ids, logprobs - def _generate(self, prompts: list[str], images: Optional[list]): + def _generate(self, prompts: list): device = self.accelerator.device mode = "train" if self.model.training else "eval" - prompt_ids, completion_ids, logprobs = self._generate_single_turn(prompts, images) + prompt_ids, completion_ids, logprobs = self._generate_single_turn(prompts) # Get completion length per sequence, used for logging prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) @@ -1358,8 +1335,14 @@ def _generate_and_score_completions( if images is not None and all(img_list == [] for img_list in images): images = None + # If the prompts are conversational and the inputs contain images, we need to convert the prompts from + # [{"role": "user", "content": "What color is the sky?"}] to + # [{"role": "user", "content": [{"type": "image", "image": }, {"type": "text", "text": "What color is the sky?"}]}] + if images is not None: + prompts = [prepare_multimodal_messages(prompt, image_list) for prompt, image_list in zip(prompts, images)] + prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list = self._generate( - prompts, images + prompts ) # Convert lists of token IDs to padded tensors diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index a6ff57aa8f2..78cf9555376 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -45,9 +45,14 @@ is_wandb_available, ) from transformers.trainer_utils import seed_worker -from transformers.utils import is_datasets_available, is_flash_attn_2_available, is_peft_available, is_rich_available +from transformers.utils import is_datasets_available, is_peft_available, is_rich_available -from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template, prepare_multimodal_messages +from ..data_utils import ( + apply_chat_template, + is_conversational, + prepare_multimodal_messages, + prepare_multimodal_messages_vllm, +) from ..extras.profiling import profiling_context, profiling_decorator from ..extras.vllm_client import VLLMClient from ..import_utils import is_vllm_available @@ -953,23 +958,9 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): rewards_per_func = gather(rewards_per_func) return rewards_per_func - def _generate_single_turn(self, prompts: list[str], images: Optional[list]): + def _generate_single_turn(self, prompts: list): device = self.accelerator.device - # If the prompts are conversational and the inputs contain images, we need to convert the prompts from - # [{"role": "user", "content": "What color is the sky?"}] to - # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] - kwargs = {} - if images is not None: - kwargs = {"images": images} - for prompt, image_list in zip(prompts, images): - if isinstance(prompt, list): # i.e., when using conversational data - prepare_multimodal_messages(prompt, num_images=len(image_list)) - - prompts_text = [ - maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts - ] - # Generate completions using either vLLM or regular generation if self.use_vllm: if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: @@ -982,38 +973,35 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): self._move_model_to_vllm() self._last_loaded_step = self.state.global_step + prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in prompts] + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process if self.vllm_mode == "server": - all_prompts_text = gather_object(prompts_text) - if images is not None: - all_images = gather_object(images) + all_prompts = gather_object(prompts) if self.accelerator.is_main_process: # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate # num_generations outputs for each one. This is faster than generating outputs for each duplicate # prompt individually. - ordered_set_of_prompts = all_prompts_text[:: self.num_generations] - - if images is not None: - ordered_set_of_images = all_images[:: self.num_generations] - else: - ordered_set_of_images = None - + ordered_set_of_prompts = all_prompts[:: self.num_generations] + + sampling_params = { + "n": self.num_generations, + "repetition_penalty": self.repetition_penalty, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": -1 if self.top_k is None else self.top_k, + "min_p": 0.0 if self.min_p is None else self.min_p, + "max_tokens": self.max_completion_length, + "truncate_prompt_tokens": self.max_prompt_length, + "guided_decoding_regex": self.guided_decoding_regex, + "generation_kwargs": self.args.generation_kwargs, + } with profiling_context(self, "vLLM.generate"): - output = self.vllm_client.generate( - prompts=ordered_set_of_prompts, - images=ordered_set_of_images, - n=self.num_generations, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=-1 if self.top_k is None else self.top_k, - min_p=0.0 if self.min_p is None else self.min_p, - max_tokens=self.max_completion_length, - truncate_prompt_tokens=self.max_prompt_length, - guided_decoding_regex=self.guided_decoding_regex, - generation_kwargs=self.args.generation_kwargs, - ) + if is_conversational({"prompt": ordered_set_of_prompts[0]}): + output = self.vllm_client.chat(prompts=ordered_set_of_prompts, **sampling_params) + else: + output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"]) else: payload = None @@ -1058,31 +1046,18 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): if self.vllm_tensor_parallel_size > 1: # Gather prompts from all ranks in the TP group and flatten. # Each rank starts with its own prompts; after gathering, all ranks see the full group set. - orig_size = len(prompts_text) + orig_size = len(prompts) gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) - all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - - if images is not None: - gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) - all_images = [img for sublist in gathered_images for img in sublist] - else: - all_images = None - else: - all_prompts_text = prompts_text - all_images = images - - if images is not None and all_images: - vllm_inputs = [] - for prompt, image_list in zip(all_prompts_text, all_images): - vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) - + torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group) + all_prompts = [p for sublist in gathered_prompts for p in sublist] else: - vllm_inputs = all_prompts_text + all_prompts = prompts with profiling_context(self, "vLLM.generate"): - all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False) + if is_conversational({"prompt": prompts[0]}): + all_outputs = self.llm.chat(all_prompts, sampling_params=sampling_params, use_tqdm=False) + else: + all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False) all_prompt_ids = [output.prompt_token_ids for output in all_outputs] all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] @@ -1102,15 +1077,18 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): self.llm.sleep(level=1) elif self.use_transformers_paged: - # Re-process inputs for paged generation if needed - # Note: images are already validated and preprocessed above - paged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs) - previous_attn = self.model_wrapped.config._attn_implementation - - if is_flash_attn_2_available(): - self.model_wrapped.config._attn_implementation = "paged_attention" + processor_kwargs = {"max_length": self.max_prompt_length, "truncation": True, "add_special_tokens": False} + if is_conversational({"prompt": prompts[0]}): + generate_inputs = self.processing_class.apply_chat_template( + conversation=prompts, + **processor_kwargs, + tokenize=True, + return_dict=True, + ) else: - self.model_wrapped.config._attn_implementation = "sdpa_paged" + generate_inputs = self.processing_class(text=prompts, **processor_kwargs) + generate_inputs["inputs"] = generate_inputs.pop("input_ids") + with ( profiling_context(self, "transformers.generate_batch"), unwrap_model_for_generation( @@ -1126,26 +1104,28 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): unwrapped_model.to(torch.float16) with torch.inference_mode(): all_outputs = unwrapped_model.generate_batch( - paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False + **generate_inputs, generation_config=self.generation_config, progress_bar=False ) unwrapped_model.train() # restore training mode, as generate_batch forces eval mode completion_ids = [output.generated_tokens for output in all_outputs.values()] - prompt_ids = paged_prompt_inputs.input_ids - # Restore the original attention implementation, training mode - self.model_wrapped.config._attn_implementation = previous_attn + prompt_ids = generate_inputs["inputs"] else: # Regular generation path - generate_inputs = self.processing_class( - text=prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - max_length=self.max_prompt_length, - truncation=True, - add_special_tokens=False, - **kwargs, - ) + processor_kwargs = { + "return_tensors": "pt", + "padding": True, + "padding_side": "left", + "max_length": self.max_prompt_length, + "truncation": True, + "add_special_tokens": False, + } + if is_conversational({"prompt": prompts[0]}): + generate_inputs = self.processing_class.apply_chat_template( + conversation=prompts, **processor_kwargs, tokenize=True, return_dict=True + ) + else: + generate_inputs = self.processing_class(text=prompts, **processor_kwargs) generate_inputs = super()._prepare_inputs(generate_inputs) with ( @@ -1175,11 +1155,11 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): return prompt_ids, completion_ids - def _generate(self, prompts: list[str], images: Optional[list]): + def _generate(self, prompts: list): device = self.accelerator.device mode = "train" if self.model.training else "eval" - prompt_ids, completion_ids = self._generate_single_turn(prompts, images) + prompt_ids, completion_ids = self._generate_single_turn(prompts) # Get completion length per sequence, used for logging prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) @@ -1232,7 +1212,13 @@ def _generate_and_score_completions( if images is not None and all(img_list == [] for img_list in images): images = None - prompt_ids_list, completion_ids_list = self._generate(prompts, images) + # If the prompts are conversational and the inputs contain images, we need to convert the prompts from + # [{"role": "user", "content": "What color is the sky?"}] to + # [{"role": "user", "content": [{"type": "image", "image": }, {"type": "text", "text": "What color is the sky?"}]}] + if images is not None: + prompts = [prepare_multimodal_messages(prompt, image_list) for prompt, image_list in zip(prompts, images)] + + prompt_ids_list, completion_ids_list = self._generate(prompts) # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 9014470b55c..6ffe70e324a 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -353,9 +353,7 @@ def _collate_language_modeling(self, examples: list[dict[str, Any]]) -> dict[str images = None if "messages" in examples[0]: # conversational case - for example in examples: - prepare_multimodal_messages(example["messages"], len(example["images"])) - messages = [example["messages"] for example in examples] + messages = [prepare_multimodal_messages(example["messages"], example["images"]) for example in examples] texts = self.processor.apply_chat_template(messages) elif self.dataset_text_field in examples[0]: # standard case texts = [example[self.dataset_text_field] for example in examples] @@ -396,7 +394,8 @@ def _collate_prompt_completion(self, examples: list[dict[str, Any]]) -> dict[str images = None if is_conversational(examples[0]): # conversational case for example in examples: - prepare_multimodal_messages(example["prompt"] + example["completion"], len(example["images"])) + example["prompt"] = prepare_multimodal_messages(example["prompt"], images=example["images"]) + example["completion"] = prepare_multimodal_messages(example["completion"], images=[]) examples = [apply_chat_template(example, self.processor) for example in examples] prompts = [example["prompt"] for example in examples] @@ -951,10 +950,13 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo output = {} if is_conversational(example): if self._is_vlm: - prepare_multimodal_messages(example["prompt"], num_images=0) - prepare_multimodal_messages(example["completion"], num_images=0) + prompt = prepare_multimodal_messages(example["prompt"], images=[]) + completion = prepare_multimodal_messages(example["completion"], images=[]) + else: + prompt = example["prompt"] + completion = example["completion"] prompt_ids = processing_class.apply_chat_template( - example["prompt"], + prompt, tokenize=True, add_generation_prompt=True, tools=example.get("tools"), @@ -964,7 +966,7 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo # even for single examples, while for LLMs it returns lists of ints. prompt_ids = prompt_ids[0] if isinstance(prompt_ids[0], list) else prompt_ids prompt_completion_processed = processing_class.apply_chat_template( - example["prompt"] + example["completion"], + prompt + completion, return_dict=True, tokenize=True, return_assistant_tokens_mask=assistant_only_loss, @@ -1002,9 +1004,11 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo else: # language modeling case if is_conversational(example): if self._is_vlm: - prepare_multimodal_messages(example["messages"], num_images=0) + messages = prepare_multimodal_messages(example["messages"], images=[]) + else: + messages = example["messages"] processed = processing_class.apply_chat_template( - example["messages"], + messages, return_dict=True, tokenize=True, return_assistant_tokens_mask=assistant_only_loss,