Skip to content

Commit 511b5d5

Browse files
SStasclaude
andcommitted
Add cross-model text baseline pipeline for GSM8K 2-agent
New pipeline_text_cross_model.py: Agent A (model X) generates text, Agent B (model Y) reads that text as context. This is the text baseline needed to interpret rosetta cross-model results — without it, we can't tell if rosetta projection adds value over simply piping text between different models. Integrated into run_gsm8k_2agent.py as --mode text_cross_model. Model B loading shared with rosetta mode to avoid duplicate loads. Usage: python benchmarks/gsm8k_2agent/run_gsm8k_2agent.py \ --mode text_cross_model \ --model_name Qwen/Qwen2.5-7B-Instruct \ --model_b meta-llama/Llama-3.2-3B-Instruct \ --max_samples 200 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 60f7406 commit 511b5d5

File tree

2 files changed

+229
-8
lines changed

2 files changed

+229
-8
lines changed
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
"""Cross-model text pipeline: 2-agent chain where Researcher (model A) passes text to Solver (model B).
2+
3+
This is the text baseline for cross-model comparison. Both agents communicate via
4+
text (like pipeline_text.py), but each agent runs on a different model. This lets
5+
us measure whether rosetta projection adds value over simply piping text between
6+
different models.
7+
"""
8+
9+
import time
10+
from typing import Any, Dict, List
11+
12+
from benchmarks.shared.generation import generate_text, render_prompt, tokenize_prompt
13+
from benchmarks.shared.metrics import gpu_memory_tracker
14+
from .agents import AGENTS, build_text_prompt
15+
from .evaluate import extract_gold, extract_gsm8k_answer, check_correct
16+
17+
18+
def run_text_cross_model_pipeline(
19+
model_a: Any,
20+
tokenizer_a: Any,
21+
model_b: Any,
22+
tokenizer_b: Any,
23+
device: str,
24+
question: str,
25+
gold_solution: str,
26+
max_new_tokens: int = 512,
27+
temperature: float = 0.7,
28+
top_p: float = 0.95,
29+
verbose: bool = False,
30+
) -> Dict:
31+
"""Run the 2-agent cross-model text pipeline on a single GSM8K problem.
32+
33+
Researcher (model A) generates text analysis.
34+
Solver (model B) receives that text in its prompt and generates the answer.
35+
"""
36+
with gpu_memory_tracker(device) as mem:
37+
t0 = time.perf_counter()
38+
agent_traces: List[Dict] = []
39+
total_prompt_tokens = 0
40+
total_output_tokens = 0
41+
total_context_tokens = 0
42+
43+
researcher = AGENTS[0]
44+
solver = AGENTS[1]
45+
46+
# --- Agent 1: Researcher on model A ---
47+
agent_t0 = time.perf_counter()
48+
messages = build_text_prompt(researcher.role, question)
49+
prompt_text = render_prompt(tokenizer_a, messages)
50+
input_ids, attention_mask = tokenize_prompt(tokenizer_a, prompt_text, device)
51+
prompt_tokens = int(input_ids.shape[-1])
52+
total_prompt_tokens += prompt_tokens
53+
54+
researcher_text, _ = generate_text(
55+
model_a, tokenizer_a, input_ids, attention_mask, device,
56+
max_new_tokens=max_new_tokens,
57+
temperature=temperature,
58+
top_p=top_p,
59+
)
60+
61+
output_encoded = tokenizer_a(researcher_text, add_special_tokens=False)
62+
output_tokens = len(output_encoded["input_ids"])
63+
total_output_tokens += output_tokens
64+
agent_time_ms = (time.perf_counter() - agent_t0) * 1000
65+
66+
agent_traces.append({
67+
"name": researcher.name,
68+
"role": researcher.role,
69+
"model": "model_a",
70+
"prompt_tokens": prompt_tokens,
71+
"output_tokens": output_tokens,
72+
"context_tokens": 0,
73+
"agent_time_ms": agent_time_ms,
74+
"output": researcher_text,
75+
})
76+
77+
if verbose:
78+
print(f" [{researcher.name} (A)] output ({len(researcher_text)} chars): "
79+
f"{researcher_text[:200]}...")
80+
81+
# --- Agent 2: Solver on model B ---
82+
agent_t0 = time.perf_counter()
83+
84+
# Count context tokens — Researcher's text re-tokenized by model B's tokenizer
85+
context_encoded = tokenizer_b(researcher_text, add_special_tokens=False)
86+
context_token_count = len(context_encoded["input_ids"])
87+
total_context_tokens += context_token_count
88+
89+
messages = build_text_prompt(solver.role, question, researcher_text)
90+
prompt_text = render_prompt(tokenizer_b, messages)
91+
input_ids, attention_mask = tokenize_prompt(tokenizer_b, prompt_text, device)
92+
prompt_tokens = int(input_ids.shape[-1])
93+
total_prompt_tokens += prompt_tokens
94+
95+
solver_text, _ = generate_text(
96+
model_b, tokenizer_b, input_ids, attention_mask, device,
97+
max_new_tokens=max_new_tokens,
98+
temperature=temperature,
99+
top_p=top_p,
100+
)
101+
102+
output_encoded = tokenizer_b(solver_text, add_special_tokens=False)
103+
output_tokens = len(output_encoded["input_ids"])
104+
total_output_tokens += output_tokens
105+
agent_time_ms = (time.perf_counter() - agent_t0) * 1000
106+
107+
agent_traces.append({
108+
"name": solver.name,
109+
"role": solver.role,
110+
"model": "model_b",
111+
"prompt_tokens": prompt_tokens,
112+
"output_tokens": output_tokens,
113+
"context_tokens": context_token_count,
114+
"agent_time_ms": agent_time_ms,
115+
"output": solver_text,
116+
})
117+
118+
if verbose:
119+
print(f" [{solver.name} (B)] output ({len(solver_text)} chars): "
120+
f"{solver_text[:200]}...")
121+
122+
wall_time = time.perf_counter() - t0
123+
124+
total_tokens = total_prompt_tokens + total_output_tokens
125+
tokens_per_sec = total_tokens / wall_time if wall_time > 0 else 0
126+
127+
gold = extract_gold(gold_solution)
128+
prediction = extract_gsm8k_answer(agent_traces[-1]["output"])
129+
correct = check_correct(prediction, gold)
130+
131+
return {
132+
"question": question,
133+
"gold": gold,
134+
"prediction": prediction,
135+
"raw_output": agent_traces[-1]["output"],
136+
"correct": correct,
137+
"wall_time": wall_time,
138+
"total_prompt_tokens": total_prompt_tokens,
139+
"total_output_tokens": total_output_tokens,
140+
"total_tokens": total_tokens,
141+
"total_context_tokens": total_context_tokens,
142+
"tokens_per_sec": tokens_per_sec,
143+
"peak_memory_mb": mem["peak_memory_mb"],
144+
"agents": agent_traces,
145+
"mode": "text_cross_model",
146+
}
147+
148+
149+
def run_text_cross_model_benchmark(
150+
model_a: Any,
151+
tokenizer_a: Any,
152+
model_b: Any,
153+
tokenizer_b: Any,
154+
device: str,
155+
dataset: List[Dict],
156+
max_new_tokens: int = 512,
157+
temperature: float = 0.7,
158+
top_p: float = 0.95,
159+
verbose: bool = False,
160+
) -> List[Dict]:
161+
"""Run cross-model text pipeline on a list of GSM8K samples."""
162+
results = []
163+
for i, sample in enumerate(dataset):
164+
if verbose:
165+
print(f"\n[TextCrossModel] Sample {i + 1}/{len(dataset)}: "
166+
f"{sample['question'][:80]}...")
167+
168+
result = run_text_cross_model_pipeline(
169+
model_a, tokenizer_a, model_b, tokenizer_b, device,
170+
question=sample["question"],
171+
gold_solution=sample["answer"],
172+
max_new_tokens=max_new_tokens,
173+
temperature=temperature,
174+
top_p=top_p,
175+
verbose=verbose,
176+
)
177+
results.append(result)
178+
179+
if verbose:
180+
status = "CORRECT" if result["correct"] else "WRONG"
181+
print(f" => {status} (pred={result['prediction']}, gold={result['gold']}, "
182+
f"time={result['wall_time']:.1f}s)")
183+
else:
184+
correct = sum(1 for r in results if r["correct"])
185+
print(f" [TextCrossModel] {i + 1}/{len(dataset)} "
186+
f"({correct}/{i + 1} correct, {result['wall_time']:.1f}s)",
187+
flush=True)
188+
189+
return results

