Skip to content

Commit 01bc8e9

Browse files
authored
vision food reasoning eval (#331)
* vision food reasoning eval * fix tool call count * fix tests
1 parent 2ae135a commit 01bc8e9

22 files changed

+314
-181
lines changed

eval_protocol/benchmarks/test_aime25.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
EvaluationRow,
66
Message,
77
MetricResult,
8+
ChatCompletionContentPartParam,
89
ChatCompletionContentPartTextParam,
910
)
1011
from eval_protocol.pytest.default_single_turn_rollout_process import (
@@ -18,10 +19,12 @@
1819

1920

2021
def _coerce_content_to_str(
21-
content: str | list[ChatCompletionContentPartTextParam] | None,
22+
content: str | list[ChatCompletionContentPartParam] | None,
2223
) -> str:
2324
if isinstance(content, list):
24-
return "".join([getattr(p, "text", str(p)) for p in content])
25+
return "".join(
26+
getattr(p, "text", str(p)) if isinstance(p, ChatCompletionContentPartTextParam) else "" for p in content
27+
)
2528
return str(content or "")
2629

2730

eval_protocol/benchmarks/test_gpqa.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
EvaluationRow,
1111
Message,
1212
MetricResult,
13+
ChatCompletionContentPartParam,
1314
ChatCompletionContentPartTextParam,
1415
)
1516
from eval_protocol.pytest.default_single_turn_rollout_process import (
@@ -54,10 +55,12 @@ def _load_gpqa_messages_from_csv() -> list[list[list[Message]]]:
5455

5556

5657
def _coerce_content_to_str(
57-
content: str | list[ChatCompletionContentPartTextParam] | None,
58+
content: str | list[ChatCompletionContentPartParam] | None,
5859
) -> str:
5960
if isinstance(content, list):
60-
return "".join([getattr(p, "text", str(p)) for p in content])
61+
return "".join(
62+
getattr(p, "text", str(p)) if isinstance(p, ChatCompletionContentPartTextParam) else "" for p in content
63+
)
6164
return str(content or "")
6265

6366

eval_protocol/benchmarks/test_livebench_data_analysis.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
EvaluationRow,
99
Message,
1010
MetricResult,
11+
ChatCompletionContentPartParam,
1112
ChatCompletionContentPartTextParam,
1213
)
1314
from eval_protocol.pytest.default_single_turn_rollout_process import (
@@ -37,9 +38,11 @@ def _extract_last_boxed_segment(text: str) -> Optional[str]:
3738
return matches[-1]
3839

3940

40-
def _coerce_content_to_str(content: str | list[ChatCompletionContentPartTextParam] | None) -> str:
41+
def _coerce_content_to_str(content: str | list[ChatCompletionContentPartParam] | None) -> str:
4142
if isinstance(content, list):
42-
return "".join([getattr(p, "text", str(p)) for p in content])
43+
return "".join(
44+
getattr(p, "text", str(p)) if isinstance(p, ChatCompletionContentPartTextParam) else "" for p in content
45+
)
4346
return str(content or "")
4447

4548

eval_protocol/models.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,11 +466,46 @@ def __iter__(self):
466466
return iter(["text", "type"])
467467

468468

469+
class ChatCompletionContentPartImageParam(BaseModel):
470+
type: Literal["image_url"] = Field("image_url", description="The type of the content part.")
471+
image_url: Dict[str, Any] = Field(
472+
..., description="Image descriptor (e.g., {'url': 'data:image/png;base64,...', 'detail': 'high'})."
473+
)
474+
475+
def __getitem__(self, key: str) -> Any:
476+
if key == "image_url":
477+
return self.image_url
478+
if key == "type":
479+
return self.type
480+
raise KeyError(key)
481+
482+
def get(self, key: str, default: Any = None) -> Any:
483+
try:
484+
return self[key]
485+
except KeyError:
486+
return default
487+
488+
def keys(self):
489+
return (k for k in ("image_url", "type"))
490+
491+
def values(self):
492+
return (self.image_url, self.type)
493+
494+
def items(self):
495+
return [("image_url", self.image_url), ("type", self.type)]
496+
497+
def __iter__(self):
498+
return iter(["image_url", "type"])
499+
500+
501+
ChatCompletionContentPartParam = Union[ChatCompletionContentPartTextParam, ChatCompletionContentPartImageParam]
502+
503+
469504
class Message(BaseModel):
470505
"""Chat message model with trajectory evaluation support."""
471506

472507
role: str # assistant, user, system, tool
473-
content: Optional[Union[str, List[ChatCompletionContentPartTextParam]]] = Field(
508+
content: Optional[Union[str, List[ChatCompletionContentPartParam]]] = Field(
474509
default="", description="The content of the message."
475510
)
476511
reasoning_content: Optional[str] = Field(

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
1414
from eval_protocol.mcp.execution.policy import LiteLLMPolicy
1515
from eval_protocol.mcp.mcp_multi_client import MCPMultiClient
16-
from eval_protocol.models import EvaluationRow, Message, ChatCompletionContentPartTextParam
16+
from eval_protocol.models import (
17+
EvaluationRow,
18+
Message,
19+
ChatCompletionContentPartParam,
20+
ChatCompletionContentPartTextParam,
21+
)
1722
from openai.types import CompletionUsage
1823
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1924
from eval_protocol.pytest.types import Dataset, RolloutProcessorConfig
@@ -98,7 +103,7 @@ def append_message_and_log(self, message: Message):
98103
self.messages.append(message)
99104
self.logger.log(self.evaluation_row)
100105

101-
async def call_agent(self) -> Optional[Union[str, List[ChatCompletionContentPartTextParam]]]:
106+
async def call_agent(self) -> Optional[Union[str, List[ChatCompletionContentPartParam]]]:
102107
"""
103108
Call the assistant with the user query.
104109
"""
@@ -222,7 +227,7 @@ def _get_content_from_tool_result(self, tool_result: CallToolResult | str) -> Li
222227

223228
def _format_tool_message_content(
224229
self, content: List[TextContent]
225-
) -> Union[str, List[ChatCompletionContentPartTextParam]]:
230+
) -> Union[str, List[ChatCompletionContentPartParam]]:
226231
"""Format tool result content for inclusion in a tool message.
227232
228233
- If a single text item, return plain string per OpenAI semantics.

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,17 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
166166
row.execution_metadata.tool_call_count = (
167167
len(converted_tool_calls) if converted_tool_calls is not None else 0
168168
)
169-
row.execution_metadata.usage = (
170-
CompletionUsage( # Note: LiteLLM sets usage dynamically via setattr(), not as a typed field
171-
prompt_tokens=response.usage.prompt_tokens, # pyright: ignore[reportAttributeAccessIssue]
172-
completion_tokens=response.usage.completion_tokens, # pyright: ignore[reportAttributeAccessIssue]
173-
total_tokens=response.usage.total_tokens, # pyright: ignore[reportAttributeAccessIssue]
169+
usage = getattr(response, "usage", None)
170+
if usage:
171+
row.execution_metadata.usage = (
172+
CompletionUsage( # Note: LiteLLM sets usage dynamically via setattr(), not as a typed field
173+
prompt_tokens=getattr(usage, "prompt_tokens", 0),
174+
completion_tokens=getattr(usage, "completion_tokens", 0),
175+
total_tokens=getattr(usage, "total_tokens", 0),
176+
)
174177
)
175-
)
178+
else:
179+
row.execution_metadata.usage = None
176180

177181
row.messages = messages
178182

eval_protocol/rewards/accuracy.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,28 @@
1010
import re
1111
from typing import Any, Callable, Dict, List, Optional, Union, cast
1212

13-
from ..models import EvaluateResult, Message, MetricResult, ChatCompletionContentPartTextParam
13+
from ..models import (
14+
EvaluateResult,
15+
Message,
16+
MetricResult,
17+
ChatCompletionContentPartParam,
18+
ChatCompletionContentPartTextParam,
19+
)
1420

1521

16-
def _to_text(content: Optional[Union[str, List[ChatCompletionContentPartTextParam]]]) -> str:
22+
def _to_text(content: Optional[Union[str, List[ChatCompletionContentPartParam]]]) -> str:
1723
"""Coerce Message.content into a plain string for regex and comparisons."""
1824
if content is None:
1925
return ""
2026
if isinstance(content, str):
2127
return content
2228
# List[ChatCompletionContentPartTextParam]
2329
try:
24-
return "\n".join(part.text for part in content)
30+
texts: List[str] = []
31+
for part in content:
32+
if isinstance(part, ChatCompletionContentPartTextParam):
33+
texts.append(part.text)
34+
return "\n".join(texts)
2535
except Exception:
2636
return ""
2737

eval_protocol/rewards/json_schema.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@
22
import re
33
from typing import Any, Dict, List, Optional, Union
44

5-
from ..models import EvaluateResult, Message, MetricResult, ChatCompletionContentPartTextParam
5+
from ..models import (
6+
EvaluateResult,
7+
Message,
8+
MetricResult,
9+
ChatCompletionContentPartParam,
10+
ChatCompletionContentPartTextParam,
11+
)
612
from ..typed_interface import reward_function
713
from .function_calling import (
814
calculate_jaccard_similarity,
@@ -59,8 +65,10 @@ def json_schema_reward(
5965
content_text = last_message.content
6066
else:
6167
try:
62-
parts: List[ChatCompletionContentPartTextParam] = last_message.content # type: ignore[assignment]
63-
content_text = "\n".join(getattr(p, "text", "") for p in parts)
68+
parts: List[ChatCompletionContentPartParam] = last_message.content # type: ignore[assignment]
69+
content_text = "\n".join(
70+
getattr(p, "text", "") for p in parts if isinstance(p, ChatCompletionContentPartTextParam)
71+
)
6472
except Exception:
6573
content_text = ""
6674
else:

eval_protocol/rewards/language_consistency.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
import re
1010
from typing import Any, Dict, List, Optional, Set, Tuple, Union
1111

12-
from ..models import EvaluateResult, Message, MetricResult, ChatCompletionContentPartTextParam
12+
from ..models import (
13+
EvaluateResult,
14+
Message,
15+
MetricResult,
16+
ChatCompletionContentPartParam,
17+
ChatCompletionContentPartTextParam,
18+
)
1319
from ..typed_interface import reward_function
1420

1521
# Dictionary mapping language codes to common words/patterns in that language
@@ -573,13 +579,17 @@ def language_consistency_reward(
573579
},
574580
)
575581

576-
def _to_text(content: Union[str, List[ChatCompletionContentPartTextParam], None]) -> str:
582+
def _to_text(content: Union[str, List[ChatCompletionContentPartParam], None]) -> str:
577583
if content is None:
578584
return ""
579585
if isinstance(content, str):
580586
return content
581587
try:
582-
return "\n".join(part.text for part in content)
588+
texts: List[str] = []
589+
for part in content:
590+
if isinstance(part, ChatCompletionContentPartTextParam):
591+
texts.append(part.text)
592+
return "\n".join(texts)
583593
except Exception:
584594
return ""
585595

eval_protocol/rewards/repetition.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,26 @@
88
import re
99
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
1010

11-
from ..models import EvaluateResult, Message, MetricResult, ChatCompletionContentPartTextParam
11+
from ..models import (
12+
EvaluateResult,
13+
Message,
14+
MetricResult,
15+
ChatCompletionContentPartParam,
16+
ChatCompletionContentPartTextParam,
17+
)
1218

1319

14-
def _to_text(content: Optional[Union[str, List[ChatCompletionContentPartTextParam]]]) -> str:
20+
def _to_text(content: Optional[Union[str, List[ChatCompletionContentPartParam]]]) -> str:
1521
if content is None:
1622
return ""
1723
if isinstance(content, str):
1824
return content
1925
try:
20-
return "\n".join(part.text for part in content)
26+
texts: List[str] = []
27+
for part in content:
28+
if isinstance(part, ChatCompletionContentPartTextParam):
29+
texts.append(part.text)
30+
return "\n".join(texts)
2131
except Exception:
2232
return ""
2333

0 commit comments

Comments
 (0)