Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/react_agent/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

async def main() -> None:
"""The main entry point for the ReAct agent example."""
import agentscope

agentscope.init(studio_url="http://localhost:3000")

toolkit = Toolkit()
toolkit.register_tool_function(execute_shell_command)
toolkit.register_tool_function(execute_python_code)
Expand Down
5 changes: 3 additions & 2 deletions src/agentscope/_utils/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,9 @@ def _create_tool_from_base_model(
structured_model (`Type[BaseModel]`):
A Pydantic BaseModel class that defines the expected structure
for the tool's output.
tool_name (`str`, default `"format_output"`):
The name to assign to the generated tool.
tool_name (`str`, default `"generate_structured_output"`):
The tool name that used to force the LLM to generate structured
output by calling this function.

Returns:
`Dict[str, Any]`: A tool definition dictionary compatible with
Expand Down
2 changes: 1 addition & 1 deletion src/agentscope/model/_anthropic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ async def _parse_anthropic_completion_response(
structured_model: Type[BaseModel] | None = None,
) -> ChatResponse:
"""Given an Anthropic Message object, extract the content blocks and
usages from it.
usages from it.

Args:
start_datetime (`datetime`):
Expand Down
2 changes: 1 addition & 1 deletion src/agentscope/model/_dashscope_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""The dashscope API model classes."""
import collections
from datetime import datetime
Expand Down Expand Up @@ -215,6 +214,7 @@ async def __call__(

return parsed_response

# pylint: disable=too-many-branches
async def _parse_dashscope_stream_response(
self,
start_datetime: datetime,
Expand Down
2 changes: 2 additions & 0 deletions src/agentscope/model/_gemini_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,11 @@ async def __call__(

if tools:
config["tools"] = self._format_tools_json_schemas(tools)

if tool_choice:
self._validate_tool_choice(tool_choice, tools)
config["tool_config"] = self._format_tool_choice(tool_choice)

if structured_model:
if tools or tool_choice:
logger.warning(
Expand Down
114 changes: 42 additions & 72 deletions src/agentscope/model/_ollama_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""Model wrapper for Ollama models."""
from datetime import datetime
from typing import (
Expand All @@ -25,7 +24,7 @@


if TYPE_CHECKING:
from ollama._types import OllamaChatResponse
from ollama._types import ChatResponse as OllamaChatResponse
else:
OllamaChatResponse = "ollama._types.ChatResponse"

Expand Down Expand Up @@ -167,7 +166,7 @@ async def __call__(
async def _parse_ollama_stream_completion_response(
self,
start_datetime: datetime,
response: AsyncIterator[Any],
response: AsyncIterator[OllamaChatResponse],
structured_model: Type[BaseModel] | None = None,
) -> AsyncGenerator[ChatResponse, None]:
"""Given an Ollama streaming completion response, extract the
Expand All @@ -176,7 +175,7 @@ async def _parse_ollama_stream_completion_response(
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (`AsyncIterator[Any]`):
response (`AsyncIterator[OllamaChatResponse]`):
Ollama streaming response async iterator to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
Expand All @@ -199,47 +198,21 @@ async def _parse_ollama_stream_completion_response(
metadata = None

async for chunk in response:
has_new_content = False
has_new_thinking = False

# Handle text content
if hasattr(chunk, "message"):
msg = chunk.message

if getattr(msg, "thinking", None):
acc_thinking_content += msg.thinking
has_new_thinking = True

if getattr(msg, "content", None):
accumulated_text += msg.content
has_new_content = True

# Handle tool calls
if getattr(msg, "tool_calls", None):
has_new_content = True
for idx, tool_call in enumerate(msg.tool_calls):
function_name = (
getattr(
tool_call,
"function",
None,
)
and tool_call.function.name
or "tool"
)
tool_id = getattr(
tool_call,
"id",
f"{function_name}_{idx}",
)
if hasattr(tool_call, "function"):
function = tool_call.function
tool_calls[tool_id] = {
"type": "tool_use",
"id": tool_id,
"name": function.name,
"input": function.arguments,
}
msg = chunk.message
acc_thinking_content += msg.thinking or ""
accumulated_text += msg.content or ""

# Handle tool calls
for idx, tool_call in enumerate(msg.tool_calls or []):
function = tool_call.function
tool_id = f"{idx}_{function.name}"
tool_calls[tool_id] = {
"type": "tool_use",
"id": tool_id,
"name": function.name,
"input": function.arguments,
}
# Calculate usage statistics
current_time = (datetime.now() - start_datetime).total_seconds()
usage = ChatUsage(
Expand All @@ -264,26 +237,24 @@ async def _parse_ollama_stream_completion_response(
metadata = _json_loads_with_repair(accumulated_text)

# Add tool call blocks
if tool_calls:
for tool_call in tool_calls.values():
try:
input_data = tool_call["input"]
if isinstance(input_data, str):
input_data = _json_loads_with_repair(input_data)
contents.append(
ToolUseBlock(
type=tool_call["type"],
id=tool_call["id"],
name=tool_call["name"],
input=input_data,
),
)
except Exception as e:
print(f"Error parsing tool call input: {e}")
for tool_call in tool_calls.values():
try:
input_data = tool_call["input"]
if isinstance(input_data, str):
input_data = _json_loads_with_repair(input_data)
contents.append(
ToolUseBlock(
type=tool_call["type"],
id=tool_call["id"],
name=tool_call["name"],
input=input_data,
),
)
except Exception as e:
print(f"Error parsing tool call input: {e}")

# Generate response when there's new content or at final chunk
is_final = getattr(chunk, "done", False)
if (has_new_thinking or has_new_content or is_final) and contents:
if chunk.done and contents:
res = ChatResponse(
content=contents,
usage=usage,
Expand Down Expand Up @@ -338,16 +309,15 @@ async def _parse_ollama_completion_response(
if structured_model:
metadata = _json_loads_with_repair(response.message.content)

if response.message.tool_calls:
for tool_call in response.message.tool_calls:
content_blocks.append(
ToolUseBlock(
type="tool_use",
id=tool_call.function.name,
name=tool_call.function.name,
input=tool_call.function.arguments,
),
)
for idx, tool_call in enumerate(response.message.tool_calls or []):
content_blocks.append(
ToolUseBlock(
type="tool_use",
id=f"{idx}_{tool_call.function.name}",
name=tool_call.function.name,
input=tool_call.function.arguments,
),
)

usage = None
if "prompt_eval_count" in response and "eval_count" in response:
Expand Down
89 changes: 42 additions & 47 deletions src/agentscope/model/_openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ async def _parse_openai_stream_response(
chunk = item.chunk
else:
chunk = item

if chunk.usage:
usage = ChatUsage(
input_tokens=chunk.usage.prompt_tokens,
Expand All @@ -261,31 +262,26 @@ async def _parse_openai_stream_response(

if chunk.choices:
choice = chunk.choices[0]
if (
hasattr(choice.delta, "reasoning_content")
and choice.delta.reasoning_content is not None
):
thinking += choice.delta.reasoning_content

if choice.delta.content:
text += choice.delta.content

if choice.delta.tool_calls:
for tool_call in choice.delta.tool_calls:
if tool_call.index in tool_calls:
if tool_call.function.arguments is not None:
tool_calls[tool_call.index][
"input"
] += tool_call.function.arguments

else:
tool_calls[tool_call.index] = {
"type": "tool_use",
"id": tool_call.id,
"name": tool_call.function.name,
"input": tool_call.function.arguments
or "",
}

thinking += (
getattr(choice.delta, "thinking_content", None) or ""
)
text += choice.delta.content or ""

for tool_call in choice.delta.tool_calls or []:
if tool_call.index in tool_calls:
if tool_call.function.arguments is not None:
tool_calls[tool_call.index][
"input"
] += tool_call.function.arguments

else:
tool_calls[tool_call.index] = {
"type": "tool_use",
"id": tool_call.id,
"name": tool_call.function.name,
"input": tool_call.function.arguments or "",
}

contents: List[
TextBlock | ToolUseBlock | ThinkingBlock
Expand All @@ -310,18 +306,17 @@ async def _parse_openai_stream_response(
if structured_model:
metadata = _json_loads_with_repair(text)

if tool_calls:
for tool_call in tool_calls.values():
contents.append(
ToolUseBlock(
type=tool_call["type"],
id=tool_call["id"],
name=tool_call["name"],
input=_json_loads_with_repair(
tool_call["input"] or "{}",
),
for tool_call in tool_calls.values():
contents.append(
ToolUseBlock(
type=tool_call["type"],
id=tool_call["id"],
name=tool_call["name"],
input=_json_loads_with_repair(
tool_call["input"] or "{}",
),
)
),
)

if contents:
res = ChatResponse(
Expand Down Expand Up @@ -381,18 +376,18 @@ def _parse_openai_completion_response(
),
)

if choice.message.tool_calls:
for tool_call in choice.message.tool_calls:
content_blocks.append(
ToolUseBlock(
type="tool_use",
id=tool_call.id,
name=tool_call.function.name,
input=_json_loads_with_repair(
tool_call.function.arguments,
),
for tool_call in choice.message.tool_calls or []:
content_blocks.append(
ToolUseBlock(
type="tool_use",
id=tool_call.id,
name=tool_call.function.name,
input=_json_loads_with_repair(
tool_call.function.arguments,
),
)
),
)

if structured_model:
metadata = choice.message.parsed.model_dump()

Expand Down