@@ -60,32 +60,36 @@ class GenerationBatchFuture:
6060 Wrapper class for encapsulating batch generation results
6161 """
6262
63- def __init__ (self , epoch , batch , gen_batch_output ):
63+ def __init__ (self , epoch , batch , gen_batch_output , future_reward = None ):
6464 """
6565 :param epoch: current epoch
6666 :param batch: Input batch data
6767 :param gen_batch_output: Generated sequences from the main model (DataProtoFuture)
68+ :param future_reward: Future for reward computation (optional)
6869 """
6970 self .epoch = epoch
7071 self .batch = batch
7172 self .gen_batch_output = gen_batch_output
73+ self .future_reward = future_reward
7274
7375 def get (self ):
7476 """
7577 Get the actual results by calling get() method on gen_batch_output
7678
7779 Returns:
78- tuple: (batch, gen_batch_result)
80+ tuple: (epoch, batch, gen_batch_result, future_reward)
81+ - epoch: Current epoch
7982 - batch: Original input batch data
8083 - gen_batch_result: Result from gen_batch_output.get() or gen_batch_output itself
84+ - future_reward: Future for reward computation if available, else None
8185 """
8286 # Call get() method on gen_batch_output if available
8387 if hasattr (self .gen_batch_output , "get" ):
8488 gen_batch_result = self .gen_batch_output .get ()
8589 else :
8690 gen_batch_result = self .gen_batch_output
8791
88- return self .epoch , self .batch , gen_batch_result
92+ return self .epoch , self .batch , gen_batch_result , self . future_reward
8993
9094
9195class OneStepOffRayTrainer (RayPPOTrainer ):
@@ -315,7 +319,10 @@ def _async_gen_next_batch(self, continuous_iterator):
315319 except Exception as e :
316320 print (f"Error in async_gen_next_batch: { e } " )
317321 return None
322+
323+ # Create the initial batch from the data loader
318324 batch = DataProto .from_single_dict (batch_dict )
325+
319326 # pop those keys for generation
320327 batch_keys_to_pop = ["input_ids" , "attention_mask" , "position_ids" ]
321328 non_tensor_batch_keys_to_pop = ["raw_prompt_ids" ]
@@ -327,16 +334,68 @@ def _async_gen_next_batch(self, continuous_iterator):
327334 non_tensor_batch_keys_to_pop .append ("tools_kwargs" )
328335 if "interaction_kwargs" in batch .non_tensor_batch :
329336 non_tensor_batch_keys_to_pop .append ("interaction_kwargs" )
337+
330338 gen_batch = batch .pop (
331339 batch_keys = batch_keys_to_pop ,
332340 non_tensor_batch_keys = non_tensor_batch_keys_to_pop ,
333341 )
334342 gen_batch = gen_batch .repeat (repeat_times = self .config .actor_rollout_ref .rollout .n , interleave = True )
343+
335344 # sync weights from actor to rollout
336345 self .sync_rollout_weights ()
346+
337347 # async generation
338348 gen_batch_output = self .rollout_wg .async_generate_sequences (gen_batch )
339- return GenerationBatchFuture (epoch , batch , gen_batch_output )
349+
350+ # Launch individual reward computations as each generation completes
351+ future_reward = None
352+ if self .config .reward_model .launch_reward_fn_async :
353+ # Store the object reference and set up callback
354+ future_reward = self ._launch_individual_rewards .remote (
355+ gen_batch_output , self .config , self .tokenizer , batch .non_tensor_batch
356+ )
357+
358+ # Return the original, now-modified `batch` and the `future_reward`
359+ return GenerationBatchFuture (epoch , batch , gen_batch_output , future_reward )
360+
361+ @staticmethod
362+ @ray .remote
363+ def _launch_individual_rewards (gen_batch_output , config , tokenizer , original_non_tensor_batch ):
364+ # Get generation results
365+ gen_batch_result = gen_batch_output .get ()
366+
367+ # Repeat non_tensor_batch to match the number of responses
368+ n = config .actor_rollout_ref .rollout .n
369+ repeated_non_tensor_batch = {}
370+ for key , value in original_non_tensor_batch .items ():
371+ repeated_non_tensor_batch [key ] = np .repeat (value , n , axis = 0 )
372+
373+ # Split into individual responses with preserved non_tensor_batch
374+ responses_split = []
375+ for i in range (len (gen_batch_result )):
376+ response_data = gen_batch_result [i : i + 1 ] # Get single response
377+ # Add repeated non_tensor_batch values
378+ for key in repeated_non_tensor_batch :
379+ response_data .non_tensor_batch [key ] = repeated_non_tensor_batch [key ][i : i + 1 ]
380+ responses_split .append (response_data )
381+
382+ # Launch async reward computation
383+ reward_futures = [
384+ compute_reward_async .remote (response_data , config , tokenizer ) for response_data in responses_split
385+ ]
386+
387+ # Wait for results and combine
388+ results = ray .get (reward_futures )
389+ rewards_list = [r [0 ] for r in results ]
390+ extras_list = [r [1 ] for r in results ]
391+
392+ combined_reward_tensor = torch .cat (rewards_list , dim = 0 )
393+ combined_extras_dict = {}
394+ if extras_list and extras_list [0 ]:
395+ for key in extras_list [0 ].keys ():
396+ combined_extras_dict [key ] = [d [key ] for d in extras_list if key in d ]
397+
398+ return combined_reward_tensor , combined_extras_dict
340399
341400 def fit (self ):
342401 """
@@ -345,6 +404,7 @@ def fit(self):
345404 to construct the PPO dataflow.
346405 The light-weight advantage computation is done on the driver process.
347406 """
407+
348408 from omegaconf import OmegaConf
349409
350410 from verl .utils .tracking import Tracking
@@ -408,7 +468,7 @@ def fit(self):
408468 with marked_timer ("step" , timing_raw ):
409469 # wait for the previous batch
410470 with marked_timer ("wait_prev_gen" , timing_raw , color = "red" ):
411- epoch , batch , gen_batch_output = batch_data_future .get ()
471+ epoch , batch , gen_batch_output , future_reward = batch_data_future .get ()
412472 timing_raw .update (gen_batch_output .meta_info ["timing" ])
413473 gen_batch_output .meta_info .pop ("timing" , None )
414474
@@ -442,8 +502,10 @@ def fit(self):
442502 reward_tensor = self .rm_wg .compute_rm_score (batch )
443503 batch = batch .union (reward_tensor )
444504
505+ # Use the pre-launched future reward if available
445506 if self .config .reward_model .launch_reward_fn_async :
446- future_reward = compute_reward_async .remote (batch , self .config , self .tokenizer )
507+ # future_reward was already started in _async_gen_next_batch
508+ reward_tensor , reward_extra_infos_dict = ray .get (future_reward )
447509 else :
448510 reward_tensor , reward_extra_infos_dict = compute_reward (batch , self .reward_fn )
449511
@@ -501,8 +563,6 @@ def fit(self):
501563 with marked_timer ("adv" , timing_raw , color = "brown" ):
502564 # we combine with rule-based rm
503565 reward_extra_infos_dict : dict [str , list ]
504- if self .config .reward_model .launch_reward_fn_async :
505- reward_tensor , reward_extra_infos_dict = ray .get (future_reward )
506566 batch .batch ["token_level_scores" ] = reward_tensor
507567
508568 if reward_extra_infos_dict :
@@ -552,7 +612,17 @@ def fit(self):
552612 # Log rollout generations if enabled
553613 rollout_data_dir = self .config .trainer .get ("rollout_data_dir" , None )
554614 if rollout_data_dir :
555- self ._log_rollout_data (batch , reward_extra_infos_dict , timing_raw , rollout_data_dir )
615+ with marked_timer ("dump_rollout_generations" , timing_raw , color = "green" ):
616+ inputs = self .tokenizer .batch_decode (batch .batch ["prompts" ], skip_special_tokens = True )
617+ outputs = self .tokenizer .batch_decode (batch .batch ["responses" ], skip_special_tokens = True )
618+ scores = batch .batch ["token_level_scores" ].sum (- 1 ).cpu ().tolist ()
619+ self ._dump_generations (
620+ inputs = inputs ,
621+ outputs = outputs ,
622+ scores = scores ,
623+ reward_extra_infos_dict = reward_extra_infos_dict ,
624+ dump_path = rollout_data_dir ,
625+ )
556626
557627 # validate
558628 if (
0 commit comments