diff --git a/recipe/one_step_off_policy/ray_trainer.py b/recipe/one_step_off_policy/ray_trainer.py index e192e655ef7..4781c8b155a 100644 --- a/recipe/one_step_off_policy/ray_trainer.py +++ b/recipe/one_step_off_policy/ray_trainer.py @@ -60,24 +60,28 @@ class GenerationBatchFuture: Wrapper class for encapsulating batch generation results """ - def __init__(self, epoch, batch, gen_batch_output): + def __init__(self, epoch, batch, gen_batch_output, future_reward=None): """ :param epoch: current epoch :param batch: Input batch data :param gen_batch_output: Generated sequences from the main model (DataProtoFuture) + :param future_reward: Future for reward computation (optional) """ self.epoch = epoch self.batch = batch self.gen_batch_output = gen_batch_output + self.future_reward = future_reward def get(self): """ Get the actual results by calling get() method on gen_batch_output Returns: - tuple: (batch, gen_batch_result) + tuple: (epoch, batch, gen_batch_result, future_reward) + - epoch: Current epoch - batch: Original input batch data - gen_batch_result: Result from gen_batch_output.get() or gen_batch_output itself + - future_reward: Future for reward computation if available, else None """ # Call get() method on gen_batch_output if available if hasattr(self.gen_batch_output, "get"): @@ -85,7 +89,7 @@ def get(self): else: gen_batch_result = self.gen_batch_output - return self.epoch, self.batch, gen_batch_result + return self.epoch, self.batch, gen_batch_result, self.future_reward class OneStepOffRayTrainer(RayPPOTrainer): @@ -315,7 +319,10 @@ def _async_gen_next_batch(self, continuous_iterator): except Exception as e: print(f"Error in async_gen_next_batch: {e}") return None + + # Create the initial batch from the data loader batch = DataProto.from_single_dict(batch_dict) + # pop those keys for generation batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] @@ -327,16 +334,68 @@ def _async_gen_next_batch(self, continuous_iterator): non_tensor_batch_keys_to_pop.append("tools_kwargs") if "interaction_kwargs" in batch.non_tensor_batch: non_tensor_batch_keys_to_pop.append("interaction_kwargs") + gen_batch = batch.pop( batch_keys=batch_keys_to_pop, non_tensor_batch_keys=non_tensor_batch_keys_to_pop, ) gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + # sync weights from actor to rollout self.sync_rollout_weights() + # async generation gen_batch_output = self.rollout_wg.async_generate_sequences(gen_batch) - return GenerationBatchFuture(epoch, batch, gen_batch_output) + + # Launch individual reward computations as each generation completes + future_reward = None + if self.config.reward_model.launch_reward_fn_async: + # Store the object reference and set up callback + future_reward = self._launch_individual_rewards.remote( + gen_batch_output, self.config, self.tokenizer, batch.non_tensor_batch + ) + + # Return the original, now-modified `batch` and the `future_reward` + return GenerationBatchFuture(epoch, batch, gen_batch_output, future_reward) + + @staticmethod + @ray.remote + def _launch_individual_rewards(gen_batch_output, config, tokenizer, original_non_tensor_batch): + # Get generation results + gen_batch_result = gen_batch_output.get() + + # Repeat non_tensor_batch to match the number of responses + n = config.actor_rollout_ref.rollout.n + repeated_non_tensor_batch = {} + for key, value in original_non_tensor_batch.items(): + repeated_non_tensor_batch[key] = np.repeat(value, n, axis=0) + + # Split into individual responses with preserved non_tensor_batch + responses_split = [] + for i in range(len(gen_batch_result)): + response_data = gen_batch_result[i : i + 1] # Get single response + # Add repeated non_tensor_batch values + for key in repeated_non_tensor_batch: + response_data.non_tensor_batch[key] = repeated_non_tensor_batch[key][i : i + 1] + responses_split.append(response_data) + + # Launch async reward computation + reward_futures = [ + compute_reward_async.remote(response_data, config, tokenizer) for response_data in responses_split + ] + + # Wait for results and combine + results = ray.get(reward_futures) + rewards_list = [r[0] for r in results] + extras_list = [r[1] for r in results] + + combined_reward_tensor = torch.cat(rewards_list, dim=0) + combined_extras_dict = {} + if extras_list and extras_list[0]: + for key in extras_list[0].keys(): + combined_extras_dict[key] = [d[key] for d in extras_list if key in d] + + return combined_reward_tensor, combined_extras_dict def fit(self): """ @@ -345,6 +404,7 @@ def fit(self): to construct the PPO dataflow. The light-weight advantage computation is done on the driver process. """ + from omegaconf import OmegaConf from verl.utils.tracking import Tracking @@ -408,7 +468,7 @@ def fit(self): with marked_timer("step", timing_raw): # wait for the previous batch with marked_timer("wait_prev_gen", timing_raw, color="red"): - epoch, batch, gen_batch_output = batch_data_future.get() + epoch, batch, gen_batch_output, future_reward = batch_data_future.get() timing_raw.update(gen_batch_output.meta_info["timing"]) gen_batch_output.meta_info.pop("timing", None) @@ -442,8 +502,10 @@ def fit(self): reward_tensor = self.rm_wg.compute_rm_score(batch) batch = batch.union(reward_tensor) + # Use the pre-launched future reward if available if self.config.reward_model.launch_reward_fn_async: - future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer) + # future_reward was already started in _async_gen_next_batch + reward_tensor, reward_extra_infos_dict = ray.get(future_reward) else: reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) @@ -501,8 +563,6 @@ def fit(self): with marked_timer("adv", timing_raw, color="brown"): # we combine with rule-based rm reward_extra_infos_dict: dict[str, list] - if self.config.reward_model.launch_reward_fn_async: - reward_tensor, reward_extra_infos_dict = ray.get(future_reward) batch.batch["token_level_scores"] = reward_tensor if reward_extra_infos_dict: @@ -552,7 +612,17 @@ def fit(self): # Log rollout generations if enabled rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) if rollout_data_dir: - self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) + with marked_timer("dump_rollout_generations", timing_raw, color="green"): + inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) + outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) + scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() + self._dump_generations( + inputs=inputs, + outputs=outputs, + scores=scores, + reward_extra_infos_dict=reward_extra_infos_dict, + dump_path=rollout_data_dir, + ) # validate if ( diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 795e79d273b..9c633b949c2 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -883,12 +883,12 @@ def generate_sequences(self, prompts: DataProto): prompts = prompts.to(get_device_id()) meta_info = { - "eos_token_id": self.model_config.generation_config.eos_token_id - if self.model_config.generation_config is not None - else self.model_config.tokenizer.eos_token_id, - "pad_token_id": self.model_config.generation_config.pad_token_id - if self.model_config.generation_config is not None - else self.model_config.tokenizer.pad_token_id, + "eos_token_id": self.generation_config.eos_token_id + if self.generation_config is not None + else self.tokenizer.eos_token_id, + "pad_token_id": self.generation_config.pad_token_id + if self.generation_config is not None + else self.tokenizer.pad_token_id, } prompts.meta_info.update(meta_info)