Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 95 additions & 38 deletions src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from trl.trainer.utils import generate_model_card, get_comet_experiment_url

import copy

from .utils import compute_logps_with_prompt_cache

if is_peft_available():
from peft import PeftConfig, get_peft_model
Expand Down Expand Up @@ -165,7 +165,7 @@ def __init__(
model_name = model if isinstance(model, str) else model.config._name_or_path
model_name = model_name.split("/")[-1]
args = GRPOConfig(f"{model_name}-GRPO")

self.gradient_checkpointing = False
# Models
# Trained model
model_init_kwargs = args.model_init_kwargs or {}
Expand Down Expand Up @@ -334,17 +334,30 @@ def _set_signature_columns_if_needed(self):


# Get the per-token log probabilities for the completions for the model and the reference model
def _get_per_token_logps(self, model, input_ids, attention_mask, pixel_values, image_grid_thw):
logits = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw).logits # (B, L, V)
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
def _get_per_token_logps(self, model, input_ids, attention_mask, pixel_values, image_grid_thw, num_logits_to_keep, mini_batch_size):
mini_batch_size = input_ids.size(0) if mini_batch_size == 0 else mini_batch_size
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)

for i in range(0, input_ids.size(0), mini_batch_size):
mini_batch_input_ids = input_ids[i : i + mini_batch_size, :] # (B_mini, P+C)
mini_batch_attention_mask = attention_mask[i : i + mini_batch_size, :] # (B_mini, P+C)
mini_pixel_values = pixel_values[i : i + mini_batch_size, :]
mini_image_grid_thw = image_grid_thw[i : i + mini_batch_size, :]
logits = model(
input_ids=mini_batch_input_ids,
attention_mask=mini_batch_attention_mask,
pixel_values=mini_pixel_values,
image_grid_thw=mini_image_grid_thw,
# num_logits_to_keep=num_logits_to_keep + 1,
).logits[:, -num_logits_to_keep - 1 : -1] # (B_mini, P+C, Vocab_size)

token_index = mini_batch_input_ids[:, -num_logits_to_keep:].unsqueeze(-1) # (B_mini, P+C, 1)
token_logits = torch.gather(logits, dim=-1, index=token_index).squeeze(-1)
logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])
del logits
token_log_prob = token_logits - logsumexp_values
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps)
return torch.cat(per_token_logps, dim=0)


# Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
Expand All @@ -356,7 +369,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")


mini_batch_size = self.args.logit_computation_mini_batch_size if hasattr(self.args, 'logit_computation_mini_batch_size') else 1 # need newest trl

prompts = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
Expand All @@ -371,48 +384,92 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)

prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
pixel_values = prompt_inputs["pixel_values"]
image_grid_thw = prompt_inputs["image_grid_thw"]


if self.max_prompt_length is not None:
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]

# Generate completions
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config)

prompt_length = prompt_ids.size(1)
prompt_ids = prompt_completion_ids[:, :prompt_length]
completion_ids = prompt_completion_ids[:, prompt_length:]
prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0)
prompt_length = prompt_inputs["input_ids"].size(1)
completion_ids = prompt_completion_ids[:, prompt_length:]

device = self.accelerator.device

# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
device = self.accelerator.device
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()

# Concatenate prompt_mask with completion_mask for logit computation
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
pixel_values = prompt_inputs["pixel_values"].repeat(self.num_generations, 1)
image_grid_thw = prompt_inputs["image_grid_thw"].repeat_interleave(self.num_generations, dim=0)

per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw)
# Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
per_token_logps = per_token_logps[:, prompt_length - 1 :]

with torch.inference_mode():

if not self.gradient_checkpointing:
# Current policy logprobs (with grad)
per_token_logps = compute_logps_with_prompt_cache(
model=model,
prompt_inputs=prompt_inputs,
completion_ids=completion_ids,
mini_batch_size=mini_batch_size,
requires_grad_for_completion=True,
)
if self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps(self.ref_model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw)
ref_per_token_logps = compute_logps_with_prompt_cache(
model=self.ref_model,
prompt_inputs=prompt_inputs,
completion_ids=completion_ids,
mini_batch_size=mini_batch_size,
requires_grad_for_completion=False,
)
else:
with self.accelerator.unwrap_model(model).disable_adapter():
ref_per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw)
ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]
ref_per_token_logps = compute_logps_with_prompt_cache(
model=model,
prompt_inputs=prompt_inputs,
completion_ids=completion_ids,
mini_batch_size=mini_batch_size,
requires_grad_for_completion=False,
)
else: # unchecked, since the issues about gradient_checkpointing unsolved : https://github.com/Deep-Agent/R1-V/issues/31

prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0)
num_logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens

# Concatenate prompt_mask with completion_mask for logit computation
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
pixel_values = prompt_inputs["pixel_values"].repeat(self.num_generations, 1)
image_grid_thw = prompt_inputs["image_grid_thw"].repeat_interleave(self.num_generations, dim=0)

per_token_logps = self._get_per_token_logps(
model=model,
input_ids=prompt_completion_ids,
attention_mask=attention_mask,
num_logits_to_keep=num_logits_to_keep,
mini_batch_size=mini_batch_size,
)

