diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml index 5fea510c34..3de7c64a6d 100644 --- a/.github/workflows/unittest.yml +++ b/.github/workflows/unittest.yml @@ -21,9 +21,9 @@ jobs: - name: Update setuptools run: | pip install setuptools==78.1.1 wheel==0.45.1 - - name: Install Full Dependencies + - name: Install Dev Dependencies run: | - pip install -q -e .[full] + pip install -q -e .[dev] pip install coverage pytest - name: Run tests with coverage run: | diff --git a/src/agentscope/_utils/_common.py b/src/agentscope/_utils/_common.py index aedee7da21..b51791480d 100644 --- a/src/agentscope/_utils/_common.py +++ b/src/agentscope/_utils/_common.py @@ -10,10 +10,11 @@ import types import typing from datetime import datetime -from typing import Union, Any, Callable +from typing import Union, Any, Callable, Type, Dict import requests from json_repair import repair_json +from pydantic import BaseModel from .._logging import logger @@ -208,3 +209,62 @@ def _remove_title_field(schema: dict) -> None: _remove_title_field( schema["additionalProperties"], ) + + +def _create_tool_from_base_model( + structured_model: Type[BaseModel], + tool_name: str = "generate_structured_output", +) -> Dict[str, Any]: + """Create a function tool definition from a Pydantic BaseModel. + This function converts a Pydantic BaseModel class into a tool definition + that can be used with function calling API. The resulting tool + definition includes the model's JSON schema as parameters, enabling + structured output generation by forcing the model to call this function + with properly formatted data. + + Args: + structured_model (`Type[BaseModel]`): + A Pydantic BaseModel class that defines the expected structure + for the tool's output. + tool_name (`str`, default `"generate_structured_output"`): + The tool name that used to force the LLM to generate structured + output by calling this function. + + Returns: + `Dict[str, Any]`: A tool definition dictionary compatible with + function calling API, containing type ("function") and + function dictionary with name, description, and parameters + (JSON schema). + + .. code-block:: python + :caption: Example usage + + from pydantic import BaseModel + + class PersonInfo(BaseModel): + name: str + age: int + email: str + + tool = _create_tool_from_base_model(PersonInfo, "extract_person") + print(tool["function"]["name"]) # extract_person + print(tool["type"]) # function + + .. note:: The function automatically removes the 'title' field from + the JSON schema to ensure compatibility with function calling + format. This is handled by the internal ``_remove_title_field()`` + function. + """ + schema = structured_model.model_json_schema() + + _remove_title_field(schema) + tool_definition = { + "type": "function", + "function": { + "name": tool_name, + "description": "Generate the required structured output with " + "this function", + "parameters": schema, + }, + } + return tool_definition diff --git a/src/agentscope/model/_anthropic_model.py b/src/agentscope/model/_anthropic_model.py index 2267c23ace..50aca10983 100644 --- a/src/agentscope/model/_anthropic_model.py +++ b/src/agentscope/model/_anthropic_model.py @@ -9,13 +9,20 @@ TYPE_CHECKING, List, Literal, + Type, ) from collections import OrderedDict +from pydantic import BaseModel + from ._model_base import ChatModelBase from ._model_response import ChatResponse from ._model_usage import ChatUsage -from .._utils._common import _json_loads_with_repair +from .._logging import logger +from .._utils._common import ( + _json_loads_with_repair, + _create_tool_from_base_model, +) from ..message import TextBlock, ToolUseBlock, ThinkingBlock from ..tracing import trace_llm from ..types._json import JSONSerializableObject @@ -97,6 +104,7 @@ async def __call__( tool_choice: Literal["auto", "none", "any", "required"] | str | None = None, + structured_model: Type[BaseModel] | None = None, **generate_kwargs: Any, ) -> ChatResponse | AsyncGenerator[ChatResponse, None]: """Get the response from Anthropic chat completions API by the given @@ -132,12 +140,26 @@ async def __call__( }, # More schemas here ] + tool_choice (`Literal["auto", "none", "any", "required"] | str \ | None`, default `None`): Controls which (if any) tool is called by the model. Can be "auto", "none", "any", "required", or specific tool name. For more details, please refer to https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use + structured_model (`Type[BaseModel] | None`, default `None`): + A Pydantic BaseModel class that defines the expected structure + for the model's output. When provided, the model will be forced + to return data that conforms to this schema by automatically + converting the BaseModel to a tool function and setting + `tool_choice` to enforce its usage. This enables structured + output generation. + + .. note:: When `structured_model` is specified, + both `tools` and `tool_choice` parameters are ignored, + and the model will only perform structured output + generation without calling any other tools. + **generate_kwargs (`Any`): The keyword arguments for Anthropic chat completions API, e.g. `temperature`, `top_p`, etc. Please @@ -164,6 +186,22 @@ async def __call__( self._validate_tool_choice(tool_choice, tools) kwargs["tool_choice"] = self._format_tool_choice(tool_choice) + if structured_model: + if tools or tool_choice: + logger.warning( + "structured_model is provided. Both 'tools' and " + "'tool_choice' parameters will be overridden and " + "ignored. The model will only perform structured output " + "generation without calling any other tools.", + ) + format_tool = _create_tool_from_base_model(structured_model) + kwargs["tools"] = self._format_tools_json_schemas( + [format_tool], + ) + kwargs["tool_choice"] = self._format_tool_choice( + format_tool["function"]["name"], + ) + # Extract the system message if messages[0]["role"] == "system": kwargs["system"] = messages[0]["content"] @@ -179,12 +217,14 @@ async def __call__( return self._parse_anthropic_stream_completion_response( start_datetime, response, + structured_model, ) # Non-streaming response parsed_response = await self._parse_anthropic_completion_response( start_datetime, response, + structured_model, ) return parsed_response @@ -193,10 +233,30 @@ async def _parse_anthropic_completion_response( self, start_datetime: datetime, response: Message, + structured_model: Type[BaseModel] | None = None, ) -> ChatResponse: - """Parse the Anthropic chat completion response into a `ChatResponse` - object.""" + """Given an Anthropic Message object, extract the content blocks and + usages from it. + + Args: + start_datetime (`datetime`): + The start datetime of the response generation. + response (`Message`): + Anthropic Message object to parse. + structured_model (`Type[BaseModel] | None`, default `None`): + A Pydantic BaseModel class that defines the expected structure + for the model's output. + + Returns: + ChatResponse (`ChatResponse`): + A ChatResponse object containing the content blocks and usage. + + .. note:: + If `structured_model` is not `None`, the expected structured output + will be stored in the metadata of the `ChatResponse`. + """ content_blocks: List[ThinkingBlock | TextBlock | ToolUseBlock] = [] + metadata = None if hasattr(response, "content") and response.content: for content_block in response.content: @@ -231,6 +291,8 @@ async def _parse_anthropic_completion_response( input=content_block.input, ), ) + if structured_model: + metadata = content_block.input usage = None if response.usage: @@ -243,6 +305,7 @@ async def _parse_anthropic_completion_response( parsed_response = ChatResponse( content=content_blocks, usage=usage, + metadata=metadata, ) return parsed_response @@ -251,9 +314,30 @@ async def _parse_anthropic_stream_completion_response( self, start_datetime: datetime, response: AsyncStream, + structured_model: Type[BaseModel] | None = None, ) -> AsyncGenerator[ChatResponse, None]: - """Parse the Anthropic chat completion response stream into an async - generator of `ChatResponse` objects.""" + """Given an Anthropic streaming response, extract the content blocks + and usages from it and yield ChatResponse objects. + + Args: + start_datetime (`datetime`): + The start datetime of the response generation. + response (`AsyncStream`): + Anthropic AsyncStream object to parse. + structured_model (`Type[BaseModel] | None`, default `None`): + A Pydantic BaseModel class that defines the expected structure + for the model's output. + + Returns: + `AsyncGenerator[ChatResponse, None]`: + An async generator that yields ChatResponse objects containing + the content blocks and usage information for each chunk in + the streaming response. + + .. note:: + If `structured_model` is not `None`, the expected structured output + will be stored in the metadata of the `ChatResponse`. + """ usage = None text_buffer = "" @@ -262,6 +346,7 @@ async def _parse_anthropic_stream_completion_response( tool_calls = OrderedDict() tool_call_buffers = {} res = None + metadata = None async for event in response: content_changed = False @@ -352,10 +437,13 @@ async def _parse_anthropic_stream_completion_response( input=input_obj, ), ) + if structured_model: + metadata = input_obj if contents: res = ChatResponse( content=contents, usage=usage, + metadata=metadata, ) yield res diff --git a/src/agentscope/model/_dashscope_model.py b/src/agentscope/model/_dashscope_model.py index 236cd9a551..89c05bdc8a 100644 --- a/src/agentscope/model/_dashscope_model.py +++ b/src/agentscope/model/_dashscope_model.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- -# pylint: disable=too-many-branches """The dashscope API model classes.""" import collections from datetime import datetime @@ -12,13 +11,18 @@ TYPE_CHECKING, List, Literal, + Type, ) +from pydantic import BaseModel from aioitertools import iter as giter from ._model_base import ChatModelBase from ._model_response import ChatResponse from ._model_usage import ChatUsage -from .._utils._common import _json_loads_with_repair +from .._utils._common import ( + _json_loads_with_repair, + _create_tool_from_base_model, +) from ..message import TextBlock, ToolUseBlock, ThinkingBlock from ..tracing import trace_llm from ..types import JSONSerializableObject @@ -100,6 +104,7 @@ async def __call__( tool_choice: Literal["auto", "none", "any", "required"] | str | None = None, + structured_model: Type[BaseModel] | None = None, **kwargs: Any, ) -> ChatResponse | AsyncGenerator[ChatResponse, None]: """Get the response from the dashscope @@ -121,6 +126,19 @@ async def __call__( Can be "auto", "none", or specific tool name. For more details, please refer to https://help.aliyun.com/zh/model-studio/qwen-function-calling + structured_model (`Type[BaseModel] | None`, default `None`): + A Pydantic BaseModel class that defines the expected structure + for the model's output. When provided, the model will be forced + to return data that conforms to this schema by automatically + converting the BaseModel to a tool function and setting + `tool_choice` to enforce its usage. This enables structured + output generation. + + .. note:: When `structured_model` is specified, + both `tools` and `tool_choice` parameters are ignored, + and the model will only perform structured output + generation without calling any other tools. + **kwargs (`Any`): The keyword arguments for DashScope chat completions API, e.g. `temperature`, `max_tokens`, `top_p`, etc. Please @@ -161,6 +179,22 @@ async def __call__( if self.enable_thinking and "enable_thinking" not in kwargs: kwargs["enable_thinking"] = self.enable_thinking + if structured_model: + if tools or tool_choice: + logger.warning( + "structured_model is provided. Both 'tools' and " + "'tool_choice' parameters will be overridden and " + "ignored. The model will only perform structured output " + "generation without calling any other tools.", + ) + format_tool = _create_tool_from_base_model(structured_model) + kwargs["tools"] = self._format_tools_json_schemas( + [format_tool], + ) + kwargs["tool_choice"] = self._format_tool_choice( + format_tool["function"]["name"], + ) + start_datetime = datetime.now() if self.model_name.startswith("qvq") or "-vl" in self.model_name: response = dashscope.MultiModalConversation.call( @@ -178,15 +212,18 @@ async def __call__( return self._parse_dashscope_stream_response( start_datetime, response, + structured_model, ) parsed_response = await self._parse_dashscope_generation_response( - used_time=(datetime.now() - start_datetime).total_seconds(), - response=response, + start_datetime, + response, + structured_model, ) return parsed_response + # pylint: disable=too-many-branches async def _parse_dashscope_stream_response( self, start_datetime: datetime, @@ -194,12 +231,37 @@ async def _parse_dashscope_stream_response( AsyncGenerator[GenerationResponse, None], Generator[MultiModalConversationResponse, None, None], ], + structured_model: Type[BaseModel] | None = None, ) -> AsyncGenerator[ChatResponse, Any]: - """Parse the DashScope GenerationResponse object and return a - ChatResponse object.""" + """Given a DashScope streaming response generator, extract the content + blocks and usages from it and yield ChatResponse objects. + + Args: + start_datetime (`datetime`): + The start datetime of the response generation. + response ( + `Union[AsyncGenerator[GenerationResponse, None], Generator[ \ + MultiModalConversationResponse, None, None]]` + ): + DashScope streaming response generator (GenerationResponse or + MultiModalConversationResponse) to parse. + structured_model (`Type[BaseModel] | None`, default `None`): + A Pydantic BaseModel class that defines the expected structure + for the model's output. + + Returns: + AsyncGenerator[ChatResponse, Any]: + An async generator that yields ChatResponse objects containing + the content blocks and usage information for each chunk in the + streaming response. + + .. note:: + If `structured_model` is not `None`, the expected structured output + will be stored in the metadata of the `ChatResponse`. + """ acc_content, acc_thinking_content = "", "" acc_tool_calls = collections.defaultdict(dict) - parsed_chunk = None + metadata = None async for chunk in giter(response): if chunk.status_code != HTTPStatus.OK: @@ -281,6 +343,9 @@ async def _parse_dashscope_stream_response( ), ) + if structured_model: + metadata = repaired_input + usage = None if chunk.usage: usage = ChatUsage( @@ -292,38 +357,48 @@ async def _parse_dashscope_stream_response( parsed_chunk = ChatResponse( content=content_blocks, usage=usage, + metadata=metadata, ) yield parsed_chunk async def _parse_dashscope_generation_response( self, - used_time: float, + start_datetime: datetime, response: Union[ GenerationResponse, MultiModalConversationResponse, ], + structured_model: Type[BaseModel] | None = None, ) -> ChatResponse: """Given a DashScope GenerationResponse object, extract the content blocks and usages from it. Args: - used_time (`float`): - The time used for the response in seconds. + start_datetime (`datetime`): + The start datetime of the response generation. response ( `Union[GenerationResponse, MultiModalConversationResponse]` ): Dashscope GenerationResponse | MultiModalConversationResponse object to parse. + structured_model (`Type[BaseModel] | None`, default `None`): + A Pydantic BaseModel class that defines the expected structure + for the model's output. Returns: - `ChatResponse`: + ChatResponse (`ChatResponse`): A ChatResponse object containing the content blocks and usage. + + .. note:: + If `structured_model` is not `None`, the expected structured output + will be stored in the metadata of the `ChatResponse`. """ # Collect the content blocks from the response. if response.status_code != 200: raise RuntimeError(response) content_blocks: List[TextBlock | ToolUseBlock] = [] + metadata = None message = response.output.choices[0].message content = message.get("content") @@ -352,35 +427,42 @@ async def _parse_dashscope_generation_response( if message.get("tool_calls"): for tool_call in message["tool_calls"]: + input_ = _json_loads_with_repair( + tool_call["function"].get( + "arguments", + "{}", + ) + or "{}", + ) content_blocks.append( ToolUseBlock( type="tool_use", name=tool_call["function"]["name"], - input=_json_loads_with_repair( - tool_call["function"].get( - "arguments", - "{}", - ) - or "{}", - ), + input=input_, id=tool_call["id"], ), ) + if structured_model: + metadata = input_ + # Usage information usage = None if response.usage: usage = ChatUsage( input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens, - time=used_time, + time=(datetime.now() - start_datetime).total_seconds(), ) - return ChatResponse( + parsed_response = ChatResponse( content=content_blocks, usage=usage, + metadata=metadata, ) + return parsed_response + def _format_tools_json_schemas( self, schemas: list[dict[str, Any]], diff --git a/src/agentscope/model/_gemini_model.py b/src/agentscope/model/_gemini_model.py index aa4f64b582..ff20e88055 100644 --- a/src/agentscope/model/_gemini_model.py +++ b/src/agentscope/model/_gemini_model.py @@ -2,8 +2,20 @@ # mypy: disable-error-code="dict-item" """The Google Gemini model in agentscope.""" from datetime import datetime -from typing import AsyncGenerator, Any, TYPE_CHECKING, AsyncIterator, Literal - +from typing import ( + AsyncGenerator, + Any, + TYPE_CHECKING, + AsyncIterator, + Literal, + Type, + List, +) + +from pydantic import BaseModel + +from .._logging import logger +from .._utils._common import _json_loads_with_repair from ..message import ToolUseBlock, TextBlock, ThinkingBlock from ._model_usage import ChatUsage from ._model_base import ChatModelBase @@ -83,6 +95,7 @@ async def __call__( tool_choice: Literal["auto", "none", "any", "required"] | str | None = None, + structured_model: Type[BaseModel] | None = None, **config_kwargs: Any, ) -> ChatResponse | AsyncGenerator[ChatResponse, None]: """Call the Gemini model with the provided arguments. @@ -99,6 +112,18 @@ async def __call__( Can be "auto", "none", "any", "required", or specific tool name. For more details, please refer to https://ai.google.dev/gemini-api/docs/function-calling?hl=en&example=meeting#function_calling_modes + structured_model (`Type[BaseModel] | None`, default `None`): + A Pydantic BaseModel class that defines the expected structure + for the model's output. + + .. note:: When `structured_model` is specified, + both `tools` and `tool_choice` parameters are ignored, + and the model will only perform structured output + generation without calling any other tools. + + For more details, please refer to + https://ai.google.dev/gemini-api/docs/structured-output + **config_kwargs (`Any`): The keyword arguments for Gemini chat completions API. """ @@ -111,10 +136,24 @@ async def __call__( if tools: config["tools"] = self._format_tools_json_schemas(tools) + if tool_choice: self._validate_tool_choice(tool_choice, tools) config["tool_config"] = self._format_tool_choice(tool_choice) + if structured_model: + if tools or tool_choice: + logger.warning( + "structured_model is provided. Both 'tools' and " + "'tool_choice' parameters will be overridden and " + "ignored. The model will only perform structured output " + "generation without calling any other tools.", + ) + config.pop("tools", None) + config.pop("tool_config", None) + config["response_mime_type"] = "application/json" + config["response_schema"] = structured_model + # Prepare the arguments for the Gemini API call kwargs: dict[str, JSONSerializableObject] = { "model": self.model_name, @@ -131,6 +170,7 @@ async def __call__( return self._parse_gemini_stream_generation_response( start_datetime, response, + structured_model, ) # non-streaming @@ -139,8 +179,9 @@ async def __call__( ) parsed_response = self._parse_gemini_generation_response( - used_time=(datetime.now() - start_datetime).total_seconds(), - response=response, + start_datetime, + response, + structured_model, ) return parsed_response @@ -149,12 +190,34 @@ async def _parse_gemini_stream_generation_response( self, start_datetime: datetime, response: AsyncIterator[GenerateContentResponse], + structured_model: Type[BaseModel] | None = None, ) -> AsyncGenerator[ChatResponse, None]: - """Parse the Gemini streaming generation response into ChatResponse""" + """Given a Gemini streaming generation response, extract the + content blocks and usages from it and yield ChatResponse objects. + + Args: + start_datetime (`datetime`): + The start datetime of the response generation. + response (`AsyncIterator[GenerateContentResponse]`): + Gemini GenerateContentResponse async iterator to parse. + structured_model (`Type[BaseModel] | None`, default `None`): + A Pydantic BaseModel class that defines the expected structure + for the model's output. + + Returns: + `AsyncGenerator[ChatResponse, None]`: + An async generator that yields ChatResponse objects containing + the content blocks and usage information for each chunk in the + streaming response. + + .. note:: + If `structured_model` is not `None`, the expected structured output + will be stored in the metadata of the `ChatResponse`. + """ - parsed_chunk = None text = "" thinking = "" + metadata = None async for chunk in response: content_block: list = [] @@ -171,6 +234,8 @@ async def _parse_gemini_stream_generation_response( # Text parts if chunk.text: text += chunk.text + if structured_model: + metadata = _json_loads_with_repair(text) # Function calls tool_calls = [] @@ -219,16 +284,38 @@ async def _parse_gemini_stream_generation_response( parsed_chunk = ChatResponse( content=content_block, usage=usage, + metadata=metadata, ) yield parsed_chunk def _parse_gemini_generation_response( self, - used_time: float, + start_datetime: datetime, response: GenerateContentResponse, + structured_model: Type[BaseModel] | None = None, ) -> ChatResponse: - """Parse the Gemini generation response into ChatResponse""" - content: list = [] + """Given a Gemini chat completion response object, extract the content + blocks and usages from it. + + Args: + start_datetime (`datetime`): + The start datetime of the response generation. + response (`ChatCompletion`): + The OpenAI chat completion response object to parse. + structured_model (`Type[BaseModel] | None`, default `None`): + A Pydantic BaseModel class that defines the expected structure + for the model's output. + + Returns: + ChatResponse (`ChatResponse`): + A ChatResponse object containing the content blocks and usage. + + .. note:: + If `structured_model` is not `None`, the expected structured output + will be stored in the metadata of the `ChatResponse`. + """ + content_blocks: List[TextBlock | ToolUseBlock | ThinkingBlock] = [] + metadata = None if ( response.candidates @@ -237,7 +324,7 @@ def _parse_gemini_generation_response( ): for part in response.candidates[0].content.parts: if part.thought and part.text: - content.append( + content_blocks.append( ThinkingBlock( type="thinking", thinking=part.text, @@ -245,16 +332,18 @@ def _parse_gemini_generation_response( ) if response.text: - content.append( + content_blocks.append( TextBlock( type="text", text=response.text, ), ) + if structured_model: + metadata = _json_loads_with_repair(response.text) if response.function_calls: for tool_call in response.function_calls: - content.append( + content_blocks.append( ToolUseBlock( type="tool_use", id=tool_call.id, @@ -268,15 +357,16 @@ def _parse_gemini_generation_response( input_tokens=response.usage_metadata.prompt_token_count, output_tokens=response.usage_metadata.total_token_count - response.usage_metadata.prompt_token_count, - time=used_time, + time=(datetime.now() - start_datetime).total_seconds(), ) else: usage = None return ChatResponse( - content=content, + content=content_blocks, usage=usage, + metadata=metadata, ) def _format_tools_json_schemas( diff --git a/src/agentscope/model/_model_response.py b/src/agentscope/model/_model_response.py index 87c00d9f81..b6c23d537d 100644 --- a/src/agentscope/model/_model_response.py +++ b/src/agentscope/model/_model_response.py @@ -10,6 +10,7 @@ from .._utils._common import _get_timestamp from .._utils._mixin import DictMixin from ..message import TextBlock, ToolUseBlock +from ..types import JSONSerializableObject @dataclass @@ -31,3 +32,8 @@ class ChatResponse(DictMixin): usage: ChatUsage | None = field(default_factory=lambda: None) """The usage information of the chat response, if available.""" + + metadata: JSONSerializableObject | None = field( + default_factory=lambda: None, + ) + """The metadata of the chat response""" diff --git a/src/agentscope/model/_ollama_model.py b/src/agentscope/model/_ollama_model.py index 80d45d94b7..ddae610c12 100644 --- a/src/agentscope/model/_ollama_model.py +++ b/src/agentscope/model/_ollama_model.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- -# pylint: disable=too-many-branches """Model wrapper for Ollama models.""" from datetime import datetime from typing import ( @@ -9,9 +8,12 @@ AsyncGenerator, AsyncIterator, Literal, + Type, ) from collections import OrderedDict +from pydantic import BaseModel + from . import ChatResponse from ._model_base import ChatModelBase from ._model_usage import ChatUsage @@ -22,7 +24,7 @@ if TYPE_CHECKING: - from ollama._types import OllamaChatResponse + from ollama._types import ChatResponse as OllamaChatResponse else: OllamaChatResponse = "ollama._types.ChatResponse" @@ -92,6 +94,7 @@ async def __call__( tool_choice: Literal["auto", "none", "any", "required"] | str | None = None, + structured_model: Type[BaseModel] | None = None, **kwargs: Any, ) -> ChatResponse | AsyncGenerator[ChatResponse, None]: """Get the response from Ollama chat completions API by the given @@ -108,6 +111,9 @@ async def __call__( Controls which (if any) tool is called by the model. Can be "auto", "none", "any", "required", or specific tool name. + structured_model (`Type[BaseModel] | None`, default `None`): + A Pydantic BaseModel class that defines the expected structure + for the model's output. **kwargs (`Any`): The keyword arguments for Ollama chat completions API, e.g. `think`etc. Please refer to the Ollama API @@ -136,6 +142,9 @@ async def __call__( if tool_choice: logger.warning("Ollama does not support tool_choice yet, ignored.") + if structured_model: + kwargs["format"] = structured_model.model_json_schema() + start_datetime = datetime.now() response = await self.client.chat(**kwargs) @@ -143,11 +152,13 @@ async def __call__( return self._parse_ollama_stream_completion_response( start_datetime, response, + structured_model, ) parsed_response = await self._parse_ollama_completion_response( start_datetime, response, + structured_model, ) return parsed_response @@ -155,56 +166,53 @@ async def __call__( async def _parse_ollama_stream_completion_response( self, start_datetime: datetime, - response: AsyncIterator[Any], + response: AsyncIterator[OllamaChatResponse], + structured_model: Type[BaseModel] | None = None, ) -> AsyncGenerator[ChatResponse, None]: - """Parse the Ollama chat completion response stream into an async - generator of `ChatResponse` objects.""" + """Given an Ollama streaming completion response, extract the + content blocks and usages from it and yield ChatResponse objects. + + Args: + start_datetime (`datetime`): + The start datetime of the response generation. + response (`AsyncIterator[OllamaChatResponse]`): + Ollama streaming response async iterator to parse. + structured_model (`Type[BaseModel] | None`, default `None`): + A Pydantic BaseModel class that defines the expected structure + for the model's output. + + Returns: + AsyncGenerator[ChatResponse, None]: + An async generator that yields ChatResponse objects containing + the content blocks and usage information for each chunk in the + streaming response. + + .. note:: + If `structured_model` is not `None`, the expected structured output + will be stored in the metadata of the `ChatResponse`. + + """ accumulated_text = "" acc_thinking_content = "" tool_calls = OrderedDict() # Store tool calls + metadata = None async for chunk in response: - has_new_content = False - has_new_thinking = False - # Handle text content - if hasattr(chunk, "message"): - msg = chunk.message - - if getattr(msg, "thinking", None): - acc_thinking_content += msg.thinking - has_new_thinking = True - - if getattr(msg, "content", None): - accumulated_text += msg.content - has_new_content = True - - # Handle tool calls - if getattr(msg, "tool_calls", None): - has_new_content = True - for idx, tool_call in enumerate(msg.tool_calls): - function_name = ( - getattr( - tool_call, - "function", - None, - ) - and tool_call.function.name - or "tool" - ) - tool_id = getattr( - tool_call, - "id", - f"{function_name}_{idx}", - ) - if hasattr(tool_call, "function"): - function = tool_call.function - tool_calls[tool_id] = { - "type": "tool_use", - "id": tool_id, - "name": function.name, - "input": function.arguments, - } + msg = chunk.message + acc_thinking_content += msg.thinking or "" + accumulated_text += msg.content or "" + + # Handle tool calls + for idx, tool_call in enumerate(msg.tool_calls or []): + function = tool_call.function + tool_id = f"{idx}_{function.name}" + tool_calls[tool_id] = { + "type": "tool_use", + "id": tool_id, + "name": function.name, + "input": function.arguments, + } # Calculate usage statistics current_time = (datetime.now() - start_datetime).total_seconds() usage = ChatUsage( @@ -225,51 +233,63 @@ async def _parse_ollama_stream_completion_response( if accumulated_text: contents.append(TextBlock(type="text", text=accumulated_text)) + if structured_model: + metadata = _json_loads_with_repair(accumulated_text) # Add tool call blocks - if tool_calls: - for tool_call in tool_calls.values(): - try: - input_data = tool_call["input"] - if isinstance(input_data, str): - input_data = _json_loads_with_repair(input_data) - contents.append( - ToolUseBlock( - type=tool_call["type"], - id=tool_call["id"], - name=tool_call["name"], - input=input_data, - ), - ) - except Exception as e: - print(f"Error parsing tool call input: {e}") + for tool_call in tool_calls.values(): + try: + input_data = tool_call["input"] + if isinstance(input_data, str): + input_data = _json_loads_with_repair(input_data) + contents.append( + ToolUseBlock( + type=tool_call["type"], + id=tool_call["id"], + name=tool_call["name"], + input=input_data, + ), + ) + except Exception as e: + print(f"Error parsing tool call input: {e}") # Generate response when there's new content or at final chunk - is_final = getattr(chunk, "done", False) - if (has_new_thinking or has_new_content or is_final) and contents: - res = ChatResponse(content=contents, usage=usage) + if chunk.done and contents: + res = ChatResponse( + content=contents, + usage=usage, + metadata=metadata, + ) yield res async def _parse_ollama_completion_response( self, start_datetime: datetime, response: OllamaChatResponse, + structured_model: Type[BaseModel] | None = None, ) -> ChatResponse: - """Parse the Ollama chat completion response into a `ChatResponse` - object. + """Given an Ollama chat completion response object, extract the content + blocks and usages from it. Args: start_datetime (`datetime`): The start datetime of the response generation. response (`OllamaChatResponse`): - The Ollama chat response object to parse. + Ollama OllamaChatResponse object to parse. + structured_model (`Type[BaseModel] | None`, default `None`): + A Pydantic BaseModel class that defines the expected structure + for the model's output. Returns: `ChatResponse`: - The content blocks and usage information extracted from the - response. + A ChatResponse object containing the content blocks and usage. + + .. note:: + If `structured_model` is not `None`, the expected structured output + will be stored in the metadata of the `ChatResponse`. """ content_blocks: List[TextBlock | ToolUseBlock | ThinkingBlock] = [] + metadata = None if response.message.thinking: content_blocks.append( @@ -286,17 +306,18 @@ async def _parse_ollama_completion_response( text=response.message.content, ), ) + if structured_model: + metadata = _json_loads_with_repair(response.message.content) - if response.message.tool_calls: - for tool_call in response.message.tool_calls: - content_blocks.append( - ToolUseBlock( - type="tool_use", - id=tool_call.function.name, - name=tool_call.function.name, - input=tool_call.function.arguments, - ), - ) + for idx, tool_call in enumerate(response.message.tool_calls or []): + content_blocks.append( + ToolUseBlock( + type="tool_use", + id=f"{idx}_{tool_call.function.name}", + name=tool_call.function.name, + input=tool_call.function.arguments, + ), + ) usage = None if "prompt_eval_count" in response and "eval_count" in response: @@ -309,6 +330,7 @@ async def _parse_ollama_completion_response( parsed_response = ChatResponse( content=content_blocks, usage=usage, + metadata=metadata, ) return parsed_response diff --git a/src/agentscope/model/_openai_model.py b/src/agentscope/model/_openai_model.py index fa034e17b3..a8785d2404 100644 --- a/src/agentscope/model/_openai_model.py +++ b/src/agentscope/model/_openai_model.py @@ -8,12 +8,16 @@ List, AsyncGenerator, Literal, + Type, ) from collections import OrderedDict +from pydantic import BaseModel + from . import ChatResponse from ._model_base import ChatModelBase from ._model_usage import ChatUsage +from .._logging import logger from .._utils._common import _json_loads_with_repair from ..message import ToolUseBlock, TextBlock, ThinkingBlock from ..tracing import trace_llm @@ -88,6 +92,7 @@ async def __call__( tool_choice: Literal["auto", "none", "any", "required"] | str | None = None, + structured_model: Type[BaseModel] | None = None, **kwargs: Any, ) -> ChatResponse | AsyncGenerator[ChatResponse, None]: """Get the response from OpenAI chat completions API by the given @@ -105,6 +110,22 @@ async def __call__( Can be "auto", "none", "any", "required", or specific tool name. For more details, please refer to https://platform.openai.com/docs/api-reference/responses/create#responses_create-tool_choice + structured_model (`Type[BaseModel] | None`, default `None`): + A Pydantic BaseModel class that defines the expected structure + for the model's output. When provided, the model will be forced + to return data that conforms to this schema by automatically + converting the BaseModel to a tool function and setting + `tool_choice` to enforce its usage. This enables structured + output generation. + + .. note:: When `structured_model` is specified, + both `tools` and `tool_choice` parameters are ignored, + and the model will only perform structured output + generation without calling any other tools. + + For more details, please refer to the `official document + `_ + **kwargs (`Any`): The keyword arguments for OpenAI chat completions API, e.g. `temperature`, `max_tokens`, `top_p`, etc. Please @@ -148,54 +169,106 @@ async def __call__( kwargs["stream_options"] = {"include_usage": True} start_datetime = datetime.now() - response = await self.client.chat.completions.create(**kwargs) + + if structured_model: + if tools or tool_choice: + logger.warning( + "structured_model is provided. Both 'tools' and " + "'tool_choice' parameters will be overridden and " + "ignored. The model will only perform structured output " + "generation without calling any other tools.", + ) + kwargs.pop("stream", None) + kwargs.pop("tools", None) + kwargs.pop("tool_choice", None) + kwargs["response_format"] = structured_model + if not self.stream: + response = await self.client.chat.completions.parse(**kwargs) + else: + response = self.client.chat.completions.stream(**kwargs) + return self._parse_openai_stream_response( + start_datetime, + response, + structured_model, + ) + else: + response = await self.client.chat.completions.create(**kwargs) if self.stream: - return self._parse_openai_stream_completion_response( + return self._parse_openai_stream_response( start_datetime, response, + structured_model, ) # Non-streaming response parsed_response = self._parse_openai_completion_response( start_datetime, response, + structured_model, ) return parsed_response - async def _parse_openai_stream_completion_response( + async def _parse_openai_stream_response( self, start_datetime: datetime, response: AsyncStream, + structured_model: Type[BaseModel] | None = None, ) -> AsyncGenerator[ChatResponse, None]: - """Parse the OpenAI chat completion response stream into an async - generator of `ChatResponse` objects.""" + """Given an OpenAI streaming completion response, extract the content + blocks and usages from it and yield ChatResponse objects. + + Args: + start_datetime (`datetime`): + The start datetime of the response generation. + response (`AsyncStream`): + OpenAI AsyncStream object to parse. + structured_model (`Type[BaseModel] | None`, default `None`): + A Pydantic BaseModel class that defines the expected structure + for the model's output. + + Returns: + `AsyncGenerator[ChatResponse, None]`: + An async generator that yields ChatResponse objects containing + the content blocks and usage information for each chunk in + the streaming response. + + .. note:: + If `structured_model` is not `None`, the expected structured output + will be stored in the metadata of the `ChatResponse`. + """ usage, res = None, None text = "" thinking = "" tool_calls = OrderedDict() - async for chunk in response: - if chunk.usage: - usage = ChatUsage( - input_tokens=chunk.usage.prompt_tokens, - output_tokens=chunk.usage.completion_tokens, - time=(datetime.now() - start_datetime).total_seconds(), - ) + metadata = None + + async with response as stream: + async for item in stream: + if structured_model: + if item.type != "chunk": + continue + chunk = item.chunk + else: + chunk = item + + if chunk.usage: + usage = ChatUsage( + input_tokens=chunk.usage.prompt_tokens, + output_tokens=chunk.usage.completion_tokens, + time=(datetime.now() - start_datetime).total_seconds(), + ) - if chunk.choices: - choice = chunk.choices[0] - if ( - hasattr(choice.delta, "reasoning_content") - and choice.delta.reasoning_content is not None - ): - thinking += choice.delta.reasoning_content + if chunk.choices: + choice = chunk.choices[0] - if choice.delta.content: - text += choice.delta.content + thinking += ( + getattr(choice.delta, "reasoning_content", None) or "" + ) + text += choice.delta.content or "" - if choice.delta.tool_calls: - for tool_call in choice.delta.tool_calls: + for tool_call in choice.delta.tool_calls or []: if tool_call.index in tool_calls: if tool_call.function.arguments is not None: tool_calls[tool_call.index][ @@ -210,25 +283,29 @@ async def _parse_openai_stream_completion_response( "input": tool_call.function.arguments or "", } - contents: List[TextBlock | ToolUseBlock | ThinkingBlock] = [] + contents: List[ + TextBlock | ToolUseBlock | ThinkingBlock + ] = [] - if thinking: - contents.append( - ThinkingBlock( - type="thinking", - thinking=thinking, - ), - ) + if thinking: + contents.append( + ThinkingBlock( + type="thinking", + thinking=thinking, + ), + ) - if text: - contents.append( - TextBlock( - type="text", - text=text, - ), - ) + if text: + contents.append( + TextBlock( + type="text", + text=text, + ), + ) + + if structured_model: + metadata = _json_loads_with_repair(text) - if tool_calls: for tool_call in tool_calls.values(): contents.append( ToolUseBlock( @@ -241,33 +318,42 @@ async def _parse_openai_stream_completion_response( ), ) - if contents: - res = ChatResponse( - content=contents, - usage=usage, - ) - yield res + if contents: + res = ChatResponse( + content=contents, + usage=usage, + metadata=metadata, + ) + yield res def _parse_openai_completion_response( self, start_datetime: datetime, response: ChatCompletion, + structured_model: Type[BaseModel] | None = None, ) -> ChatResponse: - """Parse the OpenAI chat completion response into a `ChatResponse` - object. + """Given an OpenAI chat completion response object, extract the content + blocks and usages from it. Args: start_datetime (`datetime`): The start datetime of the response generation. response (`ChatCompletion`): - The OpenAI chat completion response object to parse. + OpenAI ChatCompletion object to parse. + structured_model (`Type[BaseModel] | None`, default `None`): + A Pydantic BaseModel class that defines the expected structure + for the model's output. Returns: - `Tuple[List[TextBlock | ToolUseBlock] | None, ChatUsage | None]`: - The content blocks and usage information extracted from the - response. + ChatResponse (`ChatResponse`): + A ChatResponse object containing the content blocks and usage. + + .. note:: + If `structured_model` is not `None`, the expected structured output + will be stored in the metadata of the `ChatResponse`. """ content_blocks: List[TextBlock | ToolUseBlock | ThinkingBlock] = [] + metadata = None if response.choices: choice = response.choices[0] @@ -290,18 +376,20 @@ def _parse_openai_completion_response( ), ) - if choice.message.tool_calls: - for tool_call in choice.message.tool_calls: - content_blocks.append( - ToolUseBlock( - type="tool_use", - id=tool_call.id, - name=tool_call.function.name, - input=_json_loads_with_repair( - tool_call.function.arguments, - ), + for tool_call in choice.message.tool_calls or []: + content_blocks.append( + ToolUseBlock( + type="tool_use", + id=tool_call.id, + name=tool_call.function.name, + input=_json_loads_with_repair( + tool_call.function.arguments, ), - ) + ), + ) + + if structured_model: + metadata = choice.message.parsed.model_dump() usage = None if response.usage: @@ -314,6 +402,7 @@ def _parse_openai_completion_response( parsed_response = ChatResponse( content=content_blocks, usage=usage, + metadata=metadata, ) return parsed_response diff --git a/tests/model_anthropic_test.py b/tests/model_anthropic_test.py new file mode 100644 index 0000000000..a47520e74a --- /dev/null +++ b/tests/model_anthropic_test.py @@ -0,0 +1,431 @@ +# -*- coding: utf-8 -*- +"""Unit tests for Anthropic API model class.""" +from typing import Any, AsyncGenerator +from unittest.async_case import IsolatedAsyncioTestCase +from unittest.mock import Mock, patch, AsyncMock +from pydantic import BaseModel + +from agentscope.model import AnthropicChatModel, ChatResponse +from agentscope.message import TextBlock, ToolUseBlock, ThinkingBlock + + +class SampleModel(BaseModel): + """Sample Pydantic model for testing structured output.""" + + name: str + age: int + + +class AnthropicMessageMock: + """Mock class for Anthropic message objects.""" + + def __init__(self, content: list = None, usage: dict = None): + self.content = content or [] + self.usage = self._create_usage_mock(usage) if usage else None + + def _create_usage_mock(self, usage_data: dict) -> Mock: + usage_mock = Mock() + usage_mock.input_tokens = usage_data.get("input_tokens", 0) + usage_mock.output_tokens = usage_data.get("output_tokens", 0) + return usage_mock + + +class AnthropicContentBlockMock: + """Mock class for Anthropic content blocks.""" + + def __init__(self, block_type: str, **kwargs: Any) -> None: + self.type = block_type + for key, value in kwargs.items(): + setattr(self, key, value) + + +class AnthropicEventMock: + """Mock class for Anthropic streaming events.""" + + def __init__(self, event_type: str, **kwargs: Any) -> None: + self.type = event_type + for key, value in kwargs.items(): + setattr(self, key, value) + + +class TestAnthropicChatModel(IsolatedAsyncioTestCase): + """Test cases for AnthropicChatModel.""" + + def test_init_default_params(self) -> None: + """Test initialization with default parameters.""" + with patch("anthropic.AsyncAnthropic") as mock_client: + model = AnthropicChatModel( + model_name="claude-3-sonnet-20240229", + api_key="test_key", + ) + self.assertEqual(model.model_name, "claude-3-sonnet-20240229") + self.assertEqual(model.max_tokens, 2048) + self.assertTrue(model.stream) + self.assertIsNone(model.thinking) + self.assertEqual(model.generate_kwargs, {}) + mock_client.assert_called_once_with(api_key="test_key") + + def test_init_with_custom_params(self) -> None: + """Test initialization with custom parameters.""" + thinking_config = {"type": "enabled", "budget_tokens": 1024} + generate_kwargs = {"temperature": 0.7, "top_p": 0.9} + client_args = {"timeout": 30} + + with patch("anthropic.AsyncAnthropic") as mock_client: + model = AnthropicChatModel( + model_name="claude-3-opus-20240229", + api_key="test_key", + max_tokens=4096, + stream=False, + thinking=thinking_config, + client_args=client_args, + generate_kwargs=generate_kwargs, + ) + self.assertEqual(model.model_name, "claude-3-opus-20240229") + self.assertEqual(model.max_tokens, 4096) + self.assertFalse(model.stream) + self.assertEqual(model.thinking, thinking_config) + self.assertEqual(model.generate_kwargs, generate_kwargs) + mock_client.assert_called_once_with(api_key="test_key", timeout=30) + + async def test_call_with_regular_messages(self) -> None: + """Test calling with regular messages.""" + with patch("anthropic.AsyncAnthropic") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + model = AnthropicChatModel( + model_name="claude-3-sonnet-20240229", + api_key="test_key", + stream=False, + ) + model.client = mock_client + + messages = [{"role": "user", "content": "Hello"}] + mock_response = AnthropicMessageMock( + content=[ + AnthropicContentBlockMock( + "text", + text="Hello! How can I help you?", + ), + ], + usage={"input_tokens": 10, "output_tokens": 20}, + ) + mock_client.messages.create = AsyncMock(return_value=mock_response) + + result = await model(messages) + call_args = mock_client.messages.create.call_args[1] + self.assertEqual(call_args["model"], "claude-3-sonnet-20240229") + self.assertEqual(call_args["max_tokens"], 2048) + self.assertFalse(call_args["stream"]) + self.assertEqual(call_args["messages"], messages) + self.assertIsInstance(result, ChatResponse) + expected_content = [ + TextBlock(type="text", text="Hello! How can I help you?"), + ] + self.assertEqual(result.content, expected_content) + self.assertEqual(result.usage.input_tokens, 10) + self.assertEqual(result.usage.output_tokens, 20) + + async def test_call_with_system_message(self) -> None: + """Test calling with system message extraction.""" + with patch("anthropic.AsyncAnthropic") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + model = AnthropicChatModel( + model_name="claude-3-sonnet-20240229", + api_key="test_key", + stream=False, + ) + model.client = mock_client + + messages = [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello"}, + ] + mock_response = AnthropicMessageMock( + content=[AnthropicContentBlockMock("text", text="Hi there!")], + usage={"input_tokens": 15, "output_tokens": 5}, + ) + mock_client.messages.create = AsyncMock(return_value=mock_response) + await model(messages) + + call_args = mock_client.messages.create.call_args[1] + self.assertEqual( + call_args["system"], + "You are a helpful assistant", + ) + self.assertEqual( + call_args["messages"], + [ + {"role": "user", "content": "Hello"}, + ], + ) + + async def test_call_with_thinking_enabled(self) -> None: + """Test calling with thinking functionality enabled.""" + with patch("anthropic.AsyncAnthropic") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + thinking_config = {"type": "enabled", "budget_tokens": 1024} + model = AnthropicChatModel( + model_name="claude-3-sonnet-20240229", + api_key="test_key", + stream=False, + thinking=thinking_config, + ) + model.client = mock_client + + messages = [ + {"role": "user", "content": "Think about this problem"}, + ] + thinking_block = AnthropicContentBlockMock( + "thinking", + thinking="Let me analyze this step by step...", + signature="thinking_signature_123", + ) + text_block = AnthropicContentBlockMock( + "text", + text="Here's my analysis", + ) + mock_response = AnthropicMessageMock( + content=[thinking_block, text_block], + usage={"input_tokens": 20, "output_tokens": 40}, + ) + mock_client.messages.create = AsyncMock(return_value=mock_response) + result = await model(messages) + + call_args = mock_client.messages.create.call_args[1] + self.assertEqual(call_args["thinking"], thinking_config) + expected_thinking_block = ThinkingBlock( + type="thinking", + thinking="Let me analyze this step by step...", + ) + expected_thinking_block["signature"] = "thinking_signature_123" + expected_content = [ + expected_thinking_block, + TextBlock(type="text", text="Here's my analysis"), + ] + self.assertEqual(result.content, expected_content) + + async def test_call_with_tools_integration(self) -> None: + """Test full integration of tool calls.""" + with patch("anthropic.AsyncAnthropic") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + model = AnthropicChatModel( + model_name="claude-3-sonnet-20240229", + api_key="test_key", + stream=False, + ) + model.client = mock_client + + messages = [{"role": "user", "content": "What's the weather?"}] + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather info", + "parameters": {"type": "object"}, + }, + }, + ] + text_block = AnthropicContentBlockMock( + "text", + text="I'll check the weather", + ) + tool_block = AnthropicContentBlockMock( + "tool_use", + id="tool_123", + name="get_weather", + input={"location": "Beijing"}, + ) + + mock_response = AnthropicMessageMock( + content=[text_block, tool_block], + usage={"input_tokens": 25, "output_tokens": 15}, + ) + mock_client.messages.create = AsyncMock(return_value=mock_response) + result = await model(messages, tools=tools, tool_choice="auto") + # Verify tool formatting + call_args = mock_client.messages.create.call_args[1] + expected_tools = [ + { + "name": "get_weather", + "description": "Get weather info", + "input_schema": {"type": "object"}, + }, + ] + self.assertEqual(call_args["tools"], expected_tools) + self.assertEqual(call_args["tool_choice"], {"type": "auto"}) + expected_content = [ + TextBlock(type="text", text="I'll check the weather"), + ToolUseBlock( + type="tool_use", + id="tool_123", + name="get_weather", + input={"location": "Beijing"}, + ), + ] + self.assertEqual(result.content, expected_content) + + async def test_streaming_response_processing(self) -> None: + """Test processing of streaming response.""" + with patch("anthropic.AsyncAnthropic") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + model = AnthropicChatModel( + model_name="claude-3-sonnet-20240229", + api_key="test_key", + stream=True, + ) + model.client = mock_client + + messages = [{"role": "user", "content": "Hello"}] + events = [ + AnthropicEventMock( + "message_start", + message=Mock(usage=Mock(input_tokens=10, output_tokens=0)), + ), + AnthropicEventMock( + "content_block_delta", + index=0, + delta=Mock(type="text_delta", text="Hello"), + ), + AnthropicEventMock( + "content_block_delta", + index=0, + delta=Mock(type="text_delta", text=" there!"), + ), + AnthropicEventMock( + "message_delta", + usage=Mock(output_tokens=5), + ), + ] + + async def mock_stream() -> AsyncGenerator: + for event in events: + yield event + + mock_client.messages.create = AsyncMock(return_value=mock_stream()) + result = await model(messages) + responses = [] + async for response in result: + responses.append(response) + + self.assertGreater(len(responses), 0) + final_response = responses[-1] + self.assertIsInstance(final_response, ChatResponse) + expected_content = [ + TextBlock(type="text", text="Hello there!"), + ] + self.assertEqual(final_response.content, expected_content) + + async def test_generate_kwargs_integration(self) -> None: + """Test integration of generate_kwargs.""" + with patch("anthropic.AsyncAnthropic") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + generate_kwargs = {"temperature": 0.7, "top_p": 0.9} + model = AnthropicChatModel( + model_name="claude-3-sonnet-20240229", + api_key="test_key", + stream=False, + generate_kwargs=generate_kwargs, + ) + model.client = mock_client + + messages = [{"role": "user", "content": "Test"}] + mock_response = AnthropicMessageMock( + content=[ + AnthropicContentBlockMock("text", text="Test response"), + ], + usage={"input_tokens": 5, "output_tokens": 10}, + ) + mock_client.messages.create = AsyncMock(return_value=mock_response) + await model(messages, top_k=40) + call_args = mock_client.messages.create.call_args[1] + self.assertEqual(call_args["temperature"], 0.7) + self.assertEqual(call_args["top_p"], 0.9) + self.assertEqual(call_args["top_k"], 40) + + async def test_call_with_structured_model_integration(self) -> None: + """Test full integration of structured model.""" + with patch("anthropic.AsyncAnthropic") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + model = AnthropicChatModel( + model_name="claude-3-sonnet-20240229", + api_key="test_key", + stream=False, + ) + model.client = mock_client + + messages = [{"role": "user", "content": "Generate a person"}] + + text_block = AnthropicContentBlockMock( + "text", + text="Here's a person", + ) + tool_block = AnthropicContentBlockMock( + "tool_use", + id="tool_123", + name="generate_structured_output", + input={"name": "John", "age": 30}, + ) + + mock_response = AnthropicMessageMock( + content=[text_block, tool_block], + usage={"input_tokens": 20, "output_tokens": 15}, + ) + + mock_client.messages.create = AsyncMock(return_value=mock_response) + result = await model(messages, structured_model=SampleModel) + + call_args = mock_client.messages.create.call_args[1] + self.assertIn("tools", call_args) + self.assertIn("tool_choice", call_args) + expected_tools = [ + { + "name": "generate_structured_output", + "description": "Generate the required structured output" + " with this function", + "input_schema": { + "description": "Sample Pydantic model for testing " + "structured output.", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + "type": "object", + }, + }, + ] + self.assertEqual(call_args["tools"], expected_tools) + self.assertEqual( + call_args["tool_choice"], + { + "type": "tool", + "name": "generate_structured_output", + }, + ) + self.assertIsInstance(result, ChatResponse) + expected_content = [ + TextBlock(type="text", text="Here's a person"), + ToolUseBlock( + type="tool_use", + id="tool_123", + name="generate_structured_output", + input={"name": "John", "age": 30}, + ), + ] + self.assertEqual(result.content, expected_content) + self.assertEqual(result.metadata, {"name": "John", "age": 30}) diff --git a/tests/model_dashscope_test.py b/tests/model_dashscope_test.py new file mode 100644 index 0000000000..470bcf706c --- /dev/null +++ b/tests/model_dashscope_test.py @@ -0,0 +1,449 @@ +# -*- coding: utf-8 -*- +"""Unit tests for DashScope API model class.""" +from typing import Any, AsyncGenerator +from unittest.async_case import IsolatedAsyncioTestCase +from unittest.mock import Mock, patch +from http import HTTPStatus +from pydantic import BaseModel + +from agentscope.model import DashScopeChatModel, ChatResponse +from agentscope.message import TextBlock, ToolUseBlock, ThinkingBlock + + +class MessageMock(dict): + """Mock class for message objects, supports both dictionary and + attribute access.""" + + def __init__(self, data: dict[str, Any]): + super().__init__(data) + for key, value in data.items(): + setattr(self, key, value) + + +class SampleModel(BaseModel): + """Sample Pydantic model for testing structured output.""" + + name: str + age: int + + +class TestDashScopeChatModel(IsolatedAsyncioTestCase): + """Test cases for DashScopeChatModel.""" + + def test_init_default_params(self) -> None: + """Test initialization with default parameters.""" + model = DashScopeChatModel( + model_name="qwen-turbo", + api_key="test_key", + ) + self.assertEqual(model.model_name, "qwen-turbo") + self.assertEqual(model.api_key, "test_key") + self.assertTrue(model.stream) + self.assertIsNone(model.enable_thinking) + self.assertEqual(model.generate_kwargs, {}) + + def test_init_with_enable_thinking_forces_stream(self) -> None: + """Test that enable_thinking=True forces stream=True.""" + with patch("agentscope.model._dashscope_model.logger") as mock_logger: + model = DashScopeChatModel( + model_name="qwen-turbo", + api_key="test_key", + stream=False, + enable_thinking=True, + ) + self.assertTrue(model.stream) + self.assertTrue(model.enable_thinking) + mock_logger.info.assert_called_once() + + def test_init_with_custom_params(self) -> None: + """Test initialization with custom parameters.""" + generate_kwargs = {"temperature": 0.7, "max_tokens": 1000} + model = DashScopeChatModel( + model_name="qwen-max", + api_key="test_key", + stream=False, + enable_thinking=False, + generate_kwargs=generate_kwargs, + ) + self.assertEqual(model.model_name, "qwen-max") + self.assertFalse(model.stream) + self.assertFalse(model.enable_thinking) + self.assertEqual(model.generate_kwargs, generate_kwargs) + + async def test_call_with_regular_model(self) -> None: + """Test calling a regular model.""" + model = DashScopeChatModel( + model_name="qwen-turbo", + api_key="test_key", + stream=False, + ) + messages = [{"role": "user", "content": "Hello"}] + + mock_response = self._create_mock_response( + "Hello! How can I help you?", + ) + with patch( + "dashscope.aigc.generation.AioGeneration.call", + ) as mock_call: + mock_call.return_value = mock_response + result = await model(messages) + call_args = mock_call.call_args[1] + self.assertEqual(call_args["messages"], messages) + self.assertEqual(call_args["model"], "qwen-turbo") + self.assertFalse(call_args["stream"]) + self.assertIsInstance(result, ChatResponse) + expected_content = [ + TextBlock(type="text", text="Hello! How can I help you?"), + ] + self.assertEqual(result.content, expected_content) + + async def test_call_with_tools_integration(self) -> None: + """Test full integration of tool calls.""" + model = DashScopeChatModel( + model_name="qwen-turbo", + api_key="test_key", + stream=False, + ) + messages = [{"role": "user", "content": "What's the weather?"}] + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather info", + "parameters": {"type": "object"}, + }, + }, + ] + + mock_response = self._create_mock_response_with_tools( + "I'll check the weather for you.", + [ + { + "id": "call_123", + "function": { + "name": "get_weather", + "arguments": '{"location": "Beijing"}', + }, + }, + ], + ) + + with patch( + "dashscope.aigc.generation.AioGeneration.call", + ) as mock_call: + mock_call.return_value = mock_response + result = await model(messages, tools=tools, tool_choice="auto") + call_args = mock_call.call_args[1] + self.assertIn("tools", call_args) + self.assertIn("tool_choice", call_args) + self.assertEqual(call_args["tool_choice"], "auto") + + expected_content = [ + TextBlock(type="text", text="I'll check the weather for you."), + ToolUseBlock( + type="tool_use", + id="call_123", + name="get_weather", + input={"location": "Beijing"}, + ), + ] + self.assertEqual(result.content, expected_content) + + async def test_call_with_enable_thinking_streaming(self) -> None: + """Test streaming response with thinking mode enabled.""" + model = DashScopeChatModel( + model_name="qwen-turbo", + api_key="test_key", + enable_thinking=True, + ) + messages = [{"role": "user", "content": "Solve this problem"}] + + chunks = [ + self._create_mock_chunk( + content="Solution", + reasoning_content="Let me think...", + ), + ] + + with patch( + "dashscope.aigc.generation.AioGeneration.call", + ) as mock_call: + mock_call.return_value = self._create_async_generator(chunks) + result = await model(messages) + + call_args = mock_call.call_args[1] + self.assertTrue(call_args["enable_thinking"]) + self.assertTrue(call_args["stream"]) + responses = [] + async for response in result: + responses.append(response) + self.assertGreater(len(responses), 0) + self.assertIsInstance(responses[0], ChatResponse) + + expected_content = [ + ThinkingBlock(type="thinking", thinking="Let me think..."), + TextBlock(type="text", text="Solution"), + ] + self.assertEqual(responses[0].content, expected_content) + + async def test_call_with_structured_model_integration(self) -> None: + """Test full integration of a structured model.""" + model = DashScopeChatModel( + model_name="qwen-turbo", + api_key="test_key", + stream=False, + ) + messages = [{"role": "user", "content": "Generate a person"}] + + mock_response = self._create_mock_response_with_tools( + "Here's a person", + [ + { + "id": "call_123", + "function": { + "name": "generate_structured_output", + "arguments": '{"name": "John", "age": 30}', + }, + }, + ], + ) + + with patch( + "dashscope.aigc.generation.AioGeneration.call", + ) as mock_call: + mock_call.return_value = mock_response + + result = await model(messages, structured_model=SampleModel) + call_args = mock_call.call_args[1] + + expected_tools = [ + { + "type": "function", + "function": { + "name": "generate_structured_output", + "description": "Generate the required structured" + " output with this function", + "parameters": { + "description": "Sample Pydantic model for " + "testing structured output.", + "properties": { + "name": { + "type": "string", + }, + "age": { + "type": "integer", + }, + }, + "required": [ + "name", + "age", + ], + "type": "object", + }, + }, + }, + ] + self.assertEqual(call_args["tools"], expected_tools) + self.assertEqual( + call_args["tool_choice"], + { + "type": "function", + "function": { + "name": "generate_structured_output", + }, + }, + ) + + self.assertIsInstance(result, ChatResponse) + self.assertEqual(result.metadata, {"name": "John", "age": 30}) + expected_content = [ + TextBlock(type="text", text="Here's a person"), + ToolUseBlock( + type="tool_use", + id="call_123", + name="generate_structured_output", + input={"name": "John", "age": 30}, + ), + ] + self.assertEqual(result.content, expected_content) + + async def test_streaming_response_processing(self) -> None: + """Test processing of streaming response.""" + model = DashScopeChatModel( + model_name="qwen-turbo", + api_key="test_key", + stream=True, + ) + messages = [{"role": "user", "content": "Hello"}] + + chunks = [ + self._create_mock_chunk( + content="Hello", + reasoning_content="I should greet", + tool_calls=[], + ), + self._create_mock_chunk( + content=" there!", + reasoning_content=" the user", + tool_calls=[ + { + "index": 0, + "id": "call_123", + "function": { + "name": "greet", + "arguments": '{"name": "user"}', + }, + }, + ], + ), + ] + + with patch( + "dashscope.aigc.generation.AioGeneration.call", + ) as mock_call: + mock_call.return_value = self._create_async_generator(chunks) + result = await model(messages) + + responses = [] + async for response in result: + responses.append(response) + self.assertEqual(len(responses), 2) + final_response = responses[-1] + + expected_content = [ + ThinkingBlock( + type="thinking", + thinking="I should greet the user", + ), + TextBlock(type="text", text="Hello there!"), + ToolUseBlock( + id="call_123", + name="greet", + input={"name": "user"}, + type="tool_use", + ), + ] + self.assertEqual(final_response.content, expected_content) + + def test_tools_schema_validation_through_api(self) -> None: + """Test tools schema validation through API call.""" + model = DashScopeChatModel( + model_name="qwen-turbo", + api_key="test_key", + stream=False, + ) + # Test valid tools schema + valid_tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather info", + }, + }, + ] + + # This test validates the format of the tools schema via an actual + # API call + messages = [{"role": "user", "content": "Test"}] + mock_response = self._create_mock_response("Test") + + with patch( + "dashscope.aigc.generation.AioGeneration.call", + ) as mock_call: + mock_call.return_value = mock_response + + # Should not throw an exception + try: + import asyncio + + loop = asyncio.get_event_loop() + if loop.is_running(): + # If event loop is already running, create a task + loop.create_task(model(messages, tools=valid_tools)) + else: + loop.run_until_complete(model(messages, tools=valid_tools)) + except Exception as e: + if "schema must be a dict" in str(e): + self.fail("Valid tools schema was rejected") + + async def test_error_handling_scenarios(self) -> None: + """Test various error handling scenarios.""" + model = DashScopeChatModel( + model_name="qwen-turbo", + api_key="test_key", + stream=False, + ) + messages = [{"role": "user", "content": "Hello"}] + + # Test failure of non-streaming API call + mock_response = Mock() + mock_response.status_code = 400 + with patch( + "dashscope.aigc.generation.AioGeneration.call", + ) as mock_call: + mock_call.return_value = mock_response + with self.assertRaises(RuntimeError): + await model(messages) + + # Auxiliary methods + def _create_mock_response(self, content: str) -> Mock: + """Create a standard mock response.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.output.choices = [Mock()] + mock_response.output.choices[0].message = MessageMock( + {"content": content}, + ) + mock_response.usage = Mock() + mock_response.usage.input_tokens = 10 + mock_response.usage.output_tokens = 20 + return mock_response + + def _create_mock_response_with_tools( + self, + content: str, + tool_calls: list, + ) -> Mock: + """Create a mock response containing tool calls.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.output.choices = [Mock()] + mock_response.output.choices[0].message = MessageMock( + { + "content": content, + "tool_calls": tool_calls, + }, + ) + mock_response.usage = Mock() + mock_response.usage.input_tokens = 20 + mock_response.usage.output_tokens = 30 + return mock_response + + def _create_mock_chunk( + self, + content: str = "", + reasoning_content: str = "", + tool_calls: list = None, + ) -> Mock: + """Create a mock chunk for streaming responses.""" + chunk = Mock() + chunk.status_code = HTTPStatus.OK + chunk.output.choices = [Mock()] + chunk.output.choices[0].message = MessageMock( + { + "content": content, + "reasoning_content": reasoning_content, + "tool_calls": tool_calls or [], + }, + ) + chunk.usage = Mock() + chunk.usage.input_tokens = 5 + chunk.usage.output_tokens = 10 + return chunk + + async def _create_async_generator(self, items: list) -> AsyncGenerator: + """Create an asynchronous generator.""" + for item in items: + yield item diff --git a/tests/model_gemini_test.py b/tests/model_gemini_test.py new file mode 100644 index 0000000000..bb561e1370 --- /dev/null +++ b/tests/model_gemini_test.py @@ -0,0 +1,422 @@ +# -*- coding: utf-8 -*- +"""Unit tests for Google Gemini API model class.""" +from typing import AsyncGenerator +from unittest.async_case import IsolatedAsyncioTestCase +from unittest.mock import Mock, patch, AsyncMock +from pydantic import BaseModel + +from agentscope.model import GeminiChatModel, ChatResponse +from agentscope.message import TextBlock, ToolUseBlock, ThinkingBlock + + +class GeminiResponseMock: + """Mock class for Gemini response objects.""" + + def __init__( + self, + text: str = "", + function_calls: list = None, + usage_metadata: dict = None, + candidates: list = None, + ): + self.text = text + self.function_calls = function_calls or [] + self.usage_metadata = ( + self._create_usage_mock(usage_metadata) if usage_metadata else None + ) + self.candidates = candidates or [] + + def _create_usage_mock(self, usage_data: dict) -> Mock: + usage_mock = Mock() + usage_mock.prompt_token_count = usage_data.get("prompt_token_count", 0) + usage_mock.total_token_count = usage_data.get("total_token_count", 0) + return usage_mock + + +class GeminiFunctionCallMock: + """Mock class for Gemini function calls.""" + + def __init__(self, call_id: str, name: str, args: dict = None): + self.id = call_id + self.name = name + self.args = args or {} + + +class GeminiPartMock: + """Mock class for Gemini content parts.""" + + def __init__(self, text: str = "", thought: bool = False): + self.text = text + self.thought = thought + + +class GeminiCandidateMock: + """Mock class for Gemini candidates.""" + + def __init__(self, parts: list = None): + self.content = Mock() + self.content.parts = parts or [] + + +class SampleModel(BaseModel): + """Sample Pydantic model for testing structured output.""" + + name: str + age: int + + +class TestGeminiChatModel(IsolatedAsyncioTestCase): + """Test cases for GeminiChatModel.""" + + def test_init_default_params(self) -> None: + """Test initialization with default parameters.""" + with patch("google.genai.Client") as mock_client: + model = GeminiChatModel( + model_name="gemini-2.5-flash", + api_key="test_key", + ) + self.assertEqual(model.model_name, "gemini-2.5-flash") + self.assertTrue(model.stream) + self.assertIsNone(model.thinking_config) + self.assertEqual(model.generate_kwargs, {}) + mock_client.assert_called_once_with(api_key="test_key") + + def test_init_with_custom_params(self) -> None: + """Test initialization with custom parameters.""" + thinking_config = {"include_thoughts": True, "thinking_budget": 1024} + generate_kwargs = {"temperature": 0.7, "top_p": 0.9} + client_args = {"timeout": 30} + + with patch("google.genai.Client") as mock_client: + model = GeminiChatModel( + model_name="gemini-2.5-pro", + api_key="test_key", + stream=False, + thinking_config=thinking_config, + client_args=client_args, + generate_kwargs=generate_kwargs, + ) + self.assertEqual(model.model_name, "gemini-2.5-pro") + self.assertFalse(model.stream) + self.assertEqual(model.thinking_config, thinking_config) + self.assertEqual(model.generate_kwargs, generate_kwargs) + mock_client.assert_called_once_with(api_key="test_key", timeout=30) + + async def test_call_with_regular_model(self) -> None: + """Test calling a regular model.""" + with patch("google.genai.Client") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + model = GeminiChatModel( + model_name="gemini-2.5-flash", + api_key="test_key", + stream=False, + ) + model.client = mock_client + messages = [{"role": "user", "content": "Hello"}] + mock_response = self._create_mock_response( + "Hello! How can I help you?", + ) + mock_client.aio.models.generate_content = AsyncMock( + return_value=mock_response, + ) + + result = await model(messages) + call_args = mock_client.aio.models.generate_content.call_args[1] + self.assertEqual(call_args["model"], "gemini-2.5-flash") + self.assertEqual(call_args["contents"], messages) + self.assertIsInstance(result, ChatResponse) + expected_content = [ + TextBlock(type="text", text="Hello! How can I help you?"), + ] + self.assertEqual(result.content, expected_content) + + async def test_call_with_tools_integration(self) -> None: + """Test full integration of tool calls.""" + with patch("google.genai.Client") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + model = GeminiChatModel( + model_name="gemini-2.5-flash", + api_key="test_key", + stream=False, + ) + model.client = mock_client + + messages = [{"role": "user", "content": "What's the weather?"}] + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather info", + "parameters": {"type": "object"}, + }, + }, + ] + + mock_response = self._create_mock_response_with_tools( + "I'll check the weather for you.", + [ + GeminiFunctionCallMock( + call_id="call_123", + name="get_weather", + args={"location": "Beijing"}, + ), + ], + ) + + mock_client.aio.models.generate_content = AsyncMock( + return_value=mock_response, + ) + result = await model(messages, tools=tools, tool_choice="auto") + + call_args = mock_client.aio.models.generate_content.call_args[1] + self.assertIn("tools", call_args["config"]) + self.assertIn("tool_config", call_args["config"]) + expected_tools = [ + { + "function_declarations": [ + { + "name": "get_weather", + "description": "Get weather info", + "parameters": {"type": "object"}, + }, + ], + }, + ] + self.assertEqual(call_args["config"]["tools"], expected_tools) + self.assertEqual( + call_args["config"]["tool_config"], + { + "function_calling_config": {"mode": "AUTO"}, + }, + ) + expected_content = [ + TextBlock(type="text", text="I'll check the weather for you."), + ToolUseBlock( + type="tool_use", + id="call_123", + name="get_weather", + input={"location": "Beijing"}, + ), + ] + self.assertEqual(result.content, expected_content) + + async def test_call_with_thinking_enabled(self) -> None: + """Test calling with thinking functionality enabled.""" + with patch("google.genai.Client") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + thinking_config = { + "include_thoughts": True, + "thinking_budget": 1024, + } + model = GeminiChatModel( + model_name="gemini-2.5-pro", + api_key="test_key", + stream=False, + thinking_config=thinking_config, + ) + model.client = mock_client + + messages = [ + {"role": "user", "content": "Think about this problem"}, + ] + thinking_part = GeminiPartMock( + text="Let me analyze this step by step...", + thought=True, + ) + candidate = GeminiCandidateMock(parts=[thinking_part]) + mock_response = self._create_mock_response_with_thinking( + "Here's my analysis", + candidates=[candidate], + ) + mock_client.aio.models.generate_content = AsyncMock( + return_value=mock_response, + ) + result = await model(messages) + + call_args = mock_client.aio.models.generate_content.call_args[1] + self.assertEqual( + call_args["config"]["thinking_config"], + thinking_config, + ) + expected_content = [ + ThinkingBlock( + type="thinking", + thinking="Let me analyze this step by step...", + ), + TextBlock(type="text", text="Here's my analysis"), + ] + self.assertEqual(result.content, expected_content) + + async def test_call_with_structured_model_integration(self) -> None: + """Test full integration of a structured model.""" + with patch("google.genai.Client") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + model = GeminiChatModel( + model_name="gemini-2.5-flash", + api_key="test_key", + stream=False, + ) + model.client = mock_client + messages = [{"role": "user", "content": "Generate a person"}] + mock_response = self._create_mock_response( + '{"name": "John", "age": 30}', + ) + mock_client.aio.models.generate_content = AsyncMock( + return_value=mock_response, + ) + + result = await model(messages, structured_model=SampleModel) + call_args = mock_client.aio.models.generate_content.call_args[1] + self.assertEqual( + call_args["config"]["response_mime_type"], + "application/json", + ) + self.assertEqual( + call_args["config"]["response_schema"], + SampleModel, + ) + self.assertNotIn("tools", call_args["config"]) + self.assertNotIn("tool_config", call_args["config"]) + + self.assertIsInstance(result, ChatResponse) + self.assertEqual(result.metadata, {"name": "John", "age": 30}) + expected_content = [ + TextBlock(type="text", text='{"name": "John", "age": 30}'), + ] + self.assertEqual(result.content, expected_content) + + async def test_streaming_response_processing(self) -> None: + """Test processing of streaming response.""" + with patch("google.genai.Client") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + model = GeminiChatModel( + model_name="gemini-2.5-flash", + api_key="test_key", + stream=True, + ) + model.client = mock_client + + messages = [{"role": "user", "content": "Hello"}] + chunks = [ + self._create_mock_chunk(text="Hello"), + self._create_mock_chunk(text=" there!"), + ] + + mock_client.aio.models.generate_content_stream = AsyncMock( + return_value=self._create_async_generator(chunks), + ) + result = await model(messages) + responses = [] + async for response in result: + responses.append(response) + + self.assertEqual(len(responses), 2) + final_response = responses[-1] + expected_content = [ + TextBlock(type="text", text="Hello there!"), + ] + self.assertEqual(final_response.content, expected_content) + + async def test_generate_kwargs_integration(self) -> None: + """Test integration of generate_kwargs.""" + with patch("google.genai.Client") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + generate_kwargs = {"temperature": 0.7, "top_p": 0.9} + model = GeminiChatModel( + model_name="gemini-2.5-flash", + api_key="test_key", + stream=False, + generate_kwargs=generate_kwargs, + ) + model.client = mock_client + + messages = [{"role": "user", "content": "Test"}] + mock_response = self._create_mock_response("Test response") + mock_client.aio.models.generate_content = AsyncMock( + return_value=mock_response, + ) + + await model(messages, top_k=40) + + call_args = mock_client.aio.models.generate_content.call_args[1] + self.assertEqual(call_args["config"]["temperature"], 0.7) + self.assertEqual(call_args["config"]["top_p"], 0.9) + self.assertEqual(call_args["config"]["top_k"], 40) + + # Auxiliary methods + def _create_mock_response( + self, + text: str = "", + usage_metadata: dict = None, + ) -> GeminiResponseMock: + """Create a standard mock response.""" + return GeminiResponseMock( + text=text, + usage_metadata=usage_metadata + or {"prompt_token_count": 10, "total_token_count": 30}, + ) + + def _create_mock_response_with_tools( + self, + text: str, + function_calls: list, + usage_metadata: dict = None, + ) -> GeminiResponseMock: + """Create a mock response containing tool calls.""" + return GeminiResponseMock( + text=text, + function_calls=function_calls, + usage_metadata=usage_metadata + or {"prompt_token_count": 20, "total_token_count": 50}, + ) + + def _create_mock_response_with_thinking( + self, + text: str, + candidates: list = None, + usage_metadata: dict = None, + ) -> GeminiResponseMock: + """Create a mock response with thinking parts.""" + return GeminiResponseMock( + text=text, + candidates=candidates or [], + usage_metadata=usage_metadata + or {"prompt_token_count": 15, "total_token_count": 35}, + ) + + def _create_mock_chunk( + self, + text: str = "", + function_calls: list = None, + candidates: list = None, + usage_metadata: dict = None, + ) -> GeminiResponseMock: + """Create a mock chunk for streaming responses.""" + return GeminiResponseMock( + text=text, + function_calls=function_calls or [], + candidates=candidates or [], + usage_metadata=usage_metadata + or { + "prompt_token_count": 5, + "total_token_count": 15, + }, + ) + + async def _create_async_generator(self, items: list) -> AsyncGenerator: + """Create an asynchronous generator.""" + for item in items: + yield item diff --git a/tests/model_ollama_test.py b/tests/model_ollama_test.py new file mode 100644 index 0000000000..eefe1afe22 --- /dev/null +++ b/tests/model_ollama_test.py @@ -0,0 +1,367 @@ +# -*- coding: utf-8 -*- +"""Unit tests for Ollama API model class.""" +from typing import AsyncGenerator, Any +from unittest.async_case import IsolatedAsyncioTestCase +from unittest.mock import patch, AsyncMock +from pydantic import BaseModel + +from agentscope.model import OllamaChatModel, ChatResponse +from agentscope.message import TextBlock, ToolUseBlock, ThinkingBlock + + +class OllamaMessageMock: + """Mock class for Ollama message objects.""" + + def __init__( + self, + content: str = "", + thinking: str = "", + tool_calls: list = None, + ): + self.content = content + self.thinking = thinking + self.tool_calls = tool_calls or [] + + +class OllamaFunctionMock: + """Mock class for Ollama function objects.""" + + def __init__(self, name: str, arguments: dict = None): + self.name = name + self.arguments = arguments or {} + + +class OllamaToolCallMock: + """Mock class for Ollama tool call objects.""" + + def __init__( + self, + call_id: str = None, + function: OllamaFunctionMock = None, + ): + self.id = call_id + self.function = function + + +class OllamaResponseMock: + """Mock class for Ollama response objects.""" + + def __init__( + self, + message: OllamaMessageMock = None, + done: bool = True, + prompt_eval_count: int = 0, + eval_count: int = 0, + ) -> None: + self.message = message or OllamaMessageMock() + self.done = done + self.prompt_eval_count = prompt_eval_count + self.eval_count = eval_count + + def get(self, key: str, default: Any = None) -> Any: + """Mock dict-like get method.""" + return getattr(self, key, default) + + def __contains__(self, key: str) -> bool: + """Mock dict-like contains method.""" + return hasattr(self, key) + + +class SampleModel(BaseModel): + """Sample Pydantic model for testing structured output.""" + + name: str + age: int + + +class TestOllamaChatModel(IsolatedAsyncioTestCase): + """Test cases for OllamaChatModel.""" + + def test_init_default_params(self) -> None: + """Test initialization with default parameters.""" + with patch("ollama.AsyncClient") as mock_client: + model = OllamaChatModel(model_name="llama3.2") + self.assertEqual(model.model_name, "llama3.2") + self.assertFalse(model.stream) + self.assertIsNone(model.options) + self.assertEqual(model.keep_alive, "5m") + self.assertIsNone(model.think) + mock_client.assert_called_once_with(host=None) + + def test_init_with_custom_params(self) -> None: + """Test initialization with custom parameters.""" + options = {"temperature": 0.7, "top_p": 0.9} + with patch("ollama.AsyncClient") as mock_client: + model = OllamaChatModel( + model_name="qwen2.5", + stream=True, + options=options, + keep_alive="10m", + enable_thinking=True, + host="http://localhost:11434", + timeout=30, + ) + self.assertEqual(model.model_name, "qwen2.5") + self.assertTrue(model.stream) + self.assertEqual(model.options, options) + self.assertEqual(model.keep_alive, "10m") + self.assertTrue(model.think) + mock_client.assert_called_once_with( + host="http://localhost:11434", + timeout=30, + ) + + async def test_call_with_regular_model(self) -> None: + """Test calling a regular model.""" + with patch("ollama.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + model = OllamaChatModel(model_name="llama3.2", stream=False) + model.client = mock_client + + messages = [{"role": "user", "content": "Hello"}] + mock_response = self._create_mock_response( + "Hello! How can I help you?", + ) + mock_client.chat = AsyncMock(return_value=mock_response) + + result = await model(messages) + call_args = mock_client.chat.call_args[1] + self.assertEqual(call_args["model"], "llama3.2") + self.assertEqual(call_args["messages"], messages) + self.assertFalse(call_args["stream"]) + self.assertIsInstance(result, ChatResponse) + expected_content = [ + TextBlock(type="text", text="Hello! How can I help you?"), + ] + self.assertEqual(result.content, expected_content) + + async def test_call_with_tools_integration(self) -> None: + """Test full integration of tool calls.""" + with patch("ollama.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + model = OllamaChatModel(model_name="llama3.2", stream=False) + model.client = mock_client + + messages = [{"role": "user", "content": "What's the weather?"}] + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather info", + "parameters": {"type": "object"}, + }, + }, + ] + + function_mock = OllamaFunctionMock( + name="get_weather", + arguments={"location": "Beijing"}, + ) + tool_call_mock = OllamaToolCallMock( + call_id="call_123", + function=function_mock, + ) + message_mock = OllamaMessageMock( + content="I'll check the weather for you.", + tool_calls=[tool_call_mock], + ) + mock_response = self._create_mock_response_with_message( + message_mock, + ) + + mock_client.chat = AsyncMock(return_value=mock_response) + result = await model(messages, tools=tools) + + call_args = mock_client.chat.call_args[1] + self.assertIn("tools", call_args) + self.assertEqual(call_args["tools"], tools) + expected_content = [ + TextBlock(type="text", text="I'll check the weather for you."), + ToolUseBlock( + type="tool_use", + id="0_get_weather", + name="get_weather", + input={"location": "Beijing"}, + ), + ] + self.assertEqual(result.content, expected_content) + + async def test_call_with_thinking_enabled(self) -> None: + """Test calling with thinking functionality enabled.""" + with patch("ollama.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + model = OllamaChatModel( + model_name="qwen2.5", + stream=False, + enable_thinking=True, + ) + model.client = mock_client + + messages = [ + {"role": "user", "content": "Think about this problem"}, + ] + message_mock = OllamaMessageMock( + content="Here's my analysis", + thinking="Let me analyze this step by step...", + ) + mock_response = self._create_mock_response_with_message( + message_mock, + ) + + mock_client.chat = AsyncMock(return_value=mock_response) + result = await model(messages) + + call_args = mock_client.chat.call_args[1] + self.assertTrue(call_args["think"]) + expected_content = [ + ThinkingBlock( + type="thinking", + thinking="Let me analyze this step by step...", + ), + TextBlock(type="text", text="Here's my analysis"), + ] + self.assertEqual(result.content, expected_content) + + async def test_call_with_structured_model_integration(self) -> None: + """Test full integration of a structured model.""" + with patch("ollama.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + model = OllamaChatModel(model_name="llama3.2", stream=False) + model.client = mock_client + + messages = [{"role": "user", "content": "Generate a person"}] + mock_response = self._create_mock_response( + '{"name": "John", "age": 30}', + ) + mock_client.chat = AsyncMock(return_value=mock_response) + + result = await model(messages, structured_model=SampleModel) + call_args = mock_client.chat.call_args[1] + self.assertIn("format", call_args) + self.assertEqual( + call_args["format"], + SampleModel.model_json_schema(), + ) + self.assertIsInstance(result, ChatResponse) + self.assertEqual(result.metadata, {"name": "John", "age": 30}) + expected_content = [ + TextBlock(type="text", text='{"name": "John", "age": 30}'), + ] + self.assertEqual(result.content, expected_content) + + async def test_streaming_response_processing(self) -> None: + """Test processing of streaming response.""" + with patch("ollama.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + model = OllamaChatModel(model_name="llama3.2", stream=True) + model.client = mock_client + + messages = [{"role": "user", "content": "Hello"}] + chunks = [ + self._create_mock_chunk(content="Hello", done=False), + self._create_mock_chunk(content=" there!", done=True), + ] + + mock_client.chat = AsyncMock( + return_value=self._create_async_generator(chunks), + ) + result = await model(messages) + responses = [] + async for response in result: + responses.append(response) + + self.assertGreaterEqual(len(responses), 1) + final_response = responses[-1] + expected_content = [TextBlock(type="text", text="Hello there!")] + self.assertEqual(final_response.content, expected_content) + + async def test_options_integration(self) -> None: + """Test integration of options parameter.""" + with patch("ollama.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + options = {"temperature": 0.7, "top_p": 0.9} + model = OllamaChatModel( + model_name="llama3.2", + stream=False, + options=options, + ) + model.client = mock_client + + messages = [{"role": "user", "content": "Test"}] + mock_response = self._create_mock_response("Test response") + mock_client.chat = AsyncMock(return_value=mock_response) + + await model(messages, top_k=40) + + call_args = mock_client.chat.call_args[1] + self.assertEqual(call_args["options"], options) + self.assertEqual(call_args["keep_alive"], "5m") + self.assertEqual(call_args["top_k"], 40) + + # Auxiliary methods + def _create_mock_response( + self, + content: str = "", + prompt_eval_count: int = 10, + eval_count: int = 20, + ) -> OllamaResponseMock: + """Create a standard mock response.""" + message = OllamaMessageMock(content=content) + return OllamaResponseMock( + message=message, + prompt_eval_count=prompt_eval_count, + eval_count=eval_count, + ) + + def _create_mock_response_with_message( + self, + message: OllamaMessageMock, + prompt_eval_count: int = 10, + eval_count: int = 20, + ) -> OllamaResponseMock: + """Create a mock response with specific message.""" + return OllamaResponseMock( + message=message, + prompt_eval_count=prompt_eval_count, + eval_count=eval_count, + ) + + def _create_mock_chunk( + self, + content: str = "", + thinking: str = "", + tool_calls: list = None, + done: bool = True, + prompt_eval_count: int = 5, + eval_count: int = 10, + ) -> OllamaResponseMock: + """Create a mock chunk for streaming responses.""" + message = OllamaMessageMock( + content=content, + thinking=thinking, + tool_calls=tool_calls or [], + ) + return OllamaResponseMock( + message=message, + done=done, + prompt_eval_count=prompt_eval_count, + eval_count=eval_count, + ) + + async def _create_async_generator(self, items: list) -> AsyncGenerator: + """Create an asynchronous generator.""" + for item in items: + yield item diff --git a/tests/model_openai_test.py b/tests/model_openai_test.py new file mode 100644 index 0000000000..8f7d1f5b6c --- /dev/null +++ b/tests/model_openai_test.py @@ -0,0 +1,381 @@ +# -*- coding: utf-8 -*- +"""Unit tests for OpenAI API model class.""" +from typing import AsyncGenerator, Any +from unittest.async_case import IsolatedAsyncioTestCase +from unittest.mock import Mock, patch, AsyncMock +from pydantic import BaseModel + +from agentscope.model import OpenAIChatModel, ChatResponse +from agentscope.message import TextBlock, ToolUseBlock, ThinkingBlock + + +class SampleModel(BaseModel): + """Sample Pydantic model for testing structured output.""" + + name: str + age: int + + +class TestOpenAIChatModel(IsolatedAsyncioTestCase): + """Test cases for OpenAIChatModel.""" + + def test_init_default_params(self) -> None: + """Test initialization with default parameters.""" + with patch("openai.AsyncClient") as mock_client: + model = OpenAIChatModel(model_name="gpt-4", api_key="test_key") + self.assertEqual(model.model_name, "gpt-4") + self.assertTrue(model.stream) + self.assertIsNone(model.reasoning_effort) + self.assertEqual(model.generate_kwargs, {}) + mock_client.assert_called_once_with( + api_key="test_key", + organization=None, + ) + + def test_init_with_custom_params(self) -> None: + """Test initialization with custom parameters.""" + generate_kwargs = {"temperature": 0.7, "max_tokens": 1000} + client_args = {"timeout": 30} + with patch("openai.AsyncClient") as mock_client: + model = OpenAIChatModel( + model_name="gpt-4o", + api_key="test_key", + stream=False, + reasoning_effort="high", + organization="org-123", + client_args=client_args, + generate_kwargs=generate_kwargs, + ) + self.assertEqual(model.model_name, "gpt-4o") + self.assertFalse(model.stream) + self.assertEqual(model.reasoning_effort, "high") + self.assertEqual(model.generate_kwargs, generate_kwargs) + mock_client.assert_called_once_with( + api_key="test_key", + organization="org-123", + timeout=30, + ) + + async def test_call_with_regular_model(self) -> None: + """Test calling a regular model.""" + with patch("openai.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + model = OpenAIChatModel( + model_name="gpt-4", + api_key="test_key", + stream=False, + ) + model.client = mock_client + + messages = [{"role": "user", "content": "Hello"}] + mock_response = self._create_mock_response( + "Hello! How can I help you?", + ) + mock_client.chat.completions.create = AsyncMock( + return_value=mock_response, + ) + + result = await model(messages) + call_args = mock_client.chat.completions.create.call_args[1] + self.assertEqual(call_args["model"], "gpt-4") + self.assertEqual(call_args["messages"], messages) + self.assertFalse(call_args["stream"]) + self.assertIsInstance(result, ChatResponse) + expected_content = [ + TextBlock(type="text", text="Hello! How can I help you?"), + ] + self.assertEqual(result.content, expected_content) + + async def test_call_with_tools_integration(self) -> None: + """Test full integration of tool calls.""" + with patch("openai.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + model = OpenAIChatModel( + model_name="gpt-4", + api_key="test_key", + stream=False, + ) + model.client = mock_client + + messages = [{"role": "user", "content": "What's the weather?"}] + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather info", + "parameters": {"type": "object"}, + }, + }, + ] + + mock_response = self._create_mock_response_with_tools( + "I'll check the weather for you.", + [ + { + "id": "call_123", + "name": "get_weather", + "arguments": '{"location": "Beijing"}', + }, + ], + ) + mock_client.chat.completions.create = AsyncMock( + return_value=mock_response, + ) + result = await model(messages, tools=tools, tool_choice="auto") + call_args = mock_client.chat.completions.create.call_args[1] + self.assertIn("tools", call_args) + self.assertEqual(call_args["tools"], tools) + self.assertEqual(call_args["tool_choice"], "auto") + expected_content = [ + TextBlock(type="text", text="I'll check the weather for you."), + ToolUseBlock( + type="tool_use", + id="call_123", + name="get_weather", + input={"location": "Beijing"}, + ), + ] + self.assertEqual(result.content, expected_content) + + async def test_call_with_reasoning_effort(self) -> None: + """Test calling with reasoning effort enabled.""" + with patch("openai.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + model = OpenAIChatModel( + model_name="o3-mini", + api_key="test_key", + stream=False, + reasoning_effort="high", + ) + model.client = mock_client + + messages = [ + {"role": "user", "content": "Think about this problem"}, + ] + mock_response = self._create_mock_response_with_reasoning( + "Here's my analysis", + "Let me analyze this step by step...", + ) + mock_client.chat.completions.create = AsyncMock( + return_value=mock_response, + ) + result = await model(messages) + + call_args = mock_client.chat.completions.create.call_args[1] + self.assertEqual(call_args["reasoning_effort"], "high") + expected_content = [ + ThinkingBlock( + type="thinking", + thinking="Let me analyze this step by step...", + ), + TextBlock(type="text", text="Here's my analysis"), + ] + self.assertEqual(result.content, expected_content) + + async def test_call_with_structured_model_integration(self) -> None: + """Test full integration of a structured model.""" + with patch("openai.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + model = OpenAIChatModel( + model_name="gpt-4", + api_key="test_key", + stream=False, + ) + model.client = mock_client + + messages = [{"role": "user", "content": "Generate a person"}] + mock_response = self._create_mock_response_with_structured_data( + {"name": "John", "age": 30}, + ) + mock_client.chat.completions.parse = AsyncMock( + return_value=mock_response, + ) + + result = await model(messages, structured_model=SampleModel) + call_args = mock_client.chat.completions.parse.call_args[1] + self.assertEqual(call_args["response_format"], SampleModel) + self.assertNotIn("tools", call_args) + self.assertNotIn("tool_choice", call_args) + self.assertNotIn("stream", call_args) + self.assertIsInstance(result, ChatResponse) + self.assertEqual(result.metadata, {"name": "John", "age": 30}) + + async def test_streaming_response_processing(self) -> None: + """Test processing of streaming response.""" + with patch("openai.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + model = OpenAIChatModel( + model_name="gpt-4", + api_key="test_key", + stream=True, + ) + model.client = mock_client + + messages = [{"role": "user", "content": "Hello"}] + stream_mock = self._create_stream_mock( + [ + {"content": "Hello"}, + {"content": " there!"}, + ], + ) + + mock_client.chat.completions.create = AsyncMock( + return_value=stream_mock, + ) + result = await model(messages) + + call_args = mock_client.chat.completions.create.call_args[1] + self.assertEqual( + call_args["stream_options"], + {"include_usage": True}, + ) + responses = [] + async for response in result: + responses.append(response) + + self.assertGreaterEqual(len(responses), 1) + final_response = responses[-1] + expected_content = [TextBlock(type="text", text="Hello there!")] + self.assertEqual(final_response.content, expected_content) + + # Auxiliary methods - ensure all Mock objects have complete attributes + def _create_mock_response( + self, + content: str = "", + prompt_tokens: int = 10, + completion_tokens: int = 20, + ) -> Mock: + """Create a standard mock response.""" + message = Mock() + message.content = content + message.reasoning_content = None + message.tool_calls = [] + message.parsed = None + + choice = Mock() + choice.message = message + + response = Mock() + response.choices = [choice] + + usage = Mock() + usage.prompt_tokens = prompt_tokens + usage.completion_tokens = completion_tokens + response.usage = usage + return response + + def _create_mock_response_with_tools( + self, + content: str, + tool_calls: list, + ) -> Mock: + """Create a mock response with tool calls.""" + response = self._create_mock_response(content) + tool_call_mocks = [] + for tool_call in tool_calls: + tc_mock = Mock() + tc_mock.id = tool_call["id"] + tc_mock.function = Mock() + tc_mock.function.name = tool_call["name"] + tc_mock.function.arguments = tool_call["arguments"] + tool_call_mocks.append(tc_mock) + response.choices[0].message.tool_calls = tool_call_mocks + return response + + def _create_mock_response_with_reasoning( + self, + content: str, + reasoning_content: str, + ) -> Mock: + """Create a mock response with reasoning content.""" + response = self._create_mock_response(content) + response.choices[0].message.reasoning_content = reasoning_content + return response + + def _create_mock_response_with_structured_data(self, data: dict) -> Mock: + """Create a mock response with structured data.""" + message = Mock() + message.parsed = Mock() + message.parsed.model_dump.return_value = data + message.content = None + message.reasoning_content = None + message.tool_calls = [] + + choice = Mock() + choice.message = message + + response = Mock() + response.choices = [choice] + response.usage = None + + return response + + def _create_stream_mock(self, chunks_data: list) -> Any: + """Create a mock stream with proper async context management.""" + + class MockStream: + """Mock stream class.""" + + def __init__(self, chunks_data: list) -> None: + self.chunks_data = chunks_data + self.index = 0 + + async def __aenter__(self) -> "MockStream": + return self + + async def __aexit__( + self, + exc_type: Any, + exc_val: Any, + exc_tb: Any, + ) -> None: + pass + + def __aiter__(self) -> "MockStream": + return self + + async def __anext__(self) -> AsyncGenerator: + if self.index >= len(self.chunks_data): + raise StopAsyncIteration + chunk_data = self.chunks_data[self.index] + self.index += 1 + + delta = Mock() + delta.content = chunk_data.get("content") + delta.reasoning_content = chunk_data.get("reasoning_content") + if "tool_calls" in chunk_data: + tool_call_mocks = [] + for tc_data in chunk_data["tool_calls"]: + tc_mock = Mock() + tc_mock.id = tc_data["id"] + tc_mock.index = 0 + tc_mock.function = Mock() + tc_mock.function.name = tc_data["name"] + tc_mock.function.arguments = tc_data["arguments"] + tool_call_mocks.append(tc_mock) + delta.tool_calls = tool_call_mocks + else: + delta.tool_calls = [] + + choice = Mock() + choice.delta = delta + + chunk = Mock() + chunk.choices = [choice] + chunk.usage = Mock() + chunk.usage.prompt_tokens = 5 + chunk.usage.completion_tokens = 10 + return chunk + + return MockStream(chunks_data)