benchmarks/gsm8k_2agent/run_gsm8k_2agent.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def parse_args() -> argparse.Namespace:
3939
)
4040
parser.add_argument(
4141
"--mode",
42-
choices=["latent", "text", "direct", "rosetta", "both", "all"],
42+
choices=["latent", "text", "direct", "rosetta", "text_cross_model", "both", "all"],
4343
default="all",
4444
help="Pipeline(s) to run (default: all)",
4545
)
@@ -111,19 +111,20 @@ def run_benchmark(config: dict) -> dict:
111111
run_latent = mode in ("latent", "both", "all")
112112
run_text = mode in ("text", "both", "all")
113113
run_rosetta = mode in ("rosetta", "all")
114+
run_text_cross_model = mode in ("text_cross_model", "all")
114115

115116
print(f"Device: {device}")
116117
print(f"Mode: {mode}")
117118
print(f"Model A: {model_name}")
118-
if run_rosetta:
119+
if run_rosetta or run_text_cross_model:
119120
print(f"Model B: {model_b_name}")
120121
print(f"Samples: {max_samples}")
121122
print(f"Latent steps: {latent_steps}")
122123
print(f"Max new tokens: {max_new_tokens}")
123124
print(f"Temperature: {temperature}")
124125
print(f"Seed: {seed}")
125126
print(f"Pipelines: direct={run_direct}, text={run_text}, latent={run_latent}, "
126-
f"rosetta={run_rosetta}")
127+
f"rosetta={run_rosetta}, text_cross_model={run_text_cross_model}")
127128
print()
128129

