Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/app/agent/factory/browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ def browser_agent(options: Chat):
skill_toolkit = message_integration.register_toolkits(skill_toolkit)

search_tools = SearchToolkit.get_can_use_tools(
options.project_id, agent_name=Agents.browser_agent
options.project_id,
agent_name=Agents.browser_agent,
)
if search_tools:
search_tools = message_integration.register_functions(search_tools)
Expand Down
3 changes: 2 additions & 1 deletion backend/app/agent/factory/developer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ async def developer_agent(options: Chat):
skill_toolkit = message_integration.register_toolkits(skill_toolkit)

search_tools = SearchToolkit.get_can_use_tools(
options.project_id, agent_name=Agents.developer_agent
options.project_id,
agent_name=Agents.developer_agent,
)
if search_tools:
search_tools = message_integration.register_functions(search_tools)
Expand Down
3 changes: 2 additions & 1 deletion backend/app/agent/factory/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ async def document_agent(options: Chat):
skill_toolkit = message_integration.register_toolkits(skill_toolkit)

search_tools = SearchToolkit.get_can_use_tools(
options.project_id, agent_name=Agents.document_agent
options.project_id,
agent_name=Agents.document_agent,
)
if search_tools:
search_tools = message_integration.register_functions(search_tools)
Expand Down
3 changes: 2 additions & 1 deletion backend/app/agent/factory/multi_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def multi_modal_agent(options: Chat):
skill_toolkit = message_integration.register_toolkits(skill_toolkit)

search_tools = SearchToolkit.get_can_use_tools(
options.project_id, agent_name=Agents.multi_modal_agent
options.project_id,
agent_name=Agents.multi_modal_agent,
)
if search_tools:
search_tools = message_integration.register_functions(search_tools)
Expand Down
3 changes: 2 additions & 1 deletion backend/app/agent/factory/social_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ async def social_media_agent(options: Chat):
user_id=options.skill_config_user_id(),
).get_tools(),
*SearchToolkit.get_can_use_tools(
options.project_id, agent_name=Agents.social_media_agent
options.project_id,
agent_name=Agents.social_media_agent,
),
# *DiscordToolkit(options.project_id).get_tools(),
# *GoogleSuiteToolkit(options.project_id).get_tools(),
Expand Down
47 changes: 32 additions & 15 deletions backend/app/agent/toolkit/search_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,16 +165,21 @@ def cloud_search_google(
)
return res.json()

# @listen_toolkit(
# BaseSearchToolkit.search_duckduckgo,
# lambda _,
# query,
# source="text",
# max_results=5: f"Search DuckDuckGo with query '{query}', source '{source}', and max results {max_results}",
# lambda result: f"Search DuckDuckGo returned {len(result)} results",
# )
# def search_duckduckgo(self, query: str, source: str = "text", max_results: int = 5) -> list[dict[str, Any]]:
# return super().search_duckduckgo(query, source, max_results)
@listen_toolkit(
BaseSearchToolkit.search_duckduckgo,
lambda _,
query,
source="text",
number_of_result_pages=10: f"Search DuckDuckGo with query '{query}', source '{source}', and {number_of_result_pages} result pages",
lambda result: f"Search DuckDuckGo returned {len(result)} results",
)
def search_duckduckgo(
self,
query: str,
source: str = "text",
number_of_result_pages: int = 10,
) -> list[dict[str, Any]]:
return super().search_duckduckgo(query, source, number_of_result_pages)

