diff --git a/recipe/dapo/dapo_ray_trainer.py b/recipe/dapo/dapo_ray_trainer.py index 8c79fa325b0..a4d5c098353 100644 --- a/recipe/dapo/dapo_ray_trainer.py +++ b/recipe/dapo/dapo_ray_trainer.py @@ -332,6 +332,11 @@ def fit(self): actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) + # validate if ( self.val_reward_fn is not None diff --git a/recipe/one_step_off_policy/ray_trainer.py b/recipe/one_step_off_policy/ray_trainer.py index 662f1852646..e192e655ef7 100644 --- a/recipe/one_step_off_policy/ray_trainer.py +++ b/recipe/one_step_off_policy/ray_trainer.py @@ -552,22 +552,7 @@ def fit(self): # Log rollout generations if enabled rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) if rollout_data_dir: - with marked_timer("dump_rollout_generations", timing_raw, color="green"): - inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) - outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) - scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() - sample_gts = [ - item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in batch - ] - - self._dump_generations( - inputs=inputs, - outputs=outputs, - gts=sample_gts, - scores=scores, - reward_extra_infos_dict=reward_extra_infos_dict, - dump_path=rollout_data_dir, - ) + self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) # validate if ( diff --git a/recipe/sppo/sppo_ray_trainer.py b/recipe/sppo/sppo_ray_trainer.py index baf1786964a..e075beec8c3 100644 --- a/recipe/sppo/sppo_ray_trainer.py +++ b/recipe/sppo/sppo_ray_trainer.py @@ -308,22 +308,7 @@ def fit(self): # Log rollout generations if enabled rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) if rollout_data_dir: - with simple_timer("dump_rollout_generations", timing_raw): - print(batch.batch.keys()) - inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) - outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) - scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() - sample_gts = [ - item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in batch - ] - self._dump_generations( - inputs=inputs, - outputs=outputs, - scores=scores, - reward_extra_infos_dict=reward_extra_infos_dict, - gts=sample_gts, - dump_path=rollout_data_dir, - ) + self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) # validate if ( diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index bb945c0451f..0e3b1b77c5f 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -441,6 +441,38 @@ def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dic print(f"Dumped generations to {filename}") + def _log_rollout_data( + self, batch: DataProto, reward_extra_infos_dict: dict, timing_raw: dict, rollout_data_dir: str + ): + """Log rollout data to disk. + Args: + batch (DataProto): The batch containing rollout data + reward_extra_infos_dict (dict): Additional reward information to log + timing_raw (dict): Timing information for profiling + rollout_data_dir (str): Directory path to save the rollout data + """ + with marked_timer("dump_rollout_generations", timing_raw, color="green"): + inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) + outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) + scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() + sample_gts = [item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in batch] + + reward_extra_infos_to_dump = reward_extra_infos_dict.copy() + if "request_id" in batch.non_tensor_batch: + reward_extra_infos_dict.setdefault( + "request_id", + batch.non_tensor_batch["request_id"].tolist(), + ) + + self._dump_generations( + inputs=inputs, + outputs=outputs, + gts=sample_gts, + scores=scores, + reward_extra_infos_dict=reward_extra_infos_to_dump, + dump_path=rollout_data_dir, + ) + def _maybe_log_val_generations(self, inputs, outputs, scores): """Log a table of validation samples to the configured logger (wandb or swanlab)""" @@ -1111,29 +1143,7 @@ def fit(self): # Log rollout generations if enabled rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) if rollout_data_dir: - with marked_timer("dump_rollout_generations", timing_raw, color="green"): - inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) - outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) - scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() - sample_gts = [ - item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) - for item in batch - ] - - if "request_id" in batch.non_tensor_batch: - reward_extra_infos_dict.setdefault( - "request_id", - batch.non_tensor_batch["request_id"].tolist(), - ) - - self._dump_generations( - inputs=inputs, - outputs=outputs, - gts=sample_gts, - scores=scores, - reward_extra_infos_dict=reward_extra_infos_dict, - dump_path=rollout_data_dir, - ) + self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) # validate if (