Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 79 additions & 9 deletions recipe/one_step_off_policy/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,32 +60,36 @@ class GenerationBatchFuture:
Wrapper class for encapsulating batch generation results
"""

def __init__(self, epoch, batch, gen_batch_output):
def __init__(self, epoch, batch, gen_batch_output, future_reward=None):
"""
:param epoch: current epoch
:param batch: Input batch data
:param gen_batch_output: Generated sequences from the main model (DataProtoFuture)
:param future_reward: Future for reward computation (optional)
"""
self.epoch = epoch
self.batch = batch
self.gen_batch_output = gen_batch_output
self.future_reward = future_reward

def get(self):
"""
Get the actual results by calling get() method on gen_batch_output

Returns:
tuple: (batch, gen_batch_result)
tuple: (epoch, batch, gen_batch_result, future_reward)
- epoch: Current epoch
- batch: Original input batch data
- gen_batch_result: Result from gen_batch_output.get() or gen_batch_output itself
- future_reward: Future for reward computation if available, else None
"""
# Call get() method on gen_batch_output if available
if hasattr(self.gen_batch_output, "get"):
gen_batch_result = self.gen_batch_output.get()
else:
gen_batch_result = self.gen_batch_output

return self.epoch, self.batch, gen_batch_result
return self.epoch, self.batch, gen_batch_result, self.future_reward


class OneStepOffRayTrainer(RayPPOTrainer):
Expand Down Expand Up @@ -315,7 +319,10 @@ def _async_gen_next_batch(self, continuous_iterator):
except Exception as e:
print(f"Error in async_gen_next_batch: {e}")
return None

# Create the initial batch from the data loader
batch = DataProto.from_single_dict(batch_dict)

# pop those keys for generation
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
Expand All @@ -327,16 +334,68 @@ def _async_gen_next_batch(self, continuous_iterator):
non_tensor_batch_keys_to_pop.append("tools_kwargs")
if "interaction_kwargs" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("interaction_kwargs")

gen_batch = batch.pop(
batch_keys=batch_keys_to_pop,
non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
)
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)

# sync weights from actor to rollout
self.sync_rollout_weights()

# async generation
gen_batch_output = self.rollout_wg.async_generate_sequences(gen_batch)
return GenerationBatchFuture(epoch, batch, gen_batch_output)

# Launch individual reward computations as each generation completes
future_reward = None
if self.config.reward_model.launch_reward_fn_async:
# Store the object reference and set up callback
future_reward = self._launch_individual_rewards.remote(
gen_batch_output, self.config, self.tokenizer, batch.non_tensor_batch
)

# Return the original, now-modified `batch` and the `future_reward`
return GenerationBatchFuture(epoch, batch, gen_batch_output, future_reward)

@staticmethod
@ray.remote
def _launch_individual_rewards(gen_batch_output, config, tokenizer, original_non_tensor_batch):
# Get generation results
gen_batch_result = gen_batch_output.get()

# Repeat non_tensor_batch to match the number of responses
n = config.actor_rollout_ref.rollout.n
repeated_non_tensor_batch = {}
for key, value in original_non_tensor_batch.items():
repeated_non_tensor_batch[key] = np.repeat(value, n, axis=0)

# Split into individual responses with preserved non_tensor_batch
responses_split = []
for i in range(len(gen_batch_result)):
response_data = gen_batch_result[i : i + 1] # Get single response
# Add repeated non_tensor_batch values
for key in repeated_non_tensor_batch:
response_data.non_tensor_batch[key] = repeated_non_tensor_batch[key][i : i + 1]
responses_split.append(response_data)

# Launch async reward computation
reward_futures = [
compute_reward_async.remote(response_data, config, tokenizer) for response_data in responses_split
]

# Wait for results and combine
results = ray.get(reward_futures)
rewards_list = [r[0] for r in results]
extras_list = [r[1] for r in results]

combined_reward_tensor = torch.cat(rewards_list, dim=0)
combined_extras_dict = {}
if extras_list and extras_list[0]:
for key in extras_list[0].keys():
combined_extras_dict[key] = [d[key] for d in extras_list if key in d]
Comment on lines +393 to +396
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation for combining extras_list into combined_extras_dict has a potential bug and is inefficient. It only considers keys from the first dictionary in extras_list, ignoring any unique keys present in other dictionaries. Additionally, it iterates through extras_list for each key, which is inefficient for large lists.

A more robust and efficient approach would be to iterate through each dictionary once and collect all key-value pairs.

Suggested change
combined_extras_dict = {}
if extras_list and extras_list[0]:
for key in extras_list[0].keys():
combined_extras_dict[key] = [d[key] for d in extras_list if key in d]
combined_extras_dict = {}
for extras in extras_list:
if extras:
for key, value in extras.items():
combined_extras_dict.setdefault(key, []).append(value)


return combined_reward_tensor, combined_extras_dict

def fit(self):
"""
Expand All @@ -345,6 +404,7 @@ def fit(self):
to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""

from omegaconf import OmegaConf

from verl.utils.tracking import Tracking
Expand Down Expand Up @@ -408,7 +468,7 @@ def fit(self):
with marked_timer("step", timing_raw):
# wait for the previous batch
with marked_timer("wait_prev_gen", timing_raw, color="red"):
epoch, batch, gen_batch_output = batch_data_future.get()
epoch, batch, gen_batch_output, future_reward = batch_data_future.get()
timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None)

Expand Down Expand Up @@ -442,8 +502,10 @@ def fit(self):
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)

# Use the pre-launched future reward if available
if self.config.reward_model.launch_reward_fn_async:
future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer)
# future_reward was already started in _async_gen_next_batch
reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
else:
reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)

Expand Down Expand Up @@ -501,8 +563,6 @@ def fit(self):
with marked_timer("adv", timing_raw, color="brown"):
# we combine with rule-based rm
reward_extra_infos_dict: dict[str, list]
if self.config.reward_model.launch_reward_fn_async:
reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
batch.batch["token_level_scores"] = reward_tensor

if reward_extra_infos_dict:
Expand Down Expand Up @@ -552,7 +612,17 @@ def fit(self):
# Log rollout generations if enabled
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
if rollout_data_dir:
self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir)
with marked_timer("dump_rollout_generations", timing_raw, color="green"):
inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
self._dump_generations(
inputs=inputs,
outputs=outputs,
scores=scores,
reward_extra_infos_dict=reward_extra_infos_dict,
dump_path=rollout_data_dir,
)

# validate
if (
Expand Down
12 changes: 6 additions & 6 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,12 +883,12 @@ def generate_sequences(self, prompts: DataProto):
prompts = prompts.to(get_device_id())

meta_info = {
"eos_token_id": self.model_config.generation_config.eos_token_id
if self.model_config.generation_config is not None
else self.model_config.tokenizer.eos_token_id,
"pad_token_id": self.model_config.generation_config.pad_token_id
if self.model_config.generation_config is not None
else self.model_config.tokenizer.pad_token_id,
"eos_token_id": self.generation_config.eos_token_id
if self.generation_config is not None
else self.tokenizer.eos_token_id,
"pad_token_id": self.generation_config.pad_token_id
if self.generation_config is not None
else self.tokenizer.pad_token_id,
}
prompts.meta_info.update(meta_info)

Expand Down
Loading