mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
Add document extraction channel toggle
This commit is contained in:
parent
404b68cdd4
commit
ec4f9e9857
@ -1043,6 +1043,7 @@ Global settings that apply to all channels. Configure under the `channels` secti
|
||||
"channels": {
|
||||
"sendProgress": true,
|
||||
"sendToolHints": false,
|
||||
"extractDocumentText": true,
|
||||
"sendMaxRetries": 3,
|
||||
"transcriptionProvider": "groq",
|
||||
"transcriptionLanguage": null,
|
||||
@ -1056,6 +1057,7 @@ Global settings that apply to all channels. Configure under the `channels` secti
|
||||
| `sendProgress` | `true` | Stream agent's text progress to the channel |
|
||||
| `sendToolHints` | `false` | Stream tool-call hints (e.g. `read_file("…")`) |
|
||||
| `showReasoning` | `true` | Allow channels to surface model reasoning/thinking content (DeepSeek-R1 `reasoning_content`, Anthropic `thinking_blocks`, inline `<think>` tags). Reasoning flows as a dedicated stream with `_reasoning_delta` / `_reasoning_end` markers — channels override `send_reasoning_delta` / `send_reasoning_end` to render in-place updates. Even with `true`, channels without those overrides stay no-op silently. Currently surfaced on CLI and WebSocket/WebUI (italic shimmer header, auto-collapses after the stream ends); Telegram / Slack / Discord / Feishu / WeChat / Matrix keep the base no-op until their bubble UI is adapted. Independent of `sendProgress`. |
|
||||
| `extractDocumentText` | `true` | Extract supported document/text attachments into the model prompt. Set to `false` to keep document content out of the prompt and include attachment path references instead. |
|
||||
| `sendMaxRetries` | `3` | Max delivery attempts per outbound message, including the initial send (0-10 configured, minimum 1 actual attempt) |
|
||||
| `transcriptionProvider` | `"groq"` | Voice transcription backend: `"groq"` (free tier, default) or `"openai"`. API key and optional `apiBase` are auto-resolved from the matching provider config. Chat-style bases such as `https://api.groq.com/openai/v1` are normalized to the audio transcription endpoint. |
|
||||
| `transcriptionLanguage` | `null` | Optional ISO-639-1 language hint for audio transcription, e.g. `"en"`, `"ko"`, `"ja"`. |
|
||||
|
||||
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import mimetypes
|
||||
import os
|
||||
import time
|
||||
from contextlib import AsyncExitStack, nullcontext, suppress
|
||||
@ -51,7 +52,7 @@ from nanobot.session.webui_turns import (
|
||||
mark_webui_session,
|
||||
)
|
||||
from nanobot.utils.document import extract_documents
|
||||
from nanobot.utils.helpers import image_placeholder_text
|
||||
from nanobot.utils.helpers import detect_image_mime, image_placeholder_text
|
||||
from nanobot.utils.helpers import truncate_text as truncate_text_fn
|
||||
from nanobot.utils.image_generation_intent import image_generation_prompt
|
||||
from nanobot.utils.llm_runtime import LLMRuntime
|
||||
@ -711,7 +712,7 @@ class AgentLoop:
|
||||
content = pending_msg.content
|
||||
media = pending_msg.media if pending_msg.media else None
|
||||
if media:
|
||||
content, media = extract_documents(content, media)
|
||||
content, media = self._prepare_message_media(content, media)
|
||||
media = media or None
|
||||
user_content = self.context._build_user_content(content, media)
|
||||
return {"role": "user", "content": user_content}
|
||||
@ -1271,7 +1272,7 @@ class AgentLoop:
|
||||
msg = ctx.msg
|
||||
|
||||
if msg.media:
|
||||
new_content, image_only = extract_documents(msg.content, msg.media)
|
||||
new_content, image_only = self._prepare_message_media(msg.content, msg.media)
|
||||
ctx.msg = dataclasses.replace(msg, content=new_content, media=image_only)
|
||||
msg = ctx.msg
|
||||
|
||||
@ -1292,6 +1293,49 @@ class AgentLoop:
|
||||
|
||||
return "ok"
|
||||
|
||||
def _prepare_message_media(self, content: str, media: list[str]) -> tuple[str, list[str]]:
|
||||
if self._should_extract_document_text():
|
||||
return extract_documents(content, media)
|
||||
return self._reference_non_image_attachments(content, media)
|
||||
|
||||
def _should_extract_document_text(self) -> bool:
|
||||
cfg = self.channels_config
|
||||
if cfg is None:
|
||||
return True
|
||||
if isinstance(cfg, dict):
|
||||
value = cfg.get("extract_document_text", cfg.get("extractDocumentText", True))
|
||||
else:
|
||||
value = getattr(cfg, "extract_document_text", True)
|
||||
return value is not False
|
||||
|
||||
@staticmethod
|
||||
def _reference_non_image_attachments(content: str, media: list[str]) -> tuple[str, list[str]]:
|
||||
image_paths: list[str] = []
|
||||
attachment_refs: list[str] = []
|
||||
for path in media:
|
||||
if AgentLoop._looks_like_image(path):
|
||||
image_paths.append(path)
|
||||
else:
|
||||
attachment_refs.append(f"[Attachment: {path}]")
|
||||
if attachment_refs:
|
||||
suffix = "\n".join(attachment_refs)
|
||||
content = f"{content}\n\n{suffix}" if content else suffix
|
||||
return content, image_paths
|
||||
|
||||
@staticmethod
|
||||
def _looks_like_image(path: str) -> bool:
|
||||
p = Path(path)
|
||||
mime: str | None = None
|
||||
if p.is_file():
|
||||
try:
|
||||
with p.open("rb") as f:
|
||||
mime = detect_image_mime(f.read(16))
|
||||
except OSError:
|
||||
mime = None
|
||||
if not mime:
|
||||
mime = mimetypes.guess_type(path)[0]
|
||||
return bool(mime and mime.startswith("image/"))
|
||||
|
||||
async def _state_compact(self, ctx: TurnContext) -> str:
|
||||
ctx.session, pending = self.auto_compact.prepare_session(ctx.session, ctx.session_key)
|
||||
ctx.pending_summary = pending
|
||||
|
||||
@ -37,6 +37,7 @@ class ChannelsConfig(Base):
|
||||
send_progress: bool = True # stream agent's text progress to the channel
|
||||
send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…"))
|
||||
show_reasoning: bool = True # surface model reasoning when channel implements it
|
||||
extract_document_text: bool = True # extract text from document attachments before sending to the model
|
||||
send_max_retries: int = Field(default=3, ge=0, le=10) # Max delivery attempts (initial send included)
|
||||
transcription_provider: str = "groq" # Voice transcription backend: "groq" or "openai"
|
||||
transcription_language: str | None = Field(default=None, pattern=r"^[a-z]{2,3}$") # Optional ISO-639-1 hint for audio transcription
|
||||
|
||||
168
tests/agent/test_document_extraction_toggle.py
Normal file
168
tests/agent/test_document_extraction_toggle.py
Normal file
@ -0,0 +1,168 @@
|
||||
import asyncio
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop, TurnContext, TurnState
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.schema import ChannelsConfig
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
def _make_loop(tmp_path: Path, channels_config: ChannelsConfig | None = None) -> AgentLoop:
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok"))
|
||||
return AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
channels_config=channels_config,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_restore_extracts_documents_by_default(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
doc_path = tmp_path / "report.txt"
|
||||
doc_path.write_text("Quarterly revenue is $5M", encoding="utf-8")
|
||||
calls: list[tuple[str, list[str]]] = []
|
||||
|
||||
def fake_extract_documents(content: str, media: list[str]) -> tuple[str, list[str]]:
|
||||
calls.append((content, media))
|
||||
return f"{content}\n\n[File: report.txt]\nQuarterly revenue is $5M", []
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.loop.extract_documents", fake_extract_documents)
|
||||
|
||||
ctx = TurnContext(
|
||||
msg=InboundMessage(
|
||||
channel="cli",
|
||||
sender_id="u",
|
||||
chat_id="c",
|
||||
content="summarize",
|
||||
media=[str(doc_path)],
|
||||
),
|
||||
session_key="cli:c",
|
||||
state=TurnState.RESTORE,
|
||||
turn_id="turn-1",
|
||||
)
|
||||
|
||||
assert await loop._state_restore(ctx) == "ok"
|
||||
|
||||
assert calls == [("summarize", [str(doc_path)])]
|
||||
assert "Quarterly revenue" in ctx.msg.content
|
||||
assert ctx.msg.media == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_restore_references_documents_when_extraction_disabled(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
loop = _make_loop(tmp_path, ChannelsConfig(extract_document_text=False))
|
||||
doc_path = tmp_path / "report.txt"
|
||||
doc_path.write_text("Quarterly revenue is $5M", encoding="utf-8")
|
||||
|
||||
def fail_extract_documents(content: str, media: list[str]) -> tuple[str, list[str]]:
|
||||
raise AssertionError("document extraction should be disabled")
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.loop.extract_documents", fail_extract_documents)
|
||||
|
||||
ctx = TurnContext(
|
||||
msg=InboundMessage(
|
||||
channel="cli",
|
||||
sender_id="u",
|
||||
chat_id="c",
|
||||
content="summarize",
|
||||
media=[str(doc_path)],
|
||||
),
|
||||
session_key="cli:c",
|
||||
state=TurnState.RESTORE,
|
||||
turn_id="turn-1",
|
||||
)
|
||||
|
||||
assert await loop._state_restore(ctx) == "ok"
|
||||
|
||||
assert "Quarterly revenue" not in ctx.msg.content
|
||||
assert f"[Attachment: {doc_path}]" in ctx.msg.content
|
||||
assert ctx.msg.media == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pending_followup_references_documents_when_extraction_disabled(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
doc_path = tmp_path / "followup.txt"
|
||||
doc_path.write_text("Do not inject this file body", encoding="utf-8")
|
||||
captured_messages: list[list[dict]] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages: list[dict], **kwargs: object) -> LLMResponse:
|
||||
call_count["n"] += 1
|
||||
captured_messages.append([dict(message) for message in messages])
|
||||
return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={})
|
||||
|
||||
loop = _make_loop(tmp_path, ChannelsConfig(extract_document_text=False))
|
||||
loop.provider.chat_with_retry = chat_with_retry
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
def fail_extract_documents(content: str, media: list[str]) -> tuple[str, list[str]]:
|
||||
raise AssertionError("document extraction should be disabled")
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.loop.extract_documents", fail_extract_documents)
|
||||
|
||||
pending_queue: asyncio.Queue[InboundMessage] = asyncio.Queue()
|
||||
await pending_queue.put(
|
||||
InboundMessage(
|
||||
channel="cli",
|
||||
sender_id="u",
|
||||
chat_id="c",
|
||||
content="check this",
|
||||
media=[str(doc_path)],
|
||||
)
|
||||
)
|
||||
|
||||
final_content, _, _, _, had_injections = await loop._run_agent_loop(
|
||||
[{"role": "user", "content": "hello"}],
|
||||
channel="cli",
|
||||
chat_id="c",
|
||||
pending_queue=pending_queue,
|
||||
)
|
||||
|
||||
assert final_content == "answer-2"
|
||||
assert had_injections is True
|
||||
injected_user_content = [
|
||||
message["content"]
|
||||
for message in captured_messages[-1]
|
||||
if message.get("role") == "user" and isinstance(message.get("content"), str)
|
||||
][-1]
|
||||
assert "check this" in injected_user_content
|
||||
assert f"[Attachment: {doc_path}]" in injected_user_content
|
||||
assert "Do not inject this file body" not in injected_user_content
|
||||
|
||||
|
||||
def test_document_extraction_disabled_still_preserves_images(tmp_path: Path) -> None:
|
||||
image_path = tmp_path / "chart.png"
|
||||
image_path.write_bytes(
|
||||
base64.b64decode(
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+yF9kAAAAASUVORK5CYII="
|
||||
)
|
||||
)
|
||||
doc_path = tmp_path / "report.txt"
|
||||
doc_path.write_text("manual extraction target", encoding="utf-8")
|
||||
|
||||
content, media = AgentLoop._reference_non_image_attachments(
|
||||
"review these",
|
||||
[str(image_path), str(doc_path)],
|
||||
)
|
||||
|
||||
assert media == [str(image_path)]
|
||||
assert f"[Attachment: {doc_path}]" in content
|
||||
@ -91,6 +91,13 @@ def test_channels_config_builtin_fields_removed():
|
||||
assert not hasattr(cfg, "telegram")
|
||||
assert cfg.send_progress is True
|
||||
assert cfg.send_tool_hints is False
|
||||
assert cfg.extract_document_text is True
|
||||
|
||||
|
||||
def test_channels_config_extract_document_text_accepts_camel_alias():
|
||||
cfg = ChannelsConfig.model_validate({"extractDocumentText": False})
|
||||
|
||||
assert cfg.extract_document_text is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user