diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index dce1890af..f6bd5e2e7 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -368,7 +368,7 @@ class AgentLoop: WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy) ) self.tools.register(WebFetchTool(proxy=self.web_config.proxy)) - self.tools.register(MessageTool(send_callback=self.bus.publish_outbound)) + self.tools.register(MessageTool(send_callback=self.bus.publish_outbound, workspace=self.workspace)) self.tools.register(SpawnTool(manager=self.subagents)) if self.cron_service: self.tools.register( diff --git a/nanobot/agent/tools/message.py b/nanobot/agent/tools/message.py index f35f8fe23..6e3d037f0 100644 --- a/nanobot/agent/tools/message.py +++ b/nanobot/agent/tools/message.py @@ -2,6 +2,7 @@ import os from contextvars import ContextVar +from pathlib import Path from typing import Any, Awaitable, Callable from nanobot.agent.tools.base import Tool, tool_parameters @@ -35,8 +36,10 @@ class MessageTool(Tool): default_channel: str = "", default_chat_id: str = "", default_message_id: str | None = None, + workspace: str | Path | None = None, ): self._send_callback = send_callback + self._workspace = Path(workspace).expanduser() if workspace is not None else get_workspace_path() self._default_channel: ContextVar[str] = ContextVar("message_default_channel", default=default_channel) self._default_chat_id: ContextVar[str] = ContextVar("message_default_chat_id", default=default_chat_id) self._default_message_id: ContextVar[str | None] = ContextVar( @@ -149,7 +152,7 @@ class MessageTool(Tool): if p.startswith(("http://", "https://")) or os.path.isabs(p): resolved.append(p) else: - resolved.append(str(get_workspace_path() / p)) + resolved.append(str(self._workspace / p)) media = resolved metadata = dict(self._default_metadata.get()) if same_target else {} diff --git a/tests/tools/test_message_tool.py b/tests/tools/test_message_tool.py index d93219f04..915fb0c98 100644 --- a/tests/tools/test_message_tool.py +++ b/tests/tools/test_message_tool.py @@ -110,6 +110,26 @@ async def test_message_tool_resolves_relative_media_paths() -> None: assert sent[0].media == [expected] +@pytest.mark.asyncio +async def test_message_tool_resolves_relative_media_paths_from_active_workspace(tmp_path) -> None: + sent: list[OutboundMessage] = [] + + async def _send(msg: OutboundMessage) -> None: + sent.append(msg) + + workspace = tmp_path / "workspace" + tool = MessageTool(send_callback=_send, workspace=workspace) + + await tool.execute( + content="see attached", + channel="telegram", + chat_id="1", + media=["output/image.png"], + ) + + assert sent[0].media == [str(workspace / "output/image.png")] + + @pytest.mark.asyncio async def test_message_tool_passes_through_absolute_media_paths() -> None: sent: list[OutboundMessage] = []