mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-26 11:32:25 +00:00
fix(agent): resolve message media against active workspace
Made-with: Cursor
This commit is contained in:
parent
9b3e2524ac
commit
9b6f3d7abc
@ -368,7 +368,7 @@ class AgentLoop:
|
|||||||
WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy)
|
WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy)
|
||||||
)
|
)
|
||||||
self.tools.register(WebFetchTool(proxy=self.web_config.proxy))
|
self.tools.register(WebFetchTool(proxy=self.web_config.proxy))
|
||||||
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
|
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound, workspace=self.workspace))
|
||||||
self.tools.register(SpawnTool(manager=self.subagents))
|
self.tools.register(SpawnTool(manager=self.subagents))
|
||||||
if self.cron_service:
|
if self.cron_service:
|
||||||
self.tools.register(
|
self.tools.register(
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Awaitable, Callable
|
from typing import Any, Awaitable, Callable
|
||||||
|
|
||||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||||
@ -35,8 +36,10 @@ class MessageTool(Tool):
|
|||||||
default_channel: str = "",
|
default_channel: str = "",
|
||||||
default_chat_id: str = "",
|
default_chat_id: str = "",
|
||||||
default_message_id: str | None = None,
|
default_message_id: str | None = None,
|
||||||
|
workspace: str | Path | None = None,
|
||||||
):
|
):
|
||||||
self._send_callback = send_callback
|
self._send_callback = send_callback
|
||||||
|
self._workspace = Path(workspace).expanduser() if workspace is not None else get_workspace_path()
|
||||||
self._default_channel: ContextVar[str] = ContextVar("message_default_channel", default=default_channel)
|
self._default_channel: ContextVar[str] = ContextVar("message_default_channel", default=default_channel)
|
||||||
self._default_chat_id: ContextVar[str] = ContextVar("message_default_chat_id", default=default_chat_id)
|
self._default_chat_id: ContextVar[str] = ContextVar("message_default_chat_id", default=default_chat_id)
|
||||||
self._default_message_id: ContextVar[str | None] = ContextVar(
|
self._default_message_id: ContextVar[str | None] = ContextVar(
|
||||||
@ -149,7 +152,7 @@ class MessageTool(Tool):
|
|||||||
if p.startswith(("http://", "https://")) or os.path.isabs(p):
|
if p.startswith(("http://", "https://")) or os.path.isabs(p):
|
||||||
resolved.append(p)
|
resolved.append(p)
|
||||||
else:
|
else:
|
||||||
resolved.append(str(get_workspace_path() / p))
|
resolved.append(str(self._workspace / p))
|
||||||
media = resolved
|
media = resolved
|
||||||
|
|
||||||
metadata = dict(self._default_metadata.get()) if same_target else {}
|
metadata = dict(self._default_metadata.get()) if same_target else {}
|
||||||
|
|||||||
@ -110,6 +110,26 @@ async def test_message_tool_resolves_relative_media_paths() -> None:
|
|||||||
assert sent[0].media == [expected]
|
assert sent[0].media == [expected]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_message_tool_resolves_relative_media_paths_from_active_workspace(tmp_path) -> None:
|
||||||
|
sent: list[OutboundMessage] = []
|
||||||
|
|
||||||
|
async def _send(msg: OutboundMessage) -> None:
|
||||||
|
sent.append(msg)
|
||||||
|
|
||||||
|
workspace = tmp_path / "workspace"
|
||||||
|
tool = MessageTool(send_callback=_send, workspace=workspace)
|
||||||
|
|
||||||
|
await tool.execute(
|
||||||
|
content="see attached",
|
||||||
|
channel="telegram",
|
||||||
|
chat_id="1",
|
||||||
|
media=["output/image.png"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert sent[0].media == [str(workspace / "output/image.png")]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_tool_passes_through_absolute_media_paths() -> None:
|
async def test_message_tool_passes_through_absolute_media_paths() -> None:
|
||||||
sent: list[OutboundMessage] = []
|
sent: list[OutboundMessage] = []
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user