129130
dataset = load_dataset(max_samples)
@@ -133,6 +134,7 @@ def run_benchmark(config: dict) -> dict:
133134
latent_results = None
134135
text_results = None
135136
rosetta_results = None
137+
text_cross_model_results = None
136138

137139
if run_direct:
138140
from benchmarks.gsm8k_2agent.pipeline_direct import run_direct_benchmark
@@ -178,6 +180,29 @@ def run_benchmark(config: dict) -> dict:
178180
top_p=top_p, verbose=verbose,
179181
)
180182

183+
# Load model B if needed for cross-model modes
184+
model_b = tokenizer_b = connector_b = identity_b = None
185+
if run_rosetta or run_text_cross_model:
186+
model_b, tokenizer_b, connector_b, identity_b = load_model(model_b_name, device)
187+
188+
if run_text_cross_model:
189+
from benchmarks.gsm8k_2agent.pipeline_text_cross_model import run_text_cross_model_benchmark
190+
191+
print("\n" + "=" * 50)
192+
print("Running TEXT CROSS-MODEL (A generates text → B reads text) pipeline...")
193+
print(f" Model A (Researcher): {model_name}")
194+
print(f" Model B (Solver): {model_b_name}")
195+
print("=" * 50)
196+
set_seed(seed)
197+
198+
text_cross_model_results = run_text_cross_model_benchmark(
199+
model_a=model, tokenizer_a=tokenizer,
200+
model_b=model_b, tokenizer_b=tokenizer_b,
201+
device=device, dataset=dataset,
202+
max_new_tokens=max_new_tokens, temperature=temperature,
203+
top_p=top_p, verbose=verbose,
204+
)
205+
181206
if run_rosetta:
182207
from benchmarks.gsm8k_2agent.pipeline_rosetta import run_rosetta_benchmark
183208
from avp.rosetta.calibrate import calibrate
@@ -189,9 +214,6 @@ def run_benchmark(config: dict) -> dict:
189214
print("=" * 50)
190215
set_seed(seed)
191216

192-
# Load model B (model A is already loaded)
193-
model_b, tokenizer_b, connector_b, identity_b = load_model(model_b_name, device)
194-
195217
# Calibrate once — instant for same-family vocab-mediated
196218
print("Calibrating Rosetta Stone projection...")
197219
avp_map = calibrate(
@@ -213,7 +235,8 @@ def run_benchmark(config: dict) -> dict:
213235
num_transfer_states=num_transfer_states,
214236
)
215237

216-
# Free model B to reclaim GPU memory
238+
# Free model B to reclaim GPU memory
239+
if model_b is not None:
217240
del model_b, tokenizer_b, connector_b, identity_b
218241
if device == "cuda":
219242
import torch
@@ -231,6 +254,8 @@ def run_benchmark(config: dict) -> dict:
231254
modes.append(("Text", 13, text_results))
232255
if rosetta_results is not None:
233256
modes.append(("Rosetta", 13, rosetta_results))
257+
if text_cross_model_results is not None:
258+
modes.append(("Text Cross-Model", 16, text_cross_model_results))
234259

235260
# Compute agreement across available modes
236261
available = {}
@@ -242,6 +267,8 @@ def run_benchmark(config: dict) -> dict:
242267
available["latent"] = latent_results
243268
if rosetta_results is not None:
244269
available["rosetta"] = rosetta_results
270+
if text_cross_model_results is not None:
271+
available["text_cross_model"] = text_cross_model_results
245272
agreement_data = compute_agreement(available) if len(available) > 1 else None
246273

247274
print_summary(
@@ -262,7 +289,7 @@ def run_benchmark(config: dict) -> dict:
262289
"config": {
263290
"benchmark": "gsm8k_2agent",
264291
"model_a": model_name,
265-
"model_b": model_b_name if run_rosetta else None,
292+
"model_b": model_b_name if (run_rosetta or run_text_cross_model) else None,
266293
"device": device,
267294
"mode": mode,
268295
"max_samples": max_samples,
@@ -293,6 +320,11 @@ def run_benchmark(config: dict) -> dict:
293320
"summary": compute_stats(rosetta_results),
294321
"samples": rosetta_results,
295322
}
323+
if text_cross_model_results is not None:
324+
output_data["text_cross_model"] = {
325+
"summary": compute_stats(text_cross_model_results),
326+
"samples": text_cross_model_results,
327+
}
296328
if agreement_data is not None:
297329
output_data["agreement"] = agreement_data
298330

0 commit comments

Comments
 (0)