mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-31 14:01:17 +00:00
fix(web): serialize duckduckgo search calls
This commit is contained in:
parent
a38bc637bd
commit
ee061f0595
@ -96,10 +96,37 @@ class WebSearchTool(Tool):
|
||||
self.config = config if config is not None else WebSearchConfig()
|
||||
self.proxy = proxy
|
||||
|
||||
def _effective_provider(self) -> str:
|
||||
"""Resolve the backend that execute() will actually use."""
|
||||
provider = self.config.provider.strip().lower() or "brave"
|
||||
if provider == "duckduckgo":
|
||||
return "duckduckgo"
|
||||
if provider == "brave":
|
||||
api_key = self.config.api_key or os.environ.get("BRAVE_API_KEY", "")
|
||||
return "brave" if api_key else "duckduckgo"
|
||||
if provider == "tavily":
|
||||
api_key = self.config.api_key or os.environ.get("TAVILY_API_KEY", "")
|
||||
return "tavily" if api_key else "duckduckgo"
|
||||
if provider == "searxng":
|
||||
base_url = (self.config.base_url or os.environ.get("SEARXNG_BASE_URL", "")).strip()
|
||||
return "searxng" if base_url else "duckduckgo"
|
||||
if provider == "jina":
|
||||
api_key = self.config.api_key or os.environ.get("JINA_API_KEY", "")
|
||||
return "jina" if api_key else "duckduckgo"
|
||||
if provider == "kagi":
|
||||
api_key = self.config.api_key or os.environ.get("KAGI_API_KEY", "")
|
||||
return "kagi" if api_key else "duckduckgo"
|
||||
return provider
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def exclusive(self) -> bool:
|
||||
"""DuckDuckGo searches are serialized because ddgs is not concurrency-safe."""
|
||||
return self._effective_provider() == "duckduckgo"
|
||||
|
||||
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
||||
provider = self.config.provider.strip().lower() or "brave"
|
||||
n = min(max(count or self.config.max_results, 1), 10)
|
||||
|
||||
@ -689,11 +689,20 @@ async def test_runner_keeps_going_when_tool_result_persistence_fails():
|
||||
|
||||
|
||||
class _DelayTool(Tool):
|
||||
def __init__(self, name: str, *, delay: float, read_only: bool, shared_events: list[str]):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
delay: float,
|
||||
read_only: bool,
|
||||
shared_events: list[str],
|
||||
exclusive: bool = False,
|
||||
):
|
||||
self._name = name
|
||||
self._delay = delay
|
||||
self._read_only = read_only
|
||||
self._shared_events = shared_events
|
||||
self._exclusive = exclusive
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@ -711,6 +720,10 @@ class _DelayTool(Tool):
|
||||
def read_only(self) -> bool:
|
||||
return self._read_only
|
||||
|
||||
@property
|
||||
def exclusive(self) -> bool:
|
||||
return self._exclusive
|
||||
|
||||
async def execute(self, **kwargs):
|
||||
self._shared_events.append(f"start:{self._name}")
|
||||
await asyncio.sleep(self._delay)
|
||||
@ -756,6 +769,48 @@ async def test_runner_batches_read_only_tools_before_exclusive_work():
|
||||
assert shared_events[-2:] == ["start:write_a", "end:write_a"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_does_not_batch_exclusive_read_only_tools():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
tools = ToolRegistry()
|
||||
shared_events: list[str] = []
|
||||
read_a = _DelayTool("read_a", delay=0.03, read_only=True, shared_events=shared_events)
|
||||
read_b = _DelayTool("read_b", delay=0.03, read_only=True, shared_events=shared_events)
|
||||
ddg_like = _DelayTool(
|
||||
"ddg_like",
|
||||
delay=0.01,
|
||||
read_only=True,
|
||||
shared_events=shared_events,
|
||||
exclusive=True,
|
||||
)
|
||||
tools.register(read_a)
|
||||
tools.register(ddg_like)
|
||||
tools.register(read_b)
|
||||
|
||||
runner = AgentRunner(MagicMock())
|
||||
await runner._execute_tools(
|
||||
AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
concurrent_tools=True,
|
||||
),
|
||||
[
|
||||
ToolCallRequest(id="ro1", name="read_a", arguments={}),
|
||||
ToolCallRequest(id="ddg1", name="ddg_like", arguments={}),
|
||||
ToolCallRequest(id="ro2", name="read_b", arguments={}),
|
||||
],
|
||||
{},
|
||||
)
|
||||
|
||||
assert shared_events[0] == "start:read_a"
|
||||
assert shared_events.index("end:read_a") < shared_events.index("start:ddg_like")
|
||||
assert shared_events.index("end:ddg_like") < shared_events.index("start:read_b")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_blocks_repeated_external_fetches():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
"""Tests for multi-provider web search."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
@ -20,6 +18,25 @@ def _response(status: int = 200, json: dict | None = None) -> httpx.Response:
|
||||
return r
|
||||
|
||||
|
||||
def test_duckduckgo_search_is_exclusive():
|
||||
tool = _tool(provider="duckduckgo")
|
||||
assert tool.exclusive is True
|
||||
assert tool.concurrency_safe is False
|
||||
|
||||
|
||||
def test_brave_with_api_key_remains_concurrency_safe():
|
||||
tool = _tool(provider="brave", api_key="brave-key")
|
||||
assert tool.exclusive is False
|
||||
assert tool.concurrency_safe is True
|
||||
|
||||
|
||||
def test_brave_without_api_key_is_treated_as_duckduckgo_for_concurrency(monkeypatch):
|
||||
monkeypatch.delenv("BRAVE_API_KEY", raising=False)
|
||||
tool = _tool(provider="brave", api_key="")
|
||||
assert tool.exclusive is True
|
||||
assert tool.concurrency_safe is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brave_search(monkeypatch):
|
||||
async def mock_get(self, url, **kw):
|
||||
@ -79,7 +96,6 @@ async def test_duckduckgo_search(monkeypatch):
|
||||
import nanobot.agent.tools.web as web_mod
|
||||
monkeypatch.setattr(web_mod, "DDGS", MockDDGS, raising=False)
|
||||
|
||||
from ddgs import DDGS
|
||||
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
|
||||
|
||||
tool = _tool(provider="duckduckgo")
|
||||
@ -265,5 +281,3 @@ async def test_duckduckgo_timeout_returns_error(monkeypatch):
|
||||
result = await tool.execute(query="test")
|
||||
gate.set()
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user