Skip to content

Commit 311518e

Browse files
Merge pull request #317 from listar2000/fix-color-print-display-issue
Fix color print display issue
2 parents bc79e63 + 90e1b4c commit 311518e

File tree

7 files changed

+221
-229
lines changed

7 files changed

+221
-229
lines changed

rllm/engine/agent_execution_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
compute_mc_return,
1818
compute_trajectory_reward,
1919
)
20-
from rllm.misc import colorful_print
2120
from rllm.parser import ChatTemplateParser
21+
from rllm.utils import colorful_print
2222

2323
logger = logging.getLogger(__name__)
2424

rllm/engine/agent_workflow_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from rllm.agents.agent import Episode
1313
from rllm.engine.rollout import ModelOutput, RolloutEngine
14-
from rllm.misc import colorful_print
14+
from rllm.utils import colorful_print
1515
from rllm.workflows.workflow import TerminationReason, Workflow
1616

1717
# Avoid hard dependency on verl at import time; only for typing

rllm/misc.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,11 @@
33
"""
44

55
import random
6-
import warnings
76

8-
import click
97
import numpy as np
108
from PIL import Image
119

1210

13-
def colorful_print(string: str, *args, **kwargs) -> None:
14-
end = kwargs.pop("end", "\n")
15-
print(click.style(string, *args, **kwargs), end=end, flush=True)
16-
17-
18-
def colorful_warning(string: str, *args, **kwargs) -> None:
19-
warnings.warn(click.style(string, *args, **kwargs), stacklevel=2)
20-
21-
2211
def get_image(image_path):
2312
with Image.open(image_path) as img:
2413
return img.convert("RGB")

rllm/trainer/verl/agent_ppo_trainer.py

Lines changed: 17 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -710,107 +710,24 @@ def _transform_agent_trajectories(self, trajectories: list[dict]):
710710

711711
def visualize_trajectory(self, tensor_batch, sample_idx=0, max_samples=1, mask_key="response_mask"):
712712
"""
713-
Visualize the trajectory from tensor_batch by detokenizing prompts and responses,
714-
and highlighting the masked parts with color.
715-
716-
Args:
717-
tensor_batch: The tensor batch containing trajectory data
718-
sample_idx: Starting index of samples to visualize
719-
max_samples: Maximum number of samples to visualize
713+
Visualize the trajectory from tensor_batch using the shared visualization utility.
720714
"""
721-
from rllm.misc import colorful_print
722-
723-
# Get the relevant tensors
724-
prompts = tensor_batch.batch["prompts"]
725-
responses = tensor_batch.batch["responses"]
726-
traj_mask = tensor_batch.batch[mask_key]
727-
token_level_scores = tensor_batch.batch["token_level_scores"]
728-
729-
# Full attention mask (covers prompt + response); split it into prompt and response parts
730-
full_attn_mask = tensor_batch.batch["attention_mask"]
731-
prompt_len = prompts.shape[1]
732-
resp_len = responses.shape[1]
733-
prompt_attn_mask = full_attn_mask[:, :prompt_len]
734-
response_attn_mask = full_attn_mask[:, -resp_len:]
735-
736-
batch_size = prompts.shape[0]
737-
end_idx = min(sample_idx + max_samples, batch_size)
738-
739-
for i in range(sample_idx, end_idx):
740-
colorful_print("\n" + "=" * 60, fg="cyan", bold=True)
741-
colorful_print(f"Sample {i}", fg="cyan", bold=True)
742-
743-
# Legend before the example
744-
legend = " ".join(
745-
[
746-
"\x1b[37mwhite=masked\x1b[0m",
747-
"\x1b[34mblue=unmasked\x1b[0m",
748-
"\x1b[42m green bg=reward>0 \x1b[0m",
749-
"\x1b[41m red bg=reward<=0 \x1b[0m",
750-
]
751-
)
752-
print(f"[{legend}]")
753-
754-
# Detokenize prompt
755-
prompt_tokens = prompts[i]
756-
prompt_valid_mask = prompt_attn_mask[i].bool()
757-
# Build one-line colored prompt (prompt is always masked-from-loss => white)
758-
prompt_parts = []
759-
for tok_id, is_valid in zip(prompt_tokens.tolist(), prompt_valid_mask.tolist(), strict=False):
760-
if not is_valid:
761-
continue
762-
tok = self.tokenizer.decode([tok_id]).replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t")
763-
prompt_parts.append(f"\x1b[37m{tok}\x1b[0m") # white
764-
print("".join(prompt_parts))
765-
766-
# Separator line between prompt and response for readability
767-
print("----------------")
768-
769-
# Detokenize response with token-level highlighting
770-
resp_tokens = responses[i]
771-
resp_valid_mask = response_attn_mask[i].bool()
772-
loss_mask = traj_mask[i]
773-
rewards = token_level_scores[i]
774-
775-
# Pre-compute reward positions (typically only the last valid resp token has nonzero reward)
776-
reward_idx = None
777-
reward_value = 0.0
778-
if rewards is not None:
779-
# consider only valid response positions
780-
for j, is_valid in enumerate(resp_valid_mask.tolist()):
781-
if not is_valid:
782-
continue
783-
val = float(rewards[j].item()) if hasattr(rewards[j], "item") else float(rewards[j])
784-
if abs(val) > 1e-9:
785-
reward_idx = j
786-
reward_value = val
787-
788-
# Fallback: if no nonzero reward found, use the last valid response token
789-
if reward_idx is None:
790-
valid_indices = [idx for idx, v in enumerate(resp_valid_mask.tolist()) if v]
791-
if valid_indices:
792-
reward_idx = valid_indices[-1]
793-
if rewards is not None:
794-
val = float(rewards[reward_idx].item()) if hasattr(rewards[reward_idx], "item") else float(rewards[reward_idx])
795-
reward_value = val
796-
797-
# Colors: white for masked-from-loss; blue for contributes-to-loss; overlay background red/green if reward token
798-
response_parts = []
799-
for j, tok_id in enumerate(resp_tokens.tolist()):
800-
if not bool(resp_valid_mask[j].item() if hasattr(resp_valid_mask[j], "item") else resp_valid_mask[j]):
801-
continue
802-
tok = self.tokenizer.decode([tok_id]).replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t")
803-
804-
contributes = bool(loss_mask[j].item()) if hasattr(loss_mask[j], "item") else bool(loss_mask[j])
805-
fg = "\x1b[34m" if contributes else "\x1b[37m" # blue if in loss, else white
806-
807-
bg = ""
808-
if reward_idx is not None and j == reward_idx:
809-
bg = "\x1b[42m" if reward_value > 0 else "\x1b[41m" # green background for positive, red for negative/zero
810-
811-
response_parts.append(f"{bg}{fg}{tok}\x1b[0m")
812-
813-
print("".join(response_parts))
715+
from rllm.utils.visualization import visualize_trajectories
716+
717+
if len(tensor_batch) == 0:
718+
return
719+
720+
end_idx = min(sample_idx + max_samples, len(tensor_batch))
721+
indices = list(range(sample_idx, end_idx))
722+
723+
visualize_trajectories(
724+
batch=tensor_batch,
725+
tokenizer=self.tokenizer,
726+
sample_indices=indices,
727+
mask_key=mask_key,
728+
reward_key="token_level_scores",
729+
show_workflow_metadata=False,
730+
)
814731

815732
def generate_agent_trajectories_async(self, timing_raw=None, meta_info=None, mode="Token"):
816733
"""

