From bc357208bb4f71201cfa62d1a67fae2ab7cb3b22 Mon Sep 17 00:00:00 2001 From: rav-melisono Date: Sun, 29 Mar 2026 15:31:29 +0100 Subject: [PATCH 001/115] feat: add HTTP health endpoint on gateway port MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Binds a lightweight asyncio HTTP server on the configured gateway port (default 18790) alongside the existing agent and channel tasks. Endpoints: GET / -> "nanobot" (plain text, for service discovery) GET /health -> JSON with service, version, status, uptime, channels Zero new dependencies — uses asyncio.start_server. --- nanobot/cli/commands.py | 62 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index cacb61ae6..9ddb46d74 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -674,6 +674,67 @@ def gateway( console.print(f"[green]✓[/green] Heartbeat: every {hb_cfg.interval_s}s") + async def _health_server(host: str, health_port: int): + """Lightweight HTTP health endpoint on the gateway port.""" + import json as _json + import time + + start_time = time.monotonic() + + async def handle(reader, writer): + try: + data = await asyncio.wait_for(reader.read(4096), timeout=5) + except (asyncio.TimeoutError, ConnectionError): + writer.close() + return + + request_line = data.split(b"\r\n", 1)[0].decode("utf-8", errors="replace") + method, path = "", "" + parts = request_line.split(" ") + if len(parts) >= 2: + method, path = parts[0], parts[1] + + if method == "GET" and path == "/health": + uptime_s = int(time.monotonic() - start_time) + body = _json.dumps({ + "service": "nanobot", + "version": __version__, + "status": "running", + "uptime_seconds": uptime_s, + "channels": channels.enabled_channels, + }) + resp = ( + f"HTTP/1.0 200 OK\r\n" + f"Content-Type: application/json\r\n" + f"Content-Length: {len(body)}\r\n" + f"\r\n{body}" + ) + elif method == "GET" and path == "/": + body = "nanobot" + resp = ( + f"HTTP/1.0 200 OK\r\n" + f"Content-Type: text/plain\r\n" + f"Content-Length: {len(body)}\r\n" + f"\r\n{body}" + ) + else: + body = "Not Found" + resp = ( + f"HTTP/1.0 404 Not Found\r\n" + f"Content-Type: text/plain\r\n" + f"Content-Length: {len(body)}\r\n" + f"\r\n{body}" + ) + + writer.write(resp.encode()) + await writer.drain() + writer.close() + + server = await asyncio.start_server(handle, host, health_port) + console.print(f"[green]✓[/green] Health endpoint: http://{host}:{health_port}/health") + async with server: + await server.serve_forever() + async def run(): try: await cron.start() @@ -681,6 +742,7 @@ def gateway( await asyncio.gather( agent.run(), channels.start_all(), + _health_server(config.gateway.host, port), ) except KeyboardInterrupt: console.print("\nShutting down...") From 26ae90611653835608cc596b5d870f8614500338 Mon Sep 17 00:00:00 2001 From: Ziyan Lin Date: Mon, 30 Mar 2026 15:15:15 +0800 Subject: [PATCH 002/115] fix(providers): enforce role alternation for non-Claude providers Some LLM providers (OpenAI-compat, Azure, vLLM, Ollama) reject requests with consecutive same-role messages or trailing assistant messages. Add _enforce_role_alternation() to merge consecutive same-role user/assistant messages and strip trailing assistant messages before sending to the API. --- nanobot/providers/azure_openai_provider.py | 8 +- nanobot/providers/base.py | 36 +++++ nanobot/providers/openai_compat_provider.py | 2 +- .../test_enforce_role_alternation.py | 128 ++++++++++++++++++ 4 files changed, 170 insertions(+), 4 deletions(-) create mode 100644 tests/providers/test_enforce_role_alternation.py diff --git a/nanobot/providers/azure_openai_provider.py b/nanobot/providers/azure_openai_provider.py index d71dae917..4587cb933 100644 --- a/nanobot/providers/azure_openai_provider.py +++ b/nanobot/providers/azure_openai_provider.py @@ -94,9 +94,11 @@ class AzureOpenAIProvider(LLMProvider): ) -> dict[str, Any]: """Prepare the request payload with Azure OpenAI 2024-10-21 compliance.""" payload: dict[str, Any] = { - "messages": self._sanitize_request_messages( - self._sanitize_empty_content(messages), - _AZURE_MSG_KEYS, + "messages": self._enforce_role_alternation( + self._sanitize_request_messages( + self._sanitize_empty_content(messages), + _AZURE_MSG_KEYS, + ) ), "max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens } diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 9ce2b0c63..61d780c81 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -196,6 +196,42 @@ class LLMProvider(ABC): err = (content or "").lower() return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS) + @staticmethod + def _enforce_role_alternation(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Merge consecutive same-role messages and drop trailing assistant messages. + + Some providers (OpenAI-compat, Azure, vLLM, Ollama, etc.) reject requests + where the last message is 'assistant' (prefill not supported) or two + consecutive non-system messages share the same role. + """ + if not messages: + return messages + + merged: list[dict[str, Any]] = [] + for msg in messages: + role = msg.get("role") + if ( + merged + and role != "system" + and role not in ("tool",) + and merged[-1].get("role") == role + and role in ("user", "assistant") + ): + prev = merged[-1] + prev_content = prev.get("content") or "" + curr_content = msg.get("content") or "" + if isinstance(prev_content, str) and isinstance(curr_content, str): + prev["content"] = (prev_content + "\n\n" + curr_content).strip() + else: + merged[-1] = dict(msg) + else: + merged.append(dict(msg)) + + while merged and merged[-1].get("role") == "assistant": + merged.pop() + + return merged + @staticmethod def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None: """Replace image_url blocks with text placeholder. Returns None if no images found.""" diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 397b8e797..d456f9a6d 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -215,7 +215,7 @@ class OpenAICompatProvider(LLMProvider): clean["tool_calls"] = normalized if "tool_call_id" in clean and clean["tool_call_id"]: clean["tool_call_id"] = map_id(clean["tool_call_id"]) - return sanitized + return self._enforce_role_alternation(sanitized) # ------------------------------------------------------------------ # Build kwargs diff --git a/tests/providers/test_enforce_role_alternation.py b/tests/providers/test_enforce_role_alternation.py new file mode 100644 index 000000000..1fade6e4b --- /dev/null +++ b/tests/providers/test_enforce_role_alternation.py @@ -0,0 +1,128 @@ +"""Tests for LLMProvider._enforce_role_alternation.""" + +from nanobot.providers.base import LLMProvider + + +class TestEnforceRoleAlternation: + """Verify trailing-assistant removal and consecutive same-role merging.""" + + def test_empty_messages(self): + assert LLMProvider._enforce_role_alternation([]) == [] + + def test_no_change_needed(self): + msgs = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + {"role": "user", "content": "Bye"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert len(result) == 4 + assert result[-1]["role"] == "user" + + def test_trailing_assistant_removed(self): + msgs = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert len(result) == 1 + assert result[0]["role"] == "user" + + def test_multiple_trailing_assistants_removed(self): + msgs = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "A"}, + {"role": "assistant", "content": "B"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert len(result) == 1 + assert result[0]["role"] == "user" + + def test_consecutive_user_messages_merged(self): + msgs = [ + {"role": "user", "content": "Hello"}, + {"role": "user", "content": "How are you?"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert len(result) == 1 + assert "Hello" in result[0]["content"] + assert "How are you?" in result[0]["content"] + + def test_consecutive_assistant_messages_merged(self): + msgs = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + {"role": "assistant", "content": "How can I help?"}, + {"role": "user", "content": "Thanks"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert len(result) == 3 + assert "Hello!" in result[1]["content"] + assert "How can I help?" in result[1]["content"] + + def test_system_messages_not_merged(self): + msgs = [ + {"role": "system", "content": "System A"}, + {"role": "system", "content": "System B"}, + {"role": "user", "content": "Hi"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert len(result) == 3 + assert result[0]["content"] == "System A" + assert result[1]["content"] == "System B" + + def test_tool_messages_not_merged(self): + msgs = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]}, + {"role": "tool", "content": "result1", "tool_call_id": "1"}, + {"role": "tool", "content": "result2", "tool_call_id": "2"}, + {"role": "user", "content": "Next"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + tool_msgs = [m for m in result if m["role"] == "tool"] + assert len(tool_msgs) == 2 + + def test_non_string_content_uses_latest(self): + msgs = [ + {"role": "user", "content": [{"type": "text", "text": "A"}]}, + {"role": "user", "content": "B"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert len(result) == 1 + assert result[0]["content"] == "B" + + def test_original_messages_not_mutated(self): + msgs = [ + {"role": "user", "content": "Hello"}, + {"role": "user", "content": "World"}, + ] + original_first = dict(msgs[0]) + LLMProvider._enforce_role_alternation(msgs) + assert msgs[0] == original_first + assert len(msgs) == 2 + + def test_only_assistant_messages(self): + msgs = [ + {"role": "assistant", "content": "A"}, + {"role": "assistant", "content": "B"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert result == [] + + def test_realistic_conversation(self): + msgs = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + {"role": "user", "content": "And 3+3?"}, + {"role": "user", "content": "(please be quick)"}, + {"role": "assistant", "content": "6"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert len(result) == 4 + assert result[2]["role"] == "assistant" + assert result[3]["role"] == "user" + assert "And 3+3?" in result[3]["content"] + assert "(please be quick)" in result[3]["content"] From af6c75141f1e577f2252afb721c60193c273c60e Mon Sep 17 00:00:00 2001 From: stutiredboy Date: Wed, 8 Apr 2026 09:27:47 +0800 Subject: [PATCH 003/115] feat(): telegram support stream edit interval --- nanobot/channels/telegram.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py index 0ba8ce8e9..2dde232b1 100644 --- a/nanobot/channels/telegram.py +++ b/nanobot/channels/telegram.py @@ -166,6 +166,7 @@ def _markdown_to_telegram_html(text: str) -> str: _SEND_MAX_RETRIES = 3 _SEND_RETRY_BASE_DELAY = 0.5 # seconds, doubled each retry +_STREAM_EDIT_INTERVAL_DEFAULT = 0.6 # min seconds between edit_message_text calls @dataclass @@ -190,6 +191,7 @@ class TelegramConfig(Base): connection_pool_size: int = 32 pool_timeout: float = 5.0 streaming: bool = True + stream_edit_interval: float = Field(default=_STREAM_EDIT_INTERVAL_DEFAULT, ge=0.1) class TelegramChannel(BaseChannel): @@ -219,8 +221,6 @@ class TelegramChannel(BaseChannel): def default_config(cls) -> dict[str, Any]: return TelegramConfig().model_dump(by_alias=True) - _STREAM_EDIT_INTERVAL = 0.6 # min seconds between edit_message_text calls - def __init__(self, config: Any, bus: MessageBus): if isinstance(config, dict): config = TelegramConfig.model_validate(config) @@ -619,7 +619,7 @@ class TelegramChannel(BaseChannel): except Exception as e: logger.warning("Stream initial send failed: {}", e) raise # Let ChannelManager handle retry - elif (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL: + elif (now - buf.last_edit) >= self.config.stream_edit_interval: try: await self._call_with_retry( self._app.bot.edit_message_text, From b16865722bddd9776bd7bdfb74213db8465f7354 Mon Sep 17 00:00:00 2001 From: chengyongru Date: Wed, 8 Apr 2026 11:21:29 +0800 Subject: [PATCH 004/115] fix(tool-hint): fold paths in exec commands and deduplicate by formatted string MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. exec tool hints previously used val[:40] blind character truncation, cutting paths mid-segment. Now detects file paths via regex and abbreviates them with abbreviate_path. Supports Windows, Unix absolute, and ~/ home paths. 2. Deduplication now compares fully formatted hint strings instead of tool names alone. Fixes ls /Desktop and ls /Downloads being incorrectly merged as "ls /Desktop × 2". Co-authored-by: xzq.xu --- nanobot/utils/tool_hints.py | 57 +++++++++++++++++++++-------------- tests/agent/test_tool_hint.py | 52 +++++++++++++++++++++++++++----- 2 files changed, 79 insertions(+), 30 deletions(-) diff --git a/nanobot/utils/tool_hints.py b/nanobot/utils/tool_hints.py index a907a2700..9b6d29911 100644 --- a/nanobot/utils/tool_hints.py +++ b/nanobot/utils/tool_hints.py @@ -2,6 +2,8 @@ from __future__ import annotations +import re + from nanobot.utils.path import abbreviate_path # Registry: tool_name -> (key_args, template, is_path, is_command) @@ -17,27 +19,37 @@ _TOOL_FORMATS: dict[str, tuple[list[str], str, bool, bool]] = { "list_dir": (["path"], "ls {}", True, False), } +# Matches file paths embedded in shell commands (Windows drive, ~/, or absolute after space) +_PATH_IN_CMD_RE = re.compile( + r"(?:[A-Za-z]:[/\\]|~/|(?<=\s)/)[^\s;&|<>\"']+" +) + def format_tool_hints(tool_calls: list) -> str: """Format tool calls as concise hints with smart abbreviation.""" if not tool_calls: return "" - hints = [] - for name, count, example_tc in _group_consecutive(tool_calls): - fmt = _TOOL_FORMATS.get(name) + formatted = [] + for tc in tool_calls: + fmt = _TOOL_FORMATS.get(tc.name) if fmt: - hint = _fmt_known(example_tc, fmt) - elif name.startswith("mcp_"): - hint = _fmt_mcp(example_tc) + formatted.append(_fmt_known(tc, fmt)) + elif tc.name.startswith("mcp_"): + formatted.append(_fmt_mcp(tc)) else: - hint = _fmt_fallback(example_tc) + formatted.append(_fmt_fallback(tc)) - if count > 1: - hint = f"{hint} \u00d7 {count}" - hints.append(hint) + hints = [] + for hint in formatted: + if hints and hints[-1][0] == hint: + hints[-1] = (hint, hints[-1][1] + 1) + else: + hints.append((hint, 1)) - return ", ".join(hints) + return ", ".join( + f"{h} \u00d7 {c}" if c > 1 else h for h, c in hints + ) def _get_args(tc) -> dict: @@ -51,17 +63,6 @@ def _get_args(tc) -> dict: return {} -def _group_consecutive(calls: list) -> list[tuple[str, int, object]]: - """Group consecutive calls to the same tool: [(name, count, first), ...].""" - groups: list[tuple[str, int, object]] = [] - for tc in calls: - if groups and groups[-1][0] == tc.name: - groups[-1] = (groups[-1][0], groups[-1][1] + 1, groups[-1][2]) - else: - groups.append((tc.name, 1, tc)) - return groups - - def _extract_arg(tc, key_args: list[str]) -> str | None: """Extract the first available value from preferred key names.""" args = _get_args(tc) @@ -85,10 +86,20 @@ def _fmt_known(tc, fmt: tuple) -> str: if fmt[2]: # is_path val = abbreviate_path(val) elif fmt[3]: # is_command - val = val[:40] + "\u2026" if len(val) > 40 else val + val = _abbreviate_command(val) return fmt[1].format(val) +def _abbreviate_command(cmd: str, max_len: int = 40) -> str: + """Abbreviate paths in a command string, then truncate.""" + abbreviated = _PATH_IN_CMD_RE.sub( + lambda m: abbreviate_path(m.group(), max_len=25), cmd + ) + if len(abbreviated) <= max_len: + return abbreviated + return abbreviated[:max_len - 1] + "\u2026" + + def _fmt_mcp(tc) -> str: """Format MCP tool as server::tool.""" name = tc.name diff --git a/tests/agent/test_tool_hint.py b/tests/agent/test_tool_hint.py index 2384cfbb2..080a0b1e3 100644 --- a/tests/agent/test_tool_hint.py +++ b/tests/agent/test_tool_hint.py @@ -52,6 +52,37 @@ class TestToolHintKnownTools: assert result.startswith("$ ") assert len(result) <= 50 # reasonable limit + def test_exec_abbreviates_paths_in_command(self): + """Windows paths in exec commands should be folded, not blindly truncated.""" + cmd = "cd D:\\Documents\\GitHub\\nanobot\\.worktree\\tomain\\nanobot && git diff origin/main...pr-2706 --name-only 2>&1" + result = _hint([_tc("exec", {"command": cmd})]) + assert "\u2026/" in result # path should be folded with …/ + assert "worktree" not in result # middle segments should be collapsed + + def test_exec_abbreviates_linux_paths(self): + """Unix absolute paths in exec commands should be folded.""" + cmd = "cd /home/user/projects/nanobot/.worktree/tomain && make build" + result = _hint([_tc("exec", {"command": cmd})]) + assert "\u2026/" in result + assert "projects" not in result + + def test_exec_abbreviates_home_paths(self): + """~/ paths in exec commands should be folded.""" + cmd = "cd ~/projects/nanobot/workspace && pytest tests/" + result = _hint([_tc("exec", {"command": cmd})]) + assert "\u2026/" in result + + def test_exec_short_command_unchanged(self): + result = _hint([_tc("exec", {"command": "npm install typescript"})]) + assert result == "$ npm install typescript" + + def test_exec_chained_commands_truncated_not_mid_path(self): + """Long chained commands should truncate preserving abbreviated paths.""" + cmd = "cd D:\\Documents\\GitHub\\project && npm run build && npm test" + result = _hint([_tc("exec", {"command": cmd})]) + assert "\u2026/" in result # path folded + assert "npm" in result # chained command still visible + def test_web_search(self): result = _hint([_tc("web_search", {"query": "Claude 4 vs GPT-4"})]) assert result == 'search "Claude 4 vs GPT-4"' @@ -105,22 +136,30 @@ class TestToolHintFolding: result = _hint(calls) assert "\u00d7" not in result - def test_two_consecutive_same_folded(self): + def test_two_consecutive_different_args_not_folded(self): calls = [ _tc("grep", {"pattern": "*.py"}), _tc("grep", {"pattern": "*.ts"}), ] result = _hint(calls) + assert "\u00d7" not in result + + def test_two_consecutive_same_args_folded(self): + calls = [ + _tc("grep", {"pattern": "TODO"}), + _tc("grep", {"pattern": "TODO"}), + ] + result = _hint(calls) assert "\u00d7 2" in result - def test_three_consecutive_same_folded(self): + def test_three_consecutive_different_args_not_folded(self): calls = [ _tc("read_file", {"path": "a.py"}), _tc("read_file", {"path": "b.py"}), _tc("read_file", {"path": "c.py"}), ] result = _hint(calls) - assert "\u00d7 3" in result + assert "\u00d7" not in result def test_different_tools_not_folded(self): calls = [ @@ -187,7 +226,7 @@ class TestToolHintMixedFolding: """G4: Mixed folding groups with interleaved same-tool segments.""" def test_read_read_grep_grep_read(self): - """read×2, grep×2, read — should produce two separate groups.""" + """All different args — each hint listed separately.""" calls = [ _tc("read_file", {"path": "a.py"}), _tc("read_file", {"path": "b.py"}), @@ -196,7 +235,6 @@ class TestToolHintMixedFolding: _tc("read_file", {"path": "c.py"}), ] result = _hint(calls) - assert "\u00d7 2" in result - # Should have 3 groups: read×2, grep×2, read + assert "\u00d7" not in result parts = result.split(", ") - assert len(parts) == 3 + assert len(parts) == 5 From c092896922373ac56602081d7350c5f3b3941aae Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Wed, 8 Apr 2026 15:04:03 +0000 Subject: [PATCH 005/115] fix(tool-hint): handle quoted paths in exec hints Preserve path folding for quoted exec command paths with spaces so hint previews do not fall back to mid-path truncation. Add regression coverage for quoted Unix and Windows path cases. Made-with: Cursor --- nanobot/utils/tool_hints.py | 17 ++++++++++++----- tests/agent/test_tool_hint.py | 16 ++++++++++++++++ 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/nanobot/utils/tool_hints.py b/nanobot/utils/tool_hints.py index 9b6d29911..9758700b1 100644 --- a/nanobot/utils/tool_hints.py +++ b/nanobot/utils/tool_hints.py @@ -19,9 +19,11 @@ _TOOL_FORMATS: dict[str, tuple[list[str], str, bool, bool]] = { "list_dir": (["path"], "ls {}", True, False), } -# Matches file paths embedded in shell commands (Windows drive, ~/, or absolute after space) +# Matches file paths embedded in shell commands, including quoted paths with spaces. _PATH_IN_CMD_RE = re.compile( - r"(?:[A-Za-z]:[/\\]|~/|(?<=\s)/)[^\s;&|<>\"']+" + r'"(?P(?:[A-Za-z]:[/\\]|~/|/)[^"]+)"' + r"|'(?P(?:[A-Za-z]:[/\\]|~/|/)[^']+)'" + r"|(?P(?:[A-Za-z]:[/\\]|~/|(?<=\s)/)[^\s;&|<>\"']+)" ) @@ -92,9 +94,14 @@ def _fmt_known(tc, fmt: tuple) -> str: def _abbreviate_command(cmd: str, max_len: int = 40) -> str: """Abbreviate paths in a command string, then truncate.""" - abbreviated = _PATH_IN_CMD_RE.sub( - lambda m: abbreviate_path(m.group(), max_len=25), cmd - ) + def _replace_path(match: re.Match[str]) -> str: + if match.group("double") is not None: + return f'"{abbreviate_path(match.group("double"), max_len=25)}"' + if match.group("single") is not None: + return f"'{abbreviate_path(match.group('single'), max_len=25)}'" + return abbreviate_path(match.group("bare"), max_len=25) + + abbreviated = _PATH_IN_CMD_RE.sub(_replace_path, cmd) if len(abbreviated) <= max_len: return abbreviated return abbreviated[:max_len - 1] + "\u2026" diff --git a/tests/agent/test_tool_hint.py b/tests/agent/test_tool_hint.py index 080a0b1e3..b8ba99284 100644 --- a/tests/agent/test_tool_hint.py +++ b/tests/agent/test_tool_hint.py @@ -72,6 +72,22 @@ class TestToolHintKnownTools: result = _hint([_tc("exec", {"command": cmd})]) assert "\u2026/" in result + def test_exec_abbreviates_quoted_linux_paths_with_spaces(self): + """Quoted Unix paths with spaces should still be folded.""" + cmd = 'cd "/home/user/My Documents/project" && pytest tests/' + result = _hint([_tc("exec", {"command": cmd})]) + assert "\u2026/" in result + assert '"/home/user/My Documents/project"' not in result + assert '"' in result + + def test_exec_abbreviates_quoted_windows_paths_with_spaces(self): + """Quoted Windows paths with spaces should still be folded.""" + cmd = 'cd "C:/Program Files/Git/project" && git status' + result = _hint([_tc("exec", {"command": cmd})]) + assert "\u2026/" in result + assert '"C:/Program Files/Git/project"' not in result + assert '"' in result + def test_exec_short_command_unchanged(self): result = _hint([_tc("exec", {"command": "npm install typescript"})]) assert result == "$ npm install typescript" From d084d10dc27379c8f0557299af4eef6546266b2e Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Wed, 8 Apr 2026 15:21:08 +0000 Subject: [PATCH 006/115] feat(openai): auto-route direct reasoning requests with responses fallback --- nanobot/providers/openai_compat_provider.py | 163 +++++++++++- tests/providers/test_litellm_kwargs.py | 271 +++++++++++++++++++- 2 files changed, 423 insertions(+), 11 deletions(-) diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index a8782f831..aaa170395 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -26,6 +26,12 @@ else: from openai import AsyncOpenAI from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest +from nanobot.providers.openai_responses import ( + consume_sdk_stream, + convert_messages, + convert_tools, + parse_response_output, +) if TYPE_CHECKING: from nanobot.providers.registry import ProviderSpec @@ -113,6 +119,14 @@ def _uses_openrouter_attribution(spec: "ProviderSpec | None", api_base: str | No return bool(api_base and "openrouter" in api_base.lower()) +def _is_direct_openai_base(api_base: str | None) -> bool: + """Return True for direct OpenAI endpoints, not generic OpenAI-compatible gateways.""" + if not api_base: + return True + normalized = api_base.strip().lower().rstrip("/") + return "api.openai.com" in normalized and "openrouter" not in normalized + + class OpenAICompatProvider(LLMProvider): """Unified provider for all OpenAI-compatible APIs. @@ -137,6 +151,7 @@ class OpenAICompatProvider(LLMProvider): self._setup_env(api_key, api_base) effective_base = api_base or (spec.default_api_base if spec else None) or None + self._effective_base = effective_base default_headers = {"x-session-affinity": uuid.uuid4().hex} if _uses_openrouter_attribution(spec, effective_base): default_headers.update(_DEFAULT_OPENROUTER_HEADERS) @@ -321,6 +336,88 @@ class OpenAICompatProvider(LLMProvider): return kwargs + def _should_use_responses_api( + self, + model: str | None, + reasoning_effort: str | None, + ) -> bool: + """Use Responses API only for direct OpenAI requests that benefit from it.""" + if self._spec and self._spec.name != "openai": + return False + if not _is_direct_openai_base(self._effective_base): + return False + + model_name = (model or self.default_model).lower() + if reasoning_effort and reasoning_effort.lower() != "none": + return True + return any(token in model_name for token in ("gpt-5", "o1", "o3", "o4")) + + @staticmethod + def _should_fallback_from_responses_error(e: Exception) -> bool: + """Fallback only for likely Responses API compatibility errors.""" + response = getattr(e, "response", None) + status_code = getattr(e, "status_code", None) + if status_code is None and response is not None: + status_code = getattr(response, "status_code", None) + if status_code not in {400, 404, 422}: + return False + + body = ( + getattr(e, "body", None) + or getattr(e, "doc", None) + or getattr(response, "text", None) + ) + body_text = str(body).lower() if body is not None else "" + compatibility_markers = ( + "responses", + "response api", + "max_output_tokens", + "instructions", + "previous_response", + "unsupported", + "not supported", + "unknown parameter", + "unrecognized request argument", + ) + return any(marker in body_text for marker in compatibility_markers) + + def _build_responses_body( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None, + model: str | None, + max_tokens: int, + temperature: float, + reasoning_effort: str | None, + tool_choice: str | dict[str, Any] | None, + ) -> dict[str, Any]: + """Build a Responses API body for direct OpenAI requests.""" + model_name = model or self.default_model + sanitized_messages = self._sanitize_messages(self._sanitize_empty_content(messages)) + instructions, input_items = convert_messages(sanitized_messages) + + body: dict[str, Any] = { + "model": model_name, + "instructions": instructions or None, + "input": input_items, + "max_output_tokens": max(1, max_tokens), + "store": False, + "stream": False, + } + + if self._supports_temperature(model_name, reasoning_effort): + body["temperature"] = temperature + + if reasoning_effort and reasoning_effort.lower() != "none": + body["reasoning"] = {"effort": reasoning_effort} + body["include"] = ["reasoning.encrypted_content"] + + if tools: + body["tools"] = convert_tools(tools) + body["tool_choice"] = tool_choice or "auto" + + return body + # ------------------------------------------------------------------ # Response parsing # ------------------------------------------------------------------ @@ -731,11 +828,22 @@ class OpenAICompatProvider(LLMProvider): reasoning_effort: str | None = None, tool_choice: str | dict[str, Any] | None = None, ) -> LLMResponse: - kwargs = self._build_kwargs( - messages, tools, model, max_tokens, temperature, - reasoning_effort, tool_choice, - ) try: + if self._should_use_responses_api(model, reasoning_effort): + try: + body = self._build_responses_body( + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, + ) + return parse_response_output(await self._client.responses.create(**body)) + except Exception as responses_error: + if not self._should_fallback_from_responses_error(responses_error): + raise + + kwargs = self._build_kwargs( + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, + ) return self._parse(await self._client.chat.completions.create(**kwargs)) except Exception as e: return self._handle_error(e) @@ -751,14 +859,49 @@ class OpenAICompatProvider(LLMProvider): tool_choice: str | dict[str, Any] | None = None, on_content_delta: Callable[[str], Awaitable[None]] | None = None, ) -> LLMResponse: - kwargs = self._build_kwargs( - messages, tools, model, max_tokens, temperature, - reasoning_effort, tool_choice, - ) - kwargs["stream"] = True - kwargs["stream_options"] = {"include_usage": True} idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90")) try: + if self._should_use_responses_api(model, reasoning_effort): + try: + body = self._build_responses_body( + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, + ) + body["stream"] = True + stream = await self._client.responses.create(**body) + + async def _timed_stream(): + stream_iter = stream.__aiter__() + while True: + try: + yield await asyncio.wait_for( + stream_iter.__anext__(), + timeout=idle_timeout_s, + ) + except StopAsyncIteration: + break + + content, tool_calls, finish_reason, usage, reasoning_content = await consume_sdk_stream( + _timed_stream(), + on_content_delta, + ) + return LLMResponse( + content=content or None, + tool_calls=tool_calls, + finish_reason=finish_reason, + usage=usage, + reasoning_content=reasoning_content, + ) + except Exception as responses_error: + if not self._should_fallback_from_responses_error(responses_error): + raise + + kwargs = self._build_kwargs( + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, + ) + kwargs["stream"] = True + kwargs["stream_options"] = {"include_usage": True} stream = await self._client.chat.completions.create(**kwargs) chunks: list[Any] = [] stream_iter = stream.__aiter__() diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index 2e885e165..8839ea3f0 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -10,7 +10,7 @@ from __future__ import annotations import asyncio from types import SimpleNamespace -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -54,6 +54,57 @@ def _fake_tool_call_response() -> SimpleNamespace: return SimpleNamespace(choices=[choice], usage=usage) +def _fake_responses_response(content: str = "ok") -> MagicMock: + """Build a minimal Responses API response object.""" + resp = MagicMock() + resp.model_dump.return_value = { + "output": [{ + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": content}], + }], + "status": "completed", + "usage": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}, + } + return resp + + +def _fake_responses_stream(text: str = "ok"): + async def _stream(): + yield SimpleNamespace(type="response.output_text.delta", delta=text) + yield SimpleNamespace( + type="response.completed", + response=SimpleNamespace( + status="completed", + usage=SimpleNamespace(input_tokens=10, output_tokens=5, total_tokens=15), + output=[], + ), + ) + + return _stream() + + +def _fake_chat_stream(text: str = "ok"): + async def _stream(): + yield SimpleNamespace( + choices=[SimpleNamespace(finish_reason=None, delta=SimpleNamespace(content=text, reasoning_content=None, tool_calls=None))], + usage=None, + ) + yield SimpleNamespace( + choices=[SimpleNamespace(finish_reason="stop", delta=SimpleNamespace(content=None, reasoning_content=None, tool_calls=None))], + usage=SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ) + + return _stream() + + +class _FakeResponsesError(Exception): + def __init__(self, status_code: int, text: str): + super().__init__(text) + self.status_code = status_code + self.response = SimpleNamespace(status_code=status_code, text=text, headers={}) + + class _StalledStream: def __aiter__(self): return self @@ -226,6 +277,224 @@ def test_openai_model_passthrough() -> None: assert provider.get_default_model() == "gpt-4o" +@pytest.mark.asyncio +async def test_direct_openai_gpt5_uses_responses_api() -> None: + mock_chat = AsyncMock(return_value=_fake_chat_response()) + mock_responses = AsyncMock(return_value=_fake_responses_response("from responses")) + spec = find_by_name("openai") + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient: + client_instance = MockClient.return_value + client_instance.chat.completions.create = mock_chat + client_instance.responses.create = mock_responses + + provider = OpenAICompatProvider( + api_key="sk-test-key", + default_model="gpt-5-chat", + spec=spec, + ) + result = await provider.chat( + messages=[{"role": "user", "content": "hello"}], + model="gpt-5-chat", + ) + + assert result.content == "from responses" + mock_responses.assert_awaited_once() + mock_chat.assert_not_awaited() + call_kwargs = mock_responses.call_args.kwargs + assert call_kwargs["model"] == "gpt-5-chat" + assert call_kwargs["max_output_tokens"] == 4096 + assert "input" in call_kwargs + assert "messages" not in call_kwargs + + +@pytest.mark.asyncio +async def test_direct_openai_reasoning_prefers_responses_api() -> None: + mock_chat = AsyncMock(return_value=_fake_chat_response()) + mock_responses = AsyncMock(return_value=_fake_responses_response("reasoned")) + spec = find_by_name("openai") + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient: + client_instance = MockClient.return_value + client_instance.chat.completions.create = mock_chat + client_instance.responses.create = mock_responses + + provider = OpenAICompatProvider( + api_key="sk-test-key", + default_model="gpt-4o", + spec=spec, + ) + await provider.chat( + messages=[{"role": "user", "content": "hello"}], + model="gpt-4o", + reasoning_effort="medium", + ) + + mock_responses.assert_awaited_once() + mock_chat.assert_not_awaited() + call_kwargs = mock_responses.call_args.kwargs + assert call_kwargs["reasoning"] == {"effort": "medium"} + assert call_kwargs["include"] == ["reasoning.encrypted_content"] + + +@pytest.mark.asyncio +async def test_direct_openai_gpt4o_stays_on_chat_completions() -> None: + mock_chat = AsyncMock(return_value=_fake_chat_response()) + mock_responses = AsyncMock(return_value=_fake_responses_response()) + spec = find_by_name("openai") + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient: + client_instance = MockClient.return_value + client_instance.chat.completions.create = mock_chat + client_instance.responses.create = mock_responses + + provider = OpenAICompatProvider( + api_key="sk-test-key", + default_model="gpt-4o", + spec=spec, + ) + await provider.chat( + messages=[{"role": "user", "content": "hello"}], + model="gpt-4o", + ) + + mock_chat.assert_awaited_once() + mock_responses.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_openrouter_gpt5_stays_on_chat_completions() -> None: + mock_chat = AsyncMock(return_value=_fake_chat_response()) + mock_responses = AsyncMock(return_value=_fake_responses_response()) + spec = find_by_name("openrouter") + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient: + client_instance = MockClient.return_value + client_instance.chat.completions.create = mock_chat + client_instance.responses.create = mock_responses + + provider = OpenAICompatProvider( + api_key="sk-or-test-key", + api_base="https://openrouter.ai/api/v1", + default_model="openai/gpt-5", + spec=spec, + ) + await provider.chat( + messages=[{"role": "user", "content": "hello"}], + model="openai/gpt-5", + ) + + mock_chat.assert_awaited_once() + mock_responses.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_direct_openai_streaming_gpt5_uses_responses_api() -> None: + mock_chat = AsyncMock(return_value=_StalledStream()) + mock_responses = AsyncMock(return_value=_fake_responses_stream("hi")) + spec = find_by_name("openai") + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient: + client_instance = MockClient.return_value + client_instance.chat.completions.create = mock_chat + client_instance.responses.create = mock_responses + + provider = OpenAICompatProvider( + api_key="sk-test-key", + default_model="gpt-5-chat", + spec=spec, + ) + result = await provider.chat_stream( + messages=[{"role": "user", "content": "hello"}], + model="gpt-5-chat", + ) + + assert result.content == "hi" + assert result.finish_reason == "stop" + mock_responses.assert_awaited_once() + mock_chat.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_direct_openai_responses_404_falls_back_to_chat_completions() -> None: + mock_chat = AsyncMock(return_value=_fake_chat_response("from chat")) + mock_responses = AsyncMock(side_effect=_FakeResponsesError(404, "Responses endpoint not supported")) + spec = find_by_name("openai") + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient: + client_instance = MockClient.return_value + client_instance.chat.completions.create = mock_chat + client_instance.responses.create = mock_responses + + provider = OpenAICompatProvider( + api_key="sk-test-key", + default_model="gpt-5-chat", + spec=spec, + ) + result = await provider.chat( + messages=[{"role": "user", "content": "hello"}], + model="gpt-5-chat", + ) + + assert result.content == "from chat" + mock_responses.assert_awaited_once() + mock_chat.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_direct_openai_stream_responses_unsupported_param_falls_back() -> None: + mock_chat = AsyncMock(return_value=_fake_chat_stream("fallback stream")) + mock_responses = AsyncMock( + side_effect=_FakeResponsesError(400, "Unknown parameter: max_output_tokens for Responses API") + ) + spec = find_by_name("openai") + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient: + client_instance = MockClient.return_value + client_instance.chat.completions.create = mock_chat + client_instance.responses.create = mock_responses + + provider = OpenAICompatProvider( + api_key="sk-test-key", + default_model="gpt-5-chat", + spec=spec, + ) + result = await provider.chat_stream( + messages=[{"role": "user", "content": "hello"}], + model="gpt-5-chat", + ) + + assert result.content == "fallback stream" + mock_responses.assert_awaited_once() + mock_chat.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_direct_openai_responses_rate_limit_does_not_fallback() -> None: + mock_chat = AsyncMock(return_value=_fake_chat_response("from chat")) + mock_responses = AsyncMock(side_effect=_FakeResponsesError(429, "rate limit")) + spec = find_by_name("openai") + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient: + client_instance = MockClient.return_value + client_instance.chat.completions.create = mock_chat + client_instance.responses.create = mock_responses + + provider = OpenAICompatProvider( + api_key="sk-test-key", + default_model="gpt-5-chat", + spec=spec, + ) + result = await provider.chat( + messages=[{"role": "user", "content": "hello"}], + model="gpt-5-chat", + ) + + assert result.finish_reason == "error" + mock_responses.assert_awaited_once() + mock_chat.assert_not_awaited() + + def test_openai_compat_supports_temperature_matches_reasoning_model_rules() -> None: assert OpenAICompatProvider._supports_temperature("gpt-4o") is True assert OpenAICompatProvider._supports_temperature("gpt-5-chat") is False From 0f1e3aa15135ccfd280c1fe75628016abafd54ed Mon Sep 17 00:00:00 2001 From: "xinnan.hou" <550747419@qq.com> Date: Wed, 8 Apr 2026 11:35:32 +0800 Subject: [PATCH 007/115] fix --- nanobot/cron/service.py | 118 ++++++++++++++++++++++-------- nanobot/cron/types.py | 7 ++ pyproject.toml | 1 + tests/cron/test_cron_service.py | 63 ++++++++++++---- tests/cron/test_cron_tool_list.py | 12 ++- 5 files changed, 156 insertions(+), 45 deletions(-) diff --git a/nanobot/cron/service.py b/nanobot/cron/service.py index 5abe676e6..1807c2926 100644 --- a/nanobot/cron/service.py +++ b/nanobot/cron/service.py @@ -4,10 +4,12 @@ import asyncio import json import time import uuid +from dataclasses import asdict from datetime import datetime from pathlib import Path from typing import Any, Callable, Coroutine, Literal +from filelock import FileLock from loguru import logger from nanobot.cron.types import CronJob, CronJobState, CronPayload, CronRunRecord, CronSchedule, CronStore @@ -69,28 +71,25 @@ class CronService: self, store_path: Path, on_job: Callable[[CronJob], Coroutine[Any, Any, str | None]] | None = None, + max_sleep_ms: int = 300_000 # 5 minutes ): self.store_path = store_path + self._action_path = store_path.parent / "action.jsonl" + self._lock = FileLock(str(self._action_path.parent) + ".lock") self.on_job = on_job self._store: CronStore | None = None - self._last_mtime: float = 0.0 self._timer_task: asyncio.Task | None = None self._running = False + self.max_sleep_ms = max_sleep_ms - def _load_store(self) -> CronStore: - """Load jobs from disk. Reloads automatically if file was modified externally.""" - if self._store and self.store_path.exists(): - mtime = self.store_path.stat().st_mtime - if mtime != self._last_mtime: - logger.info("Cron: jobs.json modified externally, reloading") - self._store = None - if self._store: - return self._store - + def _load_jobs(self) -> tuple[list[CronJob], int]: + jobs = [] + version = 1 if self.store_path.exists(): try: data = json.loads(self.store_path.read_text(encoding="utf-8")) jobs = [] + version = data.get("version", 1) for j in data.get("jobs", []): jobs.append(CronJob( id=j["id"], @@ -129,13 +128,53 @@ class CronService: updated_at_ms=j.get("updatedAtMs", 0), delete_after_run=j.get("deleteAfterRun", False), )) - self._store = CronStore(jobs=jobs) - self._last_mtime = self.store_path.stat().st_mtime except Exception as e: logger.warning("Failed to load cron store: {}", e) - self._store = CronStore() - else: - self._store = CronStore() + return jobs, version + + def _merge_action(self): + if not self._action_path.exists(): + return + + jobs_map = {j.id: j for j in self._store.jobs} + def _update(params: dict): + j = CronJob.from_dict(params) + jobs_map[j.id] = j + + def _del(params: dict): + if job_id := params.get("job_id"): + jobs_map.pop(job_id) + + with self._lock: + with open(self._action_path, "r", encoding="utf-8") as f: + changed = False + for line in f: + try: + line = line.strip() + action = json.loads(line) + if "action" not in action: + continue + if action["action"] == "del": + _del(action.get("params", {})) + else: + _update(action.get("params", {})) + changed = True + except Exception as exp: + logger.debug(f"load action line error: {exp}") + continue + self._store.jobs = list(jobs_map.values()) + if self._running and changed: + self._action_path.write_text("", encoding="utf-8") + self._save_store() + return + + def _load_store(self) -> CronStore: + """Load jobs from disk. Reloads automatically if file was modified externally. + - Reload every time because it needs to merge operations on the jobs object from other instances. + """ + jobs, version = self._load_jobs() + self._store = CronStore(version=version, jobs=jobs) + self._merge_action() return self._store @@ -230,11 +269,11 @@ class CronService: if self._timer_task: self._timer_task.cancel() - next_wake = self._get_next_wake_ms() - if not next_wake or not self._running: + if not self._running: return - delay_ms = max(0, next_wake - _now_ms()) + next_wake = self._get_next_wake_ms() or 0 + delay_ms = min(self.max_sleep_ms ,max(1000, next_wake - _now_ms())) delay_s = delay_ms / 1000 async def tick(): @@ -303,6 +342,13 @@ class CronService: # Compute next run job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms()) + def _append_action(self, action: Literal["add", "del", "update"], params: dict): + self.store_path.parent.mkdir(parents=True, exist_ok=True) + with self._lock: + with open(self._action_path, "a", encoding="utf-8") as f: + f.write(json.dumps({"action": action, "params": params}, ensure_ascii=False) + "\n") + + # ========== Public API ========== def list_jobs(self, include_disabled: bool = False) -> list[CronJob]: @@ -322,7 +368,6 @@ class CronService: delete_after_run: bool = False, ) -> CronJob: """Add a new job.""" - store = self._load_store() _validate_schedule_for_add(schedule) now = _now_ms() @@ -343,10 +388,13 @@ class CronService: updated_at_ms=now, delete_after_run=delete_after_run, ) - - store.jobs.append(job) - self._save_store() - self._arm_timer() + if self._running: + store = self._load_store() + store.jobs.append(job) + self._save_store() + self._arm_timer() + else: + self._append_action("add", asdict(job)) logger.info("Cron: added job '{}' ({})", name, job.id) return job @@ -380,8 +428,11 @@ class CronService: removed = len(store.jobs) < before if removed: - self._save_store() - self._arm_timer() + if self._running: + self._save_store() + self._arm_timer() + else: + self._append_action("del", {"job_id": job_id}) logger.info("Cron: removed job {}", job_id) return "removed" @@ -398,13 +449,20 @@ class CronService: job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms()) else: job.state.next_run_at_ms = None - self._save_store() - self._arm_timer() + if self._running: + self._save_store() + self._arm_timer() + else: + self._append_action("update", asdict(job)) return job return None async def run_job(self, job_id: str, force: bool = False) -> bool: - """Manually run a job.""" + """Manually run a job. For testing purposes + - It's not that the gateway instance cannot run because it doesn't have the on_job method. + - There may be concurrency issues. + """ + self._running = True store = self._load_store() for job in store.jobs: if job.id == job_id: @@ -412,8 +470,10 @@ class CronService: return False await self._execute_job(job) self._save_store() + self._running = False self._arm_timer() return True + self._running = False return False def get_job(self, job_id: str) -> CronJob | None: diff --git a/nanobot/cron/types.py b/nanobot/cron/types.py index e7b2c4391..8a1d1e0f1 100644 --- a/nanobot/cron/types.py +++ b/nanobot/cron/types.py @@ -61,6 +61,13 @@ class CronJob: updated_at_ms: int = 0 delete_after_run: bool = False + @classmethod + def from_dict(cls, kwargs: dict): + kwargs["schedule"] = CronSchedule(**kwargs.get("schedule", {"kind": "every"})) + kwargs["payload"] = CronPayload(**kwargs.get("payload", {})) + kwargs["state"] = CronJobState(**kwargs.get("state", {})) + return cls(**kwargs) + @dataclass class CronStore: diff --git a/pyproject.toml b/pyproject.toml index a5807f962..751716135 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ dependencies = [ "tiktoken>=0.12.0,<1.0.0", "jinja2>=3.1.0,<4.0.0", "dulwich>=0.22.0,<1.0.0", + "filelock>=3.25.2", ] [project.optional-dependencies] diff --git a/tests/cron/test_cron_service.py b/tests/cron/test_cron_service.py index 8606e4f58..51cff228c 100644 --- a/tests/cron/test_cron_service.py +++ b/tests/cron/test_cron_service.py @@ -1,5 +1,6 @@ import asyncio import json +import time import pytest @@ -158,24 +159,27 @@ def test_remove_job_refuses_system_jobs(tmp_path) -> None: assert service.get_job("dream") is not None -def test_reload_jobs(tmp_path): +@pytest.mark.asyncio +async def test_start_server_not_jobs(tmp_path): store_path = tmp_path / "cron" / "jobs.json" - service = CronService(store_path, on_job=lambda _: asyncio.sleep(0)) - service.add_job( - name="hist", - schedule=CronSchedule(kind="every", every_ms=60_000), - message="hello", - ) + called = [] + async def on_job(job): + called.append(job.name) - assert len(service.list_jobs()) == 1 + service = CronService(store_path, on_job=on_job, max_sleep_ms=1000) + await service.start() + assert len(service.list_jobs()) == 0 service2 = CronService(tmp_path / "cron" / "jobs.json") service2.add_job( - name="hist2", - schedule=CronSchedule(kind="every", every_ms=60_000), - message="hello2", + name="hist", + schedule=CronSchedule(kind="every", every_ms=500), + message="hello", ) - assert len(service.list_jobs()) == 2 + assert len(service.list_jobs()) == 1 + await asyncio.sleep(2) + assert len(called) != 0 + service.stop() @pytest.mark.asyncio @@ -204,7 +208,40 @@ async def test_running_service_picks_up_external_add(tmp_path): message="ping", ) - await asyncio.sleep(0.6) + await asyncio.sleep(2) assert "external" in called finally: service.stop() + + +@pytest.mark.asyncio +async def test_add_job_during_jobs_exec(tmp_path): + store_path = tmp_path / "cron" / "jobs.json" + run_once = True + + async def on_job(job): + nonlocal run_once + if run_once: + service2 = CronService(store_path, on_job=lambda x: asyncio.sleep(0)) + service2.add_job( + name="test", + schedule=CronSchedule(kind="every", every_ms=150), + message="tick", + ) + run_once = False + + service = CronService(store_path, on_job=on_job) + service.add_job( + name="heartbeat", + schedule=CronSchedule(kind="every", every_ms=150), + message="tick", + ) + assert len(service.list_jobs()) == 1 + await service.start() + try: + await asyncio.sleep(3) + jobs = service.list_jobs() + assert len(jobs) == 2 + assert "test" in [j.name for j in jobs] + finally: + service.stop() diff --git a/tests/cron/test_cron_tool_list.py b/tests/cron/test_cron_tool_list.py index e57ab26bd..86f3055cf 100644 --- a/tests/cron/test_cron_tool_list.py +++ b/tests/cron/test_cron_tool_list.py @@ -2,9 +2,12 @@ from datetime import datetime, timezone +import pytest + from nanobot.agent.tools.cron import CronTool from nanobot.cron.service import CronService from nanobot.cron.types import CronJob, CronJobState, CronPayload, CronSchedule +from tests.test_openai_api import pytest_plugins def _make_tool(tmp_path) -> CronTool: @@ -215,8 +218,10 @@ def test_list_at_job_shows_iso_timestamp(tmp_path) -> None: assert "Asia/Shanghai" in result -def test_list_shows_last_run_state(tmp_path) -> None: +@pytest.mark.asyncio +async def test_list_shows_last_run_state(tmp_path) -> None: tool = _make_tool(tmp_path) + tool._cron._running = True job = tool._cron.add_job( name="Stateful job", schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"), @@ -232,9 +237,10 @@ def test_list_shows_last_run_state(tmp_path) -> None: assert "ok" in result assert "(UTC)" in result - -def test_list_shows_error_message(tmp_path) -> None: +@pytest.mark.asyncio +async def test_list_shows_error_message(tmp_path) -> None: tool = _make_tool(tmp_path) + tool._cron._running = True job = tool._cron.add_job( name="Failed job", schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"), From 142cb46956db801b213280afb91ebdf3deeca892 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Wed, 8 Apr 2026 15:27:20 +0000 Subject: [PATCH 008/115] fix(cron): preserve manual run state and merged history Keep manual runs from flipping the scheduler's running flag, rebuild merged run history records from action logs, and avoid delaying sub-second jobs to a one-second floor. Add regression coverage for disabled/manual runs, merged history persistence, and sub-second timers. Made-with: Cursor --- nanobot/cron/service.py | 39 ++++++++-------- nanobot/cron/types.py | 7 ++- tests/cron/test_cron_service.py | 82 +++++++++++++++++++++++++++++++++ 3 files changed, 109 insertions(+), 19 deletions(-) diff --git a/nanobot/cron/service.py b/nanobot/cron/service.py index 1807c2926..1259d3d72 100644 --- a/nanobot/cron/service.py +++ b/nanobot/cron/service.py @@ -71,7 +71,7 @@ class CronService: self, store_path: Path, on_job: Callable[[CronJob], Coroutine[Any, Any, str | None]] | None = None, - max_sleep_ms: int = 300_000 # 5 minutes + max_sleep_ms: int = 300_000, # 5 minutes ): self.store_path = store_path self._action_path = store_path.parent / "action.jsonl" @@ -272,8 +272,11 @@ class CronService: if not self._running: return - next_wake = self._get_next_wake_ms() or 0 - delay_ms = min(self.max_sleep_ms ,max(1000, next_wake - _now_ms())) + next_wake = self._get_next_wake_ms() + if next_wake is None: + delay_ms = self.max_sleep_ms + else: + delay_ms = min(self.max_sleep_ms, max(0, next_wake - _now_ms())) delay_s = delay_ms / 1000 async def tick(): @@ -458,23 +461,23 @@ class CronService: return None async def run_job(self, job_id: str, force: bool = False) -> bool: - """Manually run a job. For testing purposes - - It's not that the gateway instance cannot run because it doesn't have the on_job method. - - There may be concurrency issues. - """ + """Manually run a job without disturbing the service's running state.""" + was_running = self._running self._running = True - store = self._load_store() - for job in store.jobs: - if job.id == job_id: - if not force and not job.enabled: - return False - await self._execute_job(job) - self._save_store() - self._running = False + try: + store = self._load_store() + for job in store.jobs: + if job.id == job_id: + if not force and not job.enabled: + return False + await self._execute_job(job) + self._save_store() + return True + return False + finally: + self._running = was_running + if was_running: self._arm_timer() - return True - self._running = False - return False def get_job(self, job_id: str) -> CronJob | None: """Get a job by ID.""" diff --git a/nanobot/cron/types.py b/nanobot/cron/types.py index 8a1d1e0f1..c38542e17 100644 --- a/nanobot/cron/types.py +++ b/nanobot/cron/types.py @@ -63,9 +63,14 @@ class CronJob: @classmethod def from_dict(cls, kwargs: dict): + state_kwargs = dict(kwargs.get("state", {})) + state_kwargs["run_history"] = [ + record if isinstance(record, CronRunRecord) else CronRunRecord(**record) + for record in state_kwargs.get("run_history", []) + ] kwargs["schedule"] = CronSchedule(**kwargs.get("schedule", {"kind": "every"})) kwargs["payload"] = CronPayload(**kwargs.get("payload", {})) - kwargs["state"] = CronJobState(**kwargs.get("state", {})) + kwargs["state"] = CronJobState(**state_kwargs) return cls(**kwargs) diff --git a/tests/cron/test_cron_service.py b/tests/cron/test_cron_service.py index 51cff228c..b54cf5e20 100644 --- a/tests/cron/test_cron_service.py +++ b/tests/cron/test_cron_service.py @@ -115,6 +115,41 @@ async def test_run_history_persisted_to_disk(tmp_path) -> None: assert loaded.state.run_history[0].status == "ok" +@pytest.mark.asyncio +async def test_run_job_disabled_does_not_flip_running_state(tmp_path) -> None: + store_path = tmp_path / "cron" / "jobs.json" + service = CronService(store_path, on_job=lambda _: asyncio.sleep(0)) + job = service.add_job( + name="disabled", + schedule=CronSchedule(kind="every", every_ms=60_000), + message="hello", + ) + service.enable_job(job.id, enabled=False) + + result = await service.run_job(job.id) + + assert result is False + assert service._running is False + + +@pytest.mark.asyncio +async def test_run_job_preserves_running_service_state(tmp_path) -> None: + store_path = tmp_path / "cron" / "jobs.json" + service = CronService(store_path, on_job=lambda _: asyncio.sleep(0)) + service._running = True + job = service.add_job( + name="manual", + schedule=CronSchedule(kind="every", every_ms=60_000), + message="hello", + ) + + result = await service.run_job(job.id, force=True) + + assert result is True + assert service._running is True + service.stop() + + @pytest.mark.asyncio async def test_running_service_honors_external_disable(tmp_path) -> None: store_path = tmp_path / "cron" / "jobs.json" @@ -182,6 +217,28 @@ async def test_start_server_not_jobs(tmp_path): service.stop() +@pytest.mark.asyncio +async def test_subsecond_job_not_delayed_to_one_second(tmp_path): + store_path = tmp_path / "cron" / "jobs.json" + called = [] + + async def on_job(job): + called.append(job.name) + + service = CronService(store_path, on_job=on_job, max_sleep_ms=5000) + service.add_job( + name="fast", + schedule=CronSchedule(kind="every", every_ms=100), + message="hello", + ) + await service.start() + try: + await asyncio.sleep(0.35) + assert called + finally: + service.stop() + + @pytest.mark.asyncio async def test_running_service_picks_up_external_add(tmp_path): """A running service should detect and execute a job added by another instance.""" @@ -245,3 +302,28 @@ async def test_add_job_during_jobs_exec(tmp_path): assert "test" in [j.name for j in jobs] finally: service.stop() + + +@pytest.mark.asyncio +async def test_external_update_preserves_run_history_records(tmp_path): + store_path = tmp_path / "cron" / "jobs.json" + service = CronService(store_path, on_job=lambda _: asyncio.sleep(0)) + job = service.add_job( + name="history", + schedule=CronSchedule(kind="every", every_ms=60_000), + message="hello", + ) + await service.run_job(job.id, force=True) + + external = CronService(store_path) + updated = external.enable_job(job.id, enabled=False) + assert updated is not None + + fresh = CronService(store_path) + loaded = fresh.get_job(job.id) + assert loaded is not None + assert loaded.state.run_history + assert loaded.state.run_history[0].status == "ok" + + fresh._running = True + fresh._save_store() From d88be08bfd80875fda94d1d84f7d3595115cea47 Mon Sep 17 00:00:00 2001 From: Lingao Meng Date: Wed, 8 Apr 2026 09:10:44 +0800 Subject: [PATCH 009/115] refactor(hook): add reraise flag to AgentHook and remove _LoopHookChain Add reraise parameter to AgentHook so hooks can opt out of exception swallowing in CompositeHook._for_each_hook_safe. _LoopHook sets reraise=True to let its exceptions propagate. _LoopHookChain is removed and replaced with CompositeHook([loop_hook] + extra_hooks). Signed-off-by: Lingao Meng --- nanobot/agent/hook.py | 7 +++++++ nanobot/agent/loop.py | 41 ++--------------------------------------- 2 files changed, 9 insertions(+), 39 deletions(-) diff --git a/nanobot/agent/hook.py b/nanobot/agent/hook.py index 827831ebd..33db416ed 100644 --- a/nanobot/agent/hook.py +++ b/nanobot/agent/hook.py @@ -29,6 +29,9 @@ class AgentHookContext: class AgentHook: """Minimal lifecycle surface for shared runner customization.""" + def __init__(self, reraise: bool = False) -> None: + self._reraise = reraise + def wants_streaming(self) -> bool: return False @@ -69,6 +72,10 @@ class CompositeHook(AgentHook): async def _for_each_hook_safe(self, method_name: str, *args: Any, **kwargs: Any) -> None: for h in self._hooks: + if h._reraise: + await getattr(h, method_name)(*args, **kwargs) + continue + try: await getattr(h, method_name)(*args, **kwargs) except Exception: diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 66d765d00..d549d6d4a 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -54,6 +54,7 @@ class _LoopHook(AgentHook): chat_id: str = "direct", message_id: str | None = None, ) -> None: + super().__init__(reraise=True) self._loop = agent_loop self._on_progress = on_progress self._on_stream = on_stream @@ -108,44 +109,6 @@ class _LoopHook(AgentHook): def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: return self._loop._strip_think(content) - -class _LoopHookChain(AgentHook): - """Run the core hook before extra hooks.""" - - __slots__ = ("_primary", "_extras") - - def __init__(self, primary: AgentHook, extra_hooks: list[AgentHook]) -> None: - self._primary = primary - self._extras = CompositeHook(extra_hooks) - - def wants_streaming(self) -> bool: - return self._primary.wants_streaming() or self._extras.wants_streaming() - - async def before_iteration(self, context: AgentHookContext) -> None: - await self._primary.before_iteration(context) - await self._extras.before_iteration(context) - - async def on_stream(self, context: AgentHookContext, delta: str) -> None: - await self._primary.on_stream(context, delta) - await self._extras.on_stream(context, delta) - - async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: - await self._primary.on_stream_end(context, resuming=resuming) - await self._extras.on_stream_end(context, resuming=resuming) - - async def before_execute_tools(self, context: AgentHookContext) -> None: - await self._primary.before_execute_tools(context) - await self._extras.before_execute_tools(context) - - async def after_iteration(self, context: AgentHookContext) -> None: - await self._primary.after_iteration(context) - await self._extras.after_iteration(context) - - def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: - content = self._primary.finalize_content(context, content) - return self._extras.finalize_content(context, content) - - class AgentLoop: """ The agent loop is the core processing engine. @@ -359,7 +322,7 @@ class AgentLoop: message_id=message_id, ) hook: AgentHook = ( - _LoopHookChain(loop_hook, self._extra_hooks) + CompositeHook([loop_hook] + self._extra_hooks) if self._extra_hooks else loop_hook ) From 6bf101c79bb73b2836256c8a6428d5fda6633009 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Wed, 8 Apr 2026 15:38:41 +0000 Subject: [PATCH 010/115] fix(hook): keep composite hooks backward compatible Avoid AttributeError regressions when hooks define their own __init__ or when a CompositeHook wraps another composite. Made-with: Cursor --- nanobot/agent/hook.py | 3 ++- nanobot/agent/subagent.py | 1 + tests/agent/test_hook_composite.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/nanobot/agent/hook.py b/nanobot/agent/hook.py index 33db416ed..ad6b31ab1 100644 --- a/nanobot/agent/hook.py +++ b/nanobot/agent/hook.py @@ -65,6 +65,7 @@ class CompositeHook(AgentHook): __slots__ = ("_hooks",) def __init__(self, hooks: list[AgentHook]) -> None: + super().__init__() self._hooks = list(hooks) def wants_streaming(self) -> bool: @@ -72,7 +73,7 @@ class CompositeHook(AgentHook): async def _for_each_hook_safe(self, method_name: str, *args: Any, **kwargs: Any) -> None: for h in self._hooks: - if h._reraise: + if getattr(h, "_reraise", False): await getattr(h, method_name)(*args, **kwargs) continue diff --git a/nanobot/agent/subagent.py b/nanobot/agent/subagent.py index 585139972..63aa7ad7a 100644 --- a/nanobot/agent/subagent.py +++ b/nanobot/agent/subagent.py @@ -27,6 +27,7 @@ class _SubagentHook(AgentHook): """Logging-only hook for subagent execution.""" def __init__(self, task_id: str) -> None: + super().__init__() self._task_id = task_id async def before_execute_tools(self, context: AgentHookContext) -> None: diff --git a/tests/agent/test_hook_composite.py b/tests/agent/test_hook_composite.py index 590d8db64..c6077d526 100644 --- a/tests/agent/test_hook_composite.py +++ b/tests/agent/test_hook_composite.py @@ -232,6 +232,35 @@ async def test_composite_empty_hooks_no_ops(): assert hook.finalize_content(ctx, "test") == "test" +@pytest.mark.asyncio +async def test_composite_supports_legacy_hook_init_without_super(): + calls: list[str] = [] + + class LegacyHook(AgentHook): + def __init__(self, label: str) -> None: + self.label = label + + async def before_iteration(self, context: AgentHookContext) -> None: + calls.append(self.label) + + hook = CompositeHook([LegacyHook("legacy")]) + await hook.before_iteration(_ctx()) + assert calls == ["legacy"] + + +@pytest.mark.asyncio +async def test_composite_can_wrap_another_composite(): + calls: list[str] = [] + + class Inner(AgentHook): + async def before_iteration(self, context: AgentHookContext) -> None: + calls.append("inner") + + hook = CompositeHook([CompositeHook([Inner()])]) + await hook.before_iteration(_ctx()) + assert calls == ["inner"] + + # --------------------------------------------------------------------------- # Integration: AgentLoop with extra hooks # --------------------------------------------------------------------------- From 1700166945a28e3cab9cc20bd62ed2dd01165e18 Mon Sep 17 00:00:00 2001 From: bahtya Date: Mon, 6 Apr 2026 23:22:34 +0800 Subject: [PATCH 011/115] fix: use importlib.metadata for version to prevent mismatch with pyproject.toml Fixes #2856 Previously __version__ was hardcoded as '0.4.1' in __init__.py while pyproject.toml declared version '0.1.5'. This caused nanobot gateway to report version 0.4.1 on startup while pip showed 0.1.5. Now __version__ reads from importlib.metadata.version('nanobot-ai'), keeping pyproject.toml as the single source of truth. --- nanobot/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nanobot/__init__.py b/nanobot/__init__.py index 433dbe139..93be6664c 100644 --- a/nanobot/__init__.py +++ b/nanobot/__init__.py @@ -2,7 +2,9 @@ nanobot - A lightweight AI agent framework """ -__version__ = "0.1.5" +from importlib.metadata import version as _pkg_version + +__version__ = _pkg_version("nanobot-ai") __logo__ = "🐈" from nanobot.nanobot import Nanobot, RunResult From 715f2a79be46bc316d23e77ee2fa17ca90a03158 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Wed, 8 Apr 2026 15:45:24 +0000 Subject: [PATCH 012/115] fix(version): fall back to pyproject in source checkouts Keep importlib.metadata as the primary source for installed packages, but avoid PackageNotFoundError when nanobot is imported directly from a source tree. Made-with: Cursor --- nanobot/__init__.py | 24 ++++++++++++++++++-- tests/test_package_version.py | 41 +++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 2 deletions(-) create mode 100644 tests/test_package_version.py diff --git a/nanobot/__init__.py b/nanobot/__init__.py index 93be6664c..0bce848df 100644 --- a/nanobot/__init__.py +++ b/nanobot/__init__.py @@ -2,9 +2,29 @@ nanobot - A lightweight AI agent framework """ -from importlib.metadata import version as _pkg_version +from importlib.metadata import PackageNotFoundError, version as _pkg_version +from pathlib import Path +import tomllib -__version__ = _pkg_version("nanobot-ai") + +def _read_pyproject_version() -> str | None: + """Read the source-tree version when package metadata is unavailable.""" + pyproject = Path(__file__).resolve().parent.parent / "pyproject.toml" + if not pyproject.exists(): + return None + data = tomllib.loads(pyproject.read_text(encoding="utf-8")) + return data.get("project", {}).get("version") + + +def _resolve_version() -> str: + try: + return _pkg_version("nanobot-ai") + except PackageNotFoundError: + # Source checkouts often import nanobot without installed dist-info. + return _read_pyproject_version() or "0.1.5" + + +__version__ = _resolve_version() __logo__ = "🐈" from nanobot.nanobot import Nanobot, RunResult diff --git a/tests/test_package_version.py b/tests/test_package_version.py new file mode 100644 index 000000000..4780757d6 --- /dev/null +++ b/tests/test_package_version.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import subprocess +import sys +import textwrap +from pathlib import Path + +import tomllib + + +def test_source_checkout_import_uses_pyproject_version_without_metadata() -> None: + repo_root = Path(__file__).resolve().parents[1] + expected = tomllib.loads((repo_root / "pyproject.toml").read_text(encoding="utf-8"))["project"][ + "version" + ] + script = textwrap.dedent( + f""" + import sys + import types + + sys.path.insert(0, {str(repo_root)!r}) + fake = types.ModuleType("nanobot.nanobot") + fake.Nanobot = object + fake.RunResult = object + sys.modules["nanobot.nanobot"] = fake + + import nanobot + + print(nanobot.__version__) + """ + ) + + proc = subprocess.run( + [sys.executable, "-S", "-c", script], + capture_output=True, + text=True, + check=False, + ) + + assert proc.returncode == 0, proc.stderr + assert proc.stdout.strip() == expected From e49b6c0c96ba716e176f61e08264b4da6320c431 Mon Sep 17 00:00:00 2001 From: SHLE1 Date: Wed, 8 Apr 2026 07:20:02 +0000 Subject: [PATCH 013/115] fix(discord): enable streaming replies --- README.md | 4 +- nanobot/channels/discord.py | 112 +++++++++++++++++++++++++ tests/channels/test_discord_channel.py | 44 ++++++++++ 3 files changed, 159 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index a2ea20f8c..0747b25ed 100644 --- a/README.md +++ b/README.md @@ -394,7 +394,8 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso "enabled": true, "token": "YOUR_BOT_TOKEN", "allowFrom": ["YOUR_USER_ID"], - "groupPolicy": "mention" + "groupPolicy": "mention", + "streaming": true } } } @@ -405,6 +406,7 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso > - `"open"` — Respond to all messages > DMs always respond when the sender is in `allowFrom`. > - If you set group policy to open create new threads as private threads and then @ the bot into it. Otherwise the thread itself and the channel in which you spawned it will spawn a bot session. +> `streaming` defaults to `true`. Disable it only if you explicitly want non-streaming replies. **5. Invite the bot** - OAuth2 → URL Generator diff --git a/nanobot/channels/discord.py b/nanobot/channels/discord.py index 9bf4d919c..9e68bb46b 100644 --- a/nanobot/channels/discord.py +++ b/nanobot/channels/discord.py @@ -4,6 +4,8 @@ from __future__ import annotations import asyncio import importlib.util +import time +from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any, Literal @@ -34,6 +36,16 @@ MAX_MESSAGE_LEN = 2000 # Discord message character limit TYPING_INTERVAL_S = 8 +@dataclass +class _StreamBuf: + """Per-chat streaming accumulator for progressive Discord message edits.""" + + text: str = "" + message: Any | None = None + last_edit: float = 0.0 + stream_id: str | None = None + + class DiscordConfig(Base): """Discord channel configuration.""" @@ -45,6 +57,7 @@ class DiscordConfig(Base): read_receipt_emoji: str = "👀" working_emoji: str = "🔧" working_emoji_delay: float = 2.0 + streaming: bool = True if DISCORD_AVAILABLE: @@ -242,6 +255,7 @@ class DiscordChannel(BaseChannel): name = "discord" display_name = "Discord" + _STREAM_EDIT_INTERVAL = 0.8 @classmethod def default_config(cls) -> dict[str, Any]: @@ -263,6 +277,7 @@ class DiscordChannel(BaseChannel): self._bot_user_id: str | None = None self._pending_reactions: dict[str, Any] = {} # chat_id -> message object self._working_emoji_tasks: dict[str, asyncio.Task[None]] = {} + self._stream_bufs: dict[str, _StreamBuf] = {} async def start(self) -> None: """Start the Discord client.""" @@ -320,6 +335,61 @@ class DiscordChannel(BaseChannel): await self._stop_typing(msg.chat_id) await self._clear_reactions(msg.chat_id) + async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None: + """Progressive Discord delivery: send once, then edit until the stream ends.""" + client = self._client + if client is None or not client.is_ready(): + logger.warning("Discord client not ready; dropping stream delta") + return + + meta = metadata or {} + stream_id = meta.get("_stream_id") + + if meta.get("_stream_end"): + buf = self._stream_bufs.get(chat_id) + if not buf or buf.message is None or not buf.text: + return + if stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id: + return + await self._finalize_stream(chat_id, buf) + return + + buf = self._stream_bufs.get(chat_id) + if buf is None or (stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id): + buf = _StreamBuf(stream_id=stream_id) + self._stream_bufs[chat_id] = buf + elif buf.stream_id is None: + buf.stream_id = stream_id + + buf.text += delta + if not buf.text.strip(): + return + + target = await self._resolve_channel(chat_id) + if target is None: + logger.warning("Discord stream target {} unavailable", chat_id) + return + + now = time.monotonic() + if buf.message is None: + try: + buf.message = await target.send(content=buf.text) + buf.last_edit = now + except Exception as e: + logger.warning("Discord stream initial send failed: {}", e) + raise + return + + if (now - buf.last_edit) < self._STREAM_EDIT_INTERVAL: + return + + try: + await buf.message.edit(content=DiscordBotClient._build_chunks(buf.text, [], False)[0]) + buf.last_edit = now + except Exception as e: + logger.warning("Discord stream edit failed: {}", e) + raise + async def _handle_discord_message(self, message: discord.Message) -> None: """Handle incoming Discord messages from discord.py.""" if message.author.bot: @@ -373,6 +443,47 @@ class DiscordChannel(BaseChannel): """Backward-compatible alias for legacy tests/callers.""" await self._handle_discord_message(message) + async def _resolve_channel(self, chat_id: str) -> Any | None: + """Resolve a Discord channel from cache first, then network fetch.""" + client = self._client + if client is None or not client.is_ready(): + return None + channel_id = int(chat_id) + channel = client.get_channel(channel_id) + if channel is not None: + return channel + try: + return await client.fetch_channel(channel_id) + except Exception as e: + logger.warning("Discord channel {} unavailable: {}", chat_id, e) + return None + + async def _finalize_stream(self, chat_id: str, buf: _StreamBuf) -> None: + """Commit the final streamed content and flush overflow chunks.""" + chunks = DiscordBotClient._build_chunks(buf.text, [], False) + if not chunks: + self._stream_bufs.pop(chat_id, None) + return + + try: + await buf.message.edit(content=chunks[0]) + except Exception as e: + logger.warning("Discord final stream edit failed: {}", e) + raise + + target = getattr(buf.message, "channel", None) or await self._resolve_channel(chat_id) + if target is None: + logger.warning("Discord stream follow-up target {} unavailable", chat_id) + self._stream_bufs.pop(chat_id, None) + return + + for extra_chunk in chunks[1:]: + await target.send(content=extra_chunk) + + self._stream_bufs.pop(chat_id, None) + await self._stop_typing(chat_id) + await self._clear_reactions(chat_id) + def _should_accept_inbound( self, message: discord.Message, @@ -507,6 +618,7 @@ class DiscordChannel(BaseChannel): async def _reset_runtime_state(self, close_client: bool) -> None: """Reset client and typing state.""" await self._cancel_all_typing() + self._stream_bufs.clear() if close_client and self._client is not None and not self._client.is_closed(): try: await self._client.close() diff --git a/tests/channels/test_discord_channel.py b/tests/channels/test_discord_channel.py index 845c03c57..f588334ba 100644 --- a/tests/channels/test_discord_channel.py +++ b/tests/channels/test_discord_channel.py @@ -71,11 +71,25 @@ class _FakePartialMessage: self.id = message_id +class _FakeSentMessage: + # Sent-message double supporting edit() for streaming tests. + def __init__(self, channel, content: str) -> None: + self.channel = channel + self.content = content + self.edits: list[dict] = [] + + async def edit(self, **kwargs) -> None: + self.edits.append(dict(kwargs)) + if "content" in kwargs: + self.content = kwargs["content"] + + class _FakeChannel: # Channel double that records outbound payloads and typing activity. def __init__(self, channel_id: int = 123) -> None: self.id = channel_id self.sent_payloads: list[dict] = [] + self.sent_messages: list[_FakeSentMessage] = [] self.trigger_typing_calls = 0 self.typing_enter_hook = None @@ -85,6 +99,9 @@ class _FakeChannel: payload["file_name"] = payload["file"].filename del payload["file"] self.sent_payloads.append(payload) + message = _FakeSentMessage(self, payload.get("content", "")) + self.sent_messages.append(message) + return message def get_partial_message(self, message_id: int) -> _FakePartialMessage: return _FakePartialMessage(message_id) @@ -427,6 +444,33 @@ async def test_send_fetches_channel_when_not_cached() -> None: assert target.sent_payloads == [{"content": "hello"}] +def test_supports_streaming_enabled_by_default() -> None: + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + + assert channel.supports_streaming is True + + +@pytest.mark.asyncio +async def test_send_delta_streams_by_editing_message(monkeypatch) -> None: + owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = _FakeDiscordClient(owner, intents=None) + owner._client = client + owner._running = True + target = _FakeChannel(channel_id=123) + client.channels[123] = target + + times = iter([1.0, 3.0, 5.0]) + monkeypatch.setattr("nanobot.channels.discord.time.monotonic", lambda: next(times, 5.0)) + + await owner.send_delta("123", "hel", {"_stream_delta": True, "_stream_id": "s1"}) + await owner.send_delta("123", "lo", {"_stream_delta": True, "_stream_id": "s1"}) + await owner.send_delta("123", "", {"_stream_end": True, "_stream_id": "s1"}) + + assert target.sent_payloads[0] == {"content": "hel"} + assert target.sent_messages[0].edits == [{"content": "hello"}, {"content": "hello"}] + assert owner._stream_bufs == {} + + @pytest.mark.asyncio async def test_slash_new_forwards_when_user_is_allowlisted() -> None: channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus()) From 61dd5ac13ad3de2ad04261732540123a710d0e5e Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Wed, 8 Apr 2026 16:21:18 +0000 Subject: [PATCH 014/115] test(discord): cover streamed reply overflow Lock the Discord streaming path with a regression test for final chunk splitting so oversized replies stay safe to merge and ship. Made-with: Cursor --- tests/channels/test_discord_channel.py | 29 +++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/tests/channels/test_discord_channel.py b/tests/channels/test_discord_channel.py index f588334ba..09b80740f 100644 --- a/tests/channels/test_discord_channel.py +++ b/tests/channels/test_discord_channel.py @@ -9,7 +9,7 @@ discord = pytest.importorskip("discord") from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus -from nanobot.channels.discord import DiscordBotClient, DiscordChannel, DiscordConfig +from nanobot.channels.discord import MAX_MESSAGE_LEN, DiscordBotClient, DiscordChannel, DiscordConfig from nanobot.command.builtin import build_help_text @@ -471,6 +471,33 @@ async def test_send_delta_streams_by_editing_message(monkeypatch) -> None: assert owner._stream_bufs == {} +@pytest.mark.asyncio +async def test_send_delta_stream_end_splits_oversized_reply(monkeypatch) -> None: + owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = _FakeDiscordClient(owner, intents=None) + owner._client = client + owner._running = True + target = _FakeChannel(channel_id=123) + client.channels[123] = target + + prefix = "a" * (MAX_MESSAGE_LEN - 100) + suffix = "b" * 150 + full_text = prefix + suffix + chunks = DiscordBotClient._build_chunks(full_text, [], False) + assert len(chunks) == 2 + + times = iter([1.0, 3.0]) + monkeypatch.setattr("nanobot.channels.discord.time.monotonic", lambda: next(times, 3.0)) + + await owner.send_delta("123", prefix, {"_stream_delta": True, "_stream_id": "s1"}) + await owner.send_delta("123", suffix, {"_stream_delta": True, "_stream_id": "s1"}) + await owner.send_delta("123", "", {"_stream_end": True, "_stream_id": "s1"}) + + assert target.sent_payloads == [{"content": prefix}, {"content": chunks[1]}] + assert target.sent_messages[0].edits == [{"content": chunks[0]}, {"content": chunks[0]}] + assert owner._stream_bufs == {} + + @pytest.mark.asyncio async def test_slash_new_forwards_when_user_is_allowlisted() -> None: channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus()) From 66409784f4bf70b0ca8aaefb63a8e50f3c24e8c2 Mon Sep 17 00:00:00 2001 From: Leo fu Date: Wed, 8 Apr 2026 12:54:17 -0400 Subject: [PATCH 015/115] fix(status): use consistent divisor (1000) for token count display The /status command divided context_used by 1000 but context_total by 1024, producing inconsistent values. For example a 128000-token window displayed as 125k instead of 128k. Tokens are not a binary unit, so both should use 1000. --- nanobot/utils/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index 7267bac2a..9c2f48960 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -417,7 +417,7 @@ def build_status_content( ctx_total = max(context_window_tokens, 0) ctx_pct = int((context_tokens_estimate / ctx_total) * 100) if ctx_total > 0 else 0 ctx_used_str = f"{context_tokens_estimate // 1000}k" if context_tokens_estimate >= 1000 else str(context_tokens_estimate) - ctx_total_str = f"{ctx_total // 1024}k" if ctx_total > 0 else "n/a" + ctx_total_str = f"{ctx_total // 1000}k" if ctx_total > 0 else "n/a" token_line = f"\U0001f4ca Tokens: {last_in} in / {last_out} out" if cached and last_in: token_line += f" ({cached * 100 // last_in}% cached)" From 42624f5bf378f2f4498a18e490041b9e890badb8 Mon Sep 17 00:00:00 2001 From: Leo fu Date: Wed, 8 Apr 2026 12:57:22 -0400 Subject: [PATCH 016/115] test: update expected token display to match consistent 1000 divisor The test fixtures use 65536 as context_window_tokens. With the divisor corrected from 1024 to 1000, the display changes from 64k to 65k. --- tests/cli/test_restart_command.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cli/test_restart_command.py b/tests/cli/test_restart_command.py index 8b079d4e7..697d5fc17 100644 --- a/tests/cli/test_restart_command.py +++ b/tests/cli/test_restart_command.py @@ -148,7 +148,7 @@ class TestRestartCommand: assert response is not None assert "Model: test-model" in response.content assert "Tokens: 0 in / 0 out" in response.content - assert "Context: 20k/64k (31%)" in response.content + assert "Context: 20k/65k (31%)" in response.content assert "Session: 3 messages" in response.content assert "Uptime: 2m 5s" in response.content assert response.metadata == {"render_as": "text"} @@ -186,7 +186,7 @@ class TestRestartCommand: assert response is not None assert "Tokens: 1200 in / 34 out" in response.content - assert "Context: 1k/64k (1%)" in response.content + assert "Context: 1k/65k (1%)" in response.content @pytest.mark.asyncio async def test_process_direct_preserves_render_metadata(self): From 3cc2ebeef70a4b86607ea0c00d86b1aebd8f9876 Mon Sep 17 00:00:00 2001 From: Rohit_Dayanand123 <66650100+RohitDayanand@users.noreply.github.com> Date: Wed, 8 Apr 2026 11:45:21 -0400 Subject: [PATCH 017/115] Added bug fix to Dingtalk by zipping html to prevent raw failure --- nanobot/channels/dingtalk.py | 29 ++++++++++ tests/channels/test_dingtalk_channel.py | 77 +++++++++++++++++++++++++ 2 files changed, 106 insertions(+) diff --git a/nanobot/channels/dingtalk.py b/nanobot/channels/dingtalk.py index ab12211e8..39b5818bd 100644 --- a/nanobot/channels/dingtalk.py +++ b/nanobot/channels/dingtalk.py @@ -5,6 +5,8 @@ import json import mimetypes import os import time +import zipfile +from io import BytesIO from pathlib import Path from typing import Any from urllib.parse import unquote, urlparse @@ -171,6 +173,7 @@ class DingTalkChannel(BaseChannel): _IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"} _AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"} _VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"} + _ZIP_BEFORE_UPLOAD_EXTS = {".htm", ".html"} @classmethod def default_config(cls) -> dict[str, Any]: @@ -287,6 +290,31 @@ class DingTalkChannel(BaseChannel): name = os.path.basename(urlparse(media_ref).path) return name or {"image": "image.jpg", "voice": "audio.amr", "video": "video.mp4"}.get(upload_type, "file.bin") + @staticmethod + def _zip_bytes(filename: str, data: bytes) -> tuple[bytes, str, str]: + stem = Path(filename).stem or "attachment" + safe_name = filename or "attachment.bin" + zip_name = f"{stem}.zip" + buffer = BytesIO() + with zipfile.ZipFile(buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as archive: + archive.writestr(safe_name, data) + return buffer.getvalue(), zip_name, "application/zip" + + def _normalize_upload_payload( + self, + filename: str, + data: bytes, + content_type: str | None, + ) -> tuple[bytes, str, str | None]: + ext = Path(filename).suffix.lower() + if ext in self._ZIP_BEFORE_UPLOAD_EXTS or content_type == "text/html": + logger.info( + "DingTalk does not accept raw HTML attachments, zipping {} before upload", + filename, + ) + return self._zip_bytes(filename, data) + return data, filename, content_type + async def _read_media_bytes( self, media_ref: str, @@ -444,6 +472,7 @@ class DingTalkChannel(BaseChannel): return False filename = filename or self._guess_filename(media_ref, upload_type) + data, filename, content_type = self._normalize_upload_payload(filename, data, content_type) file_type = Path(filename).suffix.lower().lstrip(".") if not file_type: guessed = mimetypes.guess_extension(content_type or "") diff --git a/tests/channels/test_dingtalk_channel.py b/tests/channels/test_dingtalk_channel.py index 6894c8683..f743c4e62 100644 --- a/tests/channels/test_dingtalk_channel.py +++ b/tests/channels/test_dingtalk_channel.py @@ -1,4 +1,6 @@ import asyncio +import zipfile +from io import BytesIO from types import SimpleNamespace import pytest @@ -221,3 +223,78 @@ async def test_download_dingtalk_file(tmp_path, monkeypatch) -> None: assert "messageFiles/download" in channel._http.calls[0]["url"] assert channel._http.calls[0]["json"]["downloadCode"] == "code123" assert channel._http.calls[1]["method"] == "GET" + + +def test_normalize_upload_payload_zips_html_attachment() -> None: + channel = DingTalkChannel( + DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]), + MessageBus(), + ) + + data, filename, content_type = channel._normalize_upload_payload( + "report.html", + b"Hello", + "text/html", + ) + + assert filename == "report.zip" + assert content_type == "application/zip" + + archive = zipfile.ZipFile(BytesIO(data)) + assert archive.namelist() == ["report.html"] + assert archive.read("report.html") == b"Hello" + + +@pytest.mark.asyncio +async def test_send_media_ref_zips_html_before_upload(tmp_path, monkeypatch) -> None: + channel = DingTalkChannel( + DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]), + MessageBus(), + ) + + html_path = tmp_path / "report.html" + html_path.write_text("Hello", encoding="utf-8") + + captured: dict[str, object] = {} + + async def fake_upload_media(*, token, data, media_type, filename, content_type): + captured.update( + { + "token": token, + "data": data, + "media_type": media_type, + "filename": filename, + "content_type": content_type, + } + ) + return "media-123" + + async def fake_send_batch_message(token, chat_id, msg_key, msg_param): + captured.update( + { + "sent_token": token, + "chat_id": chat_id, + "msg_key": msg_key, + "msg_param": msg_param, + } + ) + return True + + monkeypatch.setattr(channel, "_upload_media", fake_upload_media) + monkeypatch.setattr(channel, "_send_batch_message", fake_send_batch_message) + + ok = await channel._send_media_ref("token-123", "user-1", str(html_path)) + + assert ok is True + assert captured["media_type"] == "file" + assert captured["filename"] == "report.zip" + assert captured["content_type"] == "application/zip" + assert captured["msg_key"] == "sampleFile" + assert captured["msg_param"] == { + "mediaId": "media-123", + "fileName": "report.zip", + "fileType": "zip", + } + + archive = zipfile.ZipFile(BytesIO(captured["data"])) + assert archive.namelist() == ["report.html"] From bfec06a2c197bd794fb74130e0322d36234e0b56 Mon Sep 17 00:00:00 2001 From: chensp <1051393758@qq.com> Date: Wed, 8 Apr 2026 15:28:58 +0800 Subject: [PATCH 018/115] Fix Windows exec env for Docker Desktop plugin discovery nanobot's Windows exec environment was not forwarding ProgramFiles and related variables, so docker desktop start could not discover the desktop CLI plugin and reported unknown command. Forward the missing variables and add a regression test that covers the Windows env shape. --- nanobot/agent/tools/shell.py | 6 ++++++ tests/tools/test_exec_platform.py | 12 +++++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index 23da2c10f..eb786e9f4 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -218,6 +218,12 @@ class ExecTool(Tool): "TMP": os.environ.get("TMP", f"{sr}\\Temp"), "PATHEXT": os.environ.get("PATHEXT", ".COM;.EXE;.BAT;.CMD"), "PATH": os.environ.get("PATH", f"{sr}\\system32;{sr}"), + "APPDATA": os.environ.get("APPDATA", ""), + "LOCALAPPDATA": os.environ.get("LOCALAPPDATA", ""), + "ProgramData": os.environ.get("ProgramData", ""), + "ProgramFiles": os.environ.get("ProgramFiles", ""), + "ProgramFiles(x86)": os.environ.get("ProgramFiles(x86)", ""), + "ProgramW6432": os.environ.get("ProgramW6432", ""), } home = os.environ.get("HOME", "/tmp") return { diff --git a/tests/tools/test_exec_platform.py b/tests/tools/test_exec_platform.py index aa3ffee71..b24d01ac4 100644 --- a/tests/tools/test_exec_platform.py +++ b/tests/tools/test_exec_platform.py @@ -5,12 +5,18 @@ strategy, and sandbox behaviour per platform — without actually running platform-specific binaries (all subprocess calls are mocked). """ +import sys from unittest.mock import AsyncMock, patch import pytest from nanobot.agent.tools.shell import ExecTool +_WINDOWS_ENV_KEYS = { + "APPDATA", "LOCALAPPDATA", "ProgramData", + "ProgramFiles", "ProgramFiles(x86)", "ProgramW6432", +} + # --------------------------------------------------------------------------- # _build_env @@ -21,7 +27,10 @@ class TestBuildEnvUnix: def test_expected_keys(self): with patch("nanobot.agent.tools.shell._IS_WINDOWS", False): env = ExecTool()._build_env() - assert set(env) == {"HOME", "LANG", "TERM"} + expected = {"HOME", "LANG", "TERM"} + assert expected <= set(env) + if sys.platform != "win32": + assert set(env) == expected def test_home_from_environ(self, monkeypatch): monkeypatch.setenv("HOME", "/Users/dev") @@ -45,6 +54,7 @@ class TestBuildEnvWindows: _EXPECTED_KEYS = { "SYSTEMROOT", "COMSPEC", "USERPROFILE", "HOMEDRIVE", "HOMEPATH", "TEMP", "TMP", "PATHEXT", "PATH", + *_WINDOWS_ENV_KEYS, } def test_expected_keys(self): From 743e73da3fc7baef7159eab78562c1ab8908cacf Mon Sep 17 00:00:00 2001 From: whs Date: Tue, 7 Apr 2026 21:47:58 +0800 Subject: [PATCH 019/115] feat(session): add unified_session config to share one session across all channels --- nanobot/agent/loop.py | 6 +- nanobot/cli/commands.py | 3 + nanobot/config/schema.py | 1 + nanobot/nanobot.py | 1 + tests/agent/test_unified_session.py | 195 ++++++++++++++++++++++++++++ 5 files changed, 205 insertions(+), 1 deletion(-) create mode 100644 tests/agent/test_unified_session.py diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index d549d6d4a..593331c3f 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -143,6 +143,7 @@ class AgentLoop: channels_config: ChannelsConfig | None = None, timezone: str | None = None, hooks: list[AgentHook] | None = None, + unified_session: bool = False, ): from nanobot.config.schema import ExecToolConfig, WebToolsConfig @@ -189,7 +190,7 @@ class AgentLoop: exec_config=self.exec_config, restrict_to_workspace=restrict_to_workspace, ) - + self._unified_session = unified_session self._running = False self._mcp_servers = mcp_servers or {} self._mcp_stack: AsyncExitStack | None = None @@ -390,6 +391,9 @@ class AgentLoop: async def _dispatch(self, msg: InboundMessage) -> None: """Process a message: per-session serial, cross-session concurrent.""" + if self._unified_session and not msg.session_key_override: + import dataclasses + msg = dataclasses.replace(msg, session_key_override="unified:default") lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock()) gate = self._concurrency_gate or nullcontext() async with lock, gate: diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index a1fb7c0e0..7c4d31f3e 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -590,6 +590,7 @@ def serve( mcp_servers=runtime_config.tools.mcp_servers, channels_config=runtime_config.channels, timezone=runtime_config.agents.defaults.timezone, + unified_session=runtime_config.agents.defaults.unified_session, ) model_name = runtime_config.agents.defaults.model @@ -681,6 +682,7 @@ def gateway( mcp_servers=config.tools.mcp_servers, channels_config=config.channels, timezone=config.agents.defaults.timezone, + unified_session=config.agents.defaults.unified_session, ) # Set cron callback (needs agent) @@ -912,6 +914,7 @@ def agent( mcp_servers=config.tools.mcp_servers, channels_config=config.channels, timezone=config.agents.defaults.timezone, + unified_session=config.agents.defaults.unified_session, ) restart_notice = consume_restart_notice_from_env() if restart_notice and should_show_cli_restart_notice(restart_notice, session_id): diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index dce4c19dc..b011d765f 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -76,6 +76,7 @@ class AgentDefaults(Base): provider_retry_mode: Literal["standard", "persistent"] = "standard" reasoning_effort: str | None = None # low / medium / high / adaptive - enables LLM thinking mode timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York" + unified_session: bool = False # Share one session across all channels (single-user multi-device) dream: DreamConfig = Field(default_factory=DreamConfig) diff --git a/nanobot/nanobot.py b/nanobot/nanobot.py index 85e9e1ddb..9166acb27 100644 --- a/nanobot/nanobot.py +++ b/nanobot/nanobot.py @@ -81,6 +81,7 @@ class Nanobot: restrict_to_workspace=config.tools.restrict_to_workspace, mcp_servers=config.tools.mcp_servers, timezone=defaults.timezone, + unified_session=defaults.unified_session, ) return cls(loop) diff --git a/tests/agent/test_unified_session.py b/tests/agent/test_unified_session.py new file mode 100644 index 000000000..1b3c5fb97 --- /dev/null +++ b/tests/agent/test_unified_session.py @@ -0,0 +1,195 @@ +"""Tests for unified_session feature. + +Covers: +- AgentLoop._dispatch() rewrites session_key to "unified:default" when enabled +- Existing session_key_override is respected (not overwritten) +- Feature is off by default (no behavior change for existing users) +- Config schema serialises unified_session as camelCase "unifiedSession" +- onboard-generated config.json contains "unifiedSession" key +""" + +import json +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from nanobot.agent.loop import AgentLoop +from nanobot.bus.events import InboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.config.schema import AgentDefaults, Config + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_loop(tmp_path: Path, unified_session: bool = False) -> AgentLoop: + """Create a minimal AgentLoop for dispatch-level tests.""" + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + + with patch("nanobot.agent.loop.SessionManager"), \ + patch("nanobot.agent.loop.SubagentManager") as MockSubMgr, \ + patch("nanobot.agent.loop.Dream"): + MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0) + loop = AgentLoop( + bus=bus, + provider=provider, + workspace=tmp_path, + unified_session=unified_session, + ) + return loop + + +def _make_msg(channel: str = "telegram", chat_id: str = "111", + session_key_override: str | None = None) -> InboundMessage: + return InboundMessage( + channel=channel, + chat_id=chat_id, + sender_id="user1", + content="hello", + session_key_override=session_key_override, + ) + + +# --------------------------------------------------------------------------- +# TestUnifiedSessionDispatch — core behaviour +# --------------------------------------------------------------------------- + +class TestUnifiedSessionDispatch: + """AgentLoop._dispatch() session key rewriting logic.""" + + @pytest.mark.asyncio + async def test_unified_session_rewrites_key_to_unified_default(self, tmp_path: Path): + """When unified_session=True, all messages use 'unified:default' as session key.""" + loop = _make_loop(tmp_path, unified_session=True) + + captured: list[str] = [] + + async def fake_process(msg, **kwargs): + captured.append(msg.session_key) + return None + + loop._process_message = fake_process # type: ignore[method-assign] + + msg = _make_msg(channel="telegram", chat_id="111") + await loop._dispatch(msg) + + assert captured == ["unified:default"] + + @pytest.mark.asyncio + async def test_unified_session_different_channels_share_same_key(self, tmp_path: Path): + """Messages from different channels all resolve to the same session key.""" + loop = _make_loop(tmp_path, unified_session=True) + + captured: list[str] = [] + + async def fake_process(msg, **kwargs): + captured.append(msg.session_key) + return None + + loop._process_message = fake_process # type: ignore[method-assign] + + await loop._dispatch(_make_msg(channel="telegram", chat_id="111")) + await loop._dispatch(_make_msg(channel="discord", chat_id="222")) + await loop._dispatch(_make_msg(channel="cli", chat_id="direct")) + + assert captured == ["unified:default", "unified:default", "unified:default"] + + @pytest.mark.asyncio + async def test_unified_session_disabled_preserves_original_key(self, tmp_path: Path): + """When unified_session=False (default), session key is channel:chat_id as usual.""" + loop = _make_loop(tmp_path, unified_session=False) + + captured: list[str] = [] + + async def fake_process(msg, **kwargs): + captured.append(msg.session_key) + return None + + loop._process_message = fake_process # type: ignore[method-assign] + + msg = _make_msg(channel="telegram", chat_id="999") + await loop._dispatch(msg) + + assert captured == ["telegram:999"] + + @pytest.mark.asyncio + async def test_unified_session_respects_existing_override(self, tmp_path: Path): + """If session_key_override is already set (e.g. Telegram thread), it is NOT overwritten.""" + loop = _make_loop(tmp_path, unified_session=True) + + captured: list[str] = [] + + async def fake_process(msg, **kwargs): + captured.append(msg.session_key) + return None + + loop._process_message = fake_process # type: ignore[method-assign] + + msg = _make_msg(channel="telegram", chat_id="111", session_key_override="telegram:thread:42") + await loop._dispatch(msg) + + assert captured == ["telegram:thread:42"] + + def test_unified_session_default_is_false(self, tmp_path: Path): + """unified_session defaults to False — no behavior change for existing users.""" + loop = _make_loop(tmp_path) + assert loop._unified_session is False + + +# --------------------------------------------------------------------------- +# TestUnifiedSessionConfig — schema & serialisation +# --------------------------------------------------------------------------- + +class TestUnifiedSessionConfig: + """Config schema and onboard serialisation for unified_session.""" + + def test_agent_defaults_unified_session_default_is_false(self): + """AgentDefaults.unified_session defaults to False.""" + defaults = AgentDefaults() + assert defaults.unified_session is False + + def test_agent_defaults_unified_session_can_be_enabled(self): + """AgentDefaults.unified_session can be set to True.""" + defaults = AgentDefaults(unified_session=True) + assert defaults.unified_session is True + + def test_config_serialises_unified_session_as_camel_case(self): + """model_dump(by_alias=True) outputs 'unifiedSession' (camelCase) for JSON.""" + config = Config() + data = config.model_dump(mode="json", by_alias=True) + agents_defaults = data["agents"]["defaults"] + assert "unifiedSession" in agents_defaults + assert agents_defaults["unifiedSession"] is False + + def test_config_parses_unified_session_from_camel_case(self): + """Config can be loaded from JSON with camelCase 'unifiedSession'.""" + raw = {"agents": {"defaults": {"unifiedSession": True}}} + config = Config.model_validate(raw) + assert config.agents.defaults.unified_session is True + + def test_config_parses_unified_session_from_snake_case(self): + """Config also accepts snake_case 'unified_session' (populate_by_name=True).""" + raw = {"agents": {"defaults": {"unified_session": True}}} + config = Config.model_validate(raw) + assert config.agents.defaults.unified_session is True + + def test_onboard_generated_config_contains_unified_session(self, tmp_path: Path): + """save_config() writes 'unifiedSession' into config.json (simulates nanobot onboard).""" + from nanobot.config.loader import save_config + + config = Config() + config_path = tmp_path / "config.json" + save_config(config, config_path) + + with open(config_path, encoding="utf-8") as f: + data = json.load(f) + + agents_defaults = data["agents"]["defaults"] + assert "unifiedSession" in agents_defaults, ( + "onboard-generated config.json must contain 'unifiedSession' key" + ) + assert agents_defaults["unifiedSession"] is False From 985f9c443ba6d98282190bf2bf1e8f28e1190a5f Mon Sep 17 00:00:00 2001 From: whs Date: Wed, 8 Apr 2026 06:22:19 +0800 Subject: [PATCH 020/115] tests: add unified_session coverage for /new and consolidation --- tests/agent/test_unified_session.py | 205 ++++++++++++++++++++++++++++ 1 file changed, 205 insertions(+) diff --git a/tests/agent/test_unified_session.py b/tests/agent/test_unified_session.py index 1b3c5fb97..1d9eaad64 100644 --- a/tests/agent/test_unified_session.py +++ b/tests/agent/test_unified_session.py @@ -6,10 +6,15 @@ Covers: - Feature is off by default (no behavior change for existing users) - Config schema serialises unified_session as camelCase "unifiedSession" - onboard-generated config.json contains "unifiedSession" key +- /new command correctly clears the shared session in unified mode +- /new is NOT a priority command (goes through _dispatch, key rewrite applies) +- Context window consolidation is unaffected by unified_session """ +import asyncio import json from pathlib import Path +from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -17,7 +22,10 @@ import pytest from nanobot.agent.loop import AgentLoop from nanobot.bus.events import InboundMessage from nanobot.bus.queue import MessageBus +from nanobot.command.builtin import cmd_new, register_builtin_commands +from nanobot.command.router import CommandContext, CommandRouter from nanobot.config.schema import AgentDefaults, Config +from nanobot.session.manager import Session, SessionManager # --------------------------------------------------------------------------- @@ -193,3 +201,200 @@ class TestUnifiedSessionConfig: "onboard-generated config.json must contain 'unifiedSession' key" ) assert agents_defaults["unifiedSession"] is False + + +# --------------------------------------------------------------------------- +# TestCmdNewUnifiedSession — /new command behaviour in unified mode +# --------------------------------------------------------------------------- + +class TestCmdNewUnifiedSession: + """/new command routing and session-clear behaviour in unified mode.""" + + def test_new_is_not_a_priority_command(self): + """/new must NOT be in the priority table — it must go through _dispatch() + so the unified session key rewrite applies before cmd_new runs.""" + router = CommandRouter() + register_builtin_commands(router) + assert router.is_priority("/new") is False + + def test_new_is_an_exact_command(self): + """/new must be registered as an exact command.""" + router = CommandRouter() + register_builtin_commands(router) + assert "/new" in router._exact + + @pytest.mark.asyncio + async def test_cmd_new_clears_unified_session(self, tmp_path: Path): + """cmd_new called with key='unified:default' clears the shared session.""" + sessions = SessionManager(tmp_path) + + # Pre-populate the shared session with some messages + shared = sessions.get_or_create("unified:default") + shared.add_message("user", "hello from telegram") + shared.add_message("assistant", "hi there") + sessions.save(shared) + assert len(sessions.get_or_create("unified:default").messages) == 2 + + # _schedule_background is a *sync* method that schedules a coroutine via + # asyncio.create_task(). Mirror that exactly so the coroutine is consumed + # and no RuntimeWarning is emitted. + loop = SimpleNamespace( + sessions=sessions, + consolidator=SimpleNamespace(archive=AsyncMock(return_value=True)), + ) + loop._schedule_background = lambda coro: asyncio.ensure_future(coro) + + msg = InboundMessage( + channel="telegram", sender_id="user1", chat_id="111", content="/new", + session_key_override="unified:default", # as _dispatch() would set it + ) + ctx = CommandContext(msg=msg, session=None, key="unified:default", raw="/new", loop=loop) + + result = await cmd_new(ctx) + + assert "New session started" in result.content + # Invalidate cache and reload from disk to confirm persistence + sessions.invalidate("unified:default") + reloaded = sessions.get_or_create("unified:default") + assert reloaded.messages == [] + + @pytest.mark.asyncio + async def test_cmd_new_in_unified_mode_does_not_affect_other_sessions(self, tmp_path: Path): + """Clearing unified:default must not touch other sessions on disk.""" + sessions = SessionManager(tmp_path) + + other = sessions.get_or_create("discord:999") + other.add_message("user", "discord message") + sessions.save(other) + + shared = sessions.get_or_create("unified:default") + shared.add_message("user", "shared message") + sessions.save(shared) + + loop = SimpleNamespace( + sessions=sessions, + consolidator=SimpleNamespace(archive=AsyncMock(return_value=True)), + ) + loop._schedule_background = lambda coro: asyncio.ensure_future(coro) + + msg = InboundMessage( + channel="telegram", sender_id="user1", chat_id="111", content="/new", + session_key_override="unified:default", + ) + ctx = CommandContext(msg=msg, session=None, key="unified:default", raw="/new", loop=loop) + await cmd_new(ctx) + + sessions.invalidate("unified:default") + sessions.invalidate("discord:999") + assert sessions.get_or_create("unified:default").messages == [] + assert len(sessions.get_or_create("discord:999").messages) == 1 + + +# --------------------------------------------------------------------------- +# TestConsolidationUnaffectedByUnifiedSession — consolidation is key-agnostic +# --------------------------------------------------------------------------- + +class TestConsolidationUnaffectedByUnifiedSession: + """maybe_consolidate_by_tokens() behaviour is identical regardless of session key.""" + + @pytest.mark.asyncio + async def test_consolidation_skips_empty_session_for_unified_key(self): + """Empty unified:default session → consolidation exits immediately, archive not called.""" + from nanobot.agent.memory import Consolidator, MemoryStore + + store = MagicMock(spec=MemoryStore) + mock_provider = MagicMock() + mock_provider.chat_with_retry = AsyncMock(return_value=MagicMock(content="summary")) + # Use spec= so MagicMock doesn't auto-generate AsyncMock for non-async methods, + # which would leave unawaited coroutines and trigger RuntimeWarning. + sessions = MagicMock(spec=SessionManager) + + consolidator = Consolidator( + store=store, + provider=mock_provider, + model="test-model", + sessions=sessions, + context_window_tokens=1000, + build_messages=MagicMock(return_value=[]), + get_tool_definitions=MagicMock(return_value=[]), + max_completion_tokens=100, + ) + consolidator.archive = AsyncMock() + + session = Session(key="unified:default") + session.messages = [] + + await consolidator.maybe_consolidate_by_tokens(session) + + consolidator.archive.assert_not_called() + + @pytest.mark.asyncio + async def test_consolidation_behaviour_identical_for_any_key(self): + """archive call count is the same for 'telegram:123' and 'unified:default' + under identical token conditions.""" + from nanobot.agent.memory import Consolidator, MemoryStore + + archive_calls: dict[str, int] = {} + + for key in ("telegram:123", "unified:default"): + store = MagicMock(spec=MemoryStore) + mock_provider = MagicMock() + mock_provider.chat_with_retry = AsyncMock(return_value=MagicMock(content="summary")) + sessions = MagicMock(spec=SessionManager) + + consolidator = Consolidator( + store=store, + provider=mock_provider, + model="test-model", + sessions=sessions, + context_window_tokens=1000, + build_messages=MagicMock(return_value=[]), + get_tool_definitions=MagicMock(return_value=[]), + max_completion_tokens=100, + ) + + session = Session(key=key) + session.messages = [] # empty → exits immediately for both keys + + consolidator.archive = AsyncMock() + await consolidator.maybe_consolidate_by_tokens(session) + archive_calls[key] = consolidator.archive.call_count + + assert archive_calls["telegram:123"] == archive_calls["unified:default"] == 0 + + @pytest.mark.asyncio + async def test_consolidation_triggers_when_over_budget_unified_key(self): + """When tokens exceed budget, consolidation attempts to find a boundary — + behaviour is identical to any other session key.""" + from nanobot.agent.memory import Consolidator, MemoryStore + + store = MagicMock(spec=MemoryStore) + mock_provider = MagicMock() + sessions = MagicMock(spec=SessionManager) + + consolidator = Consolidator( + store=store, + provider=mock_provider, + model="test-model", + sessions=sessions, + context_window_tokens=1000, + build_messages=MagicMock(return_value=[]), + get_tool_definitions=MagicMock(return_value=[]), + max_completion_tokens=100, + ) + + session = Session(key="unified:default") + session.messages = [{"role": "user", "content": "msg"}] + + # Simulate over-budget: estimated > budget + consolidator.estimate_session_prompt_tokens = MagicMock(return_value=(950, "tiktoken")) + # No valid boundary found → returns gracefully without archiving + consolidator.pick_consolidation_boundary = MagicMock(return_value=None) + consolidator.archive = AsyncMock() + + await consolidator.maybe_consolidate_by_tokens(session) + + # estimate was called (consolidation was attempted) + consolidator.estimate_session_prompt_tokens.assert_called_once_with(session) + # but archive was not called (no valid boundary) + consolidator.archive.assert_not_called() From b4c7cd654ee69ac167c5a49a61e2e04641c086ab Mon Sep 17 00:00:00 2001 From: whs Date: Wed, 8 Apr 2026 21:39:12 +0800 Subject: [PATCH 021/115] fix: use effective session key for _active_tasks in unified mode --- nanobot/agent/loop.py | 14 ++-- tests/agent/test_unified_session.py | 102 ++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 4 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 593331c3f..76bed4158 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -3,7 +3,9 @@ from __future__ import annotations import asyncio +import dataclasses import json +import re import os import time from contextlib import AsyncExitStack, nullcontext @@ -40,6 +42,8 @@ if TYPE_CHECKING: from nanobot.cron.service import CronService +# Named constant for unified session key, used across multiple locations +UNIFIED_SESSION_KEY = "unified:default" class _LoopHook(AgentHook): """Core hook for the main loop.""" @@ -385,15 +389,17 @@ class AgentLoop: if result: await self.bus.publish_outbound(result) continue + # Compute the effective session key before dispatching + # This ensures /stop command can find tasks correctly when unified session is enabled + effective_key = UNIFIED_SESSION_KEY if self._unified_session and not msg.session_key_override else msg.session_key task = asyncio.create_task(self._dispatch(msg)) - self._active_tasks.setdefault(msg.session_key, []).append(task) - task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None) + self._active_tasks.setdefault(effective_key, []).append(task) + task.add_done_callback(lambda t, k=effective_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None) async def _dispatch(self, msg: InboundMessage) -> None: """Process a message: per-session serial, cross-session concurrent.""" if self._unified_session and not msg.session_key_override: - import dataclasses - msg = dataclasses.replace(msg, session_key_override="unified:default") + msg = dataclasses.replace(msg, session_key_override=UNIFIED_SESSION_KEY) lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock()) gate = self._concurrency_gate or nullcontext() async with lock, gate: diff --git a/tests/agent/test_unified_session.py b/tests/agent/test_unified_session.py index 1d9eaad64..557beaca7 100644 --- a/tests/agent/test_unified_session.py +++ b/tests/agent/test_unified_session.py @@ -398,3 +398,105 @@ class TestConsolidationUnaffectedByUnifiedSession: consolidator.estimate_session_prompt_tokens.assert_called_once_with(session) # but archive was not called (no valid boundary) consolidator.archive.assert_not_called() + + +# --------------------------------------------------------------------------- +# TestStopCommandWithUnifiedSession — /stop command integration +# --------------------------------------------------------------------------- + + +class TestStopCommandWithUnifiedSession: + """Verify /stop command works correctly with unified session enabled.""" + + @pytest.mark.asyncio + async def test_active_tasks_use_effective_key_in_unified_mode(self, tmp_path: Path): + """When unified_session=True, tasks are stored under UNIFIED_SESSION_KEY.""" + from nanobot.agent.loop import UNIFIED_SESSION_KEY + + loop = _make_loop(tmp_path, unified_session=True) + + # Create a message from telegram channel + msg = _make_msg(channel="telegram", chat_id="123456") + + # Mock _dispatch to complete immediately + async def fake_dispatch(m): + pass + + loop._dispatch = fake_dispatch # type: ignore[method-assign] + + # Simulate the task creation flow (from _run loop) + effective_key = UNIFIED_SESSION_KEY if loop._unified_session and not msg.session_key_override else msg.session_key + task = asyncio.create_task(loop._dispatch(msg)) + loop._active_tasks.setdefault(effective_key, []).append(task) + + # Wait for task to complete + await task + + # Verify the task is stored under UNIFIED_SESSION_KEY, not the original channel:chat_id + assert UNIFIED_SESSION_KEY in loop._active_tasks + assert "telegram:123456" not in loop._active_tasks + + @pytest.mark.asyncio + async def test_stop_command_finds_task_in_unified_mode(self, tmp_path: Path): + """cmd_stop can cancel tasks when unified_session=True.""" + from nanobot.agent.loop import UNIFIED_SESSION_KEY + from nanobot.command.builtin import cmd_stop + + loop = _make_loop(tmp_path, unified_session=True) + + # Create a long-running task stored under UNIFIED_SESSION_KEY + async def long_running(): + await asyncio.sleep(10) # Will be cancelled + + task = asyncio.create_task(long_running()) + loop._active_tasks[UNIFIED_SESSION_KEY] = [task] + + # Create a message that would have session_key=UNIFIED_SESSION_KEY after dispatch + msg = InboundMessage( + channel="telegram", + chat_id="123456", + sender_id="user1", + content="/stop", + session_key_override=UNIFIED_SESSION_KEY, # Simulate post-dispatch state + ) + + ctx = CommandContext(msg=msg, session=None, key=UNIFIED_SESSION_KEY, raw="/stop", loop=loop) + + # Execute /stop + result = await cmd_stop(ctx) + + # Verify task was cancelled + assert task.cancelled() or task.done() + assert "Stopped 1 task" in result.content + + @pytest.mark.asyncio + async def test_stop_command_cross_channel_in_unified_mode(self, tmp_path: Path): + """In unified mode, /stop from one channel cancels tasks from another channel.""" + from nanobot.agent.loop import UNIFIED_SESSION_KEY + from nanobot.command.builtin import cmd_stop + + loop = _make_loop(tmp_path, unified_session=True) + + # Create tasks from different channels, all stored under UNIFIED_SESSION_KEY + async def long_running(): + await asyncio.sleep(10) + + task1 = asyncio.create_task(long_running()) + task2 = asyncio.create_task(long_running()) + loop._active_tasks[UNIFIED_SESSION_KEY] = [task1, task2] + + # /stop from discord should cancel tasks started from telegram + msg = InboundMessage( + channel="discord", + chat_id="789012", + sender_id="user2", + content="/stop", + session_key_override=UNIFIED_SESSION_KEY, + ) + + ctx = CommandContext(msg=msg, session=None, key=UNIFIED_SESSION_KEY, raw="/stop", loop=loop) + + result = await cmd_stop(ctx) + + # Both tasks should be cancelled + assert "Stopped 2 task" in result.content \ No newline at end of file From be1b34ed7c3c5c5d359f7fa2d6f48a0dac531b0b Mon Sep 17 00:00:00 2001 From: whs Date: Wed, 8 Apr 2026 21:53:22 +0800 Subject: [PATCH 022/115] fix: remove unused import re --- nanobot/agent/loop.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 76bed4158..213631618 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -5,7 +5,6 @@ from __future__ import annotations import asyncio import dataclasses import json -import re import os import time from contextlib import AsyncExitStack, nullcontext From cf02408fc0c5ebe3da42c19fc93e74ff1f616b88 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Thu, 9 Apr 2026 03:04:21 +0000 Subject: [PATCH 023/115] Merge origin/main; remove stale comment and fix blank-line style Made-with: Cursor --- nanobot/agent/loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 213631618..9128b8840 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -41,8 +41,8 @@ if TYPE_CHECKING: from nanobot.cron.service import CronService -# Named constant for unified session key, used across multiple locations UNIFIED_SESSION_KEY = "unified:default" + class _LoopHook(AgentHook): """Core hook for the main loop.""" From 1dd2d5486e1553d67a0e0b9b5b4e79342c86bb05 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Thu, 9 Apr 2026 03:15:21 +0000 Subject: [PATCH 024/115] docs: add unified session configuration to README for cross-channel continuity --- README.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/README.md b/README.md index 0747b25ed..d43fac43e 100644 --- a/README.md +++ b/README.md @@ -1519,6 +1519,32 @@ Common examples: `UTC`, `America/New_York`, `America/Los_Angeles`, `Europe/Londo > Need another timezone? Browse the full [IANA Time Zone Database](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones). +### Unified Session + +By default, each channel × chat ID combination gets its own session. If you use nanobot across multiple channels (e.g. Telegram + Discord + CLI) and want them to share the same conversation, enable `unifiedSession`: + +```json +{ + "agents": { + "defaults": { + "unifiedSession": true + } + } +} +``` + +When enabled, all incoming messages — regardless of which channel they arrive on — are routed into a single shared session. Switching from Telegram to Discord (or any other channel) continues the same conversation seamlessly. + +| Behavior | `false` (default) | `true` | +|----------|-------------------|--------| +| Session key | `channel:chat_id` | `unified:default` | +| Cross-channel continuity | No | Yes | +| `/new` clears | Current channel session | Shared session | +| `/stop` finds tasks | By channel session | By shared session | +| Existing `session_key_override` (e.g. Telegram thread) | Respected | Still respected — not overwritten | + +> This is designed for single-user, multi-device setups. It is **off by default** — existing users see zero behavior change. + ## 🧩 Multiple Instances Run multiple nanobot instances simultaneously with separate configs and runtime data. Use `--config` as the main entrypoint. Optionally pass `--workspace` during `onboard` when you want to initialize or update the saved workspace for a specific instance. From 6d74c88014922b19e4784a261f7a25f27f27fe7c Mon Sep 17 00:00:00 2001 From: Alfredo Arenas Date: Thu, 2 Apr 2026 08:56:08 -0600 Subject: [PATCH 025/115] fix(helpers): ensure assistant message content is never None --- nanobot/utils/helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index 9c2f48960..86dc205e4 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -252,7 +252,7 @@ def split_message(content: str, max_len: int = 2000) -> list[str]: while content: if len(content) <= max_len: chunks.append(content) - break + breakmsg: dict[str, Any] = {"role": "assistant", "content": content} cut = content[:max_len] # Try to break at newline first, then space, then hard break pos = cut.rfind('\n') @@ -272,7 +272,7 @@ def build_assistant_message( thinking_blocks: list[dict] | None = None, ) -> dict[str, Any]: """Build a provider-safe assistant message with optional reasoning fields.""" - msg: dict[str, Any] = {"role": "assistant", "content": content} + msg: dict[str, Any] = {"role": "assistant", "content": content or ""} if tool_calls: msg["tool_calls"] = tool_calls if reasoning_content is not None or thinking_blocks: From 6445b3b0cfb6fc03cdb0ef19a33e4b7a2a0ddc2d Mon Sep 17 00:00:00 2001 From: Alfredo Arenas Date: Thu, 2 Apr 2026 09:03:19 -0600 Subject: [PATCH 026/115] fix(helpers): repair corrupted split_message and ensure content never None Fix accidental line corruption in split_message() where 'break' was merged with unrelated code during manual editing. The actual fix: build_assistant_message() now returns content or "" instead of content (which could be None), preventing providers like MiMo V2 Omni from rejecting tool-call messages with missing text field. Fixes #2519 --- nanobot/utils/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index 86dc205e4..1f14cb36e 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -252,7 +252,7 @@ def split_message(content: str, max_len: int = 2000) -> list[str]: while content: if len(content) <= max_len: chunks.append(content) - breakmsg: dict[str, Any] = {"role": "assistant", "content": content} + break cut = content[:max_len] # Try to break at newline first, then space, then hard break pos = cut.rfind('\n') From 1e3057d0d6d2c94da18e3db22c59fbeb3b5dec97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BD=AD=E6=98=9F=E6=9D=B0?= <1198425718@qq.com> Date: Wed, 8 Apr 2026 14:44:19 +0800 Subject: [PATCH 027/115] fix(cli): remove default green style from Enabled column in tables The Enabled column in channels status and plugins list commands had a default green style that overrode the dim markup for disabled items. This caused no values to appear green instead of dimmed. Remove the default style to let cell-level markup control the display correctly. --- nanobot/cli/commands.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 7c4d31f3e..5ce8b7937 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -1119,7 +1119,7 @@ def channels_status( table = Table(title="Channel Status") table.add_column("Channel", style="cyan") - table.add_column("Enabled", style="green") + table.add_column("Enabled") for name, cls in sorted(discover_all().items()): section = getattr(config.channels, name, None) @@ -1254,7 +1254,7 @@ def plugins_list(): table = Table(title="Channel Plugins") table.add_column("Name", style="cyan") table.add_column("Source", style="magenta") - table.add_column("Enabled", style="green") + table.add_column("Enabled") for name in sorted(all_channels): cls = all_channels[name] From e9c4fe682411f4dabde6907f1d1c9241192ed593 Mon Sep 17 00:00:00 2001 From: chenyahui Date: Thu, 9 Apr 2026 14:11:47 +0800 Subject: [PATCH 028/115] feat(skills): add disabled_skills config to exclude skills from loading Introduce a disabled_skills option in the config schema that allows users to specify a list of skill names to be excluded. The setting is threaded from config through Nanobot -> AgentLoop -> ContextBuilder -> SkillsLoader. Disabled skills are filtered out from list_skills, get_always_skills, and build_skills_summary. Four new test cases cover the filtering behavior. --- nanobot/agent/context.py | 4 +-- nanobot/agent/loop.py | 3 +- nanobot/agent/skills.py | 6 +++- nanobot/cli/commands.py | 3 ++ nanobot/config/schema.py | 1 + nanobot/nanobot.py | 1 + tests/agent/test_skills_loader.py | 60 +++++++++++++++++++++++++++++++ 7 files changed, 74 insertions(+), 4 deletions(-) diff --git a/nanobot/agent/context.py b/nanobot/agent/context.py index 3ac19e7f3..56e42d845 100644 --- a/nanobot/agent/context.py +++ b/nanobot/agent/context.py @@ -21,11 +21,11 @@ class ContextBuilder: _RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]" _MAX_RECENT_HISTORY = 50 - def __init__(self, workspace: Path, timezone: str | None = None): + def __init__(self, workspace: Path, timezone: str | None = None, disabled_skills: list[str] | None = None): self.workspace = workspace self.timezone = timezone self.memory = MemoryStore(workspace) - self.skills = SkillsLoader(workspace) + self.skills = SkillsLoader(workspace, disabled_skills=set(disabled_skills) if disabled_skills else None) def build_system_prompt( self, diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 9128b8840..80205ceae 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -147,6 +147,7 @@ class AgentLoop: timezone: str | None = None, hooks: list[AgentHook] | None = None, unified_session: bool = False, + disabled_skills: list[str] | None = None, ): from nanobot.config.schema import ExecToolConfig, WebToolsConfig @@ -179,7 +180,7 @@ class AgentLoop: self._last_usage: dict[str, int] = {} self._extra_hooks: list[AgentHook] = hooks or [] - self.context = ContextBuilder(workspace, timezone=timezone) + self.context = ContextBuilder(workspace, timezone=timezone, disabled_skills=disabled_skills) self.sessions = session_manager or SessionManager(workspace) self.tools = ToolRegistry() self.runner = AgentRunner(provider) diff --git a/nanobot/agent/skills.py b/nanobot/agent/skills.py index ca215cc96..e9ef1986f 100644 --- a/nanobot/agent/skills.py +++ b/nanobot/agent/skills.py @@ -28,10 +28,11 @@ class SkillsLoader: specific tools or perform certain tasks. """ - def __init__(self, workspace: Path, builtin_skills_dir: Path | None = None): + def __init__(self, workspace: Path, builtin_skills_dir: Path | None = None, disabled_skills: set[str] | None = None): self.workspace = workspace self.workspace_skills = workspace / "skills" self.builtin_skills = builtin_skills_dir or BUILTIN_SKILLS_DIR + self.disabled_skills = disabled_skills or set() def _skill_entries_from_dir(self, base: Path, source: str, *, skip_names: set[str] | None = None) -> list[dict[str, str]]: if not base.exists(): @@ -66,6 +67,9 @@ class SkillsLoader: self._skill_entries_from_dir(self.builtin_skills, "builtin", skip_names=workspace_names) ) + if self.disabled_skills: + skills = [s for s in skills if s["name"] not in self.disabled_skills] + if filter_unavailable: return [skill for skill in skills if self._check_requirements(self._get_skill_meta(skill["name"]))] return skills diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 5ce8b7937..04a21b3f9 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -591,6 +591,7 @@ def serve( channels_config=runtime_config.channels, timezone=runtime_config.agents.defaults.timezone, unified_session=runtime_config.agents.defaults.unified_session, + disabled_skills=runtime_config.agents.defaults.disabled_skills, ) model_name = runtime_config.agents.defaults.model @@ -683,6 +684,7 @@ def gateway( channels_config=config.channels, timezone=config.agents.defaults.timezone, unified_session=config.agents.defaults.unified_session, + disabled_skills=config.agents.defaults.disabled_skills, ) # Set cron callback (needs agent) @@ -915,6 +917,7 @@ def agent( channels_config=config.channels, timezone=config.agents.defaults.timezone, unified_session=config.agents.defaults.unified_session, + disabled_skills=config.agents.defaults.disabled_skills, ) restart_notice = consume_restart_notice_from_env() if restart_notice and should_show_cli_restart_notice(restart_notice, session_id): diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index b011d765f..d6e7f9045 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -77,6 +77,7 @@ class AgentDefaults(Base): reasoning_effort: str | None = None # low / medium / high / adaptive - enables LLM thinking mode timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York" unified_session: bool = False # Share one session across all channels (single-user multi-device) + disabled_skills: list[str] = Field(default_factory=list) # Skill names to exclude from loading (e.g. ["summarize", "skill-creator"]) dream: DreamConfig = Field(default_factory=DreamConfig) diff --git a/nanobot/nanobot.py b/nanobot/nanobot.py index 9166acb27..75d030d7a 100644 --- a/nanobot/nanobot.py +++ b/nanobot/nanobot.py @@ -82,6 +82,7 @@ class Nanobot: mcp_servers=config.tools.mcp_servers, timezone=defaults.timezone, unified_session=defaults.unified_session, + disabled_skills=defaults.disabled_skills, ) return cls(loop) diff --git a/tests/agent/test_skills_loader.py b/tests/agent/test_skills_loader.py index 46923c806..4284fa0c6 100644 --- a/tests/agent/test_skills_loader.py +++ b/tests/agent/test_skills_loader.py @@ -250,3 +250,63 @@ def test_list_skills_openclaw_metadata_parsed_for_requirements( assert entries == [ {"name": "openclaw_skill", "path": str(skill_path), "source": "workspace"}, ] + + +def test_disabled_skills_excluded_from_list(tmp_path: Path) -> None: + workspace = tmp_path / "ws" + ws_skills = workspace / "skills" + ws_skills.mkdir(parents=True) + _write_skill(ws_skills, "alpha", body="# Alpha") + beta_path = _write_skill(ws_skills, "beta", body="# Beta") + builtin = tmp_path / "builtin" + builtin.mkdir() + + loader = SkillsLoader(workspace, builtin_skills_dir=builtin, disabled_skills={"alpha"}) + entries = loader.list_skills(filter_unavailable=False) + assert len(entries) == 1 + assert entries[0]["name"] == "beta" + assert entries[0]["path"] == str(beta_path) + + +def test_disabled_skills_empty_set_no_effect(tmp_path: Path) -> None: + workspace = tmp_path / "ws" + ws_skills = workspace / "skills" + ws_skills.mkdir(parents=True) + _write_skill(ws_skills, "alpha", body="# Alpha") + _write_skill(ws_skills, "beta", body="# Beta") + builtin = tmp_path / "builtin" + builtin.mkdir() + + loader = SkillsLoader(workspace, builtin_skills_dir=builtin, disabled_skills=set()) + entries = loader.list_skills(filter_unavailable=False) + assert len(entries) == 2 + + +def test_disabled_skills_excluded_from_build_skills_summary(tmp_path: Path) -> None: + workspace = tmp_path / "ws" + ws_skills = workspace / "skills" + ws_skills.mkdir(parents=True) + _write_skill(ws_skills, "alpha", body="# Alpha") + _write_skill(ws_skills, "beta", body="# Beta") + builtin = tmp_path / "builtin" + builtin.mkdir() + + loader = SkillsLoader(workspace, builtin_skills_dir=builtin, disabled_skills={"alpha"}) + summary = loader.build_skills_summary() + assert "alpha" not in summary + assert "beta" in summary + + +def test_disabled_skills_excluded_from_get_always_skills(tmp_path: Path) -> None: + workspace = tmp_path / "ws" + ws_skills = workspace / "skills" + ws_skills.mkdir(parents=True) + _write_skill(ws_skills, "alpha", metadata_json={"always": True}, body="# Alpha") + _write_skill(ws_skills, "beta", metadata_json={"always": True}, body="# Beta") + builtin = tmp_path / "builtin" + builtin.mkdir() + + loader = SkillsLoader(workspace, builtin_skills_dir=builtin, disabled_skills={"alpha"}) + always = loader.get_always_skills() + assert "alpha" not in always + assert "beta" in always From ad57bcd127b1c768ec4726651577e3b484f4eb30 Mon Sep 17 00:00:00 2001 From: Jack Lu <46274946+JackLuguibin@users.noreply.github.com> Date: Wed, 8 Apr 2026 21:39:35 +0800 Subject: [PATCH 029/115] feat(channels): add WebSocket server channel and tests Port Python implementation from a1ec7b192ad97ffd58250a720891ff09bbb73888 (websocket channel module and channel tests; excludes webui debug app). --- nanobot/channels/websocket.py | 418 +++++++++++++++++++++++ tests/channels/test_websocket_channel.py | 329 ++++++++++++++++++ 2 files changed, 747 insertions(+) create mode 100644 nanobot/channels/websocket.py create mode 100644 tests/channels/test_websocket_channel.py diff --git a/nanobot/channels/websocket.py b/nanobot/channels/websocket.py new file mode 100644 index 000000000..e09e6303e --- /dev/null +++ b/nanobot/channels/websocket.py @@ -0,0 +1,418 @@ +"""WebSocket server channel: nanobot acts as a WebSocket server and serves connected clients.""" + +from __future__ import annotations + +import asyncio +import email.utils +import hmac +import http +import json +import secrets +import ssl +import time +import uuid +from typing import Any, Self +from urllib.parse import parse_qs, urlparse + +from loguru import logger +from pydantic import Field, field_validator, model_validator +from websockets.asyncio.server import ServerConnection, serve +from websockets.datastructures import Headers +from websockets.http11 import Request as WsRequest, Response + +from nanobot.bus.events import OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.base import BaseChannel +from nanobot.config.schema import Base + + +def _strip_trailing_slash(path: str) -> str: + if len(path) > 1 and path.endswith("/"): + return path.rstrip("/") + return path or "/" + + +def _normalize_config_path(path: str) -> str: + return _strip_trailing_slash(path) + + +class WebSocketConfig(Base): + """WebSocket server channel configuration. + + Clients connect with URLs like ``ws://{host}:{port}{path}?client_id=...&token=...``. + - ``client_id``: Used for ``allow_from`` authorization; if omitted, a value is generated and logged. + - ``token``: If non-empty, the ``token`` query param may match this static secret; short-lived tokens + from ``token_issue_path`` are also accepted. + - ``token_issue_path``: If non-empty, **GET** (HTTP/1.1) to this path returns JSON + ``{"token": "...", "expires_in": }``; use ``?token=...`` when opening the WebSocket. + Must differ from ``path`` (the WS upgrade path). If the client runs in the **same process** as + nanobot and shares the asyncio loop, use a thread or async HTTP client for GET—do not call + blocking ``urllib`` or synchronous ``httpx`` from inside a coroutine. + - ``token_issue_secret``: If non-empty, token requests must send ``Authorization: Bearer `` or + ``X-Nanobot-Auth: ``. + - ``websocket_requires_token``: If True, the handshake must include a valid token (static or issued and not expired). + - Each connection has its own session: a unique ``chat_id`` maps to the agent session internally. + """ + + enabled: bool = False + host: str = "127.0.0.1" + port: int = 8765 + path: str = "/" + token: str = "" + token_issue_path: str = "" + token_issue_secret: str = "" + token_ttl_s: int = Field(default=300, ge=30, le=86_400) + websocket_requires_token: bool = False + allow_from: list[str] = Field(default_factory=lambda: ["*"]) + streaming: bool = True + max_message_bytes: int = Field(default=1_048_576, ge=1024, le=16_777_216) + ping_interval_s: float = Field(default=20.0, ge=5.0, le=300.0) + ping_timeout_s: float = Field(default=20.0, ge=5.0, le=300.0) + ssl_certfile: str = "" + ssl_keyfile: str = "" + + @field_validator("path") + @classmethod + def path_must_start_with_slash(cls, value: str) -> str: + if not value.startswith("/"): + raise ValueError('path must start with "/"') + return value + + @field_validator("token_issue_path") + @classmethod + def token_issue_path_format(cls, value: str) -> str: + value = value.strip() + if not value: + return "" + if not value.startswith("/"): + raise ValueError('token_issue_path must start with "/"') + return _normalize_config_path(value) + + @model_validator(mode="after") + def token_issue_path_differs_from_ws_path(self) -> Self: + if not self.token_issue_path: + return self + if _normalize_config_path(self.token_issue_path) == _normalize_config_path(self.path): + raise ValueError("token_issue_path must differ from path (the WebSocket upgrade path)") + return self + + +def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response: + body = json.dumps(data, ensure_ascii=False).encode("utf-8") + headers = Headers( + [ + ("Date", email.utils.formatdate(usegmt=True)), + ("Connection", "close"), + ("Content-Length", str(len(body))), + ("Content-Type", "application/json; charset=utf-8"), + ] + ) + reason = http.HTTPStatus(status).phrase + return Response(status, reason, headers, body) + + +def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]: + """Parse normalized path and query parameters in one pass.""" + parsed = urlparse("ws://x" + path_with_query) + path = _strip_trailing_slash(parsed.path or "/") + return path, parse_qs(parsed.query) + + +def _normalize_http_path(path_with_query: str) -> str: + """Return the path component (no query string), with trailing slash normalized (root stays ``/``).""" + return _parse_request_path(path_with_query)[0] + + +def _parse_query(path_with_query: str) -> dict[str, list[str]]: + return _parse_request_path(path_with_query)[1] + + +def _parse_inbound_payload(raw: str) -> str | None: + """Parse a client frame into text; return None for empty or unrecognized content.""" + text = raw.strip() + if not text: + return None + if text.startswith("{"): + try: + data = json.loads(text) + except json.JSONDecodeError: + return text + if isinstance(data, dict): + for key in ("content", "text", "message"): + value = data.get(key) + if isinstance(value, str) and value.strip(): + return value + return None + return None + return text + + +def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool: + """Return True if the token-issue HTTP request carries credentials matching ``token_issue_secret``.""" + if not configured_secret: + return True + authorization = headers.get("Authorization") or headers.get("authorization") + if authorization and authorization.lower().startswith("bearer "): + supplied = authorization[7:].strip() + return hmac.compare_digest(supplied, configured_secret) + header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth") + if not header_token: + return False + return hmac.compare_digest(header_token.strip(), configured_secret) + + +class WebSocketChannel(BaseChannel): + """Run a local WebSocket server; forward text/JSON messages to the message bus.""" + + name = "websocket" + display_name = "WebSocket" + + def __init__(self, config: Any, bus: MessageBus): + if isinstance(config, dict): + config = WebSocketConfig.model_validate(config) + super().__init__(config, bus) + self.config: WebSocketConfig = config + self._connections: dict[str, Any] = {} + self._issued_tokens: dict[str, float] = {} + self._stop_event: asyncio.Event | None = None + self._server_task: asyncio.Task[None] | None = None + + @classmethod + def default_config(cls) -> dict[str, Any]: + return WebSocketConfig().model_dump(by_alias=True) + + def _expected_path(self) -> str: + return _normalize_config_path(self.config.path) + + def _build_ssl_context(self) -> ssl.SSLContext | None: + cert = self.config.ssl_certfile.strip() + key = self.config.ssl_keyfile.strip() + if not cert and not key: + return None + if not cert or not key: + raise ValueError( + "websocket: ssl_certfile and ssl_keyfile must both be set for WSS, or both left empty" + ) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + ctx.load_cert_chain(certfile=cert, keyfile=key) + return ctx + + def _purge_expired_issued_tokens(self) -> None: + now = time.monotonic() + for token_key, expiry in list(self._issued_tokens.items()): + if now > expiry: + self._issued_tokens.pop(token_key, None) + + def _take_issued_token_if_valid(self, token_value: str | None) -> bool: + """Validate and consume one issued token (single use per connection attempt).""" + if not token_value: + return False + self._purge_expired_issued_tokens() + expiry = self._issued_tokens.get(token_value) + if expiry is None: + return False + if time.monotonic() > expiry: + self._issued_tokens.pop(token_value, None) + return False + self._issued_tokens.pop(token_value, None) + return True + + def _handle_token_issue_http(self, connection: Any, request: Any) -> Any: + secret = self.config.token_issue_secret.strip() + if secret: + if not _issue_route_secret_matches(request.headers, secret): + return connection.respond(401, "Unauthorized") + else: + logger.warning( + "websocket: token_issue_path is set but token_issue_secret is empty; " + "any client can obtain connection tokens — set token_issue_secret for production." + ) + self._purge_expired_issued_tokens() + token_value = f"nbwt_{secrets.token_urlsafe(32)}" + self._issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s) + + return _http_json_response( + {"token": token_value, "expires_in": self.config.token_ttl_s} + ) + + def _authorize_websocket_handshake(self, connection: Any, request_path: str) -> Any: + query = _parse_query(request_path) + supplied = (query.get("token") or [None])[0] + static_token = self.config.token.strip() + + if static_token: + if supplied == static_token: + return None + if supplied and self._take_issued_token_if_valid(supplied): + return None + return connection.respond(401, "Unauthorized") + + if self.config.websocket_requires_token: + if supplied and self._take_issued_token_if_valid(supplied): + return None + return connection.respond(401, "Unauthorized") + + if supplied: + self._take_issued_token_if_valid(supplied) + return None + + async def start(self) -> None: + self._running = True + self._stop_event = asyncio.Event() + + ssl_context = self._build_ssl_context() + scheme = "wss" if ssl_context else "ws" + + async def process_request( + connection: ServerConnection, + request: WsRequest, + ) -> Any: + got, _ = _parse_request_path(request.path) + if self.config.token_issue_path: + issue_expected = _normalize_config_path(self.config.token_issue_path) + if got == issue_expected: + return self._handle_token_issue_http(connection, request) + + expected_ws = self._expected_path() + if got != expected_ws: + return connection.respond(404, "Not Found") + return self._authorize_websocket_handshake(connection, request.path) + + async def handler(connection: ServerConnection) -> None: + await self._connection_loop(connection) + + logger.info( + "WebSocket server listening on {}://{}:{}{}", + scheme, + self.config.host, + self.config.port, + self.config.path, + ) + if self.config.token_issue_path: + logger.info( + "WebSocket token issue route: {}://{}:{}{}", + scheme, + self.config.host, + self.config.port, + _normalize_config_path(self.config.token_issue_path), + ) + + async def runner() -> None: + async with serve( + handler, + self.config.host, + self.config.port, + process_request=process_request, + max_size=self.config.max_message_bytes, + ping_interval=self.config.ping_interval_s, + ping_timeout=self.config.ping_timeout_s, + ssl=ssl_context, + ): + assert self._stop_event is not None + await self._stop_event.wait() + + self._server_task = asyncio.create_task(runner()) + await self._server_task + + async def _connection_loop(self, connection: Any) -> None: + request = connection.request + path_part = request.path if request else "/" + _, query = _parse_request_path(path_part) + client_id_raw = (query.get("client_id") or [None])[0] + client_id = client_id_raw.strip() if client_id_raw else "" + if not client_id: + client_id = f"anon-{uuid.uuid4().hex[:12]}" + + chat_id = str(uuid.uuid4()) + self._connections[chat_id] = connection + + await connection.send( + json.dumps( + { + "event": "ready", + "chat_id": chat_id, + "client_id": client_id, + }, + ensure_ascii=False, + ) + ) + + try: + async for raw in connection: + if isinstance(raw, bytes): + try: + raw = raw.decode("utf-8") + except UnicodeDecodeError: + logger.warning("websocket: ignoring non-utf8 binary frame") + continue + content = _parse_inbound_payload(raw) + if content is None: + continue + await self._handle_message( + sender_id=client_id, + chat_id=chat_id, + content=content, + metadata={"remote": getattr(connection, "remote_address", None)}, + ) + except Exception as e: + logger.debug("websocket connection ended: {}", e) + finally: + self._connections.pop(chat_id, None) + + async def stop(self) -> None: + self._running = False + if self._stop_event: + self._stop_event.set() + if self._server_task: + await self._server_task + self._server_task = None + self._connections.clear() + self._issued_tokens.clear() + + async def send(self, msg: OutboundMessage) -> None: + connection = self._connections.get(msg.chat_id) + if connection is None: + logger.warning("websocket: no active connection for chat_id={}", msg.chat_id) + return + payload: dict[str, Any] = { + "event": "message", + "text": msg.content, + } + if msg.media: + payload["media"] = msg.media + if msg.reply_to: + payload["reply_to"] = msg.reply_to + raw = json.dumps(payload, ensure_ascii=False) + try: + await connection.send(raw) + except Exception as e: + logger.error("websocket send failed: {}", e) + raise + + async def send_delta( + self, + chat_id: str, + delta: str, + metadata: dict[str, Any] | None = None, + ) -> None: + connection = self._connections.get(chat_id) + if connection is None: + return + meta = metadata or {} + if meta.get("_stream_end"): + body: dict[str, Any] = {"event": "stream_end"} + if meta.get("_stream_id") is not None: + body["stream_id"] = meta["_stream_id"] + else: + body = { + "event": "delta", + "text": delta, + } + if meta.get("_stream_id") is not None: + body["stream_id"] = meta["_stream_id"] + raw = json.dumps(body, ensure_ascii=False) + try: + await connection.send(raw) + except Exception as e: + logger.error("websocket stream send failed: {}", e) + raise diff --git a/tests/channels/test_websocket_channel.py b/tests/channels/test_websocket_channel.py new file mode 100644 index 000000000..e4c5ad635 --- /dev/null +++ b/tests/channels/test_websocket_channel.py @@ -0,0 +1,329 @@ +"""Unit and lightweight integration tests for the WebSocket channel.""" + +import asyncio +import functools +import json +from unittest.mock import AsyncMock, MagicMock + +import httpx +import pytest +import websockets + +from nanobot.bus.events import OutboundMessage +from nanobot.channels.websocket import ( + WebSocketChannel, + WebSocketConfig, + _issue_route_secret_matches, + _normalize_config_path, + _normalize_http_path, + _parse_inbound_payload, + _parse_query, + _parse_request_path, +) + + +async def _http_get(url: str, headers: dict[str, str] | None = None) -> httpx.Response: + """Run GET in a thread to avoid blocking the asyncio loop shared with websockets.""" + return await asyncio.to_thread( + functools.partial(httpx.get, url, headers=headers or {}, timeout=5.0) + ) + + +def test_normalize_http_path_strips_trailing_slash_except_root() -> None: + assert _normalize_http_path("/chat/") == "/chat" + assert _normalize_http_path("/chat?x=1") == "/chat" + assert _normalize_http_path("/") == "/" + + +def test_parse_request_path_matches_normalize_and_query() -> None: + path, query = _parse_request_path("/ws/?token=secret&client_id=u1") + assert path == _normalize_http_path("/ws/?token=secret&client_id=u1") + assert query == _parse_query("/ws/?token=secret&client_id=u1") + + +def test_normalize_config_path_matches_request() -> None: + assert _normalize_config_path("/ws/") == "/ws" + assert _normalize_config_path("/") == "/" + + +def test_parse_query_extracts_token_and_client_id() -> None: + query = _parse_query("/?token=secret&client_id=u1") + assert query.get("token") == ["secret"] + assert query.get("client_id") == ["u1"] + + +@pytest.mark.parametrize( + ("raw", "expected"), + [ + ("plain", "plain"), + ('{"content": "hi"}', "hi"), + ('{"text": "there"}', "there"), + ('{"message": "x"}', "x"), + (" ", None), + ("{}", None), + ], +) +def test_parse_inbound_payload(raw: str, expected: str | None) -> None: + assert _parse_inbound_payload(raw) == expected + + +def test_parse_inbound_invalid_json_falls_back_to_raw_string() -> None: + assert _parse_inbound_payload("{not json") == "{not json" + + +def test_web_socket_config_path_must_start_with_slash() -> None: + with pytest.raises(ValueError, match='path must start with "/"'): + WebSocketConfig(path="bad") + + +def test_ssl_context_requires_both_cert_and_key_files() -> None: + bus = MagicMock() + channel = WebSocketChannel( + {"enabled": True, "allowFrom": ["*"], "sslCertfile": "/tmp/c.pem", "sslKeyfile": ""}, + bus, + ) + with pytest.raises(ValueError, match="ssl_certfile and ssl_keyfile"): + channel._build_ssl_context() + + +def test_default_config_includes_safe_bind_and_streaming() -> None: + defaults = WebSocketChannel.default_config() + assert defaults["enabled"] is False + assert defaults["host"] == "127.0.0.1" + assert defaults["streaming"] is True + assert defaults["allowFrom"] == ["*"] + assert defaults.get("tokenIssuePath", "") == "" + + +def test_token_issue_path_must_differ_from_websocket_path() -> None: + with pytest.raises(ValueError, match="token_issue_path must differ"): + WebSocketConfig(path="/ws", token_issue_path="/ws") + + +def test_issue_route_secret_matches_bearer_and_header() -> None: + from websockets.datastructures import Headers + + secret = "my-secret" + bearer_headers = Headers([("Authorization", "Bearer my-secret")]) + assert _issue_route_secret_matches(bearer_headers, secret) is True + x_headers = Headers([("X-Nanobot-Auth", "my-secret")]) + assert _issue_route_secret_matches(x_headers, secret) is True + wrong = Headers([("Authorization", "Bearer other")]) + assert _issue_route_secret_matches(wrong, secret) is False + + +@pytest.mark.asyncio +async def test_send_delivers_json_message_with_media_and_reply() -> None: + bus = MagicMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) + mock_ws = AsyncMock() + channel._connections["chat-1"] = mock_ws + + msg = OutboundMessage( + channel="websocket", + chat_id="chat-1", + content="hello", + reply_to="m1", + media=["/tmp/a.png"], + ) + await channel.send(msg) + + mock_ws.send.assert_awaited_once() + payload = json.loads(mock_ws.send.call_args[0][0]) + assert payload["event"] == "message" + assert payload["text"] == "hello" + assert payload["reply_to"] == "m1" + assert payload["media"] == ["/tmp/a.png"] + + +@pytest.mark.asyncio +async def test_send_missing_connection_is_noop_without_error() -> None: + bus = MagicMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) + msg = OutboundMessage(channel="websocket", chat_id="missing", content="x") + await channel.send(msg) + + +@pytest.mark.asyncio +async def test_send_delta_emits_delta_and_stream_end() -> None: + bus = MagicMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus) + mock_ws = AsyncMock() + channel._connections["chat-1"] = mock_ws + + await channel.send_delta("chat-1", "part", {"_stream_delta": True, "_stream_id": "sid"}) + await channel.send_delta("chat-1", "", {"_stream_end": True, "_stream_id": "sid"}) + + assert mock_ws.send.await_count == 2 + first = json.loads(mock_ws.send.call_args_list[0][0][0]) + second = json.loads(mock_ws.send.call_args_list[1][0][0]) + assert first["event"] == "delta" + assert first["text"] == "part" + assert first["stream_id"] == "sid" + assert second["event"] == "stream_end" + assert second["stream_id"] == "sid" + + +@pytest.mark.asyncio +async def test_end_to_end_client_receives_ready_and_agent_sees_inbound() -> None: + bus = MagicMock() + bus.publish_inbound = AsyncMock() + port = 29876 + channel = WebSocketChannel( + { + "enabled": True, + "allowFrom": ["*"], + "host": "127.0.0.1", + "port": port, + "path": "/ws", + }, + bus, + ) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=tester") as client: + ready_raw = await client.recv() + ready = json.loads(ready_raw) + assert ready["event"] == "ready" + assert ready["client_id"] == "tester" + chat_id = ready["chat_id"] + + await client.send(json.dumps({"content": "ping from client"})) + await asyncio.sleep(0.08) + + bus.publish_inbound.assert_awaited() + inbound = bus.publish_inbound.call_args[0][0] + assert inbound.channel == "websocket" + assert inbound.sender_id == "tester" + assert inbound.chat_id == chat_id + assert inbound.content == "ping from client" + + await client.send("plain text frame") + await asyncio.sleep(0.08) + assert bus.publish_inbound.await_count >= 2 + second = [c[0][0] for c in bus.publish_inbound.call_args_list][-1] + assert second.content == "plain text frame" + finally: + await channel.stop() + await server_task + + +@pytest.mark.asyncio +async def test_token_rejects_handshake_when_mismatch() -> None: + bus = MagicMock() + port = 29877 + channel = WebSocketChannel( + { + "enabled": True, + "allowFrom": ["*"], + "host": "127.0.0.1", + "port": port, + "path": "/", + "token": "secret", + }, + bus, + ) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo: + async with websockets.connect(f"ws://127.0.0.1:{port}/?token=wrong"): + pass + assert excinfo.value.response.status_code == 401 + finally: + await channel.stop() + await server_task + + +@pytest.mark.asyncio +async def test_wrong_path_returns_404() -> None: + bus = MagicMock() + port = 29878 + channel = WebSocketChannel( + { + "enabled": True, + "allowFrom": ["*"], + "host": "127.0.0.1", + "port": port, + "path": "/ws", + }, + bus, + ) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo: + async with websockets.connect(f"ws://127.0.0.1:{port}/other"): + pass + assert excinfo.value.response.status_code == 404 + finally: + await channel.stop() + await server_task + + +def test_registry_discovers_websocket_channel() -> None: + from nanobot.channels.registry import load_channel_class + + cls = load_channel_class("websocket") + assert cls.name == "websocket" + + +@pytest.mark.asyncio +async def test_http_route_issues_token_then_websocket_requires_it() -> None: + bus = MagicMock() + bus.publish_inbound = AsyncMock() + port = 29879 + channel = WebSocketChannel( + { + "enabled": True, + "allowFrom": ["*"], + "host": "127.0.0.1", + "port": port, + "path": "/ws", + "tokenIssuePath": "/auth/token", + "tokenIssueSecret": "route-secret", + "websocketRequiresToken": True, + }, + bus, + ) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + deny = await _http_get(f"http://127.0.0.1:{port}/auth/token") + assert deny.status_code == 401 + + issue = await _http_get( + f"http://127.0.0.1:{port}/auth/token", + headers={"Authorization": "Bearer route-secret"}, + ) + assert issue.status_code == 200 + token = issue.json()["token"] + assert token.startswith("nbwt_") + + with pytest.raises(websockets.exceptions.InvalidStatus) as missing_token: + async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=x"): + pass + assert missing_token.value.response.status_code == 401 + + uri = f"ws://127.0.0.1:{port}/ws?token={token}&client_id=caller" + async with websockets.connect(uri) as client: + ready = json.loads(await client.recv()) + assert ready["event"] == "ready" + assert ready["client_id"] == "caller" + + with pytest.raises(websockets.exceptions.InvalidStatus) as reuse: + async with websockets.connect(uri): + pass + assert reuse.value.response.status_code == 401 + finally: + await channel.stop() + await server_task From e0ccc401c0c3a5189fd5d2c64f63cc6855e6c75a Mon Sep 17 00:00:00 2001 From: chengyongru Date: Thu, 9 Apr 2026 13:47:53 +0800 Subject: [PATCH 030/115] fix(websocket): handle ConnectionClosed gracefully in send and send_delta --- nanobot/channels/websocket.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/nanobot/channels/websocket.py b/nanobot/channels/websocket.py index e09e6303e..1660cbe7e 100644 --- a/nanobot/channels/websocket.py +++ b/nanobot/channels/websocket.py @@ -18,6 +18,7 @@ from loguru import logger from pydantic import Field, field_validator, model_validator from websockets.asyncio.server import ServerConnection, serve from websockets.datastructures import Headers +from websockets.exceptions import ConnectionClosed from websockets.http11 import Request as WsRequest, Response from nanobot.bus.events import OutboundMessage @@ -52,6 +53,8 @@ class WebSocketConfig(Base): ``X-Nanobot-Auth: ``. - ``websocket_requires_token``: If True, the handshake must include a valid token (static or issued and not expired). - Each connection has its own session: a unique ``chat_id`` maps to the agent session internally. + - ``media`` field in outbound messages contains local filesystem paths; remote clients need a + shared filesystem or an HTTP file server to access these files. """ enabled: bool = False @@ -385,6 +388,9 @@ class WebSocketChannel(BaseChannel): raw = json.dumps(payload, ensure_ascii=False) try: await connection.send(raw) + except ConnectionClosed: + self._connections.pop(msg.chat_id, None) + logger.warning("websocket: connection gone for chat_id={}", msg.chat_id) except Exception as e: logger.error("websocket send failed: {}", e) raise @@ -413,6 +419,9 @@ class WebSocketChannel(BaseChannel): raw = json.dumps(body, ensure_ascii=False) try: await connection.send(raw) + except ConnectionClosed: + self._connections.pop(chat_id, None) + logger.warning("websocket: stream connection gone for chat_id={}", chat_id) except Exception as e: logger.error("websocket stream send failed: {}", e) raise From 56a5906db514e687ba7f379a2d1ca9d201df6a29 Mon Sep 17 00:00:00 2001 From: chengyongru Date: Thu, 9 Apr 2026 15:18:02 +0800 Subject: [PATCH 031/115] fix(websocket): harden security and robustness - Use hmac.compare_digest for timing-safe static token comparison - Add issued token capacity limit (_MAX_ISSUED_TOKENS=10000) with 429 response - Use atomic pop in _take_issued_token_if_valid to eliminate TOCTOU window - Enforce TLSv1.2 minimum version for SSL connections - Extract _safe_send helper for consistent ConnectionClosed handling - Move connection registration after ready send to prevent out-of-order delivery - Add HTTP-level allow_from check and client_id truncation in process_request - Make stop() idempotent with graceful shutdown error handling - Normalize path via validator instead of leaving raw value - Default websocket_requires_token to True for secure-by-default behavior - Add integration tests and ws_test_client helper - Refactor tests to use shared _ch factory and bus fixture --- nanobot/channels/websocket.py | 124 +++-- tests/channels/test_websocket_channel.py | 373 +++++++++++++-- tests/channels/test_websocket_integration.py | 477 +++++++++++++++++++ tests/channels/ws_test_client.py | 227 +++++++++ 4 files changed, 1102 insertions(+), 99 deletions(-) create mode 100644 tests/channels/test_websocket_integration.py create mode 100644 tests/channels/ws_test_client.py diff --git a/nanobot/channels/websocket.py b/nanobot/channels/websocket.py index 1660cbe7e..2af61d6f7 100644 --- a/nanobot/channels/websocket.py +++ b/nanobot/channels/websocket.py @@ -65,7 +65,7 @@ class WebSocketConfig(Base): token_issue_path: str = "" token_issue_secret: str = "" token_ttl_s: int = Field(default=300, ge=30, le=86_400) - websocket_requires_token: bool = False + websocket_requires_token: bool = True allow_from: list[str] = Field(default_factory=lambda: ["*"]) streaming: bool = True max_message_bytes: int = Field(default=1_048_576, ge=1024, le=16_777_216) @@ -79,7 +79,7 @@ class WebSocketConfig(Base): def path_must_start_with_slash(cls, value: str) -> str: if not value.startswith("/"): raise ValueError('path must start with "/"') - return value + return _normalize_config_path(value) @field_validator("token_issue_path") @classmethod @@ -130,6 +130,12 @@ def _parse_query(path_with_query: str) -> dict[str, list[str]]: return _parse_request_path(path_with_query)[1] +def _query_first(query: dict[str, list[str]], key: str) -> str | None: + """Return the first value for *key*, or None.""" + values = query.get(key) + return values[0] if values else None + + def _parse_inbound_payload(raw: str) -> str | None: """Parse a client frame into text; return None for empty or unrecognized content.""" text = raw.strip() @@ -197,9 +203,12 @@ class WebSocketChannel(BaseChannel): "websocket: ssl_certfile and ssl_keyfile must both be set for WSS, or both left empty" ) ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + ctx.minimum_version = ssl.TLSVersion.TLSv1_2 ctx.load_cert_chain(certfile=cert, keyfile=key) return ctx + _MAX_ISSUED_TOKENS = 10_000 + def _purge_expired_issued_tokens(self) -> None: now = time.monotonic() for token_key, expiry in list(self._issued_tokens.items()): @@ -207,17 +216,19 @@ class WebSocketChannel(BaseChannel): self._issued_tokens.pop(token_key, None) def _take_issued_token_if_valid(self, token_value: str | None) -> bool: - """Validate and consume one issued token (single use per connection attempt).""" + """Validate and consume one issued token (single use per connection attempt). + + Uses single-step pop to minimize the window between lookup and removal; + safe under asyncio's single-threaded cooperative model. + """ if not token_value: return False self._purge_expired_issued_tokens() - expiry = self._issued_tokens.get(token_value) + expiry = self._issued_tokens.pop(token_value, None) if expiry is None: return False if time.monotonic() > expiry: - self._issued_tokens.pop(token_value, None) return False - self._issued_tokens.pop(token_value, None) return True def _handle_token_issue_http(self, connection: Any, request: Any) -> Any: @@ -231,6 +242,12 @@ class WebSocketChannel(BaseChannel): "any client can obtain connection tokens — set token_issue_secret for production." ) self._purge_expired_issued_tokens() + if len(self._issued_tokens) >= self._MAX_ISSUED_TOKENS: + logger.error( + "websocket: too many outstanding issued tokens ({}), rejecting issuance", + len(self._issued_tokens), + ) + return _http_json_response({"error": "too many outstanding tokens"}, status=429) token_value = f"nbwt_{secrets.token_urlsafe(32)}" self._issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s) @@ -238,13 +255,12 @@ class WebSocketChannel(BaseChannel): {"token": token_value, "expires_in": self.config.token_ttl_s} ) - def _authorize_websocket_handshake(self, connection: Any, request_path: str) -> Any: - query = _parse_query(request_path) - supplied = (query.get("token") or [None])[0] + def _authorize_websocket_handshake(self, connection: Any, query: dict[str, list[str]]) -> Any: + supplied = _query_first(query, "token") static_token = self.config.token.strip() if static_token: - if supplied == static_token: + if supplied and hmac.compare_digest(supplied, static_token): return None if supplied and self._take_issued_token_if_valid(supplied): return None @@ -279,7 +295,15 @@ class WebSocketChannel(BaseChannel): expected_ws = self._expected_path() if got != expected_ws: return connection.respond(404, "Not Found") - return self._authorize_websocket_handshake(connection, request.path) + # Early reject before WebSocket upgrade to avoid unnecessary overhead; + # _handle_message() performs a second check as defense-in-depth. + query = _parse_query(request.path) + client_id = _query_first(query, "client_id") or "" + if len(client_id) > 128: + client_id = client_id[:128] + if not self.is_allowed(client_id): + return connection.respond(403, "Forbidden") + return self._authorize_websocket_handshake(connection, query) async def handler(connection: ServerConnection) -> None: await self._connection_loop(connection) @@ -321,26 +345,30 @@ class WebSocketChannel(BaseChannel): request = connection.request path_part = request.path if request else "/" _, query = _parse_request_path(path_part) - client_id_raw = (query.get("client_id") or [None])[0] + client_id_raw = _query_first(query, "client_id") client_id = client_id_raw.strip() if client_id_raw else "" if not client_id: client_id = f"anon-{uuid.uuid4().hex[:12]}" + elif len(client_id) > 128: + logger.warning("websocket: client_id too long ({} chars), truncating", len(client_id)) + client_id = client_id[:128] chat_id = str(uuid.uuid4()) - self._connections[chat_id] = connection - - await connection.send( - json.dumps( - { - "event": "ready", - "chat_id": chat_id, - "client_id": client_id, - }, - ensure_ascii=False, - ) - ) try: + await connection.send( + json.dumps( + { + "event": "ready", + "chat_id": chat_id, + "client_id": client_id, + }, + ensure_ascii=False, + ) + ) + # Register only after ready is successfully sent to avoid out-of-order sends + self._connections[chat_id] = connection + async for raw in connection: if isinstance(raw, bytes): try: @@ -363,15 +391,34 @@ class WebSocketChannel(BaseChannel): self._connections.pop(chat_id, None) async def stop(self) -> None: + if not self._running: + return self._running = False if self._stop_event: self._stop_event.set() if self._server_task: - await self._server_task + try: + await self._server_task + except Exception as e: + logger.warning("websocket: server task error during shutdown: {}", e) self._server_task = None self._connections.clear() self._issued_tokens.clear() + async def _safe_send(self, chat_id: str, raw: str, *, label: str = "") -> None: + """Send a raw frame, cleaning up dead connections on ConnectionClosed.""" + connection = self._connections.get(chat_id) + if connection is None: + return + try: + await connection.send(raw) + except ConnectionClosed: + self._connections.pop(chat_id, None) + logger.warning("websocket{}connection gone for chat_id={}", label, chat_id) + except Exception as e: + logger.error("websocket{}send failed: {}", label, e) + raise + async def send(self, msg: OutboundMessage) -> None: connection = self._connections.get(msg.chat_id) if connection is None: @@ -386,14 +433,7 @@ class WebSocketChannel(BaseChannel): if msg.reply_to: payload["reply_to"] = msg.reply_to raw = json.dumps(payload, ensure_ascii=False) - try: - await connection.send(raw) - except ConnectionClosed: - self._connections.pop(msg.chat_id, None) - logger.warning("websocket: connection gone for chat_id={}", msg.chat_id) - except Exception as e: - logger.error("websocket send failed: {}", e) - raise + await self._safe_send(msg.chat_id, raw, label=" ") async def send_delta( self, @@ -401,27 +441,17 @@ class WebSocketChannel(BaseChannel): delta: str, metadata: dict[str, Any] | None = None, ) -> None: - connection = self._connections.get(chat_id) - if connection is None: + if self._connections.get(chat_id) is None: return meta = metadata or {} if meta.get("_stream_end"): body: dict[str, Any] = {"event": "stream_end"} - if meta.get("_stream_id") is not None: - body["stream_id"] = meta["_stream_id"] else: body = { "event": "delta", "text": delta, } - if meta.get("_stream_id") is not None: - body["stream_id"] = meta["_stream_id"] + if meta.get("_stream_id") is not None: + body["stream_id"] = meta["_stream_id"] raw = json.dumps(body, ensure_ascii=False) - try: - await connection.send(raw) - except ConnectionClosed: - self._connections.pop(chat_id, None) - logger.warning("websocket: stream connection gone for chat_id={}", chat_id) - except Exception as e: - logger.error("websocket stream send failed: {}", e) - raise + await self._safe_send(chat_id, raw, label=" stream ") diff --git a/tests/channels/test_websocket_channel.py b/tests/channels/test_websocket_channel.py index e4c5ad635..89a330a18 100644 --- a/tests/channels/test_websocket_channel.py +++ b/tests/channels/test_websocket_channel.py @@ -3,11 +3,15 @@ import asyncio import functools import json +import time +from typing import Any from unittest.mock import AsyncMock, MagicMock import httpx import pytest import websockets +from websockets.exceptions import ConnectionClosed +from websockets.frames import Close from nanobot.bus.events import OutboundMessage from nanobot.channels.websocket import ( @@ -21,6 +25,30 @@ from nanobot.channels.websocket import ( _parse_request_path, ) +# -- Shared helpers (aligned with test_websocket_integration.py) --------------- + +_PORT = 29876 + + +def _ch(bus: Any, **kw: Any) -> WebSocketChannel: + cfg: dict[str, Any] = { + "enabled": True, + "allowFrom": ["*"], + "host": "127.0.0.1", + "port": _PORT, + "path": "/ws", + "websocketRequiresToken": False, + } + cfg.update(kw) + return WebSocketChannel(cfg, bus) + + +@pytest.fixture() +def bus() -> MagicMock: + b = MagicMock() + b.publish_inbound = AsyncMock() + return b + async def _http_get(url: str, headers: dict[str, str] | None = None) -> httpx.Response: """Run GET in a thread to avoid blocking the asyncio loop shared with websockets.""" @@ -71,6 +99,21 @@ def test_parse_inbound_invalid_json_falls_back_to_raw_string() -> None: assert _parse_inbound_payload("{not json") == "{not json" +@pytest.mark.parametrize( + ("raw", "expected"), + [ + ('{"content": ""}', None), # empty string content + ('{"content": 123}', None), # non-string content + ('{"content": " "}', None), # whitespace-only content + ('["hello"]', '["hello"]'), # JSON array: not a dict, treated as plain text + ('{"unknown_key": "val"}', None), # unrecognized key + ('{"content": null}', None), # null content + ], +) +def test_parse_inbound_payload_edge_cases(raw: str, expected: str | None) -> None: + assert _parse_inbound_payload(raw) == expected + + def test_web_socket_config_path_must_start_with_slash() -> None: with pytest.raises(ValueError, match='path must start with "/"'): WebSocketConfig(path="bad") @@ -112,6 +155,14 @@ def test_issue_route_secret_matches_bearer_and_header() -> None: assert _issue_route_secret_matches(wrong, secret) is False +def test_issue_route_secret_matches_empty_secret() -> None: + from websockets.datastructures import Headers + + # Empty secret always returns True regardless of headers + assert _issue_route_secret_matches(Headers([]), "") is True + assert _issue_route_secret_matches(Headers([("Authorization", "Bearer anything")]), "") is True + + @pytest.mark.asyncio async def test_send_delivers_json_message_with_media_and_reply() -> None: bus = MagicMock() @@ -144,6 +195,33 @@ async def test_send_missing_connection_is_noop_without_error() -> None: await channel.send(msg) +@pytest.mark.asyncio +async def test_send_removes_connection_on_connection_closed() -> None: + bus = MagicMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) + mock_ws = AsyncMock() + mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True) + channel._connections["chat-1"] = mock_ws + + msg = OutboundMessage(channel="websocket", chat_id="chat-1", content="hello") + await channel.send(msg) + + assert "chat-1" not in channel._connections + + +@pytest.mark.asyncio +async def test_send_delta_removes_connection_on_connection_closed() -> None: + bus = MagicMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus) + mock_ws = AsyncMock() + mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True) + channel._connections["chat-1"] = mock_ws + + await channel.send_delta("chat-1", "chunk", {"_stream_delta": True, "_stream_id": "s1"}) + + assert "chat-1" not in channel._connections + + @pytest.mark.asyncio async def test_send_delta_emits_delta_and_stream_end() -> None: bus = MagicMock() @@ -165,20 +243,39 @@ async def test_send_delta_emits_delta_and_stream_end() -> None: @pytest.mark.asyncio -async def test_end_to_end_client_receives_ready_and_agent_sees_inbound() -> None: +async def test_send_non_connection_closed_exception_is_raised() -> None: bus = MagicMock() - bus.publish_inbound = AsyncMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) + mock_ws = AsyncMock() + mock_ws.send.side_effect = RuntimeError("unexpected") + channel._connections["chat-1"] = mock_ws + + msg = OutboundMessage(channel="websocket", chat_id="chat-1", content="hello") + with pytest.raises(RuntimeError, match="unexpected"): + await channel.send(msg) + + +@pytest.mark.asyncio +async def test_send_delta_missing_connection_is_noop() -> None: + bus = MagicMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus) + # No exception, no error — just a no-op + await channel.send_delta("nonexistent", "chunk", {"_stream_delta": True, "_stream_id": "s1"}) + + +@pytest.mark.asyncio +async def test_stop_is_idempotent() -> None: + bus = MagicMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) + # stop() before start() should not raise + await channel.stop() + await channel.stop() + + +@pytest.mark.asyncio +async def test_end_to_end_client_receives_ready_and_agent_sees_inbound(bus: MagicMock) -> None: port = 29876 - channel = WebSocketChannel( - { - "enabled": True, - "allowFrom": ["*"], - "host": "127.0.0.1", - "port": port, - "path": "/ws", - }, - bus, - ) + channel = _ch(bus, port=port) server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) @@ -212,20 +309,9 @@ async def test_end_to_end_client_receives_ready_and_agent_sees_inbound() -> None @pytest.mark.asyncio -async def test_token_rejects_handshake_when_mismatch() -> None: - bus = MagicMock() +async def test_token_rejects_handshake_when_mismatch(bus: MagicMock) -> None: port = 29877 - channel = WebSocketChannel( - { - "enabled": True, - "allowFrom": ["*"], - "host": "127.0.0.1", - "port": port, - "path": "/", - "token": "secret", - }, - bus, - ) + channel = _ch(bus, port=port, path="/", token="secret") server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) @@ -241,19 +327,9 @@ async def test_token_rejects_handshake_when_mismatch() -> None: @pytest.mark.asyncio -async def test_wrong_path_returns_404() -> None: - bus = MagicMock() +async def test_wrong_path_returns_404(bus: MagicMock) -> None: port = 29878 - channel = WebSocketChannel( - { - "enabled": True, - "allowFrom": ["*"], - "host": "127.0.0.1", - "port": port, - "path": "/ws", - }, - bus, - ) + channel = _ch(bus, port=port) server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) @@ -276,22 +352,13 @@ def test_registry_discovers_websocket_channel() -> None: @pytest.mark.asyncio -async def test_http_route_issues_token_then_websocket_requires_it() -> None: - bus = MagicMock() - bus.publish_inbound = AsyncMock() +async def test_http_route_issues_token_then_websocket_requires_it(bus: MagicMock) -> None: port = 29879 - channel = WebSocketChannel( - { - "enabled": True, - "allowFrom": ["*"], - "host": "127.0.0.1", - "port": port, - "path": "/ws", - "tokenIssuePath": "/auth/token", - "tokenIssueSecret": "route-secret", - "websocketRequiresToken": True, - }, - bus, + channel = _ch( + bus, port=port, + tokenIssuePath="/auth/token", + tokenIssueSecret="route-secret", + websocketRequiresToken=True, ) server_task = asyncio.create_task(channel.start()) @@ -327,3 +394,205 @@ async def test_http_route_issues_token_then_websocket_requires_it() -> None: finally: await channel.stop() await server_task + + +@pytest.mark.asyncio +async def test_end_to_end_server_pushes_streaming_deltas_to_client(bus: MagicMock) -> None: + port = 29880 + channel = _ch(bus, port=port, streaming=True) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=stream-tester") as client: + ready_raw = await client.recv() + ready = json.loads(ready_raw) + chat_id = ready["chat_id"] + + # Server pushes deltas directly + await channel.send_delta( + chat_id, "Hello ", {"_stream_delta": True, "_stream_id": "s1"} + ) + await channel.send_delta( + chat_id, "world", {"_stream_delta": True, "_stream_id": "s1"} + ) + await channel.send_delta( + chat_id, "", {"_stream_end": True, "_stream_id": "s1"} + ) + + delta1 = json.loads(await client.recv()) + assert delta1["event"] == "delta" + assert delta1["text"] == "Hello " + assert delta1["stream_id"] == "s1" + + delta2 = json.loads(await client.recv()) + assert delta2["event"] == "delta" + assert delta2["text"] == "world" + assert delta2["stream_id"] == "s1" + + end = json.loads(await client.recv()) + assert end["event"] == "stream_end" + assert end["stream_id"] == "s1" + finally: + await channel.stop() + await server_task + + +@pytest.mark.asyncio +async def test_token_issue_rejects_when_at_capacity(bus: MagicMock) -> None: + port = 29881 + channel = _ch(bus, port=port, tokenIssuePath="/auth/token", tokenIssueSecret="s") + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + # Fill issued tokens to capacity + channel._issued_tokens = { + f"nbwt_fill_{i}": time.monotonic() + 300 for i in range(channel._MAX_ISSUED_TOKENS) + } + + resp = await _http_get( + f"http://127.0.0.1:{port}/auth/token", + headers={"Authorization": "Bearer s"}, + ) + assert resp.status_code == 429 + data = resp.json() + assert "error" in data + finally: + await channel.stop() + await server_task + + +@pytest.mark.asyncio +async def test_allow_from_rejects_unauthorized_client_id(bus: MagicMock) -> None: + port = 29882 + channel = _ch(bus, port=port, allowFrom=["alice", "bob"]) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info: + async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=eve"): + pass + assert exc_info.value.response.status_code == 403 + finally: + await channel.stop() + await server_task + + +@pytest.mark.asyncio +async def test_client_id_truncation(bus: MagicMock) -> None: + port = 29883 + channel = _ch(bus, port=port) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + long_id = "x" * 200 + async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id={long_id}") as client: + ready = json.loads(await client.recv()) + assert ready["client_id"] == "x" * 128 + assert len(ready["client_id"]) == 128 + finally: + await channel.stop() + await server_task + + +@pytest.mark.asyncio +async def test_non_utf8_binary_frame_ignored(bus: MagicMock) -> None: + port = 29884 + channel = _ch(bus, port=port) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=bin-test") as client: + await client.recv() # consume ready + # Send non-UTF-8 bytes + await client.send(b"\xff\xfe\xfd") + await asyncio.sleep(0.05) + # publish_inbound should NOT have been called + bus.publish_inbound.assert_not_awaited() + finally: + await channel.stop() + await server_task + + +@pytest.mark.asyncio +async def test_static_token_accepts_issued_token_as_fallback(bus: MagicMock) -> None: + port = 29885 + channel = _ch( + bus, port=port, + token="static-secret", + tokenIssuePath="/auth/token", + tokenIssueSecret="route-secret", + ) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + # Get an issued token + resp = await _http_get( + f"http://127.0.0.1:{port}/auth/token", + headers={"Authorization": "Bearer route-secret"}, + ) + assert resp.status_code == 200 + issued_token = resp.json()["token"] + + # Connect using issued token (not the static one) + async with websockets.connect(f"ws://127.0.0.1:{port}/ws?token={issued_token}&client_id=caller") as client: + ready = json.loads(await client.recv()) + assert ready["event"] == "ready" + finally: + await channel.stop() + await server_task + + +@pytest.mark.asyncio +async def test_allow_from_empty_list_denies_all(bus: MagicMock) -> None: + port = 29886 + channel = _ch(bus, port=port, allowFrom=[]) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info: + async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=anyone"): + pass + assert exc_info.value.response.status_code == 403 + finally: + await channel.stop() + await server_task + + +@pytest.mark.asyncio +async def test_websocket_requires_token_without_issue_path(bus: MagicMock) -> None: + """When websocket_requires_token is True but no token or issue path configured, all connections are rejected.""" + port = 29887 + channel = _ch(bus, port=port, websocketRequiresToken=True) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + # No token at all → 401 + with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info: + async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=u"): + pass + assert exc_info.value.response.status_code == 401 + + # Wrong token → 401 + with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info: + async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=u&token=wrong"): + pass + assert exc_info.value.response.status_code == 401 + finally: + await channel.stop() + await server_task diff --git a/tests/channels/test_websocket_integration.py b/tests/channels/test_websocket_integration.py new file mode 100644 index 000000000..2cf0331ab --- /dev/null +++ b/tests/channels/test_websocket_integration.py @@ -0,0 +1,477 @@ +"""Integration tests for the WebSocket channel using WsTestClient. + +Complements the unit/lightweight tests in test_websocket_channel.py by covering +multi-client scenarios, edge cases, and realistic usage patterns. +""" + +from __future__ import annotations + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest +import websockets + +from nanobot.channels.websocket import WebSocketChannel +from nanobot.bus.events import OutboundMessage +from ws_test_client import WsTestClient, issue_token, issue_token_ok + + +def _ch(bus: Any, port: int, **kw: Any) -> WebSocketChannel: + cfg: dict[str, Any] = { + "enabled": True, + "allowFrom": ["*"], + "host": "127.0.0.1", + "port": port, + "path": "/", + "websocketRequiresToken": False, + } + cfg.update(kw) + return WebSocketChannel(cfg, bus) + + +@pytest.fixture() +def bus() -> MagicMock: + b = MagicMock() + b.publish_inbound = AsyncMock() + return b + + +# -- Connection basics ---------------------------------------------------- + + +@pytest.mark.asyncio +async def test_ready_event_fields(bus: MagicMock) -> None: + ch = _ch(bus, 29901) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29901/", client_id="c1") as c: + r = await c.recv_ready() + assert r.event == "ready" + assert len(r.chat_id) == 36 + assert r.client_id == "c1" + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_anonymous_client_gets_generated_id(bus: MagicMock) -> None: + ch = _ch(bus, 29902) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29902/", client_id="") as c: + r = await c.recv_ready() + assert r.client_id.startswith("anon-") + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_each_connection_unique_chat_id(bus: MagicMock) -> None: + ch = _ch(bus, 29903) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29903/", client_id="a") as c1: + async with WsTestClient("ws://127.0.0.1:29903/", client_id="b") as c2: + assert (await c1.recv_ready()).chat_id != (await c2.recv_ready()).chat_id + finally: + await ch.stop(); await t + + +# -- Inbound messages (client -> server) ---------------------------------- + + +@pytest.mark.asyncio +async def test_plain_text(bus: MagicMock) -> None: + ch = _ch(bus, 29904) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29904/", client_id="p") as c: + await c.recv_ready() + await c.send_text("hello world") + await asyncio.sleep(0.1) + inbound = bus.publish_inbound.call_args[0][0] + assert inbound.content == "hello world" + assert inbound.sender_id == "p" + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_json_content_field(bus: MagicMock) -> None: + ch = _ch(bus, 29905) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29905/", client_id="j") as c: + await c.recv_ready() + await c.send_json({"content": "structured"}) + await asyncio.sleep(0.1) + assert bus.publish_inbound.call_args[0][0].content == "structured" + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_json_text_and_message_fields(bus: MagicMock) -> None: + ch = _ch(bus, 29906) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29906/", client_id="x") as c: + await c.recv_ready() + await c.send_json({"text": "via text"}) + await asyncio.sleep(0.1) + assert bus.publish_inbound.call_args[0][0].content == "via text" + await c.send_json({"message": "via message"}) + await asyncio.sleep(0.1) + assert bus.publish_inbound.call_args[0][0].content == "via message" + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_empty_payload_ignored(bus: MagicMock) -> None: + ch = _ch(bus, 29907) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29907/", client_id="e") as c: + await c.recv_ready() + await c.send_text(" ") + await c.send_json({}) + await asyncio.sleep(0.1) + bus.publish_inbound.assert_not_awaited() + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_messages_preserve_order(bus: MagicMock) -> None: + ch = _ch(bus, 29908) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29908/", client_id="o") as c: + await c.recv_ready() + for i in range(5): + await c.send_text(f"msg-{i}") + await asyncio.sleep(0.2) + contents = [call[0][0].content for call in bus.publish_inbound.call_args_list] + assert contents == [f"msg-{i}" for i in range(5)] + finally: + await ch.stop(); await t + + +# -- Outbound messages (server -> client) --------------------------------- + + +@pytest.mark.asyncio +async def test_server_send_message(bus: MagicMock) -> None: + ch = _ch(bus, 29909) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29909/", client_id="r") as c: + ready = await c.recv_ready() + await ch.send(OutboundMessage( + channel="websocket", chat_id=ready.chat_id, content="reply", + )) + msg = await c.recv_message() + assert msg.text == "reply" + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_server_send_with_media_and_reply(bus: MagicMock) -> None: + ch = _ch(bus, 29910) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29910/", client_id="m") as c: + ready = await c.recv_ready() + await ch.send(OutboundMessage( + channel="websocket", chat_id=ready.chat_id, content="img", + media=["/tmp/a.png"], reply_to="m1", + )) + msg = await c.recv_message() + assert msg.text == "img" + assert msg.media == ["/tmp/a.png"] + assert msg.reply_to == "m1" + finally: + await ch.stop(); await t + + +# -- Streaming ------------------------------------------------------------ + + +@pytest.mark.asyncio +async def test_streaming_deltas_and_end(bus: MagicMock) -> None: + ch = _ch(bus, 29911, streaming=True) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29911/", client_id="s") as c: + cid = (await c.recv_ready()).chat_id + for part in ("Hello", " ", "world", "!"): + await ch.send_delta(cid, part, {"_stream_delta": True, "_stream_id": "s1"}) + await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "s1"}) + + msgs = await c.collect_stream() + deltas = [m for m in msgs if m.event == "delta"] + assert "".join(d.text for d in deltas) == "Hello world!" + ends = [m for m in msgs if m.event == "stream_end"] + assert len(ends) == 1 + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_interleaved_streams(bus: MagicMock) -> None: + ch = _ch(bus, 29912, streaming=True) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29912/", client_id="i") as c: + cid = (await c.recv_ready()).chat_id + await ch.send_delta(cid, "A1", {"_stream_delta": True, "_stream_id": "sa"}) + await ch.send_delta(cid, "B1", {"_stream_delta": True, "_stream_id": "sb"}) + await ch.send_delta(cid, "A2", {"_stream_delta": True, "_stream_id": "sa"}) + await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "sa"}) + await ch.send_delta(cid, "B2", {"_stream_delta": True, "_stream_id": "sb"}) + await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "sb"}) + + msgs = await c.recv_n(6) + sa = "".join(m.text for m in msgs if m.event == "delta" and m.stream_id == "sa") + sb = "".join(m.text for m in msgs if m.event == "delta" and m.stream_id == "sb") + assert sa == "A1A2" + assert sb == "B1B2" + finally: + await ch.stop(); await t + + +# -- Multi-client --------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_independent_sessions(bus: MagicMock) -> None: + ch = _ch(bus, 29913) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29913/", client_id="u1") as c1: + async with WsTestClient("ws://127.0.0.1:29913/", client_id="u2") as c2: + r1, r2 = await c1.recv_ready(), await c2.recv_ready() + await ch.send(OutboundMessage( + channel="websocket", chat_id=r1.chat_id, content="for-u1", + )) + assert (await c1.recv_message()).text == "for-u1" + await ch.send(OutboundMessage( + channel="websocket", chat_id=r2.chat_id, content="for-u2", + )) + assert (await c2.recv_message()).text == "for-u2" + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_disconnected_client_cleanup(bus: MagicMock) -> None: + ch = _ch(bus, 29914) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29914/", client_id="tmp") as c: + chat_id = (await c.recv_ready()).chat_id + # disconnected + await ch.send(OutboundMessage( + channel="websocket", chat_id=chat_id, content="orphan", + )) + assert chat_id not in ch._connections + finally: + await ch.stop(); await t + + +# -- Authentication ------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_static_token_accepted(bus: MagicMock) -> None: + ch = _ch(bus, 29915, token="secret") + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29915/", client_id="a", token="secret") as c: + assert (await c.recv_ready()).client_id == "a" + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_static_token_rejected(bus: MagicMock) -> None: + ch = _ch(bus, 29916, token="correct") + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + with pytest.raises(websockets.exceptions.InvalidStatus) as exc: + async with WsTestClient("ws://127.0.0.1:29916/", client_id="b", token="wrong"): + pass + assert exc.value.response.status_code == 401 + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_token_issue_full_flow(bus: MagicMock) -> None: + ch = _ch(bus, 29917, path="/ws", + tokenIssuePath="/auth/token", tokenIssueSecret="s", + websocketRequiresToken=True) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + # no secret -> 401 + _, status = await issue_token(port=29917, issue_path="/auth/token") + assert status == 401 + + # with secret -> token + token = await issue_token_ok(port=29917, issue_path="/auth/token", secret="s") + + # no token -> 401 + with pytest.raises(websockets.exceptions.InvalidStatus) as exc: + async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="x"): + pass + assert exc.value.response.status_code == 401 + + # valid token -> ok + async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="ok", token=token) as c: + assert (await c.recv_ready()).client_id == "ok" + + # reuse -> 401 + with pytest.raises(websockets.exceptions.InvalidStatus) as exc: + async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="r", token=token): + pass + assert exc.value.response.status_code == 401 + finally: + await ch.stop(); await t + + +# -- Path routing --------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_custom_path(bus: MagicMock) -> None: + ch = _ch(bus, 29918, path="/my-chat") + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29918/my-chat", client_id="p") as c: + assert (await c.recv_ready()).event == "ready" + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_wrong_path_404(bus: MagicMock) -> None: + ch = _ch(bus, 29919, path="/ws") + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + with pytest.raises(websockets.exceptions.InvalidStatus) as exc: + async with WsTestClient("ws://127.0.0.1:29919/wrong", client_id="x"): + pass + assert exc.value.response.status_code == 404 + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_trailing_slash_normalized(bus: MagicMock) -> None: + ch = _ch(bus, 29920, path="/ws") + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29920/ws/", client_id="s") as c: + assert (await c.recv_ready()).event == "ready" + finally: + await ch.stop(); await t + + +# -- Edge cases ----------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_large_message(bus: MagicMock) -> None: + ch = _ch(bus, 29921) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29921/", client_id="big") as c: + await c.recv_ready() + big = "x" * 100_000 + await c.send_text(big) + await asyncio.sleep(0.2) + assert bus.publish_inbound.call_args[0][0].content == big + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_unicode_roundtrip(bus: MagicMock) -> None: + ch = _ch(bus, 29922) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29922/", client_id="u") as c: + ready = await c.recv_ready() + text = "你好世界 🌍 日本語テスト" + await c.send_text(text) + await asyncio.sleep(0.1) + assert bus.publish_inbound.call_args[0][0].content == text + await ch.send(OutboundMessage( + channel="websocket", chat_id=ready.chat_id, content=text, + )) + assert (await c.recv_message()).text == text + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_rapid_fire(bus: MagicMock) -> None: + ch = _ch(bus, 29923) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29923/", client_id="r") as c: + ready = await c.recv_ready() + for i in range(50): + await c.send_text(f"in-{i}") + await asyncio.sleep(0.5) + assert bus.publish_inbound.await_count == 50 + for i in range(50): + await ch.send(OutboundMessage( + channel="websocket", chat_id=ready.chat_id, content=f"out-{i}", + )) + received = [(await c.recv_message()).text for _ in range(50)] + assert received == [f"out-{i}" for i in range(50)] + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_invalid_json_as_plain_text(bus: MagicMock) -> None: + ch = _ch(bus, 29924) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29924/", client_id="j") as c: + await c.recv_ready() + await c.send_text("{broken json") + await asyncio.sleep(0.1) + assert bus.publish_inbound.call_args[0][0].content == "{broken json" + finally: + await ch.stop(); await t diff --git a/tests/channels/ws_test_client.py b/tests/channels/ws_test_client.py new file mode 100644 index 000000000..ec3ba1460 --- /dev/null +++ b/tests/channels/ws_test_client.py @@ -0,0 +1,227 @@ +"""Lightweight WebSocket test client for integration testing the nanobot WebSocket channel. + +Provides an async ``WsTestClient`` class and token-issuance helpers that +integration tests can import and use directly:: + + from ws_test_client import WsTestClient + + async with WsTestClient("ws://127.0.0.1:8765/", client_id="t") as c: + ready = await c.recv_ready() + await c.send_text("hello") + msg = await c.recv_message() +""" + +from __future__ import annotations + +import asyncio +import json +from dataclasses import dataclass, field +from typing import Any + +import httpx +import websockets +from websockets.asyncio.client import ClientConnection + + +@dataclass +class WsMessage: + """A parsed message received from the WebSocket server.""" + + event: str + raw: dict[str, Any] = field(repr=False) + + @property + def text(self) -> str | None: + return self.raw.get("text") + + @property + def chat_id(self) -> str | None: + return self.raw.get("chat_id") + + @property + def client_id(self) -> str | None: + return self.raw.get("client_id") + + @property + def media(self) -> list[str] | None: + return self.raw.get("media") + + @property + def reply_to(self) -> str | None: + return self.raw.get("reply_to") + + @property + def stream_id(self) -> str | None: + return self.raw.get("stream_id") + + def __eq__(self, other: object) -> bool: + if not isinstance(other, WsMessage): + return NotImplemented + return self.event == other.event and self.raw == other.raw + + +class WsTestClient: + """Async WebSocket test client with helper methods for common operations. + + Usage:: + + async with WsTestClient("ws://127.0.0.1:8765/", client_id="tester") as client: + ready = await client.recv_ready() + await client.send_text("hello") + msg = await client.recv_message(timeout=5.0) + """ + + def __init__( + self, + uri: str, + *, + client_id: str = "test-client", + token: str = "", + extra_headers: dict[str, str] | None = None, + ) -> None: + params: list[str] = [] + if client_id: + params.append(f"client_id={client_id}") + if token: + params.append(f"token={token}") + sep = "&" if "?" in uri else "?" + self._uri = uri + sep + "&".join(params) if params else uri + self._extra_headers = extra_headers + self._ws: ClientConnection | None = None + + async def connect(self) -> None: + self._ws = await websockets.connect( + self._uri, + additional_headers=self._extra_headers, + ) + + async def close(self) -> None: + if self._ws: + await self._ws.close() + self._ws = None + + async def __aenter__(self) -> WsTestClient: + await self.connect() + return self + + async def __aexit__(self, *args: Any) -> None: + await self.close() + + @property + def ws(self) -> ClientConnection: + assert self._ws is not None, "Client is not connected" + return self._ws + + # -- Receiving -------------------------------------------------------- + + async def recv_raw(self, timeout: float = 10.0) -> dict[str, Any]: + """Receive and parse one raw JSON message with timeout.""" + raw = await asyncio.wait_for(self.ws.recv(), timeout=timeout) + return json.loads(raw) + + async def recv(self, timeout: float = 10.0) -> WsMessage: + """Receive one message, returning a WsMessage wrapper.""" + data = await self.recv_raw(timeout) + return WsMessage(event=data.get("event", ""), raw=data) + + async def recv_ready(self, timeout: float = 5.0) -> WsMessage: + """Receive and validate the 'ready' event.""" + msg = await self.recv(timeout) + assert msg.event == "ready", f"Expected 'ready' event, got '{msg.event}'" + return msg + + async def recv_message(self, timeout: float = 10.0) -> WsMessage: + """Receive and validate a 'message' event.""" + msg = await self.recv(timeout) + assert msg.event == "message", f"Expected 'message' event, got '{msg.event}'" + return msg + + async def recv_delta(self, timeout: float = 10.0) -> WsMessage: + """Receive and validate a 'delta' event.""" + msg = await self.recv(timeout) + assert msg.event == "delta", f"Expected 'delta' event, got '{msg.event}'" + return msg + + async def recv_stream_end(self, timeout: float = 10.0) -> WsMessage: + """Receive and validate a 'stream_end' event.""" + msg = await self.recv(timeout) + assert msg.event == "stream_end", f"Expected 'stream_end' event, got '{msg.event}'" + return msg + + async def collect_stream(self, timeout: float = 10.0) -> list[WsMessage]: + """Collect all deltas and the final stream_end into a list.""" + messages: list[WsMessage] = [] + while True: + msg = await self.recv(timeout) + messages.append(msg) + if msg.event == "stream_end": + break + return messages + + async def recv_n(self, n: int, timeout: float = 10.0) -> list[WsMessage]: + """Receive exactly *n* messages.""" + return [await self.recv(timeout) for _ in range(n)] + + # -- Sending ---------------------------------------------------------- + + async def send_text(self, text: str) -> None: + """Send a plain text frame.""" + await self.ws.send(text) + + async def send_json(self, data: dict[str, Any]) -> None: + """Send a JSON frame.""" + await self.ws.send(json.dumps(data, ensure_ascii=False)) + + async def send_content(self, content: str) -> None: + """Send content in the preferred JSON format ``{"content": ...}``.""" + await self.send_json({"content": content}) + + # -- Connection introspection ----------------------------------------- + + @property + def closed(self) -> bool: + return self._ws is None or self._ws.closed + + +# -- Token issuance helpers ----------------------------------------------- + + +async def issue_token( + host: str = "127.0.0.1", + port: int = 8765, + issue_path: str = "/auth/token", + secret: str = "", +) -> tuple[dict[str, Any] | None, int]: + """Request a short-lived token from the token-issue HTTP endpoint. + + Returns ``(parsed_json_or_None, status_code)``. + """ + url = f"http://{host}:{port}{issue_path}" + headers: dict[str, str] = {} + if secret: + headers["Authorization"] = f"Bearer {secret}" + + loop = asyncio.get_running_loop() + resp = await loop.run_in_executor( + None, lambda: httpx.get(url, headers=headers, timeout=5.0) + ) + try: + data = resp.json() + except Exception: + data = None + return data, resp.status_code + + +async def issue_token_ok( + host: str = "127.0.0.1", + port: int = 8765, + issue_path: str = "/auth/token", + secret: str = "", +) -> str: + """Request a token, asserting success, and return the token string.""" + (data, status) = await issue_token(host, port, issue_path, secret) + assert status == 200, f"Token issue failed with status {status}" + assert data is not None + token = data["token"] + assert token.startswith("nbwt_"), f"Unexpected token format: {token}" + return token From 42de13a1a9bb1e790741f8e9ab8f3c8202cfecf1 Mon Sep 17 00:00:00 2001 From: chengyongru Date: Thu, 9 Apr 2026 15:48:39 +0800 Subject: [PATCH 032/115] docs(websocket): add WebSocket channel documentation Comprehensive guide covering wire protocol, configuration reference, token issuance, security notes, and common deployment patterns. --- docs/WEBSOCKET.md | 331 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 331 insertions(+) create mode 100644 docs/WEBSOCKET.md diff --git a/docs/WEBSOCKET.md b/docs/WEBSOCKET.md new file mode 100644 index 000000000..5cc867b49 --- /dev/null +++ b/docs/WEBSOCKET.md @@ -0,0 +1,331 @@ +# WebSocket Server Channel + +Nanobot can act as a WebSocket server, allowing external clients (web apps, CLIs, scripts) to interact with the agent in real time via persistent connections. + +## Features + +- Bidirectional real-time communication over WebSocket +- Streaming support — receive agent responses token by token +- Token-based authentication (static tokens and short-lived issued tokens) +- Per-connection sessions — each connection gets a unique `chat_id` +- TLS/SSL support (WSS) with enforced TLSv1.2 minimum +- Client allow-list via `allowFrom` +- Auto-cleanup of dead connections + +## Quick Start + +### 1. Configure + +Add to `config.json` under `channels.websocket`: + +```json +{ + "channels": { + "websocket": { + "enabled": true, + "host": "127.0.0.1", + "port": 8765, + "path": "/", + "websocketRequiresToken": false, + "allowFrom": ["*"], + "streaming": true + } + } +} +``` + +### 2. Start nanobot + +```bash +nanobot gateway +``` + +You should see: + +``` +WebSocket server listening on ws://127.0.0.1:8765/ +``` + +### 3. Connect a client + +```bash +# Using websocat +websocat ws://127.0.0.1:8765/?client_id=alice + +# Using Python +import asyncio, json, websockets + +async def main(): + async with websockets.connect("ws://127.0.0.1:8765/?client_id=alice") as ws: + ready = json.loads(await ws.recv()) + print(ready) # {"event": "ready", "chat_id": "...", "client_id": "alice"} + await ws.send(json.dumps({"content": "Hello nanobot!"})) + reply = json.loads(await ws.recv()) + print(reply["text"]) + +asyncio.run(main()) +``` + +## Connection URL + +``` +ws://{host}:{port}{path}?client_id={id}&token={token} +``` + +| Parameter | Required | Description | +|-----------|----------|-------------| +| `client_id` | No | Identifier for `allowFrom` authorization. Auto-generated as `anon-xxxxxxxxxxxx` if omitted. Truncated to 128 chars. | +| `token` | Conditional | Authentication token. Required when `websocketRequiresToken` is `true` or `token` (static secret) is configured. | + +## Wire Protocol + +All frames are JSON text. Each message has an `event` field. + +### Server → Client + +**`ready`** — sent immediately after connection is established: + +```json +{ + "event": "ready", + "chat_id": "uuid-v4", + "client_id": "alice" +} +``` + +**`message`** — full agent response: + +```json +{ + "event": "message", + "text": "Hello! How can I help?", + "media": ["/tmp/image.png"], + "reply_to": "msg-id" +} +``` + +`media` and `reply_to` are only present when applicable. + +**`delta`** — streaming text chunk (only when `streaming: true`): + +```json +{ + "event": "delta", + "text": "Hello", + "stream_id": "s1" +} +``` + +**`stream_end`** — signals the end of a streaming segment: + +```json +{ + "event": "stream_end", + "stream_id": "s1" +} +``` + +### Client → Server + +Send plain text: + +```json +"Hello nanobot!" +``` + +Or send a JSON object with a recognized text field: + +```json +{"content": "Hello nanobot!"} +``` + +Recognized fields: `content`, `text`, `message` (checked in that order). Invalid JSON is treated as plain text. + +## Configuration Reference + +All fields go under `channels.websocket` in `config.json`. + +### Connection + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `enabled` | bool | `false` | Enable the WebSocket server. | +| `host` | string | `"127.0.0.1"` | Bind address. Use `"0.0.0.0"` to accept external connections. | +| `port` | int | `8765` | Listen port. | +| `path` | string | `"/"` | WebSocket upgrade path. Trailing slashes are normalized (root `/` is preserved). | +| `maxMessageBytes` | int | `1048576` | Maximum inbound message size in bytes (1 KB – 16 MB). | + +### Authentication + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `token` | string | `""` | Static shared secret. When set, clients must provide `?token=` matching this secret (timing-safe comparison). Issued tokens are also accepted as a fallback. | +| `websocketRequiresToken` | bool | `true` | When `true` and no static `token` is configured, clients must still present a valid issued token. Set to `false` to allow unauthenticated connections (only safe for local/trusted networks). | +| `tokenIssuePath` | string | `""` | HTTP path for issuing short-lived tokens. Must differ from `path`. See [Token Issuance](#token-issuance). | +| `tokenIssueSecret` | string | `""` | Secret required to obtain tokens via the issue endpoint. If empty, any client can obtain tokens (logged as a warning). | +| `tokenTtlS` | int | `300` | Time-to-live for issued tokens in seconds (30 – 86,400). | + +### Access Control + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `allowFrom` | list of string | `["*"]` | Allowed `client_id` values. `"*"` allows all; `[]` denies all. | + +### Streaming + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `streaming` | bool | `true` | Enable streaming mode. The agent sends `delta` + `stream_end` frames instead of a single `message`. | + +### Keep-alive + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `pingIntervalS` | float | `20.0` | WebSocket ping interval in seconds (5 – 300). | +| `pingTimeoutS` | float | `20.0` | Time to wait for a pong before closing the connection (5 – 300). | + +### TLS/SSL + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `sslCertfile` | string | `""` | Path to the TLS certificate file (PEM). Both `sslCertfile` and `sslKeyfile` must be set to enable WSS. | +| `sslKeyfile` | string | `""` | Path to the TLS private key file (PEM). Minimum TLS version is enforced as TLSv1.2. | + +## Token Issuance + +For production deployments where `websocketRequiresToken: true`, use short-lived tokens instead of embedding static secrets in clients. + +### How it works + +1. Client sends `GET {tokenIssuePath}` with `Authorization: Bearer {tokenIssueSecret}` (or `X-Nanobot-Auth` header). +2. Server responds with a one-time-use token: + +```json +{"token": "nbwt_aBcDeFg...", "expires_in": 300} +``` + +3. Client opens WebSocket with `?token=nbwt_aBcDeFg...&client_id=...`. +4. The token is consumed (single use) and cannot be reused. + +### Example setup + +```json +{ + "channels": { + "websocket": { + "enabled": true, + "port": 8765, + "path": "/ws", + "tokenIssuePath": "/auth/token", + "tokenIssueSecret": "your-secret-here", + "tokenTtlS": 300, + "websocketRequiresToken": true, + "allowFrom": ["*"], + "streaming": true + } + } +} +``` + +Client flow: + +```bash +# 1. Obtain a token +curl -H "Authorization: Bearer your-secret-here" http://127.0.0.1:8765/auth/token + +# 2. Connect using the token +websocat "ws://127.0.0.1:8765/ws?client_id=alice&token=nbwt_aBcDeFg..." +``` + +### Limits + +- Issued tokens are single-use — each token can only complete one handshake. +- Outstanding tokens are capped at 10,000. Requests beyond this return HTTP 429. +- Expired tokens are purged lazily on each issue or validation request. + +## Security Notes + +- **Timing-safe comparison**: Static token validation uses `hmac.compare_digest` to prevent timing attacks. +- **Defense in depth**: `allowFrom` is checked at both the HTTP handshake level and the message level. +- **Token isolation**: Each WebSocket connection gets a unique `chat_id`. Clients cannot access other sessions. +- **TLS enforcement**: When SSL is enabled, TLSv1.2 is the minimum allowed version. +- **Default-secure**: `websocketRequiresToken` defaults to `true`. Explicitly set it to `false` only on trusted networks. + +## Media Files + +Outbound `message` events may include a `media` field containing local filesystem paths. Remote clients cannot access these files directly — they need either: + +- A shared filesystem mount, or +- An HTTP file server serving the nanobot media directory + +## Common Patterns + +### Trusted local network (no auth) + +```json +{ + "channels": { + "websocket": { + "enabled": true, + "host": "0.0.0.0", + "port": 8765, + "websocketRequiresToken": false, + "allowFrom": ["*"], + "streaming": true + } + } +} +``` + +### Static token (simple auth) + +```json +{ + "channels": { + "websocket": { + "enabled": true, + "token": "my-shared-secret", + "allowFrom": ["alice", "bob"] + } + } +} +``` + +Clients connect with `?token=my-shared-secret&client_id=alice`. + +### Public endpoint with issued tokens + +```json +{ + "channels": { + "websocket": { + "enabled": true, + "host": "0.0.0.0", + "port": 8765, + "path": "/ws", + "tokenIssuePath": "/auth/token", + "tokenIssueSecret": "production-secret", + "websocketRequiresToken": true, + "sslCertfile": "/etc/ssl/certs/server.pem", + "sslKeyfile": "/etc/ssl/private/server-key.pem", + "allowFrom": ["*"] + } + } +} +``` + +### Custom path + +```json +{ + "channels": { + "websocket": { + "enabled": true, + "path": "/chat/ws", + "allowFrom": ["*"] + } + } +} +``` + +Clients connect to `ws://127.0.0.1:8765/chat/ws?client_id=...`. Trailing slashes are normalized, so `/chat/ws/` works the same. From ba8bce0f45509c8fc6fce491ae2775c75d8471f9 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Thu, 9 Apr 2026 10:16:28 +0000 Subject: [PATCH 033/115] fix(tests): add missing `from typing import Any` in websocket integration tests Made-with: Cursor --- tests/channels/test_websocket_integration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/channels/test_websocket_integration.py b/tests/channels/test_websocket_integration.py index 2cf0331ab..8ff158666 100644 --- a/tests/channels/test_websocket_integration.py +++ b/tests/channels/test_websocket_integration.py @@ -8,6 +8,7 @@ from __future__ import annotations import asyncio import json +from typing import Any from unittest.mock import AsyncMock, MagicMock import pytest From 10f6c875a5b352bf789daf009c6d52b3484bec48 Mon Sep 17 00:00:00 2001 From: yanghan-cyber Date: Thu, 9 Apr 2026 15:32:03 +0800 Subject: [PATCH 034/115] fix(agent): deliver LLM errors to streaming channels and avoid polluting session context When the LLM returns an error (e.g. 429 quota exceeded, stream timeout), streaming channels silently drop the error message because `_streamed=True` is set in metadata even though no content was actually streamed. This change: - Skips setting `_streamed` when stop_reason is "error", so error messages go through the normal channel.send() path and reach the user - Stops appending error content to session history, preventing error messages from polluting subsequent conversation context - Exposes stop_reason from _run_agent_loop to enable the above check --- nanobot/agent/loop.py | 10 +++++----- nanobot/agent/runner.py | 1 - tests/agent/test_hook_composite.py | 6 +++--- tests/agent/test_runner.py | 6 +++--- tests/tools/test_message_tool_suppress.py | 2 +- 5 files changed, 12 insertions(+), 13 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 9128b8840..54bb29c5d 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -308,7 +308,7 @@ class AgentLoop: channel: str = "cli", chat_id: str = "direct", message_id: str | None = None, - ) -> tuple[str | None, list[str], list[dict]]: + ) -> tuple[str | None, list[str], list[dict], str]: """Run the agent iteration loop. *on_stream*: called with each content delta during streaming. @@ -358,7 +358,7 @@ class AgentLoop: logger.warning("Max iterations ({}) reached", self.max_iterations) elif result.stop_reason == "error": logger.error("LLM returned error: {}", (result.final_content or "")[:200]) - return result.final_content, result.tools_used, result.messages + return result.final_content, result.tools_used, result.messages, result.stop_reason async def run(self) -> None: """Run the agent loop, dispatching messages as tasks to stay responsive to /stop.""" @@ -505,7 +505,7 @@ class AgentLoop: current_message=msg.content, channel=channel, chat_id=chat_id, current_role=current_role, ) - final_content, _, all_msgs = await self._run_agent_loop( + final_content, _, all_msgs, _ = await self._run_agent_loop( messages, session=session, channel=channel, chat_id=chat_id, message_id=msg.metadata.get("message_id"), ) @@ -553,7 +553,7 @@ class AgentLoop: channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta, )) - final_content, _, all_msgs = await self._run_agent_loop( + final_content, _, all_msgs, stop_reason = await self._run_agent_loop( initial_messages, on_progress=on_progress or _bus_progress, on_stream=on_stream, @@ -578,7 +578,7 @@ class AgentLoop: logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview) meta = dict(msg.metadata or {}) - if on_stream is not None: + if on_stream is not None and stop_reason != "error": meta["_streamed"] = True return OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, content=final_content, diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index abc7edf09..bc1a26aba 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -257,7 +257,6 @@ class AgentRunner: final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE stop_reason = "error" error = final_content - self._append_final_message(messages, final_content) context.final_content = final_content context.error = error context.stop_reason = stop_reason diff --git a/tests/agent/test_hook_composite.py b/tests/agent/test_hook_composite.py index c6077d526..672f38ed2 100644 --- a/tests/agent/test_hook_composite.py +++ b/tests/agent/test_hook_composite.py @@ -307,7 +307,7 @@ async def test_agent_loop_extra_hook_receives_calls(tmp_path): ) loop.tools.get_definitions = MagicMock(return_value=[]) - content, tools_used, messages = await loop._run_agent_loop( + content, tools_used, messages, _ = await loop._run_agent_loop( [{"role": "user", "content": "hi"}] ) @@ -331,7 +331,7 @@ async def test_agent_loop_extra_hook_error_isolation(tmp_path): ) loop.tools.get_definitions = MagicMock(return_value=[]) - content, _, _ = await loop._run_agent_loop( + content, _, _, _ = await loop._run_agent_loop( [{"role": "user", "content": "hi"}] ) @@ -373,7 +373,7 @@ async def test_agent_loop_no_hooks_backward_compat(tmp_path): loop.tools.execute = AsyncMock(return_value="ok") loop.max_iterations = 2 - content, tools_used, _ = await loop._run_agent_loop([]) + content, tools_used, _, _ = await loop._run_agent_loop([]) assert content == ( "I reached the maximum number of tool call iterations (2) " "without completing the task. You can try breaking the task into smaller steps." diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index a0804396e..36d5de846 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -798,7 +798,7 @@ async def test_loop_max_iterations_message_stays_stable(tmp_path): loop.tools.execute = AsyncMock(return_value="ok") loop.max_iterations = 2 - final_content, _, _ = await loop._run_agent_loop([]) + final_content, _, _, _ = await loop._run_agent_loop([]) assert final_content == ( "I reached the maximum number of tool call iterations (2) " @@ -825,7 +825,7 @@ async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp async def on_stream_end(*, resuming: bool = False) -> None: endings.append(resuming) - final_content, _, _ = await loop._run_agent_loop( + final_content, _, _, _ = await loop._run_agent_loop( [], on_stream=on_stream, on_stream_end=on_stream_end, @@ -849,7 +849,7 @@ async def test_loop_retries_think_only_final_response(tmp_path): loop.provider.chat_with_retry = chat_with_retry - final_content, _, _ = await loop._run_agent_loop([]) + final_content, _, _, _ = await loop._run_agent_loop([]) assert final_content == "Recovered answer" assert call_count["n"] == 2 diff --git a/tests/tools/test_message_tool_suppress.py b/tests/tools/test_message_tool_suppress.py index 26d12085f..3f06b4a70 100644 --- a/tests/tools/test_message_tool_suppress.py +++ b/tests/tools/test_message_tool_suppress.py @@ -107,7 +107,7 @@ class TestMessageToolSuppressLogic: async def on_progress(content: str, *, tool_hint: bool = False) -> None: progress.append((content, tool_hint)) - final_content, _, _ = await loop._run_agent_loop([], on_progress=on_progress) + final_content, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress) assert final_content == "Done" assert progress == [ From c625c0c2a74bf46269064d641e9b0c7840fa26de Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Thu, 9 Apr 2026 15:08:24 +0000 Subject: [PATCH 035/115] Merge origin/main and add regression tests for streaming error delivery - Merged latest main (no conflicts) - Added test_llm_error_not_appended_to_session_messages: verifies error content stays out of session messages - Added test_streamed_flag_not_set_on_llm_error: verifies _streamed is not set when LLM returns an error, so ChannelManager delivers it Made-with: Cursor --- tests/agent/test_runner.py | 63 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index 36d5de846..afb06634f 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -855,6 +855,69 @@ async def test_loop_retries_think_only_final_response(tmp_path): assert call_count["n"] == 2 +@pytest.mark.asyncio +async def test_llm_error_not_appended_to_session_messages(): + """When LLM returns finish_reason='error', the error content must NOT be + appended to the messages list (prevents polluting session history).""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + provider.chat_with_retry = AsyncMock(return_value=LLMResponse( + content="429 rate limit exceeded", finish_reason="error", tool_calls=[], usage={}, + )) + tools = MagicMock() + tools.get_definitions.return_value = [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.stop_reason == "error" + assert result.final_content == "429 rate limit exceeded" + assistant_msgs = [m for m in result.messages if m.get("role") == "assistant"] + assert all("429" not in (m.get("content") or "") for m in assistant_msgs), \ + "Error content should not appear in session messages" + + +@pytest.mark.asyncio +async def test_streamed_flag_not_set_on_llm_error(tmp_path): + """When LLM errors during a streaming-capable channel interaction, + _streamed must NOT be set so ChannelManager delivers the error.""" + from nanobot.agent.loop import AgentLoop + from nanobot.bus.events import InboundMessage + from nanobot.bus.queue import MessageBus + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") + error_resp = LLMResponse( + content="503 service unavailable", finish_reason="error", tool_calls=[], usage={}, + ) + loop.provider.chat_with_retry = AsyncMock(return_value=error_resp) + loop.provider.chat_stream_with_retry = AsyncMock(return_value=error_resp) + loop.tools.get_definitions = MagicMock(return_value=[]) + + msg = InboundMessage( + channel="feishu", sender_id="u1", chat_id="c1", content="hi", + ) + result = await loop._process_message( + msg, + on_stream=AsyncMock(), + on_stream_end=AsyncMock(), + ) + + assert result is not None + assert "503" in result.content + assert not result.metadata.get("_streamed"), \ + "_streamed must not be set when stop_reason is error" + + @pytest.mark.asyncio async def test_runner_tool_error_sets_final_content(): from nanobot.agent.runner import AgentRunSpec, AgentRunner From 0e6331b66d4c25ab1634bc041635a192515ab17f Mon Sep 17 00:00:00 2001 From: chenyahui Date: Thu, 9 Apr 2026 15:32:57 +0800 Subject: [PATCH 036/115] feat(exec): support allowed_env_keys to pass specified env vars to subprocess Add allowed_env_keys config field to selectively forward host environment variables (e.g. GOPATH, JAVA_HOME) into the sandboxed subprocess environment, while keeping the default allow-list unchanged. --- nanobot/agent/loop.py | 1 + nanobot/agent/tools/shell.py | 16 ++++++++++++++-- nanobot/config/schema.py | 1 + tests/tools/test_exec_env.py | 31 +++++++++++++++++++++++++++++++ 4 files changed, 47 insertions(+), 2 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 54bb29c5d..16a086fdb 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -242,6 +242,7 @@ class AgentLoop: restrict_to_workspace=self.restrict_to_workspace, sandbox=self.exec_config.sandbox, path_append=self.exec_config.path_append, + allowed_env_keys=self.exec_config.allowed_env_keys, )) if self.web_config.enable: self.tools.register(WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy)) diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index eb786e9f4..d80b69fbe 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -46,6 +46,7 @@ class ExecTool(Tool): restrict_to_workspace: bool = False, sandbox: str = "", path_append: str = "", + allowed_env_keys: list[str] | None = None, ): self.timeout = timeout self.working_dir = working_dir @@ -64,6 +65,7 @@ class ExecTool(Tool): self.allow_patterns = allow_patterns or [] self.restrict_to_workspace = restrict_to_workspace self.path_append = path_append + self.allowed_env_keys = allowed_env_keys or [] @property def name(self) -> str: @@ -208,7 +210,7 @@ class ExecTool(Tool): """ if _IS_WINDOWS: sr = os.environ.get("SYSTEMROOT", r"C:\Windows") - return { + env = { "SYSTEMROOT": sr, "COMSPEC": os.environ.get("COMSPEC", f"{sr}\\system32\\cmd.exe"), "USERPROFILE": os.environ.get("USERPROFILE", ""), @@ -225,12 +227,22 @@ class ExecTool(Tool): "ProgramFiles(x86)": os.environ.get("ProgramFiles(x86)", ""), "ProgramW6432": os.environ.get("ProgramW6432", ""), } + for key in self.allowed_env_keys: + val = os.environ.get(key) + if val is not None: + env[key] = val + return env home = os.environ.get("HOME", "/tmp") - return { + env = { "HOME": home, "LANG": os.environ.get("LANG", "C.UTF-8"), "TERM": os.environ.get("TERM", "dumb"), } + for key in self.allowed_env_keys: + val = os.environ.get(key) + if val is not None: + env[key] = val + return env def _guard_command(self, command: str, cwd: str) -> str | None: """Best-effort safety guard for potentially destructive commands.""" diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index b011d765f..2d31c8bf9 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -177,6 +177,7 @@ class ExecToolConfig(Base): timeout: int = 60 path_append: str = "" sandbox: str = "" # sandbox backend: "" (none) or "bwrap" + allowed_env_keys: list[str] = Field(default_factory=list) # Env var names to pass through to subprocess (e.g. ["GOPATH", "JAVA_HOME"]) class MCPServerConfig(Base): """MCP server connection configuration (stdio or HTTP).""" diff --git a/tests/tools/test_exec_env.py b/tests/tools/test_exec_env.py index a05510af4..47b2c313d 100644 --- a/tests/tools/test_exec_env.py +++ b/tests/tools/test_exec_env.py @@ -43,3 +43,34 @@ async def test_exec_path_append_preserves_system_path(): tool = ExecTool(path_append="/opt/custom/bin") result = await tool.execute(command="ls /") assert "Exit code: 0" in result + + +@_UNIX_ONLY +@pytest.mark.asyncio +async def test_exec_allowed_env_keys_passthrough(monkeypatch): + """Env vars listed in allowed_env_keys should be visible to commands.""" + monkeypatch.setenv("MY_CUSTOM_VAR", "hello-from-config") + tool = ExecTool(allowed_env_keys=["MY_CUSTOM_VAR"]) + result = await tool.execute(command="printenv MY_CUSTOM_VAR") + assert "hello-from-config" in result + + +@_UNIX_ONLY +@pytest.mark.asyncio +async def test_exec_allowed_env_keys_does_not_leak_others(monkeypatch): + """Env vars NOT in allowed_env_keys should still be blocked.""" + monkeypatch.setenv("MY_CUSTOM_VAR", "hello-from-config") + monkeypatch.setenv("MY_SECRET_VAR", "secret-value") + tool = ExecTool(allowed_env_keys=["MY_CUSTOM_VAR"]) + result = await tool.execute(command="printenv MY_SECRET_VAR") + assert "secret-value" not in result + + +@_UNIX_ONLY +@pytest.mark.asyncio +async def test_exec_allowed_env_keys_missing_var_ignored(monkeypatch): + """If an allowed key is not set in the parent process, it should be silently skipped.""" + monkeypatch.delenv("NONEXISTENT_VAR_12345", raising=False) + tool = ExecTool(allowed_env_keys=["NONEXISTENT_VAR_12345"]) + result = await tool.execute(command="printenv NONEXISTENT_VAR_12345") + assert "Exit code: 1" in result From 7506af7104b3dbe77c1ca03a31d71654ee2d451b Mon Sep 17 00:00:00 2001 From: Jonas Date: Thu, 9 Apr 2026 14:24:04 +0800 Subject: [PATCH 037/115] feat(channel): add proxy support for Discord channel - Add proxy, proxy_username, proxy_password fields to DiscordConfig - Pass proxy and proxy_auth to discord.Client - Add aiohttp.BasicAuth when credentials are provided - Add tests for proxy configuration scenarios --- nanobot/channels/discord.py | 56 ++++++++++-- tests/channels/test_discord_channel.py | 120 ++++++++++++++++++++++--- 2 files changed, 157 insertions(+), 19 deletions(-) diff --git a/nanobot/channels/discord.py b/nanobot/channels/discord.py index 9e68bb46b..c50b4ff19 100644 --- a/nanobot/channels/discord.py +++ b/nanobot/channels/discord.py @@ -22,6 +22,7 @@ from nanobot.utils.helpers import safe_filename, split_message DISCORD_AVAILABLE = importlib.util.find_spec("discord") is not None if TYPE_CHECKING: + import aiohttp import discord from discord import app_commands from discord.abc import Messageable @@ -58,6 +59,9 @@ class DiscordConfig(Base): working_emoji: str = "🔧" working_emoji_delay: float = 2.0 streaming: bool = True + proxy: str | None = None + proxy_username: str | None = None + proxy_password: str | None = None if DISCORD_AVAILABLE: @@ -65,8 +69,15 @@ if DISCORD_AVAILABLE: class DiscordBotClient(discord.Client): """discord.py client that forwards events to the channel.""" - def __init__(self, channel: DiscordChannel, *, intents: discord.Intents) -> None: - super().__init__(intents=intents) + def __init__( + self, + channel: DiscordChannel, + *, + intents: discord.Intents, + proxy: str | None = None, + proxy_auth: aiohttp.BasicAuth | None = None, + ) -> None: + super().__init__(intents=intents, proxy=proxy, proxy_auth=proxy_auth) self._channel = channel self.tree = app_commands.CommandTree(self) self._register_app_commands() @@ -130,6 +141,7 @@ if DISCORD_AVAILABLE: ) for name, description, command_text in commands: + @self.tree.command(name=name, description=description) async def command_handler( interaction: discord.Interaction, @@ -186,7 +198,9 @@ if DISCORD_AVAILABLE: else: failed_media.append(Path(media_path).name) - for index, chunk in enumerate(self._build_chunks(msg.content or "", failed_media, sent_media)): + for index, chunk in enumerate( + self._build_chunks(msg.content or "", failed_media, sent_media) + ): kwargs: dict[str, Any] = {"content": chunk} if index == 0 and reference is not None and not sent_media: kwargs["reference"] = reference @@ -292,7 +306,22 @@ class DiscordChannel(BaseChannel): try: intents = discord.Intents.none() intents.value = self.config.intents - self._client = DiscordBotClient(self, intents=intents) + + proxy_auth = None + if self.config.proxy_username and self.config.proxy_password: + import aiohttp + + proxy_auth = aiohttp.BasicAuth( + login=self.config.proxy_username, + password=self.config.proxy_password, + ) + + self._client = DiscordBotClient( + self, + intents=intents, + proxy=self.config.proxy, + proxy_auth=proxy_auth, + ) except Exception as e: logger.error("Failed to initialize Discord client: {}", e) self._client = None @@ -335,7 +364,9 @@ class DiscordChannel(BaseChannel): await self._stop_typing(msg.chat_id) await self._clear_reactions(msg.chat_id) - async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None: + async def send_delta( + self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None + ) -> None: """Progressive Discord delivery: send once, then edit until the stream ends.""" client = self._client if client is None or not client.is_ready(): @@ -355,7 +386,9 @@ class DiscordChannel(BaseChannel): return buf = self._stream_bufs.get(chat_id) - if buf is None or (stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id): + if buf is None or ( + stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id + ): buf = _StreamBuf(stream_id=stream_id) self._stream_bufs[chat_id] = buf elif buf.stream_id is None: @@ -534,7 +567,11 @@ class DiscordChannel(BaseChannel): @staticmethod def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]: """Build metadata for inbound Discord messages.""" - reply_to = str(message.reference.message_id) if message.reference and message.reference.message_id else None + reply_to = ( + str(message.reference.message_id) + if message.reference and message.reference.message_id + else None + ) return { "message_id": str(message.id), "guild_id": str(message.guild.id) if message.guild else None, @@ -549,7 +586,9 @@ class DiscordChannel(BaseChannel): if self.config.group_policy == "mention": bot_user_id = self._bot_user_id if bot_user_id is None: - logger.debug("Discord message in {} ignored (bot identity unavailable)", message.channel.id) + logger.debug( + "Discord message in {} ignored (bot identity unavailable)", message.channel.id + ) return False if any(str(user.id) == bot_user_id for user in message.mentions): @@ -591,7 +630,6 @@ class DiscordChannel(BaseChannel): except asyncio.CancelledError: pass - async def _clear_reactions(self, chat_id: str) -> None: """Remove all pending reactions after bot replies.""" # Cancel delayed working emoji if it hasn't fired yet diff --git a/tests/channels/test_discord_channel.py b/tests/channels/test_discord_channel.py index 09b80740f..3f0f3388a 100644 --- a/tests/channels/test_discord_channel.py +++ b/tests/channels/test_discord_channel.py @@ -5,11 +5,17 @@ from pathlib import Path from types import SimpleNamespace import pytest + discord = pytest.importorskip("discord") from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus -from nanobot.channels.discord import MAX_MESSAGE_LEN, DiscordBotClient, DiscordChannel, DiscordConfig +from nanobot.channels.discord import ( + MAX_MESSAGE_LEN, + DiscordBotClient, + DiscordChannel, + DiscordConfig, +) from nanobot.command.builtin import build_help_text @@ -18,9 +24,11 @@ class _FakeDiscordClient: instances: list["_FakeDiscordClient"] = [] start_error: Exception | None = None - def __init__(self, owner, *, intents) -> None: + def __init__(self, owner, *, intents, proxy=None, proxy_auth=None) -> None: self.owner = owner self.intents = intents + self.proxy = proxy + self.proxy_auth = proxy_auth self.closed = False self.ready = True self.channels: dict[int, object] = {} @@ -53,7 +61,9 @@ class _FakeDiscordClient: class _FakeAttachment: # Attachment double that can simulate successful or failing save() calls. - def __init__(self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False) -> None: + def __init__( + self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False + ) -> None: self.id = attachment_id self.filename = filename self.size = size @@ -211,7 +221,7 @@ async def test_start_handles_client_construction_failure(monkeypatch) -> None: MessageBus(), ) - def _boom(owner, *, intents): + def _boom(owner, *, intents, proxy=None, proxy_auth=None): raise RuntimeError("bad client") monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom) @@ -514,9 +524,7 @@ async def test_slash_new_forwards_when_user_is_allowlisted() -> None: assert new_cmd is not None await new_cmd.callback(interaction) - assert interaction.response.messages == [ - {"content": "Processing /new...", "ephemeral": True} - ] + assert interaction.response.messages == [{"content": "Processing /new...", "ephemeral": True}] assert len(handled) == 1 assert handled[0]["content"] == "/new" assert handled[0]["sender_id"] == "123" @@ -590,9 +598,7 @@ async def test_slash_help_returns_ephemeral_help_text() -> None: assert help_cmd is not None await help_cmd.callback(interaction) - assert interaction.response.messages == [ - {"content": build_help_text(), "ephemeral": True} - ] + assert interaction.response.messages == [{"content": build_help_text(), "ephemeral": True}] assert handled == [] @@ -727,11 +733,13 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() -> def typing(self): async def _waiter(): await release.wait() + # Hold the loop so task remains active until explicitly stopped. class _Ctx(_TypingCtx): async def __aenter__(self): await super().__aenter__() await _waiter() + return _Ctx() typing_channel = _NoTriggerChannel(channel_id=123) @@ -745,3 +753,95 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() -> await asyncio.sleep(0) assert channel._typing_tasks == {} + + +def test_config_accepts_proxy_fields() -> None: + config = DiscordConfig( + enabled=True, + token="token", + allow_from=["*"], + proxy="http://127.0.0.1:7890", + proxy_username="user", + proxy_password="pass", + ) + assert config.proxy == "http://127.0.0.1:7890" + assert config.proxy_username == "user" + assert config.proxy_password == "pass" + + +def test_config_proxy_defaults_to_none() -> None: + config = DiscordConfig(enabled=True, token="token", allow_from=["*"]) + assert config.proxy is None + assert config.proxy_username is None + assert config.proxy_password is None + + +@pytest.mark.asyncio +async def test_start_passes_proxy_to_client(monkeypatch) -> None: + _FakeDiscordClient.instances.clear() + channel = DiscordChannel( + DiscordConfig( + enabled=True, + token="token", + allow_from=["*"], + proxy="http://127.0.0.1:7890", + ), + MessageBus(), + ) + monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient) + + await channel.start() + + assert channel.is_running is False + assert len(_FakeDiscordClient.instances) == 1 + assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890" + assert _FakeDiscordClient.instances[0].proxy_auth is None + + +@pytest.mark.asyncio +async def test_start_passes_proxy_auth_when_credentials_provided(monkeypatch) -> None: + aiohttp = pytest.importorskip("aiohttp") + _FakeDiscordClient.instances.clear() + channel = DiscordChannel( + DiscordConfig( + enabled=True, + token="token", + allow_from=["*"], + proxy="http://127.0.0.1:7890", + proxy_username="user", + proxy_password="pass", + ), + MessageBus(), + ) + monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient) + + await channel.start() + + assert channel.is_running is False + assert len(_FakeDiscordClient.instances) == 1 + assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890" + assert _FakeDiscordClient.instances[0].proxy_auth is not None + assert isinstance(_FakeDiscordClient.instances[0].proxy_auth, aiohttp.BasicAuth) + assert _FakeDiscordClient.instances[0].proxy_auth.login == "user" + assert _FakeDiscordClient.instances[0].proxy_auth.password == "pass" + + +@pytest.mark.asyncio +async def test_start_no_proxy_auth_when_only_username(monkeypatch) -> None: + _FakeDiscordClient.instances.clear() + channel = DiscordChannel( + DiscordConfig( + enabled=True, + token="token", + allow_from=["*"], + proxy="http://127.0.0.1:7890", + proxy_username="user", + ), + MessageBus(), + ) + monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient) + + await channel.start() + + assert channel.is_running is False + assert _FakeDiscordClient.instances[0].proxy_auth is None From 69d748bf8ff600fad95ad9a5d2c0651575861efd Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Thu, 9 Apr 2026 15:47:37 +0000 Subject: [PATCH 038/115] Merge origin/main; warn on partial proxy credentials; add only-password test - Merged latest main (no conflicts) - Added warning log when only one of proxy_username/proxy_password is set - Added test_start_no_proxy_auth_when_only_password for coverage parity Made-with: Cursor --- nanobot/channels/discord.py | 9 ++++++++- tests/channels/test_discord_channel.py | 22 ++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/nanobot/channels/discord.py b/nanobot/channels/discord.py index c50b4ff19..6e8c673a3 100644 --- a/nanobot/channels/discord.py +++ b/nanobot/channels/discord.py @@ -308,13 +308,20 @@ class DiscordChannel(BaseChannel): intents.value = self.config.intents proxy_auth = None - if self.config.proxy_username and self.config.proxy_password: + has_user = bool(self.config.proxy_username) + has_pass = bool(self.config.proxy_password) + if has_user and has_pass: import aiohttp proxy_auth = aiohttp.BasicAuth( login=self.config.proxy_username, password=self.config.proxy_password, ) + elif has_user != has_pass: + logger.warning( + "Discord proxy auth incomplete: both proxy_username and " + "proxy_password must be set; ignoring partial credentials", + ) self._client = DiscordBotClient( self, diff --git a/tests/channels/test_discord_channel.py b/tests/channels/test_discord_channel.py index 3f0f3388a..3a31a5912 100644 --- a/tests/channels/test_discord_channel.py +++ b/tests/channels/test_discord_channel.py @@ -845,3 +845,25 @@ async def test_start_no_proxy_auth_when_only_username(monkeypatch) -> None: assert channel.is_running is False assert _FakeDiscordClient.instances[0].proxy_auth is None + + +@pytest.mark.asyncio +async def test_start_no_proxy_auth_when_only_password(monkeypatch) -> None: + _FakeDiscordClient.instances.clear() + channel = DiscordChannel( + DiscordConfig( + enabled=True, + token="token", + allow_from=["*"], + proxy="http://127.0.0.1:7890", + proxy_password="pass", + ), + MessageBus(), + ) + monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient) + + await channel.start() + + assert channel.is_running is False + assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890" + assert _FakeDiscordClient.instances[0].proxy_auth is None From 6b7e78a8e0034cfa3248ab9b8cadd7e71fbf7cb7 Mon Sep 17 00:00:00 2001 From: flobo3 Date: Thu, 9 Apr 2026 17:51:29 +0300 Subject: [PATCH 039/115] fix: strip blocks from Gemma 4 and similar models --- nanobot/utils/helpers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index 1f14cb36e..3b4f9f25a 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -15,9 +15,12 @@ from loguru import logger def strip_think(text: str) -> str: - """Remove blocks and any unclosed trailing tag.""" + """Remove thinking blocks and any unclosed trailing tag.""" text = re.sub(r"[\s\S]*?", "", text) text = re.sub(r"[\s\S]*$", "", text) + # Gemma 4 and similar models use ... blocks + text = re.sub(r"[\s\S]*?", "", text) + text = re.sub(r"[\s\S]*$", "", text) return text.strip() From e0c6e6f180945a8641a5201702b8dad08b069a6d Mon Sep 17 00:00:00 2001 From: chengyongru Date: Fri, 10 Apr 2026 10:22:50 +0800 Subject: [PATCH 040/115] test: add regression tests for tag stripping --- tests/utils/test_strip_think.py | 36 +++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 tests/utils/test_strip_think.py diff --git a/tests/utils/test_strip_think.py b/tests/utils/test_strip_think.py new file mode 100644 index 000000000..6710dfc93 --- /dev/null +++ b/tests/utils/test_strip_think.py @@ -0,0 +1,36 @@ +import pytest + +from nanobot.utils.helpers import strip_think + + +class TestStripThinkTag: + """Test ... block stripping (Gemma 4 and similar models).""" + + def test_closed_tag(self): + assert strip_think("Hello reasoning World") == "Hello World" + + def test_unclosed_trailing_tag(self): + assert strip_think("ongoing...") == "" + + def test_multiline_tag(self): + assert strip_think("\nline1\nline2\nEnd") == "End" + + def test_tag_with_nested_angle_brackets(self): + text = "a < 3 and b > 2result" + assert strip_think(text) == "result" + + def test_multiple_tag_blocks(self): + text = "AxByC" + assert strip_think(text) == "ABC" + + def test_tag_only_whitespace_inside(self): + assert strip_think("before after") == "beforeafter" + + def test_self_closing_tag_not_matched(self): + assert strip_think("some text") == "some text" + + def test_normal_text_unchanged(self): + assert strip_think("Just normal text") == "Just normal text" + + def test_empty_string(self): + assert strip_think("") == "" From ce9829e92fb80f21b0a2fc264eaffa6efe9bac9a Mon Sep 17 00:00:00 2001 From: Jiajun Date: Tue, 7 Apr 2026 23:56:23 +0800 Subject: [PATCH 041/115] feat(feishu): add done emoji support for reaction lifecycle (#2899) * feat(feishu): add done emoji support for reaction lifecycle * feat(feishu): add done emoji support and update documentation --- README.md | 4 ++++ nanobot/channels/feishu.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/README.md b/README.md index d43fac43e..d9baa3a21 100644 --- a/README.md +++ b/README.md @@ -560,6 +560,8 @@ Uses **WebSocket** long connection — no public IP required. "verificationToken": "", "allowFrom": ["ou_YOUR_OPEN_ID"], "groupPolicy": "mention", + "reactEmoji": "OnIt", + "doneEmoji": "DONE", "streaming": true } } @@ -570,6 +572,8 @@ Uses **WebSocket** long connection — no public IP required. > `encryptKey` and `verificationToken` are optional for Long Connection mode. > `allowFrom`: Add your open_id (find it in nanobot logs when you message the bot). Use `["*"]` to allow all users. > `groupPolicy`: `"mention"` (default — respond only when @mentioned), `"open"` (respond to all group messages). Private chats always respond. +> `reactEmoji`: Emoji for "processing" status (default: `OnIt`). See [available emojis](https://open.larkoffice.com/document/server-docs/im-v1/message-reaction/emojis-introduce). +> `doneEmoji`: Optional emoji for "completed" status (e.g., `DONE`, `OK`, `HEART`). When set, bot adds this reaction after removing `reactEmoji`. **3. Run** diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py index bac14cb84..e18ed8b01 100644 --- a/nanobot/channels/feishu.py +++ b/nanobot/channels/feishu.py @@ -250,6 +250,7 @@ class FeishuConfig(Base): verification_token: str = "" allow_from: list[str] = Field(default_factory=list) react_emoji: str = "THUMBSUP" + done_emoji: str | None = None # Emoji to show when task is completed (e.g., "DONE", "OK") group_policy: Literal["open", "mention"] = "mention" reply_to_message: bool = False # If True, bot replies quote the user's original message streaming: bool = True @@ -1274,6 +1275,9 @@ class FeishuChannel(BaseChannel): if meta.get("_stream_end"): if (message_id := meta.get("message_id")) and (reaction_id := meta.get("reaction_id")): await self._remove_reaction(message_id, reaction_id) + # Add completion emoji if configured + if self.config.done_emoji and message_id: + await self._add_reaction(message_id, self.config.done_emoji) buf = self._stream_bufs.pop(chat_id, None) if not buf or not buf.text: From ac1795c158228b237975560afd1972e3c1e06b3a Mon Sep 17 00:00:00 2001 From: "xzq.xu" Date: Wed, 1 Apr 2026 17:32:55 +0800 Subject: [PATCH 042/115] feat(feishu): streaming resuming + inline tool hints MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two improvements to Feishu streaming card experience: 1. Handle _resuming in send_delta: when a mid-turn _stream_end arrives with resuming=True (tool call between segments), flush current text to the card but keep the buffer alive so subsequent segments append to the same card instead of creating a new one. 2. Inline tool hints into streaming cards: when a tool hint arrives while a streaming card is active, append it to the card content (e.g. "🔧 web_fetch(...)") instead of sending a separate card. The hint is automatically stripped when the next delta arrives. Made-with: Cursor --- nanobot/channels/feishu.py | 39 ++++++++- tests/channels/test_feishu_streaming.py | 111 ++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 4 deletions(-) diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py index e18ed8b01..b5cedd8c0 100644 --- a/nanobot/channels/feishu.py +++ b/nanobot/channels/feishu.py @@ -267,6 +267,7 @@ class _FeishuStreamBuf: card_id: str | None = None sequence: int = 0 last_edit: float = 0.0 + tool_hint_len: int = 0 class FeishuChannel(BaseChannel): @@ -1279,6 +1280,19 @@ class FeishuChannel(BaseChannel): if self.config.done_emoji and message_id: await self._add_reaction(message_id, self.config.done_emoji) + resuming = meta.get("_resuming", False) + if resuming: + # Mid-turn pause (e.g. tool call between streaming segments). + # Flush current text to card but keep the buffer alive so the + # next segment appends to the same card. + buf = self._stream_bufs.get(chat_id) + if buf and buf.card_id and buf.text: + buf.sequence += 1 + await loop.run_in_executor( + None, self._stream_update_text_sync, buf.card_id, buf.text, buf.sequence, + ) + return + buf = self._stream_bufs.pop(chat_id, None) if not buf or not buf.text: return @@ -1317,6 +1331,9 @@ class FeishuChannel(BaseChannel): if buf is None: buf = _FeishuStreamBuf() self._stream_bufs[chat_id] = buf + if buf.tool_hint_len > 0: + buf.text = buf.text[:-buf.tool_hint_len] + buf.tool_hint_len = 0 buf.text += delta if not buf.text.strip(): return @@ -1350,12 +1367,26 @@ class FeishuChannel(BaseChannel): receive_id_type = "chat_id" if msg.chat_id.startswith("oc_") else "open_id" loop = asyncio.get_running_loop() - # Handle tool hint messages as code blocks in interactive cards. - # These are progress-only messages and should bypass normal reply routing. + # Handle tool hint messages. When a streaming card is active for + # this chat, inline the hint into the card instead of sending a + # separate message so the user experience stays cohesive. if msg.metadata.get("_tool_hint"): - if msg.content and msg.content.strip(): + hint = (msg.content or "").strip() + if not hint: + return + buf = self._stream_bufs.get(msg.chat_id) + if buf and buf.card_id: + suffix = f"\n\n---\n🔧 {hint}" + buf.text += suffix + buf.tool_hint_len = len(suffix) + buf.sequence += 1 + await loop.run_in_executor( + None, self._stream_update_text_sync, + buf.card_id, buf.text, buf.sequence, + ) + else: await self._send_tool_hint_card( - receive_id_type, msg.chat_id, msg.content.strip() + receive_id_type, msg.chat_id, hint ) return diff --git a/tests/channels/test_feishu_streaming.py b/tests/channels/test_feishu_streaming.py index 22ad8cbc6..2fc75bb8a 100644 --- a/tests/channels/test_feishu_streaming.py +++ b/tests/channels/test_feishu_streaming.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock import pytest +from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.feishu import FeishuChannel, FeishuConfig, _FeishuStreamBuf @@ -203,6 +204,55 @@ class TestSendDelta: ch._client.cardkit.v1.card_element.content.assert_not_called() ch._client.im.v1.message.create.assert_called_once() + @pytest.mark.asyncio + async def test_stream_end_resuming_keeps_buffer(self): + """_resuming=True flushes text to card but keeps the buffer for the next segment.""" + ch = _make_channel() + ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( + text="Partial answer", card_id="card_1", sequence=2, last_edit=0.0, + ) + ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response() + + await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True, "_resuming": True}) + + assert "oc_chat1" in ch._stream_bufs + buf = ch._stream_bufs["oc_chat1"] + assert buf.card_id == "card_1" + assert buf.sequence == 3 + ch._client.cardkit.v1.card_element.content.assert_called_once() + ch._client.cardkit.v1.card.settings.assert_not_called() + + @pytest.mark.asyncio + async def test_stream_end_resuming_then_final_end(self): + """Full multi-segment flow: resuming mid-turn, then final end closes the card.""" + ch = _make_channel() + ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( + text="Seg1", card_id="card_1", sequence=1, last_edit=0.0, + ) + ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response() + ch._client.cardkit.v1.card.settings.return_value = _mock_content_response() + + await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True, "_resuming": True}) + assert "oc_chat1" in ch._stream_bufs + + ch._stream_bufs["oc_chat1"].text += " Seg2" + await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True}) + + assert "oc_chat1" not in ch._stream_bufs + ch._client.cardkit.v1.card.settings.assert_called_once() + + @pytest.mark.asyncio + async def test_stream_end_resuming_no_card_is_noop(self): + """_resuming with no card_id (card creation failed) is a safe no-op.""" + ch = _make_channel() + ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( + text="text", card_id=None, sequence=0, last_edit=0.0, + ) + await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True, "_resuming": True}) + + assert "oc_chat1" in ch._stream_bufs + ch._client.cardkit.v1.card_element.content.assert_not_called() + @pytest.mark.asyncio async def test_stream_end_without_buf_is_noop(self): ch = _make_channel() @@ -239,6 +289,67 @@ class TestSendDelta: assert buf.sequence == 7 +class TestToolHintInlineStreaming: + """Tool hint messages should be inlined into active streaming cards.""" + + @pytest.mark.asyncio + async def test_tool_hint_inlined_when_stream_active(self): + """With an active streaming buffer, tool hint appends to the card.""" + ch = _make_channel() + ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( + text="Partial answer", card_id="card_1", sequence=2, last_edit=0.0, + ) + ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response() + + msg = OutboundMessage( + channel="feishu", chat_id="oc_chat1", + content='web_fetch("https://example.com")', + metadata={"_tool_hint": True}, + ) + await ch.send(msg) + + buf = ch._stream_bufs["oc_chat1"] + assert buf.text.endswith('🔧 web_fetch("https://example.com")') + assert buf.tool_hint_len > 0 + assert buf.sequence == 3 + ch._client.cardkit.v1.card_element.content.assert_called_once() + ch._client.im.v1.message.create.assert_not_called() + + @pytest.mark.asyncio + async def test_tool_hint_stripped_on_next_delta(self): + """When new delta arrives, the previously appended tool hint is removed.""" + ch = _make_channel() + suffix = "\n\n---\n🔧 web_fetch(\"url\")" + ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( + text="Partial answer" + suffix, + card_id="card_1", sequence=3, last_edit=0.0, + tool_hint_len=len(suffix), + ) + ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response() + + await ch.send_delta("oc_chat1", " continued") + + buf = ch._stream_bufs["oc_chat1"] + assert buf.text == "Partial answer continued" + assert buf.tool_hint_len == 0 + + @pytest.mark.asyncio + async def test_tool_hint_fallback_when_no_stream(self): + """Without an active buffer, tool hint falls back to a standalone card.""" + ch = _make_channel() + ch._client.im.v1.message.create.return_value = _mock_send_response("om_hint") + + msg = OutboundMessage( + channel="feishu", chat_id="oc_chat1", + content='read_file("path")', + metadata={"_tool_hint": True}, + ) + await ch.send(msg) + + assert "oc_chat1" not in ch._stream_bufs + ch._client.im.v1.message.create.assert_called_once() + + class TestSendMessageReturnsId: def test_returns_message_id_on_success(self): ch = _make_channel() From 589e3ac36e9911bcb47ca984386a9056e3249190 Mon Sep 17 00:00:00 2001 From: "xzq.xu" Date: Wed, 8 Apr 2026 11:33:57 +0800 Subject: [PATCH 043/115] fix(feishu): prevent tool hint stacking and clean hints on stream_end MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three fixes for inline tool hints: 1. Consecutive tool hints now replace the previous one instead of stacking — the old suffix is stripped before appending the new one. 2. When _resuming flushes the buffer, any trailing tool hint suffix is removed so it doesn't persist into the next streaming segment. 3. When final _stream_end closes the card, tool hint suffix is cleaned from the text before the final card update. Adds 3 regression tests covering all three scenarios. Made-with: Cursor --- nanobot/channels/feishu.py | 8 ++++ tests/channels/test_feishu_streaming.py | 64 +++++++++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py index b5cedd8c0..34d59bc5a 100644 --- a/nanobot/channels/feishu.py +++ b/nanobot/channels/feishu.py @@ -1287,6 +1287,9 @@ class FeishuChannel(BaseChannel): # next segment appends to the same card. buf = self._stream_bufs.get(chat_id) if buf and buf.card_id and buf.text: + if buf.tool_hint_len > 0: + buf.text = buf.text[:-buf.tool_hint_len] + buf.tool_hint_len = 0 buf.sequence += 1 await loop.run_in_executor( None, self._stream_update_text_sync, buf.card_id, buf.text, buf.sequence, @@ -1296,6 +1299,9 @@ class FeishuChannel(BaseChannel): buf = self._stream_bufs.pop(chat_id, None) if not buf or not buf.text: return + if buf.tool_hint_len > 0: + buf.text = buf.text[:-buf.tool_hint_len] + buf.tool_hint_len = 0 if buf.card_id: buf.sequence += 1 await loop.run_in_executor( @@ -1376,6 +1382,8 @@ class FeishuChannel(BaseChannel): return buf = self._stream_bufs.get(msg.chat_id) if buf and buf.card_id: + if buf.tool_hint_len > 0: + buf.text = buf.text[:-buf.tool_hint_len] suffix = f"\n\n---\n🔧 {hint}" buf.text += suffix buf.tool_hint_len = len(suffix) diff --git a/tests/channels/test_feishu_streaming.py b/tests/channels/test_feishu_streaming.py index 2fc75bb8a..3683d0f07 100644 --- a/tests/channels/test_feishu_streaming.py +++ b/tests/channels/test_feishu_streaming.py @@ -349,6 +349,70 @@ class TestToolHintInlineStreaming: assert "oc_chat1" not in ch._stream_bufs ch._client.im.v1.message.create.assert_called_once() + @pytest.mark.asyncio + async def test_consecutive_tool_hints_replace_previous(self): + """When multiple tool hints arrive consecutively, each replaces the previous one.""" + ch = _make_channel() + ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( + text="Partial answer", card_id="card_1", sequence=2, last_edit=0.0, + ) + ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response() + + msg1 = OutboundMessage( + channel="feishu", chat_id="oc_chat1", + content='$ cd /project', metadata={"_tool_hint": True}, + ) + await ch.send(msg1) + + msg2 = OutboundMessage( + channel="feishu", chat_id="oc_chat1", + content='$ git status', metadata={"_tool_hint": True}, + ) + await ch.send(msg2) + + buf = ch._stream_bufs["oc_chat1"] + assert buf.text.count("$ cd /project") == 0 + assert buf.text.count("$ git status") == 1 + assert buf.text.startswith("Partial answer") + assert buf.text.endswith("🔧 $ git status") + + @pytest.mark.asyncio + async def test_tool_hint_stripped_on_resuming_flush(self): + """When _resuming flushes the buffer, tool hint suffix is cleaned.""" + ch = _make_channel() + suffix = "\n\n---\n🔧 $ cd /project" + ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( + text="Partial answer" + suffix, + card_id="card_1", sequence=2, last_edit=0.0, + tool_hint_len=len(suffix), + ) + ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response() + + await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True, "_resuming": True}) + + buf = ch._stream_bufs["oc_chat1"] + assert buf.text == "Partial answer" + assert buf.tool_hint_len == 0 + + @pytest.mark.asyncio + async def test_tool_hint_stripped_on_final_stream_end(self): + """When final _stream_end closes the card, tool hint suffix is cleaned from text.""" + ch = _make_channel() + suffix = "\n\n---\n🔧 web_fetch(\"url\")" + ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( + text="Final content" + suffix, + card_id="card_1", sequence=3, last_edit=0.0, + tool_hint_len=len(suffix), + ) + ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response() + ch._client.cardkit.v1.card.settings.return_value = _mock_content_response() + + await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True}) + + assert "oc_chat1" not in ch._stream_bufs + update_call = ch._client.cardkit.v1.card_element.content.call_args[0][0] + assert "🔧" not in update_call.body.content + class TestSendMessageReturnsId: def test_returns_message_id_on_success(self): From 512c3b88e34a2b7d6df1e18c63351245943dcaca Mon Sep 17 00:00:00 2001 From: "xzq.xu" Date: Wed, 8 Apr 2026 11:49:27 +0800 Subject: [PATCH 044/115] fix(feishu): preserve tool hints in final card content Tool hints should be kept as permanent content in the streaming card so users can see which tools were called (matching the standalone card behavior). Previously, hints were stripped when new deltas arrived or when the stream ended, causing tool call information to disappear. Now: - New delta: hint becomes permanent content, delta appends after it - New tool hint: replaces the previous hint (unchanged) - Resuming/stream_end: hint is preserved in the final text Updated 3 tests to verify hint preservation semantics. Made-with: Cursor --- nanobot/channels/feishu.py | 9 ++------- tests/channels/test_feishu_streaming.py | 21 ++++++++++++--------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py index 34d59bc5a..2441c20c5 100644 --- a/nanobot/channels/feishu.py +++ b/nanobot/channels/feishu.py @@ -1287,9 +1287,7 @@ class FeishuChannel(BaseChannel): # next segment appends to the same card. buf = self._stream_bufs.get(chat_id) if buf and buf.card_id and buf.text: - if buf.tool_hint_len > 0: - buf.text = buf.text[:-buf.tool_hint_len] - buf.tool_hint_len = 0 + buf.tool_hint_len = 0 buf.sequence += 1 await loop.run_in_executor( None, self._stream_update_text_sync, buf.card_id, buf.text, buf.sequence, @@ -1299,9 +1297,7 @@ class FeishuChannel(BaseChannel): buf = self._stream_bufs.pop(chat_id, None) if not buf or not buf.text: return - if buf.tool_hint_len > 0: - buf.text = buf.text[:-buf.tool_hint_len] - buf.tool_hint_len = 0 + buf.tool_hint_len = 0 if buf.card_id: buf.sequence += 1 await loop.run_in_executor( @@ -1338,7 +1334,6 @@ class FeishuChannel(BaseChannel): buf = _FeishuStreamBuf() self._stream_bufs[chat_id] = buf if buf.tool_hint_len > 0: - buf.text = buf.text[:-buf.tool_hint_len] buf.tool_hint_len = 0 buf.text += delta if not buf.text.strip(): diff --git a/tests/channels/test_feishu_streaming.py b/tests/channels/test_feishu_streaming.py index 3683d0f07..1559ef8d3 100644 --- a/tests/channels/test_feishu_streaming.py +++ b/tests/channels/test_feishu_streaming.py @@ -316,8 +316,8 @@ class TestToolHintInlineStreaming: ch._client.im.v1.message.create.assert_not_called() @pytest.mark.asyncio - async def test_tool_hint_stripped_on_next_delta(self): - """When new delta arrives, the previously appended tool hint is removed.""" + async def test_tool_hint_preserved_on_next_delta(self): + """When new delta arrives, the tool hint is kept as permanent content and delta appends after it.""" ch = _make_channel() suffix = "\n\n---\n🔧 web_fetch(\"url\")" ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( @@ -330,7 +330,9 @@ class TestToolHintInlineStreaming: await ch.send_delta("oc_chat1", " continued") buf = ch._stream_bufs["oc_chat1"] - assert buf.text == "Partial answer continued" + assert "Partial answer" in buf.text + assert "🔧 web_fetch" in buf.text + assert buf.text.endswith(" continued") assert buf.tool_hint_len == 0 @pytest.mark.asyncio @@ -377,8 +379,8 @@ class TestToolHintInlineStreaming: assert buf.text.endswith("🔧 $ git status") @pytest.mark.asyncio - async def test_tool_hint_stripped_on_resuming_flush(self): - """When _resuming flushes the buffer, tool hint suffix is cleaned.""" + async def test_tool_hint_preserved_on_resuming_flush(self): + """When _resuming flushes the buffer, tool hint is kept as permanent content.""" ch = _make_channel() suffix = "\n\n---\n🔧 $ cd /project" ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( @@ -391,12 +393,13 @@ class TestToolHintInlineStreaming: await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True, "_resuming": True}) buf = ch._stream_bufs["oc_chat1"] - assert buf.text == "Partial answer" + assert "Partial answer" in buf.text + assert "🔧 $ cd /project" in buf.text assert buf.tool_hint_len == 0 @pytest.mark.asyncio - async def test_tool_hint_stripped_on_final_stream_end(self): - """When final _stream_end closes the card, tool hint suffix is cleaned from text.""" + async def test_tool_hint_preserved_on_final_stream_end(self): + """When final _stream_end closes the card, tool hint is kept in the final text.""" ch = _make_channel() suffix = "\n\n---\n🔧 web_fetch(\"url\")" ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( @@ -411,7 +414,7 @@ class TestToolHintInlineStreaming: assert "oc_chat1" not in ch._stream_bufs update_call = ch._client.cardkit.v1.card_element.content.call_args[0][0] - assert "🔧" not in update_call.body.content + assert "🔧" in update_call.body.content class TestSendMessageReturnsId: From 049ce9baaed553ff11e982568a4020c7426a6566 Mon Sep 17 00:00:00 2001 From: "xzq.xu" Date: Wed, 8 Apr 2026 13:12:47 +0800 Subject: [PATCH 045/115] fix(tool-hints): deduplicate by formatted string + per-line inline display MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two display fixes based on real-world Feishu testing: 1. tool_hints.py: format_tool_hints now deduplicates by comparing the fully formatted hint string instead of tool name alone. This fixes `ls /Desktop` and `ls /Downloads` being incorrectly merged as `ls /Desktop × 2`. Truly identical calls still fold correctly. (_group_consecutive and all abbreviation logic preserved unchanged.) 2. feishu.py: inline tool hints now display one tool per line with 🔧 prefix, and use double-newline trailing to prevent Setext heading rendering when followed by markdown `---`. Made-with: Cursor --- nanobot/channels/feishu.py | 4 +++- tests/channels/test_feishu_streaming.py | 10 +++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py index 2441c20c5..365732595 100644 --- a/nanobot/channels/feishu.py +++ b/nanobot/channels/feishu.py @@ -1379,7 +1379,9 @@ class FeishuChannel(BaseChannel): if buf and buf.card_id: if buf.tool_hint_len > 0: buf.text = buf.text[:-buf.tool_hint_len] - suffix = f"\n\n---\n🔧 {hint}" + lines = self._format_tool_hint_lines(hint).split("\n") + formatted = "\n".join(f"🔧 {ln}" for ln in lines if ln.strip()) + suffix = f"\n\n{formatted}\n\n" buf.text += suffix buf.tool_hint_len = len(suffix) buf.sequence += 1 diff --git a/tests/channels/test_feishu_streaming.py b/tests/channels/test_feishu_streaming.py index 1559ef8d3..6f5fc0580 100644 --- a/tests/channels/test_feishu_streaming.py +++ b/tests/channels/test_feishu_streaming.py @@ -309,7 +309,7 @@ class TestToolHintInlineStreaming: await ch.send(msg) buf = ch._stream_bufs["oc_chat1"] - assert buf.text.endswith('🔧 web_fetch("https://example.com")') + assert '🔧 web_fetch("https://example.com")' in buf.text assert buf.tool_hint_len > 0 assert buf.sequence == 3 ch._client.cardkit.v1.card_element.content.assert_called_once() @@ -319,7 +319,7 @@ class TestToolHintInlineStreaming: async def test_tool_hint_preserved_on_next_delta(self): """When new delta arrives, the tool hint is kept as permanent content and delta appends after it.""" ch = _make_channel() - suffix = "\n\n---\n🔧 web_fetch(\"url\")" + suffix = "\n\n🔧 web_fetch(\"url\")\n\n" ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( text="Partial answer" + suffix, card_id="card_1", sequence=3, last_edit=0.0, @@ -376,13 +376,13 @@ class TestToolHintInlineStreaming: assert buf.text.count("$ cd /project") == 0 assert buf.text.count("$ git status") == 1 assert buf.text.startswith("Partial answer") - assert buf.text.endswith("🔧 $ git status") + assert "🔧 $ git status" in buf.text @pytest.mark.asyncio async def test_tool_hint_preserved_on_resuming_flush(self): """When _resuming flushes the buffer, tool hint is kept as permanent content.""" ch = _make_channel() - suffix = "\n\n---\n🔧 $ cd /project" + suffix = "\n\n🔧 $ cd /project\n\n" ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( text="Partial answer" + suffix, card_id="card_1", sequence=2, last_edit=0.0, @@ -401,7 +401,7 @@ class TestToolHintInlineStreaming: async def test_tool_hint_preserved_on_final_stream_end(self): """When final _stream_end closes the card, tool hint is kept in the final text.""" ch = _make_channel() - suffix = "\n\n---\n🔧 web_fetch(\"url\")" + suffix = "\n\n🔧 web_fetch(\"url\")\n\n" ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( text="Final content" + suffix, card_id="card_1", sequence=3, last_edit=0.0, From 6fd2511c8a2559a6589d8627e8aacb0da8a21042 Mon Sep 17 00:00:00 2001 From: chengyongru Date: Wed, 8 Apr 2026 18:31:40 +0800 Subject: [PATCH 046/115] refactor(feishu): simplify tool hint to append-only, delegate to send_delta for throttling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Make tool_hint_prefix configurable in FeishuConfig (default: 🔧) - Delegate tool hint card updates from send() to send_delta() so hints automatically benefit from _STREAM_EDIT_INTERVAL throttling - Fix staticmethod calls to use self.__class__ instead of self - Document all supported metadata keys in send_delta docstring - Add test for empty/whitespace-only tool hint with active stream buffer --- nanobot/channels/feishu.py | 45 ++++++++++++------------- tests/channels/test_feishu_streaming.py | 44 +++++++++++++++--------- 2 files changed, 50 insertions(+), 39 deletions(-) diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py index 365732595..e57fcef85 100644 --- a/nanobot/channels/feishu.py +++ b/nanobot/channels/feishu.py @@ -251,6 +251,7 @@ class FeishuConfig(Base): allow_from: list[str] = Field(default_factory=list) react_emoji: str = "THUMBSUP" done_emoji: str | None = None # Emoji to show when task is completed (e.g., "DONE", "OK") + tool_hint_prefix: str = "\U0001f527" # Prefix for inline tool hints (default: 🔧) group_policy: Literal["open", "mention"] = "mention" reply_to_message: bool = False # If True, bot replies quote the user's original message streaming: bool = True @@ -267,7 +268,6 @@ class _FeishuStreamBuf: card_id: str | None = None sequence: int = 0 last_edit: float = 0.0 - tool_hint_len: int = 0 class FeishuChannel(BaseChannel): @@ -1265,7 +1265,15 @@ class FeishuChannel(BaseChannel): async def send_delta( self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None ) -> None: - """Progressive streaming via CardKit: create card on first delta, stream-update on subsequent.""" + """Progressive streaming via CardKit: create card on first delta, stream-update on subsequent. + + Supported metadata keys: + _stream_end: Finalize the streaming card. + _resuming: Mid-turn pause – flush but keep the buffer alive. + _tool_hint: Delta is a formatted tool hint (for display only). + message_id: Original message id (used with _stream_end for reaction cleanup). + reaction_id: Reaction id to remove on stream end. + """ if not self._client: return meta = metadata or {} @@ -1287,7 +1295,6 @@ class FeishuChannel(BaseChannel): # next segment appends to the same card. buf = self._stream_bufs.get(chat_id) if buf and buf.card_id and buf.text: - buf.tool_hint_len = 0 buf.sequence += 1 await loop.run_in_executor( None, self._stream_update_text_sync, buf.card_id, buf.text, buf.sequence, @@ -1297,7 +1304,6 @@ class FeishuChannel(BaseChannel): buf = self._stream_bufs.pop(chat_id, None) if not buf or not buf.text: return - buf.tool_hint_len = 0 if buf.card_id: buf.sequence += 1 await loop.run_in_executor( @@ -1333,8 +1339,6 @@ class FeishuChannel(BaseChannel): if buf is None: buf = _FeishuStreamBuf() self._stream_bufs[chat_id] = buf - if buf.tool_hint_len > 0: - buf.tool_hint_len = 0 buf.text += delta if not buf.text.strip(): return @@ -1377,22 +1381,17 @@ class FeishuChannel(BaseChannel): return buf = self._stream_bufs.get(msg.chat_id) if buf and buf.card_id: - if buf.tool_hint_len > 0: - buf.text = buf.text[:-buf.tool_hint_len] - lines = self._format_tool_hint_lines(hint).split("\n") - formatted = "\n".join(f"🔧 {ln}" for ln in lines if ln.strip()) - suffix = f"\n\n{formatted}\n\n" - buf.text += suffix - buf.tool_hint_len = len(suffix) - buf.sequence += 1 - await loop.run_in_executor( - None, self._stream_update_text_sync, - buf.card_id, buf.text, buf.sequence, - ) - else: - await self._send_tool_hint_card( - receive_id_type, msg.chat_id, hint - ) + # Delegate to send_delta so tool hints get the same + # throttling (and card creation) as regular text deltas. + lines = self.__class__._format_tool_hint_lines(hint).split("\n") + delta = "\n\n" + "\n".join( + f"{self.config.tool_hint_prefix} {ln}" for ln in lines if ln.strip() + ) + "\n\n" + await self.send_delta(msg.chat_id, delta) + return + await self._send_tool_hint_card( + receive_id_type, msg.chat_id, hint + ) return # Determine whether the first message should quote the user's message. @@ -1701,7 +1700,7 @@ class FeishuChannel(BaseChannel): loop = asyncio.get_running_loop() # Put each top-level tool call on its own line without altering commas inside arguments. - formatted_code = self._format_tool_hint_lines(tool_hint) + formatted_code = self.__class__._format_tool_hint_lines(tool_hint) card = { "config": {"wide_screen_mode": True}, diff --git a/tests/channels/test_feishu_streaming.py b/tests/channels/test_feishu_streaming.py index 6f5fc0580..a047c8c5f 100644 --- a/tests/channels/test_feishu_streaming.py +++ b/tests/channels/test_feishu_streaming.py @@ -310,7 +310,6 @@ class TestToolHintInlineStreaming: buf = ch._stream_bufs["oc_chat1"] assert '🔧 web_fetch("https://example.com")' in buf.text - assert buf.tool_hint_len > 0 assert buf.sequence == 3 ch._client.cardkit.v1.card_element.content.assert_called_once() ch._client.im.v1.message.create.assert_not_called() @@ -319,11 +318,9 @@ class TestToolHintInlineStreaming: async def test_tool_hint_preserved_on_next_delta(self): """When new delta arrives, the tool hint is kept as permanent content and delta appends after it.""" ch = _make_channel() - suffix = "\n\n🔧 web_fetch(\"url\")\n\n" ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( - text="Partial answer" + suffix, + text="Partial answer\n\n🔧 web_fetch(\"url\")\n\n", card_id="card_1", sequence=3, last_edit=0.0, - tool_hint_len=len(suffix), ) ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response() @@ -333,7 +330,6 @@ class TestToolHintInlineStreaming: assert "Partial answer" in buf.text assert "🔧 web_fetch" in buf.text assert buf.text.endswith(" continued") - assert buf.tool_hint_len == 0 @pytest.mark.asyncio async def test_tool_hint_fallback_when_no_stream(self): @@ -352,8 +348,8 @@ class TestToolHintInlineStreaming: ch._client.im.v1.message.create.assert_called_once() @pytest.mark.asyncio - async def test_consecutive_tool_hints_replace_previous(self): - """When multiple tool hints arrive consecutively, each replaces the previous one.""" + async def test_consecutive_tool_hints_append(self): + """When multiple tool hints arrive consecutively, each appends to the card.""" ch = _make_channel() ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( text="Partial answer", card_id="card_1", sequence=2, last_edit=0.0, @@ -373,20 +369,19 @@ class TestToolHintInlineStreaming: await ch.send(msg2) buf = ch._stream_bufs["oc_chat1"] - assert buf.text.count("$ cd /project") == 0 - assert buf.text.count("$ git status") == 1 + assert "$ cd /project" in buf.text + assert "$ git status" in buf.text assert buf.text.startswith("Partial answer") + assert "🔧 $ cd /project" in buf.text assert "🔧 $ git status" in buf.text @pytest.mark.asyncio async def test_tool_hint_preserved_on_resuming_flush(self): """When _resuming flushes the buffer, tool hint is kept as permanent content.""" ch = _make_channel() - suffix = "\n\n🔧 $ cd /project\n\n" ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( - text="Partial answer" + suffix, + text="Partial answer\n\n🔧 $ cd /project\n\n", card_id="card_1", sequence=2, last_edit=0.0, - tool_hint_len=len(suffix), ) ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response() @@ -395,17 +390,14 @@ class TestToolHintInlineStreaming: buf = ch._stream_bufs["oc_chat1"] assert "Partial answer" in buf.text assert "🔧 $ cd /project" in buf.text - assert buf.tool_hint_len == 0 @pytest.mark.asyncio async def test_tool_hint_preserved_on_final_stream_end(self): """When final _stream_end closes the card, tool hint is kept in the final text.""" ch = _make_channel() - suffix = "\n\n🔧 web_fetch(\"url\")\n\n" ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( - text="Final content" + suffix, + text="Final content\n\n🔧 web_fetch(\"url\")\n\n", card_id="card_1", sequence=3, last_edit=0.0, - tool_hint_len=len(suffix), ) ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response() ch._client.cardkit.v1.card.settings.return_value = _mock_content_response() @@ -416,6 +408,26 @@ class TestToolHintInlineStreaming: update_call = ch._client.cardkit.v1.card_element.content.call_args[0][0] assert "🔧" in update_call.body.content + @pytest.mark.asyncio + async def test_empty_tool_hint_is_noop(self): + """Empty or whitespace-only tool hint content is silently ignored.""" + ch = _make_channel() + ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( + text="Partial answer", card_id="card_1", sequence=2, last_edit=0.0, + ) + + for content in ("", " ", "\t\n"): + msg = OutboundMessage( + channel="feishu", chat_id="oc_chat1", + content=content, metadata={"_tool_hint": True}, + ) + await ch.send(msg) + + buf = ch._stream_bufs["oc_chat1"] + assert buf.text == "Partial answer" + assert buf.sequence == 2 + ch._client.cardkit.v1.card_element.content.assert_not_called() + class TestSendMessageReturnsId: def test_returns_message_id_on_success(self): From 27e7a338a3ecfe1be438cd474af5115daa0ee63f Mon Sep 17 00:00:00 2001 From: chengyongru Date: Fri, 10 Apr 2026 10:40:30 +0800 Subject: [PATCH 047/115] docs(feishu): add toolHintPrefix to README config example --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index d9baa3a21..6098c55ca 100644 --- a/README.md +++ b/README.md @@ -562,6 +562,7 @@ Uses **WebSocket** long connection — no public IP required. "groupPolicy": "mention", "reactEmoji": "OnIt", "doneEmoji": "DONE", + "toolHintPrefix": "🔧", "streaming": true } } @@ -574,6 +575,7 @@ Uses **WebSocket** long connection — no public IP required. > `groupPolicy`: `"mention"` (default — respond only when @mentioned), `"open"` (respond to all group messages). Private chats always respond. > `reactEmoji`: Emoji for "processing" status (default: `OnIt`). See [available emojis](https://open.larkoffice.com/document/server-docs/im-v1/message-reaction/emojis-introduce). > `doneEmoji`: Optional emoji for "completed" status (e.g., `DONE`, `OK`, `HEART`). When set, bot adds this reaction after removing `reactEmoji`. +> `toolHintPrefix`: Prefix for inline tool hints in streaming cards (default: `🔧`). **3. Run** From 363a0704dbc8cb11f40bd7eef754ad744fefa263 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Fri, 10 Apr 2026 04:46:48 +0000 Subject: [PATCH 048/115] refactor(runner): update message processing to preserve historical context - Adjusted message handling in AgentRunner to ensure that historical messages remain unchanged during context governance. - Introduced tests to verify that backfill operations do not alter the saved message boundary, maintaining the integrity of the conversation history. --- nanobot/agent/runner.py | 12 ++- tests/agent/test_runner.py | 163 +++++++++++++++++++++++++++++++++++++ 2 files changed, 171 insertions(+), 4 deletions(-) diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index bc1a26aba..e90715375 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -101,10 +101,14 @@ class AgentRunner: for iteration in range(spec.max_iterations): try: - messages = self._backfill_missing_tool_results(messages) - messages = self._microcompact(messages) - messages = self._apply_tool_result_budget(spec, messages) - messages_for_model = self._snip_history(spec, messages) + # Keep the persisted conversation untouched. Context governance + # may repair or compact historical messages for the model, but + # those synthetic edits must not shift the append boundary used + # later when the caller saves only the new turn. + messages_for_model = self._backfill_missing_tool_results(messages) + messages_for_model = self._microcompact(messages_for_model) + messages_for_model = self._apply_tool_result_budget(spec, messages_for_model) + messages_for_model = self._snip_history(spec, messages_for_model) except Exception as exc: logger.warning( "Context governance failed on turn {} for {}: {}; using raw messages", diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index afb06634f..a298ed956 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -1239,6 +1239,169 @@ async def test_backfill_noop_when_complete(): assert result is messages # same object — no copy +@pytest.mark.asyncio +async def test_backfill_repairs_model_context_without_shifting_save_turn_boundary(tmp_path): + """Historical backfill should not duplicate old tail messages on persist.""" + from nanobot.agent.loop import AgentLoop + from nanobot.agent.runner import _BACKFILL_CONTENT + from nanobot.bus.events import InboundMessage + from nanobot.bus.queue import MessageBus + + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + response = LLMResponse(content="new answer", tool_calls=[], usage={}) + provider.chat_with_retry = AsyncMock(return_value=response) + provider.chat_stream_with_retry = AsyncMock(return_value=response) + + loop = AgentLoop( + bus=MessageBus(), + provider=provider, + workspace=tmp_path, + model="test-model", + ) + loop.tools.get_definitions = MagicMock(return_value=[]) + loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] + + session = loop.sessions.get_or_create("cli:test") + session.messages = [ + {"role": "user", "content": "old user", "timestamp": "2026-01-01T00:00:00"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_missing", + "type": "function", + "function": {"name": "read_file", "arguments": "{}"}, + } + ], + "timestamp": "2026-01-01T00:00:01", + }, + {"role": "assistant", "content": "old tail", "timestamp": "2026-01-01T00:00:02"}, + ] + loop.sessions.save(session) + + result = await loop._process_message( + InboundMessage(channel="cli", sender_id="user", chat_id="test", content="new prompt") + ) + + assert result is not None + assert result.content == "new answer" + + request_messages = provider.chat_with_retry.await_args.kwargs["messages"] + synthetic = [ + message + for message in request_messages + if message.get("role") == "tool" and message.get("tool_call_id") == "call_missing" + ] + assert len(synthetic) == 1 + assert synthetic[0]["content"] == _BACKFILL_CONTENT + + session_after = loop.sessions.get_or_create("cli:test") + assert [ + { + key: value + for key, value in message.items() + if key in {"role", "content", "tool_call_id", "name", "tool_calls"} + } + for message in session_after.messages + ] == [ + {"role": "user", "content": "old user"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_missing", + "type": "function", + "function": {"name": "read_file", "arguments": "{}"}, + } + ], + }, + {"role": "assistant", "content": "old tail"}, + {"role": "user", "content": "new prompt"}, + {"role": "assistant", "content": "new answer"}, + ] + + +@pytest.mark.asyncio +async def test_runner_backfill_only_mutates_model_context_not_returned_messages(): + """Runner should repair orphaned tool calls for the model without rewriting result.messages.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner, _BACKFILL_CONTENT + + provider = MagicMock() + captured_messages: list[dict] = [] + + async def chat_with_retry(*, messages, **kwargs): + captured_messages[:] = messages + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + initial_messages = [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "old user"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_missing", + "type": "function", + "function": {"name": "read_file", "arguments": "{}"}, + } + ], + }, + {"role": "assistant", "content": "old tail"}, + {"role": "user", "content": "new prompt"}, + ] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=initial_messages, + tools=tools, + model="test-model", + max_iterations=3, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + synthetic = [ + message + for message in captured_messages + if message.get("role") == "tool" and message.get("tool_call_id") == "call_missing" + ] + assert len(synthetic) == 1 + assert synthetic[0]["content"] == _BACKFILL_CONTENT + + assert [ + { + key: value + for key, value in message.items() + if key in {"role", "content", "tool_call_id", "name", "tool_calls"} + } + for message in result.messages + ] == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "old user"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_missing", + "type": "function", + "function": {"name": "read_file", "arguments": "{}"}, + } + ], + }, + {"role": "assistant", "content": "old tail"}, + {"role": "user", "content": "new prompt"}, + {"role": "assistant", "content": "done"}, + ] + + # --------------------------------------------------------------------------- # Microcompact (stale tool result compaction) # --------------------------------------------------------------------------- From bfe53ebb10f95f56424a52eaf474ad3f840739f0 Mon Sep 17 00:00:00 2001 From: comadreja Date: Thu, 9 Apr 2026 11:27:15 -0500 Subject: [PATCH 049/115] fix(memory): harden consolidation with try/except on token estimation and chunk size cap - Wrap both token estimation calls in try/except to prevent silent failures from crashing the consolidation cycle - Add _MAX_CHUNK_MESSAGES = 60 to cap messages per consolidation round, avoiding oversized chunks being sent to the consolidation LLM - Improve idle log to include unconsolidated message count for easier debugging These are purely defensive improvements with no behaviour change for normal sessions. --- nanobot/agent/memory.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py index fc72c573f..0bad5125a 100644 --- a/nanobot/agent/memory.py +++ b/nanobot/agent/memory.py @@ -347,6 +347,7 @@ class Consolidator: """Lightweight consolidation: summarizes evicted messages into history.jsonl.""" _MAX_CONSOLIDATION_ROUNDS = 5 + _MAX_CHUNK_MESSAGES = 60 # hard cap per consolidation round _SAFETY_BUFFER = 1024 # extra headroom for tokenizer estimation drift @@ -461,16 +462,22 @@ class Consolidator: async with lock: budget = self.context_window_tokens - self.max_completion_tokens - self._SAFETY_BUFFER target = budget // 2 - estimated, source = self.estimate_session_prompt_tokens(session) + try: + estimated, source = self.estimate_session_prompt_tokens(session) + except Exception: + logger.exception("Token estimation failed for {}", session.key) + estimated, source = 0, "error" if estimated <= 0: return if estimated < budget: + unconsolidated_count = len(session.messages) - session.last_consolidated logger.debug( - "Token consolidation idle {}: {}/{} via {}", + "Token consolidation idle {}: {}/{} via {}, msgs={}", session.key, estimated, self.context_window_tokens, source, + unconsolidated_count, ) return @@ -492,6 +499,10 @@ class Consolidator: if not chunk: return + if len(chunk) > self._MAX_CHUNK_MESSAGES: + chunk = chunk[:self._MAX_CHUNK_MESSAGES] + end_idx = session.last_consolidated + len(chunk) + logger.info( "Token consolidation round {} for {}: {}/{} via {}, chunk={} msgs", round_num, @@ -506,7 +517,11 @@ class Consolidator: session.last_consolidated = end_idx self.sessions.save(session) - estimated, source = self.estimate_session_prompt_tokens(session) + try: + estimated, source = self.estimate_session_prompt_tokens(session) + except Exception: + logger.exception("Token estimation failed for {}", session.key) + estimated, source = 0, "error" if estimated <= 0: return From c579d67887f3e9d348ef16b9210749f4820310ac Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Fri, 10 Apr 2026 04:57:02 +0000 Subject: [PATCH 050/115] fix(memory): preserve consolidation turn boundaries under chunk cap Made-with: Cursor --- nanobot/agent/memory.py | 29 ++++++++++++++++--- tests/agent/test_consolidator.py | 49 ++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 4 deletions(-) diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py index 0bad5125a..943d91855 100644 --- a/nanobot/agent/memory.py +++ b/nanobot/agent/memory.py @@ -400,6 +400,22 @@ class Consolidator: return last_boundary + def _cap_consolidation_boundary( + self, + session: Session, + end_idx: int, + ) -> int | None: + """Clamp the chunk size without breaking the user-turn boundary.""" + start = session.last_consolidated + if end_idx - start <= self._MAX_CHUNK_MESSAGES: + return end_idx + + capped_end = start + self._MAX_CHUNK_MESSAGES + for idx in range(capped_end, start, -1): + if session.messages[idx].get("role") == "user": + return idx + return None + def estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]: """Estimate current prompt size for the normal session history view.""" history = session.get_history(max_messages=0) @@ -495,14 +511,19 @@ class Consolidator: return end_idx = boundary[0] + end_idx = self._cap_consolidation_boundary(session, end_idx) + if end_idx is None: + logger.debug( + "Token consolidation: no capped boundary for {} (round {})", + session.key, + round_num, + ) + return + chunk = session.messages[session.last_consolidated:end_idx] if not chunk: return - if len(chunk) > self._MAX_CHUNK_MESSAGES: - chunk = chunk[:self._MAX_CHUNK_MESSAGES] - end_idx = session.last_consolidated + len(chunk) - logger.info( "Token consolidation round {} for {}: {}/{} via {}, chunk={} msgs", round_num, diff --git a/tests/agent/test_consolidator.py b/tests/agent/test_consolidator.py index 72968b0e1..b7989d9dd 100644 --- a/tests/agent/test_consolidator.py +++ b/tests/agent/test_consolidator.py @@ -76,3 +76,52 @@ class TestConsolidatorTokenBudget: consolidator.archive = AsyncMock(return_value=True) await consolidator.maybe_consolidate_by_tokens(session) consolidator.archive.assert_not_called() + + async def test_chunk_cap_preserves_user_turn_boundary(self, consolidator): + """Chunk cap should rewind to the last user boundary within the cap.""" + consolidator._SAFETY_BUFFER = 0 + session = MagicMock() + session.last_consolidated = 0 + session.key = "test:key" + session.messages = [ + { + "role": "user" if i in {0, 50, 61} else "assistant", + "content": f"m{i}", + } + for i in range(70) + ] + consolidator.estimate_session_prompt_tokens = MagicMock( + side_effect=[(1200, "tiktoken"), (400, "tiktoken")] + ) + consolidator.pick_consolidation_boundary = MagicMock(return_value=(61, 999)) + consolidator.archive = AsyncMock(return_value=True) + + await consolidator.maybe_consolidate_by_tokens(session) + + archived_chunk = consolidator.archive.await_args.args[0] + assert len(archived_chunk) == 50 + assert archived_chunk[0]["content"] == "m0" + assert archived_chunk[-1]["content"] == "m49" + assert session.last_consolidated == 50 + + async def test_chunk_cap_skips_when_no_user_boundary_within_cap(self, consolidator): + """If the cap would cut mid-turn, consolidation should skip that round.""" + consolidator._SAFETY_BUFFER = 0 + session = MagicMock() + session.last_consolidated = 0 + session.key = "test:key" + session.messages = [ + { + "role": "user" if i in {0, 61} else "assistant", + "content": f"m{i}", + } + for i in range(70) + ] + consolidator.estimate_session_prompt_tokens = MagicMock(return_value=(1200, "tiktoken")) + consolidator.pick_consolidation_boundary = MagicMock(return_value=(61, 999)) + consolidator.archive = AsyncMock(return_value=True) + + await consolidator.maybe_consolidate_by_tokens(session) + + consolidator.archive.assert_not_awaited() + assert session.last_consolidated == 0 From 2bef9cb6504a6d7b39c5569738d62e7af70174f7 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Fri, 10 Apr 2026 05:37:25 +0000 Subject: [PATCH 051/115] fix(agent): preserve interrupted tool-call turns Keep tool-call assistant messages valid across provider sanitization and avoid trailing user-only history after model errors. This prevents follow-up requests from sending broken tool chains back to the gateway. --- nanobot/agent/runner.py | 37 ++++- nanobot/providers/base.py | 8 ++ nanobot/providers/openai_compat_provider.py | 4 + tests/agent/test_runner.py | 131 +++++++++++++++++- .../test_enforce_role_alternation.py | 28 ++++ tests/providers/test_litellm_kwargs.py | 29 ++++ 6 files changed, 235 insertions(+), 2 deletions(-) diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index e90715375..cfebe098f 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -31,6 +31,7 @@ from nanobot.utils.runtime import ( ) _DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model." +_PERSISTED_MODEL_ERROR_PLACEHOLDER = "[Assistant reply unavailable due to model error.]" _MAX_EMPTY_RETRIES = 2 _MAX_LENGTH_RECOVERIES = 3 _SNIP_SAFETY_BUFFER = 1024 @@ -105,7 +106,8 @@ class AgentRunner: # may repair or compact historical messages for the model, but # those synthetic edits must not shift the append boundary used # later when the caller saves only the new turn. - messages_for_model = self._backfill_missing_tool_results(messages) + messages_for_model = self._drop_orphan_tool_results(messages) + messages_for_model = self._backfill_missing_tool_results(messages_for_model) messages_for_model = self._microcompact(messages_for_model) messages_for_model = self._apply_tool_result_budget(spec, messages_for_model) messages_for_model = self._snip_history(spec, messages_for_model) @@ -261,6 +263,7 @@ class AgentRunner: final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE stop_reason = "error" error = final_content + self._append_model_error_placeholder(messages) context.final_content = final_content context.error = error context.stop_reason = stop_reason @@ -524,6 +527,12 @@ class AgentRunner: return messages.append(build_assistant_message(content)) + @staticmethod + def _append_model_error_placeholder(messages: list[dict[str, Any]]) -> None: + if messages and messages[-1].get("role") == "assistant" and not messages[-1].get("tool_calls"): + return + messages.append(build_assistant_message(_PERSISTED_MODEL_ERROR_PLACEHOLDER)) + def _normalize_tool_result( self, spec: AgentRunSpec, @@ -552,6 +561,32 @@ class AgentRunner: return truncate_text(content, spec.max_tool_result_chars) return content + @staticmethod + def _drop_orphan_tool_results( + messages: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + """Drop tool results that have no matching assistant tool_call earlier in the history.""" + declared: set[str] = set() + updated: list[dict[str, Any]] | None = None + for idx, msg in enumerate(messages): + role = msg.get("role") + if role == "assistant": + for tc in msg.get("tool_calls") or []: + if isinstance(tc, dict) and tc.get("id"): + declared.add(str(tc["id"])) + if role == "tool": + tid = msg.get("tool_call_id") + if tid and str(tid) not in declared: + if updated is None: + updated = [dict(m) for m in messages[:idx]] + continue + if updated is not None: + updated.append(dict(msg)) + + if updated is None: + return messages + return updated + @staticmethod def _backfill_missing_tool_results( messages: list[dict[str, Any]], diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 275b1ea08..d2a0727fb 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -375,6 +375,14 @@ class LLMProvider(ABC): and role in ("user", "assistant") ): prev = merged[-1] + if role == "assistant": + prev_has_tools = bool(prev.get("tool_calls")) + curr_has_tools = bool(msg.get("tool_calls")) + if curr_has_tools: + merged[-1] = dict(msg) + continue + if prev_has_tools: + continue prev_content = prev.get("content") or "" curr_content = msg.get("content") or "" if isinstance(prev_content, str) and isinstance(curr_content, str): diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 95e8b74d3..101ee6c33 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -243,6 +243,10 @@ class OpenAICompatProvider(LLMProvider): tc_clean["id"] = map_id(tc_clean.get("id")) normalized.append(tc_clean) clean["tool_calls"] = normalized + if clean.get("role") == "assistant": + # Some OpenAI-compatible gateways reject assistant messages + # that mix non-empty content with tool_calls. + clean["content"] = None if "tool_call_id" in clean and clean["tool_call_id"]: clean["tool_call_id"] = map_id(clean["tool_call_id"]) return self._enforce_role_alternation(sanitized) diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index a298ed956..ef4206573 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -859,7 +859,11 @@ async def test_loop_retries_think_only_final_response(tmp_path): async def test_llm_error_not_appended_to_session_messages(): """When LLM returns finish_reason='error', the error content must NOT be appended to the messages list (prevents polluting session history).""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import ( + AgentRunSpec, + AgentRunner, + _PERSISTED_MODEL_ERROR_PLACEHOLDER, + ) provider = MagicMock() provider.chat_with_retry = AsyncMock(return_value=LLMResponse( @@ -882,6 +886,7 @@ async def test_llm_error_not_appended_to_session_messages(): assistant_msgs = [m for m in result.messages if m.get("role") == "assistant"] assert all("429" not in (m.get("content") or "") for m in assistant_msgs), \ "Error content should not appear in session messages" + assert assistant_msgs[-1]["content"] == _PERSISTED_MODEL_ERROR_PLACEHOLDER @pytest.mark.asyncio @@ -918,6 +923,56 @@ async def test_streamed_flag_not_set_on_llm_error(tmp_path): "_streamed must not be set when stop_reason is error" +@pytest.mark.asyncio +async def test_next_turn_after_llm_error_keeps_turn_boundary(tmp_path): + from nanobot.agent.loop import AgentLoop + from nanobot.agent.runner import _PERSISTED_MODEL_ERROR_PLACEHOLDER + from nanobot.bus.events import InboundMessage + from nanobot.bus.queue import MessageBus + + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + provider.chat_with_retry = AsyncMock(side_effect=[ + LLMResponse(content="429 rate limit exceeded", finish_reason="error", tool_calls=[], usage={}), + LLMResponse(content="Recovered answer", tool_calls=[], usage={}), + ]) + + loop = AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model") + loop.tools.get_definitions = MagicMock(return_value=[]) + loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] + + first = await loop._process_message( + InboundMessage(channel="cli", sender_id="user", chat_id="test", content="first question") + ) + assert first is not None + assert first.content == "429 rate limit exceeded" + + session = loop.sessions.get_or_create("cli:test") + assert [ + {key: value for key, value in message.items() if key in {"role", "content"}} + for message in session.messages + ] == [ + {"role": "user", "content": "first question"}, + {"role": "assistant", "content": _PERSISTED_MODEL_ERROR_PLACEHOLDER}, + ] + + second = await loop._process_message( + InboundMessage(channel="cli", sender_id="user", chat_id="test", content="second question") + ) + assert second is not None + assert second.content == "Recovered answer" + + request_messages = provider.chat_with_retry.await_args_list[1].kwargs["messages"] + non_system = [message for message in request_messages if message.get("role") != "system"] + assert non_system[0] == {"role": "user", "content": "first question"} + assert non_system[1] == { + "role": "assistant", + "content": _PERSISTED_MODEL_ERROR_PLACEHOLDER, + } + assert non_system[2]["role"] == "user" + assert "second question" in non_system[2]["content"] + + @pytest.mark.asyncio async def test_runner_tool_error_sets_final_content(): from nanobot.agent.runner import AgentRunSpec, AgentRunner @@ -1218,6 +1273,41 @@ async def test_backfill_missing_tool_results_inserts_error(): assert backfilled[0]["name"] == "read_file" +def test_drop_orphan_tool_results_removes_unmatched_tool_messages(): + from nanobot.agent.runner import AgentRunner + + messages = [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "old user"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + {"id": "call_ok", "type": "function", "function": {"name": "read_file", "arguments": "{}"}}, + ], + }, + {"role": "tool", "tool_call_id": "call_ok", "name": "read_file", "content": "ok"}, + {"role": "tool", "tool_call_id": "call_orphan", "name": "exec", "content": "stale"}, + {"role": "assistant", "content": "after tool"}, + ] + + cleaned = AgentRunner._drop_orphan_tool_results(messages) + + assert cleaned == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "old user"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + {"id": "call_ok", "type": "function", "function": {"name": "read_file", "arguments": "{}"}}, + ], + }, + {"role": "tool", "tool_call_id": "call_ok", "name": "read_file", "content": "ok"}, + {"role": "assistant", "content": "after tool"}, + ] + + @pytest.mark.asyncio async def test_backfill_noop_when_complete(): """Complete message chains should not be modified.""" @@ -1239,6 +1329,45 @@ async def test_backfill_noop_when_complete(): assert result is messages # same object — no copy +@pytest.mark.asyncio +async def test_runner_drops_orphan_tool_results_before_model_request(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + captured_messages: list[dict] = [] + + async def chat_with_retry(*, messages, **kwargs): + captured_messages[:] = messages + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[ + {"role": "system", "content": "system"}, + {"role": "user", "content": "old user"}, + {"role": "tool", "tool_call_id": "call_orphan", "name": "exec", "content": "stale"}, + {"role": "assistant", "content": "after orphan"}, + {"role": "user", "content": "new prompt"}, + ], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert all( + message.get("tool_call_id") != "call_orphan" + for message in captured_messages + if message.get("role") == "tool" + ) + assert result.messages[2]["tool_call_id"] == "call_orphan" + assert result.final_content == "done" + + @pytest.mark.asyncio async def test_backfill_repairs_model_context_without_shifting_save_turn_boundary(tmp_path): """Historical backfill should not duplicate old tail messages on persist.""" diff --git a/tests/providers/test_enforce_role_alternation.py b/tests/providers/test_enforce_role_alternation.py index 1fade6e4b..aef57f474 100644 --- a/tests/providers/test_enforce_role_alternation.py +++ b/tests/providers/test_enforce_role_alternation.py @@ -84,6 +84,34 @@ class TestEnforceRoleAlternation: tool_msgs = [m for m in result if m["role"] == "tool"] assert len(tool_msgs) == 2 + def test_consecutive_assistant_keeps_later_tool_call_message(self): + msgs = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Previous reply"}, + {"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]}, + {"role": "tool", "content": "result1", "tool_call_id": "1"}, + {"role": "user", "content": "Next"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert result[1]["role"] == "assistant" + assert result[1]["tool_calls"] == [{"id": "1"}] + assert result[1]["content"] is None + assert result[2]["role"] == "tool" + + def test_consecutive_assistant_does_not_overwrite_existing_tool_call_message(self): + msgs = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]}, + {"role": "assistant", "content": "Later plain assistant"}, + {"role": "tool", "content": "result1", "tool_call_id": "1"}, + {"role": "user", "content": "Next"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert result[1]["role"] == "assistant" + assert result[1]["tool_calls"] == [{"id": "1"}] + assert result[1]["content"] is None + assert result[2]["role"] == "tool" + def test_non_string_content_uses_latest(self): msgs = [ {"role": "user", "content": [{"type": "text", "text": "A"}]}, diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index dfb7cd228..ec2581cdb 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -550,11 +550,40 @@ def test_openai_compat_preserves_message_level_reasoning_fields() -> None: {"role": "user", "content": "thanks"}, ]) + assert sanitized[1]["content"] is None assert sanitized[1]["reasoning_content"] == "hidden" assert sanitized[1]["extra_content"] == {"debug": True} assert sanitized[1]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}} +def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -> None: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + sanitized = provider._sanitize_messages([ + {"role": "user", "content": "不错"}, + {"role": "assistant", "content": "对,破 4 万指日可待"}, + { + "role": "assistant", + "content": "我再查一下", + "tool_calls": [ + { + "id": "call_function_akxp3wqzn7ph_1", + "type": "function", + "function": {"name": "exec", "arguments": "{}"}, + } + ], + }, + {"role": "tool", "tool_call_id": "call_function_akxp3wqzn7ph_1", "name": "exec", "content": "ok"}, + {"role": "user", "content": "多少star了呢"}, + ]) + + assert sanitized[1]["role"] == "assistant" + assert sanitized[1]["content"] is None + assert sanitized[1]["tool_calls"][0]["id"] == "3ec83c30d" + assert sanitized[2]["tool_call_id"] == "3ec83c30d" + + @pytest.mark.asyncio async def test_openai_compat_stream_watchdog_returns_error_on_stall(monkeypatch) -> None: monkeypatch.setenv("NANOBOT_STREAM_IDLE_TIMEOUT_S", "0") From e7e12495859df5abc89bc9e4614cb07148493199 Mon Sep 17 00:00:00 2001 From: "zhangxiaoyu.york" Date: Fri, 10 Apr 2026 12:13:58 +0800 Subject: [PATCH 052/115] fix(agent): avoid truncate_text name shadowing Rename the boolean flag in _sanitize_persisted_blocks and alias the imported helper so session persistence cannot crash with TypeError when truncation is enabled. --- nanobot/agent/loop.py | 12 +++++------ tests/test_truncate_text_shadowing.py | 31 +++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 6 deletions(-) create mode 100644 tests/test_truncate_text_shadowing.py diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 16a086fdb..bc83cc77c 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -33,7 +33,7 @@ from nanobot.bus.queue import MessageBus from nanobot.config.schema import AgentDefaults from nanobot.providers.base import LLMProvider from nanobot.session.manager import Session, SessionManager -from nanobot.utils.helpers import image_placeholder_text, truncate_text +from nanobot.utils.helpers import image_placeholder_text, truncate_text as truncate_text_fn from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE if TYPE_CHECKING: @@ -590,7 +590,7 @@ class AgentLoop: self, content: list[dict[str, Any]], *, - truncate_text: bool = False, + should_truncate_text: bool = False, drop_runtime: bool = False, ) -> list[dict[str, Any]]: """Strip volatile multimodal payloads before writing session history.""" @@ -618,8 +618,8 @@ class AgentLoop: if block.get("type") == "text" and isinstance(block.get("text"), str): text = block["text"] - if truncate_text and len(text) > self.max_tool_result_chars: - text = truncate_text(text, self.max_tool_result_chars) + if should_truncate_text and len(text) > self.max_tool_result_chars: + text = truncate_text_fn(text, self.max_tool_result_chars) filtered.append({**block, "text": text}) continue @@ -637,9 +637,9 @@ class AgentLoop: continue # skip empty assistant messages — they poison session context if role == "tool": if isinstance(content, str) and len(content) > self.max_tool_result_chars: - entry["content"] = truncate_text(content, self.max_tool_result_chars) + entry["content"] = truncate_text_fn(content, self.max_tool_result_chars) elif isinstance(content, list): - filtered = self._sanitize_persisted_blocks(content, truncate_text=True) + filtered = self._sanitize_persisted_blocks(content, should_truncate_text=True) if not filtered: continue entry["content"] = filtered diff --git a/tests/test_truncate_text_shadowing.py b/tests/test_truncate_text_shadowing.py new file mode 100644 index 000000000..11132b511 --- /dev/null +++ b/tests/test_truncate_text_shadowing.py @@ -0,0 +1,31 @@ +import inspect +from types import SimpleNamespace + + +def test_sanitize_persisted_blocks_truncate_text_shadowing_regression() -> None: + """Regression: avoid bool param shadowing imported truncate_text. + + Buggy behavior (historical): + - loop.py imports `truncate_text` from helpers + - `_sanitize_persisted_blocks(..., truncate_text: bool=...)` uses same name + - when called with `truncate_text=True`, function body executes `truncate_text(text, ...)` + which resolves to bool and raises `TypeError: 'bool' object is not callable`. + + This test asserts the fixed API exists and truncation works without raising. + """ + + from nanobot.agent.loop import AgentLoop + + sig = inspect.signature(AgentLoop._sanitize_persisted_blocks) + assert "should_truncate_text" in sig.parameters + assert "truncate_text" not in sig.parameters + + dummy = SimpleNamespace(max_tool_result_chars=5) + content = [{"type": "text", "text": "0123456789"}] + + out = AgentLoop._sanitize_persisted_blocks(dummy, content, should_truncate_text=True) + assert isinstance(out, list) + assert out and out[0]["type"] == "text" + assert isinstance(out[0]["text"], str) + assert out[0]["text"] != content[0]["text"] + From 1a51f907aa2ca578101faafa4458953d5c34c1fa Mon Sep 17 00:00:00 2001 From: weitongtong Date: Fri, 10 Apr 2026 16:10:15 +0800 Subject: [PATCH 053/115] =?UTF-8?q?feat(cron):=20=E6=B7=BB=E5=8A=A0=20Cron?= =?UTF-8?q?Service.update=5Fjob=20=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 支持更新已有定时任务的名称、调度计划、消息内容、投递配置等可变字段。 系统任务(system_event)受保护不可编辑。包含完整的单元测试覆盖。 Made-with: Cursor --- nanobot/cron/service.py | 53 +++++++++++++ tests/cron/test_cron_service.py | 127 ++++++++++++++++++++++++++++++++ 2 files changed, 180 insertions(+) diff --git a/nanobot/cron/service.py b/nanobot/cron/service.py index 1259d3d72..267613012 100644 --- a/nanobot/cron/service.py +++ b/nanobot/cron/service.py @@ -460,6 +460,59 @@ class CronService: return job return None + def update_job( + self, + job_id: str, + *, + name: str | None = None, + schedule: CronSchedule | None = None, + message: str | None = None, + deliver: bool | None = None, + channel: str | None = ..., + to: str | None = ..., + delete_after_run: bool | None = None, + ) -> CronJob | Literal["not_found", "protected"]: + """Update mutable fields of an existing job. System jobs cannot be updated. + + For ``channel`` and ``to``, pass an explicit value (including ``None``) + to update; omit (sentinel ``...``) to leave unchanged. + """ + store = self._load_store() + job = next((j for j in store.jobs if j.id == job_id), None) + if job is None: + return "not_found" + if job.payload.kind == "system_event": + return "protected" + + if schedule is not None: + _validate_schedule_for_add(schedule) + job.schedule = schedule + if name is not None: + job.name = name + if message is not None: + job.payload.message = message + if deliver is not None: + job.payload.deliver = deliver + if channel is not ...: + job.payload.channel = channel + if to is not ...: + job.payload.to = to + if delete_after_run is not None: + job.delete_after_run = delete_after_run + + job.updated_at_ms = _now_ms() + if job.enabled: + job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms()) + + if self._running: + self._save_store() + self._arm_timer() + else: + self._append_action("update", asdict(job)) + + logger.info("Cron: updated job '{}' ({})", job.name, job.id) + return job + async def run_job(self, job_id: str, force: bool = False) -> bool: """Manually run a job without disturbing the service's running state.""" was_running = self._running diff --git a/tests/cron/test_cron_service.py b/tests/cron/test_cron_service.py index b54cf5e20..d10d9569f 100644 --- a/tests/cron/test_cron_service.py +++ b/tests/cron/test_cron_service.py @@ -327,3 +327,130 @@ async def test_external_update_preserves_run_history_records(tmp_path): fresh._running = True fresh._save_store() + + +# ── update_job tests ── + + +def test_update_job_changes_name(tmp_path) -> None: + service = CronService(tmp_path / "cron" / "jobs.json") + job = service.add_job( + name="old name", + schedule=CronSchedule(kind="every", every_ms=60_000), + message="hello", + ) + result = service.update_job(job.id, name="new name") + assert isinstance(result, CronJob) + assert result.name == "new name" + assert result.payload.message == "hello" + + +def test_update_job_changes_schedule(tmp_path) -> None: + service = CronService(tmp_path / "cron" / "jobs.json") + job = service.add_job( + name="sched", + schedule=CronSchedule(kind="every", every_ms=60_000), + message="hello", + ) + old_next = job.state.next_run_at_ms + + new_sched = CronSchedule(kind="every", every_ms=120_000) + result = service.update_job(job.id, schedule=new_sched) + assert isinstance(result, CronJob) + assert result.schedule.every_ms == 120_000 + assert result.state.next_run_at_ms != old_next + + +def test_update_job_changes_message(tmp_path) -> None: + service = CronService(tmp_path / "cron" / "jobs.json") + job = service.add_job( + name="msg", + schedule=CronSchedule(kind="every", every_ms=60_000), + message="old message", + ) + result = service.update_job(job.id, message="new message") + assert isinstance(result, CronJob) + assert result.payload.message == "new message" + + +def test_update_job_changes_cron_expression(tmp_path) -> None: + service = CronService(tmp_path / "cron" / "jobs.json") + job = service.add_job( + name="cron-job", + schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"), + message="hello", + ) + result = service.update_job( + job.id, + schedule=CronSchedule(kind="cron", expr="0 18 * * *", tz="UTC"), + ) + assert isinstance(result, CronJob) + assert result.schedule.expr == "0 18 * * *" + assert result.state.next_run_at_ms is not None + + +def test_update_job_not_found(tmp_path) -> None: + service = CronService(tmp_path / "cron" / "jobs.json") + result = service.update_job("nonexistent", name="x") + assert result == "not_found" + + +def test_update_job_rejects_system_job(tmp_path) -> None: + service = CronService(tmp_path / "cron" / "jobs.json") + service.register_system_job(CronJob( + id="dream", + name="dream", + schedule=CronSchedule(kind="cron", expr="0 */2 * * *", tz="UTC"), + payload=CronPayload(kind="system_event"), + )) + result = service.update_job("dream", name="hacked") + assert result == "protected" + assert service.get_job("dream").name == "dream" + + +def test_update_job_validates_schedule(tmp_path) -> None: + service = CronService(tmp_path / "cron" / "jobs.json") + job = service.add_job( + name="validate", + schedule=CronSchedule(kind="every", every_ms=60_000), + message="hello", + ) + with pytest.raises(ValueError, match="unknown timezone"): + service.update_job( + job.id, + schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="Bad/Zone"), + ) + + +def test_update_job_preserves_run_history(tmp_path) -> None: + import asyncio + store_path = tmp_path / "cron" / "jobs.json" + service = CronService(store_path, on_job=lambda _: asyncio.sleep(0)) + job = service.add_job( + name="hist", + schedule=CronSchedule(kind="every", every_ms=60_000), + message="hello", + ) + asyncio.get_event_loop().run_until_complete(service.run_job(job.id)) + + result = service.update_job(job.id, name="renamed") + assert isinstance(result, CronJob) + assert len(result.state.run_history) == 1 + assert result.state.run_history[0].status == "ok" + + +def test_update_job_offline_writes_action(tmp_path) -> None: + service = CronService(tmp_path / "cron" / "jobs.json") + job = service.add_job( + name="offline", + schedule=CronSchedule(kind="every", every_ms=60_000), + message="hello", + ) + service.update_job(job.id, name="updated-offline") + + action_path = tmp_path / "cron" / "action.jsonl" + assert action_path.exists() + lines = [l for l in action_path.read_text().strip().split("\n") if l] + last = json.loads(lines[-1]) + assert last["action"] == "update" + assert last["params"]["name"] == "updated-offline" From 9bccfa63d2ad4416411913752b6948c0cfc3b3d0 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Fri, 10 Apr 2026 11:02:02 +0000 Subject: [PATCH 054/115] fix test: use async/await for run_job, add sentinel coverage Made-with: Cursor --- tests/cron/test_cron_service.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/tests/cron/test_cron_service.py b/tests/cron/test_cron_service.py index d10d9569f..f1956d8d2 100644 --- a/tests/cron/test_cron_service.py +++ b/tests/cron/test_cron_service.py @@ -422,7 +422,8 @@ def test_update_job_validates_schedule(tmp_path) -> None: ) -def test_update_job_preserves_run_history(tmp_path) -> None: +@pytest.mark.asyncio +async def test_update_job_preserves_run_history(tmp_path) -> None: import asyncio store_path = tmp_path / "cron" / "jobs.json" service = CronService(store_path, on_job=lambda _: asyncio.sleep(0)) @@ -431,7 +432,7 @@ def test_update_job_preserves_run_history(tmp_path) -> None: schedule=CronSchedule(kind="every", every_ms=60_000), message="hello", ) - asyncio.get_event_loop().run_until_complete(service.run_job(job.id)) + await service.run_job(job.id) result = service.update_job(job.id, name="renamed") assert isinstance(result, CronJob) @@ -454,3 +455,27 @@ def test_update_job_offline_writes_action(tmp_path) -> None: last = json.loads(lines[-1]) assert last["action"] == "update" assert last["params"]["name"] == "updated-offline" + + +def test_update_job_sentinel_channel_and_to(tmp_path) -> None: + """Passing None clears channel/to; omitting leaves them unchanged.""" + service = CronService(tmp_path / "cron" / "jobs.json") + job = service.add_job( + name="sentinel", + schedule=CronSchedule(kind="every", every_ms=60_000), + message="hello", + channel="telegram", + to="user123", + ) + assert job.payload.channel == "telegram" + assert job.payload.to == "user123" + + result = service.update_job(job.id, name="renamed") + assert isinstance(result, CronJob) + assert result.payload.channel == "telegram" + assert result.payload.to == "user123" + + result = service.update_job(job.id, channel=None, to=None) + assert isinstance(result, CronJob) + assert result.payload.channel is None + assert result.payload.to is None From 651aeae656e33b029adba1eeab5af1aee05d4df4 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Fri, 10 Apr 2026 15:44:50 +0000 Subject: [PATCH 055/115] improve file editing and add notebook tool Enhance file tools with read tracking, PDF support, safer path handling, smarter edit matching/diagnostics, and introduce notebook_edit with tests. --- nanobot/agent/loop.py | 2 + nanobot/agent/tools/file_state.py | 105 ++++++ nanobot/agent/tools/filesystem.py | 503 ++++++++++++++++++++++++-- nanobot/agent/tools/notebook.py | 162 +++++++++ pyproject.toml | 4 + tests/tools/test_edit_advanced.py | 423 ++++++++++++++++++++++ tests/tools/test_edit_enhancements.py | 152 ++++++++ tests/tools/test_notebook_tool.py | 147 ++++++++ tests/tools/test_read_enhancements.py | 180 +++++++++ 9 files changed, 1638 insertions(+), 40 deletions(-) create mode 100644 nanobot/agent/tools/file_state.py create mode 100644 nanobot/agent/tools/notebook.py create mode 100644 tests/tools/test_edit_advanced.py create mode 100644 tests/tools/test_edit_enhancements.py create mode 100644 tests/tools/test_notebook_tool.py create mode 100644 tests/tools/test_read_enhancements.py diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index bc83cc77c..56d79f3f9 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -22,6 +22,7 @@ from nanobot.agent.tools.cron import CronTool from nanobot.agent.skills import BUILTIN_SKILLS_DIR from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool from nanobot.agent.tools.message import MessageTool +from nanobot.agent.tools.notebook import NotebookEditTool from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.search import GlobTool, GrepTool from nanobot.agent.tools.shell import ExecTool @@ -235,6 +236,7 @@ class AgentLoop: self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir)) for cls in (GlobTool, GrepTool): self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir)) + self.tools.register(NotebookEditTool(workspace=self.workspace, allowed_dir=allowed_dir)) if self.exec_config.enable: self.tools.register(ExecTool( working_dir=str(self.workspace), diff --git a/nanobot/agent/tools/file_state.py b/nanobot/agent/tools/file_state.py new file mode 100644 index 000000000..81b1d4485 --- /dev/null +++ b/nanobot/agent/tools/file_state.py @@ -0,0 +1,105 @@ +"""Track file-read state for read-before-edit warnings and read deduplication.""" + +from __future__ import annotations + +import hashlib +import os +from dataclasses import dataclass +from pathlib import Path + + +@dataclass(slots=True) +class ReadState: + mtime: float + offset: int + limit: int | None + content_hash: str | None + can_dedup: bool + + +_state: dict[str, ReadState] = {} + + +def _hash_file(p: str) -> str | None: + try: + return hashlib.sha256(Path(p).read_bytes()).hexdigest() + except OSError: + return None + + +def record_read(path: str | Path, offset: int = 1, limit: int | None = None) -> None: + """Record that a file was read (called after successful read).""" + p = str(Path(path).resolve()) + try: + mtime = os.path.getmtime(p) + except OSError: + return + _state[p] = ReadState( + mtime=mtime, + offset=offset, + limit=limit, + content_hash=_hash_file(p), + can_dedup=True, + ) + + +def record_write(path: str | Path) -> None: + """Record that a file was written (updates mtime in state).""" + p = str(Path(path).resolve()) + try: + mtime = os.path.getmtime(p) + except OSError: + _state.pop(p, None) + return + _state[p] = ReadState( + mtime=mtime, + offset=1, + limit=None, + content_hash=_hash_file(p), + can_dedup=False, + ) + + +def check_read(path: str | Path) -> str | None: + """Check if a file has been read and is fresh. + + Returns None if OK, or a warning string. + When mtime changed but file content is identical (e.g. touch, editor save), + the check passes to avoid false-positive staleness warnings. + """ + p = str(Path(path).resolve()) + entry = _state.get(p) + if entry is None: + return "Warning: file has not been read yet. Read it first to verify content before editing." + try: + current_mtime = os.path.getmtime(p) + except OSError: + return None + if current_mtime != entry.mtime: + if entry.content_hash and _hash_file(p) == entry.content_hash: + entry.mtime = current_mtime + return None + return "Warning: file has been modified since last read. Re-read to verify content before editing." + return None + + +def is_unchanged(path: str | Path, offset: int = 1, limit: int | None = None) -> bool: + """Return True if file was previously read with same params and mtime is unchanged.""" + p = str(Path(path).resolve()) + entry = _state.get(p) + if entry is None: + return False + if not entry.can_dedup: + return False + if entry.offset != offset or entry.limit != limit: + return False + try: + current_mtime = os.path.getmtime(p) + except OSError: + return False + return current_mtime == entry.mtime + + +def clear() -> None: + """Clear all tracked state (useful for testing).""" + _state.clear() diff --git a/nanobot/agent/tools/filesystem.py b/nanobot/agent/tools/filesystem.py index fdce38b69..e131a2e69 100644 --- a/nanobot/agent/tools/filesystem.py +++ b/nanobot/agent/tools/filesystem.py @@ -2,11 +2,13 @@ import difflib import mimetypes +from dataclasses import dataclass from pathlib import Path from typing import Any from nanobot.agent.tools.base import Tool, tool_parameters from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema +from nanobot.agent.tools import file_state from nanobot.utils.helpers import build_image_content_blocks, detect_image_mime from nanobot.config.paths import get_media_dir @@ -60,6 +62,36 @@ class _FsTool(Tool): # --------------------------------------------------------------------------- +_BLOCKED_DEVICE_PATHS = frozenset({ + "/dev/zero", "/dev/random", "/dev/urandom", "/dev/full", + "/dev/stdin", "/dev/stdout", "/dev/stderr", + "/dev/tty", "/dev/console", + "/dev/fd/0", "/dev/fd/1", "/dev/fd/2", +}) + + +def _is_blocked_device(path: str | Path) -> bool: + """Check if path is a blocked device that could hang or produce infinite output.""" + import re + raw = str(path) + if raw in _BLOCKED_DEVICE_PATHS: + return True + if re.match(r"/proc/\d+/fd/[012]$", raw) or re.match(r"/proc/self/fd/[012]$", raw): + return True + return False + + +def _parse_page_range(pages: str, total: int) -> tuple[int, int]: + """Parse a page range like '2-5' into 0-based (start, end) inclusive.""" + parts = pages.strip().split("-") + if len(parts) == 1: + p = int(parts[0]) + return max(0, p - 1), min(p - 1, total - 1) + start = int(parts[0]) + end = int(parts[1]) + return max(0, start - 1), min(end - 1, total - 1) + + @tool_parameters( tool_parameters_schema( path=StringSchema("The file path to read"), @@ -73,6 +105,7 @@ class _FsTool(Tool): description="Maximum number of lines to read (default 2000)", minimum=1, ), + pages=StringSchema("Page range for PDF files, e.g. '1-5' (default: all, max 20 pages)"), required=["path"], ) ) @@ -81,6 +114,7 @@ class ReadFileTool(_FsTool): _MAX_CHARS = 128_000 _DEFAULT_LIMIT = 2000 + _MAX_PDF_PAGES = 20 @property def name(self) -> str: @@ -89,9 +123,10 @@ class ReadFileTool(_FsTool): @property def description(self) -> str: return ( - "Read a text file. Output format: LINE_NUM|CONTENT. " + "Read a file (text or image). Text output format: LINE_NUM|CONTENT. " + "Images return visual content for analysis. " "Use offset and limit for large files. " - "Cannot read binary files or images. " + "Cannot read non-image binary files. " "Reads exceeding ~128K chars are truncated." ) @@ -99,16 +134,27 @@ class ReadFileTool(_FsTool): def read_only(self) -> bool: return True - async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, **kwargs: Any) -> Any: + async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, pages: str | None = None, **kwargs: Any) -> Any: try: if not path: return "Error reading file: Unknown path" + + # Device path blacklist + if _is_blocked_device(path): + return f"Error: Reading {path} is blocked (device path that could hang or produce infinite output)." + fp = self._resolve(path) + if _is_blocked_device(fp): + return f"Error: Reading {fp} is blocked (device path that could hang or produce infinite output)." if not fp.exists(): return f"Error: File not found: {path}" if not fp.is_file(): return f"Error: Not a file: {path}" + # PDF support + if fp.suffix.lower() == ".pdf": + return self._read_pdf(fp, pages) + raw = fp.read_bytes() if not raw: return f"(Empty file: {path})" @@ -117,6 +163,10 @@ class ReadFileTool(_FsTool): if mime and mime.startswith("image/"): return build_image_content_blocks(raw, mime, str(fp), f"(Image file: {path})") + # Read dedup: same path + offset + limit + unchanged mtime → stub + if file_state.is_unchanged(fp, offset=offset, limit=limit): + return f"[File unchanged since last read: {path}]" + try: text_content = raw.decode("utf-8") except UnicodeDecodeError: @@ -149,12 +199,59 @@ class ReadFileTool(_FsTool): result += f"\n\n(Showing lines {offset}-{end} of {total}. Use offset={end + 1} to continue.)" else: result += f"\n\n(End of file — {total} lines total)" + file_state.record_read(fp, offset=offset, limit=limit) return result except PermissionError as e: return f"Error: {e}" except Exception as e: return f"Error reading file: {e}" + def _read_pdf(self, fp: Path, pages: str | None) -> str: + try: + import fitz # pymupdf + except ImportError: + return "Error: PDF reading requires pymupdf. Install with: pip install pymupdf" + + try: + doc = fitz.open(str(fp)) + except Exception as e: + return f"Error reading PDF: {e}" + + total_pages = len(doc) + if pages: + try: + start, end = _parse_page_range(pages, total_pages) + except (ValueError, IndexError): + doc.close() + return f"Error: Invalid page range '{pages}'. Use format like '1-5'." + if start > end or start >= total_pages: + doc.close() + return f"Error: Page range '{pages}' is out of bounds (document has {total_pages} pages)." + else: + start = 0 + end = min(total_pages - 1, self._MAX_PDF_PAGES - 1) + + if end - start + 1 > self._MAX_PDF_PAGES: + end = start + self._MAX_PDF_PAGES - 1 + + parts: list[str] = [] + for i in range(start, end + 1): + page = doc[i] + text = page.get_text().strip() + if text: + parts.append(f"--- Page {i + 1} ---\n{text}") + doc.close() + + if not parts: + return f"(PDF has no extractable text: {fp})" + + result = "\n\n".join(parts) + if end < total_pages - 1: + result += f"\n\n(Showing pages {start + 1}-{end + 1} of {total_pages}. Use pages='{end + 2}-{min(end + 1 + self._MAX_PDF_PAGES, total_pages)}' to continue.)" + if len(result) > self._MAX_CHARS: + result = result[:self._MAX_CHARS] + "\n\n(PDF text truncated at ~128K chars)" + return result + # --------------------------------------------------------------------------- # write_file @@ -192,6 +289,7 @@ class WriteFileTool(_FsTool): fp = self._resolve(path) fp.parent.mkdir(parents=True, exist_ok=True) fp.write_text(content, encoding="utf-8") + file_state.record_write(fp) return f"Successfully wrote {len(content)} characters to {fp}" except PermissionError as e: return f"Error: {e}" @@ -203,30 +301,269 @@ class WriteFileTool(_FsTool): # edit_file # --------------------------------------------------------------------------- +_QUOTE_TABLE = str.maketrans({ + "\u2018": "'", "\u2019": "'", # curly single → straight + "\u201c": '"', "\u201d": '"', # curly double → straight + "'": "'", '"': '"', # identity (kept for completeness) +}) + + +def _normalize_quotes(s: str) -> str: + return s.translate(_QUOTE_TABLE) + + +def _curly_double_quotes(text: str) -> str: + parts: list[str] = [] + opening = True + for ch in text: + if ch == '"': + parts.append("\u201c" if opening else "\u201d") + opening = not opening + else: + parts.append(ch) + return "".join(parts) + + +def _curly_single_quotes(text: str) -> str: + parts: list[str] = [] + opening = True + for i, ch in enumerate(text): + if ch != "'": + parts.append(ch) + continue + prev_ch = text[i - 1] if i > 0 else "" + next_ch = text[i + 1] if i + 1 < len(text) else "" + if prev_ch.isalnum() and next_ch.isalnum(): + parts.append("\u2019") + continue + parts.append("\u2018" if opening else "\u2019") + opening = not opening + return "".join(parts) + + +def _preserve_quote_style(old_text: str, actual_text: str, new_text: str) -> str: + """Preserve curly quote style when a quote-normalized fallback matched.""" + if _normalize_quotes(old_text.strip()) != _normalize_quotes(actual_text.strip()) or old_text == actual_text: + return new_text + + styled = new_text + if any(ch in actual_text for ch in ("\u201c", "\u201d")) and '"' in styled: + styled = _curly_double_quotes(styled) + if any(ch in actual_text for ch in ("\u2018", "\u2019")) and "'" in styled: + styled = _curly_single_quotes(styled) + return styled + + +def _leading_ws(line: str) -> str: + return line[: len(line) - len(line.lstrip(" \t"))] + + +def _reindent_like_match(old_text: str, actual_text: str, new_text: str) -> str: + """Preserve the outer indentation from the actual matched block.""" + old_lines = old_text.split("\n") + actual_lines = actual_text.split("\n") + if len(old_lines) != len(actual_lines): + return new_text + + comparable = [ + (old_line, actual_line) + for old_line, actual_line in zip(old_lines, actual_lines) + if old_line.strip() and actual_line.strip() + ] + if not comparable or any( + _normalize_quotes(old_line.strip()) != _normalize_quotes(actual_line.strip()) + for old_line, actual_line in comparable + ): + return new_text + + old_ws = _leading_ws(comparable[0][0]) + actual_ws = _leading_ws(comparable[0][1]) + if actual_ws == old_ws: + return new_text + + if old_ws: + if not actual_ws.startswith(old_ws): + return new_text + delta = actual_ws[len(old_ws):] + else: + delta = actual_ws + + if not delta: + return new_text + + return "\n".join((delta + line) if line else line for line in new_text.split("\n")) + + +@dataclass(slots=True) +class _MatchSpan: + start: int + end: int + text: str + line: int + + +def _find_exact_matches(content: str, old_text: str) -> list[_MatchSpan]: + matches: list[_MatchSpan] = [] + start = 0 + while True: + idx = content.find(old_text, start) + if idx == -1: + break + matches.append( + _MatchSpan( + start=idx, + end=idx + len(old_text), + text=content[idx : idx + len(old_text)], + line=content.count("\n", 0, idx) + 1, + ) + ) + start = idx + max(1, len(old_text)) + return matches + + +def _find_trim_matches(content: str, old_text: str, *, normalize_quotes: bool = False) -> list[_MatchSpan]: + old_lines = old_text.splitlines() + if not old_lines: + return [] + + content_lines = content.splitlines() + content_lines_keepends = content.splitlines(keepends=True) + if len(content_lines) < len(old_lines): + return [] + + offsets: list[int] = [] + pos = 0 + for line in content_lines_keepends: + offsets.append(pos) + pos += len(line) + offsets.append(pos) + + if normalize_quotes: + stripped_old = [_normalize_quotes(line.strip()) for line in old_lines] + else: + stripped_old = [line.strip() for line in old_lines] + + matches: list[_MatchSpan] = [] + window_size = len(stripped_old) + for i in range(len(content_lines) - window_size + 1): + window = content_lines[i : i + window_size] + if normalize_quotes: + comparable = [_normalize_quotes(line.strip()) for line in window] + else: + comparable = [line.strip() for line in window] + if comparable != stripped_old: + continue + + start = offsets[i] + end = offsets[i + window_size] + if content_lines_keepends[i + window_size - 1].endswith("\n"): + end -= 1 + matches.append( + _MatchSpan( + start=start, + end=end, + text=content[start:end], + line=i + 1, + ) + ) + return matches + + +def _find_quote_matches(content: str, old_text: str) -> list[_MatchSpan]: + norm_content = _normalize_quotes(content) + norm_old = _normalize_quotes(old_text) + matches: list[_MatchSpan] = [] + start = 0 + while True: + idx = norm_content.find(norm_old, start) + if idx == -1: + break + matches.append( + _MatchSpan( + start=idx, + end=idx + len(old_text), + text=content[idx : idx + len(old_text)], + line=content.count("\n", 0, idx) + 1, + ) + ) + start = idx + max(1, len(norm_old)) + return matches + + +def _find_matches(content: str, old_text: str) -> list[_MatchSpan]: + """Locate all matches using progressively looser strategies.""" + for matcher in ( + lambda: _find_exact_matches(content, old_text), + lambda: _find_trim_matches(content, old_text), + lambda: _find_trim_matches(content, old_text, normalize_quotes=True), + lambda: _find_quote_matches(content, old_text), + ): + matches = matcher() + if matches: + return matches + return [] + + +def _find_match_line_numbers(content: str, old_text: str) -> list[int]: + """Return 1-based starting line numbers for the current matching strategies.""" + return [match.line for match in _find_matches(content, old_text)] + + +def _collapse_internal_whitespace(text: str) -> str: + return "\n".join(" ".join(line.split()) for line in text.splitlines()) + + +def _diagnose_near_match(old_text: str, actual_text: str) -> list[str]: + """Return actionable hints describing why text was close but not exact.""" + hints: list[str] = [] + + if old_text.lower() == actual_text.lower() and old_text != actual_text: + hints.append("letter case differs") + if _collapse_internal_whitespace(old_text) == _collapse_internal_whitespace(actual_text) and old_text != actual_text: + hints.append("whitespace differs") + if old_text.rstrip("\n") == actual_text.rstrip("\n") and old_text != actual_text: + hints.append("trailing newline differs") + if _normalize_quotes(old_text) == _normalize_quotes(actual_text) and old_text != actual_text: + hints.append("quote style differs") + + return hints + + +def _best_window(old_text: str, content: str) -> tuple[float, int, list[str], list[str]]: + """Find the closest line-window match and return ratio/start/snippet/hints.""" + lines = content.splitlines(keepends=True) + old_lines = old_text.splitlines(keepends=True) + window = max(1, len(old_lines)) + + best_ratio, best_start = -1.0, 0 + best_window_lines: list[str] = [] + + for i in range(max(1, len(lines) - window + 1)): + current = lines[i : i + window] + ratio = difflib.SequenceMatcher(None, old_lines, current).ratio() + if ratio > best_ratio: + best_ratio, best_start = ratio, i + best_window_lines = current + + actual_text = "".join(best_window_lines).replace("\r\n", "\n").rstrip("\n") + hints = _diagnose_near_match(old_text.replace("\r\n", "\n").rstrip("\n"), actual_text) + return best_ratio, best_start, best_window_lines, hints + + def _find_match(content: str, old_text: str) -> tuple[str | None, int]: - """Locate old_text in content: exact first, then line-trimmed sliding window. + """Locate old_text in content with a multi-level fallback chain: + + 1. Exact substring match + 2. Line-trimmed sliding window (handles indentation differences) + 3. Smart quote normalization (curly ↔ straight quotes) Both inputs should use LF line endings (caller normalises CRLF). Returns (matched_fragment, count) or (None, 0). """ - if old_text in content: - return old_text, content.count(old_text) - - old_lines = old_text.splitlines() - if not old_lines: + matches = _find_matches(content, old_text) + if not matches: return None, 0 - stripped_old = [l.strip() for l in old_lines] - content_lines = content.splitlines() - - candidates = [] - for i in range(len(content_lines) - len(stripped_old) + 1): - window = content_lines[i : i + len(stripped_old)] - if [l.strip() for l in window] == stripped_old: - candidates.append("\n".join(window)) - - if candidates: - return candidates[0], len(candidates) - return None, 0 + return matches[0].text, len(matches) @tool_parameters( @@ -241,6 +578,9 @@ def _find_match(content: str, old_text: str) -> tuple[str | None, int]: class EditFileTool(_FsTool): """Edit a file by replacing text with fallback matching.""" + _MAX_EDIT_FILE_SIZE = 1024 * 1024 * 1024 # 1 GiB + _MARKDOWN_EXTS = frozenset({".md", ".mdx", ".markdown"}) + @property def name(self) -> str: return "edit_file" @@ -249,11 +589,16 @@ class EditFileTool(_FsTool): def description(self) -> str: return ( "Edit a file by replacing old_text with new_text. " - "Tolerates minor whitespace/indentation differences. " + "Tolerates minor whitespace/indentation differences and curly/straight quote mismatches. " "If old_text matches multiple times, you must provide more context " "or set replace_all=true. Shows a diff of the closest match on failure." ) + @staticmethod + def _strip_trailing_ws(text: str) -> str: + """Strip trailing whitespace from each line.""" + return "\n".join(line.rstrip() for line in text.split("\n")) + async def execute( self, path: str | None = None, old_text: str | None = None, new_text: str | None = None, @@ -267,55 +612,133 @@ class EditFileTool(_FsTool): if new_text is None: raise ValueError("Unknown new_text") + # .ipynb detection + if path.endswith(".ipynb"): + return "Error: This is a Jupyter notebook. Use the notebook_edit tool instead of edit_file." + fp = self._resolve(path) + + # Create-file semantics: old_text='' + file doesn't exist → create if not fp.exists(): - return f"Error: File not found: {path}" + if old_text == "": + fp.parent.mkdir(parents=True, exist_ok=True) + fp.write_text(new_text, encoding="utf-8") + file_state.record_write(fp) + return f"Successfully created {fp}" + return self._file_not_found_msg(path, fp) + + # File size protection + try: + fsize = fp.stat().st_size + except OSError: + fsize = 0 + if fsize > self._MAX_EDIT_FILE_SIZE: + return f"Error: File too large to edit ({fsize / (1024**3):.1f} GiB). Maximum is 1 GiB." + + # Create-file: old_text='' but file exists and not empty → reject + if old_text == "": + raw = fp.read_bytes() + content = raw.decode("utf-8") + if content.strip(): + return f"Error: Cannot create file — {path} already exists and is not empty." + fp.write_text(new_text, encoding="utf-8") + file_state.record_write(fp) + return f"Successfully edited {fp}" + + # Read-before-edit check + warning = file_state.check_read(fp) raw = fp.read_bytes() uses_crlf = b"\r\n" in raw content = raw.decode("utf-8").replace("\r\n", "\n") - match, count = _find_match(content, old_text.replace("\r\n", "\n")) + norm_old = old_text.replace("\r\n", "\n") + matches = _find_matches(content, norm_old) - if match is None: + if not matches: return self._not_found_msg(old_text, content, path) + count = len(matches) if count > 1 and not replace_all: + line_numbers = [match.line for match in matches] + preview = ", ".join(f"line {n}" for n in line_numbers[:3]) + if len(line_numbers) > 3: + preview += ", ..." + location_hint = f" at {preview}" if preview else "" return ( - f"Warning: old_text appears {count} times. " + f"Warning: old_text appears {count} times{location_hint}. " "Provide more context to make it unique, or set replace_all=true." ) norm_new = new_text.replace("\r\n", "\n") - new_content = content.replace(match, norm_new) if replace_all else content.replace(match, norm_new, 1) + + # Trailing whitespace stripping (skip markdown to preserve double-space line breaks) + if fp.suffix.lower() not in self._MARKDOWN_EXTS: + norm_new = self._strip_trailing_ws(norm_new) + + selected = matches if replace_all else matches[:1] + new_content = content + for match in reversed(selected): + replacement = _preserve_quote_style(norm_old, match.text, norm_new) + replacement = _reindent_like_match(norm_old, match.text, replacement) + + # Delete-line cleanup: when deleting text (new_text=''), consume trailing + # newline to avoid leaving a blank line + end = match.end + if replacement == "" and not match.text.endswith("\n") and content[end:end + 1] == "\n": + end += 1 + + new_content = new_content[: match.start] + replacement + new_content[end:] if uses_crlf: new_content = new_content.replace("\n", "\r\n") fp.write_bytes(new_content.encode("utf-8")) - return f"Successfully edited {fp}" + file_state.record_write(fp) + msg = f"Successfully edited {fp}" + if warning: + msg = f"{warning}\n{msg}" + return msg except PermissionError as e: return f"Error: {e}" except Exception as e: return f"Error editing file: {e}" + def _file_not_found_msg(self, path: str, fp: Path) -> str: + """Build an error message with 'Did you mean ...?' suggestions.""" + parent = fp.parent + suggestions: list[str] = [] + if parent.is_dir(): + siblings = [f.name for f in parent.iterdir() if f.is_file()] + close = difflib.get_close_matches(fp.name, siblings, n=3, cutoff=0.6) + suggestions = [str(parent / c) for c in close] + parts = [f"Error: File not found: {path}"] + if suggestions: + parts.append("Did you mean: " + ", ".join(suggestions) + "?") + return "\n".join(parts) + @staticmethod def _not_found_msg(old_text: str, content: str, path: str) -> str: - lines = content.splitlines(keepends=True) - old_lines = old_text.splitlines(keepends=True) - window = len(old_lines) - - best_ratio, best_start = 0.0, 0 - for i in range(max(1, len(lines) - window + 1)): - ratio = difflib.SequenceMatcher(None, old_lines, lines[i : i + window]).ratio() - if ratio > best_ratio: - best_ratio, best_start = ratio, i - + best_ratio, best_start, best_window_lines, hints = _best_window(old_text, content) if best_ratio > 0.5: diff = "\n".join(difflib.unified_diff( - old_lines, lines[best_start : best_start + window], + old_text.splitlines(keepends=True), + best_window_lines, fromfile="old_text (provided)", tofile=f"{path} (actual, line {best_start + 1})", lineterm="", )) - return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}" + hint_text = "" + if hints: + hint_text = "\nPossible cause: " + ", ".join(hints) + "." + return ( + f"Error: old_text not found in {path}." + f"{hint_text}\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}" + ) + + if hints: + return ( + f"Error: old_text not found in {path}. " + f"Possible cause: {', '.join(hints)}. " + "Copy the exact text from read_file and try again." + ) return f"Error: old_text not found in {path}. No similar text found. Verify the file content." diff --git a/nanobot/agent/tools/notebook.py b/nanobot/agent/tools/notebook.py new file mode 100644 index 000000000..8c4be110d --- /dev/null +++ b/nanobot/agent/tools/notebook.py @@ -0,0 +1,162 @@ +"""NotebookEditTool — edit Jupyter .ipynb notebooks.""" + +from __future__ import annotations + +import json +import uuid +from pathlib import Path +from typing import Any + +from nanobot.agent.tools.base import Tool, tool_parameters +from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema +from nanobot.agent.tools.filesystem import _FsTool + + +def _new_cell(source: str, cell_type: str = "code", generate_id: bool = False) -> dict: + cell: dict[str, Any] = { + "cell_type": cell_type, + "source": source, + "metadata": {}, + } + if cell_type == "code": + cell["outputs"] = [] + cell["execution_count"] = None + if generate_id: + cell["id"] = uuid.uuid4().hex[:8] + return cell + + +def _make_empty_notebook() -> dict: + return { + "nbformat": 4, + "nbformat_minor": 5, + "metadata": { + "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, + "language_info": {"name": "python"}, + }, + "cells": [], + } + + +@tool_parameters( + tool_parameters_schema( + path=StringSchema("Path to the .ipynb notebook file"), + cell_index=IntegerSchema(0, description="0-based index of the cell to edit", minimum=0), + new_source=StringSchema("New source content for the cell"), + cell_type=StringSchema( + "Cell type: 'code' or 'markdown' (default: code)", + enum=["code", "markdown"], + ), + edit_mode=StringSchema( + "Mode: 'replace' (default), 'insert' (after target), or 'delete'", + enum=["replace", "insert", "delete"], + ), + required=["path", "cell_index"], + ) +) +class NotebookEditTool(_FsTool): + """Edit Jupyter notebook cells: replace, insert, or delete.""" + + _VALID_CELL_TYPES = frozenset({"code", "markdown"}) + _VALID_EDIT_MODES = frozenset({"replace", "insert", "delete"}) + + @property + def name(self) -> str: + return "notebook_edit" + + @property + def description(self) -> str: + return ( + "Edit a Jupyter notebook (.ipynb) cell. " + "Modes: replace (default) replaces cell content, " + "insert adds a new cell after the target index, " + "delete removes the cell at the index. " + "cell_index is 0-based." + ) + + async def execute( + self, + path: str | None = None, + cell_index: int = 0, + new_source: str = "", + cell_type: str = "code", + edit_mode: str = "replace", + **kwargs: Any, + ) -> str: + try: + if not path: + return "Error: path is required" + + if not path.endswith(".ipynb"): + return "Error: notebook_edit only works on .ipynb files. Use edit_file for other files." + + if edit_mode not in self._VALID_EDIT_MODES: + return ( + f"Error: Invalid edit_mode '{edit_mode}'. " + "Use one of: replace, insert, delete." + ) + + if cell_type not in self._VALID_CELL_TYPES: + return ( + f"Error: Invalid cell_type '{cell_type}'. " + "Use one of: code, markdown." + ) + + fp = self._resolve(path) + + # Create new notebook if file doesn't exist and mode is insert + if not fp.exists(): + if edit_mode != "insert": + return f"Error: File not found: {path}" + nb = _make_empty_notebook() + cell = _new_cell(new_source, cell_type, generate_id=True) + nb["cells"].append(cell) + fp.parent.mkdir(parents=True, exist_ok=True) + fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8") + return f"Successfully created {fp} with 1 cell" + + try: + nb = json.loads(fp.read_text(encoding="utf-8")) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + return f"Error: Failed to parse notebook: {e}" + + cells = nb.get("cells", []) + nbformat_minor = nb.get("nbformat_minor", 0) + generate_id = nb.get("nbformat", 0) >= 4 and nbformat_minor >= 5 + + if edit_mode == "delete": + if cell_index < 0 or cell_index >= len(cells): + return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)" + cells.pop(cell_index) + nb["cells"] = cells + fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8") + return f"Successfully deleted cell {cell_index} from {fp}" + + if edit_mode == "insert": + insert_at = min(cell_index + 1, len(cells)) + cell = _new_cell(new_source, cell_type, generate_id=generate_id) + cells.insert(insert_at, cell) + nb["cells"] = cells + fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8") + return f"Successfully inserted cell at index {insert_at} in {fp}" + + # Default: replace + if cell_index < 0 or cell_index >= len(cells): + return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)" + cells[cell_index]["source"] = new_source + if cell_type and cells[cell_index].get("cell_type") != cell_type: + cells[cell_index]["cell_type"] = cell_type + if cell_type == "code": + cells[cell_index].setdefault("outputs", []) + cells[cell_index].setdefault("execution_count", None) + elif "outputs" in cells[cell_index]: + del cells[cell_index]["outputs"] + cells[cell_index].pop("execution_count", None) + nb["cells"] = cells + fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8") + return f"Successfully edited cell {cell_index} in {fp}" + + except PermissionError as e: + return f"Error: {e}" + except Exception as e: + return f"Error editing notebook: {e}" diff --git a/pyproject.toml b/pyproject.toml index 751716135..b2a25bfad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,12 +76,16 @@ discord = [ langsmith = [ "langsmith>=0.1.0", ] +pdf = [ + "pymupdf>=1.25.0", +] dev = [ "pytest>=9.0.0,<10.0.0", "pytest-asyncio>=1.3.0,<2.0.0", "aiohttp>=3.9.0,<4.0.0", "pytest-cov>=6.0.0,<7.0.0", "ruff>=0.1.0", + "pymupdf>=1.25.0", ] [project.scripts] diff --git a/tests/tools/test_edit_advanced.py b/tests/tools/test_edit_advanced.py new file mode 100644 index 000000000..baf0eb02f --- /dev/null +++ b/tests/tools/test_edit_advanced.py @@ -0,0 +1,423 @@ +"""Tests for advanced EditFileTool enhancements inspired by claude-code: +- Delete-line newline cleanup +- Smart quote normalization (curly ↔ straight) +- Quote style preservation in replacements +- Indentation preservation when fallback match is trimmed +- Trailing whitespace stripping for new_text +- File size protection +- Stale detection with content-equality fallback +""" + +import os +import time + +import pytest + +from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, _find_match +from nanobot.agent.tools import file_state + + +@pytest.fixture(autouse=True) +def _clear_file_state(): + file_state.clear() + yield + file_state.clear() + + +# --------------------------------------------------------------------------- +# Delete-line newline cleanup +# --------------------------------------------------------------------------- + + +class TestDeleteLineCleanup: + """When new_text='' and deleting a line, trailing newline should be consumed.""" + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_delete_line_consumes_trailing_newline(self, tool, tmp_path): + f = tmp_path / "a.py" + f.write_text("line1\nline2\nline3\n", encoding="utf-8") + result = await tool.execute(path=str(f), old_text="line2", new_text="") + assert "Successfully" in result + content = f.read_text() + # Should not leave a blank line where line2 was + assert content == "line1\nline3\n" + + @pytest.mark.asyncio + async def test_delete_line_with_explicit_newline_in_old_text(self, tool, tmp_path): + f = tmp_path / "a.py" + f.write_text("line1\nline2\nline3\n", encoding="utf-8") + result = await tool.execute(path=str(f), old_text="line2\n", new_text="") + assert "Successfully" in result + assert f.read_text() == "line1\nline3\n" + + @pytest.mark.asyncio + async def test_delete_preserves_content_when_not_trailing_newline(self, tool, tmp_path): + """Deleting a word mid-line should not consume extra characters.""" + f = tmp_path / "a.py" + f.write_text("hello world here\n", encoding="utf-8") + result = await tool.execute(path=str(f), old_text="world ", new_text="") + assert "Successfully" in result + assert f.read_text() == "hello here\n" + + +# --------------------------------------------------------------------------- +# Smart quote normalization +# --------------------------------------------------------------------------- + + +class TestSmartQuoteNormalization: + """_find_match should handle curly ↔ straight quote fallback.""" + + def test_curly_double_quotes_match_straight(self): + content = 'She said \u201chello\u201d to him' + old_text = 'She said "hello" to him' + match, count = _find_match(content, old_text) + assert match is not None + assert count == 1 + # Returned match should be the ORIGINAL content with curly quotes + assert "\u201c" in match + + def test_curly_single_quotes_match_straight(self): + content = "it\u2019s a test" + old_text = "it's a test" + match, count = _find_match(content, old_text) + assert match is not None + assert count == 1 + assert "\u2019" in match + + def test_straight_matches_curly_in_old_text(self): + content = 'x = "hello"' + old_text = 'x = \u201chello\u201d' + match, count = _find_match(content, old_text) + assert match is not None + assert count == 1 + + def test_exact_match_still_preferred_over_quote_normalization(self): + content = 'x = "hello"' + old_text = 'x = "hello"' + match, count = _find_match(content, old_text) + assert match == old_text + assert count == 1 + + +class TestQuoteStylePreservation: + """When quote-normalized matching occurs, replacement should preserve actual quote style.""" + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_replacement_preserves_curly_double_quotes(self, tool, tmp_path): + f = tmp_path / "quotes.txt" + f.write_text('message = “hello”\n', encoding="utf-8") + result = await tool.execute( + path=str(f), + old_text='message = "hello"', + new_text='message = "goodbye"', + ) + assert "Successfully" in result + assert f.read_text(encoding="utf-8") == 'message = “goodbye”\n' + + @pytest.mark.asyncio + async def test_replacement_preserves_curly_apostrophe(self, tool, tmp_path): + f = tmp_path / "apostrophe.txt" + f.write_text("it’s fine\n", encoding="utf-8") + result = await tool.execute( + path=str(f), + old_text="it's fine", + new_text="it's better", + ) + assert "Successfully" in result + assert f.read_text(encoding="utf-8") == "it’s better\n" + + +# --------------------------------------------------------------------------- +# Indentation preservation +# --------------------------------------------------------------------------- + + +class TestIndentationPreservation: + """Replacement should keep outer indentation when trim fallback matched.""" + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_trim_fallback_preserves_outer_indentation(self, tool, tmp_path): + f = tmp_path / "indent.py" + f.write_text( + "if True:\n" + " def foo():\n" + " pass\n", + encoding="utf-8", + ) + result = await tool.execute( + path=str(f), + old_text="def foo():\n pass", + new_text="def bar():\n return 1", + ) + assert "Successfully" in result + assert f.read_text(encoding="utf-8") == ( + "if True:\n" + " def bar():\n" + " return 1\n" + ) + + +# --------------------------------------------------------------------------- +# Failure diagnostics +# --------------------------------------------------------------------------- + + +class TestEditDiagnostics: + """Failure paths should offer actionable hints.""" + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_ambiguous_match_reports_candidate_lines(self, tool, tmp_path): + f = tmp_path / "dup.py" + f.write_text("aaa\nbbb\naaa\nbbb\n", encoding="utf-8") + result = await tool.execute(path=str(f), old_text="aaa\nbbb", new_text="xxx") + assert "appears 2 times" in result.lower() + assert "line 1" in result.lower() + assert "line 3" in result.lower() + assert "replace_all=true" in result + + @pytest.mark.asyncio + async def test_not_found_reports_whitespace_hint(self, tool, tmp_path): + f = tmp_path / "space.py" + f.write_text("value = 1\n", encoding="utf-8") + result = await tool.execute(path=str(f), old_text="value = 1", new_text="value = 2") + assert "Error" in result + assert "whitespace" in result.lower() + + @pytest.mark.asyncio + async def test_not_found_reports_case_hint(self, tool, tmp_path): + f = tmp_path / "case.py" + f.write_text("HelloWorld\n", encoding="utf-8") + result = await tool.execute(path=str(f), old_text="helloworld", new_text="goodbye") + assert "Error" in result + assert "letter case differs" in result.lower() + + +# --------------------------------------------------------------------------- +# Advanced fallback replacement behavior +# --------------------------------------------------------------------------- + + +class TestAdvancedReplaceAll: + """replace_all should work correctly for fallback-based matches too.""" + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_replace_all_preserves_each_match_indentation(self, tool, tmp_path): + f = tmp_path / "indent_multi.py" + f.write_text( + "if a:\n" + " def foo():\n" + " pass\n" + "if b:\n" + " def foo():\n" + " pass\n", + encoding="utf-8", + ) + result = await tool.execute( + path=str(f), + old_text="def foo():\n pass", + new_text="def bar():\n return 1", + replace_all=True, + ) + assert "Successfully" in result + assert f.read_text(encoding="utf-8") == ( + "if a:\n" + " def bar():\n" + " return 1\n" + "if b:\n" + " def bar():\n" + " return 1\n" + ) + + @pytest.mark.asyncio + async def test_trim_and_quote_fallback_match_succeeds(self, tool, tmp_path): + f = tmp_path / "quote_indent.py" + f.write_text(" message = “hello”\n", encoding="utf-8") + result = await tool.execute( + path=str(f), + old_text='message = "hello"', + new_text='message = "goodbye"', + ) + assert "Successfully" in result + assert f.read_text(encoding="utf-8") == " message = “goodbye”\n" + + +# --------------------------------------------------------------------------- +# Advanced fallback replacement behavior +# --------------------------------------------------------------------------- + + +class TestAdvancedReplaceAll: + """replace_all should work correctly for fallback-based matches too.""" + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_replace_all_preserves_each_match_indentation(self, tool, tmp_path): + f = tmp_path / "indent_multi.py" + f.write_text( + "if a:\n" + " def foo():\n" + " pass\n" + "if b:\n" + " def foo():\n" + " pass\n", + encoding="utf-8", + ) + result = await tool.execute( + path=str(f), + old_text="def foo():\n pass", + new_text="def bar():\n return 1", + replace_all=True, + ) + assert "Successfully" in result + assert f.read_text(encoding="utf-8") == ( + "if a:\n" + " def bar():\n" + " return 1\n" + "if b:\n" + " def bar():\n" + " return 1\n" + ) + + @pytest.mark.asyncio + async def test_trim_and_quote_fallback_match_succeeds(self, tool, tmp_path): + f = tmp_path / "quote_indent.py" + f.write_text(" message = “hello”\n", encoding="utf-8") + result = await tool.execute( + path=str(f), + old_text='message = "hello"', + new_text='message = "goodbye"', + ) + assert "Successfully" in result + assert f.read_text(encoding="utf-8") == " message = “goodbye”\n" + + +# --------------------------------------------------------------------------- +# Trailing whitespace stripping on new_text +# --------------------------------------------------------------------------- + + +class TestTrailingWhitespaceStrip: + """new_text trailing whitespace should be stripped (except .md files).""" + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_strips_trailing_whitespace_from_new_text(self, tool, tmp_path): + f = tmp_path / "a.py" + f.write_text("x = 1\n", encoding="utf-8") + result = await tool.execute( + path=str(f), old_text="x = 1", new_text="x = 2 \ny = 3 ", + ) + assert "Successfully" in result + content = f.read_text() + assert "x = 2\ny = 3\n" == content + + @pytest.mark.asyncio + async def test_preserves_trailing_whitespace_in_markdown(self, tool, tmp_path): + f = tmp_path / "doc.md" + f.write_text("# Title\n", encoding="utf-8") + # Markdown uses trailing double-space for line breaks + result = await tool.execute( + path=str(f), old_text="# Title", new_text="# Title \nSubtitle ", + ) + assert "Successfully" in result + content = f.read_text() + # Trailing spaces should be preserved for markdown + assert "Title " in content + assert "Subtitle " in content + + +# --------------------------------------------------------------------------- +# File size protection +# --------------------------------------------------------------------------- + + +class TestFileSizeProtection: + """Editing extremely large files should be rejected.""" + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_rejects_file_over_size_limit(self, tool, tmp_path): + f = tmp_path / "huge.txt" + f.write_text("x", encoding="utf-8") + # Monkey-patch the file size check by creating a stat mock + original_stat = f.stat + + class FakeStat: + def __init__(self, real_stat): + self._real = real_stat + + def __getattr__(self, name): + return getattr(self._real, name) + + @property + def st_size(self): + return 2 * 1024 * 1024 * 1024 # 2 GiB + + import unittest.mock + with unittest.mock.patch.object(type(f), 'stat', return_value=FakeStat(f.stat())): + result = await tool.execute(path=str(f), old_text="x", new_text="y") + assert "Error" in result + assert "too large" in result.lower() or "size" in result.lower() + + +# --------------------------------------------------------------------------- +# Stale detection with content-equality fallback +# --------------------------------------------------------------------------- + + +class TestStaleDetectionContentFallback: + """When mtime changed but file content is unchanged, edit should proceed without warning.""" + + @pytest.fixture() + def read_tool(self, tmp_path): + return ReadFileTool(workspace=tmp_path) + + @pytest.fixture() + def edit_tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_mtime_bump_same_content_no_warning(self, read_tool, edit_tool, tmp_path): + f = tmp_path / "a.py" + f.write_text("hello world", encoding="utf-8") + await read_tool.execute(path=str(f)) + + # Touch the file to bump mtime without changing content + time.sleep(0.05) + original_content = f.read_text() + f.write_text(original_content, encoding="utf-8") + + result = await edit_tool.execute(path=str(f), old_text="world", new_text="earth") + assert "Successfully" in result + # Should NOT warn about modification since content is the same + assert "modified" not in result.lower() diff --git a/tests/tools/test_edit_enhancements.py b/tests/tools/test_edit_enhancements.py new file mode 100644 index 000000000..7ad098960 --- /dev/null +++ b/tests/tools/test_edit_enhancements.py @@ -0,0 +1,152 @@ +"""Tests for EditFileTool enhancements: read-before-edit tracking, path suggestions, +.ipynb detection, and create-file semantics.""" + +import pytest + +from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, WriteFileTool +from nanobot.agent.tools import file_state + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def _clear_file_state(): + """Reset global read-state between tests.""" + file_state.clear() + yield + file_state.clear() + + +# --------------------------------------------------------------------------- +# Read-before-edit tracking +# --------------------------------------------------------------------------- + +class TestEditReadTracking: + """edit_file should warn when file hasn't been read first.""" + + @pytest.fixture() + def read_tool(self, tmp_path): + return ReadFileTool(workspace=tmp_path) + + @pytest.fixture() + def edit_tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_edit_warns_if_file_not_read_first(self, edit_tool, tmp_path): + f = tmp_path / "a.py" + f.write_text("hello world", encoding="utf-8") + result = await edit_tool.execute(path=str(f), old_text="world", new_text="earth") + # Should still succeed but include a warning + assert "Successfully" in result + assert "not been read" in result.lower() or "warning" in result.lower() + + @pytest.mark.asyncio + async def test_edit_succeeds_cleanly_after_read(self, read_tool, edit_tool, tmp_path): + f = tmp_path / "a.py" + f.write_text("hello world", encoding="utf-8") + await read_tool.execute(path=str(f)) + result = await edit_tool.execute(path=str(f), old_text="world", new_text="earth") + assert "Successfully" in result + # No warning when file was read first + assert "not been read" not in result.lower() + assert f.read_text() == "hello earth" + + @pytest.mark.asyncio + async def test_edit_warns_if_file_modified_since_read(self, read_tool, edit_tool, tmp_path): + f = tmp_path / "a.py" + f.write_text("hello world", encoding="utf-8") + await read_tool.execute(path=str(f)) + # External modification + f.write_text("hello universe", encoding="utf-8") + result = await edit_tool.execute(path=str(f), old_text="universe", new_text="earth") + assert "Successfully" in result + assert "modified" in result.lower() or "warning" in result.lower() + + +# --------------------------------------------------------------------------- +# Create-file semantics +# --------------------------------------------------------------------------- + +class TestEditCreateFile: + """edit_file with old_text='' creates new file if not exists.""" + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_create_new_file_with_empty_old_text(self, tool, tmp_path): + f = tmp_path / "subdir" / "new.py" + result = await tool.execute(path=str(f), old_text="", new_text="print('hi')") + assert "created" in result.lower() or "Successfully" in result + assert f.exists() + assert f.read_text() == "print('hi')" + + @pytest.mark.asyncio + async def test_create_fails_if_file_already_exists_and_not_empty(self, tool, tmp_path): + f = tmp_path / "existing.py" + f.write_text("existing content", encoding="utf-8") + result = await tool.execute(path=str(f), old_text="", new_text="new content") + assert "Error" in result or "already exists" in result.lower() + # File should be unchanged + assert f.read_text() == "existing content" + + @pytest.mark.asyncio + async def test_create_succeeds_if_file_exists_but_empty(self, tool, tmp_path): + f = tmp_path / "empty.py" + f.write_text("", encoding="utf-8") + result = await tool.execute(path=str(f), old_text="", new_text="print('hi')") + assert "Successfully" in result + assert f.read_text() == "print('hi')" + + +# --------------------------------------------------------------------------- +# .ipynb detection +# --------------------------------------------------------------------------- + +class TestEditIpynbDetection: + """edit_file should refuse .ipynb and suggest notebook_edit.""" + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_ipynb_rejected_with_suggestion(self, tool, tmp_path): + f = tmp_path / "analysis.ipynb" + f.write_text('{"cells": []}', encoding="utf-8") + result = await tool.execute(path=str(f), old_text="x", new_text="y") + assert "notebook" in result.lower() + + +# --------------------------------------------------------------------------- +# Path suggestion on not-found +# --------------------------------------------------------------------------- + +class TestEditPathSuggestion: + """edit_file should suggest similar paths on not-found.""" + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_suggests_similar_filename(self, tool, tmp_path): + f = tmp_path / "config.py" + f.write_text("x = 1", encoding="utf-8") + # Typo: conifg.py + result = await tool.execute( + path=str(tmp_path / "conifg.py"), old_text="x = 1", new_text="x = 2", + ) + assert "Error" in result + assert "config.py" in result + + @pytest.mark.asyncio + async def test_shows_cwd_in_error(self, tool, tmp_path): + result = await tool.execute( + path=str(tmp_path / "nonexistent.py"), old_text="a", new_text="b", + ) + assert "Error" in result diff --git a/tests/tools/test_notebook_tool.py b/tests/tools/test_notebook_tool.py new file mode 100644 index 000000000..232f13c4b --- /dev/null +++ b/tests/tools/test_notebook_tool.py @@ -0,0 +1,147 @@ +"""Tests for NotebookEditTool — Jupyter .ipynb editing.""" + +import json + +import pytest + +from nanobot.agent.tools.notebook import NotebookEditTool + + +def _make_notebook(cells: list[dict] | None = None, nbformat: int = 4, nbformat_minor: int = 5) -> dict: + """Build a minimal valid .ipynb structure.""" + return { + "nbformat": nbformat, + "nbformat_minor": nbformat_minor, + "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}}, + "cells": cells or [], + } + + +def _code_cell(source: str, cell_id: str | None = None) -> dict: + cell = {"cell_type": "code", "source": source, "metadata": {}, "outputs": [], "execution_count": None} + if cell_id: + cell["id"] = cell_id + return cell + + +def _md_cell(source: str, cell_id: str | None = None) -> dict: + cell = {"cell_type": "markdown", "source": source, "metadata": {}} + if cell_id: + cell["id"] = cell_id + return cell + + +def _write_nb(tmp_path, name: str, nb: dict) -> str: + p = tmp_path / name + p.write_text(json.dumps(nb), encoding="utf-8") + return str(p) + + +class TestNotebookEdit: + + @pytest.fixture() + def tool(self, tmp_path): + return NotebookEditTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_replace_cell_content(self, tool, tmp_path): + nb = _make_notebook([_code_cell("print('hello')"), _code_cell("x = 1")]) + path = _write_nb(tmp_path, "test.ipynb", nb) + result = await tool.execute(path=path, cell_index=0, new_source="print('world')") + assert "Successfully" in result + saved = json.loads((tmp_path / "test.ipynb").read_text()) + assert saved["cells"][0]["source"] == "print('world')" + assert saved["cells"][1]["source"] == "x = 1" + + @pytest.mark.asyncio + async def test_insert_cell_after_target(self, tool, tmp_path): + nb = _make_notebook([_code_cell("cell 0"), _code_cell("cell 1")]) + path = _write_nb(tmp_path, "test.ipynb", nb) + result = await tool.execute(path=path, cell_index=0, new_source="inserted", edit_mode="insert") + assert "Successfully" in result + saved = json.loads((tmp_path / "test.ipynb").read_text()) + assert len(saved["cells"]) == 3 + assert saved["cells"][0]["source"] == "cell 0" + assert saved["cells"][1]["source"] == "inserted" + assert saved["cells"][2]["source"] == "cell 1" + + @pytest.mark.asyncio + async def test_delete_cell(self, tool, tmp_path): + nb = _make_notebook([_code_cell("A"), _code_cell("B"), _code_cell("C")]) + path = _write_nb(tmp_path, "test.ipynb", nb) + result = await tool.execute(path=path, cell_index=1, edit_mode="delete") + assert "Successfully" in result + saved = json.loads((tmp_path / "test.ipynb").read_text()) + assert len(saved["cells"]) == 2 + assert saved["cells"][0]["source"] == "A" + assert saved["cells"][1]["source"] == "C" + + @pytest.mark.asyncio + async def test_create_new_notebook_from_scratch(self, tool, tmp_path): + path = str(tmp_path / "new.ipynb") + result = await tool.execute(path=path, cell_index=0, new_source="# Hello", edit_mode="insert", cell_type="markdown") + assert "Successfully" in result or "created" in result.lower() + saved = json.loads((tmp_path / "new.ipynb").read_text()) + assert saved["nbformat"] == 4 + assert len(saved["cells"]) == 1 + assert saved["cells"][0]["cell_type"] == "markdown" + assert saved["cells"][0]["source"] == "# Hello" + + @pytest.mark.asyncio + async def test_invalid_cell_index_error(self, tool, tmp_path): + nb = _make_notebook([_code_cell("only cell")]) + path = _write_nb(tmp_path, "test.ipynb", nb) + result = await tool.execute(path=path, cell_index=5, new_source="x") + assert "Error" in result + + @pytest.mark.asyncio + async def test_non_ipynb_rejected(self, tool, tmp_path): + f = tmp_path / "script.py" + f.write_text("pass") + result = await tool.execute(path=str(f), cell_index=0, new_source="x") + assert "Error" in result + assert ".ipynb" in result + + @pytest.mark.asyncio + async def test_preserves_metadata_and_outputs(self, tool, tmp_path): + cell = _code_cell("old") + cell["outputs"] = [{"output_type": "stream", "text": "hello\n"}] + cell["execution_count"] = 42 + nb = _make_notebook([cell]) + path = _write_nb(tmp_path, "test.ipynb", nb) + await tool.execute(path=path, cell_index=0, new_source="new") + saved = json.loads((tmp_path / "test.ipynb").read_text()) + assert saved["metadata"]["kernelspec"]["language"] == "python" + + @pytest.mark.asyncio + async def test_nbformat_45_generates_cell_id(self, tool, tmp_path): + nb = _make_notebook([], nbformat_minor=5) + path = _write_nb(tmp_path, "test.ipynb", nb) + await tool.execute(path=path, cell_index=0, new_source="x = 1", edit_mode="insert") + saved = json.loads((tmp_path / "test.ipynb").read_text()) + assert "id" in saved["cells"][0] + assert len(saved["cells"][0]["id"]) > 0 + + @pytest.mark.asyncio + async def test_insert_with_cell_type_markdown(self, tool, tmp_path): + nb = _make_notebook([_code_cell("code")]) + path = _write_nb(tmp_path, "test.ipynb", nb) + await tool.execute(path=path, cell_index=0, new_source="# Title", edit_mode="insert", cell_type="markdown") + saved = json.loads((tmp_path / "test.ipynb").read_text()) + assert saved["cells"][1]["cell_type"] == "markdown" + + @pytest.mark.asyncio + async def test_invalid_edit_mode_rejected(self, tool, tmp_path): + nb = _make_notebook([_code_cell("code")]) + path = _write_nb(tmp_path, "test.ipynb", nb) + result = await tool.execute(path=path, cell_index=0, new_source="x", edit_mode="replcae") + assert "Error" in result + assert "edit_mode" in result + + @pytest.mark.asyncio + async def test_invalid_cell_type_rejected(self, tool, tmp_path): + nb = _make_notebook([_code_cell("code")]) + path = _write_nb(tmp_path, "test.ipynb", nb) + result = await tool.execute(path=path, cell_index=0, new_source="x", cell_type="raw") + assert "Error" in result + assert "cell_type" in result diff --git a/tests/tools/test_read_enhancements.py b/tests/tools/test_read_enhancements.py new file mode 100644 index 000000000..a703ba6e4 --- /dev/null +++ b/tests/tools/test_read_enhancements.py @@ -0,0 +1,180 @@ +"""Tests for ReadFileTool enhancements: description fix, read dedup, PDF support, device blacklist.""" + +import pytest + +from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool +from nanobot.agent.tools import file_state + + +@pytest.fixture(autouse=True) +def _clear_file_state(): + file_state.clear() + yield + file_state.clear() + + +# --------------------------------------------------------------------------- +# Description fix +# --------------------------------------------------------------------------- + +class TestReadDescriptionFix: + + def test_description_mentions_image_support(self): + tool = ReadFileTool() + assert "image" in tool.description.lower() + + def test_description_no_longer_says_cannot_read_images(self): + tool = ReadFileTool() + assert "cannot read binary files or images" not in tool.description.lower() + + +# --------------------------------------------------------------------------- +# Read deduplication +# --------------------------------------------------------------------------- + +class TestReadDedup: + """Same file + same offset/limit + unchanged mtime -> short stub.""" + + @pytest.fixture() + def tool(self, tmp_path): + return ReadFileTool(workspace=tmp_path) + + @pytest.fixture() + def write_tool(self, tmp_path): + return WriteFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_second_read_returns_unchanged_stub(self, tool, tmp_path): + f = tmp_path / "data.txt" + f.write_text("\n".join(f"line {i}" for i in range(100)), encoding="utf-8") + first = await tool.execute(path=str(f)) + assert "line 0" in first + second = await tool.execute(path=str(f)) + assert "unchanged" in second.lower() + # Stub should not contain file content + assert "line 0" not in second + + @pytest.mark.asyncio + async def test_read_after_external_modification_returns_full(self, tool, tmp_path): + f = tmp_path / "data.txt" + f.write_text("original", encoding="utf-8") + await tool.execute(path=str(f)) + # Modify the file externally + f.write_text("modified content", encoding="utf-8") + second = await tool.execute(path=str(f)) + assert "modified content" in second + + @pytest.mark.asyncio + async def test_different_offset_returns_full(self, tool, tmp_path): + f = tmp_path / "data.txt" + f.write_text("\n".join(f"line {i}" for i in range(1, 21)), encoding="utf-8") + await tool.execute(path=str(f), offset=1, limit=5) + second = await tool.execute(path=str(f), offset=6, limit=5) + # Different offset → full read, not stub + assert "line 6" in second + + @pytest.mark.asyncio + async def test_first_read_after_write_returns_full_content(self, tool, write_tool, tmp_path): + f = tmp_path / "fresh.txt" + result = await write_tool.execute(path=str(f), content="hello") + assert "Successfully" in result + read_result = await tool.execute(path=str(f)) + assert "hello" in read_result + assert "unchanged" not in read_result.lower() + + @pytest.mark.asyncio + async def test_dedup_does_not_apply_to_images(self, tool, tmp_path): + f = tmp_path / "img.png" + f.write_bytes(b"\x89PNG\r\n\x1a\nfake-png-data") + first = await tool.execute(path=str(f)) + assert isinstance(first, list) + second = await tool.execute(path=str(f)) + # Images should always return full content blocks, not a stub + assert isinstance(second, list) + + +# --------------------------------------------------------------------------- +# PDF support +# --------------------------------------------------------------------------- + +class TestReadPdf: + + @pytest.fixture() + def tool(self, tmp_path): + return ReadFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_pdf_returns_text_content(self, tool, tmp_path): + fitz = pytest.importorskip("fitz") + pdf_path = tmp_path / "test.pdf" + doc = fitz.open() + page = doc.new_page() + page.insert_text((72, 72), "Hello PDF World") + doc.save(str(pdf_path)) + doc.close() + + result = await tool.execute(path=str(pdf_path)) + assert "Hello PDF World" in result + + @pytest.mark.asyncio + async def test_pdf_pages_parameter(self, tool, tmp_path): + fitz = pytest.importorskip("fitz") + pdf_path = tmp_path / "multi.pdf" + doc = fitz.open() + for i in range(5): + page = doc.new_page() + page.insert_text((72, 72), f"Page {i + 1} content") + doc.save(str(pdf_path)) + doc.close() + + result = await tool.execute(path=str(pdf_path), pages="2-3") + assert "Page 2 content" in result + assert "Page 3 content" in result + assert "Page 1 content" not in result + + @pytest.mark.asyncio + async def test_pdf_file_not_found_error(self, tool, tmp_path): + result = await tool.execute(path=str(tmp_path / "nope.pdf")) + assert "Error" in result + assert "not found" in result + + +# --------------------------------------------------------------------------- +# Device path blacklist +# --------------------------------------------------------------------------- + +class TestReadDeviceBlacklist: + + @pytest.fixture() + def tool(self): + return ReadFileTool() + + @pytest.mark.asyncio + async def test_dev_random_blocked(self, tool): + result = await tool.execute(path="/dev/random") + assert "Error" in result + assert "blocked" in result.lower() or "device" in result.lower() + + @pytest.mark.asyncio + async def test_dev_urandom_blocked(self, tool): + result = await tool.execute(path="/dev/urandom") + assert "Error" in result + + @pytest.mark.asyncio + async def test_dev_zero_blocked(self, tool): + result = await tool.execute(path="/dev/zero") + assert "Error" in result + + @pytest.mark.asyncio + async def test_proc_fd_blocked(self, tool): + result = await tool.execute(path="/proc/self/fd/0") + assert "Error" in result + + @pytest.mark.asyncio + async def test_symlink_to_dev_zero_blocked(self, tmp_path): + tool = ReadFileTool(workspace=tmp_path) + link = tmp_path / "zero-link" + link.symlink_to("/dev/zero") + result = await tool.execute(path=str(link)) + assert "Error" in result + assert "blocked" in result.lower() or "device" in result.lower() From a167959027d8ffc614c800ea764e9585c5c967f3 Mon Sep 17 00:00:00 2001 From: worenidewen Date: Fri, 10 Apr 2026 23:51:50 +0800 Subject: [PATCH 056/115] fix(mcp): support multiple MCP servers by connecting each in isolated task Each MCP server now connects in its own asyncio.Task to isolate anyio cancel scopes and prevent 'exit cancel scope in different task' errors when multiple servers (especially mixed transport types) are configured. Changes: - connect_mcp_servers() returns dict[str, AsyncExitStack] instead of None - Each server runs in separate task via asyncio.gather() - AgentLoop uses _mcp_stacks dict to track per-server stacks - Tests updated to handle new API --- nanobot/agent/loop.py | 245 ++++++++++++++++++++++------------- nanobot/agent/tools/mcp.py | 92 ++++++++----- tests/tools/test_mcp_tool.py | 108 +++++---------- 3 files changed, 247 insertions(+), 198 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index bc83cc77c..f7afbe901 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -43,6 +43,7 @@ if TYPE_CHECKING: UNIFIED_SESSION_KEY = "unified:default" + class _LoopHook(AgentHook): """Core hook for the main loop.""" @@ -76,7 +77,7 @@ class _LoopHook(AgentHook): prev_clean = strip_think(self._stream_buf) self._stream_buf += delta new_clean = strip_think(self._stream_buf) - incremental = new_clean[len(prev_clean):] + incremental = new_clean[len(prev_clean) :] if incremental and self._on_stream: await self._on_stream(incremental) @@ -112,6 +113,7 @@ class _LoopHook(AgentHook): def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: return self._loop._strip_think(content) + class AgentLoop: """ The agent loop is the core processing engine. @@ -196,7 +198,7 @@ class AgentLoop: self._unified_session = unified_session self._running = False self._mcp_servers = mcp_servers or {} - self._mcp_stack: AsyncExitStack | None = None + self._mcp_stacks: dict[str, AsyncExitStack] = {} self._mcp_connected = False self._mcp_connecting = False self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks @@ -228,24 +230,34 @@ class AgentLoop: def _register_default_tools(self) -> None: """Register the default set of tools.""" - allowed_dir = self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None + allowed_dir = ( + self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None + ) extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None - self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read)) + self.tools.register( + ReadFileTool( + workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read + ) + ) for cls in (WriteFileTool, EditFileTool, ListDirTool): self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir)) for cls in (GlobTool, GrepTool): self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir)) if self.exec_config.enable: - self.tools.register(ExecTool( - working_dir=str(self.workspace), - timeout=self.exec_config.timeout, - restrict_to_workspace=self.restrict_to_workspace, - sandbox=self.exec_config.sandbox, - path_append=self.exec_config.path_append, - allowed_env_keys=self.exec_config.allowed_env_keys, - )) + self.tools.register( + ExecTool( + working_dir=str(self.workspace), + timeout=self.exec_config.timeout, + restrict_to_workspace=self.restrict_to_workspace, + sandbox=self.exec_config.sandbox, + path_append=self.exec_config.path_append, + allowed_env_keys=self.exec_config.allowed_env_keys, + ) + ) if self.web_config.enable: - self.tools.register(WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy)) + self.tools.register( + WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy) + ) self.tools.register(WebFetchTool(proxy=self.web_config.proxy)) self.tools.register(MessageTool(send_callback=self.bus.publish_outbound)) self.tools.register(SpawnTool(manager=self.subagents)) @@ -260,19 +272,16 @@ class AgentLoop: return self._mcp_connecting = True from nanobot.agent.tools.mcp import connect_mcp_servers + try: - self._mcp_stack = AsyncExitStack() - await self._mcp_stack.__aenter__() - await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack) + self._mcp_stacks = await connect_mcp_servers(self._mcp_servers, self.tools) self._mcp_connected = True + except asyncio.CancelledError: + logger.warning("MCP connection cancelled (will retry next message)") + self._mcp_stacks.clear() except BaseException as e: logger.error("Failed to connect MCP servers (will retry next message): {}", e) - if self._mcp_stack: - try: - await self._mcp_stack.aclose() - except Exception: - pass - self._mcp_stack = None + self._mcp_stacks.clear() finally: self._mcp_connecting = False @@ -289,6 +298,7 @@ class AgentLoop: if not text: return None from nanobot.utils.helpers import strip_think + return strip_think(text) or None @staticmethod @@ -327,9 +337,7 @@ class AgentLoop: message_id=message_id, ) hook: AgentHook = ( - CompositeHook([loop_hook] + self._extra_hooks) - if self._extra_hooks - else loop_hook + CompositeHook([loop_hook] + self._extra_hooks) if self._extra_hooks else loop_hook ) async def _checkpoint(payload: dict[str, Any]) -> None: @@ -337,23 +345,25 @@ class AgentLoop: return self._set_runtime_checkpoint(session, payload) - result = await self.runner.run(AgentRunSpec( - initial_messages=initial_messages, - tools=self.tools, - model=self.model, - max_iterations=self.max_iterations, - max_tool_result_chars=self.max_tool_result_chars, - hook=hook, - error_message="Sorry, I encountered an error calling the AI model.", - concurrent_tools=True, - workspace=self.workspace, - session_key=session.key if session else None, - context_window_tokens=self.context_window_tokens, - context_block_limit=self.context_block_limit, - provider_retry_mode=self.provider_retry_mode, - progress_callback=on_progress, - checkpoint_callback=_checkpoint, - )) + result = await self.runner.run( + AgentRunSpec( + initial_messages=initial_messages, + tools=self.tools, + model=self.model, + max_iterations=self.max_iterations, + max_tool_result_chars=self.max_tool_result_chars, + hook=hook, + error_message="Sorry, I encountered an error calling the AI model.", + concurrent_tools=True, + workspace=self.workspace, + session_key=session.key if session else None, + context_window_tokens=self.context_window_tokens, + context_block_limit=self.context_block_limit, + provider_retry_mode=self.provider_retry_mode, + progress_callback=on_progress, + checkpoint_callback=_checkpoint, + ) + ) self._last_usage = result.usage if result.stop_reason == "max_iterations": logger.warning("Max iterations ({}) reached", self.max_iterations) @@ -391,10 +401,19 @@ class AgentLoop: continue # Compute the effective session key before dispatching # This ensures /stop command can find tasks correctly when unified session is enabled - effective_key = UNIFIED_SESSION_KEY if self._unified_session and not msg.session_key_override else msg.session_key + effective_key = ( + UNIFIED_SESSION_KEY + if self._unified_session and not msg.session_key_override + else msg.session_key + ) task = asyncio.create_task(self._dispatch(msg)) self._active_tasks.setdefault(effective_key, []).append(task) - task.add_done_callback(lambda t, k=effective_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None) + task.add_done_callback( + lambda t, k=effective_key: self._active_tasks.get(k, []) + and self._active_tasks[k].remove(t) + if t in self._active_tasks.get(k, []) + else None + ) async def _dispatch(self, msg: InboundMessage) -> None: """Process a message: per-session serial, cross-session concurrent.""" @@ -417,11 +436,14 @@ class AgentLoop: meta = dict(msg.metadata or {}) meta["_stream_delta"] = True meta["_stream_id"] = _current_stream_id() - await self.bus.publish_outbound(OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, - content=delta, - metadata=meta, - )) + await self.bus.publish_outbound( + OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content=delta, + metadata=meta, + ) + ) async def on_stream_end(*, resuming: bool = False) -> None: nonlocal stream_segment @@ -429,44 +451,56 @@ class AgentLoop: meta["_stream_end"] = True meta["_resuming"] = resuming meta["_stream_id"] = _current_stream_id() - await self.bus.publish_outbound(OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, - content="", - metadata=meta, - )) + await self.bus.publish_outbound( + OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content="", + metadata=meta, + ) + ) stream_segment += 1 response = await self._process_message( - msg, on_stream=on_stream, on_stream_end=on_stream_end, + msg, + on_stream=on_stream, + on_stream_end=on_stream_end, ) if response is not None: await self.bus.publish_outbound(response) elif msg.channel == "cli": - await self.bus.publish_outbound(OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, - content="", metadata=msg.metadata or {}, - )) + await self.bus.publish_outbound( + OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content="", + metadata=msg.metadata or {}, + ) + ) except asyncio.CancelledError: logger.info("Task cancelled for session {}", msg.session_key) raise except Exception: logger.exception("Error processing message for session {}", msg.session_key) - await self.bus.publish_outbound(OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, - content="Sorry, I encountered an error.", - )) + await self.bus.publish_outbound( + OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content="Sorry, I encountered an error.", + ) + ) async def close_mcp(self) -> None: """Drain pending background archives, then close MCP connections.""" if self._background_tasks: await asyncio.gather(*self._background_tasks, return_exceptions=True) self._background_tasks.clear() - if self._mcp_stack: + for name, stack in self._mcp_stacks.items(): try: - await self._mcp_stack.aclose() + await stack.aclose() except (RuntimeError, BaseExceptionGroup): - pass # MCP SDK cancel scope cleanup is noisy but harmless - self._mcp_stack = None + logger.debug("MCP server '{}' cleanup error (can be ignored)", name) + self._mcp_stacks.clear() def _schedule_background(self, coro) -> None: """Schedule a coroutine as a tracked background task (drained on shutdown).""" @@ -490,8 +524,9 @@ class AgentLoop: """Process a single inbound message and return the response.""" # System messages: parse origin from chat_id ("channel:chat_id") if msg.channel == "system": - channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id - else ("cli", msg.chat_id)) + channel, chat_id = ( + msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id) + ) logger.info("Processing system message from {}", msg.sender_id) key = f"{channel}:{chat_id}" session = self.sessions.get_or_create(key) @@ -503,19 +538,27 @@ class AgentLoop: current_role = "assistant" if msg.sender_id == "subagent" else "user" messages = self.context.build_messages( history=history, - current_message=msg.content, channel=channel, chat_id=chat_id, + current_message=msg.content, + channel=channel, + chat_id=chat_id, current_role=current_role, ) final_content, _, all_msgs, _ = await self._run_agent_loop( - messages, session=session, channel=channel, chat_id=chat_id, + messages, + session=session, + channel=channel, + chat_id=chat_id, message_id=msg.metadata.get("message_id"), ) self._save_turn(session, all_msgs, 1 + len(history)) self._clear_runtime_checkpoint(session) self.sessions.save(session) self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) - return OutboundMessage(channel=channel, chat_id=chat_id, - content=final_content or "Background task completed.") + return OutboundMessage( + channel=channel, + chat_id=chat_id, + content=final_content or "Background task completed.", + ) preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview) @@ -543,16 +586,22 @@ class AgentLoop: history=history, current_message=msg.content, media=msg.media if msg.media else None, - channel=msg.channel, chat_id=msg.chat_id, + channel=msg.channel, + chat_id=msg.chat_id, ) async def _bus_progress(content: str, *, tool_hint: bool = False) -> None: meta = dict(msg.metadata or {}) meta["_progress"] = True meta["_tool_hint"] = tool_hint - await self.bus.publish_outbound(OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta, - )) + await self.bus.publish_outbound( + OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content=content, + metadata=meta, + ) + ) final_content, _, all_msgs, stop_reason = await self._run_agent_loop( initial_messages, @@ -560,7 +609,8 @@ class AgentLoop: on_stream=on_stream, on_stream_end=on_stream_end, session=session, - channel=msg.channel, chat_id=msg.chat_id, + channel=msg.channel, + chat_id=msg.chat_id, message_id=msg.metadata.get("message_id"), ) @@ -582,7 +632,9 @@ class AgentLoop: if on_stream is not None and stop_reason != "error": meta["_streamed"] = True return OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, content=final_content, + channel=msg.channel, + chat_id=msg.chat_id, + content=final_content, metadata=meta, ) @@ -608,10 +660,9 @@ class AgentLoop: ): continue - if ( - block.get("type") == "image_url" - and block.get("image_url", {}).get("url", "").startswith("data:image/") - ): + if block.get("type") == "image_url" and block.get("image_url", {}).get( + "url", "" + ).startswith("data:image/"): path = (block.get("_meta") or {}).get("path", "") filtered.append({"type": "text", "text": image_placeholder_text(path)}) continue @@ -630,6 +681,7 @@ class AgentLoop: def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None: """Save new-turn messages into session, truncating large tool results.""" from datetime import datetime + for m in messages[skip:]: entry = dict(m) role, content = entry.get("role"), entry.get("content") @@ -644,7 +696,9 @@ class AgentLoop: continue entry["content"] = filtered elif role == "user": - if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG): + if isinstance(content, str) and content.startswith( + ContextBuilder._RUNTIME_CONTEXT_TAG + ): # Strip the runtime-context prefix, keep only the user text. parts = content.split("\n\n", 1) if len(parts) > 1 and parts[1].strip(): @@ -708,13 +762,15 @@ class AgentLoop: continue tool_id = tool_call.get("id") name = ((tool_call.get("function") or {}).get("name")) or "tool" - restored_messages.append({ - "role": "tool", - "tool_call_id": tool_id, - "name": name, - "content": "Error: Task interrupted before this tool finished.", - "timestamp": datetime.now().isoformat(), - }) + restored_messages.append( + { + "role": "tool", + "tool_call_id": tool_id, + "name": name, + "content": "Error: Task interrupted before this tool finished.", + "timestamp": datetime.now().isoformat(), + } + ) overlap = 0 max_overlap = min(len(session.messages), len(restored_messages)) @@ -746,6 +802,9 @@ class AgentLoop: await self._connect_mcp() msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content) return await self._process_message( - msg, session_key=session_key, on_progress=on_progress, - on_stream=on_stream, on_stream_end=on_stream_end, + msg, + session_key=session_key, + on_progress=on_progress, + on_stream=on_stream, + on_stream_end=on_stream_end, ) diff --git a/nanobot/agent/tools/mcp.py b/nanobot/agent/tools/mcp.py index 86df9d744..1b5a71322 100644 --- a/nanobot/agent/tools/mcp.py +++ b/nanobot/agent/tools/mcp.py @@ -57,9 +57,7 @@ def _normalize_schema_for_openai(schema: Any) -> dict[str, Any]: if "properties" in normalized and isinstance(normalized["properties"], dict): normalized["properties"] = { - name: _normalize_schema_for_openai(prop) - if isinstance(prop, dict) - else prop + name: _normalize_schema_for_openai(prop) if isinstance(prop, dict) else prop for name, prop in normalized["properties"].items() } @@ -138,9 +136,7 @@ class MCPToolWrapper(Tool): class MCPResourceWrapper(Tool): """Wraps an MCP resource URI as a read-only nanobot Tool.""" - def __init__( - self, session, server_name: str, resource_def, resource_timeout: int = 30 - ): + def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30): self._session = session self._uri = resource_def.uri self._name = f"mcp_{server_name}_resource_{resource_def.name}" @@ -211,9 +207,7 @@ class MCPResourceWrapper(Tool): class MCPPromptWrapper(Tool): """Wraps an MCP prompt as a read-only nanobot Tool.""" - def __init__( - self, session, server_name: str, prompt_def, prompt_timeout: int = 30 - ): + def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30): self._session = session self._prompt_name = prompt_def.name self._name = f"mcp_{server_name}_prompt_{prompt_def.name}" @@ -266,9 +260,7 @@ class MCPPromptWrapper(Tool): timeout=self._prompt_timeout, ) except asyncio.TimeoutError: - logger.warning( - "MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout - ) + logger.warning("MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout) return f"(MCP prompt call timed out after {self._prompt_timeout}s)" except asyncio.CancelledError: task = asyncio.current_task() @@ -279,13 +271,17 @@ class MCPPromptWrapper(Tool): except McpError as exc: logger.error( "MCP prompt '{}' failed: code={} message={}", - self._name, exc.error.code, exc.error.message, + self._name, + exc.error.code, + exc.error.message, ) return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])" except Exception as exc: logger.exception( "MCP prompt '{}' failed: {}: {}", - self._name, type(exc).__name__, exc, + self._name, + type(exc).__name__, + exc, ) return f"(MCP prompt call failed: {type(exc).__name__})" @@ -307,35 +303,44 @@ class MCPPromptWrapper(Tool): async def connect_mcp_servers( - mcp_servers: dict, registry: ToolRegistry, stack: AsyncExitStack -) -> None: - """Connect to configured MCP servers and register their tools, resources, and prompts.""" + mcp_servers: dict, registry: ToolRegistry +) -> dict[str, AsyncExitStack]: + """Connect to configured MCP servers and register their tools, resources, prompts. + + Returns a dict mapping server name -> its dedicated AsyncExitStack. + Each server gets its own stack and runs in its own task to prevent + cancel scope conflicts when multiple MCP servers are configured. + """ from mcp import ClientSession, StdioServerParameters from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamable_http_client - for name, cfg in mcp_servers.items(): + async def connect_single_server(name: str, cfg) -> tuple[str, AsyncExitStack | None]: + server_stack = AsyncExitStack() + await server_stack.__aenter__() + try: transport_type = cfg.type if not transport_type: if cfg.command: transport_type = "stdio" elif cfg.url: - # Convention: URLs ending with /sse use SSE transport; others use streamableHttp transport_type = ( "sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp" ) else: logger.warning("MCP server '{}': no command or url configured, skipping", name) - continue + await server_stack.aclose() + return name, None if transport_type == "stdio": params = StdioServerParameters( command=cfg.command, args=cfg.args, env=cfg.env or None ) - read, write = await stack.enter_async_context(stdio_client(params)) + read, write = await server_stack.enter_async_context(stdio_client(params)) elif transport_type == "sse": + def httpx_client_factory( headers: dict[str, str] | None = None, timeout: httpx.Timeout | None = None, @@ -353,27 +358,26 @@ async def connect_mcp_servers( auth=auth, ) - read, write = await stack.enter_async_context( + read, write = await server_stack.enter_async_context( sse_client(cfg.url, httpx_client_factory=httpx_client_factory) ) elif transport_type == "streamableHttp": - # Always provide an explicit httpx client so MCP HTTP transport does not - # inherit httpx's default 5s timeout and preempt the higher-level tool timeout. - http_client = await stack.enter_async_context( + http_client = await server_stack.enter_async_context( httpx.AsyncClient( headers=cfg.headers or None, follow_redirects=True, timeout=None, ) ) - read, write, _ = await stack.enter_async_context( + read, write, _ = await server_stack.enter_async_context( streamable_http_client(cfg.url, http_client=http_client) ) else: logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type) - continue + await server_stack.aclose() + return name, None - session = await stack.enter_async_context(ClientSession(read, write)) + session = await server_stack.enter_async_context(ClientSession(read, write)) await session.initialize() tools = await session.list_tools() @@ -418,7 +422,6 @@ async def connect_mcp_servers( ", ".join(available_wrapped_names) or "(none)", ) - # --- Register resources --- try: resources_result = await session.list_resources() for resource in resources_result.resources: @@ -433,7 +436,6 @@ async def connect_mcp_servers( except Exception as e: logger.debug("MCP server '{}': resources not supported or failed: {}", name, e) - # --- Register prompts --- try: prompts_result = await session.list_prompts() for prompt in prompts_result.prompts: @@ -442,14 +444,38 @@ async def connect_mcp_servers( ) registry.register(wrapper) registered_count += 1 - logger.debug( - "MCP: registered prompt '{}' from server '{}'", wrapper.name, name - ) + logger.debug("MCP: registered prompt '{}' from server '{}'", wrapper.name, name) except Exception as e: logger.debug("MCP server '{}': prompts not supported or failed: {}", name, e) logger.info( "MCP server '{}': connected, {} capabilities registered", name, registered_count ) + return name, server_stack + except Exception as e: logger.error("MCP server '{}': failed to connect: {}", name, e) + try: + await server_stack.aclose() + except Exception: + pass + return name, None + + server_stacks: dict[str, AsyncExitStack] = {} + + tasks: list[asyncio.Task] = [] + for name, cfg in mcp_servers.items(): + task = asyncio.create_task(connect_single_server(name, cfg)) + tasks.append(task) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + for i, result in enumerate(results): + name = list(mcp_servers.keys())[i] + if isinstance(result, BaseException): + if not isinstance(result, asyncio.CancelledError): + logger.error("MCP server '{}' connection task failed: {}", name, result) + elif result is not None and result[1] is not None: + server_stacks[result[0]] = result[1] + + return server_stacks diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index 67382d9ea..adeb78e75 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -271,15 +271,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_raw_names( ) -> None: fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"]) registry = ToolRegistry() - stack = AsyncExitStack() - await stack.__aenter__() - try: - await connect_mcp_servers( - {"test": MCPServerConfig(command="fake", enabled_tools=["demo"])}, - registry, - stack, - ) - finally: + stacks = await connect_mcp_servers( + {"test": MCPServerConfig(command="fake", enabled_tools=["demo"])}, + registry, + ) + for stack in stacks.values(): await stack.aclose() assert registry.tool_names == ["mcp_test_demo"] @@ -291,15 +287,11 @@ async def test_connect_mcp_servers_enabled_tools_defaults_to_all( ) -> None: fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"]) registry = ToolRegistry() - stack = AsyncExitStack() - await stack.__aenter__() - try: - await connect_mcp_servers( - {"test": MCPServerConfig(command="fake")}, - registry, - stack, - ) - finally: + stacks = await connect_mcp_servers( + {"test": MCPServerConfig(command="fake")}, + registry, + ) + for stack in stacks.values(): await stack.aclose() assert registry.tool_names == ["mcp_test_demo", "mcp_test_other"] @@ -311,15 +303,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_wrapped_names( ) -> None: fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"]) registry = ToolRegistry() - stack = AsyncExitStack() - await stack.__aenter__() - try: - await connect_mcp_servers( - {"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])}, - registry, - stack, - ) - finally: + stacks = await connect_mcp_servers( + {"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])}, + registry, + ) + for stack in stacks.values(): await stack.aclose() assert registry.tool_names == ["mcp_test_demo"] @@ -331,15 +319,11 @@ async def test_connect_mcp_servers_enabled_tools_empty_list_registers_none( ) -> None: fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"]) registry = ToolRegistry() - stack = AsyncExitStack() - await stack.__aenter__() - try: - await connect_mcp_servers( - {"test": MCPServerConfig(command="fake", enabled_tools=[])}, - registry, - stack, - ) - finally: + stacks = await connect_mcp_servers( + {"test": MCPServerConfig(command="fake", enabled_tools=[])}, + registry, + ) + for stack in stacks.values(): await stack.aclose() assert registry.tool_names == [] @@ -358,15 +342,11 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries( monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning) - stack = AsyncExitStack() - await stack.__aenter__() - try: - await connect_mcp_servers( - {"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])}, - registry, - stack, - ) - finally: + stacks = await connect_mcp_servers( + {"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])}, + registry, + ) + for stack in stacks.values(): await stack.aclose() assert registry.tool_names == [] @@ -389,9 +369,7 @@ def _make_resource_def( return SimpleNamespace(name=name, uri=uri, description=description) -def _make_resource_wrapper( - session: object, *, timeout: float = 0.1 -) -> MCPResourceWrapper: +def _make_resource_wrapper(session: object, *, timeout: float = 0.1) -> MCPResourceWrapper: return MCPResourceWrapper(session, "srv", _make_resource_def(), resource_timeout=timeout) @@ -434,9 +412,7 @@ async def test_resource_wrapper_execute_handles_timeout() -> None: await asyncio.sleep(1) return SimpleNamespace(contents=[]) - wrapper = _make_resource_wrapper( - SimpleNamespace(read_resource=read_resource), timeout=0.01 - ) + wrapper = _make_resource_wrapper(SimpleNamespace(read_resource=read_resource), timeout=0.01) result = await wrapper.execute() assert result == "(MCP resource read timed out after 0.01s)" @@ -464,20 +440,14 @@ def _make_prompt_def( return SimpleNamespace(name=name, description=description, arguments=arguments) -def _make_prompt_wrapper( - session: object, *, timeout: float = 0.1 -) -> MCPPromptWrapper: - return MCPPromptWrapper( - session, "srv", _make_prompt_def(), prompt_timeout=timeout - ) +def _make_prompt_wrapper(session: object, *, timeout: float = 0.1) -> MCPPromptWrapper: + return MCPPromptWrapper(session, "srv", _make_prompt_def(), prompt_timeout=timeout) def test_prompt_wrapper_properties() -> None: arg1 = SimpleNamespace(name="topic", required=True) arg2 = SimpleNamespace(name="style", required=False) - wrapper = MCPPromptWrapper( - None, "myserver", _make_prompt_def(arguments=[arg1, arg2]) - ) + wrapper = MCPPromptWrapper(None, "myserver", _make_prompt_def(arguments=[arg1, arg2])) assert wrapper.name == "mcp_myserver_prompt_myprompt" assert "[MCP Prompt]" in wrapper.description assert "A test prompt" in wrapper.description @@ -528,9 +498,7 @@ async def test_prompt_wrapper_execute_handles_timeout() -> None: await asyncio.sleep(1) return SimpleNamespace(messages=[]) - wrapper = _make_prompt_wrapper( - SimpleNamespace(get_prompt=get_prompt), timeout=0.01 - ) + wrapper = _make_prompt_wrapper(SimpleNamespace(get_prompt=get_prompt), timeout=0.01) result = await wrapper.execute() assert result == "(MCP prompt call timed out after 0.01s)" @@ -616,15 +584,11 @@ async def test_connect_registers_resources_and_prompts( prompt_names=["prompt_c"], ) registry = ToolRegistry() - stack = AsyncExitStack() - await stack.__aenter__() - try: - await connect_mcp_servers( - {"test": MCPServerConfig(command="fake")}, - registry, - stack, - ) - finally: + stacks = await connect_mcp_servers( + {"test": MCPServerConfig(command="fake")}, + registry, + ) + for stack in stacks.values(): await stack.aclose() assert "mcp_test_tool_a" in registry.tool_names From 696b64b5a6d0246c898ea3bcd50140cfb7f0b3a7 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Fri, 10 Apr 2026 16:02:00 +0000 Subject: [PATCH 057/115] fix(notebook): remove unused imports Clean up unused imports in notebook_edit so the Ruff F401 check passes cleanly. Made-with: Cursor --- nanobot/agent/tools/notebook.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nanobot/agent/tools/notebook.py b/nanobot/agent/tools/notebook.py index 8c4be110d..fa53809f1 100644 --- a/nanobot/agent/tools/notebook.py +++ b/nanobot/agent/tools/notebook.py @@ -4,10 +4,9 @@ from __future__ import annotations import json import uuid -from pathlib import Path from typing import Any -from nanobot.agent.tools.base import Tool, tool_parameters +from nanobot.agent.tools.base import tool_parameters from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema from nanobot.agent.tools.filesystem import _FsTool From e392c27f7e6981313198c078dd6b3ef01aeb4b1f Mon Sep 17 00:00:00 2001 From: 04cb <0x04cb@gmail.com> Date: Sat, 11 Apr 2026 00:47:23 +0800 Subject: [PATCH 058/115] fix(utils): anchor unclosed think-tag regex to string start (#3004) --- nanobot/utils/helpers.py | 4 ++-- tests/utils/test_strip_think.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index 3b4f9f25a..1bfd9f18b 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -17,10 +17,10 @@ from loguru import logger def strip_think(text: str) -> str: """Remove thinking blocks and any unclosed trailing tag.""" text = re.sub(r"[\s\S]*?", "", text) - text = re.sub(r"[\s\S]*$", "", text) + text = re.sub(r"^\s*[\s\S]*$", "", text) # Gemma 4 and similar models use ... blocks text = re.sub(r"[\s\S]*?", "", text) - text = re.sub(r"[\s\S]*$", "", text) + text = re.sub(r"^\s*[\s\S]*$", "", text) return text.strip() diff --git a/tests/utils/test_strip_think.py b/tests/utils/test_strip_think.py index 6710dfc93..5828c6d1f 100644 --- a/tests/utils/test_strip_think.py +++ b/tests/utils/test_strip_think.py @@ -34,3 +34,32 @@ class TestStripThinkTag: def test_empty_string(self): assert strip_think("") == "" + + +class TestStripThinkFalsePositive: + """Ensure mid-content / tags are NOT stripped (#3004).""" + + def test_backtick_think_tag_preserved(self): + text = "*Think Stripping:* A new utility to strip `` tags from output." + assert strip_think(text) == text + + def test_prose_think_tag_preserved(self): + text = "The model emits at the start of its response." + assert strip_think(text) == text + + def test_code_block_think_tag_preserved(self): + text = "Example:\n```\ntext = re.sub(r\"[\\s\\S]*\", \"\", text)\n```\nDone." + assert strip_think(text) == text + + def test_backtick_thought_tag_preserved(self): + text = "Gemma 4 uses `` blocks for reasoning." + assert strip_think(text) == text + + def test_prefix_unclosed_think_still_stripped(self): + assert strip_think("reasoning without closing") == "" + + def test_prefix_unclosed_think_with_whitespace(self): + assert strip_think(" reasoning...") == "" + + def test_prefix_unclosed_thought_still_stripped(self): + assert strip_think("reasoning without closing") == "" From b52bfddf16636f85d5805fce393721d49eb5223c Mon Sep 17 00:00:00 2001 From: Daniel Phang Date: Sat, 11 Apr 2026 00:34:48 -0700 Subject: [PATCH 059/115] fix(cron): guard _load_store against reentrant reload during job execution When on_job callbacks call list_jobs() (which triggers _load_store), the in-memory state is reloaded from disk, discarding the next_run_at_ms updates that _on_timer is actively computing. This causes jobs to re-trigger indefinitely on the next tick. Add an _executing flag around the job execution loop. While set, _load_store returns the cached store instead of reloading from disk. Includes regression test. Co-Authored-By: Claude Opus 4.6 (1M context) --- nanobot/cron/service.py | 13 ++++++++-- tests/cron/test_cron_service.py | 46 +++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/nanobot/cron/service.py b/nanobot/cron/service.py index 267613012..d3f74fbf8 100644 --- a/nanobot/cron/service.py +++ b/nanobot/cron/service.py @@ -80,6 +80,7 @@ class CronService: self._store: CronStore | None = None self._timer_task: asyncio.Task | None = None self._running = False + self._executing = False self.max_sleep_ms = max_sleep_ms def _load_jobs(self) -> tuple[list[CronJob], int]: @@ -171,7 +172,11 @@ class CronService: def _load_store(self) -> CronStore: """Load jobs from disk. Reloads automatically if file was modified externally. - Reload every time because it needs to merge operations on the jobs object from other instances. + - Skip reload when _executing to prevent on_job callbacks (e.g. list_jobs) + from replacing in-memory state that _on_timer is actively modifying. """ + if self._executing and self._store is not None: + return self._store jobs, version = self._load_jobs() self._store = CronStore(version=version, jobs=jobs) self._merge_action() @@ -298,8 +303,12 @@ class CronService: if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms ] - for job in due_jobs: - await self._execute_job(job) + self._executing = True + try: + for job in due_jobs: + await self._execute_job(job) + finally: + self._executing = False self._save_store() self._arm_timer() diff --git a/tests/cron/test_cron_service.py b/tests/cron/test_cron_service.py index f1956d8d2..4aa7fc06d 100644 --- a/tests/cron/test_cron_service.py +++ b/tests/cron/test_cron_service.py @@ -479,3 +479,49 @@ def test_update_job_sentinel_channel_and_to(tmp_path) -> None: assert isinstance(result, CronJob) assert result.payload.channel is None assert result.payload.to is None + + +@pytest.mark.asyncio +async def test_list_jobs_during_on_job_does_not_cause_stale_reload(tmp_path) -> None: + """Regression: if the bot calls list_jobs (which reloads from disk) during + on_job execution, the in-memory next_run_at_ms update must not be lost. + Previously this caused an infinite re-trigger loop.""" + store_path = tmp_path / "cron" / "jobs.json" + execution_count = 0 + + async def on_job_that_lists(job): + nonlocal execution_count + execution_count += 1 + # Simulate the bot calling cron(action=list) mid-execution + service.list_jobs() + + service = CronService(store_path, on_job=on_job_that_lists, max_sleep_ms=100) + await service.start() + + # Add two jobs scheduled in the past so they're immediately due + now_ms = int(time.time() * 1000) + for name in ("job-a", "job-b"): + service.add_job( + name=name, + schedule=CronSchedule(kind="every", every_ms=3_600_000), + message="test", + ) + # Force next_run to the past so _on_timer picks them up + for job in service._store.jobs: + job.state.next_run_at_ms = now_ms - 1000 + service._save_store() + service._arm_timer() + + # Let the timer fire once + await asyncio.sleep(0.3) + service.stop() + + # Each job should have run exactly once, not looped + assert execution_count == 2 + + # Verify next_run_at_ms was persisted correctly (in the future) + raw = json.loads(store_path.read_text()) + for j in raw["jobs"]: + next_run = j["state"]["nextRunAtMs"] + assert next_run is not None + assert next_run > now_ms, f"Job '{j['name']}' next_run should be in the future" From fb6dd111e1c11afaaa2a7d7ea5ee4a65ce4787d1 Mon Sep 17 00:00:00 2001 From: chengyongru <61816729+chengyongru@users.noreply.github.com> Date: Fri, 10 Apr 2026 17:43:42 +0800 Subject: [PATCH 060/115] =?UTF-8?q?feat(agent):=20auto=20compact=20?= =?UTF-8?q?=E2=80=94=20proactive=20session=20compression=20to=20reduce=20t?= =?UTF-8?q?oken=20cost=20and=20latency=20(#2982)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a user is idle for longer than a configured TTL, nanobot **proactively** compresses the session context into a summary. This reduces token cost and first-token latency when the user returns — instead of re-processing a long stale context with an expired KV cache, the model receives a compact summary and fresh input. --- README.md | 26 + nanobot/agent/auto_compact.py | 82 +++ nanobot/agent/context.py | 9 +- nanobot/agent/loop.py | 38 +- nanobot/agent/memory.py | 4 + nanobot/cli/commands.py | 3 + nanobot/config/schema.py | 1 + nanobot/nanobot.py | 1 + nanobot/session/manager.py | 3 + tests/agent/test_auto_compact.py | 931 +++++++++++++++++++++++++++++++ 10 files changed, 1091 insertions(+), 7 deletions(-) create mode 100644 nanobot/agent/auto_compact.py create mode 100644 tests/agent/test_auto_compact.py diff --git a/README.md b/README.md index 6098c55ca..a9bf4b5e3 100644 --- a/README.md +++ b/README.md @@ -1503,6 +1503,32 @@ MCP tools are automatically discovered and registered on startup. The LLM can us **Docker security**: The official Docker image runs as a non-root user (`nanobot`, UID 1000) with bubblewrap pre-installed. When using `docker-compose.yml`, the container drops all Linux capabilities except `SYS_ADMIN` (required for bwrap's namespace isolation). +### Auto Compact + +When a user is idle for longer than a configured TTL, nanobot **proactively** compresses the session context into a summary. This reduces token cost and first-token latency when the user returns — instead of re-processing a long stale context with an expired KV cache, the model receives a compact summary and fresh input. + +```json +{ + "agents": { + "defaults": { + "sessionTtlMinutes": 15 + } + } +} +``` + +| Option | Default | Description | +|--------|---------|-------------| +| `agents.defaults.sessionTtlMinutes` | `0` (disabled) | Minutes of idle time before auto-compaction. Set to `0` to disable. Recommended: `15` — matches typical LLM KV cache expiration, so compacted sessions won't waste cache on cold entries. | + +How it works: +1. **Idle detection**: On each idle tick (~1 s), checks all sessions for expiration. +2. **Background compaction**: Expired sessions are summarized via LLM, then cleared. +3. **Summary injection**: When the user returns, the summary is injected as runtime context (one-shot, not persisted). + +> [!TIP] +> The summary survives bot restarts — it's stored in session metadata and recovered on the next message. + ### Timezone Time is context. Context should be precise. diff --git a/nanobot/agent/auto_compact.py b/nanobot/agent/auto_compact.py new file mode 100644 index 000000000..171f5f55a --- /dev/null +++ b/nanobot/agent/auto_compact.py @@ -0,0 +1,82 @@ +"""Auto compact: proactive compression of idle sessions to reduce token cost and latency.""" + +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING, Callable, Coroutine + +from loguru import logger + +if TYPE_CHECKING: + from nanobot.agent.memory import Consolidator + from nanobot.session.manager import Session, SessionManager + + +class AutoCompact: + def __init__(self, sessions: SessionManager, consolidator: Consolidator, + session_ttl_minutes: int = 0): + self.sessions = sessions + self.consolidator = consolidator + self._ttl = session_ttl_minutes + self._archiving: set[str] = set() + self._summaries: dict[str, tuple[str, datetime]] = {} + + def _is_expired(self, ts: datetime | str | None) -> bool: + if self._ttl <= 0 or not ts: + return False + if isinstance(ts, str): + ts = datetime.fromisoformat(ts) + return (datetime.now() - ts).total_seconds() >= self._ttl * 60 + + @staticmethod + def _format_summary(text: str, last_active: datetime) -> str: + idle_min = int((datetime.now() - last_active).total_seconds() / 60) + return f"Inactive for {idle_min} minutes.\nPrevious conversation summary: {text}" + + def check_expired(self, schedule_background: Callable[[Coroutine], None]) -> None: + for info in self.sessions.list_sessions(): + key = info.get("key", "") + if key and key not in self._archiving and self._is_expired(info.get("updated_at")): + self._archiving.add(key) + logger.debug("Auto-compact: scheduling archival for {} (idle > {} min)", key, self._ttl) + schedule_background(self._archive(key)) + + async def _archive(self, key: str) -> None: + try: + self.sessions.invalidate(key) + session = self.sessions.get_or_create(key) + msgs = session.messages[session.last_consolidated:] + if not msgs: + logger.debug("Auto-compact: skipping {}, no un-consolidated messages", key) + session.updated_at = datetime.now() + self.sessions.save(session) + return + n = len(msgs) + last_active = session.updated_at + await self.consolidator.archive(msgs) + entry = self.consolidator.get_last_history_entry() + summary = (entry or {}).get("content", "") + if summary and summary != "(nothing)": + self._summaries[key] = (summary, last_active) + session.metadata["_last_summary"] = {"text": summary, "last_active": last_active.isoformat()} + session.clear() + self.sessions.save(session) + logger.info("Auto-compact: archived {} ({} messages, summary={})", key, n, bool(summary)) + except Exception: + logger.exception("Auto-compact: failed for {}", key) + finally: + self._archiving.discard(key) + + def prepare_session(self, session: Session, key: str) -> tuple[Session, str | None]: + if key in self._archiving or self._is_expired(session.updated_at): + logger.info("Auto-compact: reloading session {} (archiving={})", key, key in self._archiving) + session = self.sessions.get_or_create(key) + entry = self._summaries.pop(key, None) + if entry: + session.metadata.pop("_last_summary", None) + return session, self._format_summary(entry[0], entry[1]) + if not session.messages and "_last_summary" in session.metadata: + meta = session.metadata.pop("_last_summary") + self.sessions.save(session) + return session, self._format_summary(meta["text"], datetime.fromisoformat(meta["last_active"])) + return session, None diff --git a/nanobot/agent/context.py b/nanobot/agent/context.py index 3ac19e7f3..e3460ddfd 100644 --- a/nanobot/agent/context.py +++ b/nanobot/agent/context.py @@ -20,6 +20,7 @@ class ContextBuilder: BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md"] _RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]" _MAX_RECENT_HISTORY = 50 + _RUNTIME_CONTEXT_END = "[/Runtime Context]" def __init__(self, workspace: Path, timezone: str | None = None): self.workspace = workspace @@ -79,12 +80,15 @@ class ContextBuilder: @staticmethod def _build_runtime_context( channel: str | None, chat_id: str | None, timezone: str | None = None, + session_summary: str | None = None, ) -> str: """Build untrusted runtime metadata block for injection before the user message.""" lines = [f"Current Time: {current_time_str(timezone)}"] if channel and chat_id: lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"] - return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines) + if session_summary: + lines += ["", "[Resumed Session]", session_summary] + return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines) + "\n" + ContextBuilder._RUNTIME_CONTEXT_END @staticmethod def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]: @@ -121,9 +125,10 @@ class ContextBuilder: channel: str | None = None, chat_id: str | None = None, current_role: str = "user", + session_summary: str | None = None, ) -> list[dict[str, Any]]: """Build the complete message list for an LLM call.""" - runtime_ctx = self._build_runtime_context(channel, chat_id, self.timezone) + runtime_ctx = self._build_runtime_context(channel, chat_id, self.timezone, session_summary=session_summary) user_content = self._build_user_content(current_message, media) # Merge runtime context and user content into a single user message diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index bc83cc77c..65a5a1abc 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable from loguru import logger +from nanobot.agent.auto_compact import AutoCompact from nanobot.agent.context import ContextBuilder from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook from nanobot.agent.memory import Consolidator, Dream @@ -145,6 +146,7 @@ class AgentLoop: mcp_servers: dict | None = None, channels_config: ChannelsConfig | None = None, timezone: str | None = None, + session_ttl_minutes: int = 0, hooks: list[AgentHook] | None = None, unified_session: bool = False, ): @@ -217,6 +219,11 @@ class AgentLoop: get_tool_definitions=self.tools.get_definitions, max_completion_tokens=provider.generation.max_tokens, ) + self.auto_compact = AutoCompact( + sessions=self.sessions, + consolidator=self.consolidator, + session_ttl_minutes=session_ttl_minutes, + ) self.dream = Dream( store=self.context.memory, provider=provider, @@ -371,6 +378,7 @@ class AgentLoop: try: msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0) except asyncio.TimeoutError: + self.auto_compact.check_expired(self._schedule_background) continue except asyncio.CancelledError: # Preserve real task cancellation so shutdown can complete cleanly. @@ -497,13 +505,18 @@ class AgentLoop: session = self.sessions.get_or_create(key) if self._restore_runtime_checkpoint(session): self.sessions.save(session) + + session, pending = self.auto_compact.prepare_session(session, key) + await self.consolidator.maybe_consolidate_by_tokens(session) self._set_tool_context(channel, chat_id, msg.metadata.get("message_id")) history = session.get_history(max_messages=0) current_role = "assistant" if msg.sender_id == "subagent" else "user" + messages = self.context.build_messages( history=history, current_message=msg.content, channel=channel, chat_id=chat_id, + session_summary=pending, current_role=current_role, ) final_content, _, all_msgs, _ = await self._run_agent_loop( @@ -525,6 +538,8 @@ class AgentLoop: if self._restore_runtime_checkpoint(session): self.sessions.save(session) + session, pending = self.auto_compact.prepare_session(session, key) + # Slash commands raw = msg.content.strip() ctx = CommandContext(msg=msg, session=session, key=key, raw=raw, loop=self) @@ -539,9 +554,11 @@ class AgentLoop: message_tool.start_turn() history = session.get_history(max_messages=0) + initial_messages = self.context.build_messages( history=history, current_message=msg.content, + session_summary=pending, media=msg.media if msg.media else None, channel=msg.channel, chat_id=msg.chat_id, ) @@ -645,12 +662,23 @@ class AgentLoop: entry["content"] = filtered elif role == "user": if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG): - # Strip the runtime-context prefix, keep only the user text. - parts = content.split("\n\n", 1) - if len(parts) > 1 and parts[1].strip(): - entry["content"] = parts[1] + # Strip the entire runtime-context block (including any session summary). + # The block is bounded by _RUNTIME_CONTEXT_TAG and _RUNTIME_CONTEXT_END. + end_marker = ContextBuilder._RUNTIME_CONTEXT_END + end_pos = content.find(end_marker) + if end_pos >= 0: + after = content[end_pos + len(end_marker):].lstrip("\n") + if after: + entry["content"] = after + else: + continue else: - continue + # Fallback: no end marker found, strip the tag prefix + after_tag = content[len(ContextBuilder._RUNTIME_CONTEXT_TAG):].lstrip("\n") + if after_tag.strip(): + entry["content"] = after_tag + else: + continue if isinstance(content, list): filtered = self._sanitize_persisted_blocks(content, drop_runtime=True) if not filtered: diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py index 943d91855..26c5cd45f 100644 --- a/nanobot/agent/memory.py +++ b/nanobot/agent/memory.py @@ -374,6 +374,10 @@ class Consolidator: weakref.WeakValueDictionary() ) + def get_last_history_entry(self) -> dict[str, Any] | None: + """Return the most recent entry from history.jsonl.""" + return self.store._read_last_entry() + def get_lock(self, session_key: str) -> asyncio.Lock: """Return the shared consolidation lock for one session.""" return self._locks.setdefault(session_key, asyncio.Lock()) diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 5ce8b7937..9d818a9db 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -591,6 +591,7 @@ def serve( channels_config=runtime_config.channels, timezone=runtime_config.agents.defaults.timezone, unified_session=runtime_config.agents.defaults.unified_session, + session_ttl_minutes=runtime_config.agents.defaults.session_ttl_minutes, ) model_name = runtime_config.agents.defaults.model @@ -683,6 +684,7 @@ def gateway( channels_config=config.channels, timezone=config.agents.defaults.timezone, unified_session=config.agents.defaults.unified_session, + session_ttl_minutes=config.agents.defaults.session_ttl_minutes, ) # Set cron callback (needs agent) @@ -915,6 +917,7 @@ def agent( channels_config=config.channels, timezone=config.agents.defaults.timezone, unified_session=config.agents.defaults.unified_session, + session_ttl_minutes=config.agents.defaults.session_ttl_minutes, ) restart_notice = consume_restart_notice_from_env() if restart_notice and should_show_cli_restart_notice(restart_notice, session_id): diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 2d31c8bf9..8ab68d7b5 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -77,6 +77,7 @@ class AgentDefaults(Base): reasoning_effort: str | None = None # low / medium / high / adaptive - enables LLM thinking mode timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York" unified_session: bool = False # Share one session across all channels (single-user multi-device) + session_ttl_minutes: int = Field(default=0, ge=0) # Auto /new after idle (0 = disabled) dream: DreamConfig = Field(default_factory=DreamConfig) diff --git a/nanobot/nanobot.py b/nanobot/nanobot.py index 9166acb27..df0e49842 100644 --- a/nanobot/nanobot.py +++ b/nanobot/nanobot.py @@ -82,6 +82,7 @@ class Nanobot: mcp_servers=config.tools.mcp_servers, timezone=defaults.timezone, unified_session=defaults.unified_session, + session_ttl_minutes=defaults.session_ttl_minutes, ) return cls(loop) diff --git a/nanobot/session/manager.py b/nanobot/session/manager.py index 27df31405..2ed0624a2 100644 --- a/nanobot/session/manager.py +++ b/nanobot/session/manager.py @@ -155,6 +155,7 @@ class SessionManager: messages = [] metadata = {} created_at = None + updated_at = None last_consolidated = 0 with open(path, encoding="utf-8") as f: @@ -168,6 +169,7 @@ class SessionManager: if data.get("_type") == "metadata": metadata = data.get("metadata", {}) created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None + updated_at = datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else None last_consolidated = data.get("last_consolidated", 0) else: messages.append(data) @@ -176,6 +178,7 @@ class SessionManager: key=key, messages=messages, created_at=created_at or datetime.now(), + updated_at=updated_at or datetime.now(), metadata=metadata, last_consolidated=last_consolidated ) diff --git a/tests/agent/test_auto_compact.py b/tests/agent/test_auto_compact.py new file mode 100644 index 000000000..8b26254e9 --- /dev/null +++ b/tests/agent/test_auto_compact.py @@ -0,0 +1,931 @@ +"""Tests for auto compact (idle TTL) feature.""" + +import asyncio +from datetime import datetime, timedelta +from unittest.mock import AsyncMock, MagicMock +from pathlib import Path + +import pytest + +from nanobot.agent.loop import AgentLoop +from nanobot.bus.events import InboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.config.schema import AgentDefaults +from nanobot.command import CommandContext +from nanobot.providers.base import LLMResponse + + +def _make_loop(tmp_path: Path, session_ttl_minutes: int = 15) -> AgentLoop: + """Create a minimal AgentLoop for testing.""" + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + provider.estimate_prompt_tokens.return_value = (10_000, "test") + provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) + provider.generation.max_tokens = 4096 + loop = AgentLoop( + bus=bus, + provider=provider, + workspace=tmp_path, + model="test-model", + context_window_tokens=128_000, + session_ttl_minutes=session_ttl_minutes, + ) + loop.tools.get_definitions = MagicMock(return_value=[]) + return loop + + +class TestSessionTTLConfig: + """Test session TTL configuration.""" + + def test_default_ttl_is_zero(self): + """Default TTL should be 0 (disabled).""" + defaults = AgentDefaults() + assert defaults.session_ttl_minutes == 0 + + def test_custom_ttl(self): + """Custom TTL should be stored correctly.""" + defaults = AgentDefaults(session_ttl_minutes=30) + assert defaults.session_ttl_minutes == 30 + + +class TestAgentLoopTTLParam: + """Test that AutoCompact receives and stores session_ttl_minutes.""" + + def test_loop_stores_ttl(self, tmp_path): + """AutoCompact should store the TTL value.""" + loop = _make_loop(tmp_path, session_ttl_minutes=25) + assert loop.auto_compact._ttl == 25 + + def test_loop_default_ttl_zero(self, tmp_path): + """AutoCompact default TTL should be 0 (disabled).""" + loop = _make_loop(tmp_path, session_ttl_minutes=0) + assert loop.auto_compact._ttl == 0 + + +class TestAutoCompact: + """Test the _archive method.""" + + @pytest.mark.asyncio + async def test_is_expired_boundary(self, tmp_path): + """Exactly at TTL boundary should be expired (>= not >).""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + ts = datetime.now() - timedelta(minutes=15) + assert loop.auto_compact._is_expired(ts) is True + ts2 = datetime.now() - timedelta(minutes=14, seconds=59) + assert loop.auto_compact._is_expired(ts2) is False + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_is_expired_string_timestamp(self, tmp_path): + """_is_expired should parse ISO string timestamps.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + ts = (datetime.now() - timedelta(minutes=20)).isoformat() + assert loop.auto_compact._is_expired(ts) is True + assert loop.auto_compact._is_expired(None) is False + assert loop.auto_compact._is_expired("") is False + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_check_expired_only_archives_expired_sessions(self, tmp_path): + """With multiple sessions, only the expired one should be archived.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + # Expired session + s1 = loop.sessions.get_or_create("cli:expired") + s1.add_message("user", "old") + s1.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(s1) + # Active session + s2 = loop.sessions.get_or_create("cli:active") + s2.add_message("user", "recent") + loop.sessions.save(s2) + + async def _fake_archive(messages): + return True + + loop.consolidator.archive = _fake_archive + loop.auto_compact.check_expired(loop._schedule_background) + await asyncio.sleep(0.1) + + active_after = loop.sessions.get_or_create("cli:active") + assert len(active_after.messages) == 1 + assert active_after.messages[0]["content"] == "recent" + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_auto_compact_archives_and_clears(self, tmp_path): + """_archive should archive un-consolidated messages and clear session.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + for i in range(4): + session.add_message("user", f"msg{i}") + session.add_message("assistant", f"resp{i}") + loop.sessions.save(session) + + archived_messages = [] + + async def _fake_archive(messages): + archived_messages.extend(messages) + return True + + loop.consolidator.archive = _fake_archive + + await loop.auto_compact._archive("cli:test") + + assert len(archived_messages) == 8 + session_after = loop.sessions.get_or_create("cli:test") + assert len(session_after.messages) == 0 + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_auto_compact_stores_summary(self, tmp_path): + """_archive should store the summary in _summaries.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + session.add_message("user", "hello") + session.add_message("assistant", "hi there") + loop.sessions.save(session) + + async def _fake_archive(messages): + return True + + loop.consolidator.archive = _fake_archive + loop.consolidator.get_last_history_entry = lambda: { + "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "User said hello.", + } + + await loop.auto_compact._archive("cli:test") + + entry = loop.auto_compact._summaries.get("cli:test") + assert entry is not None + assert entry[0] == "User said hello." + session_after = loop.sessions.get_or_create("cli:test") + assert len(session_after.messages) == 0 + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_auto_compact_empty_session(self, tmp_path): + """_archive on empty session should not archive.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + + archive_called = False + + async def _fake_archive(messages): + nonlocal archive_called + archive_called = True + return True + + loop.consolidator.archive = _fake_archive + + await loop.auto_compact._archive("cli:test") + + assert not archive_called + session_after = loop.sessions.get_or_create("cli:test") + assert len(session_after.messages) == 0 + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_auto_compact_respects_last_consolidated(self, tmp_path): + """_archive should only archive un-consolidated messages.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + for i in range(10): + session.add_message("user", f"msg{i}") + session.add_message("assistant", f"resp{i}") + session.last_consolidated = 18 + loop.sessions.save(session) + + archived_count = 0 + + async def _fake_archive(messages): + nonlocal archived_count + archived_count = len(messages) + return True + + loop.consolidator.archive = _fake_archive + + await loop.auto_compact._archive("cli:test") + + assert archived_count == 2 + await loop.close_mcp() + + +class TestAutoCompactIdleDetection: + """Test idle detection triggers auto-new in _process_message.""" + + @pytest.mark.asyncio + async def test_no_auto_compact_when_ttl_disabled(self, tmp_path): + """No auto-new should happen when TTL is 0 (disabled).""" + loop = _make_loop(tmp_path, session_ttl_minutes=0) + session = loop.sessions.get_or_create("cli:test") + session.add_message("user", "old message") + session.updated_at = datetime.now() - timedelta(minutes=30) + loop.sessions.save(session) + + msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="new msg") + await loop._process_message(msg) + + session_after = loop.sessions.get_or_create("cli:test") + assert any(m["content"] == "old message" for m in session_after.messages) + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_auto_compact_triggers_on_idle(self, tmp_path): + """Proactive auto-new archives expired session; _process_message reloads it.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + session.add_message("user", "old message") + session.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(session) + + archived_messages = [] + + async def _fake_archive(messages): + archived_messages.extend(messages) + return True + + loop.consolidator.archive = _fake_archive + loop.consolidator.get_last_history_entry = lambda: { + "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "Summary.", + } + + # Simulate proactive archive completing before message arrives + await loop.auto_compact._archive("cli:test") + + msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="new msg") + await loop._process_message(msg) + + session_after = loop.sessions.get_or_create("cli:test") + assert not any(m["content"] == "old message" for m in session_after.messages) + assert any(m["content"] == "new msg" for m in session_after.messages) + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_no_auto_compact_when_active(self, tmp_path): + """No auto-new should happen when session is recently active.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + session.add_message("user", "recent message") + loop.sessions.save(session) + + msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="new msg") + await loop._process_message(msg) + + session_after = loop.sessions.get_or_create("cli:test") + assert any(m["content"] == "recent message" for m in session_after.messages) + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_auto_compact_does_not_affect_priority_commands(self, tmp_path): + """Priority commands (/stop, /restart) bypass _process_message entirely via run().""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + session.add_message("user", "old message") + session.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(session) + + # Priority commands are dispatched in run() before _process_message is called. + # Simulate that path directly via dispatch_priority. + raw = "/stop" + msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content=raw) + ctx = CommandContext(msg=msg, session=session, key="cli:test", raw=raw, loop=loop) + result = await loop.commands.dispatch_priority(ctx) + assert result is not None + assert "stopped" in result.content.lower() or "no active task" in result.content.lower() + + # Session should be untouched since priority commands skip _process_message + session_after = loop.sessions.get_or_create("cli:test") + assert any(m["content"] == "old message" for m in session_after.messages) + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_auto_compact_with_slash_new(self, tmp_path): + """Auto-new fires before /new dispatches; session is cleared twice but idempotent.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + for i in range(4): + session.add_message("user", f"msg{i}") + session.add_message("assistant", f"resp{i}") + session.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(session) + + async def _fake_archive(messages): + return True + + loop.consolidator.archive = _fake_archive + + msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") + response = await loop._process_message(msg) + + assert response is not None + assert "new session started" in response.content.lower() + + session_after = loop.sessions.get_or_create("cli:test") + # Session is empty (auto-new archived and cleared, /new cleared again) + assert len(session_after.messages) == 0 + await loop.close_mcp() + + +class TestAutoCompactSystemMessages: + """Test that auto-new also works for system messages.""" + + @pytest.mark.asyncio + async def test_auto_compact_triggers_for_system_messages(self, tmp_path): + """Proactive auto-new archives expired session; system messages reload it.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + session.add_message("user", "old message from subagent context") + session.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(session) + + async def _fake_archive(messages): + return True + + loop.consolidator.archive = _fake_archive + loop.consolidator.get_last_history_entry = lambda: { + "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "Summary.", + } + + # Simulate proactive archive completing before system message arrives + await loop.auto_compact._archive("cli:test") + + msg = InboundMessage( + channel="system", sender_id="subagent", chat_id="cli:test", + content="subagent result", + ) + await loop._process_message(msg) + + session_after = loop.sessions.get_or_create("cli:test") + assert not any( + m["content"] == "old message from subagent context" + for m in session_after.messages + ) + await loop.close_mcp() + + +class TestAutoCompactEdgeCases: + """Edge cases for auto session new.""" + + @pytest.mark.asyncio + async def test_auto_compact_with_nothing_summary(self, tmp_path): + """Auto-new should not inject when archive produces '(nothing)'.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + session.add_message("user", "thanks") + session.add_message("assistant", "you're welcome") + session.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(session) + + loop.provider.chat_with_retry = AsyncMock( + return_value=LLMResponse(content="(nothing)", tool_calls=[]) + ) + + await loop.auto_compact._archive("cli:test") + + session_after = loop.sessions.get_or_create("cli:test") + assert len(session_after.messages) == 0 + # "(nothing)" summary should not be stored + assert "cli:test" not in loop.auto_compact._summaries + + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_auto_compact_archive_failure_still_clears(self, tmp_path): + """Auto-new should clear session even if LLM archive fails (raw_archive fallback).""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + session.add_message("user", "important data") + session.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(session) + + loop.provider.chat_with_retry = AsyncMock(side_effect=Exception("API down")) + + # Should not raise + await loop.auto_compact._archive("cli:test") + + session_after = loop.sessions.get_or_create("cli:test") + # Session should be cleared (archive falls back to raw dump) + assert len(session_after.messages) == 0 + + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_auto_compact_preserves_runtime_checkpoint_before_check(self, tmp_path): + """Runtime checkpoint is restored; proactive archive handles the expired session.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + session.metadata[AgentLoop._RUNTIME_CHECKPOINT_KEY] = { + "assistant_message": {"role": "assistant", "content": "interrupted response"}, + "completed_tool_results": [], + "pending_tool_calls": [], + } + session.add_message("user", "previous message") + session.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(session) + + archived_messages = [] + + async def _fake_archive(messages): + archived_messages.extend(messages) + return True + + loop.consolidator.archive = _fake_archive + loop.consolidator.get_last_history_entry = lambda: { + "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "Summary.", + } + + # Simulate proactive archive completing before message arrives + await loop.auto_compact._archive("cli:test") + + msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="continue") + await loop._process_message(msg) + + # The checkpoint-restored message should have been archived by proactive path + assert len(archived_messages) >= 1 + + await loop.close_mcp() + + +class TestAutoCompactIntegration: + """End-to-end test of auto session new feature.""" + + @pytest.mark.asyncio + async def test_full_lifecycle(self, tmp_path): + """ + Full lifecycle: messages -> idle -> auto-new -> archive -> clear -> summary injected as runtime context. + """ + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + + # Phase 1: User has a conversation + session.add_message("user", "I'm learning English, teach me past tense") + session.add_message("assistant", "Past tense is used for actions completed in the past...") + session.add_message("user", "Give me an example") + session.add_message("assistant", '"I walked to the store yesterday."') + loop.sessions.save(session) + + # Phase 2: Time passes (simulate idle) + session.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(session) + + # Phase 3: User returns with a new message + loop.provider.chat_with_retry = AsyncMock( + return_value=LLMResponse( + content="User is learning English past tense. Example: 'I walked to the store yesterday.'", + tool_calls=[], + ) + ) + + msg = InboundMessage( + channel="cli", sender_id="user", chat_id="test", + content="Let's continue, teach me present perfect", + ) + response = await loop._process_message(msg) + + # Phase 4: Verify + session_after = loop.sessions.get_or_create("cli:test") + + # Old messages should be gone + assert not any( + "past tense is used" in str(m.get("content", "")) for m in session_after.messages + ) + + # Summary should NOT be persisted in session (ephemeral, one-shot) + assert not any( + "[Resumed Session]" in str(m.get("content", "")) for m in session_after.messages + ) + # Runtime context end marker should NOT be persisted + assert not any( + "[/Runtime Context]" in str(m.get("content", "")) for m in session_after.messages + ) + + # Pending summary should be consumed (one-shot) + assert "cli:test" not in loop.auto_compact._summaries + + # The new message should be processed (response exists) + assert response is not None + + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_multi_paragraph_user_message_preserved(self, tmp_path): + """Multi-paragraph user messages must be fully preserved after auto-new.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + session.add_message("user", "old message") + session.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(session) + + async def _fake_archive(messages): + return True + + loop.consolidator.archive = _fake_archive + loop.consolidator.get_last_history_entry = lambda: { + "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "Summary.", + } + + # Simulate proactive archive completing before message arrives + await loop.auto_compact._archive("cli:test") + + msg = InboundMessage( + channel="cli", sender_id="user", chat_id="test", + content="Paragraph one\n\nParagraph two\n\nParagraph three", + ) + await loop._process_message(msg) + + session_after = loop.sessions.get_or_create("cli:test") + user_msgs = [m for m in session_after.messages if m.get("role") == "user"] + assert len(user_msgs) >= 1 + # All three paragraphs must be preserved + persisted = user_msgs[-1]["content"] + assert "Paragraph one" in persisted + assert "Paragraph two" in persisted + assert "Paragraph three" in persisted + # No runtime context markers in persisted message + assert "[Runtime Context" not in persisted + assert "[/Runtime Context]" not in persisted + await loop.close_mcp() + + +class TestProactiveAutoCompact: + """Test proactive auto-new on idle ticks (TimeoutError path in run loop).""" + + @staticmethod + async def _run_check_expired(loop): + """Helper: run check_expired via callback and wait for background tasks.""" + loop.auto_compact.check_expired(loop._schedule_background) + await asyncio.sleep(0.1) + + @pytest.mark.asyncio + async def test_no_check_when_ttl_disabled(self, tmp_path): + """check_expired should be a no-op when TTL is 0.""" + loop = _make_loop(tmp_path, session_ttl_minutes=0) + session = loop.sessions.get_or_create("cli:test") + session.add_message("user", "old message") + session.updated_at = datetime.now() - timedelta(minutes=30) + loop.sessions.save(session) + + await self._run_check_expired(loop) + + session_after = loop.sessions.get_or_create("cli:test") + assert len(session_after.messages) == 1 + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_proactive_archive_on_idle_tick(self, tmp_path): + """Expired session should be archived during idle tick.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + session.add_message("user", "old message") + session.add_message("assistant", "old response") + session.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(session) + + archived_messages = [] + + async def _fake_archive(messages): + archived_messages.extend(messages) + return True + + loop.consolidator.archive = _fake_archive + loop.consolidator.get_last_history_entry = lambda: { + "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "User chatted about old things.", + } + + await self._run_check_expired(loop) + + session_after = loop.sessions.get_or_create("cli:test") + assert len(session_after.messages) == 0 + assert len(archived_messages) == 2 + entry = loop.auto_compact._summaries.get("cli:test") + assert entry is not None + assert entry[0] == "User chatted about old things." + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_no_proactive_archive_when_active(self, tmp_path): + """Recently active session should NOT be archived on idle tick.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + session.add_message("user", "recent message") + loop.sessions.save(session) + + await self._run_check_expired(loop) + + session_after = loop.sessions.get_or_create("cli:test") + assert len(session_after.messages) == 1 + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_no_duplicate_archive(self, tmp_path): + """Should not archive the same session twice if already in progress.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + session.add_message("user", "old message") + session.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(session) + + archive_count = 0 + started = asyncio.Event() + block_forever = asyncio.Event() + + async def _slow_archive(messages): + nonlocal archive_count + archive_count += 1 + started.set() + await block_forever.wait() + return True + + loop.consolidator.archive = _slow_archive + + # First call starts archiving via callback + loop.auto_compact.check_expired(loop._schedule_background) + await started.wait() + assert archive_count == 1 + + # Second call should skip (key is in _archiving) + loop.auto_compact.check_expired(loop._schedule_background) + await asyncio.sleep(0.05) + assert archive_count == 1 + + # Clean up + block_forever.set() + await asyncio.sleep(0.1) + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_proactive_archive_error_does_not_block(self, tmp_path): + """Proactive archive failure should be caught and not block future ticks.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + session.add_message("user", "old message") + session.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(session) + + async def _failing_archive(messages): + raise RuntimeError("LLM down") + + loop.consolidator.archive = _failing_archive + + # Should not raise + await self._run_check_expired(loop) + + # Key should be removed from _archiving (finally block) + assert "cli:test" not in loop.auto_compact._archiving + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_proactive_archive_skips_empty_sessions(self, tmp_path): + """Proactive archive should not call LLM for sessions with no un-consolidated messages.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + session.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(session) + + archive_called = False + + async def _fake_archive(messages): + nonlocal archive_called + archive_called = True + return True + + loop.consolidator.archive = _fake_archive + + await self._run_check_expired(loop) + + assert not archive_called + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_no_reschedule_after_successful_archive(self, tmp_path): + """Already-archived session should NOT be re-scheduled on subsequent ticks.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + session.add_message("user", "old message") + session.add_message("assistant", "old response") + session.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(session) + + archive_count = 0 + + async def _fake_archive(messages): + nonlocal archive_count + archive_count += 1 + return True + + loop.consolidator.archive = _fake_archive + loop.consolidator.get_last_history_entry = lambda: { + "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "Summary.", + } + + # First tick: archives the session + await self._run_check_expired(loop) + assert archive_count == 1 + + # Second tick: should NOT re-schedule (updated_at is fresh after clear) + await self._run_check_expired(loop) + assert archive_count == 1 # Still 1, not re-scheduled + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_empty_skip_refreshes_updated_at_prevents_reschedule(self, tmp_path): + """Empty session skip refreshes updated_at, preventing immediate re-scheduling.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + session.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(session) + + archive_count = 0 + + async def _fake_archive(messages): + nonlocal archive_count + archive_count += 1 + return True + + loop.consolidator.archive = _fake_archive + + # First tick: skips (no messages), refreshes updated_at + await self._run_check_expired(loop) + assert archive_count == 0 + + # Second tick: should NOT re-schedule because updated_at is fresh + await self._run_check_expired(loop) + assert archive_count == 0 + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_session_can_be_compacted_again_after_new_messages(self, tmp_path): + """After successful compact + user sends new messages + idle again, should compact again.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + session.add_message("user", "first conversation") + session.add_message("assistant", "first response") + session.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(session) + + archive_count = 0 + + async def _fake_archive(messages): + nonlocal archive_count + archive_count += 1 + return True + + loop.consolidator.archive = _fake_archive + loop.consolidator.get_last_history_entry = lambda: { + "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "Summary.", + } + + # First compact cycle + await loop.auto_compact._archive("cli:test") + assert archive_count == 1 + + # User returns, sends new messages + msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="second topic") + await loop._process_message(msg) + + # Simulate idle again + loop.sessions.invalidate("cli:test") + session2 = loop.sessions.get_or_create("cli:test") + session2.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(session2) + + # Second compact cycle should succeed + await loop.auto_compact._archive("cli:test") + assert archive_count == 2 + await loop.close_mcp() + + +class TestSummaryPersistence: + """Test that summary survives restart via session metadata.""" + + @pytest.mark.asyncio + async def test_summary_persisted_in_session_metadata(self, tmp_path): + """After archive, _last_summary should be in session metadata.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + session.add_message("user", "hello") + session.add_message("assistant", "hi there") + session.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(session) + + async def _fake_archive(messages): + return True + + loop.consolidator.archive = _fake_archive + loop.consolidator.get_last_history_entry = lambda: { + "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "User said hello.", + } + + await loop.auto_compact._archive("cli:test") + + # Summary should be persisted in session metadata + session_after = loop.sessions.get_or_create("cli:test") + meta = session_after.metadata.get("_last_summary") + assert meta is not None + assert meta["text"] == "User said hello." + assert "last_active" in meta + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_summary_recovered_after_restart(self, tmp_path): + """Summary should be recovered from metadata when _summaries is empty (simulates restart).""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + session.add_message("user", "hello") + session.add_message("assistant", "hi there") + last_active = datetime.now() - timedelta(minutes=20) + session.updated_at = last_active + loop.sessions.save(session) + + async def _fake_archive(messages): + return True + + loop.consolidator.archive = _fake_archive + loop.consolidator.get_last_history_entry = lambda: { + "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "User said hello.", + } + + # Archive + await loop.auto_compact._archive("cli:test") + + # Simulate restart: clear in-memory state + loop.auto_compact._summaries.clear() + loop.sessions.invalidate("cli:test") + + # prepare_session should recover summary from metadata + reloaded = loop.sessions.get_or_create("cli:test") + _, summary = loop.auto_compact.prepare_session(reloaded, "cli:test") + + assert summary is not None + assert "User said hello." in summary + assert "Inactive for" in summary + # Metadata should be cleaned up after consumption + assert "_last_summary" not in reloaded.metadata + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_metadata_cleanup_no_leak(self, tmp_path): + """_last_summary should be removed from metadata after being consumed.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + session.add_message("user", "hello") + session.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(session) + + async def _fake_archive(messages): + return True + + loop.consolidator.archive = _fake_archive + loop.consolidator.get_last_history_entry = lambda: { + "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "Summary.", + } + + await loop.auto_compact._archive("cli:test") + + # Clear in-memory to force metadata path + loop.auto_compact._summaries.clear() + loop.sessions.invalidate("cli:test") + reloaded = loop.sessions.get_or_create("cli:test") + + # First call: consumes from metadata + _, summary = loop.auto_compact.prepare_session(reloaded, "cli:test") + assert summary is not None + + # Second call: no summary (already consumed) + _, summary2 = loop.auto_compact.prepare_session(reloaded, "cli:test") + assert summary2 is None + assert "_last_summary" not in reloaded.metadata + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_metadata_cleanup_on_inmemory_path(self, tmp_path): + """In-memory _summaries path should also clean up _last_summary from metadata.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + session.add_message("user", "hello") + session.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(session) + + async def _fake_archive(messages): + return True + + loop.consolidator.archive = _fake_archive + loop.consolidator.get_last_history_entry = lambda: { + "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "Summary.", + } + + await loop.auto_compact._archive("cli:test") + + # Both _summaries and metadata have the summary + assert "cli:test" in loop.auto_compact._summaries + loop.sessions.invalidate("cli:test") + reloaded = loop.sessions.get_or_create("cli:test") + assert "_last_summary" in reloaded.metadata + + # In-memory path is taken (no restart) + _, summary = loop.auto_compact.prepare_session(reloaded, "cli:test") + assert summary is not None + # Metadata should also be cleaned up + assert "_last_summary" not in reloaded.metadata + await loop.close_mcp() From 69d60e2b063b1b5b649a97320b36d8b7612b7f8a Mon Sep 17 00:00:00 2001 From: chengyongru Date: Fri, 10 Apr 2026 18:03:36 +0800 Subject: [PATCH 061/115] fix(agent): handle UnicodeDecodeError in _read_last_entry history.jsonl may contain non-UTF-8 bytes (e.g. from email channel binary content), causing auto compact to fail when reading the last entry for summary generation. Catch UnicodeDecodeError alongside FileNotFoundError and JSONDecodeError. --- nanobot/agent/memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py index 26c5cd45f..e9662ff2c 100644 --- a/nanobot/agent/memory.py +++ b/nanobot/agent/memory.py @@ -290,7 +290,7 @@ class MemoryStore: if not lines: return None return json.loads(lines[-1]) - except (FileNotFoundError, json.JSONDecodeError): + except (FileNotFoundError, json.JSONDecodeError, UnicodeDecodeError): return None def _write_entries(self, entries: list[dict[str, Any]]) -> None: From d03458f0346ffbc59996e265dece8520e79e974b Mon Sep 17 00:00:00 2001 From: chengyongru Date: Fri, 10 Apr 2026 18:14:14 +0800 Subject: [PATCH 062/115] fix(agent): eliminate race condition in auto compact summary retrieval MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Make Consolidator.archive() return the summary string directly instead of writing to history.jsonl then reading back via get_last_history_entry(). This eliminates a race condition where concurrent _archive calls for different sessions could read each other's summaries from the shared history file (cross-user context leak in multi-user deployments). Also removes Consolidator.get_last_history_entry() — no longer needed. --- nanobot/agent/auto_compact.py | 6 +-- nanobot/agent/memory.py | 14 +++--- tests/agent/test_auto_compact.py | 76 +++++++++----------------------- tests/agent/test_consolidator.py | 6 +-- 4 files changed, 31 insertions(+), 71 deletions(-) diff --git a/nanobot/agent/auto_compact.py b/nanobot/agent/auto_compact.py index 171f5f55a..f30feac17 100644 --- a/nanobot/agent/auto_compact.py +++ b/nanobot/agent/auto_compact.py @@ -53,9 +53,7 @@ class AutoCompact: return n = len(msgs) last_active = session.updated_at - await self.consolidator.archive(msgs) - entry = self.consolidator.get_last_history_entry() - summary = (entry or {}).get("content", "") + summary = await self.consolidator.archive(msgs) or "" if summary and summary != "(nothing)": self._summaries[key] = (summary, last_active) session.metadata["_last_summary"] = {"text": summary, "last_active": last_active.isoformat()} @@ -71,6 +69,8 @@ class AutoCompact: if key in self._archiving or self._is_expired(session.updated_at): logger.info("Auto-compact: reloading session {} (archiving={})", key, key in self._archiving) session = self.sessions.get_or_create(key) + # Hot path: summary from in-memory dict (process hasn't restarted). + # Also clean metadata copy so stale _last_summary never leaks to disk. entry = self._summaries.pop(key, None) if entry: session.metadata.pop("_last_summary", None) diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py index e9662ff2c..04d988ee5 100644 --- a/nanobot/agent/memory.py +++ b/nanobot/agent/memory.py @@ -374,10 +374,6 @@ class Consolidator: weakref.WeakValueDictionary() ) - def get_last_history_entry(self) -> dict[str, Any] | None: - """Return the most recent entry from history.jsonl.""" - return self.store._read_last_entry() - def get_lock(self, session_key: str) -> asyncio.Lock: """Return the shared consolidation lock for one session.""" return self._locks.setdefault(session_key, asyncio.Lock()) @@ -437,13 +433,13 @@ class Consolidator: self._get_tool_definitions(), ) - async def archive(self, messages: list[dict]) -> bool: + async def archive(self, messages: list[dict]) -> str | None: """Summarize messages via LLM and append to history.jsonl. - Returns True on success (or degraded success), False if nothing to do. + Returns the summary text on success, None if nothing to archive. """ if not messages: - return False + return None try: formatted = MemoryStore._format_messages(messages) response = await self.provider.chat_with_retry( @@ -463,11 +459,11 @@ class Consolidator: ) summary = response.content or "[no summary]" self.store.append_history(summary) - return True + return summary except Exception: logger.warning("Consolidation LLM call failed, raw-dumping to history") self.store.raw_archive(messages) - return True + return None async def maybe_consolidate_by_tokens(self, session: Session) -> None: """Loop: archive old messages until prompt fits within safe budget. diff --git a/tests/agent/test_auto_compact.py b/tests/agent/test_auto_compact.py index 8b26254e9..39792e290 100644 --- a/tests/agent/test_auto_compact.py +++ b/tests/agent/test_auto_compact.py @@ -101,7 +101,7 @@ class TestAutoCompact: loop.sessions.save(s2) async def _fake_archive(messages): - return True + return "Summary." loop.consolidator.archive = _fake_archive loop.auto_compact.check_expired(loop._schedule_background) @@ -126,7 +126,7 @@ class TestAutoCompact: async def _fake_archive(messages): archived_messages.extend(messages) - return True + return "Summary." loop.consolidator.archive = _fake_archive @@ -147,12 +147,9 @@ class TestAutoCompact: loop.sessions.save(session) async def _fake_archive(messages): - return True + return "User said hello." loop.consolidator.archive = _fake_archive - loop.consolidator.get_last_history_entry = lambda: { - "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "User said hello.", - } await loop.auto_compact._archive("cli:test") @@ -174,7 +171,7 @@ class TestAutoCompact: async def _fake_archive(messages): nonlocal archive_called archive_called = True - return True + return "Summary." loop.consolidator.archive = _fake_archive @@ -201,7 +198,7 @@ class TestAutoCompact: async def _fake_archive(messages): nonlocal archived_count archived_count = len(messages) - return True + return "Summary." loop.consolidator.archive = _fake_archive @@ -243,12 +240,9 @@ class TestAutoCompactIdleDetection: async def _fake_archive(messages): archived_messages.extend(messages) - return True + return "Summary." loop.consolidator.archive = _fake_archive - loop.consolidator.get_last_history_entry = lambda: { - "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "Summary.", - } # Simulate proactive archive completing before message arrives await loop.auto_compact._archive("cli:test") @@ -311,7 +305,7 @@ class TestAutoCompactIdleDetection: loop.sessions.save(session) async def _fake_archive(messages): - return True + return "Summary." loop.consolidator.archive = _fake_archive @@ -340,12 +334,9 @@ class TestAutoCompactSystemMessages: loop.sessions.save(session) async def _fake_archive(messages): - return True + return "Summary." loop.consolidator.archive = _fake_archive - loop.consolidator.get_last_history_entry = lambda: { - "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "Summary.", - } # Simulate proactive archive completing before system message arrives await loop.auto_compact._archive("cli:test") @@ -428,12 +419,9 @@ class TestAutoCompactEdgeCases: async def _fake_archive(messages): archived_messages.extend(messages) - return True + return "Summary." loop.consolidator.archive = _fake_archive - loop.consolidator.get_last_history_entry = lambda: { - "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "Summary.", - } # Simulate proactive archive completing before message arrives await loop.auto_compact._archive("cli:test") @@ -518,12 +506,9 @@ class TestAutoCompactIntegration: loop.sessions.save(session) async def _fake_archive(messages): - return True + return "Summary." loop.consolidator.archive = _fake_archive - loop.consolidator.get_last_history_entry = lambda: { - "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "Summary.", - } # Simulate proactive archive completing before message arrives await loop.auto_compact._archive("cli:test") @@ -586,12 +571,9 @@ class TestProactiveAutoCompact: async def _fake_archive(messages): archived_messages.extend(messages) - return True + return "User chatted about old things." loop.consolidator.archive = _fake_archive - loop.consolidator.get_last_history_entry = lambda: { - "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "User chatted about old things.", - } await self._run_check_expired(loop) @@ -635,7 +617,7 @@ class TestProactiveAutoCompact: archive_count += 1 started.set() await block_forever.wait() - return True + return "Summary." loop.consolidator.archive = _slow_archive @@ -688,7 +670,7 @@ class TestProactiveAutoCompact: async def _fake_archive(messages): nonlocal archive_called archive_called = True - return True + return "Summary." loop.consolidator.archive = _fake_archive @@ -712,12 +694,9 @@ class TestProactiveAutoCompact: async def _fake_archive(messages): nonlocal archive_count archive_count += 1 - return True + return "Summary." loop.consolidator.archive = _fake_archive - loop.consolidator.get_last_history_entry = lambda: { - "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "Summary.", - } # First tick: archives the session await self._run_check_expired(loop) @@ -741,7 +720,7 @@ class TestProactiveAutoCompact: async def _fake_archive(messages): nonlocal archive_count archive_count += 1 - return True + return "Summary." loop.consolidator.archive = _fake_archive @@ -769,12 +748,9 @@ class TestProactiveAutoCompact: async def _fake_archive(messages): nonlocal archive_count archive_count += 1 - return True + return "Summary." loop.consolidator.archive = _fake_archive - loop.consolidator.get_last_history_entry = lambda: { - "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "Summary.", - } # First compact cycle await loop.auto_compact._archive("cli:test") @@ -810,12 +786,9 @@ class TestSummaryPersistence: loop.sessions.save(session) async def _fake_archive(messages): - return True + return "User said hello." loop.consolidator.archive = _fake_archive - loop.consolidator.get_last_history_entry = lambda: { - "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "User said hello.", - } await loop.auto_compact._archive("cli:test") @@ -839,12 +812,9 @@ class TestSummaryPersistence: loop.sessions.save(session) async def _fake_archive(messages): - return True + return "User said hello." loop.consolidator.archive = _fake_archive - loop.consolidator.get_last_history_entry = lambda: { - "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "User said hello.", - } # Archive await loop.auto_compact._archive("cli:test") @@ -874,12 +844,9 @@ class TestSummaryPersistence: loop.sessions.save(session) async def _fake_archive(messages): - return True + return "Summary." loop.consolidator.archive = _fake_archive - loop.consolidator.get_last_history_entry = lambda: { - "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "Summary.", - } await loop.auto_compact._archive("cli:test") @@ -908,12 +875,9 @@ class TestSummaryPersistence: loop.sessions.save(session) async def _fake_archive(messages): - return True + return "Summary." loop.consolidator.archive = _fake_archive - loop.consolidator.get_last_history_entry = lambda: { - "cursor": 1, "timestamp": "2026-01-01 00:00", "content": "Summary.", - } await loop.auto_compact._archive("cli:test") diff --git a/tests/agent/test_consolidator.py b/tests/agent/test_consolidator.py index b7989d9dd..28587e1b4 100644 --- a/tests/agent/test_consolidator.py +++ b/tests/agent/test_consolidator.py @@ -46,7 +46,7 @@ class TestConsolidatorSummarize: {"role": "assistant", "content": "Done, fixed the race condition."}, ] result = await consolidator.archive(messages) - assert result is True + assert result == "User fixed a bug in the auth module." entries = store.read_unprocessed_history(since_cursor=0) assert len(entries) == 1 @@ -55,14 +55,14 @@ class TestConsolidatorSummarize: mock_provider.chat_with_retry.side_effect = Exception("API error") messages = [{"role": "user", "content": "hello"}] result = await consolidator.archive(messages) - assert result is True # always succeeds + assert result is None # no summary on raw dump fallback entries = store.read_unprocessed_history(since_cursor=0) assert len(entries) == 1 assert "[RAW]" in entries[0]["content"] async def test_summarize_skips_empty_messages(self, consolidator): result = await consolidator.archive([]) - assert result is False + assert result is None class TestConsolidatorTokenBudget: From 1cb28b39a30cea4be4ca57e7a1997bb06e9f71b3 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Sat, 11 Apr 2026 07:25:50 +0000 Subject: [PATCH 063/115] feat(agent): retain recent context during auto compact Keep a legal recent suffix in idle auto-compacted sessions so resumed chats preserve their freshest live context while older messages are summarized. Recover persisted summaries even when retained messages remain, and document the new behavior. --- README.md | 6 +- nanobot/agent/auto_compact.py | 51 +++++++++++--- tests/agent/test_auto_compact.py | 116 ++++++++++++++++--------------- 3 files changed, 104 insertions(+), 69 deletions(-) diff --git a/README.md b/README.md index a9bf4b5e3..88ff35f29 100644 --- a/README.md +++ b/README.md @@ -1505,7 +1505,7 @@ MCP tools are automatically discovered and registered on startup. The LLM can us ### Auto Compact -When a user is idle for longer than a configured TTL, nanobot **proactively** compresses the session context into a summary. This reduces token cost and first-token latency when the user returns — instead of re-processing a long stale context with an expired KV cache, the model receives a compact summary and fresh input. +When a user is idle for longer than a configured TTL, nanobot **proactively** compresses the older part of the session context into a summary while keeping a recent legal suffix of live messages. This reduces token cost and first-token latency when the user returns — instead of re-processing a long stale context with an expired KV cache, the model receives a compact summary, the most recent live context, and fresh input. ```json { @@ -1523,8 +1523,8 @@ When a user is idle for longer than a configured TTL, nanobot **proactively** co How it works: 1. **Idle detection**: On each idle tick (~1 s), checks all sessions for expiration. -2. **Background compaction**: Expired sessions are summarized via LLM, then cleared. -3. **Summary injection**: When the user returns, the summary is injected as runtime context (one-shot, not persisted). +2. **Background compaction**: Expired sessions summarize the older live prefix via LLM and keep the most recent legal suffix (currently 8 messages). +3. **Summary injection**: When the user returns, the summary is injected as runtime context (one-shot, not persisted) alongside the retained recent suffix. > [!TIP] > The summary survives bot restarts — it's stored in session metadata and recovered on the next message. diff --git a/nanobot/agent/auto_compact.py b/nanobot/agent/auto_compact.py index f30feac17..47c7b5a36 100644 --- a/nanobot/agent/auto_compact.py +++ b/nanobot/agent/auto_compact.py @@ -3,16 +3,18 @@ from __future__ import annotations from datetime import datetime -from typing import TYPE_CHECKING, Callable, Coroutine +from typing import TYPE_CHECKING, Any, Callable, Coroutine from loguru import logger +from nanobot.session.manager import Session, SessionManager if TYPE_CHECKING: from nanobot.agent.memory import Consolidator - from nanobot.session.manager import Session, SessionManager class AutoCompact: + _RECENT_SUFFIX_MESSAGES = 8 + def __init__(self, sessions: SessionManager, consolidator: Consolidator, session_ttl_minutes: int = 0): self.sessions = sessions @@ -33,6 +35,27 @@ class AutoCompact: idle_min = int((datetime.now() - last_active).total_seconds() / 60) return f"Inactive for {idle_min} minutes.\nPrevious conversation summary: {text}" + def _split_unconsolidated( + self, session: Session, + ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + """Split live session tail into archiveable prefix and retained recent suffix.""" + tail = list(session.messages[session.last_consolidated:]) + if not tail: + return [], [] + + probe = Session( + key=session.key, + messages=tail.copy(), + created_at=session.created_at, + updated_at=session.updated_at, + metadata={}, + last_consolidated=0, + ) + probe.retain_recent_legal_suffix(self._RECENT_SUFFIX_MESSAGES) + kept = probe.messages + cut = len(tail) - len(kept) + return tail[:cut], kept + def check_expired(self, schedule_background: Callable[[Coroutine], None]) -> None: for info in self.sessions.list_sessions(): key = info.get("key", "") @@ -45,21 +68,31 @@ class AutoCompact: try: self.sessions.invalidate(key) session = self.sessions.get_or_create(key) - msgs = session.messages[session.last_consolidated:] - if not msgs: + archive_msgs, kept_msgs = self._split_unconsolidated(session) + if not archive_msgs and not kept_msgs: logger.debug("Auto-compact: skipping {}, no un-consolidated messages", key) session.updated_at = datetime.now() self.sessions.save(session) return - n = len(msgs) + last_active = session.updated_at - summary = await self.consolidator.archive(msgs) or "" + summary = "" + if archive_msgs: + summary = await self.consolidator.archive(archive_msgs) or "" if summary and summary != "(nothing)": self._summaries[key] = (summary, last_active) session.metadata["_last_summary"] = {"text": summary, "last_active": last_active.isoformat()} - session.clear() + session.messages = kept_msgs + session.last_consolidated = 0 + session.updated_at = datetime.now() self.sessions.save(session) - logger.info("Auto-compact: archived {} ({} messages, summary={})", key, n, bool(summary)) + logger.info( + "Auto-compact: archived {} (archived={}, kept={}, summary={})", + key, + len(archive_msgs), + len(kept_msgs), + bool(summary), + ) except Exception: logger.exception("Auto-compact: failed for {}", key) finally: @@ -75,7 +108,7 @@ class AutoCompact: if entry: session.metadata.pop("_last_summary", None) return session, self._format_summary(entry[0], entry[1]) - if not session.messages and "_last_summary" in session.metadata: + if "_last_summary" in session.metadata: meta = session.metadata.pop("_last_summary") self.sessions.save(session) return session, self._format_summary(meta["text"], datetime.fromisoformat(meta["last_active"])) diff --git a/tests/agent/test_auto_compact.py b/tests/agent/test_auto_compact.py index 39792e290..8f1be03a2 100644 --- a/tests/agent/test_auto_compact.py +++ b/tests/agent/test_auto_compact.py @@ -35,6 +35,13 @@ def _make_loop(tmp_path: Path, session_ttl_minutes: int = 15) -> AgentLoop: return loop +def _add_turns(session, turns: int, *, prefix: str = "msg") -> None: + """Append simple user/assistant turns to a session.""" + for i in range(turns): + session.add_message("user", f"{prefix} user {i}") + session.add_message("assistant", f"{prefix} assistant {i}") + + class TestSessionTTLConfig: """Test session TTL configuration.""" @@ -113,13 +120,11 @@ class TestAutoCompact: await loop.close_mcp() @pytest.mark.asyncio - async def test_auto_compact_archives_and_clears(self, tmp_path): - """_archive should archive un-consolidated messages and clear session.""" + async def test_auto_compact_archives_prefix_and_keeps_recent_suffix(self, tmp_path): + """_archive should summarize the old prefix and keep a recent legal suffix.""" loop = _make_loop(tmp_path, session_ttl_minutes=15) session = loop.sessions.get_or_create("cli:test") - for i in range(4): - session.add_message("user", f"msg{i}") - session.add_message("assistant", f"resp{i}") + _add_turns(session, 6) loop.sessions.save(session) archived_messages = [] @@ -132,9 +137,11 @@ class TestAutoCompact: await loop.auto_compact._archive("cli:test") - assert len(archived_messages) == 8 + assert len(archived_messages) == 4 session_after = loop.sessions.get_or_create("cli:test") - assert len(session_after.messages) == 0 + assert len(session_after.messages) == loop.auto_compact._RECENT_SUFFIX_MESSAGES + assert session_after.messages[0]["content"] == "msg user 2" + assert session_after.messages[-1]["content"] == "msg assistant 5" await loop.close_mcp() @pytest.mark.asyncio @@ -142,8 +149,7 @@ class TestAutoCompact: """_archive should store the summary in _summaries.""" loop = _make_loop(tmp_path, session_ttl_minutes=15) session = loop.sessions.get_or_create("cli:test") - session.add_message("user", "hello") - session.add_message("assistant", "hi there") + _add_turns(session, 6, prefix="hello") loop.sessions.save(session) async def _fake_archive(messages): @@ -157,7 +163,7 @@ class TestAutoCompact: assert entry is not None assert entry[0] == "User said hello." session_after = loop.sessions.get_or_create("cli:test") - assert len(session_after.messages) == 0 + assert len(session_after.messages) == loop.auto_compact._RECENT_SUFFIX_MESSAGES await loop.close_mcp() @pytest.mark.asyncio @@ -187,9 +193,7 @@ class TestAutoCompact: """_archive should only archive un-consolidated messages.""" loop = _make_loop(tmp_path, session_ttl_minutes=15) session = loop.sessions.get_or_create("cli:test") - for i in range(10): - session.add_message("user", f"msg{i}") - session.add_message("assistant", f"resp{i}") + _add_turns(session, 14) session.last_consolidated = 18 loop.sessions.save(session) @@ -232,7 +236,7 @@ class TestAutoCompactIdleDetection: """Proactive auto-new archives expired session; _process_message reloads it.""" loop = _make_loop(tmp_path, session_ttl_minutes=15) session = loop.sessions.get_or_create("cli:test") - session.add_message("user", "old message") + _add_turns(session, 6, prefix="old") session.updated_at = datetime.now() - timedelta(minutes=20) loop.sessions.save(session) @@ -251,7 +255,8 @@ class TestAutoCompactIdleDetection: await loop._process_message(msg) session_after = loop.sessions.get_or_create("cli:test") - assert not any(m["content"] == "old message" for m in session_after.messages) + assert len(archived_messages) == 4 + assert not any(m["content"] == "old user 0" for m in session_after.messages) assert any(m["content"] == "new msg" for m in session_after.messages) await loop.close_mcp() @@ -329,7 +334,7 @@ class TestAutoCompactSystemMessages: """Proactive auto-new archives expired session; system messages reload it.""" loop = _make_loop(tmp_path, session_ttl_minutes=15) session = loop.sessions.get_or_create("cli:test") - session.add_message("user", "old message from subagent context") + _add_turns(session, 6, prefix="old") session.updated_at = datetime.now() - timedelta(minutes=20) loop.sessions.save(session) @@ -349,7 +354,7 @@ class TestAutoCompactSystemMessages: session_after = loop.sessions.get_or_create("cli:test") assert not any( - m["content"] == "old message from subagent context" + m["content"] == "old user 0" for m in session_after.messages ) await loop.close_mcp() @@ -363,8 +368,7 @@ class TestAutoCompactEdgeCases: """Auto-new should not inject when archive produces '(nothing)'.""" loop = _make_loop(tmp_path, session_ttl_minutes=15) session = loop.sessions.get_or_create("cli:test") - session.add_message("user", "thanks") - session.add_message("assistant", "you're welcome") + _add_turns(session, 6, prefix="thanks") session.updated_at = datetime.now() - timedelta(minutes=20) loop.sessions.save(session) @@ -375,18 +379,18 @@ class TestAutoCompactEdgeCases: await loop.auto_compact._archive("cli:test") session_after = loop.sessions.get_or_create("cli:test") - assert len(session_after.messages) == 0 + assert len(session_after.messages) == loop.auto_compact._RECENT_SUFFIX_MESSAGES # "(nothing)" summary should not be stored assert "cli:test" not in loop.auto_compact._summaries await loop.close_mcp() @pytest.mark.asyncio - async def test_auto_compact_archive_failure_still_clears(self, tmp_path): - """Auto-new should clear session even if LLM archive fails (raw_archive fallback).""" + async def test_auto_compact_archive_failure_still_keeps_recent_suffix(self, tmp_path): + """Auto-new should keep the recent suffix even if LLM archive falls back to raw dump.""" loop = _make_loop(tmp_path, session_ttl_minutes=15) session = loop.sessions.get_or_create("cli:test") - session.add_message("user", "important data") + _add_turns(session, 6, prefix="important") session.updated_at = datetime.now() - timedelta(minutes=20) loop.sessions.save(session) @@ -396,14 +400,13 @@ class TestAutoCompactEdgeCases: await loop.auto_compact._archive("cli:test") session_after = loop.sessions.get_or_create("cli:test") - # Session should be cleared (archive falls back to raw dump) - assert len(session_after.messages) == 0 + assert len(session_after.messages) == loop.auto_compact._RECENT_SUFFIX_MESSAGES await loop.close_mcp() @pytest.mark.asyncio async def test_auto_compact_preserves_runtime_checkpoint_before_check(self, tmp_path): - """Runtime checkpoint is restored; proactive archive handles the expired session.""" + """Short expired sessions keep recent messages; checkpoint restore still works on resume.""" loop = _make_loop(tmp_path, session_ttl_minutes=15) session = loop.sessions.get_or_create("cli:test") session.metadata[AgentLoop._RUNTIME_CHECKPOINT_KEY] = { @@ -429,8 +432,10 @@ class TestAutoCompactEdgeCases: msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="continue") await loop._process_message(msg) - # The checkpoint-restored message should have been archived by proactive path - assert len(archived_messages) >= 1 + session_after = loop.sessions.get_or_create("cli:test") + assert archived_messages == [] + assert any(m["content"] == "previous message" for m in session_after.messages) + assert any(m["content"] == "interrupted response" for m in session_after.messages) await loop.close_mcp() @@ -446,11 +451,17 @@ class TestAutoCompactIntegration: loop = _make_loop(tmp_path, session_ttl_minutes=15) session = loop.sessions.get_or_create("cli:test") - # Phase 1: User has a conversation + # Phase 1: User has a conversation longer than the retained recent suffix session.add_message("user", "I'm learning English, teach me past tense") session.add_message("assistant", "Past tense is used for actions completed in the past...") session.add_message("user", "Give me an example") session.add_message("assistant", '"I walked to the store yesterday."') + session.add_message("user", "Give me another example") + session.add_message("assistant", '"She visited Paris last year."') + session.add_message("user", "Quiz me") + session.add_message("assistant", "What is the past tense of go?") + session.add_message("user", "I think it is went") + session.add_message("assistant", "Correct.") loop.sessions.save(session) # Phase 2: Time passes (simulate idle) @@ -474,7 +485,7 @@ class TestAutoCompactIntegration: # Phase 4: Verify session_after = loop.sessions.get_or_create("cli:test") - # Old messages should be gone + # The oldest messages should be trimmed from live session history assert not any( "past tense is used" in str(m.get("content", "")) for m in session_after.messages ) @@ -497,8 +508,8 @@ class TestAutoCompactIntegration: await loop.close_mcp() @pytest.mark.asyncio - async def test_multi_paragraph_user_message_preserved(self, tmp_path): - """Multi-paragraph user messages must be fully preserved after auto-new.""" + async def test_runtime_context_markers_not_persisted_for_multi_paragraph_turn(self, tmp_path): + """Auto-compact resume context must not leak runtime markers into persisted session history.""" loop = _make_loop(tmp_path, session_ttl_minutes=15) session = loop.sessions.get_or_create("cli:test") session.add_message("user", "old message") @@ -520,16 +531,11 @@ class TestAutoCompactIntegration: await loop._process_message(msg) session_after = loop.sessions.get_or_create("cli:test") - user_msgs = [m for m in session_after.messages if m.get("role") == "user"] - assert len(user_msgs) >= 1 - # All three paragraphs must be preserved - persisted = user_msgs[-1]["content"] - assert "Paragraph one" in persisted - assert "Paragraph two" in persisted - assert "Paragraph three" in persisted - # No runtime context markers in persisted message - assert "[Runtime Context" not in persisted - assert "[/Runtime Context]" not in persisted + assert any(m.get("content") == "old message" for m in session_after.messages) + for persisted in session_after.messages: + content = str(persisted.get("content", "")) + assert "[Runtime Context" not in content + assert "[/Runtime Context]" not in content await loop.close_mcp() @@ -562,8 +568,7 @@ class TestProactiveAutoCompact: """Expired session should be archived during idle tick.""" loop = _make_loop(tmp_path, session_ttl_minutes=15) session = loop.sessions.get_or_create("cli:test") - session.add_message("user", "old message") - session.add_message("assistant", "old response") + _add_turns(session, 5, prefix="old") session.updated_at = datetime.now() - timedelta(minutes=20) loop.sessions.save(session) @@ -578,7 +583,7 @@ class TestProactiveAutoCompact: await self._run_check_expired(loop) session_after = loop.sessions.get_or_create("cli:test") - assert len(session_after.messages) == 0 + assert len(session_after.messages) == loop.auto_compact._RECENT_SUFFIX_MESSAGES assert len(archived_messages) == 2 entry = loop.auto_compact._summaries.get("cli:test") assert entry is not None @@ -604,7 +609,7 @@ class TestProactiveAutoCompact: """Should not archive the same session twice if already in progress.""" loop = _make_loop(tmp_path, session_ttl_minutes=15) session = loop.sessions.get_or_create("cli:test") - session.add_message("user", "old message") + _add_turns(session, 6, prefix="old") session.updated_at = datetime.now() - timedelta(minutes=20) loop.sessions.save(session) @@ -641,7 +646,7 @@ class TestProactiveAutoCompact: """Proactive archive failure should be caught and not block future ticks.""" loop = _make_loop(tmp_path, session_ttl_minutes=15) session = loop.sessions.get_or_create("cli:test") - session.add_message("user", "old message") + _add_turns(session, 6, prefix="old") session.updated_at = datetime.now() - timedelta(minutes=20) loop.sessions.save(session) @@ -684,8 +689,7 @@ class TestProactiveAutoCompact: """Already-archived session should NOT be re-scheduled on subsequent ticks.""" loop = _make_loop(tmp_path, session_ttl_minutes=15) session = loop.sessions.get_or_create("cli:test") - session.add_message("user", "old message") - session.add_message("assistant", "old response") + _add_turns(session, 5, prefix="old") session.updated_at = datetime.now() - timedelta(minutes=20) loop.sessions.save(session) @@ -738,8 +742,7 @@ class TestProactiveAutoCompact: """After successful compact + user sends new messages + idle again, should compact again.""" loop = _make_loop(tmp_path, session_ttl_minutes=15) session = loop.sessions.get_or_create("cli:test") - session.add_message("user", "first conversation") - session.add_message("assistant", "first response") + _add_turns(session, 5, prefix="first") session.updated_at = datetime.now() - timedelta(minutes=20) loop.sessions.save(session) @@ -780,8 +783,7 @@ class TestSummaryPersistence: """After archive, _last_summary should be in session metadata.""" loop = _make_loop(tmp_path, session_ttl_minutes=15) session = loop.sessions.get_or_create("cli:test") - session.add_message("user", "hello") - session.add_message("assistant", "hi there") + _add_turns(session, 6, prefix="hello") session.updated_at = datetime.now() - timedelta(minutes=20) loop.sessions.save(session) @@ -805,8 +807,7 @@ class TestSummaryPersistence: """Summary should be recovered from metadata when _summaries is empty (simulates restart).""" loop = _make_loop(tmp_path, session_ttl_minutes=15) session = loop.sessions.get_or_create("cli:test") - session.add_message("user", "hello") - session.add_message("assistant", "hi there") + _add_turns(session, 6, prefix="hello") last_active = datetime.now() - timedelta(minutes=20) session.updated_at = last_active loop.sessions.save(session) @@ -825,6 +826,7 @@ class TestSummaryPersistence: # prepare_session should recover summary from metadata reloaded = loop.sessions.get_or_create("cli:test") + assert len(reloaded.messages) == loop.auto_compact._RECENT_SUFFIX_MESSAGES _, summary = loop.auto_compact.prepare_session(reloaded, "cli:test") assert summary is not None @@ -839,7 +841,7 @@ class TestSummaryPersistence: """_last_summary should be removed from metadata after being consumed.""" loop = _make_loop(tmp_path, session_ttl_minutes=15) session = loop.sessions.get_or_create("cli:test") - session.add_message("user", "hello") + _add_turns(session, 6, prefix="hello") session.updated_at = datetime.now() - timedelta(minutes=20) loop.sessions.save(session) @@ -870,7 +872,7 @@ class TestSummaryPersistence: """In-memory _summaries path should also clean up _last_summary from metadata.""" loop = _make_loop(tmp_path, session_ttl_minutes=15) session = loop.sessions.get_or_create("cli:test") - session.add_message("user", "hello") + _add_turns(session, 6, prefix="hello") session.updated_at = datetime.now() - timedelta(minutes=20) loop.sessions.save(session) From 84e840659aabc5682d3699c9466a050d82f644df Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Sat, 11 Apr 2026 07:32:56 +0000 Subject: [PATCH 064/115] refactor(config): rename auto compact config key Prefer the more user-friendly idleCompactAfterMinutes name for auto compact while keeping sessionTtlMinutes as a backward-compatible alias. Update tests and README to document the retained recent-context behavior and the new preferred key. --- README.md | 13 ++++++++----- nanobot/config/schema.py | 7 ++++++- tests/agent/test_auto_compact.py | 17 +++++++++++++++++ 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 88ff35f29..856986754 100644 --- a/README.md +++ b/README.md @@ -1505,13 +1505,13 @@ MCP tools are automatically discovered and registered on startup. The LLM can us ### Auto Compact -When a user is idle for longer than a configured TTL, nanobot **proactively** compresses the older part of the session context into a summary while keeping a recent legal suffix of live messages. This reduces token cost and first-token latency when the user returns — instead of re-processing a long stale context with an expired KV cache, the model receives a compact summary, the most recent live context, and fresh input. +When a user is idle for longer than a configured threshold, nanobot **proactively** compresses the older part of the session context into a summary while keeping a recent legal suffix of live messages. This reduces token cost and first-token latency when the user returns — instead of re-processing a long stale context with an expired KV cache, the model receives a compact summary, the most recent live context, and fresh input. ```json { "agents": { "defaults": { - "sessionTtlMinutes": 15 + "idleCompactAfterMinutes": 15 } } } @@ -1519,15 +1519,18 @@ When a user is idle for longer than a configured TTL, nanobot **proactively** co | Option | Default | Description | |--------|---------|-------------| -| `agents.defaults.sessionTtlMinutes` | `0` (disabled) | Minutes of idle time before auto-compaction. Set to `0` to disable. Recommended: `15` — matches typical LLM KV cache expiration, so compacted sessions won't waste cache on cold entries. | +| `agents.defaults.idleCompactAfterMinutes` | `0` (disabled) | Minutes of idle time before auto-compaction starts. Set to `0` to disable. Recommended: `15` — close to a typical LLM KV cache expiry window, so stale sessions get compacted before the user returns. | + +`sessionTtlMinutes` remains accepted as a legacy alias for backward compatibility, but `idleCompactAfterMinutes` is the preferred config key going forward. How it works: 1. **Idle detection**: On each idle tick (~1 s), checks all sessions for expiration. -2. **Background compaction**: Expired sessions summarize the older live prefix via LLM and keep the most recent legal suffix (currently 8 messages). +2. **Background compaction**: Idle sessions summarize the older live prefix via LLM and keep the most recent legal suffix (currently 8 messages). 3. **Summary injection**: When the user returns, the summary is injected as runtime context (one-shot, not persisted) alongside the retained recent suffix. +4. **Restart-safe resume**: The summary is also mirrored into session metadata so it can still be recovered after a process restart. > [!TIP] -> The summary survives bot restarts — it's stored in session metadata and recovered on the next message. +> Think of auto compact as "summarize older context, keep the freshest live turns." It is not a hard session reset. ### Timezone diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 8ab68d7b5..67cce4470 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -77,7 +77,12 @@ class AgentDefaults(Base): reasoning_effort: str | None = None # low / medium / high / adaptive - enables LLM thinking mode timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York" unified_session: bool = False # Share one session across all channels (single-user multi-device) - session_ttl_minutes: int = Field(default=0, ge=0) # Auto /new after idle (0 = disabled) + session_ttl_minutes: int = Field( + default=0, + ge=0, + validation_alias=AliasChoices("idleCompactAfterMinutes", "sessionTtlMinutes"), + serialization_alias="idleCompactAfterMinutes", + ) # Auto-compact idle threshold in minutes (0 = disabled) dream: DreamConfig = Field(default_factory=DreamConfig) diff --git a/tests/agent/test_auto_compact.py b/tests/agent/test_auto_compact.py index 8f1be03a2..b3462820b 100644 --- a/tests/agent/test_auto_compact.py +++ b/tests/agent/test_auto_compact.py @@ -55,6 +55,23 @@ class TestSessionTTLConfig: defaults = AgentDefaults(session_ttl_minutes=30) assert defaults.session_ttl_minutes == 30 + def test_user_friendly_alias_is_supported(self): + """Config should accept idleCompactAfterMinutes as the preferred JSON key.""" + defaults = AgentDefaults.model_validate({"idleCompactAfterMinutes": 30}) + assert defaults.session_ttl_minutes == 30 + + def test_legacy_alias_is_still_supported(self): + """Config should still accept the old sessionTtlMinutes key for compatibility.""" + defaults = AgentDefaults.model_validate({"sessionTtlMinutes": 30}) + assert defaults.session_ttl_minutes == 30 + + def test_serializes_with_user_friendly_alias(self): + """Config dumps should use idleCompactAfterMinutes for JSON output.""" + defaults = AgentDefaults(session_ttl_minutes=30) + data = defaults.model_dump(mode="json", by_alias=True) + assert data["idleCompactAfterMinutes"] == 30 + assert "sessionTtlMinutes" not in data + class TestAgentLoopTTLParam: """Test that AutoCompact receives and stores session_ttl_minutes.""" From 5932482d01bb442e99143cabea2c0a0c272c5a1b Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Sat, 11 Apr 2026 07:49:31 +0000 Subject: [PATCH 065/115] refactor(agent): rename auto compact module Rename the auto compact module to autocompact.py for a cleaner path while keeping the AutoCompact type and behavior unchanged. Update the agent loop import to match. --- nanobot/agent/{auto_compact.py => autocompact.py} | 0 nanobot/agent/loop.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename nanobot/agent/{auto_compact.py => autocompact.py} (100%) diff --git a/nanobot/agent/auto_compact.py b/nanobot/agent/autocompact.py similarity index 100% rename from nanobot/agent/auto_compact.py rename to nanobot/agent/autocompact.py diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 65a5a1abc..05a27349f 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable from loguru import logger -from nanobot.agent.auto_compact import AutoCompact +from nanobot.agent.autocompact import AutoCompact from nanobot.agent.context import ContextBuilder from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook from nanobot.agent.memory import Consolidator, Dream From e0ba56808967f6bac652fa8e5ed13ac32874f2e4 Mon Sep 17 00:00:00 2001 From: weitongtong Date: Sat, 11 Apr 2026 14:34:45 +0800 Subject: [PATCH 066/115] =?UTF-8?q?fix(cron):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E5=9B=BA=E5=AE=9A=E9=97=B4=E9=9A=94=E4=BB=BB=E5=8A=A1=E5=9B=A0?= =?UTF-8?q?=20store=20=E5=B9=B6=E5=8F=91=E6=9B=BF=E6=8D=A2=E5=AF=BC?= =?UTF-8?q?=E8=87=B4=E7=9A=84=E9=87=8D=E5=A4=8D=E6=89=A7=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _on_timer 中 await _execute_job 让出控制权期间,前端轮询触发的 list_jobs 调用 _load_store 从磁盘重新加载覆盖 self._store, 已执行任务的状态被旧值回退,导致再次触发。 引入 _timer_active 标志位,在任务执行期间阻止并发 _load_store 替换 store。同时修复 store 为空时未重新 arm timer 的问题。 Made-with: Cursor --- nanobot/cron/service.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/nanobot/cron/service.py b/nanobot/cron/service.py index 267613012..165ce54d7 100644 --- a/nanobot/cron/service.py +++ b/nanobot/cron/service.py @@ -80,6 +80,7 @@ class CronService: self._store: CronStore | None = None self._timer_task: asyncio.Task | None = None self._running = False + self._timer_active = False self.max_sleep_ms = max_sleep_ms def _load_jobs(self) -> tuple[list[CronJob], int]: @@ -171,7 +172,11 @@ class CronService: def _load_store(self) -> CronStore: """Load jobs from disk. Reloads automatically if file was modified externally. - Reload every time because it needs to merge operations on the jobs object from other instances. + - During _on_timer execution, return the existing store to prevent concurrent + _load_store calls (e.g. from list_jobs polling) from replacing it mid-execution. """ + if self._timer_active and self._store: + return self._store jobs, version = self._load_jobs() self._store = CronStore(version=version, jobs=jobs) self._merge_action() @@ -290,18 +295,23 @@ class CronService: """Handle timer tick - run due jobs.""" self._load_store() if not self._store: + self._arm_timer() return - now = _now_ms() - due_jobs = [ - j for j in self._store.jobs - if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms - ] + self._timer_active = True + try: + now = _now_ms() + due_jobs = [ + j for j in self._store.jobs + if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms + ] - for job in due_jobs: - await self._execute_job(job) + for job in due_jobs: + await self._execute_job(job) - self._save_store() + self._save_store() + finally: + self._timer_active = False self._arm_timer() async def _execute_job(self, job: CronJob) -> None: From 5bb7f77b80934d49432c6acc96644f33c1c7b956 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Sat, 11 Apr 2026 08:43:25 +0000 Subject: [PATCH 067/115] feat(tests): add regression test for timer execution to prevent store rollback during job execution --- tests/cron/test_cron_service.py | 39 +++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/cron/test_cron_service.py b/tests/cron/test_cron_service.py index 4aa7fc06d..747f8ec81 100644 --- a/tests/cron/test_cron_service.py +++ b/tests/cron/test_cron_service.py @@ -329,6 +329,45 @@ async def test_external_update_preserves_run_history_records(tmp_path): fresh._save_store() +# ── timer race regression tests ── + + +@pytest.mark.asyncio +async def test_timer_execution_is_not_rolled_back_by_list_jobs_reload(tmp_path): + """list_jobs() during _on_timer should not replace the active store and re-run the same due job.""" + store_path = tmp_path / "cron" / "jobs.json" + calls: list[str] = [] + + async def on_job(job): + calls.append(job.id) + # Simulate frontend polling list_jobs while the timer callback is mid-execution. + service.list_jobs(include_disabled=True) + await asyncio.sleep(0) + + service = CronService(store_path, on_job=on_job) + service._running = True + service._load_store() + service._arm_timer = lambda: None + + job = service.add_job( + name="race", + schedule=CronSchedule(kind="every", every_ms=60_000), + message="hello", + ) + job.state.next_run_at_ms = max(1, int(time.time() * 1000) - 1_000) + service._save_store() + + await service._on_timer() + await service._on_timer() + + assert calls == [job.id] + loaded = service.get_job(job.id) + assert loaded is not None + assert loaded.state.last_run_at_ms is not None + assert loaded.state.next_run_at_ms is not None + assert loaded.state.next_run_at_ms > loaded.state.last_run_at_ms + + # ── update_job tests ── From d3aa209cf6967563e023701acb2d6f4b867762d9 Mon Sep 17 00:00:00 2001 From: Mike Terhar Date: Wed, 8 Apr 2026 09:24:32 -0400 Subject: [PATCH 068/115] add kagi web search tool --- nanobot/agent/tools/web.py | 25 +++++++++++++++++++++++++ nanobot/config/schema.py | 2 +- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/nanobot/agent/tools/web.py b/nanobot/agent/tools/web.py index 275fcf88c..38fc33d74 100644 --- a/nanobot/agent/tools/web.py +++ b/nanobot/agent/tools/web.py @@ -114,6 +114,8 @@ class WebSearchTool(Tool): return await self._search_jina(query, n) elif provider == "brave": return await self._search_brave(query, n) + elif provider == "kagi": + return await self._search_kagi(query, n) else: return f"Error: unknown search provider '{provider}'" @@ -204,6 +206,29 @@ class WebSearchTool(Tool): logger.warning("Jina search failed ({}), falling back to DuckDuckGo", e) return await self._search_duckduckgo(query, n) + async def _search_kagi(self, query: str, n: int) -> str: + api_key = self.config.api_key or os.environ.get("KAGI_API_KEY", "") + if not api_key: + logger.warning("KAGI_API_KEY not set, falling back to DuckDuckGo") + return await self._search_duckduckgo(query, n) + try: + async with httpx.AsyncClient(proxy=self.proxy) as client: + r = await client.get( + "https://kagi.com/api/v0/search", + params={"q": query, "limit": n}, + headers={"Authorization": f"Bot {api_key}"}, + timeout=10.0, + ) + r.raise_for_status() + # t=0 items are search results; other values are related searches, etc. + items = [ + {"title": d.get("title", ""), "url": d.get("url", ""), "content": d.get("snippet", "")} + for d in r.json().get("data", []) if d.get("t") == 0 + ] + return _format_results(query, items, n) + except Exception as e: + return f"Error: {e}" + async def _search_duckduckgo(self, query: str, n: int) -> str: try: # Note: duckduckgo_search is synchronous and does its own requests diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 67cce4470..a841fe159 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -159,7 +159,7 @@ class GatewayConfig(Base): class WebSearchConfig(Base): """Web search tool configuration.""" - provider: str = "duckduckgo" # brave, tavily, duckduckgo, searxng, jina + provider: str = "duckduckgo" # brave, tavily, duckduckgo, searxng, jina, kagi api_key: str = "" base_url: str = "" # SearXNG base URL max_results: int = 5 From 74dbce3770ce180c553ae5fbe7a17a8512c17ef7 Mon Sep 17 00:00:00 2001 From: Mike Terhar Date: Wed, 8 Apr 2026 09:25:21 -0400 Subject: [PATCH 069/115] add kagi info to README --- README.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/README.md b/README.md index 856986754..72fd62aa0 100644 --- a/README.md +++ b/README.md @@ -1312,6 +1312,7 @@ If you need to allow trusted private ranges such as Tailscale / CGNAT addresses, | `brave` | `apiKey` | `BRAVE_API_KEY` | No | | `tavily` | `apiKey` | `TAVILY_API_KEY` | No | | `jina` | `apiKey` | `JINA_API_KEY` | Free tier (10M tokens) | +| `kagi` | `apiKey` | `KAGI_API_KEY` | No | | `searxng` | `baseUrl` | `SEARXNG_BASE_URL` | Yes (self-hosted) | | `duckduckgo` (default) | — | — | Yes | @@ -1368,6 +1369,20 @@ If you need to allow trusted private ranges such as Tailscale / CGNAT addresses, } ``` +**Kagi:** +```json +{ + "tools": { + "web": { + "search": { + "provider": "kagi", + "apiKey": "your-kagi-api-key" + } + } + } +} +``` + **SearXNG** (self-hosted, no API key needed): ```json { From b959ae6d8965ca58b8dde49247137403e09ecea1 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Sat, 11 Apr 2026 08:48:42 +0000 Subject: [PATCH 070/115] test(web): cover Kagi search provider Add focused coverage for the Kagi web search provider, including the request format and the DuckDuckGo fallback when no API key is configured. --- tests/tools/test_web_search_tool.py | 38 +++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/tools/test_web_search_tool.py b/tests/tools/test_web_search_tool.py index e33dd7e6c..790d8adcd 100644 --- a/tests/tools/test_web_search_tool.py +++ b/tests/tools/test_web_search_tool.py @@ -120,6 +120,27 @@ async def test_jina_search(monkeypatch): assert "https://jina.ai" in result +@pytest.mark.asyncio +async def test_kagi_search(monkeypatch): + async def mock_get(self, url, **kw): + assert "kagi.com/api/v0/search" in url + assert kw["headers"]["Authorization"] == "Bot kagi-key" + assert kw["params"] == {"q": "test", "limit": 2} + return _response(json={ + "data": [ + {"t": 0, "title": "Kagi Result", "url": "https://kagi.com", "snippet": "Premium search"}, + {"t": 1, "list": ["ignored related search"]}, + ] + }) + + monkeypatch.setattr(httpx.AsyncClient, "get", mock_get) + tool = _tool(provider="kagi", api_key="kagi-key") + result = await tool.execute(query="test", count=2) + assert "Kagi Result" in result + assert "https://kagi.com" in result + assert "ignored related search" not in result + + @pytest.mark.asyncio async def test_unknown_provider(): tool = _tool(provider="unknown") @@ -189,6 +210,23 @@ async def test_jina_422_falls_back_to_duckduckgo(monkeypatch): assert "DuckDuckGo fallback" in result +@pytest.mark.asyncio +async def test_kagi_fallback_to_duckduckgo_when_no_key(monkeypatch): + class MockDDGS: + def __init__(self, **kw): + pass + + def text(self, query, max_results=5): + return [{"title": "Fallback", "href": "https://ddg.example", "body": "DuckDuckGo fallback"}] + + monkeypatch.setattr("ddgs.DDGS", MockDDGS) + monkeypatch.delenv("KAGI_API_KEY", raising=False) + + tool = _tool(provider="kagi", api_key="") + result = await tool.execute(query="test") + assert "Fallback" in result + + @pytest.mark.asyncio async def test_jina_search_uses_path_encoded_query(monkeypatch): calls = {} From f5640d69fe39ae5dee993f64567d953fd7fc7b71 Mon Sep 17 00:00:00 2001 From: Jiajun Xie Date: Thu, 9 Apr 2026 09:18:33 +0800 Subject: [PATCH 071/115] fix(feishu): improve voice message download with detailed logging - Add explicit error logging for missing file_key and message_id - Add logging for download failures - Change audio extension from .opus to .ogg for better Whisper compatibility - Feishu voice messages are opus in OGG container; .ogg is more widely recognized --- nanobot/channels/feishu.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py index e57fcef85..7d9c2772b 100644 --- a/nanobot/channels/feishu.py +++ b/nanobot/channels/feishu.py @@ -1014,14 +1014,29 @@ class FeishuChannel(BaseChannel): elif msg_type in ("audio", "file", "media"): file_key = content_json.get("file_key") - if file_key and message_id: - data, filename = await loop.run_in_executor( - None, self._download_file_sync, message_id, file_key, msg_type - ) - if not filename: - filename = file_key[:16] - if msg_type == "audio" and not filename.endswith(".opus"): - filename = f"{filename}.opus" + if not file_key: + logger.warning("Feishu {} message missing file_key: {}", msg_type, content_json) + return None, f"[{msg_type}: missing file_key]" + if not message_id: + logger.warning("Feishu {} message missing message_id", msg_type) + return None, f"[{msg_type}: missing message_id]" + + data, filename = await loop.run_in_executor( + None, self._download_file_sync, message_id, file_key, msg_type + ) + + if not data: + logger.warning("Feishu {} download failed: file_key={}", msg_type, file_key) + return None, f"[{msg_type}: download failed]" + + if not filename: + filename = file_key[:16] + + # Feishu voice messages are opus in OGG container. + # Use .ogg extension for better Whisper compatibility. + if msg_type == "audio": + if not any(filename.endswith(ext) for ext in (".opus", ".ogg", ".oga")): + filename = f"{filename}.ogg" if data and filename: file_path = media_dir / filename From 36d2a11e73fcf095edf07c591cd6a5e06bba7e54 Mon Sep 17 00:00:00 2001 From: chengyongru <61816729+chengyongru@users.noreply.github.com> Date: Sat, 11 Apr 2026 02:11:02 +0800 Subject: [PATCH 072/115] feat(agent): mid-turn message injection for responsive follow-ups (#2985) * feat(agent): add mid-turn message injection for responsive follow-ups Allow user messages sent during an active agent turn to be injected into the running LLM context instead of being queued behind a per-session lock. Inspired by Claude Code's mid-turn queue drain mechanism (query.ts:1547-1643). Key design decisions: - Messages are injected as natural user messages between iterations, no tool cancellation or special system prompt needed - Two drain checkpoints: after tool execution and after final LLM response ("last-mile" to prevent dropping late arrivals) - Bounded by MAX_INJECTION_CYCLES (5) to prevent consuming the iteration budget on rapid follow-ups - had_injections flag bypasses _sent_in_turn suppression so follow-up responses are always delivered Closes #1609 * fix(agent): harden mid-turn injection with streaming fix, bounded queue, and message safety - Fix streaming protocol violation: Checkpoint 2 now checks for injections BEFORE calling on_stream_end, passing resuming=True when injections found so streaming channels (Feishu) don't prematurely finalize the card - Bound pending queue to maxsize=20 with QueueFull handling - Add warning log when injection batch exceeds _MAX_INJECTIONS_PER_TURN - Re-publish leftover queue messages to bus in _dispatch finally block to prevent silent message loss on early exit (max_iterations, tool_error, cancel) - Fix PEP 8 blank line before dataclass and logger.info indentation - Add 12 new tests covering drain, checkpoints, cycle cap, queue routing, cleanup, and leftover re-publish --- nanobot/agent/loop.py | 226 +++++++----- nanobot/agent/runner.py | 77 +++- tests/agent/test_hook_composite.py | 6 +- tests/agent/test_runner.py | 415 +++++++++++++++++++++- tests/tools/test_message_tool_suppress.py | 2 +- 5 files changed, 631 insertions(+), 95 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 0c33bc5c8..b51444650 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -207,6 +207,10 @@ class AgentLoop: self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks self._background_tasks: list[asyncio.Task] = [] self._session_locks: dict[str, asyncio.Lock] = {} + # Per-session pending queues for mid-turn message injection. + # When a session has an active task, new messages for that session + # are routed here instead of creating a new task. + self._pending_queues: dict[str, asyncio.Queue] = {} # NANOBOT_MAX_CONCURRENT_REQUESTS: <=0 means unlimited; default 3. _max = int(os.environ.get("NANOBOT_MAX_CONCURRENT_REQUESTS", "3")) self._concurrency_gate: asyncio.Semaphore | None = ( @@ -331,13 +335,16 @@ class AgentLoop: channel: str = "cli", chat_id: str = "direct", message_id: str | None = None, - ) -> tuple[str | None, list[str], list[dict], str]: + pending_queue: asyncio.Queue | None = None, + ) -> tuple[str | None, list[str], list[dict], str, bool]: """Run the agent iteration loop. *on_stream*: called with each content delta during streaming. *on_stream_end(resuming)*: called when a streaming session finishes. ``resuming=True`` means tool calls follow (spinner should restart); ``resuming=False`` means this is the final response. + + Returns (final_content, tools_used, messages, stop_reason, had_injections). """ loop_hook = _LoopHook( self, @@ -357,31 +364,42 @@ class AgentLoop: return self._set_runtime_checkpoint(session, payload) - result = await self.runner.run( - AgentRunSpec( - initial_messages=initial_messages, - tools=self.tools, - model=self.model, - max_iterations=self.max_iterations, - max_tool_result_chars=self.max_tool_result_chars, - hook=hook, - error_message="Sorry, I encountered an error calling the AI model.", - concurrent_tools=True, - workspace=self.workspace, - session_key=session.key if session else None, - context_window_tokens=self.context_window_tokens, - context_block_limit=self.context_block_limit, - provider_retry_mode=self.provider_retry_mode, - progress_callback=on_progress, - checkpoint_callback=_checkpoint, - ) - ) + async def _drain_pending() -> list[InboundMessage]: + """Non-blocking drain of follow-up messages from the pending queue.""" + if pending_queue is None: + return [] + items: list[InboundMessage] = [] + while True: + try: + items.append(pending_queue.get_nowait()) + except asyncio.QueueEmpty: + break + return items + + result = await self.runner.run(AgentRunSpec( + initial_messages=initial_messages, + tools=self.tools, + model=self.model, + max_iterations=self.max_iterations, + max_tool_result_chars=self.max_tool_result_chars, + hook=hook, + error_message="Sorry, I encountered an error calling the AI model.", + concurrent_tools=True, + workspace=self.workspace, + session_key=session.key if session else None, + context_window_tokens=self.context_window_tokens, + context_block_limit=self.context_block_limit, + provider_retry_mode=self.provider_retry_mode, + progress_callback=on_progress, + checkpoint_callback=_checkpoint, + injection_callback=_drain_pending, + )) self._last_usage = result.usage if result.stop_reason == "max_iterations": logger.warning("Max iterations ({}) reached", self.max_iterations) elif result.stop_reason == "error": logger.error("LLM returned error: {}", (result.final_content or "")[:200]) - return result.final_content, result.tools_used, result.messages, result.stop_reason + return result.final_content, result.tools_used, result.messages, result.stop_reason, result.had_injections async def run(self) -> None: """Run the agent loop, dispatching messages as tasks to stay responsive to /stop.""" @@ -412,6 +430,23 @@ class AgentLoop: if result: await self.bus.publish_outbound(result) continue + # If this session already has an active pending queue (i.e. a task + # is processing this session), route the message there for mid-turn + # injection instead of creating a competing task. + if msg.session_key in self._pending_queues: + try: + self._pending_queues[msg.session_key].put_nowait(msg) + except asyncio.QueueFull: + logger.warning( + "Pending queue full for session {}, dropping follow-up", + msg.session_key, + ) + else: + logger.info( + "Routed follow-up message to pending queue for session {}", + msg.session_key, + ) + continue # Compute the effective session key before dispatching # This ensures /stop command can find tasks correctly when unified session is enabled effective_key = ( @@ -432,76 +467,89 @@ class AgentLoop: """Process a message: per-session serial, cross-session concurrent.""" if self._unified_session and not msg.session_key_override: msg = dataclasses.replace(msg, session_key_override=UNIFIED_SESSION_KEY) - lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock()) + session_key = msg.session_key + lock = self._session_locks.setdefault(session_key, asyncio.Lock()) gate = self._concurrency_gate or nullcontext() - async with lock, gate: - try: - on_stream = on_stream_end = None - if msg.metadata.get("_wants_stream"): - # Split one answer into distinct stream segments. - stream_base_id = f"{msg.session_key}:{time.time_ns()}" - stream_segment = 0 - def _current_stream_id() -> str: - return f"{stream_base_id}:{stream_segment}" + # Register a pending queue so follow-up messages for this session are + # routed here (mid-turn injection) instead of spawning a new task. + pending = asyncio.Queue(maxsize=20) + self._pending_queues[session_key] = pending - async def on_stream(delta: str) -> None: - meta = dict(msg.metadata or {}) - meta["_stream_delta"] = True - meta["_stream_id"] = _current_stream_id() - await self.bus.publish_outbound( - OutboundMessage( - channel=msg.channel, - chat_id=msg.chat_id, + try: + async with lock, gate: + try: + on_stream = on_stream_end = None + if msg.metadata.get("_wants_stream"): + # Split one answer into distinct stream segments. + stream_base_id = f"{msg.session_key}:{time.time_ns()}" + stream_segment = 0 + + def _current_stream_id() -> str: + return f"{stream_base_id}:{stream_segment}" + + async def on_stream(delta: str) -> None: + meta = dict(msg.metadata or {}) + meta["_stream_delta"] = True + meta["_stream_id"] = _current_stream_id() + await self.bus.publish_outbound(OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, content=delta, metadata=meta, - ) - ) + )) - async def on_stream_end(*, resuming: bool = False) -> None: - nonlocal stream_segment - meta = dict(msg.metadata or {}) - meta["_stream_end"] = True - meta["_resuming"] = resuming - meta["_stream_id"] = _current_stream_id() - await self.bus.publish_outbound( - OutboundMessage( - channel=msg.channel, - chat_id=msg.chat_id, + async def on_stream_end(*, resuming: bool = False) -> None: + nonlocal stream_segment + meta = dict(msg.metadata or {}) + meta["_stream_end"] = True + meta["_resuming"] = resuming + meta["_stream_id"] = _current_stream_id() + await self.bus.publish_outbound(OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, content="", metadata=meta, - ) - ) - stream_segment += 1 + )) + stream_segment += 1 - response = await self._process_message( - msg, - on_stream=on_stream, - on_stream_end=on_stream_end, - ) - if response is not None: - await self.bus.publish_outbound(response) - elif msg.channel == "cli": - await self.bus.publish_outbound( - OutboundMessage( - channel=msg.channel, - chat_id=msg.chat_id, - content="", - metadata=msg.metadata or {}, - ) + response = await self._process_message( + msg, on_stream=on_stream, on_stream_end=on_stream_end, + pending_queue=pending, ) - except asyncio.CancelledError: - logger.info("Task cancelled for session {}", msg.session_key) - raise - except Exception: - logger.exception("Error processing message for session {}", msg.session_key) - await self.bus.publish_outbound( - OutboundMessage( - channel=msg.channel, - chat_id=msg.chat_id, + if response is not None: + await self.bus.publish_outbound(response) + elif msg.channel == "cli": + await self.bus.publish_outbound(OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, + content="", metadata=msg.metadata or {}, + )) + except asyncio.CancelledError: + logger.info("Task cancelled for session {}", session_key) + raise + except Exception: + logger.exception("Error processing message for session {}", session_key) + await self.bus.publish_outbound(OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, content="Sorry, I encountered an error.", + )) + finally: + # Drain any messages still in the pending queue and re-publish + # them to the bus so they are processed as fresh inbound messages + # rather than silently lost. + queue = self._pending_queues.pop(session_key, None) + if queue is not None: + leftover = 0 + while True: + try: + item = queue.get_nowait() + except asyncio.QueueEmpty: + break + await self.bus.publish_inbound(item) + leftover += 1 + if leftover: + logger.info( + "Re-published {} leftover message(s) to bus for session {}", + leftover, session_key, ) - ) async def close_mcp(self) -> None: """Drain pending background archives, then close MCP connections.""" @@ -533,6 +581,7 @@ class AgentLoop: on_progress: Callable[[str], Awaitable[None]] | None = None, on_stream: Callable[[str], Awaitable[None]] | None = None, on_stream_end: Callable[..., Awaitable[None]] | None = None, + pending_queue: asyncio.Queue | None = None, ) -> OutboundMessage | None: """Process a single inbound message and return the response.""" # System messages: parse origin from chat_id ("channel:chat_id") @@ -559,11 +608,8 @@ class AgentLoop: session_summary=pending, current_role=current_role, ) - final_content, _, all_msgs, _ = await self._run_agent_loop( - messages, - session=session, - channel=channel, - chat_id=chat_id, + final_content, _, all_msgs, _, _ = await self._run_agent_loop( + messages, session=session, channel=channel, chat_id=chat_id, message_id=msg.metadata.get("message_id"), ) self._save_turn(session, all_msgs, 1 + len(history)) @@ -623,7 +669,7 @@ class AgentLoop: ) ) - final_content, _, all_msgs, stop_reason = await self._run_agent_loop( + final_content, _, all_msgs, stop_reason, had_injections = await self._run_agent_loop( initial_messages, on_progress=on_progress or _bus_progress, on_stream=on_stream, @@ -632,6 +678,7 @@ class AgentLoop: channel=msg.channel, chat_id=msg.chat_id, message_id=msg.metadata.get("message_id"), + pending_queue=pending_queue, ) if final_content is None or not final_content.strip(): @@ -642,8 +689,13 @@ class AgentLoop: self.sessions.save(session) self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) - if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: - return None + # When follow-up messages were injected mid-turn, the LLM's final + # response addresses those follow-ups. Always send the response in + # this case, even if MessageTool was used earlier in the turn — the + # follow-up response is new content the user hasn't seen. + if not had_injections: + if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: + return None preview = final_content[:120] + "..." if len(final_content) > 120 else final_content logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview) diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index cfebe098f..f7187191e 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -34,6 +34,8 @@ _DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model." _PERSISTED_MODEL_ERROR_PLACEHOLDER = "[Assistant reply unavailable due to model error.]" _MAX_EMPTY_RETRIES = 2 _MAX_LENGTH_RECOVERIES = 3 +_MAX_INJECTIONS_PER_TURN = 3 +_MAX_INJECTION_CYCLES = 5 _SNIP_SAFETY_BUFFER = 1024 _MICROCOMPACT_KEEP_RECENT = 10 _MICROCOMPACT_MIN_CHARS = 500 @@ -42,6 +44,9 @@ _COMPACTABLE_TOOLS = frozenset({ "web_search", "web_fetch", "list_dir", }) _BACKFILL_CONTENT = "[Tool result unavailable — call was interrupted or lost]" + + + @dataclass(slots=True) class AgentRunSpec: """Configuration for a single agent execution.""" @@ -66,6 +71,7 @@ class AgentRunSpec: provider_retry_mode: str = "standard" progress_callback: Any | None = None checkpoint_callback: Any | None = None + injection_callback: Any | None = None @dataclass(slots=True) @@ -79,6 +85,7 @@ class AgentRunResult: stop_reason: str = "completed" error: str | None = None tool_events: list[dict[str, str]] = field(default_factory=list) + had_injections: bool = False class AgentRunner: @@ -87,6 +94,38 @@ class AgentRunner: def __init__(self, provider: LLMProvider): self.provider = provider + async def _drain_injections(self, spec: AgentRunSpec) -> list[str]: + """Drain pending user messages via the injection callback. + + Returns all drained message contents (capped by + ``_MAX_INJECTIONS_PER_TURN``), or an empty list when there is + nothing to inject. Messages beyond the cap are logged so they + are not silently lost. + """ + if spec.injection_callback is None: + return [] + try: + items = await spec.injection_callback() + except Exception: + logger.exception("injection_callback failed") + return [] + if not items: + return [] + # items are InboundMessage objects from _drain_pending + texts: list[str] = [] + for item in items: + text = getattr(item, "content", str(item)) + if text.strip(): + texts.append(text) + if len(texts) > _MAX_INJECTIONS_PER_TURN: + dropped = len(texts) - _MAX_INJECTIONS_PER_TURN + logger.warning( + "Injection batch has {} messages, capping to {} ({} dropped)", + len(texts), _MAX_INJECTIONS_PER_TURN, dropped, + ) + texts = texts[-_MAX_INJECTIONS_PER_TURN:] + return texts + async def run(self, spec: AgentRunSpec) -> AgentRunResult: hook = spec.hook or AgentHook() messages = list(spec.initial_messages) @@ -99,6 +138,8 @@ class AgentRunner: external_lookup_counts: dict[str, int] = {} empty_content_retries = 0 length_recovery_count = 0 + had_injections = False + injection_cycles = 0 for iteration in range(spec.max_iterations): try: @@ -200,6 +241,18 @@ class AgentRunner: ) empty_content_retries = 0 length_recovery_count = 0 + # Checkpoint 1: drain injections after tools, before next LLM call + if injection_cycles < _MAX_INJECTION_CYCLES: + injections = await self._drain_injections(spec) + if injections: + had_injections = True + injection_cycles += 1 + for text in injections: + messages.append({"role": "user", "content": text}) + logger.info( + "Injected {} follow-up message(s) after tool execution ({}/{})", + len(injections), injection_cycles, _MAX_INJECTION_CYCLES, + ) await hook.after_iteration(context) continue @@ -256,8 +309,29 @@ class AgentRunner: await hook.after_iteration(context) continue + # Check for mid-turn injections BEFORE signaling stream end. + # If injections are found we keep the stream alive (resuming=True) + # so streaming channels don't prematurely finalize the card. + _injected_after_final = False + if injection_cycles < _MAX_INJECTION_CYCLES: + injections = await self._drain_injections(spec) + if injections: + had_injections = True + injection_cycles += 1 + _injected_after_final = True + for text in injections: + messages.append({"role": "user", "content": text}) + logger.info( + "Injected {} follow-up message(s) after final response ({}/{})", + len(injections), injection_cycles, _MAX_INJECTION_CYCLES, + ) + if hook.wants_streaming(): - await hook.on_stream_end(context, resuming=False) + await hook.on_stream_end(context, resuming=_injected_after_final) + + if _injected_after_final: + await hook.after_iteration(context) + continue if response.finish_reason == "error": final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE @@ -323,6 +397,7 @@ class AgentRunner: stop_reason=stop_reason, error=error, tool_events=tool_events, + had_injections=had_injections, ) def _build_request_kwargs( diff --git a/tests/agent/test_hook_composite.py b/tests/agent/test_hook_composite.py index 672f38ed2..8971d48ec 100644 --- a/tests/agent/test_hook_composite.py +++ b/tests/agent/test_hook_composite.py @@ -307,7 +307,7 @@ async def test_agent_loop_extra_hook_receives_calls(tmp_path): ) loop.tools.get_definitions = MagicMock(return_value=[]) - content, tools_used, messages, _ = await loop._run_agent_loop( + content, tools_used, messages, _, _ = await loop._run_agent_loop( [{"role": "user", "content": "hi"}] ) @@ -331,7 +331,7 @@ async def test_agent_loop_extra_hook_error_isolation(tmp_path): ) loop.tools.get_definitions = MagicMock(return_value=[]) - content, _, _, _ = await loop._run_agent_loop( + content, _, _, _, _ = await loop._run_agent_loop( [{"role": "user", "content": "hi"}] ) @@ -373,7 +373,7 @@ async def test_agent_loop_no_hooks_backward_compat(tmp_path): loop.tools.execute = AsyncMock(return_value="ok") loop.max_iterations = 2 - content, tools_used, _, _ = await loop._run_agent_loop([]) + content, tools_used, _, _, _ = await loop._run_agent_loop([]) assert content == ( "I reached the maximum number of tool call iterations (2) " "without completing the task. You can try breaking the task into smaller steps." diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index ef4206573..ba503e988 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -798,7 +798,7 @@ async def test_loop_max_iterations_message_stays_stable(tmp_path): loop.tools.execute = AsyncMock(return_value="ok") loop.max_iterations = 2 - final_content, _, _, _ = await loop._run_agent_loop([]) + final_content, _, _, _, _ = await loop._run_agent_loop([]) assert final_content == ( "I reached the maximum number of tool call iterations (2) " @@ -825,7 +825,7 @@ async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp async def on_stream_end(*, resuming: bool = False) -> None: endings.append(resuming) - final_content, _, _, _ = await loop._run_agent_loop( + final_content, _, _, _, _ = await loop._run_agent_loop( [], on_stream=on_stream, on_stream_end=on_stream_end, @@ -849,7 +849,7 @@ async def test_loop_retries_think_only_final_response(tmp_path): loop.provider.chat_with_retry = chat_with_retry - final_content, _, _, _ = await loop._run_agent_loop([]) + final_content, _, _, _, _ = await loop._run_agent_loop([]) assert final_content == "Recovered answer" assert call_count["n"] == 2 @@ -1607,3 +1607,412 @@ async def test_microcompact_skips_non_compactable_tools(): result = AgentRunner._microcompact(messages) assert result is messages # no compactable tools found + + +# ── Mid-turn injection tests ────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_drain_injections_returns_empty_when_no_callback(): + """No injection_callback → empty list.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + runner = AgentRunner(provider) + tools = MagicMock() + tools.get_definitions.return_value = [] + spec = AgentRunSpec( + initial_messages=[], tools=tools, model="m", + max_iterations=1, max_tool_result_chars=1000, + injection_callback=None, + ) + result = await runner._drain_injections(spec) + assert result == [] + + +@pytest.mark.asyncio +async def test_drain_injections_extracts_content_from_inbound_messages(): + """Should extract .content from InboundMessage objects.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + runner = AgentRunner(provider) + tools = MagicMock() + tools.get_definitions.return_value = [] + + msgs = [ + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="hello"), + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="world"), + ] + + async def cb(): + return msgs + + spec = AgentRunSpec( + initial_messages=[], tools=tools, model="m", + max_iterations=1, max_tool_result_chars=1000, + injection_callback=cb, + ) + result = await runner._drain_injections(spec) + assert result == ["hello", "world"] + + +@pytest.mark.asyncio +async def test_drain_injections_caps_at_max_and_logs_warning(): + """When more than _MAX_INJECTIONS_PER_TURN items, only the last N are kept.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + runner = AgentRunner(provider) + tools = MagicMock() + tools.get_definitions.return_value = [] + + msgs = [ + InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg{i}") + for i in range(_MAX_INJECTIONS_PER_TURN + 3) + ] + + async def cb(): + return msgs + + spec = AgentRunSpec( + initial_messages=[], tools=tools, model="m", + max_iterations=1, max_tool_result_chars=1000, + injection_callback=cb, + ) + result = await runner._drain_injections(spec) + assert len(result) == _MAX_INJECTIONS_PER_TURN + # Should keep the LAST _MAX_INJECTIONS_PER_TURN items + assert result[0] == "msg3" + assert result[-1] == f"msg{_MAX_INJECTIONS_PER_TURN + 2}" + + +@pytest.mark.asyncio +async def test_drain_injections_skips_empty_content(): + """Messages with blank content should be filtered out.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + runner = AgentRunner(provider) + tools = MagicMock() + tools.get_definitions.return_value = [] + + msgs = [ + InboundMessage(channel="cli", sender_id="u", chat_id="c", content=""), + InboundMessage(channel="cli", sender_id="u", chat_id="c", content=" "), + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="valid"), + ] + + async def cb(): + return msgs + + spec = AgentRunSpec( + initial_messages=[], tools=tools, model="m", + max_iterations=1, max_tool_result_chars=1000, + injection_callback=cb, + ) + result = await runner._drain_injections(spec) + assert result == ["valid"] + + +@pytest.mark.asyncio +async def test_drain_injections_handles_callback_exception(): + """If the callback raises, return empty list (error is logged).""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + runner = AgentRunner(provider) + tools = MagicMock() + tools.get_definitions.return_value = [] + + async def cb(): + raise RuntimeError("boom") + + spec = AgentRunSpec( + initial_messages=[], tools=tools, model="m", + max_iterations=1, max_tool_result_chars=1000, + injection_callback=cb, + ) + result = await runner._drain_injections(spec) + assert result == [] + + +@pytest.mark.asyncio +async def test_checkpoint1_injects_after_tool_execution(): + """Follow-up messages are injected after tool execution, before next LLM call.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + captured_messages = [] + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + captured_messages.append(list(messages)) + if call_count["n"] == 1: + return LLMResponse( + content="using tool", + tool_calls=[ToolCallRequest(id="c1", name="read_file", arguments={"path": "x"})], + usage={}, + ) + return LLMResponse(content="final answer", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="file content") + + injection_queue = asyncio.Queue() + + async def inject_cb(): + items = [] + while not injection_queue.empty(): + items.append(await injection_queue.get()) + return items + + # Put a follow-up message in the queue before the run starts + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + assert result.final_content == "final answer" + # The second call should have the injected user message + assert call_count["n"] == 2 + last_messages = captured_messages[-1] + injected = [m for m in last_messages if m.get("role") == "user" and m.get("content") == "follow-up question"] + assert len(injected) == 1 + + +@pytest.mark.asyncio +async def test_checkpoint2_injects_after_final_response_with_resuming_stream(): + """After final response, if injections exist, stream_end should get resuming=True.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.hook import AgentHook, AgentHookContext + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + stream_end_calls = [] + + class TrackingHook(AgentHook): + def wants_streaming(self) -> bool: + return True + + async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: + stream_end_calls.append(resuming) + + def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: + return content + + async def chat_stream_with_retry(*, messages, on_content_delta=None, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse(content="first answer", tool_calls=[], usage={}) + return LLMResponse(content="second answer", tool_calls=[], usage={}) + + provider.chat_stream_with_retry = chat_stream_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + injection_queue = asyncio.Queue() + + async def inject_cb(): + items = [] + while not injection_queue.empty(): + items.append(await injection_queue.get()) + return items + + # Inject a follow-up that arrives during the first response + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="quick follow-up") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + hook=TrackingHook(), + injection_callback=inject_cb, + )) + + assert result.had_injections is True + assert result.final_content == "second answer" + assert call_count["n"] == 2 + # First stream_end should have resuming=True (because injections found) + assert stream_end_calls[0] is True + # Second (final) stream_end should have resuming=False + assert stream_end_calls[-1] is False + + +@pytest.mark.asyncio +async def test_injection_cycles_capped_at_max(): + """Injection cycles should be capped at _MAX_INJECTION_CYCLES.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTION_CYCLES + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + drain_count = {"n": 0} + + async def inject_cb(): + drain_count["n"] += 1 + # Only inject for the first _MAX_INJECTION_CYCLES drains + if drain_count["n"] <= _MAX_INJECTION_CYCLES: + return [InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg-{drain_count['n']}")] + return [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "start"}], + tools=tools, + model="test-model", + max_iterations=20, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + # Should be capped: _MAX_INJECTION_CYCLES injection rounds + 1 final round + assert call_count["n"] == _MAX_INJECTION_CYCLES + 1 + + +@pytest.mark.asyncio +async def test_no_injections_flag_is_false_by_default(): + """had_injections should be False when no injection callback or no messages.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + + async def chat_with_retry(**kwargs): + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hi"}], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.had_injections is False + + +@pytest.mark.asyncio +async def test_pending_queue_cleanup_on_dispatch(tmp_path): + """_pending_queues should be cleaned up after _dispatch completes.""" + loop = _make_loop(tmp_path) + + async def chat_with_retry(**kwargs): + return LLMResponse(content="done", tool_calls=[], usage={}) + + loop.provider.chat_with_retry = chat_with_retry + + from nanobot.bus.events import InboundMessage + + msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="hello") + # The queue should not exist before dispatch + assert msg.session_key not in loop._pending_queues + + await loop._dispatch(msg) + + # The queue should be cleaned up after dispatch + assert msg.session_key not in loop._pending_queues + + +@pytest.mark.asyncio +async def test_followup_routed_to_pending_queue(tmp_path): + """When a session has an active dispatch, follow-up messages go to pending queue.""" + from nanobot.bus.events import InboundMessage + + loop = _make_loop(tmp_path) + + # Simulate an active dispatch by manually adding a pending queue + pending = asyncio.Queue(maxsize=20) + loop._pending_queues["cli:c"] = pending + + msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up") + + # Directly test the routing logic from run() — if session_key is in + # _pending_queues, the message should be put into the queue. + assert msg.session_key in loop._pending_queues + loop._pending_queues[msg.session_key].put_nowait(msg) + + assert not pending.empty() + queued_msg = pending.get_nowait() + assert queued_msg.content == "follow-up" + + +@pytest.mark.asyncio +async def test_dispatch_republishes_leftover_queue_messages(tmp_path): + """Messages left in the pending queue after _dispatch are re-published to the bus. + + This tests the finally-block cleanup that prevents message loss when + the runner exits early (e.g., max_iterations, tool_error) with messages + still in the queue. + """ + from nanobot.bus.events import InboundMessage + + loop = _make_loop(tmp_path) + bus = loop.bus + + # Simulate a completed dispatch by manually registering a queue + # with leftover messages, then running the cleanup logic directly. + pending = asyncio.Queue(maxsize=20) + session_key = "cli:c" + loop._pending_queues[session_key] = pending + pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="leftover-1")) + pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="leftover-2")) + + # Execute the cleanup logic from the finally block + queue = loop._pending_queues.pop(session_key, None) + assert queue is not None + leftover = 0 + while True: + try: + item = queue.get_nowait() + except asyncio.QueueEmpty: + break + await bus.publish_inbound(item) + leftover += 1 + + assert leftover == 2 + + # Verify the messages are now on the bus + msgs = [] + while not bus.inbound.empty(): + msgs.append(await asyncio.wait_for(bus.consume_inbound(), timeout=0.5)) + contents = [m.content for m in msgs] + assert "leftover-1" in contents + assert "leftover-2" in contents diff --git a/tests/tools/test_message_tool_suppress.py b/tests/tools/test_message_tool_suppress.py index 3f06b4a70..a922e95ed 100644 --- a/tests/tools/test_message_tool_suppress.py +++ b/tests/tools/test_message_tool_suppress.py @@ -107,7 +107,7 @@ class TestMessageToolSuppressLogic: async def on_progress(content: str, *, tool_hint: bool = False) -> None: progress.append((content, tool_hint)) - final_content, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress) + final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress) assert final_content == "Done" assert progress == [ From f6c39ec946d01e358e734d7dd99655e66323f30c Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Sat, 11 Apr 2026 13:17:28 +0000 Subject: [PATCH 073/115] feat(agent): enhance session key handling for follow-up messages --- nanobot/agent/loop.py | 32 +++++++++------ nanobot/agent/runner.py | 23 ++++++++++- tests/agent/test_runner.py | 84 ++++++++++++++++++++++++++++++++++---- 3 files changed, 118 insertions(+), 21 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index b51444650..0f72e39f6 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -324,6 +324,12 @@ class AgentLoop: return format_tool_hints(tool_calls) + def _effective_session_key(self, msg: InboundMessage) -> str: + """Return the session key used for task routing and mid-turn injections.""" + if self._unified_session and not msg.session_key_override: + return UNIFIED_SESSION_KEY + return msg.session_key + async def _run_agent_loop( self, initial_messages: list[dict], @@ -430,30 +436,32 @@ class AgentLoop: if result: await self.bus.publish_outbound(result) continue + effective_key = self._effective_session_key(msg) # If this session already has an active pending queue (i.e. a task # is processing this session), route the message there for mid-turn # injection instead of creating a competing task. - if msg.session_key in self._pending_queues: + if effective_key in self._pending_queues: + pending_msg = msg + if effective_key != msg.session_key: + pending_msg = dataclasses.replace( + msg, + session_key_override=effective_key, + ) try: - self._pending_queues[msg.session_key].put_nowait(msg) + self._pending_queues[effective_key].put_nowait(pending_msg) except asyncio.QueueFull: logger.warning( "Pending queue full for session {}, dropping follow-up", - msg.session_key, + effective_key, ) else: logger.info( "Routed follow-up message to pending queue for session {}", - msg.session_key, + effective_key, ) continue # Compute the effective session key before dispatching # This ensures /stop command can find tasks correctly when unified session is enabled - effective_key = ( - UNIFIED_SESSION_KEY - if self._unified_session and not msg.session_key_override - else msg.session_key - ) task = asyncio.create_task(self._dispatch(msg)) self._active_tasks.setdefault(effective_key, []).append(task) task.add_done_callback( @@ -465,9 +473,9 @@ class AgentLoop: async def _dispatch(self, msg: InboundMessage) -> None: """Process a message: per-session serial, cross-session concurrent.""" - if self._unified_session and not msg.session_key_override: - msg = dataclasses.replace(msg, session_key_override=UNIFIED_SESSION_KEY) - session_key = msg.session_key + session_key = self._effective_session_key(msg) + if session_key != msg.session_key: + msg = dataclasses.replace(msg, session_key_override=session_key) lock = self._session_locks.setdefault(session_key, asyncio.Lock()) gate = self._concurrency_gate or nullcontext() diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index f7187191e..0ba0e6bc6 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -309,6 +309,14 @@ class AgentRunner: await hook.after_iteration(context) continue + assistant_message: dict[str, Any] | None = None + if response.finish_reason != "error" and not is_blank_text(clean): + assistant_message = build_assistant_message( + clean, + reasoning_content=response.reasoning_content, + thinking_blocks=response.thinking_blocks, + ) + # Check for mid-turn injections BEFORE signaling stream end. # If injections are found we keep the stream alive (resuming=True) # so streaming channels don't prematurely finalize the card. @@ -319,6 +327,19 @@ class AgentRunner: had_injections = True injection_cycles += 1 _injected_after_final = True + if assistant_message is not None: + messages.append(assistant_message) + await self._emit_checkpoint( + spec, + { + "phase": "final_response", + "iteration": iteration, + "model": spec.model, + "assistant_message": assistant_message, + "completed_tool_results": [], + "pending_tool_calls": [], + }, + ) for text in injections: messages.append({"role": "user", "content": text}) logger.info( @@ -354,7 +375,7 @@ class AgentRunner: await hook.after_iteration(context) break - messages.append(build_assistant_message( + messages.append(assistant_message or build_assistant_message( clean, reasoning_content=response.reasoning_content, thinking_blocks=response.thinking_blocks, diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index ba503e988..a9e32e0f8 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -1862,6 +1862,66 @@ async def test_checkpoint2_injects_after_final_response_with_resuming_stream(): assert stream_end_calls[-1] is False +@pytest.mark.asyncio +async def test_checkpoint2_preserves_final_response_in_history_before_followup(): + """A follow-up injected after a final answer must still see that answer in history.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + captured_messages = [] + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + captured_messages.append([dict(message) for message in messages]) + if call_count["n"] == 1: + return LLMResponse(content="first answer", tool_calls=[], usage={}) + return LLMResponse(content="second answer", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + injection_queue = asyncio.Queue() + + async def inject_cb(): + items = [] + while not injection_queue.empty(): + items.append(await injection_queue.get()) + return items + + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.final_content == "second answer" + assert call_count["n"] == 2 + assert captured_messages[-1] == [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "first answer"}, + {"role": "user", "content": "follow-up question"}, + ] + assert [ + {"role": message["role"], "content": message["content"]} + for message in result.messages + if message.get("role") == "assistant" + ] == [ + {"role": "assistant", "content": "first answer"}, + {"role": "assistant", "content": "second answer"}, + ] + + @pytest.mark.asyncio async def test_injection_cycles_capped_at_max(): """Injection cycles should be capped at _MAX_INJECTION_CYCLES.""" @@ -1953,25 +2013,33 @@ async def test_pending_queue_cleanup_on_dispatch(tmp_path): @pytest.mark.asyncio async def test_followup_routed_to_pending_queue(tmp_path): - """When a session has an active dispatch, follow-up messages go to pending queue.""" + """Unified-session follow-ups should route into the active pending queue.""" + from nanobot.agent.loop import UNIFIED_SESSION_KEY from nanobot.bus.events import InboundMessage loop = _make_loop(tmp_path) + loop._unified_session = True + loop._dispatch = AsyncMock() # type: ignore[method-assign] - # Simulate an active dispatch by manually adding a pending queue pending = asyncio.Queue(maxsize=20) - loop._pending_queues["cli:c"] = pending + loop._pending_queues[UNIFIED_SESSION_KEY] = pending - msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up") + run_task = asyncio.create_task(loop.run()) + msg = InboundMessage(channel="discord", sender_id="u", chat_id="c", content="follow-up") + await loop.bus.publish_inbound(msg) - # Directly test the routing logic from run() — if session_key is in - # _pending_queues, the message should be put into the queue. - assert msg.session_key in loop._pending_queues - loop._pending_queues[msg.session_key].put_nowait(msg) + deadline = time.time() + 2 + while pending.empty() and time.time() < deadline: + await asyncio.sleep(0.01) + loop.stop() + await asyncio.wait_for(run_task, timeout=2) + + assert loop._dispatch.await_count == 0 assert not pending.empty() queued_msg = pending.get_nowait() assert queued_msg.content == "follow-up" + assert queued_msg.session_key == UNIFIED_SESSION_KEY @pytest.mark.asyncio From cf8381f517084b5e9ec8e14539dcdc5b0eab2baa Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Sat, 11 Apr 2026 13:37:47 +0000 Subject: [PATCH 074/115] feat(agent): enhance message injection handling and content merging --- nanobot/agent/loop.py | 42 ++-- nanobot/agent/runner.py | 85 ++++++-- tests/agent/test_runner.py | 235 +++++++++++++++++++++- tests/tools/test_message_tool_suppress.py | 37 ++++ 4 files changed, 358 insertions(+), 41 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 0f72e39f6..675865350 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -17,7 +17,7 @@ from nanobot.agent.autocompact import AutoCompact from nanobot.agent.context import ContextBuilder from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook from nanobot.agent.memory import Consolidator, Dream -from nanobot.agent.runner import AgentRunSpec, AgentRunner +from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunSpec, AgentRunner from nanobot.agent.subagent import SubagentManager from nanobot.agent.tools.cron import CronTool from nanobot.agent.skills import BUILTIN_SKILLS_DIR @@ -370,16 +370,30 @@ class AgentLoop: return self._set_runtime_checkpoint(session, payload) - async def _drain_pending() -> list[InboundMessage]: + async def _drain_pending(*, limit: int = _MAX_INJECTIONS_PER_TURN) -> list[dict[str, Any]]: """Non-blocking drain of follow-up messages from the pending queue.""" if pending_queue is None: return [] - items: list[InboundMessage] = [] - while True: + items: list[dict[str, Any]] = [] + while len(items) < limit: try: - items.append(pending_queue.get_nowait()) + pending_msg = pending_queue.get_nowait() except asyncio.QueueEmpty: break + user_content = self.context._build_user_content( + pending_msg.content, + pending_msg.media if pending_msg.media else None, + ) + runtime_ctx = self.context._build_runtime_context( + pending_msg.channel, + pending_msg.chat_id, + self.context.timezone, + ) + if isinstance(user_content, str): + merged: str | list[dict[str, Any]] = f"{runtime_ctx}\n\n{user_content}" + else: + merged = [{"type": "text", "text": runtime_ctx}] + user_content + items.append({"role": "user", "content": merged}) return items result = await self.runner.run(AgentRunSpec( @@ -451,7 +465,7 @@ class AgentLoop: self._pending_queues[effective_key].put_nowait(pending_msg) except asyncio.QueueFull: logger.warning( - "Pending queue full for session {}, dropping follow-up", + "Pending queue full for session {}, falling back to queued task", effective_key, ) else: @@ -459,7 +473,7 @@ class AgentLoop: "Routed follow-up message to pending queue for session {}", effective_key, ) - continue + continue # Compute the effective session key before dispatching # This ensures /stop command can find tasks correctly when unified session is enabled task = asyncio.create_task(self._dispatch(msg)) @@ -697,12 +711,14 @@ class AgentLoop: self.sessions.save(session) self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) - # When follow-up messages were injected mid-turn, the LLM's final - # response addresses those follow-ups. Always send the response in - # this case, even if MessageTool was used earlier in the turn — the - # follow-up response is new content the user hasn't seen. - if not had_injections: - if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: + # When follow-up messages were injected mid-turn, a later natural + # language reply may address those follow-ups and should not be + # suppressed just because MessageTool was used earlier in the turn. + # However, if the turn falls back to the empty-final-response + # placeholder, suppress it when the real user-visible output already + # came from MessageTool. + if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: + if not had_injections or stop_reason == "empty_final_response": return None preview = final_content[:120] + "..." if len(final_content) > 120 else final_content diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 0ba0e6bc6..164921bb4 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio from dataclasses import dataclass, field +import inspect from pathlib import Path from typing import Any @@ -94,37 +95,89 @@ class AgentRunner: def __init__(self, provider: LLMProvider): self.provider = provider - async def _drain_injections(self, spec: AgentRunSpec) -> list[str]: + @staticmethod + def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]: + if isinstance(left, str) and isinstance(right, str): + return f"{left}\n\n{right}" if left else right + + def _to_blocks(value: Any) -> list[dict[str, Any]]: + if isinstance(value, list): + return [ + item if isinstance(item, dict) else {"type": "text", "text": str(item)} + for item in value + ] + if value is None: + return [] + return [{"type": "text", "text": str(value)}] + + return _to_blocks(left) + _to_blocks(right) + + @classmethod + def _append_injected_messages( + cls, + messages: list[dict[str, Any]], + injections: list[dict[str, Any]], + ) -> None: + """Append injected user messages while preserving role alternation.""" + for injection in injections: + if ( + messages + and injection.get("role") == "user" + and messages[-1].get("role") == "user" + ): + merged = dict(messages[-1]) + merged["content"] = cls._merge_message_content( + merged.get("content"), + injection.get("content"), + ) + messages[-1] = merged + continue + messages.append(injection) + + async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]: """Drain pending user messages via the injection callback. - Returns all drained message contents (capped by + Returns normalized user messages (capped by ``_MAX_INJECTIONS_PER_TURN``), or an empty list when there is - nothing to inject. Messages beyond the cap are logged so they + nothing to inject. Messages beyond the cap are logged so they are not silently lost. """ if spec.injection_callback is None: return [] try: - items = await spec.injection_callback() + signature = inspect.signature(spec.injection_callback) + accepts_limit = ( + "limit" in signature.parameters + or any( + parameter.kind is inspect.Parameter.VAR_KEYWORD + for parameter in signature.parameters.values() + ) + ) + if accepts_limit: + items = await spec.injection_callback(limit=_MAX_INJECTIONS_PER_TURN) + else: + items = await spec.injection_callback() except Exception: logger.exception("injection_callback failed") return [] if not items: return [] - # items are InboundMessage objects from _drain_pending - texts: list[str] = [] + injected_messages: list[dict[str, Any]] = [] for item in items: + if isinstance(item, dict) and item.get("role") == "user" and "content" in item: + injected_messages.append(item) + continue text = getattr(item, "content", str(item)) if text.strip(): - texts.append(text) - if len(texts) > _MAX_INJECTIONS_PER_TURN: - dropped = len(texts) - _MAX_INJECTIONS_PER_TURN + injected_messages.append({"role": "user", "content": text}) + if len(injected_messages) > _MAX_INJECTIONS_PER_TURN: + dropped = len(injected_messages) - _MAX_INJECTIONS_PER_TURN logger.warning( - "Injection batch has {} messages, capping to {} ({} dropped)", - len(texts), _MAX_INJECTIONS_PER_TURN, dropped, + "Injection callback returned {} messages, capping to {} ({} dropped)", + len(injected_messages), _MAX_INJECTIONS_PER_TURN, dropped, ) - texts = texts[-_MAX_INJECTIONS_PER_TURN:] - return texts + injected_messages = injected_messages[:_MAX_INJECTIONS_PER_TURN] + return injected_messages async def run(self, spec: AgentRunSpec) -> AgentRunResult: hook = spec.hook or AgentHook() @@ -247,8 +300,7 @@ class AgentRunner: if injections: had_injections = True injection_cycles += 1 - for text in injections: - messages.append({"role": "user", "content": text}) + self._append_injected_messages(messages, injections) logger.info( "Injected {} follow-up message(s) after tool execution ({}/{})", len(injections), injection_cycles, _MAX_INJECTION_CYCLES, @@ -340,8 +392,7 @@ class AgentRunner: "pending_tool_calls": [], }, ) - for text in injections: - messages.append({"role": "user", "content": text}) + self._append_injected_messages(messages, injections) logger.info( "Injected {} follow-up message(s) after final response ({}/{})", len(injections), injection_cycles, _MAX_INJECTION_CYCLES, diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index a9e32e0f8..b9047b674 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import base64 import os import time from unittest.mock import AsyncMock, MagicMock, patch @@ -1633,7 +1634,7 @@ async def test_drain_injections_returns_empty_when_no_callback(): @pytest.mark.asyncio async def test_drain_injections_extracts_content_from_inbound_messages(): """Should extract .content from InboundMessage objects.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN + from nanobot.agent.runner import AgentRunSpec, AgentRunner from nanobot.bus.events import InboundMessage provider = MagicMock() @@ -1655,12 +1656,15 @@ async def test_drain_injections_extracts_content_from_inbound_messages(): injection_callback=cb, ) result = await runner._drain_injections(spec) - assert result == ["hello", "world"] + assert result == [ + {"role": "user", "content": "hello"}, + {"role": "user", "content": "world"}, + ] @pytest.mark.asyncio -async def test_drain_injections_caps_at_max_and_logs_warning(): - """When more than _MAX_INJECTIONS_PER_TURN items, only the last N are kept.""" +async def test_drain_injections_passes_limit_to_callback_when_supported(): + """Limit-aware callbacks can preserve overflow in their own queue.""" from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN from nanobot.bus.events import InboundMessage @@ -1668,14 +1672,16 @@ async def test_drain_injections_caps_at_max_and_logs_warning(): runner = AgentRunner(provider) tools = MagicMock() tools.get_definitions.return_value = [] + seen_limits: list[int] = [] msgs = [ InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg{i}") for i in range(_MAX_INJECTIONS_PER_TURN + 3) ] - async def cb(): - return msgs + async def cb(*, limit: int): + seen_limits.append(limit) + return msgs[:limit] spec = AgentRunSpec( initial_messages=[], tools=tools, model="m", @@ -1683,10 +1689,12 @@ async def test_drain_injections_caps_at_max_and_logs_warning(): injection_callback=cb, ) result = await runner._drain_injections(spec) - assert len(result) == _MAX_INJECTIONS_PER_TURN - # Should keep the LAST _MAX_INJECTIONS_PER_TURN items - assert result[0] == "msg3" - assert result[-1] == f"msg{_MAX_INJECTIONS_PER_TURN + 2}" + assert seen_limits == [_MAX_INJECTIONS_PER_TURN] + assert result == [ + {"role": "user", "content": "msg0"}, + {"role": "user", "content": "msg1"}, + {"role": "user", "content": "msg2"}, + ] @pytest.mark.asyncio @@ -1715,7 +1723,7 @@ async def test_drain_injections_skips_empty_content(): injection_callback=cb, ) result = await runner._drain_injections(spec) - assert result == ["valid"] + assert result == [{"role": "user", "content": "valid"}] @pytest.mark.asyncio @@ -1922,6 +1930,129 @@ async def test_checkpoint2_preserves_final_response_in_history_before_followup() ] +@pytest.mark.asyncio +async def test_loop_injected_followup_preserves_image_media(tmp_path): + """Mid-turn follow-ups with images should keep multimodal content.""" + from nanobot.agent.loop import AgentLoop + from nanobot.bus.events import InboundMessage + from nanobot.bus.queue import MessageBus + + image_path = tmp_path / "followup.png" + image_path.write_bytes(base64.b64decode( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+yF9kAAAAASUVORK5CYII=" + )) + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + captured_messages: list[list[dict]] = [] + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + captured_messages.append(list(messages)) + if call_count["n"] == 1: + return LLMResponse(content="first answer", tool_calls=[], usage={}) + return LLMResponse(content="second answer", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") + loop.tools.get_definitions = MagicMock(return_value=[]) + + pending_queue = asyncio.Queue() + await pending_queue.put(InboundMessage( + channel="cli", + sender_id="u", + chat_id="c", + content="", + media=[str(image_path)], + )) + + final_content, _, _, _, had_injections = await loop._run_agent_loop( + [{"role": "user", "content": "hello"}], + channel="cli", + chat_id="c", + pending_queue=pending_queue, + ) + + assert final_content == "second answer" + assert had_injections is True + assert call_count["n"] == 2 + injected_user_messages = [ + message for message in captured_messages[-1] + if message.get("role") == "user" and isinstance(message.get("content"), list) + ] + assert injected_user_messages + assert any( + block.get("type") == "image_url" + for block in injected_user_messages[-1]["content"] + if isinstance(block, dict) + ) + + +@pytest.mark.asyncio +async def test_runner_merges_multiple_injected_user_messages_without_losing_media(): + """Multiple injected follow-ups should not create lossy consecutive user messages.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + call_count = {"n": 0} + captured_messages = [] + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + captured_messages.append([dict(message) for message in messages]) + if call_count["n"] == 1: + return LLMResponse(content="first answer", tool_calls=[], usage={}) + return LLMResponse(content="second answer", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + async def inject_cb(): + if call_count["n"] == 1: + return [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + {"type": "text", "text": "look at this"}, + ], + }, + {"role": "user", "content": "and answer briefly"}, + ] + return [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.final_content == "second answer" + assert call_count["n"] == 2 + second_call = captured_messages[-1] + user_messages = [message for message in second_call if message.get("role") == "user"] + assert len(user_messages) == 2 + injected = user_messages[-1] + assert isinstance(injected["content"], list) + assert any( + block.get("type") == "image_url" + for block in injected["content"] + if isinstance(block, dict) + ) + assert any( + block.get("type") == "text" and block.get("text") == "and answer briefly" + for block in injected["content"] + if isinstance(block, dict) + ) + + @pytest.mark.asyncio async def test_injection_cycles_capped_at_max(): """Injection cycles should be capped at _MAX_INJECTION_CYCLES.""" @@ -2042,6 +2173,88 @@ async def test_followup_routed_to_pending_queue(tmp_path): assert queued_msg.session_key == UNIFIED_SESSION_KEY +@pytest.mark.asyncio +async def test_pending_queue_preserves_overflow_for_next_injection_cycle(tmp_path): + """Pending queue should leave overflow messages queued for later drains.""" + from nanobot.agent.loop import AgentLoop + from nanobot.bus.events import InboundMessage + from nanobot.bus.queue import MessageBus + from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + captured_messages: list[list[dict]] = [] + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + captured_messages.append([dict(message) for message in messages]) + return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") + loop.tools.get_definitions = MagicMock(return_value=[]) + + pending_queue = asyncio.Queue() + total_followups = _MAX_INJECTIONS_PER_TURN + 2 + for idx in range(total_followups): + await pending_queue.put(InboundMessage( + channel="cli", + sender_id="u", + chat_id="c", + content=f"follow-up-{idx}", + )) + + final_content, _, _, _, had_injections = await loop._run_agent_loop( + [{"role": "user", "content": "hello"}], + channel="cli", + chat_id="c", + pending_queue=pending_queue, + ) + + assert final_content == "answer-3" + assert had_injections is True + assert call_count["n"] == 3 + flattened_user_content = "\n".join( + message["content"] + for message in captured_messages[-1] + if message.get("role") == "user" and isinstance(message.get("content"), str) + ) + for idx in range(total_followups): + assert f"follow-up-{idx}" in flattened_user_content + assert pending_queue.empty() + + +@pytest.mark.asyncio +async def test_pending_queue_full_falls_back_to_queued_task(tmp_path): + """QueueFull should preserve the message by dispatching a queued task.""" + from nanobot.bus.events import InboundMessage + + loop = _make_loop(tmp_path) + loop._dispatch = AsyncMock() # type: ignore[method-assign] + + pending = asyncio.Queue(maxsize=1) + pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="already queued")) + loop._pending_queues["cli:c"] = pending + + run_task = asyncio.create_task(loop.run()) + msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up") + await loop.bus.publish_inbound(msg) + + deadline = time.time() + 2 + while loop._dispatch.await_count == 0 and time.time() < deadline: + await asyncio.sleep(0.01) + + loop.stop() + await asyncio.wait_for(run_task, timeout=2) + + assert loop._dispatch.await_count == 1 + dispatched_msg = loop._dispatch.await_args.args[0] + assert dispatched_msg.content == "follow-up" + assert pending.qsize() == 1 + + @pytest.mark.asyncio async def test_dispatch_republishes_leftover_queue_messages(tmp_path): """Messages left in the pending queue after _dispatch are re-published to the bus. diff --git a/tests/tools/test_message_tool_suppress.py b/tests/tools/test_message_tool_suppress.py index a922e95ed..434b2ca71 100644 --- a/tests/tools/test_message_tool_suppress.py +++ b/tests/tools/test_message_tool_suppress.py @@ -1,5 +1,6 @@ """Test message tool suppress logic for final replies.""" +import asyncio from pathlib import Path from unittest.mock import AsyncMock, MagicMock @@ -86,6 +87,42 @@ class TestMessageToolSuppressLogic: assert result is not None assert "Hello" in result.content + @pytest.mark.asyncio + async def test_injected_followup_with_message_tool_does_not_emit_empty_fallback( + self, tmp_path: Path + ) -> None: + loop = _make_loop(tmp_path) + tool_call = ToolCallRequest( + id="call1", name="message", + arguments={"content": "Tool reply", "channel": "feishu", "chat_id": "chat123"}, + ) + calls = iter([ + LLMResponse(content="First answer", tool_calls=[]), + LLMResponse(content="", tool_calls=[tool_call]), + LLMResponse(content="", tool_calls=[]), + LLMResponse(content="", tool_calls=[]), + LLMResponse(content="", tool_calls=[]), + ]) + loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls)) + loop.tools.get_definitions = MagicMock(return_value=[]) + + sent: list[OutboundMessage] = [] + mt = loop.tools.get("message") + if isinstance(mt, MessageTool): + mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m))) + + pending_queue = asyncio.Queue() + await pending_queue.put( + InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="follow-up") + ) + + msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Start") + result = await loop._process_message(msg, pending_queue=pending_queue) + + assert len(sent) == 1 + assert sent[0].content == "Tool reply" + assert result is None + async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None: loop = _make_loop(tmp_path) tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"}) From 48f6bbd256bde3d43bfa04166c9ae497aa4367a8 Mon Sep 17 00:00:00 2001 From: gem12 Date: Sat, 21 Mar 2026 21:28:41 +0800 Subject: [PATCH 075/115] feat(channels): Add full media support for QQ and WeCom channels QQ channel improvements (on top of nightly): - Add top-level try/except in _on_message and send() for resilience - Use defensive getattr() for attachment attributes (botpy version compat) - Skip file_name for image uploads to avoid QQ rendering as file attachment - Extract only file_info from upload response to avoid extra fields - Handle protocol-relative URLs (//...) in attachment downloads WeCom channel improvements: - Add _upload_media_ws() for WebSocket 3-step media upload protocol - Send media files (image/video/voice/file) via WeCom rich media API - Support progress messages (plain reply) vs final response (streaming) - Support proactive send when no frame available (cron push) - Pass media_paths to message bus for downstream processing --- nanobot/channels/qq.py | 194 ++++++++++++++++++++++---------------- nanobot/channels/wecom.py | 163 ++++++++++++++++++++++++++++---- 2 files changed, 255 insertions(+), 102 deletions(-) diff --git a/nanobot/channels/qq.py b/nanobot/channels/qq.py index bef2cf27a..484eed6e2 100644 --- a/nanobot/channels/qq.py +++ b/nanobot/channels/qq.py @@ -242,43 +242,46 @@ class QQChannel(BaseChannel): async def send(self, msg: OutboundMessage) -> None: """Send attachments first, then text.""" - if not self._client: - logger.warning("QQ client not initialized") - return + try: + if not self._client: + logger.warning("QQ client not initialized") + return - msg_id = msg.metadata.get("message_id") - chat_type = self._chat_type_cache.get(msg.chat_id, "c2c") - is_group = chat_type == "group" + msg_id = msg.metadata.get("message_id") + chat_type = self._chat_type_cache.get(msg.chat_id, "c2c") + is_group = chat_type == "group" - # 1) Send media - for media_ref in msg.media or []: - ok = await self._send_media( - chat_id=msg.chat_id, - media_ref=media_ref, - msg_id=msg_id, - is_group=is_group, - ) - if not ok: - filename = ( - os.path.basename(urlparse(media_ref).path) - or os.path.basename(media_ref) - or "file" + # 1) Send media + for media_ref in msg.media or []: + ok = await self._send_media( + chat_id=msg.chat_id, + media_ref=media_ref, + msg_id=msg_id, + is_group=is_group, ) + if not ok: + filename = ( + os.path.basename(urlparse(media_ref).path) + or os.path.basename(media_ref) + or "file" + ) + await self._send_text_only( + chat_id=msg.chat_id, + is_group=is_group, + msg_id=msg_id, + content=f"[Attachment send failed: {filename}]", + ) + + # 2) Send text + if msg.content and msg.content.strip(): await self._send_text_only( chat_id=msg.chat_id, is_group=is_group, msg_id=msg_id, - content=f"[Attachment send failed: {filename}]", + content=msg.content.strip(), ) - - # 2) Send text - if msg.content and msg.content.strip(): - await self._send_text_only( - chat_id=msg.chat_id, - is_group=is_group, - msg_id=msg_id, - content=msg.content.strip(), - ) + except Exception: + logger.exception("Error sending QQ message to chat_id={}", msg.chat_id) async def _send_text_only( self, @@ -438,15 +441,26 @@ class QQChannel(BaseChannel): endpoint = "/v2/users/{openid}/files" id_key = "openid" - payload = { + payload: dict[str, Any] = { id_key: chat_id, "file_type": file_type, "file_data": file_data, - "file_name": file_name, "srv_send_msg": srv_send_msg, } + # Only pass file_name for non-image types (file_type=4). + # Passing file_name for images causes QQ client to render them as + # file attachments instead of inline images. + if file_type != QQ_FILE_TYPE_IMAGE and file_name: + payload["file_name"] = file_name + route = Route("POST", endpoint, **{id_key: chat_id}) - return await self._client.api._http.request(route, json=payload) + result = await self._client.api._http.request(route, json=payload) + + # Extract only the file_info field to avoid extra fields (file_uuid, ttl, etc.) + # that may confuse QQ client when sending the media object. + if isinstance(result, dict) and "file_info" in result: + return {"file_info": result["file_info"]} + return result # --------------------------- # Inbound (receive) @@ -454,58 +468,68 @@ class QQChannel(BaseChannel): async def _on_message(self, data: C2CMessage | GroupMessage, is_group: bool = False) -> None: """Parse inbound message, download attachments, and publish to the bus.""" - if data.id in self._processed_ids: - return - self._processed_ids.append(data.id) + try: + if data.id in self._processed_ids: + return + self._processed_ids.append(data.id) - if is_group: - chat_id = data.group_openid - user_id = data.author.member_openid - self._chat_type_cache[chat_id] = "group" - else: - chat_id = str( - getattr(data.author, "id", None) or getattr(data.author, "user_openid", "unknown") - ) - user_id = chat_id - self._chat_type_cache[chat_id] = "c2c" - - content = (data.content or "").strip() - - # the data used by tests don't contain attachments property - # so we use getattr with a default of [] to avoid AttributeError in tests - attachments = getattr(data, "attachments", None) or [] - media_paths, recv_lines, att_meta = await self._handle_attachments(attachments) - - # Compose content that always contains actionable saved paths - if recv_lines: - tag = "[Image]" if any(_is_image_name(Path(p).name) for p in media_paths) else "[File]" - file_block = "Received files:\n" + "\n".join(recv_lines) - content = f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}" - - if not content and not media_paths: - return - - if self.config.ack_message: - try: - await self._send_text_only( - chat_id=chat_id, - is_group=is_group, - msg_id=data.id, - content=self.config.ack_message, + if is_group: + chat_id = data.group_openid + user_id = data.author.member_openid + self._chat_type_cache[chat_id] = "group" + else: + chat_id = str( + getattr(data.author, "id", None) + or getattr(data.author, "user_openid", "unknown") ) - except Exception: - logger.debug("QQ ack message failed for chat_id={}", chat_id) + user_id = chat_id + self._chat_type_cache[chat_id] = "c2c" - await self._handle_message( - sender_id=user_id, - chat_id=chat_id, - content=content, - media=media_paths if media_paths else None, - metadata={ - "message_id": data.id, - "attachments": att_meta, - }, - ) + content = (data.content or "").strip() + + # the data used by tests don't contain attachments property + # so we use getattr with a default of [] to avoid AttributeError in tests + attachments = getattr(data, "attachments", None) or [] + media_paths, recv_lines, att_meta = await self._handle_attachments(attachments) + + # Compose content that always contains actionable saved paths + if recv_lines: + tag = ( + "[Image]" + if any(_is_image_name(Path(p).name) for p in media_paths) + else "[File]" + ) + file_block = "Received files:\n" + "\n".join(recv_lines) + content = ( + f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}" + ) + + if not content and not media_paths: + return + + if self.config.ack_message: + try: + await self._send_text_only( + chat_id=chat_id, + is_group=is_group, + msg_id=data.id, + content=self.config.ack_message, + ) + except Exception: + logger.debug("QQ ack message failed for chat_id={}", chat_id) + + await self._handle_message( + sender_id=user_id, + chat_id=chat_id, + content=content, + media=media_paths if media_paths else None, + metadata={ + "message_id": data.id, + "attachments": att_meta, + }, + ) + except Exception: + logger.exception("Error handling QQ inbound message id={}", getattr(data, "id", "?")) async def _handle_attachments( self, @@ -520,7 +544,9 @@ class QQChannel(BaseChannel): return media_paths, recv_lines, att_meta for att in attachments: - url, filename, ctype = att.url, att.filename, att.content_type + url = getattr(att, "url", None) or "" + filename = getattr(att, "filename", None) or "" + ctype = getattr(att, "content_type", None) or "" logger.info("Downloading file from QQ: {}", filename or url) local_path = await self._download_to_media_dir_chunked(url, filename_hint=filename) @@ -555,6 +581,10 @@ class QQChannel(BaseChannel): Enforces a max download size and writes to a .part temp file that is atomically renamed on success. """ + # Handle protocol-relative URLs (e.g. "//multimedia.nt.qq.com/...") + if url.startswith("//"): + url = f"https:{url}" + if not self._http: self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120)) diff --git a/nanobot/channels/wecom.py b/nanobot/channels/wecom.py index 05ad14825..ec9e782be 100644 --- a/nanobot/channels/wecom.py +++ b/nanobot/channels/wecom.py @@ -1,6 +1,8 @@ """WeCom (Enterprise WeChat) channel implementation using wecom_aibot_sdk.""" import asyncio +import base64 +import hashlib import importlib.util import os from collections import OrderedDict @@ -217,6 +219,7 @@ class WecomChannel(BaseChannel): chat_id = body.get("chatid", sender_id) content_parts = [] + media_paths: list[str] = [] if msg_type == "text": text = body.get("text", {}).get("content", "") @@ -232,7 +235,8 @@ class WecomChannel(BaseChannel): file_path = await self._download_and_save_media(file_url, aes_key, "image") if file_path: filename = os.path.basename(file_path) - content_parts.append(f"[image: {filename}]\n[Image: source: {file_path}]") + content_parts.append(f"[image: {filename}]") + media_paths.append(file_path) else: content_parts.append("[image: download failed]") else: @@ -286,12 +290,11 @@ class WecomChannel(BaseChannel): self._chat_frames[chat_id] = frame # Forward to message bus - # Note: media paths are included in content for broader model compatibility await self._handle_message( sender_id=sender_id, chat_id=chat_id, content=content, - media=None, + media=media_paths or None, metadata={ "message_id": msg_id, "msg_type": msg_type, @@ -336,6 +339,93 @@ class WecomChannel(BaseChannel): logger.error("Error downloading media: {}", e) return None + async def _upload_media_ws( + self, client: Any, file_path: str, + ) -> "tuple[str, str] | tuple[None, None]": + """Upload a local file to WeCom via WebSocket 3-step protocol (base64). + + Uses the WeCom WebSocket upload commands directly via + ``client._ws_manager.send_reply()``: + + ``aibot_upload_media_init`` → upload_id + ``aibot_upload_media_chunk`` × N (≤512 KB raw per chunk, base64) + ``aibot_upload_media_finish`` → media_id + + Returns (media_id, media_type) on success, (None, None) on failure. + """ + from wecom_aibot_sdk.utils import generate_req_id as _gen_req_id + + try: + fname = os.path.basename(file_path) + ext = os.path.splitext(fname)[1].lower() + + if ext in (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"): + media_type = "image" + elif ext in (".mp4", ".avi", ".mov"): + media_type = "video" + elif ext in (".amr", ".mp3", ".wav", ".ogg"): + media_type = "voice" + else: + media_type = "file" + + data = open(file_path, "rb").read() # noqa: SIM115 + file_size = len(data) + md5_hash = hashlib.md5(data).hexdigest() # noqa: S324 + + CHUNK_SIZE = 512 * 1024 # 512 KB raw (before base64) + chunk_list = [data[i : i + CHUNK_SIZE] for i in range(0, file_size, CHUNK_SIZE)] + n_chunks = len(chunk_list) + + # Step 1: init + req_id = _gen_req_id("upload_init") + resp = await client._ws_manager.send_reply(req_id, { + "type": media_type, + "filename": fname, + "total_size": file_size, + "total_chunks": n_chunks, + "md5": md5_hash, + }, "aibot_upload_media_init") + if resp.errcode != 0: + logger.warning("WeCom upload init failed ({}): {}", resp.errcode, resp.errmsg) + return None, None + upload_id = resp.body.get("upload_id") if resp.body else None + if not upload_id: + logger.warning("WeCom upload init: no upload_id in response") + return None, None + + # Step 2: send chunks + for i, chunk in enumerate(chunk_list): + req_id = _gen_req_id("upload_chunk") + resp = await client._ws_manager.send_reply(req_id, { + "upload_id": upload_id, + "chunk_index": i, + "base64_data": base64.b64encode(chunk).decode(), + }, "aibot_upload_media_chunk") + if resp.errcode != 0: + logger.warning("WeCom upload chunk {} failed ({}): {}", i, resp.errcode, resp.errmsg) + return None, None + + # Step 3: finish + req_id = _gen_req_id("upload_finish") + resp = await client._ws_manager.send_reply(req_id, { + "upload_id": upload_id, + }, "aibot_upload_media_finish") + if resp.errcode != 0: + logger.warning("WeCom upload finish failed ({}): {}", resp.errcode, resp.errmsg) + return None, None + + media_id = resp.body.get("media_id") if resp.body else None + if not media_id: + logger.warning("WeCom upload finish: no media_id in response body={}", resp.body) + return None, None + + logger.debug("WeCom uploaded {} ({}) → media_id={}", fname, media_type, media_id[:16] + "...") + return media_id, media_type + + except Exception as e: + logger.error("WeCom _upload_media_ws error for {}: {}", file_path, e) + return None, None + async def send(self, msg: OutboundMessage) -> None: """Send a message through WeCom.""" if not self._client: @@ -343,28 +433,61 @@ class WecomChannel(BaseChannel): return try: - content = msg.content.strip() - if not content: - return + content = (msg.content or "").strip() + is_progress = bool(msg.metadata.get("_progress")) # Get the stored frame for this chat frame = self._chat_frames.get(msg.chat_id) - if not frame: - logger.warning("No frame found for chat {}, cannot reply", msg.chat_id) + + # Send media files via WebSocket upload + for file_path in msg.media or []: + if not os.path.isfile(file_path): + logger.warning("WeCom media file not found: {}", file_path) + continue + media_id, media_type = await self._upload_media_ws(self._client, file_path) + if media_id: + if frame: + await self._client.reply(frame, { + "msgtype": media_type, + media_type: {"media_id": media_id}, + }) + else: + await self._client.send_message(msg.chat_id, { + "msgtype": media_type, + media_type: {"media_id": media_id}, + }) + logger.debug("WeCom sent {} → {}", media_type, msg.chat_id) + else: + content += f"\n[file upload failed: {os.path.basename(file_path)}]" + + if not content: return - # Use streaming reply for better UX - stream_id = self._generate_req_id("stream") - - # Send as streaming message with finish=True - await self._client.reply_stream( - frame, - stream_id, - content, - finish=True, - ) - - logger.debug("WeCom message sent to {}", msg.chat_id) + if frame: + if is_progress: + # Progress messages (thinking text): send as plain reply, no streaming + await self._client.reply(frame, { + "msgtype": "text", + "text": {"content": content}, + }) + logger.debug("WeCom progress sent to {}", msg.chat_id) + else: + # Final response: use streaming reply for better UX + stream_id = self._generate_req_id("stream") + await self._client.reply_stream( + frame, + stream_id, + content, + finish=True, + ) + logger.debug("WeCom message sent to {}", msg.chat_id) + else: + # No frame (e.g. cron push): proactive send only supports markdown + await self._client.send_message(msg.chat_id, { + "msgtype": "markdown", + "markdown": {"content": content}, + }) + logger.info("WeCom proactive send to {}", msg.chat_id) except Exception as e: logger.error("Error sending WeCom message: {}", e) From f900e4f259139dc368c4f1adc65898bcb24641a2 Mon Sep 17 00:00:00 2001 From: chengyongru Date: Fri, 10 Apr 2026 16:27:23 +0800 Subject: [PATCH 076/115] fix(wecom): harden upload and inbound media handling - Use asyncio.to_thread for file I/O to avoid blocking event loop - Add 200MB upload size limit with early rejection - Fix file handle leak by using context manager - Free raw bytes early after chunking to reduce memory pressure - Add file attachments to media_paths (was text-only, inconsistent with image) - Use robust _sanitize_filename() instead of os.path.basename() for path safety - Remove re-raise in send() for consistency with QQ channel - Fix truncated media_id logging for short IDs --- nanobot/channels/wecom.py | 49 ++++++++++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/nanobot/channels/wecom.py b/nanobot/channels/wecom.py index ec9e782be..d285d4bcd 100644 --- a/nanobot/channels/wecom.py +++ b/nanobot/channels/wecom.py @@ -5,7 +5,9 @@ import base64 import hashlib import importlib.util import os +import re from collections import OrderedDict +from pathlib import Path from typing import Any from loguru import logger @@ -19,6 +21,20 @@ from pydantic import Field WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None +# Upload safety limits (matching QQ channel defaults) +WECOM_UPLOAD_MAX_BYTES = 1024 * 1024 * 200 # 200MB + +# Replace unsafe characters with "_", keep Chinese and common safe punctuation. +_SAFE_NAME_RE = re.compile(r"[^\w.\-()\[\]()【】\u4e00-\u9fff]+", re.UNICODE) + + +def _sanitize_filename(name: str) -> str: + """Sanitize filename to avoid traversal and problematic chars.""" + name = (name or "").strip() + name = Path(name).name + name = _SAFE_NAME_RE.sub("_", name).strip("._ ") + return name + class WecomConfig(Base): """WeCom (Enterprise WeChat) AI Bot channel configuration.""" @@ -260,7 +276,8 @@ class WecomChannel(BaseChannel): if file_url and aes_key: file_path = await self._download_and_save_media(file_url, aes_key, "file", file_name) if file_path: - content_parts.append(f"[file: {file_name}]\n[File: source: {file_path}]") + content_parts.append(f"[file: {file_name}]") + media_paths.append(file_path) else: content_parts.append(f"[file: {file_name}: download failed]") else: @@ -328,7 +345,7 @@ class WecomChannel(BaseChannel): media_dir = get_media_dir("wecom") if not filename: filename = fname or f"{media_type}_{hash(file_url) % 100000}" - filename = os.path.basename(filename) + filename = _sanitize_filename(filename) file_path = media_dir / filename file_path.write_bytes(data) @@ -368,13 +385,24 @@ class WecomChannel(BaseChannel): else: media_type = "file" - data = open(file_path, "rb").read() # noqa: SIM115 - file_size = len(data) - md5_hash = hashlib.md5(data).hexdigest() # noqa: S324 + # Read file size and data in a thread to avoid blocking the event loop + def _read_file(): + file_size = os.path.getsize(file_path) + if file_size > WECOM_UPLOAD_MAX_BYTES: + raise ValueError( + f"File too large: {file_size} bytes (max {WECOM_UPLOAD_MAX_BYTES})" + ) + with open(file_path, "rb") as f: + return file_size, f.read() + + file_size, data = await asyncio.to_thread(_read_file) + # MD5 is used for file integrity only, not cryptographic security + md5_hash = hashlib.md5(data).hexdigest() CHUNK_SIZE = 512 * 1024 # 512 KB raw (before base64) chunk_list = [data[i : i + CHUNK_SIZE] for i in range(0, file_size, CHUNK_SIZE)] n_chunks = len(chunk_list) + del data # free raw bytes early # Step 1: init req_id = _gen_req_id("upload_init") @@ -419,9 +447,13 @@ class WecomChannel(BaseChannel): logger.warning("WeCom upload finish: no media_id in response body={}", resp.body) return None, None - logger.debug("WeCom uploaded {} ({}) → media_id={}", fname, media_type, media_id[:16] + "...") + suffix = "..." if len(media_id) > 16 else "" + logger.debug("WeCom uploaded {} ({}) → media_id={}", fname, media_type, media_id[:16] + suffix) return media_id, media_type + except ValueError as e: + logger.warning("WeCom upload skipped for {}: {}", file_path, e) + return None, None except Exception as e: logger.error("WeCom _upload_media_ws error for {}: {}", file_path, e) return None, None @@ -489,6 +521,5 @@ class WecomChannel(BaseChannel): }) logger.info("WeCom proactive send to {}", msg.chat_id) - except Exception as e: - logger.error("Error sending WeCom message: {}", e) - raise + except Exception: + logger.exception("Error sending WeCom message to chat_id={}", msg.chat_id) From f6f712a2ae74473283c6f0e88bf93eeae4f54c9a Mon Sep 17 00:00:00 2001 From: chengyongru Date: Fri, 10 Apr 2026 16:33:57 +0800 Subject: [PATCH 077/115] fix(wecom): harden upload/download, extract media type helper - Use asyncio.to_thread for file I/O to avoid blocking event loop - Add 200MB upload size limit with early rejection - Fix file handle leak by using context manager - Use memoryview for upload chunking to reduce peak memory - Add inbound download size check to prevent OOM - Use asyncio.to_thread for write_bytes in download path - Extract inline media_type detection to _guess_wecom_media_type() --- nanobot/channels/wecom.py | 43 +++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/nanobot/channels/wecom.py b/nanobot/channels/wecom.py index d285d4bcd..910f02489 100644 --- a/nanobot/channels/wecom.py +++ b/nanobot/channels/wecom.py @@ -35,6 +35,23 @@ def _sanitize_filename(name: str) -> str: name = _SAFE_NAME_RE.sub("_", name).strip("._ ") return name + +_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"} +_VIDEO_EXTS = {".mp4", ".avi", ".mov"} +_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg"} + + +def _guess_wecom_media_type(filename: str) -> str: + """Classify file extension as WeCom media_type string.""" + ext = Path(filename).suffix.lower() + if ext in _IMAGE_EXTS: + return "image" + if ext in _VIDEO_EXTS: + return "video" + if ext in _AUDIO_EXTS: + return "voice" + return "file" + class WecomConfig(Base): """WeCom (Enterprise WeChat) AI Bot channel configuration.""" @@ -342,13 +359,21 @@ class WecomChannel(BaseChannel): logger.warning("Failed to download media from WeCom") return None + if len(data) > WECOM_UPLOAD_MAX_BYTES: + logger.warning( + "WeCom inbound media too large: {} bytes (max {})", + len(data), + WECOM_UPLOAD_MAX_BYTES, + ) + return None + media_dir = get_media_dir("wecom") if not filename: filename = fname or f"{media_type}_{hash(file_url) % 100000}" filename = _sanitize_filename(filename) file_path = media_dir / filename - file_path.write_bytes(data) + await asyncio.to_thread(file_path.write_bytes, data) logger.debug("Downloaded {} to {}", media_type, file_path) return str(file_path) @@ -374,16 +399,7 @@ class WecomChannel(BaseChannel): try: fname = os.path.basename(file_path) - ext = os.path.splitext(fname)[1].lower() - - if ext in (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"): - media_type = "image" - elif ext in (".mp4", ".avi", ".mov"): - media_type = "video" - elif ext in (".amr", ".mp3", ".wav", ".ogg"): - media_type = "voice" - else: - media_type = "file" + media_type = _guess_wecom_media_type(fname) # Read file size and data in a thread to avoid blocking the event loop def _read_file(): @@ -400,9 +416,10 @@ class WecomChannel(BaseChannel): md5_hash = hashlib.md5(data).hexdigest() CHUNK_SIZE = 512 * 1024 # 512 KB raw (before base64) - chunk_list = [data[i : i + CHUNK_SIZE] for i in range(0, file_size, CHUNK_SIZE)] + mv = memoryview(data) + chunk_list = [bytes(mv[i : i + CHUNK_SIZE]) for i in range(0, file_size, CHUNK_SIZE)] n_chunks = len(chunk_list) - del data # free raw bytes early + del mv, data # Step 1: init req_id = _gen_req_id("upload_init") From 0d03f10fa02eddb2fcb5efe70244bfc313c34fa0 Mon Sep 17 00:00:00 2001 From: chengyongru Date: Fri, 10 Apr 2026 17:02:08 +0800 Subject: [PATCH 078/115] test(channels): add media support tests for QQ and WeCom channels Cover helpers (sanitize_filename, guess media type), outbound send (exception handling, media-then-text order, fallback), inbound message processing (attachments, dedup, empty content), _post_base64file payload filtering, and WeCom upload/download flows. --- tests/channels/test_qq_media.py | 304 ++++++++++++++ tests/channels/test_wecom_channel.py | 583 +++++++++++++++++++++++++++ 2 files changed, 887 insertions(+) create mode 100644 tests/channels/test_qq_media.py create mode 100644 tests/channels/test_wecom_channel.py diff --git a/tests/channels/test_qq_media.py b/tests/channels/test_qq_media.py new file mode 100644 index 000000000..80a5ad20e --- /dev/null +++ b/tests/channels/test_qq_media.py @@ -0,0 +1,304 @@ +"""Tests for QQ channel media support: helpers, send, inbound, and upload.""" + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +try: + from nanobot.channels import qq + + QQ_AVAILABLE = getattr(qq, "QQ_AVAILABLE", False) +except ImportError: + QQ_AVAILABLE = False + +if not QQ_AVAILABLE: + pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True) + +from nanobot.bus.events import OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.qq import ( + QQ_FILE_TYPE_FILE, + QQ_FILE_TYPE_IMAGE, + QQChannel, + QQConfig, + _guess_send_file_type, + _is_image_name, + _sanitize_filename, +) + + +class _FakeApi: + def __init__(self) -> None: + self.c2c_calls: list[dict] = [] + self.group_calls: list[dict] = [] + + async def post_c2c_message(self, **kwargs) -> None: + self.c2c_calls.append(kwargs) + + async def post_group_message(self, **kwargs) -> None: + self.group_calls.append(kwargs) + + +class _FakeHttp: + """Fake _http for _post_base64file tests.""" + + def __init__(self, return_value: dict | None = None) -> None: + self.return_value = return_value or {} + self.calls: list[tuple] = [] + + async def request(self, route, **kwargs): + self.calls.append((route, kwargs)) + return self.return_value + + +class _FakeClient: + def __init__(self, http_return: dict | None = None) -> None: + self.api = _FakeApi() + self.api._http = _FakeHttp(http_return) + + +# ── Helper function tests (pure, no async) ────────────────────────── + + +def test_sanitize_filename_strips_path_traversal() -> None: + assert _sanitize_filename("../../etc/passwd") == "passwd" + + +def test_sanitize_filename_keeps_chinese_chars() -> None: + assert _sanitize_filename("文件(1).jpg") == "文件(1).jpg" + + +def test_sanitize_filename_strips_unsafe_chars() -> None: + result = _sanitize_filename('file<>:"|?*.txt') + # All unsafe chars replaced with "_", but * is replaced too + assert result.startswith("file") + assert result.endswith(".txt") + assert "<" not in result + assert ">" not in result + assert '"' not in result + assert "|" not in result + assert "?" not in result + + +def test_sanitize_filename_empty_input() -> None: + assert _sanitize_filename("") == "" + + +def test_is_image_name_with_known_extensions() -> None: + for ext in (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".tif", ".tiff", ".ico", ".svg"): + assert _is_image_name(f"photo{ext}") is True + + +def test_is_image_name_with_unknown_extension() -> None: + for ext in (".pdf", ".txt", ".mp3", ".mp4"): + assert _is_image_name(f"doc{ext}") is False + + +def test_guess_send_file_type_image() -> None: + assert _guess_send_file_type("photo.png") == QQ_FILE_TYPE_IMAGE + assert _guess_send_file_type("pic.jpg") == QQ_FILE_TYPE_IMAGE + + +def test_guess_send_file_type_file() -> None: + assert _guess_send_file_type("doc.pdf") == QQ_FILE_TYPE_FILE + + +def test_guess_send_file_type_by_mime() -> None: + # A filename with no known extension but whose mime type is image/* + assert _guess_send_file_type("photo.xyz_image_test") == QQ_FILE_TYPE_FILE + + +# ── send() exception handling ─────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_send_exception_caught_not_raised() -> None: + """Exceptions inside send() must not propagate.""" + channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus()) + channel._client = _FakeClient() + + with patch.object(channel, "_send_text_only", new_callable=AsyncMock, side_effect=RuntimeError("boom")): + await channel.send( + OutboundMessage(channel="qq", chat_id="user1", content="hello") + ) + # No exception raised — test passes if we get here. + + +@pytest.mark.asyncio +async def test_send_media_then_text() -> None: + """Media is sent before text when both are present.""" + import tempfile + + channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus()) + channel._client = _FakeClient() + + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + f.write(b"\x89PNG\r\n") + tmp = f.name + + try: + with patch.object(channel, "_post_base64file", new_callable=AsyncMock, return_value={"file_info": "1"}) as mock_upload: + await channel.send( + OutboundMessage( + channel="qq", + chat_id="user1", + content="text after image", + media=[tmp], + metadata={"message_id": "m1"}, + ) + ) + assert mock_upload.called + + # Text should have been sent via c2c (default chat type) + text_calls = [c for c in channel._client.api.c2c_calls if c.get("msg_type") == 0] + assert len(text_calls) >= 1 + assert text_calls[-1]["content"] == "text after image" + finally: + import os + os.unlink(tmp) + + +@pytest.mark.asyncio +async def test_send_media_failure_falls_back_to_text() -> None: + """When _send_media returns False, a failure notice is appended.""" + channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus()) + channel._client = _FakeClient() + + with patch.object(channel, "_send_media", new_callable=AsyncMock, return_value=False): + await channel.send( + OutboundMessage( + channel="qq", + chat_id="user1", + content="hello", + media=["https://example.com/bad.png"], + metadata={"message_id": "m1"}, + ) + ) + + # Should have the failure text among the c2c calls + failure_calls = [c for c in channel._client.api.c2c_calls if "Attachment send failed" in c.get("content", "")] + assert len(failure_calls) == 1 + assert "bad.png" in failure_calls[0]["content"] + + +# ── _on_message() exception handling ──────────────────────────────── + + +@pytest.mark.asyncio +async def test_on_message_exception_caught_not_raised() -> None: + """Missing required attributes should not crash _on_message.""" + channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus()) + channel._client = _FakeClient() + + # Construct a message-like object that lacks 'author' — triggers AttributeError + bad_data = SimpleNamespace(id="x1", content="hi") + # Should not raise + await channel._on_message(bad_data, is_group=False) + + +@pytest.mark.asyncio +async def test_on_message_with_attachments() -> None: + """Messages with attachments produce media_paths and formatted content.""" + import tempfile + + channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus()) + channel._client = _FakeClient() + + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + f.write(b"\x89PNG\r\n") + saved_path = f.name + + att = SimpleNamespace(url="", filename="screenshot.png", content_type="image/png") + + # Patch _download_to_media_dir_chunked to return the temp file path + async def fake_download(url, filename_hint=""): + return saved_path + + try: + with patch.object(channel, "_download_to_media_dir_chunked", side_effect=fake_download): + data = SimpleNamespace( + id="att1", + content="look at this", + author=SimpleNamespace(user_openid="u1"), + attachments=[att], + ) + await channel._on_message(data, is_group=False) + + msg = await channel.bus.consume_inbound() + assert "look at this" in msg.content + assert "screenshot.png" in msg.content + assert "Received files:" in msg.content + assert len(msg.media) == 1 + assert msg.media[0] == saved_path + finally: + import os + os.unlink(saved_path) + + +# ── _post_base64file() ───────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_post_base64file_omits_file_name_for_images() -> None: + """file_type=1 (image) → payload must not contain file_name.""" + channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus()) + channel._client = _FakeClient(http_return={"file_info": "img_abc"}) + + await channel._post_base64file( + chat_id="user1", + is_group=False, + file_type=QQ_FILE_TYPE_IMAGE, + file_data="ZmFrZQ==", + file_name="photo.png", + ) + + http = channel._client.api._http + assert len(http.calls) == 1 + payload = http.calls[0][1]["json"] + assert "file_name" not in payload + assert payload["file_type"] == QQ_FILE_TYPE_IMAGE + + +@pytest.mark.asyncio +async def test_post_base64file_includes_file_name_for_files() -> None: + """file_type=4 (file) → payload must contain file_name.""" + channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus()) + channel._client = _FakeClient(http_return={"file_info": "file_abc"}) + + await channel._post_base64file( + chat_id="user1", + is_group=False, + file_type=QQ_FILE_TYPE_FILE, + file_data="ZmFrZQ==", + file_name="report.pdf", + ) + + http = channel._client.api._http + assert len(http.calls) == 1 + payload = http.calls[0][1]["json"] + assert payload["file_name"] == "report.pdf" + assert payload["file_type"] == QQ_FILE_TYPE_FILE + + +@pytest.mark.asyncio +async def test_post_base64file_filters_response_to_file_info() -> None: + """Response with file_info + extra fields must be filtered to only file_info.""" + channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus()) + channel._client = _FakeClient(http_return={ + "file_info": "fi_123", + "file_uuid": "uuid_xxx", + "ttl": 3600, + }) + + result = await channel._post_base64file( + chat_id="user1", + is_group=False, + file_type=QQ_FILE_TYPE_FILE, + file_data="ZmFrZQ==", + file_name="doc.pdf", + ) + + assert result == {"file_info": "fi_123"} + assert "file_uuid" not in result + assert "ttl" not in result diff --git a/tests/channels/test_wecom_channel.py b/tests/channels/test_wecom_channel.py new file mode 100644 index 000000000..164c01ea2 --- /dev/null +++ b/tests/channels/test_wecom_channel.py @@ -0,0 +1,583 @@ +"""Tests for WeCom channel: helpers, download, upload, send, and message processing.""" + +import os +import tempfile +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +try: + import importlib.util + + WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None +except ImportError: + WECOM_AVAILABLE = False + +if not WECOM_AVAILABLE: + pytest.skip("WeCom dependencies not installed (wecom_aibot_sdk)", allow_module_level=True) + +from nanobot.bus.events import OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.wecom import ( + WecomChannel, + WecomConfig, + _guess_wecom_media_type, + _sanitize_filename, +) + +# Try to import the real response class; fall back to a stub if unavailable. +try: + from wecom_aibot_sdk.utils import WsResponse + + _RealWsResponse = WsResponse +except ImportError: + _RealWsResponse = None + + +class _FakeResponse: + """Minimal stand-in for wecom_aibot_sdk WsResponse.""" + + def __init__(self, errcode: int = 0, body: dict | None = None, errmsg: str = "ok"): + self.errcode = errcode + self.errmsg = errmsg + self.body = body or {} + + +class _FakeWsManager: + """Tracks send_reply calls and returns configurable responses.""" + + def __init__(self, responses: list[_FakeResponse] | None = None): + self.responses = responses or [] + self.calls: list[tuple[str, dict, str]] = [] + self._idx = 0 + + async def send_reply(self, req_id: str, data: dict, cmd: str) -> _FakeResponse: + self.calls.append((req_id, data, cmd)) + if self._idx < len(self.responses): + resp = self.responses[self._idx] + self._idx += 1 + return resp + return _FakeResponse() + + +class _FakeFrame: + """Minimal frame object with a body dict.""" + + def __init__(self, body: dict | None = None): + self.body = body or {} + + +class _FakeWeComClient: + """Fake WeCom client with mock methods.""" + + def __init__(self, ws_responses: list[_FakeResponse] | None = None): + self._ws_manager = _FakeWsManager(ws_responses) + self.download_file = AsyncMock(return_value=(None, None)) + self.reply = AsyncMock() + self.reply_stream = AsyncMock() + self.send_message = AsyncMock() + self.reply_welcome = AsyncMock() + + +# ── Helper function tests (pure, no async) ────────────────────────── + + +def test_sanitize_filename_strips_path_traversal() -> None: + assert _sanitize_filename("../../etc/passwd") == "passwd" + + +def test_sanitize_filename_keeps_chinese_chars() -> None: + assert _sanitize_filename("文件(1).jpg") == "文件(1).jpg" + + +def test_sanitize_filename_empty_input() -> None: + assert _sanitize_filename("") == "" + + +def test_guess_wecom_media_type_image() -> None: + for ext in (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"): + assert _guess_wecom_media_type(f"photo{ext}") == "image" + + +def test_guess_wecom_media_type_video() -> None: + for ext in (".mp4", ".avi", ".mov"): + assert _guess_wecom_media_type(f"video{ext}") == "video" + + +def test_guess_wecom_media_type_voice() -> None: + for ext in (".amr", ".mp3", ".wav", ".ogg"): + assert _guess_wecom_media_type(f"audio{ext}") == "voice" + + +def test_guess_wecom_media_type_file_fallback() -> None: + for ext in (".pdf", ".doc", ".xlsx", ".zip"): + assert _guess_wecom_media_type(f"doc{ext}") == "file" + + +def test_guess_wecom_media_type_case_insensitive() -> None: + assert _guess_wecom_media_type("photo.PNG") == "image" + assert _guess_wecom_media_type("photo.Jpg") == "image" + + +# ── _download_and_save_media() ────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_download_and_save_success() -> None: + """Successful download writes file and returns sanitized path.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + + fake_data = b"\x89PNG\r\nfake image" + client.download_file.return_value = (fake_data, "raw_photo.png") + + with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())): + path = await channel._download_and_save_media("https://example.com/img.png", "aes_key", "image", "photo.png") + + assert path is not None + assert os.path.isfile(path) + assert os.path.basename(path) == "photo.png" + # Cleanup + os.unlink(path) + + +@pytest.mark.asyncio +async def test_download_and_save_oversized_rejected() -> None: + """Data exceeding 200MB is rejected → returns None.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + + big_data = b"\x00" * (200 * 1024 * 1024 + 1) # 200MB + 1 byte + client.download_file.return_value = (big_data, "big.bin") + + with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())): + result = await channel._download_and_save_media("https://example.com/big.bin", "key", "file", "big.bin") + + assert result is None + + +@pytest.mark.asyncio +async def test_download_and_save_failure() -> None: + """SDK returns None data → returns None.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + + client.download_file.return_value = (None, None) + + with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())): + result = await channel._download_and_save_media("https://example.com/fail.png", "key", "image") + + assert result is None + + +# ── _upload_media_ws() ────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_upload_media_ws_success() -> None: + """Happy path: init → chunk → finish → returns (media_id, media_type).""" + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + f.write(b"\x89PNG\r\n") + tmp = f.name + + try: + responses = [ + _FakeResponse(errcode=0, body={"upload_id": "up_1"}), + _FakeResponse(errcode=0, body={}), + _FakeResponse(errcode=0, body={"media_id": "media_abc"}), + ] + + client = _FakeWeComClient(responses) + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + channel._client = client + + with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"): + media_id, media_type = await channel._upload_media_ws(client, tmp) + + assert media_id == "media_abc" + assert media_type == "image" + finally: + os.unlink(tmp) + + +@pytest.mark.asyncio +async def test_upload_media_ws_oversized_file() -> None: + """File >200MB triggers ValueError → returns (None, None).""" + # Instead of creating a real 200MB+ file, mock os.path.getsize and open + with patch("os.path.getsize", return_value=200 * 1024 * 1024 + 1), \ + patch("builtins.open", MagicMock()): + client = _FakeWeComClient() + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + channel._client = client + + result = await channel._upload_media_ws(client, "/fake/large.bin") + assert result == (None, None) + + +@pytest.mark.asyncio +async def test_upload_media_ws_init_failure() -> None: + """Init step returns errcode != 0 → returns (None, None).""" + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f: + f.write(b"hello") + tmp = f.name + + try: + responses = [ + _FakeResponse(errcode=50001, errmsg="invalid"), + ] + + client = _FakeWeComClient(responses) + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + channel._client = client + + with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"): + result = await channel._upload_media_ws(client, tmp) + + assert result == (None, None) + finally: + os.unlink(tmp) + + +@pytest.mark.asyncio +async def test_upload_media_ws_chunk_failure() -> None: + """Chunk step returns errcode != 0 → returns (None, None).""" + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + f.write(b"\x89PNG\r\n") + tmp = f.name + + try: + responses = [ + _FakeResponse(errcode=0, body={"upload_id": "up_1"}), + _FakeResponse(errcode=50002, errmsg="chunk fail"), + ] + + client = _FakeWeComClient(responses) + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + channel._client = client + + with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"): + result = await channel._upload_media_ws(client, tmp) + + assert result == (None, None) + finally: + os.unlink(tmp) + + +@pytest.mark.asyncio +async def test_upload_media_ws_finish_no_media_id() -> None: + """Finish step returns empty media_id → returns (None, None).""" + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + f.write(b"\x89PNG\r\n") + tmp = f.name + + try: + responses = [ + _FakeResponse(errcode=0, body={"upload_id": "up_1"}), + _FakeResponse(errcode=0, body={}), + _FakeResponse(errcode=0, body={}), # no media_id + ] + + client = _FakeWeComClient(responses) + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + channel._client = client + + with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"): + result = await channel._upload_media_ws(client, tmp) + + assert result == (None, None) + finally: + os.unlink(tmp) + + +# ── send() ────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_send_text_with_frame() -> None: + """When frame is stored, send uses reply_stream for final text.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + channel._generate_req_id = lambda x: f"req_{x}" + channel._chat_frames["chat1"] = _FakeFrame() + + await channel.send( + OutboundMessage(channel="wecom", chat_id="chat1", content="hello") + ) + + client.reply_stream.assert_called_once() + call_args = client.reply_stream.call_args + assert call_args[0][2] == "hello" # content arg + + +@pytest.mark.asyncio +async def test_send_progress_with_frame() -> None: + """When metadata has _progress, send uses reply (not reply_stream).""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + channel._chat_frames["chat1"] = _FakeFrame() + + await channel.send( + OutboundMessage(channel="wecom", chat_id="chat1", content="thinking...", metadata={"_progress": True}) + ) + + client.reply.assert_called_once() + client.reply_stream.assert_not_called() + call_args = client.reply.call_args + assert call_args[0][1]["text"]["content"] == "thinking..." + + +@pytest.mark.asyncio +async def test_send_proactive_without_frame() -> None: + """Without stored frame, send uses send_message with markdown.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + + await channel.send( + OutboundMessage(channel="wecom", chat_id="chat1", content="proactive msg") + ) + + client.send_message.assert_called_once() + call_args = client.send_message.call_args + assert call_args[0][0] == "chat1" + assert call_args[0][1]["msgtype"] == "markdown" + + +@pytest.mark.asyncio +async def test_send_media_then_text() -> None: + """Media files are uploaded and sent before text content.""" + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + f.write(b"\x89PNG\r\n") + tmp = f.name + + try: + responses = [ + _FakeResponse(errcode=0, body={"upload_id": "up_1"}), + _FakeResponse(errcode=0, body={}), + _FakeResponse(errcode=0, body={"media_id": "media_123"}), + ] + + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + client = _FakeWeComClient(responses) + channel._client = client + channel._generate_req_id = lambda x: f"req_{x}" + channel._chat_frames["chat1"] = _FakeFrame() + + await channel.send( + OutboundMessage(channel="wecom", chat_id="chat1", content="see image", media=[tmp]) + ) + + # Media should have been sent via reply + media_calls = [c for c in client.reply.call_args_list if c[0][1].get("msgtype") == "image"] + assert len(media_calls) == 1 + assert media_calls[0][0][1]["image"]["media_id"] == "media_123" + + # Text should have been sent via reply_stream + client.reply_stream.assert_called_once() + finally: + os.unlink(tmp) + + +@pytest.mark.asyncio +async def test_send_media_file_not_found() -> None: + """Non-existent media path is skipped with a warning.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + channel._generate_req_id = lambda x: f"req_{x}" + channel._chat_frames["chat1"] = _FakeFrame() + + await channel.send( + OutboundMessage(channel="wecom", chat_id="chat1", content="hello", media=["/nonexistent/file.png"]) + ) + + # reply_stream should still be called for the text part + client.reply_stream.assert_called_once() + # No media reply should happen + media_calls = [c for c in client.reply.call_args_list if c[0][1].get("msgtype") in ("image", "file", "video")] + assert len(media_calls) == 0 + + +@pytest.mark.asyncio +async def test_send_exception_caught_not_raised() -> None: + """Exceptions inside send() must not propagate.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + channel._generate_req_id = lambda x: f"req_{x}" + channel._chat_frames["chat1"] = _FakeFrame() + + # Make reply_stream raise + client.reply_stream.side_effect = RuntimeError("boom") + + await channel.send( + OutboundMessage(channel="wecom", chat_id="chat1", content="fail test") + ) + # No exception — test passes if we reach here. + + +# ── _process_message() ────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_process_text_message() -> None: + """Text message is routed to bus with correct fields.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + + frame = _FakeFrame(body={ + "msgid": "msg_text_1", + "chatid": "chat1", + "chattype": "single", + "from": {"userid": "user1"}, + "text": {"content": "hello wecom"}, + }) + + await channel._process_message(frame, "text") + + msg = await channel.bus.consume_inbound() + assert msg.sender_id == "user1" + assert msg.chat_id == "chat1" + assert msg.content == "hello wecom" + assert msg.metadata["msg_type"] == "text" + + +@pytest.mark.asyncio +async def test_process_image_message() -> None: + """Image message: download success → media_paths non-empty.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus()) + client = _FakeWeComClient() + + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + f.write(b"\x89PNG\r\n") + saved = f.name + + client.download_file.return_value = (b"\x89PNG\r\n", "photo.png") + channel._client = client + + try: + with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(os.path.dirname(saved))): + frame = _FakeFrame(body={ + "msgid": "msg_img_1", + "chatid": "chat1", + "from": {"userid": "user1"}, + "image": {"url": "https://example.com/img.png", "aeskey": "key123"}, + }) + await channel._process_message(frame, "image") + + msg = await channel.bus.consume_inbound() + assert len(msg.media) == 1 + assert msg.media[0].endswith("photo.png") + assert "[image:" in msg.content + finally: + if os.path.exists(saved): + pass # may have been overwritten; clean up if exists + # Clean up any photo.png in tempdir + p = os.path.join(os.path.dirname(saved), "photo.png") + if os.path.exists(p): + os.unlink(p) + + +@pytest.mark.asyncio +async def test_process_file_message() -> None: + """File message: download success → media_paths non-empty (critical fix verification).""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus()) + client = _FakeWeComClient() + + with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f: + f.write(b"%PDF-1.4 fake") + saved = f.name + + client.download_file.return_value = (b"%PDF-1.4 fake", "report.pdf") + channel._client = client + + try: + with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(os.path.dirname(saved))): + frame = _FakeFrame(body={ + "msgid": "msg_file_1", + "chatid": "chat1", + "from": {"userid": "user1"}, + "file": {"url": "https://example.com/report.pdf", "aeskey": "key456", "name": "report.pdf"}, + }) + await channel._process_message(frame, "file") + + msg = await channel.bus.consume_inbound() + assert len(msg.media) == 1 + assert msg.media[0].endswith("report.pdf") + assert "[file: report.pdf]" in msg.content + finally: + p = os.path.join(os.path.dirname(saved), "report.pdf") + if os.path.exists(p): + os.unlink(p) + + +@pytest.mark.asyncio +async def test_process_voice_message() -> None: + """Voice message: transcribed text is included in content.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + + frame = _FakeFrame(body={ + "msgid": "msg_voice_1", + "chatid": "chat1", + "from": {"userid": "user1"}, + "voice": {"content": "transcribed text here"}, + }) + + await channel._process_message(frame, "voice") + + msg = await channel.bus.consume_inbound() + assert "transcribed text here" in msg.content + assert "[voice]" in msg.content + + +@pytest.mark.asyncio +async def test_process_message_deduplication() -> None: + """Same msg_id is not processed twice.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + + frame = _FakeFrame(body={ + "msgid": "msg_dup_1", + "chatid": "chat1", + "from": {"userid": "user1"}, + "text": {"content": "once"}, + }) + + await channel._process_message(frame, "text") + await channel._process_message(frame, "text") + + msg = await channel.bus.consume_inbound() + assert msg.content == "once" + + # Second message should not appear on the bus + assert channel.bus.inbound.empty() + + +@pytest.mark.asyncio +async def test_process_message_empty_content_skipped() -> None: + """Message with empty content produces no bus message.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + + frame = _FakeFrame(body={ + "msgid": "msg_empty_1", + "chatid": "chat1", + "from": {"userid": "user1"}, + "text": {"content": ""}, + }) + + await channel._process_message(frame, "text") + + assert channel.bus.inbound.empty() From 9f433cab014701d69cfd1b93040a8844867733b1 Mon Sep 17 00:00:00 2001 From: chengyongru <2755839590@qq.com> Date: Fri, 10 Apr 2026 22:20:28 +0800 Subject: [PATCH 079/115] fix(wecom): use reply_stream for progress messages to avoid errcode=40008 The plain reply() uses cmd="reply" which does not support "text" msgtype and causes WeCom API to return errcode=40008 (invalid message type). Unify both progress and final text messages to use reply_stream() (cmd="aibot_respond_msg"), differentiating via finish flag. Fixes #2999 --- nanobot/channels/wecom.py | 32 +++++++++++++--------------- tests/channels/test_wecom_channel.py | 11 +++++----- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/nanobot/channels/wecom.py b/nanobot/channels/wecom.py index 910f02489..a7d7f1fe2 100644 --- a/nanobot/channels/wecom.py +++ b/nanobot/channels/wecom.py @@ -513,23 +513,21 @@ class WecomChannel(BaseChannel): return if frame: - if is_progress: - # Progress messages (thinking text): send as plain reply, no streaming - await self._client.reply(frame, { - "msgtype": "text", - "text": {"content": content}, - }) - logger.debug("WeCom progress sent to {}", msg.chat_id) - else: - # Final response: use streaming reply for better UX - stream_id = self._generate_req_id("stream") - await self._client.reply_stream( - frame, - stream_id, - content, - finish=True, - ) - logger.debug("WeCom message sent to {}", msg.chat_id) + # Both progress and final messages must use reply_stream (cmd="aibot_respond_msg"). + # The plain reply() uses cmd="reply" which does not support "text" msgtype + # and causes errcode=40008 from WeCom API. + stream_id = self._generate_req_id("stream") + await self._client.reply_stream( + frame, + stream_id, + content, + finish=not is_progress, + ) + logger.debug( + "WeCom {} sent to {}", + "progress" if is_progress else "message", + msg.chat_id, + ) else: # No frame (e.g. cron push): proactive send only supports markdown await self._client.send_message(msg.chat_id, { diff --git a/tests/channels/test_wecom_channel.py b/tests/channels/test_wecom_channel.py index 164c01ea2..b79c023ba 100644 --- a/tests/channels/test_wecom_channel.py +++ b/tests/channels/test_wecom_channel.py @@ -317,20 +317,21 @@ async def test_send_text_with_frame() -> None: @pytest.mark.asyncio async def test_send_progress_with_frame() -> None: - """When metadata has _progress, send uses reply (not reply_stream).""" + """When metadata has _progress, send uses reply_stream with finish=False.""" channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) client = _FakeWeComClient() channel._client = client + channel._generate_req_id = lambda x: f"req_{x}" channel._chat_frames["chat1"] = _FakeFrame() await channel.send( OutboundMessage(channel="wecom", chat_id="chat1", content="thinking...", metadata={"_progress": True}) ) - client.reply.assert_called_once() - client.reply_stream.assert_not_called() - call_args = client.reply.call_args - assert call_args[0][1]["text"]["content"] == "thinking..." + client.reply_stream.assert_called_once() + call_args = client.reply_stream.call_args + assert call_args[0][2] == "thinking..." # content arg + assert call_args[1]["finish"] is False @pytest.mark.asyncio From 4cd4ed8adae54af57d7743cc16c332ec0a004c27 Mon Sep 17 00:00:00 2001 From: 04cb <0x04cb@gmail.com> Date: Sat, 11 Apr 2026 21:48:31 +0800 Subject: [PATCH 080/115] fix(agent): preserve tool results on fatal error to prevent orphan tool_calls (#2943) --- nanobot/agent/runner.py | 31 ++++++---- tests/agent/test_runner.py | 115 +++++++++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+), 12 deletions(-) diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index cfebe098f..0d8062842 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -111,14 +111,21 @@ class AgentRunner: messages_for_model = self._microcompact(messages_for_model) messages_for_model = self._apply_tool_result_budget(spec, messages_for_model) messages_for_model = self._snip_history(spec, messages_for_model) + # Snipping may have created new orphans; clean them up. + messages_for_model = self._drop_orphan_tool_results(messages_for_model) + messages_for_model = self._backfill_missing_tool_results(messages_for_model) except Exception as exc: logger.warning( - "Context governance failed on turn {} for {}: {}; using raw messages", + "Context governance failed on turn {} for {}: {}; applying minimal repair", iteration, spec.session_key or "default", exc, ) - messages_for_model = messages + try: + messages_for_model = self._drop_orphan_tool_results(messages) + messages_for_model = self._backfill_missing_tool_results(messages_for_model) + except Exception: + messages_for_model = messages context = AgentHookContext(iteration=iteration, messages=messages) await hook.before_iteration(context) response = await self._request_model(spec, messages_for_model, hook, context) @@ -162,16 +169,6 @@ class AgentRunner: tool_events.extend(new_events) context.tool_results = list(results) context.tool_events = list(new_events) - if fatal_error is not None: - error = f"Error: {type(fatal_error).__name__}: {fatal_error}" - final_content = error - stop_reason = "tool_error" - self._append_final_message(messages, final_content) - context.final_content = final_content - context.error = error - context.stop_reason = stop_reason - await hook.after_iteration(context) - break completed_tool_results: list[dict[str, Any]] = [] for tool_call, result in zip(response.tool_calls, results): tool_message = { @@ -187,6 +184,16 @@ class AgentRunner: } messages.append(tool_message) completed_tool_results.append(tool_message) + if fatal_error is not None: + error = f"Error: {type(fatal_error).__name__}: {fatal_error}" + final_content = error + stop_reason = "tool_error" + self._append_final_message(messages, final_content) + context.final_content = final_content + context.error = error + context.stop_reason = stop_reason + await hook.after_iteration(context) + break await self._emit_checkpoint( spec, { diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index ef4206573..45da0896c 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -1607,3 +1607,118 @@ async def test_microcompact_skips_non_compactable_tools(): result = AgentRunner._microcompact(messages) assert result is messages # no compactable tools found + + +@pytest.mark.asyncio +async def test_runner_tool_error_preserves_tool_results_in_messages(): + """When a tool raises a fatal error, its results must still be appended + to messages so the session never contains orphan tool_calls (#2943).""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + + async def chat_with_retry(*, messages, **kwargs): + return LLMResponse( + content=None, + tool_calls=[ + ToolCallRequest(id="tc1", name="read_file", arguments={"path": "a"}), + ToolCallRequest(id="tc2", name="exec", arguments={"cmd": "bad"}), + ], + usage={}, + ) + + provider.chat_with_retry = chat_with_retry + provider.chat_stream_with_retry = chat_with_retry + + call_idx = 0 + + async def fake_execute(name, args, **kw): + nonlocal call_idx + call_idx += 1 + if call_idx == 2: + raise RuntimeError("boom") + return "file content" + + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(side_effect=fake_execute) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "do stuff"}], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + fail_on_tool_error=True, + )) + + assert result.stop_reason == "tool_error" + # Both tool results must be in messages even though tc2 had a fatal error. + tool_msgs = [m for m in result.messages if m.get("role") == "tool"] + assert len(tool_msgs) == 2 + assert tool_msgs[0]["tool_call_id"] == "tc1" + assert tool_msgs[1]["tool_call_id"] == "tc2" + # The assistant message with tool_calls must precede the tool results. + asst_tc_idx = next( + i for i, m in enumerate(result.messages) + if m.get("role") == "assistant" and m.get("tool_calls") + ) + tool_indices = [ + i for i, m in enumerate(result.messages) if m.get("role") == "tool" + ] + assert all(ti > asst_tc_idx for ti in tool_indices) + + +def test_governance_repairs_orphans_after_snip(): + """After _snip_history clips an assistant+tool_calls, the second + _drop_orphan_tool_results pass must clean up the resulting orphans.""" + from nanobot.agent.runner import AgentRunner + + messages = [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "old msg"}, + {"role": "assistant", "content": None, + "tool_calls": [{"id": "tc_old", "type": "function", + "function": {"name": "search", "arguments": "{}"}}]}, + {"role": "tool", "tool_call_id": "tc_old", "name": "search", + "content": "old result"}, + {"role": "assistant", "content": "old answer"}, + {"role": "user", "content": "new msg"}, + ] + + # Simulate snipping that keeps only the tail: drop the assistant with + # tool_calls but keep its tool result (orphan). + snipped = [ + {"role": "system", "content": "system"}, + {"role": "tool", "tool_call_id": "tc_old", "name": "search", + "content": "old result"}, + {"role": "assistant", "content": "old answer"}, + {"role": "user", "content": "new msg"}, + ] + + cleaned = AgentRunner._drop_orphan_tool_results(snipped) + # The orphan tool result should be removed. + assert not any( + m.get("role") == "tool" and m.get("tool_call_id") == "tc_old" + for m in cleaned + ) + + +def test_governance_fallback_still_repairs_orphans(): + """When full governance fails, the fallback must still run + _drop_orphan_tool_results and _backfill_missing_tool_results.""" + from nanobot.agent.runner import AgentRunner + + # Messages with an orphan tool result (no matching assistant tool_call). + messages = [ + {"role": "user", "content": "hello"}, + {"role": "tool", "tool_call_id": "orphan_tc", "name": "read", + "content": "stale"}, + {"role": "assistant", "content": "hi"}, + ] + + repaired = AgentRunner._drop_orphan_tool_results(messages) + repaired = AgentRunner._backfill_missing_tool_results(repaired) + # Orphan tool result should be gone. + assert not any(m.get("tool_call_id") == "orphan_tc" for m in repaired) From ee946d96ca6ae65e04478e32bdcd7daac04ea022 Mon Sep 17 00:00:00 2001 From: Dianqi Ji Date: Mon, 6 Apr 2026 12:03:56 -0700 Subject: [PATCH 081/115] feat(channels/feishu): add domain config for Lark global support Add 'domain' field to FeishuConfig (Literal['feishu', 'lark'], default 'feishu'). Pass domain to lark.Client.builder() and lark.ws.Client to support Lark global (open.larksuite.com) in addition to Feishu China (open.feishu.cn). Existing configs default to 'feishu' for backward compatibility. Also add documentation for domain field in README.md and add tests for domain config. --- README.md | 4 ++- nanobot/channels/feishu.py | 6 ++++ tests/channels/test_feishu_domain.py | 48 ++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 tests/channels/test_feishu_domain.py diff --git a/README.md b/README.md index 72fd62aa0..c044073f0 100644 --- a/README.md +++ b/README.md @@ -563,7 +563,8 @@ Uses **WebSocket** long connection — no public IP required. "reactEmoji": "OnIt", "doneEmoji": "DONE", "toolHintPrefix": "🔧", - "streaming": true + "streaming": true, + "domain": "feishu" } } } @@ -576,6 +577,7 @@ Uses **WebSocket** long connection — no public IP required. > `reactEmoji`: Emoji for "processing" status (default: `OnIt`). See [available emojis](https://open.larkoffice.com/document/server-docs/im-v1/message-reaction/emojis-introduce). > `doneEmoji`: Optional emoji for "completed" status (e.g., `DONE`, `OK`, `HEART`). When set, bot adds this reaction after removing `reactEmoji`. > `toolHintPrefix`: Prefix for inline tool hints in streaming cards (default: `🔧`). +> `domain`: `"feishu"` (default) for China (open.feishu.cn), `"lark"` for international Lark (open.larksuite.com). **3. Run** diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py index 7d9c2772b..5afeca35f 100644 --- a/nanobot/channels/feishu.py +++ b/nanobot/channels/feishu.py @@ -22,6 +22,8 @@ from nanobot.channels.base import BaseChannel from nanobot.config.paths import get_media_dir from nanobot.config.schema import Base +from lark_oapi.core.const import FEISHU_DOMAIN, LARK_DOMAIN + FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None # Message type display mapping @@ -255,6 +257,7 @@ class FeishuConfig(Base): group_policy: Literal["open", "mention"] = "mention" reply_to_message: bool = False # If True, bot replies quote the user's original message streaming: bool = True + domain: Literal["feishu", "lark"] = "feishu" # Set to "lark" for international Lark _STREAM_ELEMENT_ID = "streaming_md" @@ -328,10 +331,12 @@ class FeishuChannel(BaseChannel): self._loop = asyncio.get_running_loop() # Create Lark client for sending messages + domain = LARK_DOMAIN if self.config.domain == "lark" else FEISHU_DOMAIN self._client = ( lark.Client.builder() .app_id(self.config.app_id) .app_secret(self.config.app_secret) + .domain(domain) .log_level(lark.LogLevel.INFO) .build() ) @@ -359,6 +364,7 @@ class FeishuChannel(BaseChannel): self._ws_client = lark.ws.Client( self.config.app_id, self.config.app_secret, + domain=domain, event_handler=event_handler, log_level=lark.LogLevel.INFO, ) diff --git a/tests/channels/test_feishu_domain.py b/tests/channels/test_feishu_domain.py new file mode 100644 index 000000000..caa1c4145 --- /dev/null +++ b/tests/channels/test_feishu_domain.py @@ -0,0 +1,48 @@ +"""Tests for Feishu/Lark domain configuration.""" +from unittest.mock import MagicMock + +import pytest + +from nanobot.bus.queue import MessageBus +from nanobot.channels.feishu import FeishuChannel, FeishuConfig + + +def _make_channel(domain: str = "feishu") -> FeishuChannel: + config = FeishuConfig( + enabled=True, + app_id="cli_test", + app_secret="secret", + allow_from=["*"], + domain=domain, + ) + ch = FeishuChannel(config, MessageBus()) + ch._client = MagicMock() + ch._loop = None + return ch + + +class TestFeishuConfigDomain: + def test_domain_default_is_feishu(self): + config = FeishuConfig() + assert config.domain == "feishu" + + def test_domain_accepts_lark(self): + config = FeishuConfig(domain="lark") + assert config.domain == "lark" + + def test_domain_accepts_feishu(self): + config = FeishuConfig(domain="feishu") + assert config.domain == "feishu" + + def test_default_config_includes_domain(self): + default_cfg = FeishuChannel.default_config() + assert "domain" in default_cfg + assert default_cfg["domain"] == "feishu" + + def test_channel_persists_domain_from_config(self): + ch = _make_channel(domain="lark") + assert ch.config.domain == "lark" + + def test_channel_persists_feishu_domain_from_config(self): + ch = _make_channel(domain="feishu") + assert ch.config.domain == "feishu" From e229c2ebc0bdef2fe5a2b7000a28fc1cf2308bf3 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Sun, 12 Apr 2026 02:21:46 +0000 Subject: [PATCH 082/115] fix(pr): remove internal .docs file from PR Keep the local review note out of the GitHub diff while preserving the actual code and test changes for this PR. Made-with: Cursor --- .docs/NEW_THINGS_SHOULD_ADD_TO_WEB.md | 189 -------------------------- 1 file changed, 189 deletions(-) delete mode 100644 .docs/NEW_THINGS_SHOULD_ADD_TO_WEB.md diff --git a/.docs/NEW_THINGS_SHOULD_ADD_TO_WEB.md b/.docs/NEW_THINGS_SHOULD_ADD_TO_WEB.md deleted file mode 100644 index 75fe4a541..000000000 --- a/.docs/NEW_THINGS_SHOULD_ADD_TO_WEB.md +++ /dev/null @@ -1,189 +0,0 @@ -# Pending Web Documentation Updates - -Items that need to be synced to `.web/nanobot-web/nanobot-web-page/docs/` when ready. - -## Merged (ready to update) - -### 1. Anthropic adaptive thinking mode (PR #2882) -- **What changed:** `reasoning_effort` now supports `"adaptive"` in addition to `"low"` / `"medium"` / `"high"`. - When set to `"adaptive"`, the model decides when and how much to think (supported on claude-sonnet-4-6, claude-opus-4-6). -- **Where to update:** - - `content.js` → `agents.defaults` reference section → `reasoningEffort` field description: - current text lists `"low"`, `"medium"`, `"high"`, or `null` — add `"adaptive"`. - - All 6 locale files (`zh-CN.js`, `zh-TW.js`, `ja.js`, `ko.js`, `es.js`, `fr.js`) → same field. -- **Source:** `nanobot/config/schema.py` line comment, `nanobot/providers/anthropic_provider.py` `_build_kwargs`. - -## Not yet merged (update after merge) - -### 2. Windows shell / cross-platform exec tool (PR #2926 + PR #2941) -- **What changed:** `exec` tool now works on Windows via `cmd.exe /c`. Environment isolation is - platform-aware: Unix passes `HOME`/`LANG`/`TERM` (bash -l handles PATH); Windows passes a curated - set of 15 system variables (`PATH`, `SYSTEMROOT`, `COMSPEC`, `USERPROFILE`, `HOMEDRIVE`, - `HOMEPATH`, `TEMP`, `TMP`, `PATHEXT`, `APPDATA`, `LOCALAPPDATA`, `ProgramData`, `ProgramFiles`, - `ProgramFiles(x86)`, `ProgramW6432`) while still excluding secrets. `bwrap` sandbox is gracefully - skipped on Windows with a warning. -- **Where to update:** - - `content.js` → Security section → exec tool environment description: - current text says "only HOME, LANG, TERM" — needs platform-specific note listing the 15 Windows variables. - - All 6 locale files → same section. -- **Source:** `nanobot/agent/tools/shell.py` `_build_env`, `_spawn`. - -### 3. Channel Plugin Guide — Pydantic config requirement (PR #2850) -- **Status:** ✅ Already updated in this batch (v=20260407d). -- Code examples updated to use `WebhookConfig(Base)` Pydantic model. -- Warning note added in all 7 languages explaining `is_allowed()` silent failure with plain dict. - -### 4. Telegram location sharing support (PR #2910) -- **What changed:** Telegram channel now handles location messages. When a user shares a - location pin, coordinates are forwarded to the agent as `[location: lat, lon]` — consistent - with the existing `[image: ...]` / `[transcription: ...]` conventions. This enables MCP tools - that accept geo coordinates (maps, weather, nearby search) to be triggered from a Telegram - location share. -- **Where to update:** - - `content.js` → Telegram channel section → supported message types: - current text lists text, images, voice, audio, documents — add location pins. - - All 6 locale files → same section. -- **Source:** `nanobot/channels/telegram.py` — `filters.LOCATION` in handler, `message.location` extraction in `_on_message`. - -### 5. Tool hint formatting for exec paths and dedup (PR #2926) -- **What changed:** Tool hints now fold file paths embedded in `exec` commands instead of blindly - truncating them mid-path. This includes quoted paths with spaces on Unix and Windows. Consecutive - hints are also deduplicated by the final formatted hint string, so different arguments are shown - separately while truly identical calls still fold as `× N`. -- **Where to update:** - - `content.js` → Agent loop / tool hint display section: - explain that exec command previews abbreviate embedded paths for readability and that folding - happens only for repeated identical rendered hints. - - All 6 locale files → same section. -- **Source:** `nanobot/utils/tool_hints.py`, `tests/agent/test_tool_hint.py`. - -### 6. Discord streaming replies enabled by default (PR #2939) -- **What changed:** Discord now supports the streaming reply path used by Telegram, and Discord - config gains a `streaming` flag that defaults to `true`. This avoids the previous non-streaming - fallback path that could end in an empty final response with some OpenAI-compatible gateways. -- **Where to update:** - - `content.js` → Discord channel section → config reference: - add the `streaming` field, note that it defaults to `true`, and explain it can be disabled to - force non-streaming replies. - - All 6 locale files → same section. -- **Source:** `nanobot/channels/discord.py`, `tests/channels/test_discord_channel.py`, `README.md`. - -### 7. WebSocket server channel (PR #2964) -- **What changed:** New `websocket` channel that runs a WebSocket server, allowing external clients - (web apps, CLIs, Chrome extensions, scripts) to interact with the agent in real time via persistent - connections. Supports streaming (`delta` + `stream_end` events), token-based authentication - (static tokens and short-lived issued tokens via HTTP endpoint), per-connection sessions, - TLS/SSL (WSS), and client allow-list. -- **Where to update:** - - `content.js` → Channels section: - add a new WebSocket channel subsection covering configuration (`channels.websocket`), wire - protocol (`ready`, `message`, `delta`, `stream_end` events), authentication modes (static token, - issued tokens via `tokenIssuePath`), and common deployment patterns. - - All 6 locale files → same section. - - README → supported channels list: add WebSocket. -- **Source:** `nanobot/channels/websocket.py`, `docs/WEBSOCKET.md` (comprehensive standalone doc). - -### 8. Exec tool `allowed_env_keys` config (PR #2962) -- **What changed:** New `allowed_env_keys` field in `tools.exec` config. Users can list host - environment variable names (e.g. `["GOPATH", "JAVA_HOME"]`) to selectively forward into the - sandboxed subprocess. Default is an empty list — no behavior change for existing users. Works - on both Unix and Windows. -- **Where to update:** - - `content.js` → Security section → exec tool environment description: - current text describes the default allow-list (HOME/LANG/TERM on Unix, 15 vars on Windows). - Add a note about `allowed_env_keys` for passing additional env vars. - - All 6 locale files → same section. -- **Source:** `nanobot/config/schema.py` (`ExecToolConfig.allowed_env_keys`), `nanobot/agent/tools/shell.py` (`_build_env`). - -### 9. Discord proxy support (PR #2960) -- **What changed:** Discord channel config gains `proxy`, `proxy_username`, and `proxy_password` - fields. When set, the Discord bot connection is routed through the specified HTTP proxy, - optionally with BasicAuth. Partial credentials (only username or only password) are logged - as a warning and ignored. -- **Where to update:** - - `content.js` → Discord channel section → config reference: - add the three proxy fields, note that `proxy_username`/`proxy_password` are both required - for auth, and that partial credentials are ignored with a warning. - - All 6 locale files → same section. -- **Source:** `nanobot/channels/discord.py` (`DiscordConfig`, `DiscordChannel.start`). - -### 10. Feishu streaming enhancements: resuming, inline tool hints, done emoji (PR #2993) -- **What changed:** Three Feishu channel improvements: - 1. `doneEmoji` config field — optional completion emoji (e.g. `"DONE"`) added after `reactEmoji` is removed when the bot finishes processing. - 2. `toolHintPrefix` config field — configurable prefix for inline tool hints (default: `🔧`). - 3. Streaming resuming — mid-turn tool calls flush text to the streaming card without closing it, so the next text segment continues on the same card. Tool hints are inlined into active streaming cards instead of sent as separate messages. -- **Where to update:** - - `content.js` → Feishu channel section → config reference: - add `doneEmoji` (optional string, emoji name for completion reaction) and `toolHintPrefix` (string, default `🔧`). - Note streaming resuming behavior for mid-turn tool calls. - - All 6 locale files → same section. - - README → already updated in this PR with config example. -- **Source:** `nanobot/channels/feishu.py` (`FeishuConfig.done_emoji`, `FeishuConfig.tool_hint_prefix`, `send_delta` resuming logic, `send` tool hint inline logic). - -### 11. Unified session across channels (PR #2900) -- **What changed:** New `unifiedSession` toggle in `config.json` (`agents.defaults`). When set to - `true`, all incoming messages — regardless of which channel they arrive on — share a single - session key (`unified:default`). Switching from Telegram to Discord continues the same - conversation. Defaults to `false` — zero behavior change for existing users. Existing - `session_key_override` (e.g. Telegram thread) is respected and not overwritten. -- **Where to update:** - - `content.js` → `agents.defaults` reference section: - add `unifiedSession` field, type `boolean`, default `false`, explain single-user multi-device - use case and that it merges all channel sessions into one. - - All 6 locale files → same section. - - README → config example or feature list, mention cross-channel unified session. -- **Source:** `nanobot/config/schema.py` (`unified_session`), `nanobot/agent/loop.py` (`UNIFIED_SESSION_KEY`, `_dispatch`). - -### 12. Auto compact config rename + recent live suffix retention (PR #3007) -- **What changed:** Auto compact now preserves a recent legal suffix of live session messages while - summarizing the older unconsolidated prefix, instead of clearing the entire live session. The - preferred config key is now `idleCompactAfterMinutes`; legacy `sessionTtlMinutes` remains accepted - as a backward-compatible alias. -- **Where to update:** - - `content.js` → `agents.defaults` reference section: - rename the field to `idleCompactAfterMinutes`, note that `sessionTtlMinutes` is a legacy alias, - and explain that auto compact keeps recent live context instead of replacing the whole session - with only a summary. - - All 6 locale files → same section. - - Any auto-compact behavior notes: - update wording from "session cleared" to "older context summarized, recent live suffix retained". -- **Source:** `nanobot/config/schema.py` (`AgentDefaults.session_ttl_minutes` aliases), - `nanobot/agent/auto_compact.py` (`_split_unconsolidated`, `_archive`), `README.md` Auto Compact section. - -### 13. Kagi web search provider (PR #2945) -- **What changed:** `tools.web.search.provider` now accepts `kagi`, using `apiKey` / `KAGI_API_KEY` - to call Kagi's Search API through the built-in `web_search` tool. -- **Where to update:** - - `content.js` → web tools / search provider section: - add `kagi` to the provider list, note that it uses the standard `apiKey` field or `KAGI_API_KEY`. - - All 6 locale files → same section. - - Any provider comparison tables: - add Kagi alongside Brave, Tavily, Jina, SearXNG, and DuckDuckGo. -- **Source:** `nanobot/agent/tools/web.py` (`_search_kagi`), - `nanobot/config/schema.py` (`WebSearchConfig.provider` comment), `README.md` web tools section. - -### 14. Mid-turn follow-up injection for active agent runs (PR #3042) -- **What changed:** If a user sends another message while the agent is still working on the same - session, the follow-up can now be injected into the current agent turn instead of waiting behind - the per-session lock as a separate later turn. Streaming channels keep the active reply open when - the turn resumes, so the follow-up answer can continue in the same live response flow. -- **Where to update:** - - `content.js` → agent loop / streaming behavior section: - explain that same-session follow-ups during an active turn may be folded into the in-flight - response instead of always starting a brand-new queued turn. - - All 6 locale files → same section. -- **Source:** `nanobot/agent/loop.py` (`_pending_queues`, unified-session routing, leftover re-publish), - `nanobot/agent/runner.py` (injection checkpoints, resumed stream end handling). - -### 15. Disable built-in/workspace skills via config (PR #2959) -- **What changed:** New `disabledSkills` field under `agents.defaults`. Users can provide a list of - skill directory names to exclude from loading, so selected built-in or workspace skills no longer - appear in the main agent or subagent skill summaries and are not auto-injected as always-on skills. -- **Where to update:** - - `content.js` -> `agents.defaults` reference section: - add `disabledSkills` as an array of skill names, explain that names match skill directory names, - and note that disabled skills are hidden from both the main agent and subagents. - - All 6 locale files -> same section. -- **Source:** `nanobot/config/schema.py` (`AgentDefaults.disabled_skills`), - `nanobot/agent/context.py` (`ContextBuilder`), `nanobot/agent/subagent.py` (`SubagentManager._build_subagent_prompt`), - `nanobot/agent/skills.py` (`SkillsLoader` filtering). From a142788da9141b665437f475ab198914edf14adb Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Sun, 12 Apr 2026 02:42:52 +0000 Subject: [PATCH 083/115] docs(readme): document disabledSkills config Explain the new agents.defaults.disabledSkills option so users can discover and configure skill exclusion from the main agent and subagents. Made-with: Cursor --- README.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/README.md b/README.md index c044073f0..6d7763c42 100644 --- a/README.md +++ b/README.md @@ -1597,6 +1597,26 @@ When enabled, all incoming messages — regardless of which channel they arrive > This is designed for single-user, multi-device setups. It is **off by default** — existing users see zero behavior change. +### Disabled Skills + +nanobot ships with built-in skills, and your workspace can also define custom skills under `skills/`. If you want to hide specific skills from the agent, set `agents.defaults.disabledSkills` to a list of skill directory names: + +```json +{ + "agents": { + "defaults": { + "disabledSkills": ["github", "weather"] + } + } +} +``` + +Disabled skills are excluded from the main agent's skill summary, from always-on skill injection, and from subagent skill summaries. This is useful when some bundled skills are unnecessary for your deployment or should not be exposed to end users. + +| Option | Default | Description | +|--------|---------|-------------| +| `agents.defaults.disabledSkills` | `[]` | List of skill directory names to exclude from loading. Applies to both built-in skills and workspace skills. | + ## 🧩 Multiple Instances Run multiple nanobot instances simultaneously with separate configs and runtime data. Use `--config` as the main entrypoint. Optionally pass `--workspace` during `onboard` when you want to initialize or update the saved workspace for a specific instance. From 00fb491bc9cee04088c4f7c46198e5e8f7f85030 Mon Sep 17 00:00:00 2001 From: 04cb <0x04cb@gmail.com> Date: Sun, 12 Apr 2026 15:09:22 +0800 Subject: [PATCH 084/115] fix(shell): block exec writes to history.jsonl and cursor files (#2989) --- nanobot/agent/tools/shell.py | 8 ++++++ tests/tools/test_exec_security.py | 46 +++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index d80b69fbe..6af9629aa 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -61,6 +61,14 @@ class ExecTool(Tool): r">\s*/dev/sd", # write to disk r"\b(shutdown|reboot|poweroff)\b", # system power r":\(\)\s*\{.*\};\s*:", # fork bomb + # Block writes to nanobot internal state files (#2989). + # history.jsonl / .dream_cursor are managed by append_history(); + # direct writes corrupt the cursor format and crash /dream. + r">>?\s*\S*(?:history\.jsonl|\.dream_cursor)", # > / >> redirect + r"\btee\b[^|;&<>]*(?:history\.jsonl|\.dream_cursor)", # tee / tee -a + r"\b(?:cp|mv)\b[^|;&<>]*(?:history\.jsonl|\.dream_cursor)", # cp/mv target + r"\bdd\b[^|;&<>]*\bof=\S*(?:history\.jsonl|\.dream_cursor)", # dd of= + r"\bsed\s+-i[^|;&<>]*(?:history\.jsonl|\.dream_cursor)", # sed -i ] self.allow_patterns = allow_patterns or [] self.restrict_to_workspace = restrict_to_workspace diff --git a/tests/tools/test_exec_security.py b/tests/tools/test_exec_security.py index e65d57565..bb8fc21ec 100644 --- a/tests/tools/test_exec_security.py +++ b/tests/tools/test_exec_security.py @@ -67,3 +67,49 @@ async def test_exec_blocks_chained_internal_url(): command="echo start && curl http://169.254.169.254/latest/meta-data/ && echo done" ) assert "Error" in result + + +# --- #2989: block writes to nanobot internal state files ----------------- + + +@pytest.mark.parametrize( + "command", + [ + "cat foo >> history.jsonl", + "echo '{}' > history.jsonl", + "echo '{}' > memory/history.jsonl", + "echo '{}' > ./workspace/memory/history.jsonl", + "tee -a history.jsonl < foo", + "tee history.jsonl", + "cp /tmp/fake.jsonl history.jsonl", + "mv backup.jsonl memory/history.jsonl", + "dd if=/dev/zero of=memory/history.jsonl", + "sed -i 's/old/new/' history.jsonl", + "echo x > .dream_cursor", + "cp /tmp/x memory/.dream_cursor", + ], +) +def test_exec_blocks_writes_to_history_jsonl(command): + """Direct writes to history.jsonl / .dream_cursor must be blocked (#2989).""" + tool = ExecTool() + result = tool._guard_command(command, "/tmp") + assert result is not None + assert "dangerous pattern" in result.lower() + + +@pytest.mark.parametrize( + "command", + [ + "cat history.jsonl", + "wc -l history.jsonl", + "tail -n 5 history.jsonl", + "grep foo history.jsonl", + "ls memory/", + "echo history.jsonl", + ], +) +def test_exec_allows_reads_of_history_jsonl(command): + """Read-only access to history.jsonl must still be allowed.""" + tool = ExecTool() + result = tool._guard_command(command, "/tmp") + assert result is None From 3f59bd1443b971474968dfb75f02fb13e635fc8c Mon Sep 17 00:00:00 2001 From: 04cb <0x04cb@gmail.com> Date: Sun, 12 Apr 2026 15:09:22 +0800 Subject: [PATCH 085/115] fix(shell): reject LLM-supplied working_dir outside workspace (#2826) --- nanobot/agent/tools/shell.py | 15 +++++++ tests/tools/test_exec_security.py | 68 +++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index 6af9629aa..729afa60b 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -101,6 +101,21 @@ class ExecTool(Tool): timeout: int | None = None, **kwargs: Any, ) -> str: cwd = working_dir or self.working_dir or os.getcwd() + + # Prevent an LLM-supplied working_dir from escaping the configured + # workspace when restrict_to_workspace is enabled (#2826). Without + # this, a caller can pass working_dir="/etc" and then all absolute + # paths under /etc would pass the _guard_command check that anchors + # on cwd. + if self.restrict_to_workspace and self.working_dir: + try: + requested = Path(cwd).expanduser().resolve() + workspace_root = Path(self.working_dir).expanduser().resolve() + except Exception: + return "Error: working_dir could not be resolved" + if requested != workspace_root and workspace_root not in requested.parents: + return "Error: working_dir is outside the configured workspace" + guard_error = self._guard_command(command, cwd) if guard_error: return guard_error diff --git a/tests/tools/test_exec_security.py b/tests/tools/test_exec_security.py index bb8fc21ec..9f001aaff 100644 --- a/tests/tools/test_exec_security.py +++ b/tests/tools/test_exec_security.py @@ -113,3 +113,71 @@ def test_exec_allows_reads_of_history_jsonl(command): tool = ExecTool() result = tool._guard_command(command, "/tmp") assert result is None + + +# --- #2826: working_dir must not escape the configured workspace --------- + + +@pytest.mark.asyncio +async def test_exec_blocks_working_dir_outside_workspace(tmp_path): + """An LLM-supplied working_dir outside the workspace must be rejected.""" + workspace = tmp_path / "workspace" + workspace.mkdir() + tool = ExecTool(working_dir=str(workspace), restrict_to_workspace=True) + result = await tool.execute(command="rm calendar.ics", working_dir="/etc") + assert "outside the configured workspace" in result + + +@pytest.mark.asyncio +async def test_exec_blocks_absolute_rm_via_hijacked_working_dir(tmp_path): + """Regression for #2826: `rm /abs/path` via working_dir hijack.""" + workspace = tmp_path / "workspace" + workspace.mkdir() + victim_dir = tmp_path / "outside" + victim_dir.mkdir() + victim = victim_dir / "file.ics" + victim.write_text("data") + + tool = ExecTool(working_dir=str(workspace), restrict_to_workspace=True) + result = await tool.execute( + command=f"rm {victim}", + working_dir=str(victim_dir), + ) + assert "outside the configured workspace" in result + assert victim.exists(), "victim file must not have been deleted" + + +@pytest.mark.asyncio +async def test_exec_allows_working_dir_within_workspace(tmp_path): + """A working_dir that is a subdirectory of the workspace is fine.""" + workspace = tmp_path / "workspace" + subdir = workspace / "project" + subdir.mkdir(parents=True) + tool = ExecTool(working_dir=str(workspace), restrict_to_workspace=True, timeout=5) + result = await tool.execute(command="echo ok", working_dir=str(subdir)) + assert "ok" in result + assert "outside the configured workspace" not in result + + +@pytest.mark.asyncio +async def test_exec_allows_working_dir_equal_to_workspace(tmp_path): + """Passing working_dir equal to the workspace root must be allowed.""" + workspace = tmp_path / "workspace" + workspace.mkdir() + tool = ExecTool(working_dir=str(workspace), restrict_to_workspace=True, timeout=5) + result = await tool.execute(command="echo ok", working_dir=str(workspace)) + assert "ok" in result + assert "outside the configured workspace" not in result + + +@pytest.mark.asyncio +async def test_exec_ignores_workspace_check_when_not_restricted(tmp_path): + """Without restrict_to_workspace, the LLM may still choose any working_dir.""" + workspace = tmp_path / "workspace" + workspace.mkdir() + other = tmp_path / "other" + other.mkdir() + tool = ExecTool(working_dir=str(workspace), restrict_to_workspace=False, timeout=5) + result = await tool.execute(command="echo ok", working_dir=str(other)) + assert "ok" in result + assert "outside the configured workspace" not in result From 5dc238c7efe54d3693a247aa0ecded1d98c7af6e Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Sun, 12 Apr 2026 08:28:38 +0000 Subject: [PATCH 086/115] fix(shell): allow read-only copies from internal state files Keep the new exec guard focused on writes to history.jsonl and .dream_cursor while still allowing read-only copy operations out of those files. Made-with: Cursor --- nanobot/agent/tools/shell.py | 2 +- tests/tools/test_exec_security.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index 729afa60b..aa8ca67b1 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -66,7 +66,7 @@ class ExecTool(Tool): # direct writes corrupt the cursor format and crash /dream. r">>?\s*\S*(?:history\.jsonl|\.dream_cursor)", # > / >> redirect r"\btee\b[^|;&<>]*(?:history\.jsonl|\.dream_cursor)", # tee / tee -a - r"\b(?:cp|mv)\b[^|;&<>]*(?:history\.jsonl|\.dream_cursor)", # cp/mv target + r"\b(?:cp|mv)\b(?:\s+[^\s|;&<>]+)+\s+\S*(?:history\.jsonl|\.dream_cursor)", # cp/mv target r"\bdd\b[^|;&<>]*\bof=\S*(?:history\.jsonl|\.dream_cursor)", # dd of= r"\bsed\s+-i[^|;&<>]*(?:history\.jsonl|\.dream_cursor)", # sed -i ] diff --git a/tests/tools/test_exec_security.py b/tests/tools/test_exec_security.py index 9f001aaff..20687dcbf 100644 --- a/tests/tools/test_exec_security.py +++ b/tests/tools/test_exec_security.py @@ -104,6 +104,7 @@ def test_exec_blocks_writes_to_history_jsonl(command): "wc -l history.jsonl", "tail -n 5 history.jsonl", "grep foo history.jsonl", + "cp history.jsonl /tmp/history.backup", "ls memory/", "echo history.jsonl", ], From 2a243bfe4f5747a04b5b541635e1742329d1d532 Mon Sep 17 00:00:00 2001 From: chengyongru <2755839590@qq.com> Date: Sun, 12 Apr 2026 00:46:09 +0800 Subject: [PATCH 087/115] feat(agent): integrate skill discovery into Dream consolidation Instead of a separate skill discovery system, extend Dream's two-phase pipeline to also detect reusable behavioral patterns from conversation history and generate SKILL.md files. Phase 1 gains a [SKILL] output type for pattern detection. Phase 2 gains write_file (scoped to skills/) and read access to builtin skills, enabling it to check for duplicates and follow skill-creator's format conventions before creating new skills. Inspired by PR #3039 by @wanghesong2019. Co-authored-by: wanghesong2019 --- nanobot/agent/memory.py | 55 +++++++++++++++++++++++-- nanobot/templates/agent/dream_phase1.md | 7 ++++ nanobot/templates/agent/dream_phase2.md | 13 ++++++ 3 files changed, 71 insertions(+), 4 deletions(-) diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py index 04d988ee5..8980a4baa 100644 --- a/nanobot/agent/memory.py +++ b/nanobot/agent/memory.py @@ -582,14 +582,53 @@ class Dream: def _build_tools(self) -> ToolRegistry: """Build a minimal tool registry for the Dream agent.""" - from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool + from nanobot.agent.skills import BUILTIN_SKILLS_DIR + from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, WriteFileTool tools = ToolRegistry() workspace = self.store.workspace - tools.register(ReadFileTool(workspace=workspace, allowed_dir=workspace)) + # Allow reading builtin skills for reference during skill creation + extra_read = [BUILTIN_SKILLS_DIR] if BUILTIN_SKILLS_DIR.exists() else None + tools.register(ReadFileTool( + workspace=workspace, + allowed_dir=workspace, + extra_allowed_dirs=extra_read, + )) tools.register(EditFileTool(workspace=workspace, allowed_dir=workspace)) + # write_file scoped to skills/ directory for skill creation + skills_dir = workspace / "skills" + skills_dir.mkdir(parents=True, exist_ok=True) + tools.register(WriteFileTool(workspace=skills_dir, allowed_dir=skills_dir)) return tools + # -- skill listing -------------------------------------------------------- + + def _list_existing_skills(self) -> list[str]: + """List existing skills as 'name — description' for dedup context.""" + import re as _re + + from nanobot.agent.skills import BUILTIN_SKILLS_DIR + + _DESC_RE = _re.compile(r"^description:\s*(.+)$", _re.MULTILINE | _re.IGNORECASE) + entries: dict[str, str] = {} + for base in (self.store.workspace / "skills", BUILTIN_SKILLS_DIR): + if not base.exists(): + continue + for d in base.iterdir(): + if not d.is_dir(): + continue + skill_md = d / "SKILL.md" + if not skill_md.exists(): + continue + # Prefer workspace skills over builtin (same name) + if d.name in entries and base == BUILTIN_SKILLS_DIR: + continue + content = skill_md.read_text(encoding="utf-8")[:500] + m = _DESC_RE.search(content) + desc = m.group(1).strip() if m else "(no description)" + entries[d.name] = desc + return [f"{name} — {desc}" for name, desc in sorted(entries.items())] + # -- main entry ---------------------------------------------------------- async def run(self) -> bool: @@ -615,6 +654,7 @@ class Dream: current_memory = self.store.read_memory() or "(empty)" current_soul = self.store.read_soul() or "(empty)" current_user = self.store.read_user() or "(empty)" + file_context = ( f"## Current Date\n{current_date}\n\n" f"## Current MEMORY.md ({len(current_memory)} chars)\n{current_memory}\n\n" @@ -622,7 +662,7 @@ class Dream: f"## Current USER.md ({len(current_user)} chars)\n{current_user}" ) - # Phase 1: Analyze + # Phase 1: Analyze (no skills list — dedup is Phase 2's job) phase1_prompt = ( f"## Conversation History\n{history_text}\n\n{file_context}" ) @@ -647,7 +687,14 @@ class Dream: return False # Phase 2: Delegate to AgentRunner with read_file / edit_file - phase2_prompt = f"## Analysis Result\n{analysis}\n\n{file_context}" + existing_skills = self._list_existing_skills() + skills_section = "" + if existing_skills: + skills_section = ( + "\n\n## Existing Skills\n" + + "\n".join(f"- {s}" for s in existing_skills) + ) + phase2_prompt = f"## Analysis Result\n{analysis}\n\n{file_context}{skills_section}" tools = self._tools messages: list[dict[str, Any]] = [ diff --git a/nanobot/templates/agent/dream_phase1.md b/nanobot/templates/agent/dream_phase1.md index 596958e36..3cc19b186 100644 --- a/nanobot/templates/agent/dream_phase1.md +++ b/nanobot/templates/agent/dream_phase1.md @@ -3,6 +3,7 @@ Compare conversation history against current memory files. Also scan memory file Output one line per finding: [FILE] atomic fact (not already in memory) [FILE-REMOVE] reason for removal +[SKILL] kebab-case-name: one-line description of the reusable pattern Files: USER (identity, preferences), SOUL (bot behavior, tone), MEMORY (knowledge, project context) @@ -18,6 +19,12 @@ Staleness — flag for [FILE-REMOVE]: - Detailed incident info after 14 days — reduce to one-line summary - Superseded: approaches replaced by newer solutions, deprecated dependencies +Skill discovery — flag [SKILL] when ALL of these are true: +- A specific, repeatable workflow appeared 2+ times in the conversation history +- It involves clear steps (not vague preferences like "likes concise answers") +- It is substantial enough to warrant its own instruction set (not trivial like "read a file") +- Do not worry about duplicates — the next phase will check against existing skills + Do not add: current weather, transient status, temporary errors, conversational filler. [SKIP] if nothing needs updating. diff --git a/nanobot/templates/agent/dream_phase2.md b/nanobot/templates/agent/dream_phase2.md index 49c8020da..450be7096 100644 --- a/nanobot/templates/agent/dream_phase2.md +++ b/nanobot/templates/agent/dream_phase2.md @@ -1,11 +1,13 @@ Update memory files based on the analysis below. - [FILE] entries: add the described content to the appropriate file - [FILE-REMOVE] entries: delete the corresponding content from memory files +- [SKILL] entries: create a new skill under skills//SKILL.md using write_file ## File paths (relative to workspace root) - SOUL.md - USER.md - memory/MEMORY.md +- skills//SKILL.md (for [SKILL] entries only) Do NOT guess paths. @@ -17,6 +19,17 @@ Do NOT guess paths. - Surgical edits only — never rewrite entire files - If nothing to update, stop without calling tools +## Skill creation rules (for [SKILL] entries) +- Use write_file to create skills//SKILL.md +- Before writing, read_file skills/skill-creator/SKILL.md for format reference (frontmatter structure, naming conventions, quality standards) +- **Dedup check**: read existing skills listed below to verify the new skill is not functionally redundant. Skip creation if an existing skill already covers the same workflow. +- Include YAML frontmatter with name and description fields +- Keep SKILL.md under 2000 words — concise and actionable +- Include: when to use, steps, output format, at least one example +- Do NOT overwrite existing skills — skip if the skill directory already exists +- Reference specific tools the agent has access to (read_file, write_file, exec, web_search, etc.) +- Skills are instruction sets, not code — do not include implementation code + ## Quality - Every line must carry standalone value - Concise bullets under clear headers From 7a7f5c96893d9aa46d9540c5d0fbcf4be346de40 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Sun, 12 Apr 2026 08:46:12 +0000 Subject: [PATCH 088/115] fix(dream): use valid builtin skill template paths Point Dream skill creation at a readable builtin skill-creator template, keep skill writes rooted at the workspace, and document the new skill discovery behavior in README. Made-with: Cursor --- README.md | 1 + nanobot/agent/memory.py | 14 ++++++++++--- nanobot/templates/agent/dream_phase2.md | 2 +- tests/agent/test_dream.py | 28 +++++++++++++++++++++++++ 4 files changed, 41 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 6d7763c42..b376d0991 100644 --- a/README.md +++ b/README.md @@ -1742,6 +1742,7 @@ time. - `memory/history.jsonl` stores append-only summarized history - `SOUL.md`, `USER.md`, and `memory/MEMORY.md` store long-term knowledge managed by Dream +- `Dream` can also promote repeated workflows into reusable workspace skills under `skills/` - `Dream` runs on a schedule and can also be triggered manually - memory changes can be inspected and restored with built-in commands diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py index 8980a4baa..3f8b24314 100644 --- a/nanobot/agent/memory.py +++ b/nanobot/agent/memory.py @@ -595,10 +595,11 @@ class Dream: extra_allowed_dirs=extra_read, )) tools.register(EditFileTool(workspace=workspace, allowed_dir=workspace)) - # write_file scoped to skills/ directory for skill creation + # write_file resolves relative paths from workspace root, but can only + # write under skills/ so the prompt can safely use skills//SKILL.md. skills_dir = workspace / "skills" skills_dir.mkdir(parents=True, exist_ok=True) - tools.register(WriteFileTool(workspace=skills_dir, allowed_dir=skills_dir)) + tools.register(WriteFileTool(workspace=workspace, allowed_dir=skills_dir)) return tools # -- skill listing -------------------------------------------------------- @@ -633,6 +634,8 @@ class Dream: async def run(self) -> bool: """Process unprocessed history entries. Returns True if work was done.""" + from nanobot.agent.skills import BUILTIN_SKILLS_DIR + last_cursor = self.store.get_last_dream_cursor() entries = self.store.read_unprocessed_history(since_cursor=last_cursor) if not entries: @@ -697,10 +700,15 @@ class Dream: phase2_prompt = f"## Analysis Result\n{analysis}\n\n{file_context}{skills_section}" tools = self._tools + skill_creator_path = BUILTIN_SKILLS_DIR / "skill-creator" / "SKILL.md" messages: list[dict[str, Any]] = [ { "role": "system", - "content": render_template("agent/dream_phase2.md", strip=True), + "content": render_template( + "agent/dream_phase2.md", + strip=True, + skill_creator_path=str(skill_creator_path), + ), }, {"role": "user", "content": phase2_prompt}, ] diff --git a/nanobot/templates/agent/dream_phase2.md b/nanobot/templates/agent/dream_phase2.md index 450be7096..f833afb6a 100644 --- a/nanobot/templates/agent/dream_phase2.md +++ b/nanobot/templates/agent/dream_phase2.md @@ -21,7 +21,7 @@ Do NOT guess paths. ## Skill creation rules (for [SKILL] entries) - Use write_file to create skills//SKILL.md -- Before writing, read_file skills/skill-creator/SKILL.md for format reference (frontmatter structure, naming conventions, quality standards) +- Before writing, read_file `{{ skill_creator_path }}` for format reference (frontmatter structure, naming conventions, quality standards) - **Dedup check**: read existing skills listed below to verify the new skill is not functionally redundant. Skip creation if an existing skill already covers the same workflow. - Include YAML frontmatter with name and description fields - Keep SKILL.md under 2000 words — concise and actionable diff --git a/tests/agent/test_dream.py b/tests/agent/test_dream.py index 38faafa7d..eece79ed9 100644 --- a/tests/agent/test_dream.py +++ b/tests/agent/test_dream.py @@ -6,6 +6,7 @@ from unittest.mock import AsyncMock, MagicMock from nanobot.agent.memory import Dream, MemoryStore from nanobot.agent.runner import AgentRunResult +from nanobot.agent.skills import BUILTIN_SKILLS_DIR @pytest.fixture @@ -95,3 +96,30 @@ class TestDreamRun: entries = store.read_unprocessed_history(since_cursor=0) assert all(e["cursor"] > 0 for e in entries) + async def test_skill_phase_uses_builtin_skill_creator_path(self, dream, mock_provider, mock_runner, store): + """Dream should point skill creation guidance at the builtin skill-creator template.""" + store.append_history("Repeated workflow one") + store.append_history("Repeated workflow two") + mock_provider.chat_with_retry.return_value = MagicMock(content="[SKILL] test-skill: test description") + mock_runner.run = AsyncMock(return_value=_make_run_result()) + + await dream.run() + + spec = mock_runner.run.call_args[0][0] + system_prompt = spec.initial_messages[0]["content"] + expected = str(BUILTIN_SKILLS_DIR / "skill-creator" / "SKILL.md") + assert expected in system_prompt + + async def test_skill_write_tool_accepts_workspace_relative_skill_path(self, dream, store): + """Dream skill creation should allow skills//SKILL.md relative to workspace root.""" + write_tool = dream._tools.get("write_file") + assert write_tool is not None + + result = await write_tool.execute( + path="skills/test-skill/SKILL.md", + content="---\nname: test-skill\ndescription: Test\n---\n", + ) + + assert "Successfully wrote" in result + assert (store.workspace / "skills" / "test-skill" / "SKILL.md").exists() + From b2612019850840402611ee8659402241113451b4 Mon Sep 17 00:00:00 2001 From: yanghan-cyber Date: Sun, 12 Apr 2026 14:20:14 +0800 Subject: [PATCH 089/115] fix(retry): strip images in-place to prevent repeated error-retry cycles When a non-transient LLM error occurs with image content, the retry mechanism strips images from a copy but never updates the original conversation history. Subsequent iterations rebuild context from the unmodified history, causing the same error-retry cycle to repeat every iteration until max_iterations is reached. Add _strip_image_content_inplace() that mutates the original message content lists in-place after a successful no-image retry, so callers sharing those references (e.g. the runner's conversation history) also see the stripped version. --- nanobot/providers/base.py | 27 +++++++++++++++++++++++++- tests/providers/test_provider_retry.py | 7 ++++--- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index d2a0727fb..8ce2b9a7a 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -419,6 +419,26 @@ class LLMProvider(ABC): result.append(msg) return result if found else None + @staticmethod + def _strip_image_content_inplace(messages: list[dict[str, Any]]) -> bool: + """Replace image_url blocks with text placeholder *in-place*. + + Mutates the content lists of the original message dicts so that + callers holding references to those dicts also see the stripped + version. + """ + found = False + for msg in messages: + content = msg.get("content") + if isinstance(content, list): + for i, b in enumerate(content): + if isinstance(b, dict) and b.get("type") == "image_url": + path = (b.get("_meta") or {}).get("path", "") + placeholder = image_placeholder_text(path, empty="[image omitted]") + content[i] = {"type": "text", "text": placeholder} + found = True + return found + async def _safe_chat(self, **kwargs: Any) -> LLMResponse: """Call chat() and convert unexpected exceptions to error responses.""" try: @@ -670,7 +690,12 @@ class LLMProvider(ABC): ) retry_kw = dict(kw) retry_kw["messages"] = stripped - return await call(**retry_kw) + result = await call(**retry_kw) + # Permanently strip images from the original messages so + # subsequent iterations do not repeat the error-retry cycle. + if result.finish_reason != "error": + self._strip_image_content_inplace(original_messages) + return result return response if persistent and identical_error_count >= self._PERSISTENT_IDENTICAL_ERROR_LIMIT: diff --git a/tests/providers/test_provider_retry.py b/tests/providers/test_provider_retry.py index 78c2a791e..c64e2a0f8 100644 --- a/tests/providers/test_provider_retry.py +++ b/tests/providers/test_provider_retry.py @@ -1,4 +1,5 @@ import asyncio +import copy import pytest @@ -152,7 +153,7 @@ async def test_non_transient_error_with_images_retries_without_images() -> None: LLMResponse(content="ok, no image"), ]) - response = await provider.chat_with_retry(messages=_IMAGE_MSG) + response = await provider.chat_with_retry(messages=copy.deepcopy(_IMAGE_MSG)) assert response.content == "ok, no image" assert provider.calls == 2 @@ -187,7 +188,7 @@ async def test_image_fallback_returns_error_on_second_failure() -> None: LLMResponse(content="still failing", finish_reason="error"), ]) - response = await provider.chat_with_retry(messages=_IMAGE_MSG) + response = await provider.chat_with_retry(messages=copy.deepcopy(_IMAGE_MSG)) assert provider.calls == 2 assert response.content == "still failing" @@ -202,7 +203,7 @@ async def test_image_fallback_without_meta_uses_default_placeholder() -> None: LLMResponse(content="ok"), ]) - response = await provider.chat_with_retry(messages=_IMAGE_MSG_NO_META) + response = await provider.chat_with_retry(messages=copy.deepcopy(_IMAGE_MSG_NO_META)) assert response.content == "ok" assert provider.calls == 2 From 217e1fc957513c9e6804f5ab1fd3bc66cf105b4b Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Sun, 12 Apr 2026 12:08:43 +0000 Subject: [PATCH 090/115] test(retry): lock in-place image fallback behavior Add a focused regression test for the successful no-image retry path so the original message history stays stripped after fallback and the repeated retry loop cannot silently return. Made-with: Cursor --- tests/providers/test_provider_retry.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/providers/test_provider_retry.py b/tests/providers/test_provider_retry.py index c64e2a0f8..2ef784a3d 100644 --- a/tests/providers/test_provider_retry.py +++ b/tests/providers/test_provider_retry.py @@ -165,6 +165,24 @@ async def test_non_transient_error_with_images_retries_without_images() -> None: assert any("[image: /media/test.png]" in (b.get("text") or "") for b in content) +@pytest.mark.asyncio +async def test_successful_image_retry_mutates_original_messages_in_place() -> None: + """Successful no-image retry should update the caller's message history.""" + provider = ScriptedProvider([ + LLMResponse(content="model does not support images", finish_reason="error"), + LLMResponse(content="ok, no image"), + ]) + messages = copy.deepcopy(_IMAGE_MSG) + + response = await provider.chat_with_retry(messages=messages) + + assert response.content == "ok, no image" + content = messages[0]["content"] + assert isinstance(content, list) + assert all(block.get("type") != "image_url" for block in content) + assert any("[image: /media/test.png]" in (block.get("text") or "") for block in content) + + @pytest.mark.asyncio async def test_non_transient_error_without_images_no_retry() -> None: """Non-transient errors without image content are returned immediately.""" From 7e91aecd7dfc932557548940b1a147e7897a4c4e Mon Sep 17 00:00:00 2001 From: bahtya Date: Sun, 12 Apr 2026 02:48:22 +0800 Subject: [PATCH 091/115] fix(telegram): narrow exception catch in _send_text to prevent retry amplification MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously _send_text() caught all exceptions (except Exception) when sending HTML-formatted messages, falling back to plain text even for network errors like TimedOut and NetworkError. This caused connection demand to double during pool exhaustion scenarios (3 retries × 2 fallback attempts = 6 calls per message instead of 3). Now only catches BadRequest (HTML parse errors), letting network errors propagate immediately to the retry layer where they belong. Fixes: HKUDS/nanobot#3050 --- nanobot/channels/telegram.py | 5 +- tests/channels/test_telegram_channel.py | 156 ++++++++++++++++++++++++ 2 files changed, 160 insertions(+), 1 deletion(-) diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py index 2dde232b1..d2572fac3 100644 --- a/nanobot/channels/telegram.py +++ b/nanobot/channels/telegram.py @@ -520,7 +520,10 @@ class TelegramChannel(BaseChannel): reply_parameters=reply_params, **(thread_kwargs or {}), ) - except Exception as e: + except BadRequest as e: + # Only fall back to plain text on actual HTML parse/format errors. + # Network errors (TimedOut, NetworkError) should propagate immediately + # to avoid doubling connection demand during pool exhaustion. logger.warning("HTML parse failed, falling back to plain text: {}", e) try: await self._call_with_retry( diff --git a/tests/channels/test_telegram_channel.py b/tests/channels/test_telegram_channel.py index 5a1964127..7dfb094f9 100644 --- a/tests/channels/test_telegram_channel.py +++ b/tests/channels/test_telegram_channel.py @@ -1159,3 +1159,159 @@ async def test_on_message_location_with_text() -> None: assert len(handled) == 1 assert "meet me here" in handled[0]["content"] assert "[location: 51.5074, -0.1278]" in handled[0]["content"] + + +# --------------------------------------------------------------------------- +# Tests for retry amplification fix (issue #3050) +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_send_text_does_not_fallback_on_network_timeout() -> None: + """TimedOut should propagate immediately, NOT trigger plain-text fallback. + + Before the fix, _send_text caught ALL exceptions (including TimedOut) + and retried as plain text, doubling connection demand during pool + exhaustion — see issue #3050. + """ + from telegram.error import TimedOut + + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + + call_count = 0 + + async def always_timeout(**kwargs): + nonlocal call_count + call_count += 1 + raise TimedOut() + + channel._app.bot.send_message = always_timeout + + import nanobot.channels.telegram as tg_mod + orig_delay = tg_mod._SEND_RETRY_BASE_DELAY + tg_mod._SEND_RETRY_BASE_DELAY = 0.01 + try: + with pytest.raises(TimedOut): + await channel._send_text(123, "hello", None, {}) + finally: + tg_mod._SEND_RETRY_BASE_DELAY = orig_delay + + # With the fix: only _call_with_retry's 3 HTML attempts (no plain fallback). + # Before the fix: 3 HTML + 3 plain = 6 attempts. + assert call_count == 3, ( + f"Expected 3 calls (HTML retries only), got {call_count} " + "(plain-text fallback should not trigger on TimedOut)" + ) + + +@pytest.mark.asyncio +async def test_send_text_does_not_fallback_on_network_error() -> None: + """NetworkError should propagate immediately, NOT trigger plain-text fallback.""" + from telegram.error import NetworkError + + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + + call_count = 0 + + async def always_network_error(**kwargs): + nonlocal call_count + call_count += 1 + raise NetworkError("Connection reset") + + channel._app.bot.send_message = always_network_error + + import nanobot.channels.telegram as tg_mod + orig_delay = tg_mod._SEND_RETRY_BASE_DELAY + tg_mod._SEND_RETRY_BASE_DELAY = 0.01 + try: + with pytest.raises(NetworkError): + await channel._send_text(123, "hello", None, {}) + finally: + tg_mod._SEND_RETRY_BASE_DELAY = orig_delay + + # _call_with_retry does NOT retry NetworkError (only TimedOut/RetryAfter), + # so it raises after 1 attempt. The fix prevents plain-text fallback. + # Before the fix: 1 HTML + 1 plain = 2. After the fix: 1 HTML only. + assert call_count == 1, ( + f"Expected 1 call (HTML only, no plain fallback), got {call_count}" + ) + + +@pytest.mark.asyncio +async def test_send_text_falls_back_on_bad_request() -> None: + """BadRequest (HTML parse error) should still trigger plain-text fallback.""" + from telegram.error import BadRequest + + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + + original_send = channel._app.bot.send_message + html_call_count = 0 + + async def html_fails(**kwargs): + nonlocal html_call_count + if kwargs.get("parse_mode") == "HTML": + html_call_count += 1 + raise BadRequest("Can't parse entities") + return await original_send(**kwargs) + + channel._app.bot.send_message = html_fails + + import nanobot.channels.telegram as tg_mod + orig_delay = tg_mod._SEND_RETRY_BASE_DELAY + tg_mod._SEND_RETRY_BASE_DELAY = 0.01 + try: + await channel._send_text(123, "hello **world**", None, {}) + finally: + tg_mod._SEND_RETRY_BASE_DELAY = orig_delay + + # HTML attempt failed with BadRequest → fallback to plain text succeeds. + assert html_call_count == 1, f"Expected 1 HTML attempt, got {html_call_count}" + assert len(channel._app.bot.sent_messages) == 1 + # Plain text send should NOT have parse_mode + assert channel._app.bot.sent_messages[0].get("parse_mode") is None + + +@pytest.mark.asyncio +async def test_send_text_bad_request_plain_fallback_exhausted() -> None: + """When both HTML and plain-text fallback fail with BadRequest, the error propagates.""" + from telegram.error import BadRequest + + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + + call_count = 0 + + async def always_bad_request(**kwargs): + nonlocal call_count + call_count += 1 + raise BadRequest("Bad request") + + channel._app.bot.send_message = always_bad_request + + import nanobot.channels.telegram as tg_mod + orig_delay = tg_mod._SEND_RETRY_BASE_DELAY + tg_mod._SEND_RETRY_BASE_DELAY = 0.01 + try: + with pytest.raises(BadRequest): + await channel._send_text(123, "hello", None, {}) + finally: + tg_mod._SEND_RETRY_BASE_DELAY = orig_delay + + # _call_with_retry does NOT retry BadRequest (only TimedOut/RetryAfter), + # so HTML fails after 1 attempt → fallback to plain also fails after 1 attempt. + # Before the fix: 2 total. After the fix: still 2 (BadRequest SHOULD fallback). + assert call_count == 2, f"Expected 2 calls (1 HTML + 1 plain), got {call_count}" From fa9852494416c851c53067363ddcfaac83a9e6a8 Mon Sep 17 00:00:00 2001 From: bahtya Date: Sun, 12 Apr 2026 03:28:38 +0800 Subject: [PATCH 092/115] fix(channels): prevent retry amplification and silent message loss across channels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Audited all channel implementations for overly broad exception handling that causes retry amplification or silent message loss during network errors. This is the same class of bug as #3050 (Telegram _send_text). Fixes by channel: Telegram (send_delta): - _stream_end path used except Exception for HTML edit fallback - Network errors (TimedOut, NetworkError) triggered redundant plain text edit, doubling connection demand during pool exhaustion - Changed to except BadRequest, matching the _send_text fix Discord: - send() caught all exceptions without re-raising - ChannelManager._send_with_retry() saw successful return, never retried - Messages silently dropped on any send failure - Added raise after error logging DingTalk: - _send_batch_message() returned False on all exceptions including network errors — no retry, fallback text sent unnecessarily - _read_media_bytes() and _upload_media() swallowed transport errors, causing _send_media_ref() to cascade through doomed fallback attempts - Added except httpx.TransportError handlers that re-raise immediately WeChat: - Media send failure triggered text fallback even for network errors - During network issues: 3×(media + text) = 6 API calls per message - Added specific catches: TimeoutException/TransportError re-raise, 5xx HTTPStatusError re-raises, 4xx falls back to text QQ: - _send_media() returned False on all exceptions - Network errors triggered fallback text instead of retry - Added except (aiohttp.ClientError, OSError) that re-raises Tests: 331 passed (283 existing + 48 new across 5 channel test files) Fixes: #3054 Related: #3050, #3053 --- nanobot/channels/dingtalk.py | 9 + nanobot/channels/discord.py | 1 + nanobot/channels/qq.py | 5 + nanobot/channels/telegram.py | 5 +- nanobot/channels/weixin.py | 36 ++++ tests/channels/test_dingtalk_channel.py | 155 +++++++++++++++++ tests/channels/test_discord_channel.py | 97 +++++++++++ tests/channels/test_qq_channel.py | 221 ++++++++++++++++++++++++ tests/channels/test_telegram_channel.py | 78 +++++++++ tests/channels/test_weixin_channel.py | 182 +++++++++++++++++++ 10 files changed, 788 insertions(+), 1 deletion(-) diff --git a/nanobot/channels/dingtalk.py b/nanobot/channels/dingtalk.py index 39b5818bd..a863ba0df 100644 --- a/nanobot/channels/dingtalk.py +++ b/nanobot/channels/dingtalk.py @@ -337,6 +337,9 @@ class DingTalkChannel(BaseChannel): content_type = (resp.headers.get("content-type") or "").split(";")[0].strip() filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref)) return resp.content, filename, content_type or None + except httpx.TransportError as e: + logger.error("DingTalk media download network error ref={} err={}", media_ref, e) + raise except Exception as e: logger.error("DingTalk media download error ref={} err={}", media_ref, e) return None, None, None @@ -388,6 +391,9 @@ class DingTalkChannel(BaseChannel): logger.error("DingTalk media upload missing media_id body={}", text[:500]) return None return str(media_id) + except httpx.TransportError as e: + logger.error("DingTalk media upload network error type={} err={}", media_type, e) + raise except Exception as e: logger.error("DingTalk media upload error type={} err={}", media_type, e) return None @@ -437,6 +443,9 @@ class DingTalkChannel(BaseChannel): return False logger.debug("DingTalk message sent to {} with msgKey={}", chat_id, msg_key) return True + except httpx.TransportError as e: + logger.error("DingTalk network error sending message msgKey={} err={}", msg_key, e) + raise except Exception as e: logger.error("Error sending DingTalk message msgKey={} err={}", msg_key, e) return False diff --git a/nanobot/channels/discord.py b/nanobot/channels/discord.py index 6e8c673a3..336b6148d 100644 --- a/nanobot/channels/discord.py +++ b/nanobot/channels/discord.py @@ -366,6 +366,7 @@ class DiscordChannel(BaseChannel): await client.send_outbound(msg) except Exception as e: logger.error("Error sending Discord message: {}", e) + raise finally: if not is_progress: await self._stop_typing(msg.chat_id) diff --git a/nanobot/channels/qq.py b/nanobot/channels/qq.py index 484eed6e2..96d9d5ecd 100644 --- a/nanobot/channels/qq.py +++ b/nanobot/channels/qq.py @@ -362,7 +362,12 @@ class QQChannel(BaseChannel): logger.info("QQ media sent: {}", filename) return True + except (aiohttp.ClientError, OSError) as e: + # Network / transport errors — propagate for retry by caller + logger.warning("QQ send media network error filename={} err={}", filename, e) + raise except Exception as e: + # API-level or other non-network errors — return False so send() can fallback logger.error("QQ send media failed filename={} err={}", filename, e) return False diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py index d2572fac3..f63704aa7 100644 --- a/nanobot/channels/telegram.py +++ b/nanobot/channels/telegram.py @@ -570,7 +570,10 @@ class TelegramChannel(BaseChannel): chat_id=int_chat_id, message_id=buf.message_id, text=html, parse_mode="HTML", ) - except Exception as e: + except BadRequest as e: + # Only fall back to plain text on actual HTML parse/format errors. + # Network errors (TimedOut, NetworkError) should propagate immediately + # to avoid doubling connection demand during pool exhaustion. if self._is_not_modified_error(e): logger.debug("Final stream edit already applied for {}", chat_id) self._stream_bufs.pop(chat_id, None) diff --git a/nanobot/channels/weixin.py b/nanobot/channels/weixin.py index 3f87e2203..fbe84bcf8 100644 --- a/nanobot/channels/weixin.py +++ b/nanobot/channels/weixin.py @@ -985,7 +985,43 @@ class WeixinChannel(BaseChannel): for media_path in (msg.media or []): try: await self._send_media_file(msg.chat_id, media_path, ctx_token) + except (httpx.TimeoutException, httpx.TransportError) as net_err: + # Network/transport errors: do NOT fall back to text — + # the text send would also likely fail, and the outer + # except will re-raise so ChannelManager retries properly. + logger.error( + "Network error sending WeChat media {}: {}", + media_path, + net_err, + ) + raise + except httpx.HTTPStatusError as http_err: + status_code = ( + http_err.response.status_code + if http_err.response is not None + else 0 + ) + if status_code >= 500: + # Server-side / retryable HTTP error — same as network. + logger.error( + "Server error ({} {}) sending WeChat media {}: {}", + status_code, + http_err.response.reason_phrase + if http_err.response is not None + else "", + media_path, + http_err, + ) + raise + # 4xx client errors are NOT retryable — fall back to text. + filename = Path(media_path).name + logger.error("Failed to send WeChat media {}: {}", media_path, http_err) + await self._send_text( + msg.chat_id, f"[Failed to send: {filename}]", ctx_token, + ) except Exception as e: + # Non-network errors (format, file-not-found, etc.): + # notify the user via text fallback. filename = Path(media_path).name logger.error("Failed to send WeChat media {}: {}", media_path, e) # Notify user about failure via text diff --git a/tests/channels/test_dingtalk_channel.py b/tests/channels/test_dingtalk_channel.py index f743c4e62..86de99bb5 100644 --- a/tests/channels/test_dingtalk_channel.py +++ b/tests/channels/test_dingtalk_channel.py @@ -2,7 +2,9 @@ import asyncio import zipfile from io import BytesIO from types import SimpleNamespace +from unittest.mock import AsyncMock +import httpx import pytest # Check optional dingtalk dependencies before running tests @@ -52,6 +54,21 @@ class _FakeHttp: return self._next_response() +class _NetworkErrorHttp: + """HTTP client stub that raises httpx.TransportError on every request.""" + + def __init__(self) -> None: + self.calls: list[dict] = [] + + async def post(self, url: str, json=None, headers=None, **kwargs): + self.calls.append({"method": "POST", "url": url, "json": json, "headers": headers}) + raise httpx.ConnectError("Connection refused") + + async def get(self, url: str, **kwargs): + self.calls.append({"method": "GET", "url": url}) + raise httpx.ConnectError("Connection refused") + + @pytest.mark.asyncio async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None: config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]) @@ -298,3 +315,141 @@ async def test_send_media_ref_zips_html_before_upload(tmp_path, monkeypatch) -> archive = zipfile.ZipFile(BytesIO(captured["data"])) assert archive.namelist() == ["report.html"] + + +# ── Exception handling tests ────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_send_batch_message_propagates_transport_error() -> None: + """Network/transport errors must re-raise so callers can retry.""" + config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]) + channel = DingTalkChannel(config, MessageBus()) + channel._http = _NetworkErrorHttp() + + with pytest.raises(httpx.ConnectError, match="Connection refused"): + await channel._send_batch_message( + "token", + "user123", + "sampleMarkdown", + {"text": "hello", "title": "Nanobot Reply"}, + ) + + # The POST was attempted exactly once + assert len(channel._http.calls) == 1 + assert channel._http.calls[0]["method"] == "POST" + + +@pytest.mark.asyncio +async def test_send_batch_message_returns_false_on_api_error() -> None: + """DingTalk API-level errors (non-200 status, errcode != 0) should return False.""" + config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]) + channel = DingTalkChannel(config, MessageBus()) + + # Non-200 status code → API error → return False + channel._http = _FakeHttp(responses=[_FakeResponse(400, {"errcode": 400})]) + result = await channel._send_batch_message( + "token", "user123", "sampleMarkdown", {"text": "hello"} + ) + assert result is False + + # 200 with non-zero errcode → API error → return False + channel._http = _FakeHttp(responses=[_FakeResponse(200, {"errcode": 100})]) + result = await channel._send_batch_message( + "token", "user123", "sampleMarkdown", {"text": "hello"} + ) + assert result is False + + # 200 with errcode=0 → success → return True + channel._http = _FakeHttp(responses=[_FakeResponse(200, {"errcode": 0})]) + result = await channel._send_batch_message( + "token", "user123", "sampleMarkdown", {"text": "hello"} + ) + assert result is True + + +@pytest.mark.asyncio +async def test_send_media_ref_short_circuits_on_transport_error() -> None: + """When the first send fails with a transport error, _send_media_ref must + re-raise immediately instead of trying download+upload+fallback.""" + config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]) + channel = DingTalkChannel(config, MessageBus()) + channel._http = _NetworkErrorHttp() + + # An image URL triggers the sampleImageMsg path first + with pytest.raises(httpx.ConnectError, match="Connection refused"): + await channel._send_media_ref("token", "user123", "https://example.com/photo.jpg") + + # Only one POST should have been attempted — no download/upload/fallback + assert len(channel._http.calls) == 1 + assert channel._http.calls[0]["method"] == "POST" + + +@pytest.mark.asyncio +async def test_send_media_ref_short_circuits_on_download_transport_error() -> None: + """When the image URL send returns an API error (False) but the download + for the fallback hits a transport error, it must re-raise rather than + silently returning False.""" + config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]) + channel = DingTalkChannel(config, MessageBus()) + + # First POST (sampleImageMsg) returns API error → False, then GET (download) raises transport error + class _MixedHttp: + def __init__(self) -> None: + self.calls: list[dict] = [] + + async def post(self, url, json=None, headers=None, **kwargs): + self.calls.append({"method": "POST", "url": url}) + # API-level failure: 200 with errcode != 0 + return _FakeResponse(200, {"errcode": 100}) + + async def get(self, url, **kwargs): + self.calls.append({"method": "GET", "url": url}) + raise httpx.ConnectError("Connection refused") + + channel._http = _MixedHttp() + + with pytest.raises(httpx.ConnectError, match="Connection refused"): + await channel._send_media_ref("token", "user123", "https://example.com/photo.jpg") + + # Should have attempted POST (image URL) and GET (download), but NOT upload + assert len(channel._http.calls) == 2 + assert channel._http.calls[0]["method"] == "POST" + assert channel._http.calls[1]["method"] == "GET" + + +@pytest.mark.asyncio +async def test_send_media_ref_short_circuits_on_upload_transport_error() -> None: + """When download succeeds but upload hits a transport error, must re-raise.""" + config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]) + channel = DingTalkChannel(config, MessageBus()) + + image_bytes = b"\xff\xd8\xff\xe0" + b"\x00" * 100 # minimal JPEG-ish data + + class _UploadFailsHttp: + def __init__(self) -> None: + self.calls: list[dict] = [] + + async def post(self, url, json=None, headers=None, files=None, **kwargs): + self.calls.append({"method": "POST", "url": url}) + # If it's the upload endpoint, raise transport error + if "media/upload" in url: + raise httpx.ConnectError("Connection refused") + # Otherwise (sampleImageMsg), return API error to trigger fallback + return _FakeResponse(200, {"errcode": 100}) + + async def get(self, url, **kwargs): + self.calls.append({"method": "GET", "url": url}) + resp = _FakeResponse(200) + resp.content = image_bytes + resp.headers = {"content-type": "image/jpeg"} + return resp + + channel._http = _UploadFailsHttp() + + with pytest.raises(httpx.ConnectError, match="Connection refused"): + await channel._send_media_ref("token", "user123", "https://example.com/photo.jpg") + + # POST (image URL), GET (download), POST (upload) attempted — no further sends + methods = [c["method"] for c in channel._http.calls] + assert methods == ["POST", "GET", "POST"] diff --git a/tests/channels/test_discord_channel.py b/tests/channels/test_discord_channel.py index 3a31a5912..7a39bff2b 100644 --- a/tests/channels/test_discord_channel.py +++ b/tests/channels/test_discord_channel.py @@ -867,3 +867,100 @@ async def test_start_no_proxy_auth_when_only_password(monkeypatch) -> None: assert channel.is_running is False assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890" assert _FakeDiscordClient.instances[0].proxy_auth is None + + +# --------------------------------------------------------------------------- +# Tests for the send() exception propagation fix +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_send_re_raises_network_error() -> None: + """Network errors during send must propagate so ChannelManager can retry.""" + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = _FakeDiscordClient(channel, intents=None) + channel._client = client + channel._running = True + + async def _failing_send_outbound(msg: OutboundMessage) -> None: + raise ConnectionError("network unreachable") + + client.send_outbound = _failing_send_outbound # type: ignore[method-assign] + + with pytest.raises(ConnectionError, match="network unreachable"): + await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello")) + + +@pytest.mark.asyncio +async def test_send_re_raises_generic_exception() -> None: + """Any exception from send_outbound must propagate, not be swallowed.""" + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = _FakeDiscordClient(channel, intents=None) + channel._client = client + channel._running = True + + async def _failing_send_outbound(msg: OutboundMessage) -> None: + raise RuntimeError("discord API failure") + + client.send_outbound = _failing_send_outbound # type: ignore[method-assign] + + with pytest.raises(RuntimeError, match="discord API failure"): + await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello")) + + +@pytest.mark.asyncio +async def test_send_still_stops_typing_on_error() -> None: + """Typing cleanup must still run in the finally block even when send raises.""" + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = _FakeDiscordClient(channel, intents=None) + channel._client = client + channel._running = True + + # Start a typing task so we can verify it gets cleaned up + start = asyncio.Event() + release = asyncio.Event() + + async def slow_typing() -> None: + start.set() + await release.wait() + + typing_channel = _FakeChannel(channel_id=123) + typing_channel.typing_enter_hook = slow_typing + await channel._start_typing(typing_channel) + await asyncio.wait_for(start.wait(), timeout=1.0) + + async def _failing_send_outbound(msg: OutboundMessage) -> None: + raise ConnectionError("timeout") + + client.send_outbound = _failing_send_outbound # type: ignore[method-assign] + + with pytest.raises(ConnectionError, match="timeout"): + await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello")) + + release.set() + await asyncio.sleep(0) + + # Typing should have been cleaned up by the finally block + assert channel._typing_tasks == {} + + +@pytest.mark.asyncio +async def test_send_succeeds_normally() -> None: + """Successful sends should work without raising.""" + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = _FakeDiscordClient(channel, intents=None) + channel._client = client + channel._running = True + + sent_messages: list[OutboundMessage] = [] + + async def _capture_send_outbound(msg: OutboundMessage) -> None: + sent_messages.append(msg) + + client.send_outbound = _capture_send_outbound # type: ignore[method-assign] + + msg = OutboundMessage(channel="discord", chat_id="123", content="hello world") + await channel.send(msg) + + assert len(sent_messages) == 1 + assert sent_messages[0].content == "hello world" + assert sent_messages[0].chat_id == "123" diff --git a/tests/channels/test_qq_channel.py b/tests/channels/test_qq_channel.py index 729442a13..417648adf 100644 --- a/tests/channels/test_qq_channel.py +++ b/tests/channels/test_qq_channel.py @@ -1,6 +1,7 @@ import tempfile from pathlib import Path from types import SimpleNamespace +from unittest.mock import AsyncMock, patch import pytest @@ -14,6 +15,8 @@ except ImportError: if not QQ_AVAILABLE: pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True) +import aiohttp + from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.qq import QQChannel, QQConfig @@ -170,3 +173,221 @@ async def test_read_media_bytes_missing_file() -> None: data, filename = await channel._read_media_bytes("/nonexistent/path/image.png") assert data is None assert filename is None + + +# ------------------------------------------------------- +# Tests for _send_media exception handling +# ------------------------------------------------------- + +def _make_channel_with_local_file(suffix: str = ".png", content: bytes = b"\x89PNG\r\n"): + """Create a QQChannel with a fake client and a temp file for media.""" + channel = QQChannel( + QQConfig(app_id="app", secret="secret", allow_from=["*"]), + MessageBus(), + ) + channel._client = _FakeClient() + channel._chat_type_cache["user1"] = "c2c" + + tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) + tmp.write(content) + tmp.close() + return channel, tmp.name + + +@pytest.mark.asyncio +async def test_send_media_network_error_propagates() -> None: + """aiohttp.ClientError (network/transport) should re-raise, not return False.""" + channel, tmp_path = _make_channel_with_local_file() + + # Make the base64 upload raise a network error + channel._client.api._http = SimpleNamespace() + channel._client.api._http.request = AsyncMock( + side_effect=aiohttp.ServerDisconnectedError("connection lost"), + ) + + with pytest.raises(aiohttp.ServerDisconnectedError): + await channel._send_media( + chat_id="user1", + media_ref=tmp_path, + msg_id="msg1", + is_group=False, + ) + + +@pytest.mark.asyncio +async def test_send_media_client_connector_error_propagates() -> None: + """aiohttp.ClientConnectorError (DNS/connection refused) should re-raise.""" + channel, tmp_path = _make_channel_with_local_file() + + from aiohttp.client_reqrep import ConnectionKey + conn_key = ConnectionKey("api.qq.com", 443, True, None, None, None, None) + connector_error = aiohttp.ClientConnectorError( + connection_key=conn_key, + os_error=OSError("Connection refused"), + ) + + channel._client.api._http = SimpleNamespace() + channel._client.api._http.request = AsyncMock( + side_effect=connector_error, + ) + + with pytest.raises(aiohttp.ClientConnectorError): + await channel._send_media( + chat_id="user1", + media_ref=tmp_path, + msg_id="msg1", + is_group=False, + ) + + +@pytest.mark.asyncio +async def test_send_media_oserror_propagates() -> None: + """OSError (low-level I/O) should re-raise for retry.""" + channel, tmp_path = _make_channel_with_local_file() + + channel._client.api._http = SimpleNamespace() + channel._client.api._http.request = AsyncMock( + side_effect=OSError("Network is unreachable"), + ) + + with pytest.raises(OSError): + await channel._send_media( + chat_id="user1", + media_ref=tmp_path, + msg_id="msg1", + is_group=False, + ) + + +@pytest.mark.asyncio +async def test_send_media_api_error_returns_false() -> None: + """API-level errors (botpy RuntimeError subclasses) should return False, not raise.""" + channel, tmp_path = _make_channel_with_local_file() + + # Simulate a botpy API error (e.g. ServerError is a RuntimeError subclass) + from botpy.errors import ServerError + + channel._client.api._http = SimpleNamespace() + channel._client.api._http.request = AsyncMock( + side_effect=ServerError("internal server error"), + ) + + result = await channel._send_media( + chat_id="user1", + media_ref=tmp_path, + msg_id="msg1", + is_group=False, + ) + assert result is False + + +@pytest.mark.asyncio +async def test_send_media_generic_runtime_error_returns_false() -> None: + """Generic RuntimeError (not network) should return False.""" + channel, tmp_path = _make_channel_with_local_file() + + channel._client.api._http = SimpleNamespace() + channel._client.api._http.request = AsyncMock( + side_effect=RuntimeError("some API error"), + ) + + result = await channel._send_media( + chat_id="user1", + media_ref=tmp_path, + msg_id="msg1", + is_group=False, + ) + assert result is False + + +@pytest.mark.asyncio +async def test_send_media_value_error_returns_false() -> None: + """ValueError (bad API response data) should return False.""" + channel, tmp_path = _make_channel_with_local_file() + + channel._client.api._http = SimpleNamespace() + channel._client.api._http.request = AsyncMock( + side_effect=ValueError("bad response data"), + ) + + result = await channel._send_media( + chat_id="user1", + media_ref=tmp_path, + msg_id="msg1", + is_group=False, + ) + assert result is False + + +@pytest.mark.asyncio +async def test_send_media_timeout_error_propagates() -> None: + """asyncio.TimeoutError inherits from Exception but not ClientError/OSError. + However, aiohttp.ServerTimeoutError IS a ClientError subclass, so that propagates. + For a plain TimeoutError (which is also OSError in Python 3.11+), it should propagate.""" + channel, tmp_path = _make_channel_with_local_file() + + channel._client.api._http = SimpleNamespace() + channel._client.api._http.request = AsyncMock( + side_effect=aiohttp.ServerTimeoutError("request timed out"), + ) + + with pytest.raises(aiohttp.ServerTimeoutError): + await channel._send_media( + chat_id="user1", + media_ref=tmp_path, + msg_id="msg1", + is_group=False, + ) + + +@pytest.mark.asyncio +async def test_send_fallback_text_on_api_error() -> None: + """When _send_media returns False (API error), send() should emit fallback text.""" + channel, tmp_path = _make_channel_with_local_file() + + from botpy.errors import ServerError + + channel._client.api._http = SimpleNamespace() + channel._client.api._http.request = AsyncMock( + side_effect=ServerError("internal server error"), + ) + + await channel.send( + OutboundMessage( + channel="qq", + chat_id="user1", + content="", + media=[tmp_path], + metadata={"message_id": "msg1"}, + ) + ) + + # Should have sent a fallback text message + assert len(channel._client.api.c2c_calls) == 1 + fallback_content = channel._client.api.c2c_calls[0]["content"] + assert "Attachment send failed" in fallback_content + + +@pytest.mark.asyncio +async def test_send_propagates_network_error_no_fallback() -> None: + """When _send_media raises a network error, send() should NOT silently fallback.""" + channel, tmp_path = _make_channel_with_local_file() + + channel._client.api._http = SimpleNamespace() + channel._client.api._http.request = AsyncMock( + side_effect=aiohttp.ServerDisconnectedError("connection lost"), + ) + + with pytest.raises(aiohttp.ServerDisconnectedError): + await channel.send( + OutboundMessage( + channel="qq", + chat_id="user1", + content="hello", + media=[tmp_path], + metadata={"message_id": "msg1"}, + ) + ) + + # No fallback text should have been sent + assert len(channel._client.api.c2c_calls) == 0 diff --git a/tests/channels/test_telegram_channel.py b/tests/channels/test_telegram_channel.py index 7dfb094f9..8d9431ba6 100644 --- a/tests/channels/test_telegram_channel.py +++ b/tests/channels/test_telegram_channel.py @@ -387,6 +387,84 @@ async def test_send_delta_stream_end_treats_not_modified_as_success() -> None: assert "123" not in channel._stream_bufs +@pytest.mark.asyncio +async def test_send_delta_stream_end_does_not_fallback_on_network_timeout() -> None: + """TimedOut during HTML edit should propagate, never fall back to plain text.""" + from telegram.error import TimedOut + + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + # _call_with_retry retries TimedOut up to 3 times, so the mock will be called + # multiple times – but all calls must be with parse_mode="HTML" (no plain fallback). + channel._app.bot.edit_message_text = AsyncMock(side_effect=TimedOut("network timeout")) + channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0) + + with pytest.raises(TimedOut, match="network timeout"): + await channel.send_delta("123", "", {"_stream_end": True}) + + # Every call to edit_message_text must have used parse_mode="HTML" — + # no plain-text fallback call should have been made. + for call in channel._app.bot.edit_message_text.call_args_list: + assert call.kwargs.get("parse_mode") == "HTML" + # Buffer should still be present (not cleaned up on error) + assert "123" in channel._stream_bufs + + +@pytest.mark.asyncio +async def test_send_delta_stream_end_does_not_fallback_on_network_error() -> None: + """NetworkError during HTML edit should propagate, never fall back to plain text.""" + from telegram.error import NetworkError + + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + channel._app.bot.edit_message_text = AsyncMock(side_effect=NetworkError("connection reset")) + channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0) + + with pytest.raises(NetworkError, match="connection reset"): + await channel.send_delta("123", "", {"_stream_end": True}) + + # Every call to edit_message_text must have used parse_mode="HTML" — + # no plain-text fallback call should have been made. + for call in channel._app.bot.edit_message_text.call_args_list: + assert call.kwargs.get("parse_mode") == "HTML" + # Buffer should still be present (not cleaned up on error) + assert "123" in channel._stream_bufs + + +@pytest.mark.asyncio +async def test_send_delta_stream_end_falls_back_on_bad_request() -> None: + """BadRequest (HTML parse error) should still trigger plain-text fallback.""" + from telegram.error import BadRequest + + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + + # First call (HTML) raises BadRequest, second call (plain) succeeds + channel._app.bot.edit_message_text = AsyncMock( + side_effect=[BadRequest("Can't parse entities"), None] + ) + channel._stream_bufs["123"] = _StreamBuf(text="hello ", message_id=7, last_edit=0.0) + + await channel.send_delta("123", "", {"_stream_end": True}) + + # edit_message_text should have been called twice: once for HTML, once for plain fallback + assert channel._app.bot.edit_message_text.call_count == 2 + # Second call should not use parse_mode="HTML" + second_call_kwargs = channel._app.bot.edit_message_text.call_args_list[1].kwargs + assert "parse_mode" not in second_call_kwargs or second_call_kwargs.get("parse_mode") is None + # Buffer should be cleaned up on success + assert "123" not in channel._stream_bufs + + @pytest.mark.asyncio async def test_send_delta_stream_end_splits_oversized_reply() -> None: """Final streamed reply exceeding Telegram limit is split into chunks.""" diff --git a/tests/channels/test_weixin_channel.py b/tests/channels/test_weixin_channel.py index 3a847411b..2b455fca6 100644 --- a/tests/channels/test_weixin_channel.py +++ b/tests/channels/test_weixin_channel.py @@ -1003,3 +1003,185 @@ async def test_download_media_item_non_image_requires_aes_key_even_with_full_url assert saved_path is None channel._client.get.assert_not_awaited() + + +# --------------------------------------------------------------------------- +# Tests for media-send error classification (network vs non-network errors) +# --------------------------------------------------------------------------- + + +def _make_outbound_msg(chat_id: str = "wx-user", content: str = "", media: list | None = None): + """Build a minimal OutboundMessage-like object for send() tests.""" + from nanobot.bus.events import OutboundMessage + + return OutboundMessage( + channel="weixin", + chat_id=chat_id, + content=content, + media=media or [], + metadata={}, + ) + + +@pytest.mark.asyncio +async def test_send_media_timeout_error_propagates_without_text_fallback() -> None: + """httpx.TimeoutException during media send must re-raise immediately, + NOT fall back to _send_text (which would also fail during network issues).""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-1" + channel._send_media_file = AsyncMock(side_effect=httpx.TimeoutException("timed out")) + channel._send_text = AsyncMock() + + msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"]) + + with pytest.raises(httpx.TimeoutException, match="timed out"): + await channel.send(msg) + + # _send_text must NOT have been called as a fallback + channel._send_text.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_send_media_transport_error_propagates_without_text_fallback() -> None: + """httpx.TransportError during media send must re-raise immediately.""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-1" + channel._send_media_file = AsyncMock( + side_effect=httpx.TransportError("connection reset") + ) + channel._send_text = AsyncMock() + + msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"]) + + with pytest.raises(httpx.TransportError, match="connection reset"): + await channel.send(msg) + + channel._send_text.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_send_media_5xx_http_status_error_propagates_without_text_fallback() -> None: + """httpx.HTTPStatusError with a 5xx status must re-raise immediately.""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-1" + + fake_response = httpx.Response( + status_code=503, + request=httpx.Request("POST", "https://example.test/upload"), + ) + channel._send_media_file = AsyncMock( + side_effect=httpx.HTTPStatusError( + "Service Unavailable", request=fake_response.request, response=fake_response + ) + ) + channel._send_text = AsyncMock() + + msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"]) + + with pytest.raises(httpx.HTTPStatusError, match="Service Unavailable"): + await channel.send(msg) + + channel._send_text.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_send_media_4xx_http_status_error_falls_back_to_text() -> None: + """httpx.HTTPStatusError with a 4xx status should fall back to text, not re-raise.""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-1" + + fake_response = httpx.Response( + status_code=400, + request=httpx.Request("POST", "https://example.test/upload"), + ) + channel._send_media_file = AsyncMock( + side_effect=httpx.HTTPStatusError( + "Bad Request", request=fake_response.request, response=fake_response + ) + ) + channel._send_text = AsyncMock() + + msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"]) + + # Should NOT raise — 4xx is a client error, non-retryable + await channel.send(msg) + + # _send_text should have been called with the fallback message + channel._send_text.assert_awaited_once_with( + "wx-user", "[Failed to send: photo.jpg]", "ctx-1" + ) + + +@pytest.mark.asyncio +async def test_send_media_file_not_found_falls_back_to_text() -> None: + """FileNotFoundError (a non-network error) should fall back to text.""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-1" + channel._send_media_file = AsyncMock( + side_effect=FileNotFoundError("Media file not found: /tmp/missing.jpg") + ) + channel._send_text = AsyncMock() + + msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/missing.jpg"]) + + # Should NOT raise + await channel.send(msg) + + channel._send_text.assert_awaited_once_with( + "wx-user", "[Failed to send: missing.jpg]", "ctx-1" + ) + + +@pytest.mark.asyncio +async def test_send_media_value_error_falls_back_to_text() -> None: + """ValueError (e.g. unsupported format) should fall back to text.""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-1" + channel._send_media_file = AsyncMock( + side_effect=ValueError("Unsupported media format") + ) + channel._send_text = AsyncMock() + + msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/file.xyz"]) + + # Should NOT raise + await channel.send(msg) + + channel._send_text.assert_awaited_once_with( + "wx-user", "[Failed to send: file.xyz]", "ctx-1" + ) + + +@pytest.mark.asyncio +async def test_send_media_network_error_does_not_double_api_calls() -> None: + """During network issues, media send should make exactly 1 API call attempt, + not 2 (media + text fallback). Verify total call count.""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-1" + channel._send_media_file = AsyncMock( + side_effect=httpx.ConnectError("connection refused") + ) + channel._send_text = AsyncMock() + + msg = _make_outbound_msg(chat_id="wx-user", content="hello", media=["/tmp/img.png"]) + + with pytest.raises(httpx.ConnectError): + await channel.send(msg) + + # _send_media_file called once, _send_text never called + channel._send_media_file.assert_awaited_once() + channel._send_text.assert_not_awaited() From f879d81b28cea4bc7ff07ffc2db2362080138c71 Mon Sep 17 00:00:00 2001 From: bahtya Date: Sun, 12 Apr 2026 09:24:06 +0800 Subject: [PATCH 093/115] fix(channels/qq): propagate network errors in send() instead of swallowing The catch-all except Exception in QQ send() was swallowing aiohttp.ClientError and OSError that _send_media correctly re-raises. Add explicit catch for network errors before the generic handler. --- nanobot/channels/qq.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nanobot/channels/qq.py b/nanobot/channels/qq.py index 96d9d5ecd..f109f6da6 100644 --- a/nanobot/channels/qq.py +++ b/nanobot/channels/qq.py @@ -280,6 +280,9 @@ class QQChannel(BaseChannel): msg_id=msg_id, content=msg.content.strip(), ) + except (aiohttp.ClientError, OSError): + # Network / transport errors — propagate so ChannelManager can retry + raise except Exception: logger.exception("Error sending QQ message to chat_id={}", msg.chat_id) From c68b3edb9d085e2c3f3b0018a41e3d61dc07cf26 Mon Sep 17 00:00:00 2001 From: haosenwang1018 Date: Sun, 12 Apr 2026 20:06:47 +0000 Subject: [PATCH 094/115] fix(provider): clarify local 502 recovery hints --- nanobot/providers/openai_compat_provider.py | 13 +++++++++++-- tests/providers/test_custom_provider.py | 14 ++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 101ee6c33..83fbd7fb3 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -798,8 +798,7 @@ class OpenAICompatProvider(LLMProvider): "error_should_retry": should_retry, } - @staticmethod - def _handle_error(e: Exception) -> LLMResponse: + def _handle_error(self, e: Exception) -> LLMResponse: body = ( getattr(e, "doc", None) or getattr(e, "body", None) @@ -807,6 +806,16 @@ class OpenAICompatProvider(LLMProvider): ) body_text = body if isinstance(body, str) else str(body) if body is not None else "" msg = f"Error: {body_text.strip()[:500]}" if body_text.strip() else f"Error calling LLM: {e}" + + spec = self._spec + text = f"{body_text} {e}".lower() + if spec and spec.is_local and ("502" in text or "connection" in text or "refused" in text): + msg += ( + "\nHint: this is a local model endpoint. Check that the local server is reachable at " + f"{self.api_base or spec.default_api_base}, and if you are using a proxy/tunnel, make sure it " + "can reach your local Ollama/vLLM service instead of routing localhost through the remote host." + ) + response = getattr(e, "response", None) retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None)) if retry_after is None: diff --git a/tests/providers/test_custom_provider.py b/tests/providers/test_custom_provider.py index d2a9f4247..33c5d027a 100644 --- a/tests/providers/test_custom_provider.py +++ b/tests/providers/test_custom_provider.py @@ -4,6 +4,7 @@ from types import SimpleNamespace from unittest.mock import patch from nanobot.providers.openai_compat_provider import OpenAICompatProvider +from nanobot.providers.registry import find_by_name def test_custom_provider_parse_handles_empty_choices() -> None: @@ -53,3 +54,16 @@ def test_custom_provider_parse_chunks_accepts_plain_text_chunks() -> None: assert result.finish_reason == "stop" assert result.content == "hello world" + + +def test_local_provider_502_error_includes_reachability_hint() -> None: + spec = find_by_name("ollama") + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider(api_base="http://localhost:11434/v1", spec=spec) + + result = provider._handle_error(Exception("Error code: 502")) + + assert result.finish_reason == "error" + assert "local model endpoint" in result.content + assert "http://localhost:11434/v1" in result.content + assert "proxy/tunnel" in result.content From 3573109408f2d9bb9cdda60938b12def6a53f0fb Mon Sep 17 00:00:00 2001 From: haosenwang1018 Date: Sun, 12 Apr 2026 20:53:18 +0000 Subject: [PATCH 095/115] fix(provider): preserve static error helper compatibility --- nanobot/providers/openai_compat_provider.py | 15 ++++++++++----- tests/providers/test_custom_provider.py | 6 +++++- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 83fbd7fb3..4dea2d5fc 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -798,7 +798,13 @@ class OpenAICompatProvider(LLMProvider): "error_should_retry": should_retry, } - def _handle_error(self, e: Exception) -> LLMResponse: + @staticmethod + def _handle_error( + e: Exception, + *, + spec: ProviderSpec | None = None, + api_base: str | None = None, + ) -> LLMResponse: body = ( getattr(e, "doc", None) or getattr(e, "body", None) @@ -807,12 +813,11 @@ class OpenAICompatProvider(LLMProvider): body_text = body if isinstance(body, str) else str(body) if body is not None else "" msg = f"Error: {body_text.strip()[:500]}" if body_text.strip() else f"Error calling LLM: {e}" - spec = self._spec text = f"{body_text} {e}".lower() if spec and spec.is_local and ("502" in text or "connection" in text or "refused" in text): msg += ( "\nHint: this is a local model endpoint. Check that the local server is reachable at " - f"{self.api_base or spec.default_api_base}, and if you are using a proxy/tunnel, make sure it " + f"{api_base or spec.default_api_base}, and if you are using a proxy/tunnel, make sure it " "can reach your local Ollama/vLLM service instead of routing localhost through the remote host." ) @@ -859,7 +864,7 @@ class OpenAICompatProvider(LLMProvider): ) return self._parse(await self._client.chat.completions.create(**kwargs)) except Exception as e: - return self._handle_error(e) + return self._handle_error(e, spec=self._spec, api_base=self.api_base) async def chat_stream( self, @@ -942,7 +947,7 @@ class OpenAICompatProvider(LLMProvider): error_kind="timeout", ) except Exception as e: - return self._handle_error(e) + return self._handle_error(e, spec=self._spec, api_base=self.api_base) def get_default_model(self) -> str: return self.default_model diff --git a/tests/providers/test_custom_provider.py b/tests/providers/test_custom_provider.py index 33c5d027a..85314dc79 100644 --- a/tests/providers/test_custom_provider.py +++ b/tests/providers/test_custom_provider.py @@ -61,7 +61,11 @@ def test_local_provider_502_error_includes_reachability_hint() -> None: with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): provider = OpenAICompatProvider(api_base="http://localhost:11434/v1", spec=spec) - result = provider._handle_error(Exception("Error code: 502")) + result = provider._handle_error( + Exception("Error code: 502"), + spec=spec, + api_base="http://localhost:11434/v1", + ) assert result.finish_reason == "error" assert "local model endpoint" in result.content From 92ef594b6a196a8084187d163058b82250695e76 Mon Sep 17 00:00:00 2001 From: haosenwang1018 Date: Mon, 13 Apr 2026 01:07:08 +0000 Subject: [PATCH 096/115] fix(mcp): hint on stdio protocol pollution --- nanobot/agent/tools/mcp.py | 18 +++++++++++++++++- tests/tools/test_mcp_tool.py | 27 +++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/nanobot/agent/tools/mcp.py b/nanobot/agent/tools/mcp.py index 1b5a71322..2aea19279 100644 --- a/nanobot/agent/tools/mcp.py +++ b/nanobot/agent/tools/mcp.py @@ -454,7 +454,23 @@ async def connect_mcp_servers( return name, server_stack except Exception as e: - logger.error("MCP server '{}': failed to connect: {}", name, e) + hint = "" + text = str(e).lower() + if any( + marker in text + for marker in ( + "parse error", + "invalid json", + "unexpected token", + "jsonrpc", + "content-length", + ) + ): + hint = ( + " Hint: this looks like stdio protocol pollution. Make sure the MCP server writes " + "only JSON-RPC to stdout and sends logs/debug output to stderr instead." + ) + logger.error("MCP server '{}': failed to connect: {}{}", name, e, hint) try: await server_stack.aclose() except Exception: diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index da90c4d0d..a133f53db 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -356,6 +356,33 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries( assert "Available wrapped names: mcp_test_demo" in warnings[-1] +@pytest.mark.asyncio +async def test_connect_mcp_servers_logs_stdio_pollution_hint( + monkeypatch: pytest.MonkeyPatch, +) -> None: + messages: list[str] = [] + + def _error(message: str, *args: object) -> None: + messages.append(message.format(*args)) + + @asynccontextmanager + async def _broken_stdio_client(_params: object): + raise RuntimeError("Parse error: Unexpected token 'INFO' before JSON-RPC headers") + yield # pragma: no cover + + monkeypatch.setattr(sys.modules["mcp.client.stdio"], "stdio_client", _broken_stdio_client) + monkeypatch.setattr("nanobot.agent.tools.mcp.logger.error", _error) + + registry = ToolRegistry() + stacks = await connect_mcp_servers({"gh": MCPServerConfig(command="github-mcp")}, registry) + + assert stacks == {} + assert messages + assert "stdio protocol pollution" in messages[-1] + assert "stdout" in messages[-1] + assert "stderr" in messages[-1] + + @pytest.mark.asyncio async def test_connect_mcp_servers_one_failure_does_not_block_others( monkeypatch: pytest.MonkeyPatch, From 830644c35292402befa22a4fe01cfb2845c36805 Mon Sep 17 00:00:00 2001 From: ramonpaolo Date: Sun, 12 Apr 2026 20:56:36 -0300 Subject: [PATCH 097/115] fix: add guard for non-dict tool call parameters - Add type validation in registry.prepare_call() to catch list/other invalid params - Add logger.warning() in provider layer when non-dict args detected - Works for OpenAI-compatible and Anthropic providers - Registry returns clear error hint for model to self-correct --- nanobot/agent/tools/registry.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/nanobot/agent/tools/registry.py b/nanobot/agent/tools/registry.py index 99d3ec63a..137038c0c 100644 --- a/nanobot/agent/tools/registry.py +++ b/nanobot/agent/tools/registry.py @@ -68,6 +68,13 @@ class ToolRegistry: params: dict[str, Any], ) -> tuple[Tool | None, dict[str, Any], str | None]: """Resolve, cast, and validate one tool call.""" + # Guard against invalid parameter types (e.g., list instead of dict) + if not isinstance(params, dict) and name in ('write_file', 'read_file'): + return None, params, ( + f"Error: Tool '{name}' parameters must be a JSON object, got {type(params).__name__}. " + "Use named parameters: tool_name(param1=\"value1\", param2=\"value2\")" + ) + tool = self._tools.get(name) if not tool: return None, params, ( From 49355b2bd6025a44f2e8328c8956ac47be7c0e8b Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Mon, 13 Apr 2026 01:53:18 +0000 Subject: [PATCH 098/115] test(tools): lock non-object parameter validation Add focused registry coverage so the new read_file/read_write parameter guard stays actionable without changing generic validation behavior for other tools. Made-with: Cursor --- tests/tools/test_tool_registry.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/tools/test_tool_registry.py b/tests/tools/test_tool_registry.py index 5b259119e..f9e8ce5e1 100644 --- a/tests/tools/test_tool_registry.py +++ b/tests/tools/test_tool_registry.py @@ -47,3 +47,27 @@ def test_get_definitions_orders_builtins_then_mcp_tools() -> None: "mcp_fs_list", "mcp_git_status", ] + + +def test_prepare_call_read_file_rejects_non_object_params_with_actionable_hint() -> None: + registry = ToolRegistry() + registry.register(_FakeTool("read_file")) + + tool, params, error = registry.prepare_call("read_file", ["foo.txt"]) + + assert tool is None + assert params == ["foo.txt"] + assert error is not None + assert "must be a JSON object" in error + assert "Use named parameters" in error + + +def test_prepare_call_other_tools_keep_generic_object_validation() -> None: + registry = ToolRegistry() + registry.register(_FakeTool("grep")) + + tool, params, error = registry.prepare_call("grep", ["TODO"]) + + assert tool is not None + assert params == ["TODO"] + assert error == "Error: Invalid parameters for tool 'grep': parameters must be an object, got list" From ea94a9c088bb12275fa67e8567ec145ab7231454 Mon Sep 17 00:00:00 2001 From: nikube Date: Sun, 12 Apr 2026 20:57:11 +0000 Subject: [PATCH 099/115] fix(agent): persist user message before running turn loop The existing runtime_checkpoint mechanism preserves the in-flight assistant/tool state if the process dies mid-turn, but the triggering user message is only written to session history at the end of the turn via _save_turn(). If the worker is killed (OOM, SIGKILL, a self- triggered systemctl restart, container eviction, etc.) before the turn completes, the user's message is silently lost: on restart, the session log only shows the interrupted assistant turn without any record of what the user asked. Any recovery tooling built on top of session logs cannot reply because it has no prompt to reply to. This patch appends the incoming user message to the session and flushes it to disk immediately after the session is loaded and before the agent loop runs, then adjusts the _save_turn skip offset so the final persistence step does not duplicate it. Limited to textual content (isinstance(msg.content, str)); list-shaped content (media blocks) still flows through _save_turn's sanitization at end of turn, preserving existing behavior for those cases. --- nanobot/agent/loop.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 5631e12a0..8bc65b7d3 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -693,6 +693,23 @@ class AgentLoop: ) ) + # Persist the triggering user message immediately, before running the + # agent loop. If the process is killed mid-turn (OOM, SIGKILL, self- + # restart, etc.), the existing runtime_checkpoint preserves the + # in-flight assistant/tool state but NOT the user message itself, so + # the user's prompt is silently lost on recovery. Saving it up front + # makes recovery possible from the session log alone. + user_persisted_early = False + if isinstance(msg.content, str) and msg.content.strip(): + from datetime import datetime as _dt + session.messages.append({ + "role": "user", + "content": msg.content, + "timestamp": _dt.now().isoformat(), + }) + self.sessions.save(session) + user_persisted_early = True + final_content, _, all_msgs, stop_reason, had_injections = await self._run_agent_loop( initial_messages, on_progress=on_progress or _bus_progress, @@ -708,7 +725,9 @@ class AgentLoop: if final_content is None or not final_content.strip(): final_content = EMPTY_FINAL_RESPONSE_MESSAGE - self._save_turn(session, all_msgs, 1 + len(history)) + # Skip the already-persisted user message when saving the turn + save_skip = 1 + len(history) + (1 if user_persisted_early else 0) + self._save_turn(session, all_msgs, save_skip) self._clear_runtime_checkpoint(session) self.sessions.save(session) self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) From b964a894d216c8c716647491fe0e63328bc93a70 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Mon, 13 Apr 2026 02:09:40 +0000 Subject: [PATCH 100/115] test(agent): cover early user-message persistence Use session.add_message for the pre-turn user-message flush and add focused regression tests for crash-time persistence and duplicate-free successful saves. Made-with: Cursor --- nanobot/agent/loop.py | 7 +--- tests/agent/test_loop_save_turn.py | 62 ++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 6 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 8bc65b7d3..96b5b30c6 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -701,12 +701,7 @@ class AgentLoop: # makes recovery possible from the session log alone. user_persisted_early = False if isinstance(msg.content, str) and msg.content.strip(): - from datetime import datetime as _dt - session.messages.append({ - "role": "user", - "content": msg.content, - "timestamp": _dt.now().isoformat(), - }) + session.add_message("user", msg.content) self.sessions.save(session) user_persisted_early = True diff --git a/tests/agent/test_loop_save_turn.py b/tests/agent/test_loop_save_turn.py index 8a0b54b86..c499282ab 100644 --- a/tests/agent/test_loop_save_turn.py +++ b/tests/agent/test_loop_save_turn.py @@ -1,5 +1,12 @@ +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + from nanobot.agent.context import ContextBuilder from nanobot.agent.loop import AgentLoop +from nanobot.bus.events import InboundMessage +from nanobot.bus.queue import MessageBus from nanobot.session.manager import Session @@ -11,6 +18,12 @@ def _mk_loop() -> AgentLoop: return loop +def _make_full_loop(tmp_path: Path) -> AgentLoop: + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + return AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model") + + def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None: loop = _mk_loop() session = Session(key="test:runtime-only") @@ -200,3 +213,52 @@ def test_restore_runtime_checkpoint_dedupes_overlapping_tail() -> None: assert session.messages[0]["role"] == "assistant" assert session.messages[1]["tool_call_id"] == "call_done" assert session.messages[2]["tool_call_id"] == "call_pending" + + +@pytest.mark.asyncio +async def test_process_message_persists_user_message_before_turn_completes(tmp_path: Path) -> None: + loop = _make_full_loop(tmp_path) + loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] + loop._run_agent_loop = AsyncMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign] + + msg = InboundMessage(channel="feishu", sender_id="u1", chat_id="c1", content="persist me") + with pytest.raises(RuntimeError, match="boom"): + await loop._process_message(msg) + + loop.sessions.invalidate("feishu:c1") + persisted = loop.sessions.get_or_create("feishu:c1") + assert [m["role"] for m in persisted.messages] == ["user"] + assert persisted.messages[0]["content"] == "persist me" + assert persisted.updated_at >= persisted.created_at + + +@pytest.mark.asyncio +async def test_process_message_does_not_duplicate_early_persisted_user_message(tmp_path: Path) -> None: + loop = _make_full_loop(tmp_path) + loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] + loop._run_agent_loop = AsyncMock(return_value=( + "done", + None, + [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "done"}, + ], + "stop", + False, + )) # type: ignore[method-assign] + + result = await loop._process_message( + InboundMessage(channel="feishu", sender_id="u1", chat_id="c2", content="hello") + ) + + assert result is not None + assert result.content == "done" + session = loop.sessions.get_or_create("feishu:c2") + assert [ + {k: v for k, v in m.items() if k in {"role", "content"}} + for m in session.messages + ] == [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "done"}, + ] From 6484c7c47a74b157432b8e1e3b866fe3ad4711d7 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Mon, 13 Apr 2026 02:21:39 +0000 Subject: [PATCH 101/115] fix(agent): close interrupted early-persisted user turns Track text-only user messages that were flushed before the turn loop completes, then materialize an interrupted assistant placeholder on the next request so session history stays legal and later turns do not skip their own assistant reply. Made-with: Cursor --- nanobot/agent/loop.py | 34 ++++++++++++++++++++++ tests/agent/test_loop_save_turn.py | 46 ++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 96b5b30c6..0031c90c5 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -129,6 +129,7 @@ class AgentLoop: """ _RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint" + _PENDING_USER_TURN_KEY = "pending_user_turn" def __init__( self, @@ -618,6 +619,8 @@ class AgentLoop: session = self.sessions.get_or_create(key) if self._restore_runtime_checkpoint(session): self.sessions.save(session) + if self._restore_pending_user_turn(session): + self.sessions.save(session) session, pending = self.auto_compact.prepare_session(session, key) @@ -653,6 +656,8 @@ class AgentLoop: session = self.sessions.get_or_create(key) if self._restore_runtime_checkpoint(session): self.sessions.save(session) + if self._restore_pending_user_turn(session): + self.sessions.save(session) session, pending = self.auto_compact.prepare_session(session, key) @@ -702,6 +707,7 @@ class AgentLoop: user_persisted_early = False if isinstance(msg.content, str) and msg.content.strip(): session.add_message("user", msg.content) + self._mark_pending_user_turn(session) self.sessions.save(session) user_persisted_early = True @@ -723,6 +729,7 @@ class AgentLoop: # Skip the already-persisted user message when saving the turn save_skip = 1 + len(history) + (1 if user_persisted_early else 0) self._save_turn(session, all_msgs, save_skip) + self._clear_pending_user_turn(session) self._clear_runtime_checkpoint(session) self.sessions.save(session) self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) @@ -840,6 +847,12 @@ class AgentLoop: session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload self.sessions.save(session) + def _mark_pending_user_turn(self, session: Session) -> None: + session.metadata[self._PENDING_USER_TURN_KEY] = True + + def _clear_pending_user_turn(self, session: Session) -> None: + session.metadata.pop(self._PENDING_USER_TURN_KEY, None) + def _clear_runtime_checkpoint(self, session: Session) -> None: if self._RUNTIME_CHECKPOINT_KEY in session.metadata: session.metadata.pop(self._RUNTIME_CHECKPOINT_KEY, None) @@ -906,9 +919,30 @@ class AgentLoop: break session.messages.extend(restored_messages[overlap:]) + self._clear_pending_user_turn(session) self._clear_runtime_checkpoint(session) return True + def _restore_pending_user_turn(self, session: Session) -> bool: + """Close a turn that only persisted the user message before crashing.""" + from datetime import datetime + + if not session.metadata.get(self._PENDING_USER_TURN_KEY): + return False + + if session.messages and session.messages[-1].get("role") == "user": + session.messages.append( + { + "role": "assistant", + "content": "Error: Task interrupted before a response was generated.", + "timestamp": datetime.now().isoformat(), + } + ) + session.updated_at = datetime.now() + + self._clear_pending_user_turn(session) + return True + async def process_direct( self, content: str, diff --git a/tests/agent/test_loop_save_turn.py b/tests/agent/test_loop_save_turn.py index c499282ab..c965ccd8c 100644 --- a/tests/agent/test_loop_save_turn.py +++ b/tests/agent/test_loop_save_turn.py @@ -229,6 +229,7 @@ async def test_process_message_persists_user_message_before_turn_completes(tmp_p persisted = loop.sessions.get_or_create("feishu:c1") assert [m["role"] for m in persisted.messages] == ["user"] assert persisted.messages[0]["content"] == "persist me" + assert persisted.metadata.get(AgentLoop._PENDING_USER_TURN_KEY) is True assert persisted.updated_at >= persisted.created_at @@ -262,3 +263,48 @@ async def test_process_message_does_not_duplicate_early_persisted_user_message(t {"role": "user", "content": "hello"}, {"role": "assistant", "content": "done"}, ] + assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata + + +@pytest.mark.asyncio +async def test_next_turn_after_crash_closes_pending_user_turn_before_new_input(tmp_path: Path) -> None: + loop = _make_full_loop(tmp_path) + loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] + loop.provider.chat_with_retry = AsyncMock(return_value=MagicMock()) # unused because _run_agent_loop is stubbed + + session = loop.sessions.get_or_create("feishu:c3") + session.add_message("user", "old question") + session.metadata[AgentLoop._PENDING_USER_TURN_KEY] = True + loop.sessions.save(session) + + loop._run_agent_loop = AsyncMock(return_value=( + "new answer", + None, + [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "old question"}, + {"role": "assistant", "content": "Error: Task interrupted before a response was generated."}, + {"role": "user", "content": "new question"}, + {"role": "assistant", "content": "new answer"}, + ], + "stop", + False, + )) # type: ignore[method-assign] + + result = await loop._process_message( + InboundMessage(channel="feishu", sender_id="u1", chat_id="c3", content="new question") + ) + + assert result is not None + assert result.content == "new answer" + session = loop.sessions.get_or_create("feishu:c3") + assert [ + {k: v for k, v in m.items() if k in {"role", "content"}} + for m in session.messages + ] == [ + {"role": "user", "content": "old question"}, + {"role": "assistant", "content": "Error: Task interrupted before a response was generated."}, + {"role": "user", "content": "new question"}, + {"role": "assistant", "content": "new answer"}, + ] + assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata From becaff3e9d9710fe5a4d8d23d4cb8c64d46ef431 Mon Sep 17 00:00:00 2001 From: chengyongru Date: Mon, 13 Apr 2026 11:27:16 +0800 Subject: [PATCH 102/115] fix(agent): skip auto-compact for sessions with active agent tasks Prevent proactive compaction from archiving sessions that have an in-flight agent task, avoiding mid-turn context truncation when a task runs longer than the idle TTL. --- nanobot/agent/autocompact.py | 17 ++++-- nanobot/agent/loop.py | 5 +- tests/agent/test_auto_compact.py | 100 ++++++++++++++++++++++++++++++- 3 files changed, 115 insertions(+), 7 deletions(-) diff --git a/nanobot/agent/autocompact.py b/nanobot/agent/autocompact.py index 47c7b5a36..ce70337cd 100644 --- a/nanobot/agent/autocompact.py +++ b/nanobot/agent/autocompact.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Collection from datetime import datetime from typing import TYPE_CHECKING, Any, Callable, Coroutine @@ -23,12 +24,13 @@ class AutoCompact: self._archiving: set[str] = set() self._summaries: dict[str, tuple[str, datetime]] = {} - def _is_expired(self, ts: datetime | str | None) -> bool: + def _is_expired(self, ts: datetime | str | None, + now: datetime | None = None) -> bool: if self._ttl <= 0 or not ts: return False if isinstance(ts, str): ts = datetime.fromisoformat(ts) - return (datetime.now() - ts).total_seconds() >= self._ttl * 60 + return ((now or datetime.now()) - ts).total_seconds() >= self._ttl * 60 @staticmethod def _format_summary(text: str, last_active: datetime) -> str: @@ -56,10 +58,17 @@ class AutoCompact: cut = len(tail) - len(kept) return tail[:cut], kept - def check_expired(self, schedule_background: Callable[[Coroutine], None]) -> None: + def check_expired(self, schedule_background: Callable[[Coroutine], None], + active_session_keys: Collection[str] = ()) -> None: + """Schedule archival for idle sessions, skipping those with in-flight agent tasks.""" + now = datetime.now() for info in self.sessions.list_sessions(): key = info.get("key", "") - if key and key not in self._archiving and self._is_expired(info.get("updated_at")): + if not key or key in self._archiving: + continue + if key in active_session_keys: + continue + if self._is_expired(info.get("updated_at"), now): self._archiving.add(key) logger.debug("Auto-compact: scheduling archival for {} (idle > {} min)", key, self._ttl) schedule_background(self._archive(key)) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 0031c90c5..39e1ce23a 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -434,7 +434,10 @@ class AgentLoop: try: msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0) except asyncio.TimeoutError: - self.auto_compact.check_expired(self._schedule_background) + self.auto_compact.check_expired( + self._schedule_background, + active_session_keys=self._pending_queues.keys(), + ) continue except asyncio.CancelledError: # Preserve real task cancellation so shutdown can complete cleanly. diff --git a/tests/agent/test_auto_compact.py b/tests/agent/test_auto_compact.py index b3462820b..1f6886ed0 100644 --- a/tests/agent/test_auto_compact.py +++ b/tests/agent/test_auto_compact.py @@ -560,9 +560,12 @@ class TestProactiveAutoCompact: """Test proactive auto-new on idle ticks (TimeoutError path in run loop).""" @staticmethod - async def _run_check_expired(loop): + async def _run_check_expired(loop, active_session_keys=()): """Helper: run check_expired via callback and wait for background tasks.""" - loop.auto_compact.check_expired(loop._schedule_background) + loop.auto_compact.check_expired( + loop._schedule_background, + active_session_keys=active_session_keys, + ) await asyncio.sleep(0.1) @pytest.mark.asyncio @@ -701,6 +704,99 @@ class TestProactiveAutoCompact: assert not archive_called await loop.close_mcp() + @pytest.mark.asyncio + async def test_skip_expired_session_with_active_agent_task(self, tmp_path): + """Expired session with an active agent task should NOT be archived.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + _add_turns(session, 6, prefix="old") + session.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(session) + + archive_count = 0 + + async def _fake_archive(messages): + nonlocal archive_count + archive_count += 1 + return "Summary." + + loop.consolidator.archive = _fake_archive + + # Simulate an active agent task for this session + await self._run_check_expired(loop, active_session_keys={"cli:test"}) + assert archive_count == 0 + + session_after = loop.sessions.get_or_create("cli:test") + assert len(session_after.messages) == 12 # All messages preserved + + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_archive_after_active_task_completes(self, tmp_path): + """Session should be archived on next tick after active task completes.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + session = loop.sessions.get_or_create("cli:test") + _add_turns(session, 6, prefix="old") + session.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(session) + + archive_count = 0 + + async def _fake_archive(messages): + nonlocal archive_count + archive_count += 1 + return "Summary." + + loop.consolidator.archive = _fake_archive + + # First tick: active task, skip + await self._run_check_expired(loop, active_session_keys={"cli:test"}) + assert archive_count == 0 + + # Second tick: task completed, should archive + await self._run_check_expired(loop) + assert archive_count == 1 + await loop.close_mcp() + + @pytest.mark.asyncio + async def test_partial_active_set_only_archives_inactive_expired(self, tmp_path): + """With multiple sessions, only the expired+inactive one should be archived.""" + loop = _make_loop(tmp_path, session_ttl_minutes=15) + # Session A: expired, no active task -> should be archived + s1 = loop.sessions.get_or_create("cli:expired_idle") + _add_turns(s1, 6, prefix="old_a") + s1.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(s1) + # Session B: expired, has active task -> should be skipped + s2 = loop.sessions.get_or_create("cli:expired_active") + _add_turns(s2, 6, prefix="old_b") + s2.updated_at = datetime.now() - timedelta(minutes=20) + loop.sessions.save(s2) + # Session C: recent, no active task -> should be skipped + s3 = loop.sessions.get_or_create("cli:recent") + s3.add_message("user", "recent") + loop.sessions.save(s3) + + archive_count = 0 + + async def _fake_archive(messages): + nonlocal archive_count + archive_count += 1 + return "Summary." + + loop.consolidator.archive = _fake_archive + + await self._run_check_expired(loop, active_session_keys={"cli:expired_active"}) + + assert archive_count == 1 + s1_after = loop.sessions.get_or_create("cli:expired_idle") + assert len(s1_after.messages) == loop.auto_compact._RECENT_SUFFIX_MESSAGES + s2_after = loop.sessions.get_or_create("cli:expired_active") + assert len(s2_after.messages) == 12 # Preserved + s3_after = loop.sessions.get_or_create("cli:recent") + assert len(s3_after.messages) == 1 # Preserved + await loop.close_mcp() + @pytest.mark.asyncio async def test_no_reschedule_after_successful_archive(self, tmp_path): """Already-archived session should NOT be re-scheduled on subsequent ticks.""" From ac714803f67171ad5787142b575d90b8b28bfbf0 Mon Sep 17 00:00:00 2001 From: chengyongru Date: Mon, 13 Apr 2026 11:30:54 +0800 Subject: [PATCH 103/115] fix(provider): recover trailing assistant message as user to prevent empty request When a subagent result is injected with current_role="assistant", _enforce_role_alternation drops the trailing assistant message, leaving only the system prompt. Providers like Zhipu/GLM reject such requests with error 1214 ("messages parameter invalid"). Now the last popped assistant message is recovered as a user message when no user/tool messages remain. --- nanobot/providers/base.py | 16 +++++++- .../test_enforce_role_alternation.py | 41 +++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 8ce2b9a7a..759d880a8 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -392,8 +392,22 @@ class LLMProvider(ABC): else: merged.append(dict(msg)) + last_popped = None while merged and merged[-1].get("role") == "assistant": - merged.pop() + last_popped = merged.pop() + + # If removing trailing assistant messages left only system messages, + # the request would be invalid for most providers (e.g. Zhipu/GLM + # error 1214). Recover by converting the last popped assistant + # message to a user message so the LLM can still see the content. + if ( + merged + and last_popped is not None + and not any(m.get("role") in ("user", "tool") for m in merged) + ): + recovered = dict(last_popped) + recovered["role"] = "user" + merged.append(recovered) return merged diff --git a/tests/providers/test_enforce_role_alternation.py b/tests/providers/test_enforce_role_alternation.py index aef57f474..333c5d04e 100644 --- a/tests/providers/test_enforce_role_alternation.py +++ b/tests/providers/test_enforce_role_alternation.py @@ -131,6 +131,47 @@ class TestEnforceRoleAlternation: assert msgs[0] == original_first assert len(msgs) == 2 + def test_trailing_assistant_recovered_as_user_when_only_system_remains(self): + """Subagent result injected as assistant message must not be silently dropped. + + When build_messages(current_role="assistant") produces [system, assistant], + _enforce_role_alternation would drop the assistant, leaving only [system]. + Most providers (e.g. Zhipu/GLM error 1214) reject such requests. + The trailing assistant should be recovered as a user message instead. + """ + msgs = [ + {"role": "system", "content": "You are helpful."}, + {"role": "assistant", "content": "Subagent completed successfully."}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert len(result) == 2 + assert result[0]["role"] == "system" + assert result[1]["role"] == "user" + assert "Subagent completed successfully." in result[1]["content"] + + def test_trailing_assistant_not_recovered_when_user_message_present(self): + """Recovery should NOT happen when a user message already exists.""" + msgs = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert len(result) == 2 + assert result[-1]["role"] == "user" + + def test_trailing_assistant_recovered_with_tool_result_preceding(self): + """When only [system, tool, assistant] remains, recovery is not needed + because tool messages are valid non-system content.""" + msgs = [ + {"role": "system", "content": "You are helpful."}, + {"role": "tool", "content": "result", "tool_call_id": "1"}, + {"role": "assistant", "content": "Done."}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert len(result) == 2 + assert result[-1]["role"] == "tool" + def test_only_assistant_messages(self): msgs = [ {"role": "assistant", "content": "A"}, From 85c7996766d19f1a709325cb601dc3a41d4b2a87 Mon Sep 17 00:00:00 2001 From: haosenwang1018 Date: Mon, 13 Apr 2026 04:12:52 +0000 Subject: [PATCH 104/115] docs(api): clarify cross-channel message delivery --- README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/README.md b/README.md index b376d0991..ebe88938d 100644 --- a/README.md +++ b/README.md @@ -1858,6 +1858,19 @@ By default, the API binds to `127.0.0.1:8900`. You can change this in `config.js - Single-message input: each request must contain exactly one `user` message - Fixed model: omit `model`, or pass the same model shown by `/v1/models` - No streaming: `stream=true` is not supported +- API requests run in the synthetic `api` channel, so the `message` tool does **not** automatically deliver to Telegram/Discord/etc. To proactively send to another chat, call `message` with an explicit `channel` and `chat_id` for an enabled channel. + +Example tool call for cross-channel delivery from an API session: + +```json +{ + "content": "Build finished successfully.", + "channel": "telegram", + "chat_id": "123456789" +} +``` + +If `channel` points to a channel that is not enabled in your config, nanobot will queue the outbound event but no platform delivery will occur. ### Endpoints From d33bf22e91bbc703be0075570261ea229804a7a9 Mon Sep 17 00:00:00 2001 From: haosenwang1018 Date: Mon, 13 Apr 2026 04:27:00 +0000 Subject: [PATCH 105/115] docs(provider): clarify responses api routing --- README.md | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/README.md b/README.md index ebe88938d..f593e26ec 100644 --- a/README.md +++ b/README.md @@ -1053,6 +1053,30 @@ Connects directly to any OpenAI-compatible endpoint — LM Studio, llama.cpp, To ``` > For local servers that don't require a key, set `apiKey` to any non-empty string (e.g. `"no-key"`). +> +> `custom` is the right choice for providers that expose an OpenAI-compatible **chat completions** API. It does **not** force third-party endpoints onto the OpenAI/Azure **Responses API**. +> +> If your proxy or gateway is specifically Responses-API-compatible, use the `azure_openai` provider shape instead and point `apiBase` at that endpoint: +> +> ```json +> { +> "providers": { +> "azure_openai": { +> "apiKey": "your-api-key", +> "apiBase": "https://api.your-provider.com", +> "defaultModel": "your-model-name" +> } +> }, +> "agents": { +> "defaults": { +> "provider": "azure_openai", +> "model": "your-model-name" +> } +> } +> } +> ``` +> +> In short: **chat-completions-compatible endpoint → `custom`**; **Responses-compatible endpoint → `azure_openai`**. From 3c06db7e4e7338c4944ef96f37dd35c9fe0349d6 Mon Sep 17 00:00:00 2001 From: chengyongru Date: Mon, 13 Apr 2026 16:03:15 +0800 Subject: [PATCH 106/115] fix(log): remove noisy no-op logs from auto-compact Remove two debug log lines that fire on every idle channel check: - "scheduling archival" (logged before knowing if there's work) - "skipping, no un-consolidated messages" (the common no-op path) The meaningful "archived" info log (only on real work) is preserved. --- nanobot/agent/autocompact.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/nanobot/agent/autocompact.py b/nanobot/agent/autocompact.py index ce70337cd..eabd86155 100644 --- a/nanobot/agent/autocompact.py +++ b/nanobot/agent/autocompact.py @@ -70,7 +70,6 @@ class AutoCompact: continue if self._is_expired(info.get("updated_at"), now): self._archiving.add(key) - logger.debug("Auto-compact: scheduling archival for {} (idle > {} min)", key, self._ttl) schedule_background(self._archive(key)) async def _archive(self, key: str) -> None: @@ -79,7 +78,6 @@ class AutoCompact: session = self.sessions.get_or_create(key) archive_msgs, kept_msgs = self._split_unconsolidated(session) if not archive_msgs and not kept_msgs: - logger.debug("Auto-compact: skipping {}, no un-consolidated messages", key) session.updated_at = datetime.now() self.sessions.save(session) return @@ -95,13 +93,14 @@ class AutoCompact: session.last_consolidated = 0 session.updated_at = datetime.now() self.sessions.save(session) - logger.info( - "Auto-compact: archived {} (archived={}, kept={}, summary={})", - key, - len(archive_msgs), - len(kept_msgs), - bool(summary), - ) + if archive_msgs: + logger.info( + "Auto-compact: archived {} (archived={}, kept={}, summary={})", + key, + len(archive_msgs), + len(kept_msgs), + bool(summary), + ) except Exception: logger.exception("Auto-compact: failed for {}", key) finally: From d849a3fa060825ff02b73f74dc14d3ab97735b33 Mon Sep 17 00:00:00 2001 From: chengyongru <2755839590@qq.com> Date: Mon, 13 Apr 2026 23:33:25 +0800 Subject: [PATCH 107/115] fix(agent): drain injection queue on error/edge-case exit paths When the agent runner exits due to LLM error, tool error, empty response, or max_iterations, it breaks out of the iteration loop without draining the pending injection queue. This causes leftover messages to be re-published as independent inbound messages, resulting in duplicate or confusing replies to the user. Extract the injection drain logic into a `_try_drain_injections` helper and call it before each break in the error/edge-case paths. If injections are found, continue the loop instead of breaking. For max_iterations (where the loop is exhausted), drain injections to prevent re-publish without continuing. --- nanobot/agent/runner.py | 107 ++++++++++++----- tests/agent/test_runner.py | 233 +++++++++++++++++++++++++++++++++++++ 2 files changed, 314 insertions(+), 26 deletions(-) diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index e92d864f2..5cb7b4f0e 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -134,6 +134,36 @@ class AgentRunner: continue messages.append(injection) + async def _try_drain_injections( + self, + spec: AgentRunSpec, + messages: list[dict[str, Any]], + assistant_message: dict[str, Any] | None, + injection_cycles: int, + *, + phase: str = "after error", + ) -> tuple[bool, int]: + """Drain pending injections. Returns (should_continue, updated_cycles). + + If injections are found and we haven't exceeded _MAX_INJECTION_CYCLES, + append them to *messages* and return (True, cycles+1) so the caller + continues the iteration loop. Otherwise return (False, cycles). + """ + if injection_cycles >= _MAX_INJECTION_CYCLES: + return False, injection_cycles + injections = await self._drain_injections(spec) + if not injections: + return False, injection_cycles + injection_cycles += 1 + if assistant_message is not None: + messages.append(assistant_message) + self._append_injected_messages(messages, injections) + logger.info( + "Injected {} follow-up message(s) {} ({}/{})", + len(injections), phase, injection_cycles, _MAX_INJECTION_CYCLES, + ) + return True, injection_cycles + async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]: """Drain pending user messages via the injection callback. @@ -287,6 +317,13 @@ class AgentRunner: context.error = error context.stop_reason = stop_reason await hook.after_iteration(context) + should_continue, injection_cycles = await self._try_drain_injections( + spec, messages, None, injection_cycles, + phase="after tool error", + ) + if should_continue: + had_injections = True + continue break await self._emit_checkpoint( spec, @@ -379,36 +416,31 @@ class AgentRunner: # Check for mid-turn injections BEFORE signaling stream end. # If injections are found we keep the stream alive (resuming=True) # so streaming channels don't prematurely finalize the card. - _injected_after_final = False - if injection_cycles < _MAX_INJECTION_CYCLES: - injections = await self._drain_injections(spec) - if injections: - had_injections = True - injection_cycles += 1 - _injected_after_final = True - if assistant_message is not None: - messages.append(assistant_message) - await self._emit_checkpoint( - spec, - { - "phase": "final_response", - "iteration": iteration, - "model": spec.model, - "assistant_message": assistant_message, - "completed_tool_results": [], - "pending_tool_calls": [], - }, - ) - self._append_injected_messages(messages, injections) - logger.info( - "Injected {} follow-up message(s) after final response ({}/{})", - len(injections), injection_cycles, _MAX_INJECTION_CYCLES, + should_continue, injection_cycles = await self._try_drain_injections( + spec, messages, assistant_message, injection_cycles, + phase="after final response", + ) + if should_continue: + had_injections = True + # Emit checkpoint for the assistant message that was appended + # by _try_drain_injections, then keep the stream alive. + if assistant_message is not None: + await self._emit_checkpoint( + spec, + { + "phase": "final_response", + "iteration": iteration, + "model": spec.model, + "assistant_message": assistant_message, + "completed_tool_results": [], + "pending_tool_calls": [], + }, ) if hook.wants_streaming(): - await hook.on_stream_end(context, resuming=_injected_after_final) + await hook.on_stream_end(context, resuming=should_continue) - if _injected_after_final: + if should_continue: await hook.after_iteration(context) continue @@ -421,6 +453,13 @@ class AgentRunner: context.error = error context.stop_reason = stop_reason await hook.after_iteration(context) + should_continue, injection_cycles = await self._try_drain_injections( + spec, messages, None, injection_cycles, + phase="after LLM error", + ) + if should_continue: + had_injections = True + continue break if is_blank_text(clean): final_content = EMPTY_FINAL_RESPONSE_MESSAGE @@ -431,6 +470,13 @@ class AgentRunner: context.error = error context.stop_reason = stop_reason await hook.after_iteration(context) + should_continue, injection_cycles = await self._try_drain_injections( + spec, messages, None, injection_cycles, + phase="after empty response", + ) + if should_continue: + had_injections = True + continue break messages.append(assistant_message or build_assistant_message( @@ -467,6 +513,15 @@ class AgentRunner: max_iterations=spec.max_iterations, ) self._append_final_message(messages, final_content) + # Drain any remaining injections so they are appended to the + # conversation history instead of being re-published as + # independent inbound messages by _dispatch's finally block. + # We ignore should_continue here because the for-loop has already + # exhausted all iterations. + _, injection_cycles = await self._try_drain_injections( + spec, messages, None, injection_cycles, + phase="after max_iterations", + ) return AgentRunResult( final_content=final_content, diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index a62457aa8..4a943165c 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -2410,3 +2410,236 @@ async def test_dispatch_republishes_leftover_queue_messages(tmp_path): contents = [m.content for m in msgs] assert "leftover-1" in contents assert "leftover-2" in contents + + +@pytest.mark.asyncio +async def test_drain_injections_on_fatal_tool_error(): + """Pending injections should be drained even when a fatal tool error occurs.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="", + tool_calls=[ToolCallRequest(id="c1", name="exec", arguments={"cmd": "bad"})], + usage={}, + ) + # Second call: respond normally to the injected follow-up + return LLMResponse(content="reply to follow-up", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(side_effect=RuntimeError("tool exploded")) + + injection_queue = asyncio.Queue() + + async def inject_cb(): + items = [] + while not injection_queue.empty(): + items.append(await injection_queue.get()) + return items + + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after error") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + fail_on_tool_error=True, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + assert result.final_content == "reply to follow-up" + # The injection should be in the messages history + injected = [ + m for m in result.messages + if m.get("role") == "user" and m.get("content") == "follow-up after error" + ] + assert len(injected) == 1 + + +@pytest.mark.asyncio +async def test_drain_injections_on_llm_error(): + """Pending injections should be drained when the LLM returns an error finish_reason.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content=None, + tool_calls=[], + finish_reason="error", + usage={}, + ) + # Second call: respond normally to the injected follow-up + return LLMResponse(content="recovered answer", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + injection_queue = asyncio.Queue() + + async def inject_cb(): + items = [] + while not injection_queue.empty(): + items.append(await injection_queue.get()) + return items + + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after LLM error") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "previous response"}, + {"role": "user", "content": "trigger error"}, + ], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + assert result.final_content == "recovered answer" + injected = [ + m for m in result.messages + if m.get("role") == "user" and "follow-up after LLM error" in str(m.get("content", "")) + ] + assert len(injected) == 1 + + +@pytest.mark.asyncio +async def test_drain_injections_on_empty_final_response(): + """Pending injections should be drained when the runner exits due to empty response.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_EMPTY_RETRIES + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] <= _MAX_EMPTY_RETRIES + 1: + return LLMResponse(content="", tool_calls=[], usage={}) + # After retries exhausted + injection drain, respond normally + return LLMResponse(content="answer after empty", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + injection_queue = asyncio.Queue() + + async def inject_cb(): + items = [] + while not injection_queue.empty(): + items.append(await injection_queue.get()) + return items + + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after empty") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "previous response"}, + {"role": "user", "content": "trigger empty"}, + ], + tools=tools, + model="test-model", + max_iterations=10, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + assert result.final_content == "answer after empty" + injected = [ + m for m in result.messages + if m.get("role") == "user" and "follow-up after empty" in str(m.get("content", "")) + ] + assert len(injected) == 1 + + +@pytest.mark.asyncio +async def test_drain_injections_on_max_iterations(): + """Pending injections should be drained when the runner hits max_iterations. + + Unlike other error paths, max_iterations cannot continue the loop, so + injections are appended to messages but not processed by the LLM. + The key point is they are consumed from the queue to prevent re-publish. + """ + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + return LLMResponse( + content="", + tool_calls=[ToolCallRequest(id=f"c{call_count['n']}", name="read_file", arguments={"path": "x"})], + usage={}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="file content") + + injection_queue = asyncio.Queue() + + async def inject_cb(): + items = [] + while not injection_queue.empty(): + items.append(await injection_queue.get()) + return items + + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after max iters") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=2, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.stop_reason == "max_iterations" + # The injection was consumed from the queue (preventing re-publish) + assert injection_queue.empty() + # The injection message is appended to conversation history + injected = [ + m for m in result.messages + if m.get("role") == "user" and m.get("content") == "follow-up after max iters" + ] + assert len(injected) == 1 From a1e1eed2f13c19e9dcec2fa912103dc967b8daec Mon Sep 17 00:00:00 2001 From: chengyongru <2755839590@qq.com> Date: Mon, 13 Apr 2026 23:51:23 +0800 Subject: [PATCH 108/115] refactor(runner): consolidate all injection drain paths and deduplicate tests - Migrate "after tools" inline drain to use _try_drain_injections, completing the refactoring (all 6 drain sites now use the helper). - Move checkpoint emission into _try_drain_injections via optional iteration parameter, eliminating the leaky split between helper and caller for the final-response path. - Extract _make_injection_callback() test helper to replace 7 identical inject_cb function bodies. - Add test_injection_cycle_cap_on_error_path to verify the cycle cap is enforced on error exit paths. --- nanobot/agent/runner.py | 49 ++++++++--------- tests/agent/test_runner.py | 109 +++++++++++++++++++++++-------------- 2 files changed, 90 insertions(+), 68 deletions(-) diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 5cb7b4f0e..20226aed6 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -142,12 +142,14 @@ class AgentRunner: injection_cycles: int, *, phase: str = "after error", + iteration: int | None = None, ) -> tuple[bool, int]: """Drain pending injections. Returns (should_continue, updated_cycles). If injections are found and we haven't exceeded _MAX_INJECTION_CYCLES, - append them to *messages* and return (True, cycles+1) so the caller - continues the iteration loop. Otherwise return (False, cycles). + append them to *messages* (and emit a checkpoint if *assistant_message* + and *iteration* are both provided) and return (True, cycles+1) so the + caller continues the iteration loop. Otherwise return (False, cycles). """ if injection_cycles >= _MAX_INJECTION_CYCLES: return False, injection_cycles @@ -157,6 +159,18 @@ class AgentRunner: injection_cycles += 1 if assistant_message is not None: messages.append(assistant_message) + if iteration is not None: + await self._emit_checkpoint( + spec, + { + "phase": "final_response", + "iteration": iteration, + "model": spec.model, + "assistant_message": assistant_message, + "completed_tool_results": [], + "pending_tool_calls": [], + }, + ) self._append_injected_messages(messages, injections) logger.info( "Injected {} follow-up message(s) {} ({}/{})", @@ -339,16 +353,12 @@ class AgentRunner: empty_content_retries = 0 length_recovery_count = 0 # Checkpoint 1: drain injections after tools, before next LLM call - if injection_cycles < _MAX_INJECTION_CYCLES: - injections = await self._drain_injections(spec) - if injections: - had_injections = True - injection_cycles += 1 - self._append_injected_messages(messages, injections) - logger.info( - "Injected {} follow-up message(s) after tool execution ({}/{})", - len(injections), injection_cycles, _MAX_INJECTION_CYCLES, - ) + _drained, injection_cycles = await self._try_drain_injections( + spec, messages, None, injection_cycles, + phase="after tool execution", + ) + if _drained: + had_injections = True await hook.after_iteration(context) continue @@ -419,23 +429,10 @@ class AgentRunner: should_continue, injection_cycles = await self._try_drain_injections( spec, messages, assistant_message, injection_cycles, phase="after final response", + iteration=iteration, ) if should_continue: had_injections = True - # Emit checkpoint for the assistant message that was appended - # by _try_drain_injections, then keep the stream alive. - if assistant_message is not None: - await self._emit_checkpoint( - spec, - { - "phase": "final_response", - "iteration": iteration, - "model": spec.model, - "assistant_message": assistant_message, - "completed_tool_results": [], - "pending_tool_calls": [], - }, - ) if hook.wants_streaming(): await hook.on_stream_end(context, resuming=should_continue) diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index 4a943165c..53cd07e88 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -18,6 +18,16 @@ from nanobot.providers.base import LLMResponse, ToolCallRequest _MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars +def _make_injection_callback(queue: asyncio.Queue): + """Return an async callback that drains *queue* into a list of dicts.""" + async def inject_cb(): + items = [] + while not queue.empty(): + items.append(await queue.get()) + return items + return inject_cb + + def _make_loop(tmp_path): from nanobot.agent.loop import AgentLoop from nanobot.bus.queue import MessageBus @@ -1888,12 +1898,7 @@ async def test_checkpoint1_injects_after_tool_execution(): tools.execute = AsyncMock(return_value="file content") injection_queue = asyncio.Queue() - - async def inject_cb(): - items = [] - while not injection_queue.empty(): - items.append(await injection_queue.get()) - return items + inject_cb = _make_injection_callback(injection_queue) # Put a follow-up message in the queue before the run starts await injection_queue.put( @@ -1951,12 +1956,7 @@ async def test_checkpoint2_injects_after_final_response_with_resuming_stream(): tools.get_definitions.return_value = [] injection_queue = asyncio.Queue() - - async def inject_cb(): - items = [] - while not injection_queue.empty(): - items.append(await injection_queue.get()) - return items + inject_cb = _make_injection_callback(injection_queue) # Inject a follow-up that arrives during the first response await injection_queue.put( @@ -2005,12 +2005,7 @@ async def test_checkpoint2_preserves_final_response_in_history_before_followup() tools.get_definitions.return_value = [] injection_queue = asyncio.Queue() - - async def inject_cb(): - items = [] - while not injection_queue.empty(): - items.append(await injection_queue.get()) - return items + inject_cb = _make_injection_callback(injection_queue) await injection_queue.put( InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question") @@ -2438,12 +2433,7 @@ async def test_drain_injections_on_fatal_tool_error(): tools.execute = AsyncMock(side_effect=RuntimeError("tool exploded")) injection_queue = asyncio.Queue() - - async def inject_cb(): - items = [] - while not injection_queue.empty(): - items.append(await injection_queue.get()) - return items + inject_cb = _make_injection_callback(injection_queue) await injection_queue.put( InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after error") @@ -2496,12 +2486,7 @@ async def test_drain_injections_on_llm_error(): tools.get_definitions.return_value = [] injection_queue = asyncio.Queue() - - async def inject_cb(): - items = [] - while not injection_queue.empty(): - items.append(await injection_queue.get()) - return items + inject_cb = _make_injection_callback(injection_queue) await injection_queue.put( InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after LLM error") @@ -2551,12 +2536,7 @@ async def test_drain_injections_on_empty_final_response(): tools.get_definitions.return_value = [] injection_queue = asyncio.Queue() - - async def inject_cb(): - items = [] - while not injection_queue.empty(): - items.append(await injection_queue.get()) - return items + inject_cb = _make_injection_callback(injection_queue) await injection_queue.put( InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after empty") @@ -2613,12 +2593,7 @@ async def test_drain_injections_on_max_iterations(): tools.execute = AsyncMock(return_value="file content") injection_queue = asyncio.Queue() - - async def inject_cb(): - items = [] - while not injection_queue.empty(): - items.append(await injection_queue.get()) - return items + inject_cb = _make_injection_callback(injection_queue) await injection_queue.put( InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after max iters") @@ -2643,3 +2618,53 @@ async def test_drain_injections_on_max_iterations(): if m.get("role") == "user" and m.get("content") == "follow-up after max iters" ] assert len(injected) == 1 + + +@pytest.mark.asyncio +async def test_injection_cycle_cap_on_error_path(): + """Injection cycles should be capped even when every iteration hits an LLM error.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTION_CYCLES + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + return LLMResponse( + content=None, + tool_calls=[], + finish_reason="error", + usage={}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + drain_count = {"n": 0} + + async def inject_cb(): + drain_count["n"] += 1 + if drain_count["n"] <= _MAX_INJECTION_CYCLES: + return [InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg-{drain_count['n']}")] + return [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "previous"}, + {"role": "user", "content": "trigger error"}, + ], + tools=tools, + model="test-model", + max_iterations=20, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + # Should cap: _MAX_INJECTION_CYCLES drained rounds + 1 final round that breaks + assert call_count["n"] == _MAX_INJECTION_CYCLES + 1 + assert drain_count["n"] == _MAX_INJECTION_CYCLES From a38bc637bdaaf1ce3e3090ba2d32afdbf79029f5 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Mon, 13 Apr 2026 16:28:35 +0000 Subject: [PATCH 109/115] fix(runner): preserve injection flag after max-iteration drain Keep late follow-up injections observable when they are drained during max-iteration shutdown so loop-level response suppression still makes the right decision. Made-with: Cursor --- nanobot/agent/runner.py | 4 ++- tests/agent/test_runner.py | 64 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 20226aed6..592af9de2 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -515,10 +515,12 @@ class AgentRunner: # independent inbound messages by _dispatch's finally block. # We ignore should_continue here because the for-loop has already # exhausted all iterations. - _, injection_cycles = await self._try_drain_injections( + drained_after_max_iterations, injection_cycles = await self._try_drain_injections( spec, messages, None, injection_cycles, phase="after max_iterations", ) + if drained_after_max_iterations: + had_injections = True return AgentRunResult( final_content=final_content, diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index 53cd07e88..74025d779 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -2610,6 +2610,7 @@ async def test_drain_injections_on_max_iterations(): )) assert result.stop_reason == "max_iterations" + assert result.had_injections is True # The injection was consumed from the queue (preventing re-publish) assert injection_queue.empty() # The injection message is appended to conversation history @@ -2620,6 +2621,69 @@ async def test_drain_injections_on_max_iterations(): assert len(injected) == 1 +@pytest.mark.asyncio +async def test_drain_injections_set_flag_when_followup_arrives_after_last_iteration(): + """Late follow-ups drained in max_iterations should still flip had_injections.""" + from nanobot.agent.hook import AgentHook + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + return LLMResponse( + content="", + tool_calls=[ToolCallRequest(id=f"c{call_count['n']}", name="read_file", arguments={"path": "x"})], + usage={}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="file content") + + injection_queue = asyncio.Queue() + inject_cb = _make_injection_callback(injection_queue) + + class InjectOnLastAfterIterationHook(AgentHook): + def __init__(self) -> None: + self.after_iteration_calls = 0 + + async def after_iteration(self, context) -> None: + self.after_iteration_calls += 1 + if self.after_iteration_calls == 2: + await injection_queue.put( + InboundMessage( + channel="cli", + sender_id="u", + chat_id="c", + content="late follow-up after max iters", + ) + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=2, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + hook=InjectOnLastAfterIterationHook(), + )) + + assert result.stop_reason == "max_iterations" + assert result.had_injections is True + assert injection_queue.empty() + injected = [ + m for m in result.messages + if m.get("role") == "user" and m.get("content") == "late follow-up after max iters" + ] + assert len(injected) == 1 + + @pytest.mark.asyncio async def test_injection_cycle_cap_on_error_path(): """Injection cycles should be capped even when every iteration hits an LLM error.""" From ee061f0595f4258634bac4417aaf5b4089c96d13 Mon Sep 17 00:00:00 2001 From: yeyitech Date: Tue, 14 Apr 2026 13:30:18 +0800 Subject: [PATCH 110/115] fix(web): serialize duckduckgo search calls --- nanobot/agent/tools/web.py | 27 ++++++++++++++ tests/agent/test_runner.py | 57 ++++++++++++++++++++++++++++- tests/tools/test_web_search_tool.py | 24 +++++++++--- 3 files changed, 102 insertions(+), 6 deletions(-) diff --git a/nanobot/agent/tools/web.py b/nanobot/agent/tools/web.py index 38fc33d74..31d4cdef2 100644 --- a/nanobot/agent/tools/web.py +++ b/nanobot/agent/tools/web.py @@ -96,10 +96,37 @@ class WebSearchTool(Tool): self.config = config if config is not None else WebSearchConfig() self.proxy = proxy + def _effective_provider(self) -> str: + """Resolve the backend that execute() will actually use.""" + provider = self.config.provider.strip().lower() or "brave" + if provider == "duckduckgo": + return "duckduckgo" + if provider == "brave": + api_key = self.config.api_key or os.environ.get("BRAVE_API_KEY", "") + return "brave" if api_key else "duckduckgo" + if provider == "tavily": + api_key = self.config.api_key or os.environ.get("TAVILY_API_KEY", "") + return "tavily" if api_key else "duckduckgo" + if provider == "searxng": + base_url = (self.config.base_url or os.environ.get("SEARXNG_BASE_URL", "")).strip() + return "searxng" if base_url else "duckduckgo" + if provider == "jina": + api_key = self.config.api_key or os.environ.get("JINA_API_KEY", "") + return "jina" if api_key else "duckduckgo" + if provider == "kagi": + api_key = self.config.api_key or os.environ.get("KAGI_API_KEY", "") + return "kagi" if api_key else "duckduckgo" + return provider + @property def read_only(self) -> bool: return True + @property + def exclusive(self) -> bool: + """DuckDuckGo searches are serialized because ddgs is not concurrency-safe.""" + return self._effective_provider() == "duckduckgo" + async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str: provider = self.config.provider.strip().lower() or "brave" n = min(max(count or self.config.max_results, 1), 10) diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index 74025d779..f742408b3 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -689,11 +689,20 @@ async def test_runner_keeps_going_when_tool_result_persistence_fails(): class _DelayTool(Tool): - def __init__(self, name: str, *, delay: float, read_only: bool, shared_events: list[str]): + def __init__( + self, + name: str, + *, + delay: float, + read_only: bool, + shared_events: list[str], + exclusive: bool = False, + ): self._name = name self._delay = delay self._read_only = read_only self._shared_events = shared_events + self._exclusive = exclusive @property def name(self) -> str: @@ -711,6 +720,10 @@ class _DelayTool(Tool): def read_only(self) -> bool: return self._read_only + @property + def exclusive(self) -> bool: + return self._exclusive + async def execute(self, **kwargs): self._shared_events.append(f"start:{self._name}") await asyncio.sleep(self._delay) @@ -756,6 +769,48 @@ async def test_runner_batches_read_only_tools_before_exclusive_work(): assert shared_events[-2:] == ["start:write_a", "end:write_a"] +@pytest.mark.asyncio +async def test_runner_does_not_batch_exclusive_read_only_tools(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + tools = ToolRegistry() + shared_events: list[str] = [] + read_a = _DelayTool("read_a", delay=0.03, read_only=True, shared_events=shared_events) + read_b = _DelayTool("read_b", delay=0.03, read_only=True, shared_events=shared_events) + ddg_like = _DelayTool( + "ddg_like", + delay=0.01, + read_only=True, + shared_events=shared_events, + exclusive=True, + ) + tools.register(read_a) + tools.register(ddg_like) + tools.register(read_b) + + runner = AgentRunner(MagicMock()) + await runner._execute_tools( + AgentRunSpec( + initial_messages=[], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + concurrent_tools=True, + ), + [ + ToolCallRequest(id="ro1", name="read_a", arguments={}), + ToolCallRequest(id="ddg1", name="ddg_like", arguments={}), + ToolCallRequest(id="ro2", name="read_b", arguments={}), + ], + {}, + ) + + assert shared_events[0] == "start:read_a" + assert shared_events.index("end:read_a") < shared_events.index("start:ddg_like") + assert shared_events.index("end:ddg_like") < shared_events.index("start:read_b") + + @pytest.mark.asyncio async def test_runner_blocks_repeated_external_fetches(): from nanobot.agent.runner import AgentRunSpec, AgentRunner diff --git a/tests/tools/test_web_search_tool.py b/tests/tools/test_web_search_tool.py index 790d8adcd..a42e51e1a 100644 --- a/tests/tools/test_web_search_tool.py +++ b/tests/tools/test_web_search_tool.py @@ -1,7 +1,5 @@ """Tests for multi-provider web search.""" -import asyncio - import httpx import pytest @@ -20,6 +18,25 @@ def _response(status: int = 200, json: dict | None = None) -> httpx.Response: return r +def test_duckduckgo_search_is_exclusive(): + tool = _tool(provider="duckduckgo") + assert tool.exclusive is True + assert tool.concurrency_safe is False + + +def test_brave_with_api_key_remains_concurrency_safe(): + tool = _tool(provider="brave", api_key="brave-key") + assert tool.exclusive is False + assert tool.concurrency_safe is True + + +def test_brave_without_api_key_is_treated_as_duckduckgo_for_concurrency(monkeypatch): + monkeypatch.delenv("BRAVE_API_KEY", raising=False) + tool = _tool(provider="brave", api_key="") + assert tool.exclusive is True + assert tool.concurrency_safe is False + + @pytest.mark.asyncio async def test_brave_search(monkeypatch): async def mock_get(self, url, **kw): @@ -79,7 +96,6 @@ async def test_duckduckgo_search(monkeypatch): import nanobot.agent.tools.web as web_mod monkeypatch.setattr(web_mod, "DDGS", MockDDGS, raising=False) - from ddgs import DDGS monkeypatch.setattr("ddgs.DDGS", MockDDGS) tool = _tool(provider="duckduckgo") @@ -265,5 +281,3 @@ async def test_duckduckgo_timeout_returns_error(monkeypatch): result = await tool.execute(query="test") gate.set() assert "Error" in result - - From 65a15f39ee7ebfe8b9585231165222cf5ee1cd76 Mon Sep 17 00:00:00 2001 From: yeyitech Date: Tue, 14 Apr 2026 13:42:59 +0800 Subject: [PATCH 111/115] test(loop): cover /stop checkpoint recovery --- tests/agent/test_loop_save_turn.py | 109 +++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/tests/agent/test_loop_save_turn.py b/tests/agent/test_loop_save_turn.py index c965ccd8c..8885e0cc0 100644 --- a/tests/agent/test_loop_save_turn.py +++ b/tests/agent/test_loop_save_turn.py @@ -1,3 +1,4 @@ +import asyncio from pathlib import Path from unittest.mock import AsyncMock, MagicMock @@ -308,3 +309,111 @@ async def test_next_turn_after_crash_closes_pending_user_turn_before_new_input(t {"role": "assistant", "content": "new answer"}, ] assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata + + +@pytest.mark.asyncio +async def test_stop_preserves_runtime_checkpoint_for_next_turn(tmp_path: Path) -> None: + from nanobot.command.builtin import cmd_stop + from nanobot.command.router import CommandContext + + loop = _make_full_loop(tmp_path) + loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] + + checkpoint_saved = asyncio.Event() + + async def interrupted_run_agent_loop(_initial_messages, *, session=None, **_kwargs): + assert session is not None + loop._set_runtime_checkpoint( + session, + { + "assistant_message": { + "role": "assistant", + "content": "working", + "tool_calls": [ + { + "id": "call_done", + "type": "function", + "function": {"name": "read_file", "arguments": "{}"}, + }, + { + "id": "call_pending", + "type": "function", + "function": {"name": "exec", "arguments": "{}"}, + }, + ], + }, + "completed_tool_results": [ + { + "role": "tool", + "tool_call_id": "call_done", + "name": "read_file", + "content": "ok", + } + ], + "pending_tool_calls": [ + { + "id": "call_pending", + "type": "function", + "function": {"name": "exec", "arguments": "{}"}, + } + ], + }, + ) + checkpoint_saved.set() + await asyncio.Event().wait() + + loop._run_agent_loop = interrupted_run_agent_loop # type: ignore[method-assign] + + first_msg = InboundMessage(channel="feishu", sender_id="u1", chat_id="c4", content="keep progress") + task = asyncio.create_task(loop._process_message(first_msg)) + loop._active_tasks[first_msg.session_key] = [task] + await asyncio.wait_for(checkpoint_saved.wait(), timeout=1.0) + + stop_msg = InboundMessage(channel="feishu", sender_id="u1", chat_id="c4", content="/stop") + stop_ctx = CommandContext(msg=stop_msg, session=None, key=stop_msg.session_key, raw="/stop", loop=loop) + stop_result = await cmd_stop(stop_ctx) + + assert "Stopped 1 task" in stop_result.content + assert task.done() + + loop.sessions.invalidate("feishu:c4") + interrupted = loop.sessions.get_or_create("feishu:c4") + assert interrupted.metadata.get(AgentLoop._PENDING_USER_TURN_KEY) is True + assert interrupted.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is not None + + async def resumed_run_agent_loop(initial_messages, **_kwargs): + return ( + "next answer", + None, + [*initial_messages, {"role": "assistant", "content": "next answer"}], + "stop", + False, + ) + + loop._run_agent_loop = resumed_run_agent_loop # type: ignore[method-assign] + result = await loop._process_message( + InboundMessage(channel="feishu", sender_id="u1", chat_id="c4", content="continue here") + ) + + assert result is not None + assert result.content == "next answer" + + session = loop.sessions.get_or_create("feishu:c4") + assert [ + {k: v for k, v in m.items() if k in {"role", "content", "tool_call_id", "name"}} + for m in session.messages + ] == [ + {"role": "user", "content": "keep progress"}, + {"role": "assistant", "content": "working"}, + {"role": "tool", "tool_call_id": "call_done", "name": "read_file", "content": "ok"}, + { + "role": "tool", + "tool_call_id": "call_pending", + "name": "exec", + "content": "Error: Task interrupted before this tool finished.", + }, + {"role": "user", "content": "continue here"}, + {"role": "assistant", "content": "next answer"}, + ] + assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata + assert AgentLoop._RUNTIME_CHECKPOINT_KEY not in session.metadata From e4b3f9bd28b098704c5ce4dc6e8505da434bd9d2 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Tue, 14 Apr 2026 07:19:38 +0000 Subject: [PATCH 112/115] security(gateway): keep health endpoint local by default Bind the gateway health listener to localhost by default and reduce the probe response to a minimal status payload so accidental public exposure leaks less information. Made-with: Cursor --- README.md | 11 +++++++---- nanobot/cli/commands.py | 20 +------------------- nanobot/config/schema.py | 2 +- tests/cli/test_commands.py | 14 +++++--------- 4 files changed, 14 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index a5ddf1368..4dd7a93b3 100644 --- a/README.md +++ b/README.md @@ -1727,6 +1727,7 @@ Example config: } }, "gateway": { + "host": "127.0.0.1", "port": 18790 } } @@ -1739,11 +1740,13 @@ nanobot gateway --config ~/.nanobot-telegram/config.json nanobot gateway --config ~/.nanobot-discord/config.json ``` -Each gateway instance also exposes a lightweight HTTP status endpoint on -`gateway.host:gateway.port`: +Each gateway instance also exposes a lightweight HTTP health endpoint on +`gateway.host:gateway.port`. By default, the gateway binds to `127.0.0.1`, +so the endpoint stays local unless you explicitly set `gateway.host` to a +public or LAN-facing address. -- `GET /` returns `nanobot` -- `GET /health` returns JSON with service metadata, uptime, and enabled channels +- `GET /health` returns `{"status":"ok"}` +- Other paths return `404` Override workspace for one-off runs when needed: diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 1f3f00c85..953e8b1f9 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -824,9 +824,6 @@ def gateway( async def _health_server(host: str, health_port: int): """Lightweight HTTP health endpoint on the gateway port.""" import json as _json - import time - - start_time = time.monotonic() async def handle(reader, writer): try: @@ -842,28 +839,13 @@ def gateway( method, path = parts[0], parts[1] if method == "GET" and path == "/health": - uptime_s = int(time.monotonic() - start_time) - body = _json.dumps({ - "service": "nanobot", - "version": __version__, - "status": "running", - "uptime_seconds": uptime_s, - "channels": channels.enabled_channels, - }) + body = _json.dumps({"status": "ok"}) resp = ( f"HTTP/1.0 200 OK\r\n" f"Content-Type: application/json\r\n" f"Content-Length: {len(body)}\r\n" f"\r\n{body}" ) - elif method == "GET" and path == "/": - body = "nanobot" - resp = ( - f"HTTP/1.0 200 OK\r\n" - f"Content-Type: text/plain\r\n" - f"Content-Length: {len(body)}\r\n" - f"\r\n{body}" - ) else: body = "Not Found" resp = ( diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index aa5ab9932..fd73e0800 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -152,7 +152,7 @@ class ApiConfig(Base): class GatewayConfig(Base): """Gateway/server configuration.""" - host: str = "0.0.0.0" + host: str = "127.0.0.1" # Safer default: local-only bind. port: int = 18790 heartbeat: HeartbeatConfig = Field(default_factory=HeartbeatConfig) diff --git a/tests/cli/test_commands.py b/tests/cli/test_commands.py index 1ae2ffd87..e4edfaf87 100644 --- a/tests/cli/test_commands.py +++ b/tests/cli/test_commands.py @@ -1131,7 +1131,6 @@ def test_gateway_health_endpoint_binds_and_serves_expected_responses( ) -> None: config_file = _write_instance_config(tmp_path) config = Config() - config.gateway.host = "127.0.0.9" config.gateway.port = 18791 captured: dict[str, object] = {} @@ -1245,9 +1244,9 @@ def test_gateway_health_endpoint_binds_and_serves_expected_responses( result = runner.invoke(app, ["gateway", "--config", str(config_file)]) assert result.exit_code == 0 - assert captured["host"] == "127.0.0.9" + assert captured["host"] == "127.0.0.1" assert captured["port"] == 18791 - assert "Health endpoint: http://127.0.0.9:18791/health" in result.stdout + assert "Health endpoint: http://127.0.0.1:18791/health" in result.stdout def _call_handler(path: str) -> tuple[str, _FakeWriter]: request = f"GET {path} HTTP/1.1\r\nHost: localhost\r\n\r\n".encode() @@ -1259,17 +1258,14 @@ def test_gateway_health_endpoint_binds_and_serves_expected_responses( root_response, root_writer = _call_handler("/") assert root_writer.closed is True - assert "HTTP/1.0 200 OK" in root_response - assert root_response.endswith("\r\n\r\nnanobot") + assert "HTTP/1.0 404 Not Found" in root_response + assert root_response.endswith("\r\n\r\nNot Found") health_response, health_writer = _call_handler("/health") assert health_writer.closed is True assert "HTTP/1.0 200 OK" in health_response health_body = json.loads(health_response.split("\r\n\r\n", 1)[1]) - assert health_body["service"] == "nanobot" - assert health_body["status"] == "running" - assert health_body["channels"] == ["telegram", "discord"] - assert health_body["uptime_seconds"] >= 0 + assert health_body == {"status": "ok"} missing_response, missing_writer = _call_handler("/missing") assert missing_writer.closed is True From 0adce5405b2f7733fd63d536c9b51028fab9f3f3 Mon Sep 17 00:00:00 2001 From: chengyongru Date: Tue, 14 Apr 2026 14:14:14 +0800 Subject: [PATCH 113/115] fix(feishu): remove resuming to avoid 10-min streaming card timeout Feishu streaming cards auto-close after 10 minutes from creation, regardless of update activity. With resuming enabled, a single card lives across multiple tool-call rounds and can exceed this limit, causing the final response to be silently lost. Remove the _resuming logic from send_delta so each tool-call round gets its own short-lived streaming card (well under 10 min). Add a fallback that sends a regular interactive card when the final streaming update fails. --- docs/CHANNEL_PLUGIN_GUIDE.md | 1 - nanobot/channels/feishu.py | 112 +++++++----------- tests/channels/test_feishu_streaming.py | 65 ++-------- .../test_feishu_tool_hint_code_block.py | 87 +++++++------- 4 files changed, 93 insertions(+), 172 deletions(-) diff --git a/docs/CHANNEL_PLUGIN_GUIDE.md b/docs/CHANNEL_PLUGIN_GUIDE.md index 86e06bf63..65ff9eec9 100644 --- a/docs/CHANNEL_PLUGIN_GUIDE.md +++ b/docs/CHANNEL_PLUGIN_GUIDE.md @@ -290,7 +290,6 @@ async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | |------|---------| | `_stream_delta: True` | A content chunk (delta contains the new text) | | `_stream_end: True` | Streaming finished (delta is empty) | -| `_resuming: True` | More streaming rounds coming (e.g. tool call then another response) | ### Example: Webhook with Streaming diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py index 5afeca35f..1442c3637 100644 --- a/nanobot/channels/feishu.py +++ b/nanobot/channels/feishu.py @@ -1290,7 +1290,6 @@ class FeishuChannel(BaseChannel): Supported metadata keys: _stream_end: Finalize the streaming card. - _resuming: Mid-turn pause – flush but keep the buffer alive. _tool_hint: Delta is a formatted tool hint (for display only). message_id: Original message id (used with _stream_end for reaction cleanup). reaction_id: Reaction id to remove on stream end. @@ -1309,50 +1308,44 @@ class FeishuChannel(BaseChannel): if self.config.done_emoji and message_id: await self._add_reaction(message_id, self.config.done_emoji) - resuming = meta.get("_resuming", False) - if resuming: - # Mid-turn pause (e.g. tool call between streaming segments). - # Flush current text to card but keep the buffer alive so the - # next segment appends to the same card. - buf = self._stream_bufs.get(chat_id) - if buf and buf.card_id and buf.text: - buf.sequence += 1 - await loop.run_in_executor( - None, self._stream_update_text_sync, buf.card_id, buf.text, buf.sequence, - ) - return - buf = self._stream_bufs.pop(chat_id, None) if not buf or not buf.text: return + # Try to finalize via streaming card; if that fails (e.g. + # streaming mode was closed by Feishu due to timeout), fall + # back to sending a regular interactive card. if buf.card_id: buf.sequence += 1 - await loop.run_in_executor( + ok = await loop.run_in_executor( None, self._stream_update_text_sync, buf.card_id, buf.text, buf.sequence, ) - # Required so the chat list preview exits the streaming placeholder (Feishu streaming card docs). - buf.sequence += 1 - await loop.run_in_executor( - None, - self._close_streaming_mode_sync, - buf.card_id, - buf.sequence, - ) - else: - for chunk in self._split_elements_by_table_limit( - self._build_card_elements(buf.text) - ): - card = json.dumps( - {"config": {"wide_screen_mode": True}, "elements": chunk}, - ensure_ascii=False, - ) + if ok: + buf.sequence += 1 await loop.run_in_executor( - None, self._send_message_sync, rid_type, chat_id, "interactive", card + None, + self._close_streaming_mode_sync, + buf.card_id, + buf.sequence, ) + return + logger.warning( + "Streaming card {} final update failed, falling back to regular card", + buf.card_id, + ) + for chunk in self._split_elements_by_table_limit( + self._build_card_elements(buf.text) + ): + card = json.dumps( + {"config": {"wide_screen_mode": True}, "elements": chunk}, + ensure_ascii=False, + ) + await loop.run_in_executor( + None, self._send_message_sync, rid_type, chat_id, "interactive", card + ) return # --- accumulate delta --- @@ -1404,14 +1397,21 @@ class FeishuChannel(BaseChannel): if buf and buf.card_id: # Delegate to send_delta so tool hints get the same # throttling (and card creation) as regular text deltas. - lines = self.__class__._format_tool_hint_lines(hint).split("\n") - delta = "\n\n" + "\n".join( - f"{self.config.tool_hint_prefix} {ln}" for ln in lines if ln.strip() - ) + "\n\n" - await self.send_delta(msg.chat_id, delta) + await self.send_delta( + msg.chat_id, + "\n\n" + self._format_tool_hint_delta(hint) + "\n\n", + ) return - await self._send_tool_hint_card( - receive_id_type, msg.chat_id, hint + # No active streaming card — send as a regular + # interactive card with the same 🔧 prefix style. + card = json.dumps( + {"config": {"wide_screen_mode": True}, "elements": [ + {"tag": "markdown", "content": self._format_tool_hint_delta(hint)}, + ]}, + ensure_ascii=False, + ) + await loop.run_in_executor( + None, self._send_message_sync, receive_id_type, msg.chat_id, "interactive", card ) return @@ -1708,33 +1708,9 @@ class FeishuChannel(BaseChannel): return "\n".join(part for part in parts if part) - async def _send_tool_hint_card( - self, receive_id_type: str, receive_id: str, tool_hint: str - ) -> None: - """Send tool hint as an interactive card with formatted code block. - - Args: - receive_id_type: "chat_id" or "open_id" - receive_id: The target chat or user ID - tool_hint: Formatted tool hint string (e.g., 'web_search("q"), read_file("path")') - """ - loop = asyncio.get_running_loop() - - # Put each top-level tool call on its own line without altering commas inside arguments. - formatted_code = self.__class__._format_tool_hint_lines(tool_hint) - - card = { - "config": {"wide_screen_mode": True}, - "elements": [ - {"tag": "markdown", "content": f"**Tool Calls**\n\n```text\n{formatted_code}\n```"} - ], - } - - await loop.run_in_executor( - None, - self._send_message_sync, - receive_id_type, - receive_id, - "interactive", - json.dumps(card, ensure_ascii=False), + def _format_tool_hint_delta(self, tool_hint: str) -> str: + """Format a tool hint string with the 🔧 prefix for each line.""" + lines = self.__class__._format_tool_hint_lines(tool_hint).split("\n") + return "\n".join( + f"{self.config.tool_hint_prefix} {ln}" for ln in lines if ln.strip() ) diff --git a/tests/channels/test_feishu_streaming.py b/tests/channels/test_feishu_streaming.py index a047c8c5f..4bef83548 100644 --- a/tests/channels/test_feishu_streaming.py +++ b/tests/channels/test_feishu_streaming.py @@ -205,53 +205,22 @@ class TestSendDelta: ch._client.im.v1.message.create.assert_called_once() @pytest.mark.asyncio - async def test_stream_end_resuming_keeps_buffer(self): - """_resuming=True flushes text to card but keeps the buffer for the next segment.""" + async def test_stream_end_fallback_when_final_update_fails(self): + """If streaming mode was closed (e.g. Feishu timeout), fall back to a regular card.""" ch = _make_channel() ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( - text="Partial answer", card_id="card_1", sequence=2, last_edit=0.0, + text="Lost content", card_id="card_1", sequence=3, last_edit=0.0, ) - ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response() + ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response(success=False) + ch._client.im.v1.message.create.return_value = _mock_send_response("om_fb") - await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True, "_resuming": True}) - - assert "oc_chat1" in ch._stream_bufs - buf = ch._stream_bufs["oc_chat1"] - assert buf.card_id == "card_1" - assert buf.sequence == 3 - ch._client.cardkit.v1.card_element.content.assert_called_once() - ch._client.cardkit.v1.card.settings.assert_not_called() - - @pytest.mark.asyncio - async def test_stream_end_resuming_then_final_end(self): - """Full multi-segment flow: resuming mid-turn, then final end closes the card.""" - ch = _make_channel() - ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( - text="Seg1", card_id="card_1", sequence=1, last_edit=0.0, - ) - ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response() - ch._client.cardkit.v1.card.settings.return_value = _mock_content_response() - - await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True, "_resuming": True}) - assert "oc_chat1" in ch._stream_bufs - - ch._stream_bufs["oc_chat1"].text += " Seg2" await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True}) assert "oc_chat1" not in ch._stream_bufs - ch._client.cardkit.v1.card.settings.assert_called_once() - - @pytest.mark.asyncio - async def test_stream_end_resuming_no_card_is_noop(self): - """_resuming with no card_id (card creation failed) is a safe no-op.""" - ch = _make_channel() - ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( - text="text", card_id=None, sequence=0, last_edit=0.0, - ) - await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True, "_resuming": True}) - - assert "oc_chat1" in ch._stream_bufs - ch._client.cardkit.v1.card_element.content.assert_not_called() + # Should NOT attempt to close streaming mode since update failed + ch._client.cardkit.v1.card.settings.assert_not_called() + # Should fall back to sending a regular interactive card + ch._client.im.v1.message.create.assert_called_once() @pytest.mark.asyncio async def test_stream_end_without_buf_is_noop(self): @@ -375,22 +344,6 @@ class TestToolHintInlineStreaming: assert "🔧 $ cd /project" in buf.text assert "🔧 $ git status" in buf.text - @pytest.mark.asyncio - async def test_tool_hint_preserved_on_resuming_flush(self): - """When _resuming flushes the buffer, tool hint is kept as permanent content.""" - ch = _make_channel() - ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( - text="Partial answer\n\n🔧 $ cd /project\n\n", - card_id="card_1", sequence=2, last_edit=0.0, - ) - ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response() - - await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True, "_resuming": True}) - - buf = ch._stream_bufs["oc_chat1"] - assert "Partial answer" in buf.text - assert "🔧 $ cd /project" in buf.text - @pytest.mark.asyncio async def test_tool_hint_preserved_on_final_stream_end(self): """When final _stream_end closes the card, tool hint is kept in the final text.""" diff --git a/tests/channels/test_feishu_tool_hint_code_block.py b/tests/channels/test_feishu_tool_hint_code_block.py index a5db5ad69..4f9d214c6 100644 --- a/tests/channels/test_feishu_tool_hint_code_block.py +++ b/tests/channels/test_feishu_tool_hint_code_block.py @@ -1,6 +1,7 @@ -"""Tests for FeishuChannel tool hint code block formatting.""" +"""Tests for FeishuChannel tool hint formatting.""" import json +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest @@ -28,15 +29,24 @@ def mock_feishu_channel(): config.app_secret = "test_app_secret" config.encrypt_key = None config.verification_token = None + config.tool_hint_prefix = "\U0001f527" # 🔧 bus = MagicMock() channel = FeishuChannel(config, bus) - channel._client = MagicMock() # Simulate initialized client + channel._client = MagicMock() return channel +def _get_tool_hint_card(mock_send): + """Extract the interactive card from _send_message_sync calls.""" + call_args = mock_send.call_args[0] + _, _, msg_type, content = call_args + assert msg_type == "interactive" + return json.loads(content) + + @mark.asyncio -async def test_tool_hint_sends_code_message(mock_feishu_channel): - """Tool hint messages should be sent as interactive cards with code blocks.""" +async def test_tool_hint_sends_interactive_card(mock_feishu_channel): + """Tool hint without active buffer sends an interactive card with 🔧 style.""" msg = OutboundMessage( channel="feishu", chat_id="oc_123456", @@ -47,23 +57,12 @@ async def test_tool_hint_sends_code_message(mock_feishu_channel): with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send: await mock_feishu_channel.send(msg) - # Verify interactive message with card was sent assert mock_send.call_count == 1 - call_args = mock_send.call_args[0] - receive_id_type, receive_id, msg_type, content = call_args - - assert receive_id_type == "chat_id" - assert receive_id == "oc_123456" - assert msg_type == "interactive" - - # Parse content to verify card structure - card = json.loads(content) + card = _get_tool_hint_card(mock_send) assert card["config"]["wide_screen_mode"] is True - assert len(card["elements"]) == 1 - assert card["elements"][0]["tag"] == "markdown" - # Check that code block is properly formatted with language hint - expected_md = "**Tool Calls**\n\n```text\nweb_search(\"test query\")\n```" - assert card["elements"][0]["content"] == expected_md + md = card["elements"][0]["content"] + assert "\U0001f527" in md + assert "web_search" in md @mark.asyncio @@ -78,8 +77,6 @@ async def test_tool_hint_empty_content_does_not_send(mock_feishu_channel): with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send: await mock_feishu_channel.send(msg) - - # Should not send any message mock_send.assert_not_called() @@ -96,7 +93,6 @@ async def test_tool_hint_without_metadata_sends_as_normal(mock_feishu_channel): with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send: await mock_feishu_channel.send(msg) - # Should send as text message (detected format) assert mock_send.call_count == 1 call_args = mock_send.call_args[0] _, _, msg_type, content = call_args @@ -106,7 +102,7 @@ async def test_tool_hint_without_metadata_sends_as_normal(mock_feishu_channel): @mark.asyncio async def test_tool_hint_multiple_tools_in_one_message(mock_feishu_channel): - """Multiple tool calls should be displayed each on its own line in a code block.""" + """Multiple tool calls should each get the 🔧 prefix.""" msg = OutboundMessage( channel="feishu", chat_id="oc_123456", @@ -117,13 +113,11 @@ async def test_tool_hint_multiple_tools_in_one_message(mock_feishu_channel): with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send: await mock_feishu_channel.send(msg) - call_args = mock_send.call_args[0] - msg_type = call_args[2] - content = json.loads(call_args[3]) - assert msg_type == "interactive" - # Each tool call should be on its own line - expected_md = "**Tool Calls**\n\n```text\nweb_search(\"query\"),\nread_file(\"/path/to/file\")\n```" - assert content["elements"][0]["content"] == expected_md + card = _get_tool_hint_card(mock_send) + md = card["elements"][0]["content"] + assert "web_search" in md + assert "read_file" in md + assert "\U0001f527" in md @mark.asyncio @@ -139,8 +133,8 @@ async def test_tool_hint_new_format_basic(mock_feishu_channel): with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send: await mock_feishu_channel.send(msg) - content = json.loads(mock_send.call_args[0][3]) - md = content["elements"][0]["content"] + card = _get_tool_hint_card(mock_send) + md = card["elements"][0]["content"] assert "read src/main.py" in md assert 'grep "TODO"' in md @@ -158,16 +152,15 @@ async def test_tool_hint_new_format_with_comma_in_quotes(mock_feishu_channel): with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send: await mock_feishu_channel.send(msg) - content = json.loads(mock_send.call_args[0][3]) - md = content["elements"][0]["content"] - # The comma inside quotes should NOT cause a line break + card = _get_tool_hint_card(mock_send) + md = card["elements"][0]["content"] assert 'grep "hello, world"' in md assert "$ echo test" in md @mark.asyncio async def test_tool_hint_new_format_with_folding(mock_feishu_channel): - """Folded calls (× N) should display on separate lines.""" + """Folded calls (× N) should display correctly.""" msg = OutboundMessage( channel="feishu", chat_id="oc_123456", @@ -178,8 +171,8 @@ async def test_tool_hint_new_format_with_folding(mock_feishu_channel): with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send: await mock_feishu_channel.send(msg) - content = json.loads(mock_send.call_args[0][3]) - md = content["elements"][0]["content"] + card = _get_tool_hint_card(mock_send) + md = card["elements"][0]["content"] assert "\u00d7 3" in md assert 'grep "pattern"' in md @@ -197,9 +190,12 @@ async def test_tool_hint_new_format_mcp(mock_feishu_channel): with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send: await mock_feishu_channel.send(msg) - content = json.loads(mock_send.call_args[0][3]) - md = content["elements"][0]["content"] + card = _get_tool_hint_card(mock_send) + md = card["elements"][0]["content"] assert "4_5v::analyze_image" in md + + +@mark.asyncio async def test_tool_hint_keeps_commas_inside_arguments(mock_feishu_channel): """Commas inside a single tool argument must not be split onto a new line.""" msg = OutboundMessage( @@ -212,10 +208,7 @@ async def test_tool_hint_keeps_commas_inside_arguments(mock_feishu_channel): with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send: await mock_feishu_channel.send(msg) - content = json.loads(mock_send.call_args[0][3]) - expected_md = ( - "**Tool Calls**\n\n```text\n" - "web_search(\"foo, bar\"),\n" - "read_file(\"/path/to/file\")\n```" - ) - assert content["elements"][0]["content"] == expected_md + card = _get_tool_hint_card(mock_send) + md = card["elements"][0]["content"] + assert 'web_search("foo, bar")' in md + assert 'read_file("/path/to/file")' in md From 873be5180b9a52ef495866732c584970cb481e41 Mon Sep 17 00:00:00 2001 From: yeyitech Date: Tue, 14 Apr 2026 14:31:33 +0800 Subject: [PATCH 114/115] feat(slack): resolve named message targets --- nanobot/channels/slack.py | 125 +++++++++++++++++++++++- tests/channels/test_slack_channel.py | 137 ++++++++++++++++++++++++++- 2 files changed, 255 insertions(+), 7 deletions(-) diff --git a/nanobot/channels/slack.py b/nanobot/channels/slack.py index 2503f6a2d..af03d4973 100644 --- a/nanobot/channels/slack.py +++ b/nanobot/channels/slack.py @@ -5,6 +5,7 @@ import re from typing import Any from loguru import logger +from pydantic import Field from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.response import SocketModeResponse from slack_sdk.socket_mode.websockets import SocketModeClient @@ -13,8 +14,6 @@ from slackify_markdown import slackify_markdown from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus -from pydantic import Field - from nanobot.channels.base import BaseChannel from nanobot.config.schema import Base @@ -50,6 +49,9 @@ class SlackChannel(BaseChannel): name = "slack" display_name = "Slack" + _SLACK_ID_RE = re.compile(r"^[CDGUW][A-Z0-9]{2,}$") + _SLACK_CHANNEL_REF_RE = re.compile(r"^<#([A-Z0-9]+)(?:\|[^>]+)?>$") + _SLACK_USER_REF_RE = re.compile(r"^<@([A-Z0-9]+)(?:\|[^>]+)?>$") @classmethod def default_config(cls) -> dict[str, Any]: @@ -63,6 +65,7 @@ class SlackChannel(BaseChannel): self._web_client: AsyncWebClient | None = None self._socket_client: SocketModeClient | None = None self._bot_user_id: str | None = None + self._target_cache: dict[str, str] = {} async def start(self) -> None: """Start the Slack Socket Mode client.""" @@ -113,6 +116,7 @@ class SlackChannel(BaseChannel): logger.warning("Slack client not running") return try: + target_chat_id = await self._resolve_target_chat_id(msg.chat_id) slack_meta = msg.metadata.get("slack", {}) if msg.metadata else {} thread_ts = slack_meta.get("thread_ts") channel_type = slack_meta.get("channel_type") @@ -123,7 +127,7 @@ class SlackChannel(BaseChannel): # but send a single blank message when the bot has no text or files to send. if msg.content or not (msg.media or []): await self._web_client.chat_postMessage( - channel=msg.chat_id, + channel=target_chat_id, text=self._to_mrkdwn(msg.content) if msg.content else " ", thread_ts=thread_ts_param, ) @@ -131,7 +135,7 @@ class SlackChannel(BaseChannel): for media_path in msg.media or []: try: await self._web_client.files_upload_v2( - channel=msg.chat_id, + channel=target_chat_id, file=media_path, thread_ts=thread_ts_param, ) @@ -141,12 +145,123 @@ class SlackChannel(BaseChannel): # Update reaction emoji when the final (non-progress) response is sent if not (msg.metadata or {}).get("_progress"): event = slack_meta.get("event", {}) - await self._update_react_emoji(msg.chat_id, event.get("ts")) + await self._update_react_emoji(event.get("channel") or msg.chat_id, event.get("ts")) except Exception as e: logger.error("Error sending Slack message: {}", e) raise + async def _resolve_target_chat_id(self, target: str) -> str: + """Resolve human-friendly Slack targets to concrete IDs when needed.""" + if not self._web_client: + return target + + target = target.strip() + if not target: + return target + + if match := self._SLACK_CHANNEL_REF_RE.fullmatch(target): + return match.group(1) + if match := self._SLACK_USER_REF_RE.fullmatch(target): + return await self._open_dm_for_user(match.group(1)) + if self._SLACK_ID_RE.fullmatch(target): + if target.startswith(("U", "W")): + return await self._open_dm_for_user(target) + return target + + if target.startswith("#"): + return await self._resolve_channel_name(target[1:]) + if target.startswith("@"): + return await self._resolve_user_handle(target[1:]) + + try: + return await self._resolve_channel_name(target) + except ValueError: + return await self._resolve_user_handle(target) + + async def _resolve_channel_name(self, name: str) -> str: + normalized = self._normalize_target_name(name) + if not normalized: + raise ValueError("Slack target channel name is empty") + + cache_key = f"channel:{normalized}" + if cache_key in self._target_cache: + return self._target_cache[cache_key] + + cursor: str | None = None + while True: + response = await self._web_client.conversations_list( + types="public_channel,private_channel", + exclude_archived=True, + limit=200, + cursor=cursor, + ) + for channel in response.get("channels", []): + if self._normalize_target_name(str(channel.get("name") or "")) == normalized: + channel_id = str(channel.get("id") or "") + if channel_id: + self._target_cache[cache_key] = channel_id + return channel_id + cursor = ((response.get("response_metadata") or {}).get("next_cursor") or "").strip() + if not cursor: + break + + raise ValueError( + f"Slack channel '{name}' was not found. Use a joined channel name like " + f"'#general' or a concrete channel ID." + ) + + async def _resolve_user_handle(self, handle: str) -> str: + normalized = self._normalize_target_name(handle) + if not normalized: + raise ValueError("Slack target user handle is empty") + + cache_key = f"user:{normalized}" + if cache_key in self._target_cache: + return self._target_cache[cache_key] + + cursor: str | None = None + while True: + response = await self._web_client.users_list(limit=200, cursor=cursor) + for member in response.get("members", []): + if self._member_matches_handle(member, normalized): + user_id = str(member.get("id") or "") + if not user_id: + continue + dm_id = await self._open_dm_for_user(user_id) + self._target_cache[cache_key] = dm_id + return dm_id + cursor = ((response.get("response_metadata") or {}).get("next_cursor") or "").strip() + if not cursor: + break + + raise ValueError( + f"Slack user '{handle}' was not found. Use '@name' or a concrete DM/channel ID." + ) + + async def _open_dm_for_user(self, user_id: str) -> str: + response = await self._web_client.conversations_open(users=user_id) + channel_id = str(((response.get("channel") or {}).get("id")) or "") + if not channel_id: + raise ValueError(f"Slack DM target for user '{user_id}' could not be opened.") + return channel_id + + @staticmethod + def _normalize_target_name(value: str) -> str: + return value.strip().lstrip("#@").lower() + + @classmethod + def _member_matches_handle(cls, member: dict[str, Any], normalized: str) -> bool: + profile = member.get("profile") or {} + candidates = { + str(member.get("name") or ""), + str(profile.get("display_name") or ""), + str(profile.get("display_name_normalized") or ""), + str(profile.get("real_name") or ""), + str(profile.get("real_name_normalized") or ""), + } + return normalized in {cls._normalize_target_name(candidate) for candidate in candidates if candidate} + async def _on_socket_request( self, client: SocketModeClient, diff --git a/tests/channels/test_slack_channel.py b/tests/channels/test_slack_channel.py index f7eec95c0..6fb05a912 100644 --- a/tests/channels/test_slack_channel.py +++ b/tests/channels/test_slack_channel.py @@ -10,8 +10,7 @@ except ImportError: from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus -from nanobot.channels.slack import SlackChannel -from nanobot.channels.slack import SlackConfig +from nanobot.channels.slack import SlackChannel, SlackConfig class _FakeAsyncWebClient: @@ -20,6 +19,12 @@ class _FakeAsyncWebClient: self.file_upload_calls: list[dict[str, object | None]] = [] self.reactions_add_calls: list[dict[str, object | None]] = [] self.reactions_remove_calls: list[dict[str, object | None]] = [] + self.conversations_list_calls: list[dict[str, object | None]] = [] + self.users_list_calls: list[dict[str, object | None]] = [] + self.conversations_open_calls: list[dict[str, object | None]] = [] + self._conversations_pages: list[dict[str, object]] = [] + self._users_pages: list[dict[str, object]] = [] + self._open_dm_response: dict[str, object] = {"channel": {"id": "D_OPENED"}} async def chat_postMessage( self, @@ -81,6 +86,22 @@ class _FakeAsyncWebClient: } ) + async def conversations_list(self, **kwargs): + self.conversations_list_calls.append(kwargs) + if self._conversations_pages: + return self._conversations_pages.pop(0) + return {"channels": [], "response_metadata": {"next_cursor": ""}} + + async def users_list(self, **kwargs): + self.users_list_calls.append(kwargs) + if self._users_pages: + return self._users_pages.pop(0) + return {"members": [], "response_metadata": {"next_cursor": ""}} + + async def conversations_open(self, **kwargs): + self.conversations_open_calls.append(kwargs) + return self._open_dm_response + @pytest.mark.asyncio async def test_send_uses_thread_for_channel_messages() -> None: @@ -151,3 +172,115 @@ async def test_send_updates_reaction_when_final_response_sent() -> None: assert fake_web.reactions_add_calls == [ {"channel": "C123", "name": "white_check_mark", "timestamp": "1700000000.000100"} ] + + +@pytest.mark.asyncio +async def test_send_resolves_channel_name_to_channel_id() -> None: + channel = SlackChannel(SlackConfig(enabled=True), MessageBus()) + fake_web = _FakeAsyncWebClient() + fake_web._conversations_pages = [ + { + "channels": [{"id": "C999", "name": "channel_x"}], + "response_metadata": {"next_cursor": ""}, + } + ] + channel._web_client = fake_web + + await channel.send( + OutboundMessage( + channel="slack", + chat_id="#channel_x", + content="hello", + ) + ) + + assert fake_web.chat_post_calls == [ + {"channel": "C999", "text": "hello\n", "thread_ts": None} + ] + assert len(fake_web.conversations_list_calls) == 1 + + +@pytest.mark.asyncio +async def test_send_resolves_user_handle_to_dm_channel() -> None: + channel = SlackChannel(SlackConfig(enabled=True), MessageBus()) + fake_web = _FakeAsyncWebClient() + fake_web._users_pages = [ + { + "members": [ + { + "id": "U234", + "name": "alice", + "profile": {"display_name": "Alice"}, + } + ], + "response_metadata": {"next_cursor": ""}, + } + ] + fake_web._open_dm_response = {"channel": {"id": "D234"}} + channel._web_client = fake_web + + await channel.send( + OutboundMessage( + channel="slack", + chat_id="@alice", + content="hello", + ) + ) + + assert fake_web.conversations_open_calls == [{"users": "U234"}] + assert fake_web.chat_post_calls == [ + {"channel": "D234", "text": "hello\n", "thread_ts": None} + ] + + +@pytest.mark.asyncio +async def test_send_updates_reaction_on_origin_channel_for_cross_channel_send() -> None: + channel = SlackChannel(SlackConfig(enabled=True, react_emoji="eyes"), MessageBus()) + fake_web = _FakeAsyncWebClient() + fake_web._conversations_pages = [ + { + "channels": [{"id": "C999", "name": "channel_x"}], + "response_metadata": {"next_cursor": ""}, + } + ] + channel._web_client = fake_web + + await channel.send( + OutboundMessage( + channel="slack", + chat_id="channel_x", + content="done", + metadata={ + "slack": { + "event": {"ts": "1700000000.000100", "channel": "D_ORIGIN"}, + "channel_type": "im", + }, + }, + ) + ) + + assert fake_web.chat_post_calls == [ + {"channel": "C999", "text": "done\n", "thread_ts": None} + ] + assert fake_web.reactions_remove_calls == [ + {"channel": "D_ORIGIN", "name": "eyes", "timestamp": "1700000000.000100"} + ] + assert fake_web.reactions_add_calls == [ + {"channel": "D_ORIGIN", "name": "white_check_mark", "timestamp": "1700000000.000100"} + ] + + +@pytest.mark.asyncio +async def test_send_raises_when_named_target_cannot_be_resolved() -> None: + channel = SlackChannel(SlackConfig(enabled=True), MessageBus()) + fake_web = _FakeAsyncWebClient() + channel._web_client = fake_web + + with pytest.raises(ValueError, match="was not found"): + await channel.send( + OutboundMessage( + channel="slack", + chat_id="#missing-channel", + content="hello", + ) + ) From 0a51344483d8210d3673c4b0489557a8e0b217f8 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Tue, 14 Apr 2026 11:53:06 +0000 Subject: [PATCH 115/115] fix(slack): keep cross-target sends out of origin threads When Slack resolves a named target to another conversation, do not reuse the origin thread timestamp on the destination send, and keep reaction cleanup anchored to the source conversation. Made-with: Cursor --- nanobot/channels/slack.py | 9 ++++++-- tests/channels/test_slack_channel.py | 32 ++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/nanobot/channels/slack.py b/nanobot/channels/slack.py index af03d4973..c68020ce7 100644 --- a/nanobot/channels/slack.py +++ b/nanobot/channels/slack.py @@ -120,8 +120,13 @@ class SlackChannel(BaseChannel): slack_meta = msg.metadata.get("slack", {}) if msg.metadata else {} thread_ts = slack_meta.get("thread_ts") channel_type = slack_meta.get("channel_type") + origin_chat_id = str((slack_meta.get("event", {}) or {}).get("channel") or msg.chat_id) # Slack DMs don't use threads; channel/group replies may keep thread_ts. - thread_ts_param = thread_ts if thread_ts and channel_type != "im" else None + thread_ts_param = ( + thread_ts + if thread_ts and channel_type != "im" and target_chat_id == origin_chat_id + else None + ) # Slack rejects empty text payloads. Keep media-only messages media-only, # but send a single blank message when the bot has no text or files to send. @@ -145,7 +150,7 @@ class SlackChannel(BaseChannel): # Update reaction emoji when the final (non-progress) response is sent if not (msg.metadata or {}).get("_progress"): event = slack_meta.get("event", {}) - await self._update_react_emoji(event.get("channel") or msg.chat_id, event.get("ts")) + await self._update_react_emoji(origin_chat_id, event.get("ts")) except Exception as e: logger.error("Error sending Slack message: {}", e) diff --git a/tests/channels/test_slack_channel.py b/tests/channels/test_slack_channel.py index 6fb05a912..2e72c4e61 100644 --- a/tests/channels/test_slack_channel.py +++ b/tests/channels/test_slack_channel.py @@ -270,6 +270,38 @@ async def test_send_updates_reaction_on_origin_channel_for_cross_channel_send() ] +@pytest.mark.asyncio +async def test_send_does_not_reuse_origin_thread_ts_for_cross_channel_send() -> None: + channel = SlackChannel(SlackConfig(enabled=True), MessageBus()) + fake_web = _FakeAsyncWebClient() + fake_web._conversations_pages = [ + { + "channels": [{"id": "C999", "name": "channel_x"}], + "response_metadata": {"next_cursor": ""}, + } + ] + channel._web_client = fake_web + + await channel.send( + OutboundMessage( + channel="slack", + chat_id="channel_x", + content="done", + metadata={ + "slack": { + "event": {"ts": "1700000000.000100", "channel": "C_ORIGIN"}, + "thread_ts": "1700000000.000200", + "channel_type": "channel", + }, + }, + ) + ) + + assert fake_web.chat_post_calls == [ + {"channel": "C999", "text": "done\n", "thread_ts": None} + ] + + @pytest.mark.asyncio async def test_send_raises_when_named_target_cannot_be_resolved() -> None: channel = SlackChannel(SlackConfig(enabled=True), MessageBus())