@@ -662,126 +662,27 @@ def shutdown(self):
662662
663663 def visualize_trajectory_last_step (self , tensor_batch , sample_idx = 0 , max_samples = 1 ):
664664 """
665- Visualize last steps from a workflow rollout:
666- - detokenize prompts/responses
667- - show token usage mask
668- - show reward tokens (placed at the last response token)
669- - print Correct/Incorrect using `is_correct` from non_tensors
665+ Visualize last steps from a workflow rollout using the shared visualization utility.
670666 """
671- from rllm .misc import colorful_print
667+ from rllm .utils . visualization import visualize_trajectories
672668
673669 # Select only last steps if stepwise-advantage is enabled
674670 if "is_last_step" in tensor_batch .non_tensor_batch :
675671 is_last = tensor_batch .non_tensor_batch ["is_last_step" ]
676672 if is_last is not None and len (is_last ) == len (tensor_batch ):
677673 tensor_batch = tensor_batch [is_last ]
678674
679- prompts = tensor_batch .batch ["prompts" ]
680- responses = tensor_batch .batch ["responses" ]
681- # Full attention mask (covers prompt + response); split it into prompt and response parts
682- full_attn_mask = tensor_batch .batch ["attention_mask" ]
683- prompt_len = prompts .shape [1 ]
684- resp_len = responses .shape [1 ]
685- prompt_attn_mask = full_attn_mask [:, :prompt_len ]
686- response_attn_mask = full_attn_mask [:, - resp_len :]
687-
688- # Loss mask over the response tokens only
689- response_loss_mask = tensor_batch .batch .get ("response_mask" )
690-
691- # Rewards aligned to response tokens
692- token_level_scores = tensor_batch .batch .get ("step_rewards" if self .config .rllm .stepwise_advantage .enable and self .config .rllm .stepwise_advantage .mode == "per_step" else "traj_rewards" )
693-
694- # Optional meta to print outcome
695- is_correct = tensor_batch .non_tensor_batch .get ("is_correct" , None )
696- term_reasons = tensor_batch .non_tensor_batch .get ("termination_reasons" , None )
697- episode_ids = tensor_batch .non_tensor_batch .get ("episode_ids" , None )
698- trajectory_ids = tensor_batch .non_tensor_batch .get ("trajectory_ids" , None )
699-
700- bsz = prompts .shape [0 ]
701- end_idx = min (sample_idx + max_samples , bsz )
702-
703- for i in range (sample_idx , end_idx ):
704- colorful_print ("\n " + "=" * 60 , fg = "cyan" , bold = True )
705- # Header with ids
706- if episode_ids is not None or trajectory_ids is not None :
707- colorful_print (f"Episode: { episode_ids [i ] if episode_ids is not None else '?' } | Traj: { trajectory_ids [i ] if trajectory_ids is not None else '?' } " , fg = "cyan" , bold = True )
708-
709- # Outcome line
710- if is_correct is not None :
711- ok = bool (is_correct [i ])
712- colorful_print (f"Outcome: { '✓ Correct' if ok else '✗ Incorrect' } " , fg = ("green" if ok else "red" ), bold = True )
713-
714- if term_reasons is not None :
715- colorful_print (f"Termination: { term_reasons [i ]} " , fg = "yellow" )
716-
717- # Legend before the example
718- legend = " " .join (
719- [
720- "\x1b [37mwhite=masked\x1b [0m" ,
721- "\x1b [34mblue=unmasked\x1b [0m" ,
722- "\x1b [42m green bg=reward>0 \x1b [0m" ,
723- "\x1b [41m red bg=reward<=0 \x1b [0m" ,
724- ]
725- )
726- print (f"[{ legend } ]" )
727-
728- # Detokenize prompt
729- prompt_tokens = prompts [i ]
730- prompt_valid_mask = prompt_attn_mask [i ].bool ()
731- # Build one-line colored prompt (prompt is always masked-from-loss => white)
732- prompt_parts = []
733- for tok_id , is_valid in zip (prompt_tokens .tolist (), prompt_valid_mask .tolist (), strict = False ):
734- if not is_valid :
735- continue
736- tok = self .tokenizer .decode ([tok_id ]).replace ("\n " , "\\ n" ).replace ("\r " , "\\ r" ).replace ("\t " , "\\ t" )
737- prompt_parts .append (f"\x1b [37m{ tok } \x1b [0m" ) # white
738- print ("" .join (prompt_parts ))
739-
740- # Separator line between prompt and response for readability
741- print ("----------------" )
742-
743- # Detokenize response with token-level highlighting
744- resp_tokens = responses [i ]
745- resp_valid_mask = response_attn_mask [i ].bool ()
746- loss_mask = response_loss_mask [i ] if response_loss_mask is not None else resp_valid_mask
747- rewards = token_level_scores [i ] if token_level_scores is not None else None
748-
749- # Pre-compute reward positions (typically only the last valid resp token has nonzero reward)
750- reward_idx = None
751- reward_value = 0.0
752- if rewards is not None :
753- # consider only valid response positions
754- for j , is_valid in enumerate (resp_valid_mask .tolist ()):
755- if not is_valid :
756- continue
757- val = float (rewards [j ].item ()) if hasattr (rewards [j ], "item" ) else float (rewards [j ])
758- if abs (val ) > 1e-9 :
759- reward_idx = j
760- reward_value = val
761-
762- # Fallback: if no nonzero reward found, use the last valid response token
763- if reward_idx is None :
764- valid_indices = [idx for idx , v in enumerate (resp_valid_mask .tolist ()) if v ]
765- if valid_indices :
766- reward_idx = valid_indices [- 1 ]
767- if rewards is not None :
768- val = float (rewards [reward_idx ].item ()) if hasattr (rewards [reward_idx ], "item" ) else float (rewards [reward_idx ])
769- reward_value = val
770-
771- # Colors: white for masked-from-loss; blue for contributes-to-loss; overlay background red/green if reward token
772- response_parts = []
773- for j , tok_id in enumerate (resp_tokens .tolist ()):
774- if not bool (resp_valid_mask [j ].item () if hasattr (resp_valid_mask [j ], "item" ) else resp_valid_mask [j ]):
775- continue
776- tok = self .tokenizer .decode ([tok_id ]).replace ("\n " , "\\ n" ).replace ("\r " , "\\ r" ).replace ("\t " , "\\ t" )
777-
778- contributes = bool (loss_mask [j ].item ()) if hasattr (loss_mask [j ], "item" ) else bool (loss_mask [j ])
779- fg = "\x1b [34m" if contributes else "\x1b [37m" # blue if in loss, else white
780-
781- bg = ""
782- if reward_idx is not None and j == reward_idx :
783- bg = "\x1b [42m" if reward_value > 0 else "\x1b [41m" # green background for positive, red for negative/zero
784-
785- response_parts .append (f"{ bg } { fg } { tok } \x1b [0m" )
786-
787- print ("" .join (response_parts ))
675+ if len (tensor_batch ) == 0 :
676+ return
677+
678+ end_idx = min (sample_idx + max_samples , len (tensor_batch ))
679+ indices = list (range (sample_idx , end_idx ))
680+
681+ visualize_trajectories (
682+ batch = tensor_batch ,
683+ tokenizer = self .tokenizer ,
684+ sample_indices = indices ,
685+ mask_key = "response_mask" ,
686+ reward_key = "step_rewards" if self .config .rllm .stepwise_advantage .enable and self .config .rllm .stepwise_advantage .mode == "per_step" else "traj_rewards" ,
687+ show_workflow_metadata = True ,
688+ )
0 commit comments