# @listen_toolkit(
# BaseSearchToolkit.tavily_search,
Expand Down Expand Up @@ -365,9 +370,14 @@ def cloud_search_google(

@classmethod
def get_can_use_tools(
cls, api_task_id: str, agent_name: str | None = None
cls,
api_task_id: str,
agent_name: str | None = None,
) -> list[FunctionTool]:
search_toolkit = SearchToolkit(api_task_id, agent_name=agent_name)
search_toolkit = SearchToolkit(
api_task_id,
agent_name=agent_name,
)
tools = [
# FunctionTool(search_toolkit.search_wiki),
# FunctionTool(search_toolkit.search_duckduckgo),
Expand All @@ -380,10 +390,17 @@ def get_can_use_tools(
# if env("BRAVE_API_KEY"):
# tools.append(FunctionTool(search_toolkit.search_brave))

if (env("GOOGLE_API_KEY") and env("SEARCH_ENGINE_ID")) or env(
"cloud_api_key"
):
if env("GOOGLE_API_KEY") and env("SEARCH_ENGINE_ID"):
logger.info("Using search tool: search_google (user API keys)")
tools.append(FunctionTool(search_toolkit.search_google))
elif env("cloud_api_key"):
logger.info("Using search tool: search_google (cloud proxy)")
tools.append(FunctionTool(search_toolkit.search_google))
else:
logger.info(
"Using search tool: search_duckduckgo (no API keys configured)"
)
tools.append(FunctionTool(search_toolkit.search_duckduckgo))

# if env("TAVILY_API_KEY"):
# tools.append(FunctionTool(search_toolkit.tavily_search))
Expand Down
1 change: 1 addition & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"opentelemetry-api>=1.34.1",
"opentelemetry-sdk>=1.34.1",
"opentelemetry-exporter-otlp-proto-http>=1.34.1",
"duckduckgo-search>=7.0.0",
]


Expand Down
175 changes: 175 additions & 0 deletions backend/tests/app/agent/toolkit/test_search_toolkit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========

import asyncio
from unittest.mock import MagicMock, patch

import pytest

from app.agent.toolkit.search_toolkit import SearchToolkit
from app.service.task import TaskLock, task_locks

pytestmark = pytest.mark.unit

_ENV_MOD = "app.agent.toolkit.search_toolkit.env"
_ENV_NOT_EMPTY_MOD = "app.agent.toolkit.search_toolkit.env_not_empty"
_TEST_TASK_ID = "test_task_search"


def _ensure_task_lock(task_id: str = _TEST_TASK_ID):
"""Ensure a task lock exists for the given task_id."""
if task_id not in task_locks:
task_locks[task_id] = TaskLock(
id=task_id, queue=asyncio.Queue(), human_input={}
)


def test_get_can_use_tools_duckduckgo_fallback_when_no_keys():
"""When no Google API keys or cloud_api_key, DuckDuckGo is used."""
with patch(_ENV_MOD, return_value=None):
tools = SearchToolkit.get_can_use_tools("test_task")
assert len(tools) == 1
assert "duckduckgo" in tools[0].func.__name__


def test_get_can_use_tools_google_api_when_keys_present():
"""When Google API keys are present, search_google is used."""

def mock_env(key, default=None):
return {
"GOOGLE_API_KEY": "test-key",
"SEARCH_ENGINE_ID": "test-cx",
}.get(key, default)

with patch(_ENV_MOD, side_effect=mock_env):
tools = SearchToolkit.get_can_use_tools("test_task")
assert len(tools) == 1
assert "search_google" == tools[0].func.__name__


def test_get_can_use_tools_cloud_api_key():
"""When cloud_api_key is present, search_google is used."""

def mock_env(key, default=None):
return {"cloud_api_key": "cloud-key"}.get(key, default)

with patch(_ENV_MOD, side_effect=mock_env):
tools = SearchToolkit.get_can_use_tools("test_task")
assert len(tools) == 1
assert "search_google" == tools[0].func.__name__


def test_get_can_use_tools_accepts_agent_name():
"""get_can_use_tools passes agent_name to the toolkit instance."""
with patch(_ENV_MOD, return_value=None):
tools = SearchToolkit.get_can_use_tools(
"test_task", agent_name="test_agent"
)
assert len(tools) == 1


def test_search_google_uses_user_keys():
"""search_google uses user-configured API keys when available."""
_ensure_task_lock()

def mock_env(key, default=None):
return {
"GOOGLE_API_KEY": "user-key",
"SEARCH_ENGINE_ID": "user-cx",
}.get(key, default)

toolkit = SearchToolkit(_TEST_TASK_ID)
with patch(_ENV_MOD, side_effect=mock_env):
with patch.object(
SearchToolkit.__bases__[0],
"search_google",
return_value=[{"result_id": 1, "title": "test"}],
) as mock_super:
result = toolkit.search_google("test query")
mock_super.assert_called_once()
assert result == [{"result_id": 1, "title": "test"}]


def test_search_google_falls_back_to_cloud():
"""search_google falls back to cloud search when no user keys."""
_ensure_task_lock()

toolkit = SearchToolkit(_TEST_TASK_ID)
with patch(_ENV_MOD, return_value=None):
with patch.object(
toolkit,
"cloud_search_google",
return_value=[{"result_id": 1, "title": "cloud"}],
) as mock_cloud:
result = toolkit.search_google("test query")
mock_cloud.assert_called_once_with("test query", "web", 10, 1)
assert result == [{"result_id": 1, "title": "cloud"}]


def test_get_can_use_tools_google_keys_no_duckduckgo():
"""When Google API keys are present, DuckDuckGo is NOT included."""

def mock_env(key, default=None):
return {
"GOOGLE_API_KEY": "test-key",
"SEARCH_ENGINE_ID": "test-cx",
}.get(key, default)

with patch(_ENV_MOD, side_effect=mock_env):
tools = SearchToolkit.get_can_use_tools("test_task")
names = [t.func.__name__ for t in tools]
assert "duckduckgo" not in " ".join(names)


def test_search_duckduckgo_delegates_to_base():
"""search_duckduckgo delegates to the base class method."""
_ensure_task_lock()

toolkit = SearchToolkit(_TEST_TASK_ID)
expected = [{"result_id": 1, "title": "duck result"}]

with patch.object(
SearchToolkit.__bases__[0],
"search_duckduckgo",
return_value=expected,
) as mock_super:
result = toolkit.search_duckduckgo("test query")
mock_super.assert_called_once()
assert result == expected


def test_cloud_search_google_calls_server():
"""cloud_search_google makes HTTP request to server proxy."""
toolkit = SearchToolkit("test_task")

mock_response = MagicMock()
mock_response.json.return_value = [{"result_id": 1, "title": "proxied"}]

with (
patch(
_ENV_NOT_EMPTY_MOD,
side_effect=lambda k: {
"SERVER_URL": "http://test-server",
"cloud_api_key": "test-cloud-key",
}[k],
),
patch(
"app.agent.toolkit.search_toolkit.httpx.get",
return_value=mock_response,
) as mock_get,
):
result = toolkit.cloud_search_google("test query")

mock_get.assert_called_once()
assert result == [{"result_id": 1, "title": "proxied"}]
Loading