rllm/trainer/verl/agent_workflow_trainer.py

Lines changed: 16 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -662,126 +662,27 @@ def shutdown(self):
662662

663663
def visualize_trajectory_last_step(self, tensor_batch, sample_idx=0, max_samples=1):
664664
"""
665-
Visualize last steps from a workflow rollout:
666-
- detokenize prompts/responses
667-
- show token usage mask
668-
- show reward tokens (placed at the last response token)
669-
- print Correct/Incorrect using `is_correct` from non_tensors
665+
Visualize last steps from a workflow rollout using the shared visualization utility.
670666
"""
671-
from rllm.misc import colorful_print
667+
from rllm.utils.visualization import visualize_trajectories
672668

673669
# Select only last steps if stepwise-advantage is enabled
674670
if "is_last_step" in tensor_batch.non_tensor_batch:
675671
is_last = tensor_batch.non_tensor_batch["is_last_step"]
676672
if is_last is not None and len(is_last) == len(tensor_batch):
677673
tensor_batch = tensor_batch[is_last]
678674

679-
prompts = tensor_batch.batch["prompts"]
680-
responses = tensor_batch.batch["responses"]
681-
# Full attention mask (covers prompt + response); split it into prompt and response parts
682-
full_attn_mask = tensor_batch.batch["attention_mask"]
683-
prompt_len = prompts.shape[1]
684-
resp_len = responses.shape[1]
685-
prompt_attn_mask = full_attn_mask[:, :prompt_len]
686-
response_attn_mask = full_attn_mask[:, -resp_len:]
687-
688-
# Loss mask over the response tokens only
689-
response_loss_mask = tensor_batch.batch.get("response_mask")
690-
691-
# Rewards aligned to response tokens
692-
token_level_scores = tensor_batch.batch.get("step_rewards" if self.config.rllm.stepwise_advantage.enable and self.config.rllm.stepwise_advantage.mode == "per_step" else "traj_rewards")
693-
694-
# Optional meta to print outcome
695-
is_correct = tensor_batch.non_tensor_batch.get("is_correct", None)
696-
term_reasons = tensor_batch.non_tensor_batch.get("termination_reasons", None)
697-
episode_ids = tensor_batch.non_tensor_batch.get("episode_ids", None)
698-
trajectory_ids = tensor_batch.non_tensor_batch.get("trajectory_ids", None)
699-
700-
bsz = prompts.shape[0]
701-
end_idx = min(sample_idx + max_samples, bsz)
702-
703-
for i in range(sample_idx, end_idx):
704-
colorful_print("\n" + "=" * 60, fg="cyan", bold=True)
705-
# Header with ids
706-
if episode_ids is not None or trajectory_ids is not None:
707-
colorful_print(f"Episode: {episode_ids[i] if episode_ids is not None else '?'} | Traj: {trajectory_ids[i] if trajectory_ids is not None else '?'}", fg="cyan", bold=True)
708-
709-
# Outcome line
710-
if is_correct is not None:
711-
ok = bool(is_correct[i])
712-
colorful_print(f"Outcome: {'✓ Correct' if ok else '✗ Incorrect'}", fg=("green" if ok else "red"), bold=True)
713-
714-
if term_reasons is not None:
715-
colorful_print(f"Termination: {term_reasons[i]}", fg="yellow")
716-
717-
# Legend before the example
718-
legend = " ".join(
719-
[
720-
"\x1b[37mwhite=masked\x1b[0m",
721-
"\x1b[34mblue=unmasked\x1b[0m",
722-
"\x1b[42m green bg=reward>0 \x1b[0m",
723-
"\x1b[41m red bg=reward<=0 \x1b[0m",
724-
]
725-
)
726-
print(f"[{legend}]")
727-
728-
# Detokenize prompt
729-
prompt_tokens = prompts[i]
730-
prompt_valid_mask = prompt_attn_mask[i].bool()
731-
# Build one-line colored prompt (prompt is always masked-from-loss => white)
732-
prompt_parts = []
733-
for tok_id, is_valid in zip(prompt_tokens.tolist(), prompt_valid_mask.tolist(), strict=False):
734-
if not is_valid:
735-
continue
736-
tok = self.tokenizer.decode([tok_id]).replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t")
737-
prompt_parts.append(f"\x1b[37m{tok}\x1b[0m") # white
738-
print("".join(prompt_parts))
739-
740-
# Separator line between prompt and response for readability
741-
print("----------------")
742-
743-
# Detokenize response with token-level highlighting
744-
resp_tokens = responses[i]
745-
resp_valid_mask = response_attn_mask[i].bool()
746-
loss_mask = response_loss_mask[i] if response_loss_mask is not None else resp_valid_mask
747-
rewards = token_level_scores[i] if token_level_scores is not None else None
748-
749-
# Pre-compute reward positions (typically only the last valid resp token has nonzero reward)
750-
reward_idx = None
751-
reward_value = 0.0
752-
if rewards is not None:
753-
# consider only valid response positions
754-
for j, is_valid in enumerate(resp_valid_mask.tolist()):
755-
if not is_valid:
756-
continue
757-
val = float(rewards[j].item()) if hasattr(rewards[j], "item") else float(rewards[j])
758-
if abs(val) > 1e-9:
759-
reward_idx = j
760-
reward_value = val
761-
762-
# Fallback: if no nonzero reward found, use the last valid response token
763-
if reward_idx is None:
764-
valid_indices = [idx for idx, v in enumerate(resp_valid_mask.tolist()) if v]
765-
if valid_indices:
766-
reward_idx = valid_indices[-1]
767-
if rewards is not None:
768-
val = float(rewards[reward_idx].item()) if hasattr(rewards[reward_idx], "item") else float(rewards[reward_idx])
769-
reward_value = val
770-
771-
# Colors: white for masked-from-loss; blue for contributes-to-loss; overlay background red/green if reward token
772-
response_parts = []
773-
for j, tok_id in enumerate(resp_tokens.tolist()):
774-
if not bool(resp_valid_mask[j].item() if hasattr(resp_valid_mask[j], "item") else resp_valid_mask[j]):
775-
continue
776-
tok = self.tokenizer.decode([tok_id]).replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t")
777-
778-
contributes = bool(loss_mask[j].item()) if hasattr(loss_mask[j], "item") else bool(loss_mask[j])
779-
fg = "\x1b[34m" if contributes else "\x1b[37m" # blue if in loss, else white
780-
781-
bg = ""
782-
if reward_idx is not None and j == reward_idx:
783-
bg = "\x1b[42m" if reward_value > 0 else "\x1b[41m" # green background for positive, red for negative/zero
784-
785-
response_parts.append(f"{bg}{fg}{tok}\x1b[0m")
786-
787-
print("".join(response_parts))
675+
if len(tensor_batch) == 0:
676+
return
677+
678+
end_idx = min(sample_idx + max_samples, len(tensor_batch))
679+
indices = list(range(sample_idx, end_idx))
680+
681+
visualize_trajectories(
682+
batch=tensor_batch,
683+
tokenizer=self.tokenizer,
684+
sample_indices=indices,
685+
mask_key="response_mask",
686+
reward_key="step_rewards" if self.config.rllm.stepwise_advantage.enable and self.config.rllm.stepwise_advantage.mode == "per_step" else "traj_rewards",
687+
show_workflow_metadata=True,
688+
)

rllm/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22

33
from rllm.utils.compute_pass_at_k import compute_pass_at_k
44
from rllm.utils.episode_logger import EpisodeLogger
5+
from rllm.utils.visualization import VisualizationConfig, colorful_print, colorful_warning, visualize_trajectories
56

6-
__all__ = ["EpisodeLogger", "compute_pass_at_k"]
7+
__all__ = ["EpisodeLogger", "compute_pass_at_k", "visualize_trajectories", "VisualizationConfig", "colorful_print", "colorful_warning"]

0 commit comments

Comments
 (0)