Skip to content

Commit 6b1255f

Browse files
[recipe] fix: Fix a Typo in One_Step_Off_Policy and Add async of Generative Reward Model in Response Generation (volcengine#3369)
Fix a typo in verl/workers/fsdp_workers.py: original code: if self.model_config.generation_config is not None updated code: if self.generation_config is not None Add async of generation reward model (GRM): As the generative reward model is slow in the call. It is unreasonable to wait for all responses to be generated before sending to GRM for evaluation. So I add an async to start GRM evaluation once individual response generation is finished. --------- Co-authored-by: zhichao (jimmy) <[email protected]>
1 parent 832df12 commit 6b1255f

File tree

2 files changed

+85
-15
lines changed

2 files changed

+85
-15
lines changed

recipe/one_step_off_policy/ray_trainer.py

Lines changed: 79 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,32 +60,36 @@ class GenerationBatchFuture:
6060
Wrapper class for encapsulating batch generation results
6161
"""
6262

63-
def __init__(self, epoch, batch, gen_batch_output):
63+
def __init__(self, epoch, batch, gen_batch_output, future_reward=None):
6464
"""
6565
:param epoch: current epoch
6666
:param batch: Input batch data
6767
:param gen_batch_output: Generated sequences from the main model (DataProtoFuture)
68+
:param future_reward: Future for reward computation (optional)
6869
"""
6970
self.epoch = epoch
7071
self.batch = batch
7172
self.gen_batch_output = gen_batch_output
73+
self.future_reward = future_reward
7274

7375
def get(self):
7476
"""
7577
Get the actual results by calling get() method on gen_batch_output
7678
7779
Returns:
78-
tuple: (batch, gen_batch_result)
80+
tuple: (epoch, batch, gen_batch_result, future_reward)
81+
- epoch: Current epoch
7982
- batch: Original input batch data
8083
- gen_batch_result: Result from gen_batch_output.get() or gen_batch_output itself
84+
- future_reward: Future for reward computation if available, else None
8185
"""
8286
# Call get() method on gen_batch_output if available
8387
if hasattr(self.gen_batch_output, "get"):
8488
gen_batch_result = self.gen_batch_output.get()
8589
else:
8690
gen_batch_result = self.gen_batch_output
8791

88-
return self.epoch, self.batch, gen_batch_result
92+
return self.epoch, self.batch, gen_batch_result, self.future_reward
8993

9094

9195
class OneStepOffRayTrainer(RayPPOTrainer):
@@ -315,7 +319,10 @@ def _async_gen_next_batch(self, continuous_iterator):
315319
except Exception as e:
316320
print(f"Error in async_gen_next_batch: {e}")
317321
return None
322+
323+
# Create the initial batch from the data loader
318324
batch = DataProto.from_single_dict(batch_dict)
325+
319326
# pop those keys for generation
320327
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
321328
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
@@ -327,16 +334,68 @@ def _async_gen_next_batch(self, continuous_iterator):
327334
non_tensor_batch_keys_to_pop.append("tools_kwargs")
328335
if "interaction_kwargs" in batch.non_tensor_batch:
329336
non_tensor_batch_keys_to_pop.append("interaction_kwargs")
337+
330338
gen_batch = batch.pop(
331339
batch_keys=batch_keys_to_pop,
332340
non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
333341
)
334342
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
343+
335344
# sync weights from actor to rollout
336345
self.sync_rollout_weights()
346+
337347
# async generation
338348
gen_batch_output = self.rollout_wg.async_generate_sequences(gen_batch)
339-
return GenerationBatchFuture(epoch, batch, gen_batch_output)
349+
350+
# Launch individual reward computations as each generation completes
351+
future_reward = None
352+
if self.config.reward_model.launch_reward_fn_async:
353+
# Store the object reference and set up callback
354+
future_reward = self._launch_individual_rewards.remote(
355+
gen_batch_output, self.config, self.tokenizer, batch.non_tensor_batch
356+
)
357+
358+
# Return the original, now-modified `batch` and the `future_reward`
359+
return GenerationBatchFuture(epoch, batch, gen_batch_output, future_reward)
360+
361+
@staticmethod
362+
@ray.remote
363+
def _launch_individual_rewards(gen_batch_output, config, tokenizer, original_non_tensor_batch):
364+
# Get generation results
365+
gen_batch_result = gen_batch_output.get()
366+
367+
# Repeat non_tensor_batch to match the number of responses
368+
n = config.actor_rollout_ref.rollout.n
369+
repeated_non_tensor_batch = {}
370+
for key, value in original_non_tensor_batch.items():
371+
repeated_non_tensor_batch[key] = np.repeat(value, n, axis=0)
372+
373+
# Split into individual responses with preserved non_tensor_batch
374+
responses_split = []
375+
for i in range(len(gen_batch_result)):
376+
response_data = gen_batch_result[i : i + 1] # Get single response
377+
# Add repeated non_tensor_batch values
378+
for key in repeated_non_tensor_batch:
379+
response_data.non_tensor_batch[key] = repeated_non_tensor_batch[key][i : i + 1]
380+
responses_split.append(response_data)
381+
382+
# Launch async reward computation
383+
reward_futures = [
384+
compute_reward_async.remote(response_data, config, tokenizer) for response_data in responses_split
385+
]
386+
387+
# Wait for results and combine
388+
results = ray.get(reward_futures)
389+
rewards_list = [r[0] for r in results]
390+
extras_list = [r[1] for r in results]
391+
392+
combined_reward_tensor = torch.cat(rewards_list, dim=0)
393+
combined_extras_dict = {}
394+
if extras_list and extras_list[0]:
395+
for key in extras_list[0].keys():
396+
combined_extras_dict[key] = [d[key] for d in extras_list if key in d]
397+
398+
return combined_reward_tensor, combined_extras_dict
340399

341400
def fit(self):
342401
"""
@@ -345,6 +404,7 @@ def fit(self):
345404
to construct the PPO dataflow.
346405
The light-weight advantage computation is done on the driver process.
347406
"""
407+
348408
from omegaconf import OmegaConf
349409

350410
from verl.utils.tracking import Tracking
@@ -408,7 +468,7 @@ def fit(self):
408468
with marked_timer("step", timing_raw):
409469
# wait for the previous batch
410470
with marked_timer("wait_prev_gen", timing_raw, color="red"):
411-
epoch, batch, gen_batch_output = batch_data_future.get()
471+
epoch, batch, gen_batch_output, future_reward = batch_data_future.get()
412472
timing_raw.update(gen_batch_output.meta_info["timing"])
413473
gen_batch_output.meta_info.pop("timing", None)
414474

@@ -442,8 +502,10 @@ def fit(self):
442502
reward_tensor = self.rm_wg.compute_rm_score(batch)
443503
batch = batch.union(reward_tensor)
444504

505+
# Use the pre-launched future reward if available
445506
if self.config.reward_model.launch_reward_fn_async:
446-
future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer)
507+
# future_reward was already started in _async_gen_next_batch
508+
reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
447509
else:
448510
reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)
449511

@@ -501,8 +563,6 @@ def fit(self):
501563
with marked_timer("adv", timing_raw, color="brown"):
502564
# we combine with rule-based rm
503565
reward_extra_infos_dict: dict[str, list]
504-
if self.config.reward_model.launch_reward_fn_async:
505-
reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
506566
batch.batch["token_level_scores"] = reward_tensor
507567

508568
if reward_extra_infos_dict:
@@ -552,7 +612,17 @@ def fit(self):
552612
# Log rollout generations if enabled
553613
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
554614
if rollout_data_dir:
555-
self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir)
615+
with marked_timer("dump_rollout_generations", timing_raw, color="green"):
616+
inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
617+
outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
618+
scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
619+
self._dump_generations(
620+
inputs=inputs,
621+
outputs=outputs,
622+
scores=scores,
623+
reward_extra_infos_dict=reward_extra_infos_dict,
624+
dump_path=rollout_data_dir,
625+
)
556626

557627
# validate
558628
if (

verl/workers/fsdp_workers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -883,12 +883,12 @@ def generate_sequences(self, prompts: DataProto):
883883
prompts = prompts.to(get_device_id())
884884

885885
meta_info = {
886-
"eos_token_id": self.model_config.generation_config.eos_token_id
887-
if self.model_config.generation_config is not None
888-
else self.model_config.tokenizer.eos_token_id,
889-
"pad_token_id": self.model_config.generation_config.pad_token_id
890-
if self.model_config.generation_config is not None
891-
else self.model_config.tokenizer.pad_token_id,
886+
"eos_token_id": self.generation_config.eos_token_id
887+
if self.generation_config is not None
888+
else self.tokenizer.eos_token_id,
889+
"pad_token_id": self.generation_config.pad_token_id
890+
if self.generation_config is not None
891+
else self.tokenizer.pad_token_id,
892892
}
893893
prompts.meta_info.update(meta_info)
894894

0 commit comments

Comments
 (0)