Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion recipe/one_step_off_policy/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

from .ray_trainer import OneStepOffRayTrainer


@hydra.main(config_path="config", config_name="one_step_off_ppo_trainer", version_base=None)
def main(config):
run_ppo(config)
Expand Down
85 changes: 76 additions & 9 deletions recipe/one_step_off_policy/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,32 +60,36 @@ class GenerationBatchFuture:
Wrapper class for encapsulating batch generation results
"""

def __init__(self, epoch, batch, gen_batch_output):
def __init__(self, epoch, batch, gen_batch_output, future_reward=None):
"""
:param epoch: current epoch
:param batch: Input batch data
:param gen_batch_output: Generated sequences from the main model (DataProtoFuture)
:param future_reward: Future for reward computation (optional)
"""
self.epoch = epoch
self.batch = batch
self.gen_batch_output = gen_batch_output
self.future_reward = future_reward

def get(self):
"""
Get the actual results by calling get() method on gen_batch_output

Returns:
tuple: (batch, gen_batch_result)
tuple: (epoch, batch, gen_batch_result, future_reward)
- epoch: Current epoch
- batch: Original input batch data
- gen_batch_result: Result from gen_batch_output.get() or gen_batch_output itself
- future_reward: Future for reward computation if available, else None
"""
# Call get() method on gen_batch_output if available
if hasattr(self.gen_batch_output, "get"):
gen_batch_result = self.gen_batch_output.get()
else:
gen_batch_result = self.gen_batch_output

return self.epoch, self.batch, gen_batch_result
return self.epoch, self.batch, gen_batch_result, self.future_reward


class OneStepOffRayTrainer(RayPPOTrainer):
Expand Down Expand Up @@ -304,7 +308,10 @@ def _async_gen_next_batch(self, continuous_iterator):
except Exception as e:
print(f"Error in async_gen_next_batch: {e}")
return None

# Create the initial batch from the data loader
batch = DataProto.from_single_dict(batch_dict)

# pop those keys for generation
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
Expand All @@ -316,16 +323,69 @@ def _async_gen_next_batch(self, continuous_iterator):
non_tensor_batch_keys_to_pop.append("tools_kwargs")
if "interaction_kwargs" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("interaction_kwargs")

gen_batch = batch.pop(
batch_keys=batch_keys_to_pop,
non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
)
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)

# sync weights from actor to rollout
self.sync_rollout_weights()

# async generation
gen_batch_output = self.rollout_wg.async_generate_sequences(gen_batch)
return GenerationBatchFuture(epoch, batch, gen_batch_output)

# Launch individual reward computations as each generation completes
future_reward = None
if self.config.reward_model.launch_reward_fn_async:
# Store the object reference and set up callback
future_reward = self._launch_individual_rewards.remote(
gen_batch_output, self.config, self.tokenizer, batch.non_tensor_batch
)

# Return the original, now-modified `batch` and the `future_reward`
return GenerationBatchFuture(epoch, batch, gen_batch_output, future_reward)

@staticmethod
@ray.remote
def _launch_individual_rewards(gen_batch_output, config, tokenizer, original_non_tensor_batch):
# Get generation results
gen_batch_result = gen_batch_output.get()

# Repeat non_tensor_batch to match the number of responses
n = config.actor_rollout_ref.rollout.n
repeated_non_tensor_batch = {}
for key, value in original_non_tensor_batch.items():
repeated_non_tensor_batch[key] = np.repeat(value, n, axis=0)

# Split into individual responses with preserved non_tensor_batch
responses_split = []
for i in range(len(gen_batch_result)):
response_data = gen_batch_result[i:i+1] # Get single response
# Add repeated non_tensor_batch values
for key in repeated_non_tensor_batch:
response_data.non_tensor_batch[key] = repeated_non_tensor_batch[key][i:i+1]
responses_split.append(response_data)

# Launch async reward computation
reward_futures = [
compute_reward_async.remote(response_data, config, tokenizer)
for response_data in responses_split
]

# Wait for results and combine
results = ray.get(reward_futures)
rewards_list = [r[0] for r in results]
extras_list = [r[1] for r in results]

combined_reward_tensor = torch.cat(rewards_list, dim=0)
combined_extras_dict = {}
if extras_list and extras_list[0]:
for key in extras_list[0].keys():
combined_extras_dict[key] = [d[key] for d in extras_list if key in d]
Comment on lines +393 to +396
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation for combining extras_list into combined_extras_dict has a potential bug and is inefficient. It only considers keys from the first dictionary in extras_list, ignoring any unique keys present in other dictionaries. Additionally, it iterates through extras_list for each key, which is inefficient for large lists.

A more robust and efficient approach would be to iterate through each dictionary once and collect all key-value pairs.

Suggested change
combined_extras_dict = {}
if extras_list and extras_list[0]:
for key in extras_list[0].keys():
combined_extras_dict[key] = [d[key] for d in extras_list if key in d]
combined_extras_dict = {}
for extras in extras_list:
if extras:
for key, value in extras.items():
combined_extras_dict.setdefault(key, []).append(value)


return combined_reward_tensor, combined_extras_dict

def fit(self):
"""
Expand All @@ -334,6 +394,7 @@ def fit(self):
to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""

