Skip to content

Commit 74b5326

Browse files
committed
verified rollouts
1 parent 3fbfa48 commit 74b5326

File tree

2 files changed

+61
-8
lines changed

2 files changed

+61
-8
lines changed

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -139,16 +139,22 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
139139
tool_calls=converted_tool_calls,
140140
)
141141
]
142-
row.execution_metadata.usage = (
143-
CompletionUsage( # Note: LiteLLM sets usage dynamically via setattr(), not as a typed field
144-
prompt_tokens=response.usage.prompt_tokens, # pyright: ignore[reportAttributeAccessIssue]
145-
completion_tokens=response.usage.completion_tokens, # pyright: ignore[reportAttributeAccessIssue]
146-
total_tokens=response.usage.total_tokens, # pyright: ignore[reportAttributeAccessIssue]
147-
)
148-
)
149-
150142
row.messages = messages
151143

144+
usage = getattr(response, "usage", None)
145+
if usage is not None:
146+
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
147+
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
148+
total_tokens = getattr(usage, "total_tokens", None)
149+
if total_tokens is None:
150+
total_tokens = prompt_tokens + completion_tokens
151+
152+
row.execution_metadata.usage = CompletionUsage(
153+
prompt_tokens=prompt_tokens,
154+
completion_tokens=completion_tokens,
155+
total_tokens=total_tokens,
156+
)
157+
152158
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
153159

154160
default_logger.log(row)

tests/pytest/test_single_turn_rollout_processor.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,50 @@ async def fake_acompletion(**kwargs):
116116
assert [m["role"] for m in sent_msgs] == ["user", "assistant"]
117117
assert [m.role for m in out.messages] == ["user", "assistant", "assistant"]
118118
assert out.messages[-1].content == "Hello again"
119+
120+
121+
@pytest.mark.asyncio
122+
async def test_single_turn_handles_missing_usage_block(monkeypatch):
123+
row = EvaluationRow(messages=[Message(role="user", content="Describe the picture")])
124+
125+
import eval_protocol.pytest.default_single_turn_rollout_process as mod
126+
127+
class StubChoices:
128+
pass
129+
130+
class StubModelResponse:
131+
def __init__(self, text: str):
132+
self.choices = [StubChoices()]
133+
self.choices[0].message = SimpleNamespace(content=text, tool_calls=None)
134+
self.usage = None
135+
136+
async def fake_acompletion(**kwargs):
137+
return StubModelResponse(text="It looks like creme brulee")
138+
139+
class StubLogger:
140+
def __init__(self):
141+
self.logged = []
142+
143+
def log(self, row):
144+
self.logged.append(row)
145+
146+
def read(self, rollout_id=None):
147+
return list(self.logged)
148+
149+
stub_logger = StubLogger()
150+
151+
monkeypatch.setattr(mod, "ModelResponse", StubModelResponse, raising=True)
152+
monkeypatch.setattr(mod, "Choices", StubChoices, raising=True)
153+
monkeypatch.setattr(mod, "acompletion", fake_acompletion, raising=True)
154+
monkeypatch.setattr(mod, "default_logger", stub_logger, raising=False)
155+
156+
processor = SingleTurnRolloutProcessor()
157+
config = _DummyConfig()
158+
159+
tasks = processor([row], config)
160+
out = await tasks[0]
161+
162+
assert [m.role for m in out.messages] == ["user", "assistant"]
163+
assert out.messages[-1].content == "It looks like creme brulee"
164+
# Usage should remain unset when the provider omits it
165+
assert out.execution_metadata.usage is None

0 commit comments

Comments
 (0)