mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 22:34:06 +00:00
Extract is_image_file() and reference_non_image_attachments() from AgentLoop private static methods into nanobot/utils/document.py where they belong alongside extract_documents(). Simplify config lookup by removing dead isinstance(dict) branch.
170 lines
5.6 KiB
Python
170 lines
5.6 KiB
Python
import asyncio
|
|
import base64
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from nanobot.agent.loop import AgentLoop, TurnContext, TurnState
|
|
from nanobot.bus.events import InboundMessage
|
|
from nanobot.bus.queue import MessageBus
|
|
from nanobot.config.schema import ChannelsConfig
|
|
from nanobot.providers.base import LLMResponse
|
|
from nanobot.utils.document import reference_non_image_attachments
|
|
|
|
|
|
def _make_loop(tmp_path: Path, channels_config: ChannelsConfig | None = None) -> AgentLoop:
|
|
provider = MagicMock()
|
|
provider.get_default_model.return_value = "test-model"
|
|
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok"))
|
|
return AgentLoop(
|
|
bus=MessageBus(),
|
|
provider=provider,
|
|
workspace=tmp_path,
|
|
model="test-model",
|
|
channels_config=channels_config,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_state_restore_extracts_documents_by_default(
|
|
tmp_path: Path,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
loop = _make_loop(tmp_path)
|
|
doc_path = tmp_path / "report.txt"
|
|
doc_path.write_text("Quarterly revenue is $5M", encoding="utf-8")
|
|
calls: list[tuple[str, list[str]]] = []
|
|
|
|
def fake_extract_documents(content: str, media: list[str]) -> tuple[str, list[str]]:
|
|
calls.append((content, media))
|
|
return f"{content}\n\n[File: report.txt]\nQuarterly revenue is $5M", []
|
|
|
|
monkeypatch.setattr("nanobot.agent.loop.extract_documents", fake_extract_documents)
|
|
|
|
ctx = TurnContext(
|
|
msg=InboundMessage(
|
|
channel="cli",
|
|
sender_id="u",
|
|
chat_id="c",
|
|
content="summarize",
|
|
media=[str(doc_path)],
|
|
),
|
|
session_key="cli:c",
|
|
state=TurnState.RESTORE,
|
|
turn_id="turn-1",
|
|
)
|
|
|
|
assert await loop._state_restore(ctx) == "ok"
|
|
|
|
assert calls == [("summarize", [str(doc_path)])]
|
|
assert "Quarterly revenue" in ctx.msg.content
|
|
assert ctx.msg.media == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_state_restore_references_documents_when_extraction_disabled(
|
|
tmp_path: Path,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
loop = _make_loop(tmp_path, ChannelsConfig(extract_document_text=False))
|
|
doc_path = tmp_path / "report.txt"
|
|
doc_path.write_text("Quarterly revenue is $5M", encoding="utf-8")
|
|
|
|
def fail_extract_documents(content: str, media: list[str]) -> tuple[str, list[str]]:
|
|
raise AssertionError("document extraction should be disabled")
|
|
|
|
monkeypatch.setattr("nanobot.agent.loop.extract_documents", fail_extract_documents)
|
|
|
|
ctx = TurnContext(
|
|
msg=InboundMessage(
|
|
channel="cli",
|
|
sender_id="u",
|
|
chat_id="c",
|
|
content="summarize",
|
|
media=[str(doc_path)],
|
|
),
|
|
session_key="cli:c",
|
|
state=TurnState.RESTORE,
|
|
turn_id="turn-1",
|
|
)
|
|
|
|
assert await loop._state_restore(ctx) == "ok"
|
|
|
|
assert "Quarterly revenue" not in ctx.msg.content
|
|
assert f"[Attachment: {doc_path}]" in ctx.msg.content
|
|
assert ctx.msg.media == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pending_followup_references_documents_when_extraction_disabled(
|
|
tmp_path: Path,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
doc_path = tmp_path / "followup.txt"
|
|
doc_path.write_text("Do not inject this file body", encoding="utf-8")
|
|
captured_messages: list[list[dict]] = []
|
|
call_count = {"n": 0}
|
|
|
|
async def chat_with_retry(*, messages: list[dict], **kwargs: object) -> LLMResponse:
|
|
call_count["n"] += 1
|
|
captured_messages.append([dict(message) for message in messages])
|
|
return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={})
|
|
|
|
loop = _make_loop(tmp_path, ChannelsConfig(extract_document_text=False))
|
|
loop.provider.chat_with_retry = chat_with_retry
|
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
|
|
|
def fail_extract_documents(content: str, media: list[str]) -> tuple[str, list[str]]:
|
|
raise AssertionError("document extraction should be disabled")
|
|
|
|
monkeypatch.setattr("nanobot.agent.loop.extract_documents", fail_extract_documents)
|
|
|
|
pending_queue: asyncio.Queue[InboundMessage] = asyncio.Queue()
|
|
await pending_queue.put(
|
|
InboundMessage(
|
|
channel="cli",
|
|
sender_id="u",
|
|
chat_id="c",
|
|
content="check this",
|
|
media=[str(doc_path)],
|
|
)
|
|
)
|
|
|
|
final_content, _, _, _, had_injections = await loop._run_agent_loop(
|
|
[{"role": "user", "content": "hello"}],
|
|
channel="cli",
|
|
chat_id="c",
|
|
pending_queue=pending_queue,
|
|
)
|
|
|
|
assert final_content == "answer-2"
|
|
assert had_injections is True
|
|
injected_user_content = [
|
|
message["content"]
|
|
for message in captured_messages[-1]
|
|
if message.get("role") == "user" and isinstance(message.get("content"), str)
|
|
][-1]
|
|
assert "check this" in injected_user_content
|
|
assert f"[Attachment: {doc_path}]" in injected_user_content
|
|
assert "Do not inject this file body" not in injected_user_content
|
|
|
|
|
|
def test_document_extraction_disabled_still_preserves_images(tmp_path: Path) -> None:
|
|
image_path = tmp_path / "chart.png"
|
|
image_path.write_bytes(
|
|
base64.b64decode(
|
|
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+yF9kAAAAASUVORK5CYII="
|
|
)
|
|
)
|
|
doc_path = tmp_path / "report.txt"
|
|
doc_path.write_text("manual extraction target", encoding="utf-8")
|
|
|
|
content, media = reference_non_image_attachments(
|
|
"review these",
|
|
[str(image_path), str(doc_path)],
|
|
)
|
|
|
|
assert media == [str(image_path)]
|
|
assert f"[Attachment: {doc_path}]" in content
|