mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-31 14:01:17 +00:00
fix(web): serialize duckduckgo search calls
This commit is contained in:
parent
a38bc637bd
commit
ee061f0595
@ -96,10 +96,37 @@ class WebSearchTool(Tool):
|
|||||||
self.config = config if config is not None else WebSearchConfig()
|
self.config = config if config is not None else WebSearchConfig()
|
||||||
self.proxy = proxy
|
self.proxy = proxy
|
||||||
|
|
||||||
|
def _effective_provider(self) -> str:
|
||||||
|
"""Resolve the backend that execute() will actually use."""
|
||||||
|
provider = self.config.provider.strip().lower() or "brave"
|
||||||
|
if provider == "duckduckgo":
|
||||||
|
return "duckduckgo"
|
||||||
|
if provider == "brave":
|
||||||
|
api_key = self.config.api_key or os.environ.get("BRAVE_API_KEY", "")
|
||||||
|
return "brave" if api_key else "duckduckgo"
|
||||||
|
if provider == "tavily":
|
||||||
|
api_key = self.config.api_key or os.environ.get("TAVILY_API_KEY", "")
|
||||||
|
return "tavily" if api_key else "duckduckgo"
|
||||||
|
if provider == "searxng":
|
||||||
|
base_url = (self.config.base_url or os.environ.get("SEARXNG_BASE_URL", "")).strip()
|
||||||
|
return "searxng" if base_url else "duckduckgo"
|
||||||
|
if provider == "jina":
|
||||||
|
api_key = self.config.api_key or os.environ.get("JINA_API_KEY", "")
|
||||||
|
return "jina" if api_key else "duckduckgo"
|
||||||
|
if provider == "kagi":
|
||||||
|
api_key = self.config.api_key or os.environ.get("KAGI_API_KEY", "")
|
||||||
|
return "kagi" if api_key else "duckduckgo"
|
||||||
|
return provider
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def read_only(self) -> bool:
|
def read_only(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exclusive(self) -> bool:
|
||||||
|
"""DuckDuckGo searches are serialized because ddgs is not concurrency-safe."""
|
||||||
|
return self._effective_provider() == "duckduckgo"
|
||||||
|
|
||||||
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
||||||
provider = self.config.provider.strip().lower() or "brave"
|
provider = self.config.provider.strip().lower() or "brave"
|
||||||
n = min(max(count or self.config.max_results, 1), 10)
|
n = min(max(count or self.config.max_results, 1), 10)
|
||||||
|
|||||||
@ -689,11 +689,20 @@ async def test_runner_keeps_going_when_tool_result_persistence_fails():
|
|||||||
|
|
||||||
|
|
||||||
class _DelayTool(Tool):
|
class _DelayTool(Tool):
|
||||||
def __init__(self, name: str, *, delay: float, read_only: bool, shared_events: list[str]):
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
*,
|
||||||
|
delay: float,
|
||||||
|
read_only: bool,
|
||||||
|
shared_events: list[str],
|
||||||
|
exclusive: bool = False,
|
||||||
|
):
|
||||||
self._name = name
|
self._name = name
|
||||||
self._delay = delay
|
self._delay = delay
|
||||||
self._read_only = read_only
|
self._read_only = read_only
|
||||||
self._shared_events = shared_events
|
self._shared_events = shared_events
|
||||||
|
self._exclusive = exclusive
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@ -711,6 +720,10 @@ class _DelayTool(Tool):
|
|||||||
def read_only(self) -> bool:
|
def read_only(self) -> bool:
|
||||||
return self._read_only
|
return self._read_only
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exclusive(self) -> bool:
|
||||||
|
return self._exclusive
|
||||||
|
|
||||||
async def execute(self, **kwargs):
|
async def execute(self, **kwargs):
|
||||||
self._shared_events.append(f"start:{self._name}")
|
self._shared_events.append(f"start:{self._name}")
|
||||||
await asyncio.sleep(self._delay)
|
await asyncio.sleep(self._delay)
|
||||||
@ -756,6 +769,48 @@ async def test_runner_batches_read_only_tools_before_exclusive_work():
|
|||||||
assert shared_events[-2:] == ["start:write_a", "end:write_a"]
|
assert shared_events[-2:] == ["start:write_a", "end:write_a"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_does_not_batch_exclusive_read_only_tools():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
tools = ToolRegistry()
|
||||||
|
shared_events: list[str] = []
|
||||||
|
read_a = _DelayTool("read_a", delay=0.03, read_only=True, shared_events=shared_events)
|
||||||
|
read_b = _DelayTool("read_b", delay=0.03, read_only=True, shared_events=shared_events)
|
||||||
|
ddg_like = _DelayTool(
|
||||||
|
"ddg_like",
|
||||||
|
delay=0.01,
|
||||||
|
read_only=True,
|
||||||
|
shared_events=shared_events,
|
||||||
|
exclusive=True,
|
||||||
|
)
|
||||||
|
tools.register(read_a)
|
||||||
|
tools.register(ddg_like)
|
||||||
|
tools.register(read_b)
|
||||||
|
|
||||||
|
runner = AgentRunner(MagicMock())
|
||||||
|
await runner._execute_tools(
|
||||||
|
AgentRunSpec(
|
||||||
|
initial_messages=[],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
concurrent_tools=True,
|
||||||
|
),
|
||||||
|
[
|
||||||
|
ToolCallRequest(id="ro1", name="read_a", arguments={}),
|
||||||
|
ToolCallRequest(id="ddg1", name="ddg_like", arguments={}),
|
||||||
|
ToolCallRequest(id="ro2", name="read_b", arguments={}),
|
||||||
|
],
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert shared_events[0] == "start:read_a"
|
||||||
|
assert shared_events.index("end:read_a") < shared_events.index("start:ddg_like")
|
||||||
|
assert shared_events.index("end:ddg_like") < shared_events.index("start:read_b")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_runner_blocks_repeated_external_fetches():
|
async def test_runner_blocks_repeated_external_fetches():
|
||||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
"""Tests for multi-provider web search."""
|
"""Tests for multi-provider web search."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -20,6 +18,25 @@ def _response(status: int = 200, json: dict | None = None) -> httpx.Response:
|
|||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
def test_duckduckgo_search_is_exclusive():
|
||||||
|
tool = _tool(provider="duckduckgo")
|
||||||
|
assert tool.exclusive is True
|
||||||
|
assert tool.concurrency_safe is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_brave_with_api_key_remains_concurrency_safe():
|
||||||
|
tool = _tool(provider="brave", api_key="brave-key")
|
||||||
|
assert tool.exclusive is False
|
||||||
|
assert tool.concurrency_safe is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_brave_without_api_key_is_treated_as_duckduckgo_for_concurrency(monkeypatch):
|
||||||
|
monkeypatch.delenv("BRAVE_API_KEY", raising=False)
|
||||||
|
tool = _tool(provider="brave", api_key="")
|
||||||
|
assert tool.exclusive is True
|
||||||
|
assert tool.concurrency_safe is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_brave_search(monkeypatch):
|
async def test_brave_search(monkeypatch):
|
||||||
async def mock_get(self, url, **kw):
|
async def mock_get(self, url, **kw):
|
||||||
@ -79,7 +96,6 @@ async def test_duckduckgo_search(monkeypatch):
|
|||||||
import nanobot.agent.tools.web as web_mod
|
import nanobot.agent.tools.web as web_mod
|
||||||
monkeypatch.setattr(web_mod, "DDGS", MockDDGS, raising=False)
|
monkeypatch.setattr(web_mod, "DDGS", MockDDGS, raising=False)
|
||||||
|
|
||||||
from ddgs import DDGS
|
|
||||||
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
|
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
|
||||||
|
|
||||||
tool = _tool(provider="duckduckgo")
|
tool = _tool(provider="duckduckgo")
|
||||||
@ -265,5 +281,3 @@ async def test_duckduckgo_timeout_returns_error(monkeypatch):
|
|||||||
result = await tool.execute(query="test")
|
result = await tool.execute(query="test")
|
||||||
gate.set()
|
gate.set()
|
||||||
assert "Error" in result
|
assert "Error" in result
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user