with torch.inference_mode():
if self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps(
model=self.ref_model,
input_ids=prompt_completion_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
num_logits_to_keep=num_logits_to_keep,
mini_batch_size=mini_batch_size,
)
else:
with self.accelerator.unwrap_model(model).disable_adapter():
ref_per_token_logps = self._get_per_token_logps(
model=model,
input_ids=prompt_completion_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
num_logits_to_keep=num_logits_to_keep,
mini_batch_size=mini_batch_size,
)

# Compute the KL divergence between the model and the reference model
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
Expand Down
95 changes: 95 additions & 0 deletions src/open-r1-multimodal/src/open_r1/trainer/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
def compute_logps_with_prompt_cache(
model: torch.nn.Module,
prompt_inputs: dict,
completion_ids: torch.LongTensor,
mini_batch_size: int,
requires_grad_for_completion: bool = True,
) -> torch.FloatTensor:
"""
The method will compute the log probabilities of the completion tokens by using the prompt cache.
1) Forward pass on the prompt with torch.no_grad() to get `past_key_values`.
2) Forward pass (with or without grad) on the completion tokens using that cache.
3) Compute per-token log probabilities for the completion.
Args:
model (`nn.Module`): A causal LM (transformers.AutoModelForCausalLM) or similar.
prompt_inputs (`dict`): The dict of prompt tensors, e.g. {"input_ids", "attention_mask", ...}.
completion_ids (`torch.LongTensor`): Shape [B*G, completion_len].
mini_batch_size (`int`): The number of completion rows to process at once.
requires_grad_for_completion (`bool`): Whether to enable gradient for the completion pass.
Returns:
per_token_logps (`torch.FloatTensor`): shape [B*G, completion_len],
where per_token_logps[i, t] is the logprob of ith completion's t-th completion token,
given all preceding tokens in the prompt + the partial completion up to t-1.
"""

# Get the batch size (B), number of completions (G), and completion length (C)
B = prompt_inputs["input_ids"].size(0)
G = completion_ids.size(0) // B
C = completion_ids.size(1)

# If the user did not specify a mini_batch_size, use the full batch size (B*G)
if mini_batch_size <= 0:
mini_batch_size = completion_ids.size(0)
# print(model)
# Forward pass over prompt tokens to get 2 things with torch.no_grad:
# 1) `past_key_values`` (KV cache)
# 2) `prompt_last_logps` (the logprobs of the first completion token prediction)
with torch.no_grad():
# print(prompt_inputs)
prompt_out = model(**prompt_inputs, use_cache=True, return_dict=True)
# print(prompt_out)

# Only keep the last prompt logit, immediately convert to log probabilities and expand to B*G
prompt_last_logps = prompt_out.logits[:, -1:].log_softmax(dim=-1).repeat_interleave(G, dim=0)

# Gather the these log probs as they relates to the first completion token
first_completion_token_logps = torch.gather(
prompt_last_logps, dim=-1, index=completion_ids[:, :1].unsqueeze(-1)
).squeeze(-1)

# Expand the KV Cache `G` times to match the dimension of completion_ids (B -> B*G) and split into mini-batches
repeated_kv_cache = prompt_out.past_key_values # a DynamicCache
repeated_kv_cache.batch_repeat_interleave(G)
mini_batch_kv_caches = repeated_kv_cache.batch_split(full_batch_size=B * G, split_size=mini_batch_size)

# Process completion tokens in mini-batches
completion_token_logps = []

for batch_idx, mini_batch_kv_cache in enumerate(mini_batch_kv_caches):
start_idx = batch_idx * mini_batch_size
end_idx = start_idx + mini_batch_size
mini_batch_ids = completion_ids[start_idx:end_idx] # (mini_batch_size, C)

with torch.set_grad_enabled(requires_grad_for_completion):
mini_batch_logits = model(
input_ids=mini_batch_ids,
past_key_values=mini_batch_kv_cache,
# num_logits_to_keep=C,
use_cache=False,
).logits[:, -C:-1, :]

# # Original method
# mini_batch_log_probs = mini_batch_logits.log_softmax(dim=-1)
# del mini_batch_logits

# mini_batch_token_log_prob = torch.gather(mini_batch_log_probs, dim=-1, index=mini_batch_index).squeeze(-1)
# del mini_batch_log_probs

# More optimized method (https://github.com/huggingface/trl/pull/2773)
# Get the corresponding completion token ids and gather the logits for completion_ids w/ idx >= 1
mini_batch_index = mini_batch_ids[:, 1:].unsqueeze(-1) # (mini_batch_size, C-1, 1)
mini_batch_token_logits = torch.gather(mini_batch_logits, dim=-1, index=mini_batch_index).squeeze(
-1
) # (mini_batch_size, C-1)
mini_batch_logsumexp_values = torch.stack(
[torch.logsumexp(l, dim=-1) for l in mini_batch_logits]
) # (mini_batch_size, C-1)
del mini_batch_logits
mini_batch_token_log_prob = mini_batch_token_logits - mini_batch_logsumexp_values # (mini_batch_size, C-1)
completion_token_logps.append(mini_batch_token_log_prob)
del mini_batch_token_logits, mini_batch_logsumexp_values, mini_batch_token_log_prob

# Combine results
all_completion_token_logps = torch.cat(completion_token_logps, dim=0) # (B*G, C-1)
return torch.cat([first_completion_token_logps, all_completion_token_logps], dim=1) # (B*G, C)