fix(web): serialize duckduckgo search calls

This commit is contained in:
yeyitech 2026-04-14 13:30:18 +08:00 committed by Xubin Ren
parent a38bc637bd
commit ee061f0595
3 changed files with 102 additions and 6 deletions

View File

@ -96,10 +96,37 @@ class WebSearchTool(Tool):
self.config = config if config is not None else WebSearchConfig() self.config = config if config is not None else WebSearchConfig()
self.proxy = proxy self.proxy = proxy
def _effective_provider(self) -> str:
"""Resolve the backend that execute() will actually use."""
provider = self.config.provider.strip().lower() or "brave"
if provider == "duckduckgo":
return "duckduckgo"
if provider == "brave":
api_key = self.config.api_key or os.environ.get("BRAVE_API_KEY", "")
return "brave" if api_key else "duckduckgo"
if provider == "tavily":
api_key = self.config.api_key or os.environ.get("TAVILY_API_KEY", "")
return "tavily" if api_key else "duckduckgo"
if provider == "searxng":
base_url = (self.config.base_url or os.environ.get("SEARXNG_BASE_URL", "")).strip()
return "searxng" if base_url else "duckduckgo"
if provider == "jina":
api_key = self.config.api_key or os.environ.get("JINA_API_KEY", "")
return "jina" if api_key else "duckduckgo"
if provider == "kagi":
api_key = self.config.api_key or os.environ.get("KAGI_API_KEY", "")
return "kagi" if api_key else "duckduckgo"
return provider
@property @property
def read_only(self) -> bool: def read_only(self) -> bool:
return True return True
@property
def exclusive(self) -> bool:
"""DuckDuckGo searches are serialized because ddgs is not concurrency-safe."""
return self._effective_provider() == "duckduckgo"
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str: async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
provider = self.config.provider.strip().lower() or "brave" provider = self.config.provider.strip().lower() or "brave"
n = min(max(count or self.config.max_results, 1), 10) n = min(max(count or self.config.max_results, 1), 10)

View File

@ -689,11 +689,20 @@ async def test_runner_keeps_going_when_tool_result_persistence_fails():
class _DelayTool(Tool): class _DelayTool(Tool):
def __init__(self, name: str, *, delay: float, read_only: bool, shared_events: list[str]): def __init__(
self,
name: str,
*,
delay: float,
read_only: bool,
shared_events: list[str],
exclusive: bool = False,
):
self._name = name self._name = name
self._delay = delay self._delay = delay
self._read_only = read_only self._read_only = read_only
self._shared_events = shared_events self._shared_events = shared_events
self._exclusive = exclusive
@property @property
def name(self) -> str: def name(self) -> str:
@ -711,6 +720,10 @@ class _DelayTool(Tool):
def read_only(self) -> bool: def read_only(self) -> bool:
return self._read_only return self._read_only
@property
def exclusive(self) -> bool:
return self._exclusive
async def execute(self, **kwargs): async def execute(self, **kwargs):
self._shared_events.append(f"start:{self._name}") self._shared_events.append(f"start:{self._name}")
await asyncio.sleep(self._delay) await asyncio.sleep(self._delay)
@ -756,6 +769,48 @@ async def test_runner_batches_read_only_tools_before_exclusive_work():
assert shared_events[-2:] == ["start:write_a", "end:write_a"] assert shared_events[-2:] == ["start:write_a", "end:write_a"]
@pytest.mark.asyncio
async def test_runner_does_not_batch_exclusive_read_only_tools():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
tools = ToolRegistry()
shared_events: list[str] = []
read_a = _DelayTool("read_a", delay=0.03, read_only=True, shared_events=shared_events)
read_b = _DelayTool("read_b", delay=0.03, read_only=True, shared_events=shared_events)
ddg_like = _DelayTool(
"ddg_like",
delay=0.01,
read_only=True,
shared_events=shared_events,
exclusive=True,
)
tools.register(read_a)
tools.register(ddg_like)
tools.register(read_b)
runner = AgentRunner(MagicMock())
await runner._execute_tools(
AgentRunSpec(
initial_messages=[],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
concurrent_tools=True,
),
[
ToolCallRequest(id="ro1", name="read_a", arguments={}),
ToolCallRequest(id="ddg1", name="ddg_like", arguments={}),
ToolCallRequest(id="ro2", name="read_b", arguments={}),
],
{},
)
assert shared_events[0] == "start:read_a"
assert shared_events.index("end:read_a") < shared_events.index("start:ddg_like")
assert shared_events.index("end:ddg_like") < shared_events.index("start:read_b")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_runner_blocks_repeated_external_fetches(): async def test_runner_blocks_repeated_external_fetches():
from nanobot.agent.runner import AgentRunSpec, AgentRunner from nanobot.agent.runner import AgentRunSpec, AgentRunner

View File

@ -1,7 +1,5 @@
"""Tests for multi-provider web search.""" """Tests for multi-provider web search."""
import asyncio
import httpx import httpx
import pytest import pytest
@ -20,6 +18,25 @@ def _response(status: int = 200, json: dict | None = None) -> httpx.Response:
return r return r
def test_duckduckgo_search_is_exclusive():
tool = _tool(provider="duckduckgo")
assert tool.exclusive is True
assert tool.concurrency_safe is False
def test_brave_with_api_key_remains_concurrency_safe():
tool = _tool(provider="brave", api_key="brave-key")
assert tool.exclusive is False
assert tool.concurrency_safe is True
def test_brave_without_api_key_is_treated_as_duckduckgo_for_concurrency(monkeypatch):
monkeypatch.delenv("BRAVE_API_KEY", raising=False)
tool = _tool(provider="brave", api_key="")
assert tool.exclusive is True
assert tool.concurrency_safe is False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_brave_search(monkeypatch): async def test_brave_search(monkeypatch):
async def mock_get(self, url, **kw): async def mock_get(self, url, **kw):
@ -79,7 +96,6 @@ async def test_duckduckgo_search(monkeypatch):
import nanobot.agent.tools.web as web_mod import nanobot.agent.tools.web as web_mod
monkeypatch.setattr(web_mod, "DDGS", MockDDGS, raising=False) monkeypatch.setattr(web_mod, "DDGS", MockDDGS, raising=False)
from ddgs import DDGS
monkeypatch.setattr("ddgs.DDGS", MockDDGS) monkeypatch.setattr("ddgs.DDGS", MockDDGS)
tool = _tool(provider="duckduckgo") tool = _tool(provider="duckduckgo")
@ -265,5 +281,3 @@ async def test_duckduckgo_timeout_returns_error(monkeypatch):
result = await tool.execute(query="test") result = await tool.execute(query="test")
gate.set() gate.set()
assert "Error" in result assert "Error" in result