from omegaconf import OmegaConf

from verl.utils.tracking import Tracking
Expand Down Expand Up @@ -397,7 +458,7 @@ def fit(self):
with marked_timer("step", timing_raw):
# wait for the previous batch
with marked_timer("wait_prev_gen", timing_raw, color="red"):
epoch, batch, gen_batch_output = batch_data_future.get()
epoch, batch, gen_batch_output, future_reward = batch_data_future.get()
timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None)

Expand Down Expand Up @@ -431,8 +492,10 @@ def fit(self):
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)

# Use the pre-launched future reward if available
if self.config.reward_model.launch_reward_fn_async:
future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer)
# future_reward was already started in _async_gen_next_batch
reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
else:
reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)

Expand Down Expand Up @@ -490,8 +553,6 @@ def fit(self):
with marked_timer("adv", timing_raw, color="brown"):
# we combine with rule-based rm
reward_extra_infos_dict: dict[str, list]
if self.config.reward_model.launch_reward_fn_async:
reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
batch.batch["token_level_scores"] = reward_tensor

if reward_extra_infos_dict:
Expand Down Expand Up @@ -552,6 +613,12 @@ def fit(self):
reward_extra_infos_dict=reward_extra_infos_dict,
dump_path=rollout_data_dir,
)
#########################################################
# Print timing info for this step
print(f"\nStep {self.global_steps} timing:")
for k, v in timing_raw.items():
print(f" {k}: {v:.4f} s")
#########################################################

# validate
if (
Expand Down Expand Up @@ -605,4 +672,4 @@ def fit(self):
if is_last_step:
pprint(f"Final validation metrics: {last_val_metrics}")
progress_bar.close()
return
return
12 changes: 6 additions & 6 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,12 +829,12 @@ def generate_sequences(self, prompts: DataProto):
prompts = prompts.to(get_device_id())

meta_info = {
"eos_token_id": self.model_config.generation_config.eos_token_id
if self.model_config.generation_config is not None
else self.model_config.tokenizer.eos_token_id,
"pad_token_id": self.model_config.generation_config.pad_token_id
if self.model_config.generation_config is not None
else self.model_config.tokenizer.pad_token_id,
"eos_token_id": self.generation_config.eos_token_id
if self.generation_config is not None
else self.tokenizer.eos_token_id,
"pad_token_id": self.generation_config.pad_token_id
if self.generation_config is not None
else self.tokenizer.pad_token_id,
}
prompts.meta_info.update(meta_info)

Expand Down
Loading