Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/agentscope/_utils/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,9 @@ def _create_tool_from_base_model(
structured_model (`Type[BaseModel]`):
A Pydantic BaseModel class that defines the expected structure
for the tool's output.
tool_name (`str`, default `"format_output"`):
The name to assign to the generated tool.
tool_name (`str`, default `"generate_structured_output"`):
The tool name that used to force the LLM to generate structured
output by calling this function.

Returns:
`Dict[str, Any]`: A tool definition dictionary compatible with
Expand Down
2 changes: 1 addition & 1 deletion src/agentscope/model/_anthropic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ async def _parse_anthropic_completion_response(
structured_model: Type[BaseModel] | None = None,
) -> ChatResponse:
"""Given an Anthropic Message object, extract the content blocks and
usages from it.
usages from it.

Args:
start_datetime (`datetime`):
Expand Down
2 changes: 1 addition & 1 deletion src/agentscope/model/_dashscope_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""The dashscope API model classes."""
import collections
from datetime import datetime
Expand Down Expand Up @@ -215,6 +214,7 @@ async def __call__(

return parsed_response

# pylint: disable=too-many-branches
async def _parse_dashscope_stream_response(
self,
start_datetime: datetime,
Expand Down
2 changes: 2 additions & 0 deletions src/agentscope/model/_gemini_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,11 @@ async def __call__(

if tools:
config["tools"] = self._format_tools_json_schemas(tools)

if tool_choice:
self._validate_tool_choice(tool_choice, tools)
config["tool_config"] = self._format_tool_choice(tool_choice)

if structured_model:
if tools or tool_choice:
logger.warning(
Expand Down
114 changes: 42 additions & 72 deletions src/agentscope/model/_ollama_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""Model wrapper for Ollama models."""
from datetime import datetime
from typing import (
Expand All @@ -25,7 +24,7 @@


if TYPE_CHECKING:
from ollama._types import OllamaChatResponse
from ollama._types import ChatResponse as OllamaChatResponse
else:
OllamaChatResponse = "ollama._types.ChatResponse"

Expand Down Expand Up @@ -167,7 +166,7 @@ async def __call__(
async def _parse_ollama_stream_completion_response(
self,
start_datetime: datetime,
response: AsyncIterator[Any],
response: AsyncIterator[OllamaChatResponse],
structured_model: Type[BaseModel] | None = None,
) -> AsyncGenerator[ChatResponse, None]:
"""Given an Ollama streaming completion response, extract the
Expand All @@ -176,7 +175,7 @@ async def _parse_ollama_stream_completion_response(
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (`AsyncIterator[Any]`):
response (`AsyncIterator[OllamaChatResponse]`):
Ollama streaming response async iterator to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
Expand All @@ -199,47 +198,21 @@ async def _parse_ollama_stream_completion_response(
metadata = None

async for chunk in response:
has_new_content = False
has_new_thinking = False

# Handle text content
if hasattr(chunk, "message"):
msg = chunk.message

if getattr(msg, "thinking", None):
acc_thinking_content += msg.thinking
has_new_thinking = True

if getattr(msg, "content", None):
accumulated_text += msg.content
has_new_content = True

# Handle tool calls
if getattr(msg, "tool_calls", None):
has_new_content = True
for idx, tool_call in enumerate(msg.tool_calls):
function_name = (
getattr(
tool_call,
"function",
None,
)
and tool_call.function.name
or "tool"
)
tool_id = getattr(
tool_call,
"id",
f"{function_name}_{idx}",
)
if hasattr(tool_call, "function"):
function = tool_call.function
tool_calls[tool_id] = {
"type": "tool_use",
"id": tool_id,
"name": function.name,
"input": function.arguments,
}
msg = chunk.message
acc_thinking_content += msg.thinking or ""
accumulated_text += msg.content or ""

# Handle tool calls
for idx, tool_call in enumerate(msg.tool_calls or []):
function = tool_call.function
tool_id = f"{idx}_{function.name}"
tool_calls[tool_id] = {
"type": "tool_use",
"id": tool_id,
"name": function.name,
"input": function.arguments,
}
# Calculate usage statistics
current_time = (datetime.now() - start_datetime).total_seconds()
usage = ChatUsage(
Expand All @@ -264,26 +237,24 @@ async def _parse_ollama_stream_completion_response(
metadata = _json_loads_with_repair(accumulated_text)

# Add tool call blocks
if tool_calls:
for tool_call in tool_calls.values():
try:
input_data = tool_call["input"]
if isinstance(input_data, str):
input_data = _json_loads_with_repair(input_data)
contents.append(
ToolUseBlock(
type=tool_call["type"],
id=tool_call["id"],
name=tool_call["name"],
input=input_data,
),
)
except Exception as e:
print(f"Error parsing tool call input: {e}")
for tool_call in tool_calls.values():
try:
input_data = tool_call["input"]
if isinstance(input_data, str):
input_data = _json_loads_with_repair(input_data)
contents.append(
ToolUseBlock(
type=tool_call["type"],
id=tool_call["id"],
name=tool_call["name"],
input=input_data,
),
)
except Exception as e:
print(f"Error parsing tool call input: {e}")

# Generate response when there's new content or at final chunk
is_final = getattr(chunk, "done", False)
if (has_new_thinking or has_new_content or is_final) and contents:
if chunk.done and contents:
res = ChatResponse(
content=contents,
usage=usage,
Expand Down Expand Up @@ -338,16 +309,15 @@ async def _parse_ollama_completion_response(
if structured_model:
metadata = _json_loads_with_repair(response.message.content)

if response.message.tool_calls:
for tool_call in response.message.tool_calls:
content_blocks.append(
ToolUseBlock(
type="tool_use",
id=tool_call.function.name,
name=tool_call.function.name,
input=tool_call.function.arguments,
),
)
for idx, tool_call in enumerate(response.message.tool_calls or []):
content_blocks.append(
ToolUseBlock(
type="tool_use",
id=f"{idx}_{tool_call.function.name}",
name=tool_call.function.name,
input=tool_call.function.arguments,
),
)

usage = None
if "prompt_eval_count" in response and "eval_count" in response:
Expand Down
89 changes: 42 additions & 47 deletions src/agentscope/model/_openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ async def _parse_openai_stream_response(
chunk = item.chunk
else:
chunk = item

if chunk.usage:
usage = ChatUsage(
input_tokens=chunk.usage.prompt_tokens,
Expand All @@ -261,31 +262,26 @@ async def _parse_openai_stream_response(

if chunk.choices:
choice = chunk.choices[0]
if (
hasattr(choice.delta, "reasoning_content")
and choice.delta.reasoning_content is not None
):
thinking += choice.delta.reasoning_content

if choice.delta.content:
text += choice.delta.content

if choice.delta.tool_calls:
for tool_call in choice.delta.tool_calls:
if tool_call.index in tool_calls:
if tool_call.function.arguments is not None:
tool_calls[tool_call.index][
"input"
] += tool_call.function.arguments

else:
tool_calls[tool_call.index] = {
"type": "tool_use",
"id": tool_call.id,
"name": tool_call.function.name,
"input": tool_call.function.arguments
or "",
}

thinking += (
getattr(choice.delta, "reasoning_content", None) or ""
)
text += choice.delta.content or ""

for tool_call in choice.delta.tool_calls or []:
if tool_call.index in tool_calls:
if tool_call.function.arguments is not None:
tool_calls[tool_call.index][
"input"
] += tool_call.function.arguments

else:
tool_calls[tool_call.index] = {
"type": "tool_use",
"id": tool_call.id,
"name": tool_call.function.name,
"input": tool_call.function.arguments or "",
}

contents: List[
TextBlock | ToolUseBlock | ThinkingBlock
Expand All @@ -310,18 +306,17 @@ async def _parse_openai_stream_response(
if structured_model:
metadata = _json_loads_with_repair(text)

if tool_calls:
for tool_call in tool_calls.values():
contents.append(
ToolUseBlock(
type=tool_call["type"],
id=tool_call["id"],
name=tool_call["name"],
input=_json_loads_with_repair(
tool_call["input"] or "{}",
),
for tool_call in tool_calls.values():
contents.append(
ToolUseBlock(
type=tool_call["type"],
id=tool_call["id"],
name=tool_call["name"],
input=_json_loads_with_repair(
tool_call["input"] or "{}",
),
)
),
)

if contents:
res = ChatResponse(
Expand Down Expand Up @@ -381,18 +376,18 @@ def _parse_openai_completion_response(
),
)

if choice.message.tool_calls:
for tool_call in choice.message.tool_calls:
content_blocks.append(
ToolUseBlock(
type="tool_use",
id=tool_call.id,
name=tool_call.function.name,
input=_json_loads_with_repair(
tool_call.function.arguments,
),
for tool_call in choice.message.tool_calls or []:
content_blocks.append(
ToolUseBlock(
type="tool_use",
id=tool_call.id,
name=tool_call.function.name,
input=_json_loads_with_repair(
tool_call.function.arguments,
),
)
),
)

if structured_model:
metadata = choice.message.parsed.model_dump()

Expand Down