diff --git a/docs/channel-plugin-guide.md b/docs/channel-plugin-guide.md index d37a92883..da668c9ee 100644 --- a/docs/channel-plugin-guide.md +++ b/docs/channel-plugin-guide.md @@ -238,6 +238,9 @@ nanobot channels login --force # re-authenticate | `supports_streaming` (property) | `True` when config has `"streaming": true` **and** subclass overrides `send_delta()`. | | `is_running` | Returns `self._running`. | | `login(force=False)` | Perform interactive login (e.g. QR code scan). Returns `True` if already authenticated or login succeeds. Override in subclasses that support interactive login. | +| `send_reasoning_delta(chat_id, delta, metadata?)` | Optional hook for streamed model reasoning/thinking content. Default is no-op. | +| `send_reasoning_end(chat_id, metadata?)` | Optional hook marking the end of a reasoning block. Default is no-op. | +| `send_reasoning(msg)` | Optional one-shot reasoning fallback. Default translates to `send_reasoning_delta()` + `send_reasoning_end()`. | ### Optional (streaming) @@ -350,6 +353,112 @@ When `streaming` is `false` (default) or omitted, only `send()` is called — no | `async send_delta(chat_id, delta, metadata?)` | Override to handle streaming chunks. No-op by default. | | `supports_streaming` (property) | Returns `True` when config has `streaming: true` **and** subclass overrides `send_delta`. | +## Progress, Tool Hints, and Reasoning + +Besides normal assistant text, nanobot can emit low-emphasis trace blocks. These are intended for UI affordances like status rows, collapsible "used tools" groups, or reasoning/thinking blocks. Platforms that do not have a good place for them can ignore them safely. + +### Progress and Tool Hints + +Progress and tool hints arrive through the normal `send(msg)` path. Check `msg.metadata` before rendering: + +```python +async def send(self, msg: OutboundMessage) -> None: + meta = msg.metadata or {} + + if meta.get("_tool_hint"): + # A short tool breadcrumb, e.g. read_file("config.json") + await self._send_trace(msg.chat_id, msg.content, kind="tool") + return + + if meta.get("_progress"): + # Generic non-final status, e.g. "Thinking..." or "Running command..." + await self._send_trace(msg.chat_id, msg.content, kind="progress") + return + + await self._send_message(msg.chat_id, msg.content, media=msg.media) +``` + +Tool hints are off by default for most channels. Users can enable them globally or per channel: + +```json +{ + "channels": { + "sendToolHints": true, + "webhook": { + "enabled": true, + "sendToolHints": true + } + } +} +``` + +### Reasoning Blocks + +Reasoning is delivered through dedicated optional hooks, not `send()`. Override `send_reasoning_delta()` and `send_reasoning_end()` if your platform can show model reasoning as a subdued/collapsible block. The default implementation is a no-op, so unsupported channels simply drop reasoning content. + +```python +class WebhookChannel(BaseChannel): + name = "webhook" + display_name = "Webhook" + + def __init__(self, config: Any, bus: MessageBus): + if isinstance(config, dict): + config = WebhookConfig(**config) + super().__init__(config, bus) + self._reasoning_buffers: dict[str, str] = {} + + async def send_reasoning_delta( + self, + chat_id: str, + delta: str, + metadata: dict[str, Any] | None = None, + ) -> None: + meta = metadata or {} + stream_id = str(meta.get("_stream_id") or chat_id) + self._reasoning_buffers[stream_id] = self._reasoning_buffers.get(stream_id, "") + delta + await self._update_reasoning_block(chat_id, self._reasoning_buffers[stream_id], final=False) + + async def send_reasoning_end( + self, + chat_id: str, + metadata: dict[str, Any] | None = None, + ) -> None: + meta = metadata or {} + stream_id = str(meta.get("_stream_id") or chat_id) + text = self._reasoning_buffers.pop(stream_id, "") + if text: + await self._update_reasoning_block(chat_id, text, final=True) +``` + +**Reasoning metadata flags:** + +| Flag | Meaning | +|------|---------| +| `_reasoning_delta: True` | A reasoning/thinking chunk; `delta` contains the new text. | +| `_reasoning_end: True` | The current reasoning block is complete; `delta` is empty. | +| `_reasoning: True` | Legacy one-shot reasoning. `BaseChannel.send_reasoning()` converts it to delta + end. | +| `_stream_id` | Stable id for this assistant turn/segment. Use it to key buffers instead of only `chat_id`. | + +Reasoning visibility is controlled by `showReasoning` globally or per channel: + +```json +{ + "channels": { + "showReasoning": true, + "webhook": { + "enabled": true, + "showReasoning": true + } + } +} +``` + +Recommended rendering: + +- Render tool hints and progress as trace/status UI, not as normal assistant replies. +- Render reasoning with lower visual emphasis and collapse it after completion when the platform supports that. +- Keep reasoning separate from final answer text. A final answer still arrives through `send()` or `send_delta()`. + ## Config ### Why Pydantic model is required diff --git a/docs/configuration.md b/docs/configuration.md index c0d73e7b2..0123017d2 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -743,6 +743,7 @@ Global settings that apply to all channels. Configure under the `channels` secti |---------|---------|-------------| | `sendProgress` | `true` | Stream agent's text progress to the channel | | `sendToolHints` | `false` | Stream tool-call hints (e.g. `read_file("…")`) | +| `showReasoning` | `true` | Allow channels to surface model reasoning/thinking content (DeepSeek-R1 `reasoning_content`, Anthropic `thinking_blocks`, inline `` tags). Reasoning flows as a dedicated stream with `_reasoning_delta` / `_reasoning_end` markers — channels override `send_reasoning_delta` / `send_reasoning_end` to render in-place updates. Even with `true`, channels without those overrides stay no-op silently. Currently surfaced on CLI and WebSocket/WebUI (italic shimmer header, auto-collapses after the stream ends); Telegram / Slack / Discord / Feishu / WeChat / Matrix keep the base no-op until their bubble UI is adapted. Independent of `sendProgress`. | | `sendMaxRetries` | `3` | Max delivery attempts per outbound message, including the initial send (0-10 configured, minimum 1 actual attempt) | | `transcriptionProvider` | `"groq"` | Voice transcription backend: `"groq"` (free tier, default) or `"openai"`. API key is auto-resolved from the matching provider config. | | `transcriptionLanguage` | `null` | Optional ISO-639-1 language hint for audio transcription, e.g. `"en"`, `"ko"`, `"ja"`. | diff --git a/docs/websocket.md b/docs/websocket.md index 556bb5bb6..d6a816ac1 100644 --- a/docs/websocket.md +++ b/docs/websocket.md @@ -128,6 +128,29 @@ All frames are JSON text. Each message has an `event` field. } ``` +**`reasoning_delta`** — incremental model reasoning / thinking chunk for the active assistant turn. Mirrors `delta` but targets the reasoning bubble above the answer rather than the answer body: + +```json +{ + "event": "reasoning_delta", + "chat_id": "uuid-v4", + "text": "Let me decompose ", + "stream_id": "r1" +} +``` + +**`reasoning_end`** — close marker for the active reasoning stream. WebUI uses this to lock the in-place bubble and switch from the shimmer header to a static collapsed state: + +```json +{ + "event": "reasoning_end", + "chat_id": "uuid-v4", + "stream_id": "r1" +} +``` + +Reasoning frames only flow when the channel's `showReasoning` is `true` (default) and the model returns reasoning content (DeepSeek-R1 / Kimi / MiMo / OpenAI reasoning models, Anthropic extended thinking, or inline `` / `` tags). Models without reasoning produce zero `reasoning_delta` frames. + **`runtime_model_updated`** — broadcast when the gateway runtime model changes, for example after `/model `: ```json diff --git a/nanobot/agent/hook.py b/nanobot/agent/hook.py index d0106cfb6..5b6fed445 100644 --- a/nanobot/agent/hook.py +++ b/nanobot/agent/hook.py @@ -22,6 +22,7 @@ class AgentHookContext: tool_results: list[Any] = field(default_factory=list) tool_events: list[dict[str, str]] = field(default_factory=list) streamed_content: bool = False + streamed_reasoning: bool = False final_content: str | None = None stop_reason: str | None = None error: str | None = None @@ -48,6 +49,17 @@ class AgentHook: async def before_execute_tools(self, context: AgentHookContext) -> None: pass + async def emit_reasoning(self, reasoning_content: str | None) -> None: + pass + + async def emit_reasoning_end(self) -> None: + """Mark the end of an in-flight reasoning stream. + + Hooks that buffer ``emit_reasoning`` chunks (for in-place UI updates) + flush and freeze the rendered group here. One-shot hooks ignore. + """ + pass + async def after_iteration(self, context: AgentHookContext) -> None: pass @@ -95,6 +107,12 @@ class CompositeHook(AgentHook): async def before_execute_tools(self, context: AgentHookContext) -> None: await self._for_each_hook_safe("before_execute_tools", context) + async def emit_reasoning(self, reasoning_content: str | None) -> None: + await self._for_each_hook_safe("emit_reasoning", reasoning_content) + + async def emit_reasoning_end(self) -> None: + await self._for_each_hook_safe("emit_reasoning_end") + async def after_iteration(self, context: AgentHookContext) -> None: await self._for_each_hook_safe("after_iteration", context) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index c73013379..9bfce39fb 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -4,7 +4,6 @@ from __future__ import annotations import asyncio import dataclasses -import json import os import time from contextlib import AsyncExitStack, nullcontext, suppress @@ -15,19 +14,14 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable from loguru import logger +from nanobot.agent import model_presets as preset_helpers from nanobot.agent.autocompact import AutoCompact from nanobot.agent.context import ContextBuilder -from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook +from nanobot.agent.hook import AgentHook, CompositeHook from nanobot.agent.memory import Consolidator, Dream -from nanobot.agent import model_presets as preset_helpers +from nanobot.agent.progress_hook import AgentProgressHook from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec from nanobot.agent.subagent import SubagentManager -from nanobot.agent.tools.ask import ( - ask_user_options_from_messages, - ask_user_outbound, - ask_user_tool_result_messages, - pending_ask_user_id, -) from nanobot.agent.tools.file_state import FileStateStore, bind_file_states, reset_file_states from nanobot.agent.tools.message import MessageTool from nanobot.agent.tools.registry import ToolRegistry @@ -44,12 +38,6 @@ from nanobot.utils.document import extract_documents from nanobot.utils.helpers import image_placeholder_text from nanobot.utils.helpers import truncate_text as truncate_text_fn from nanobot.utils.image_generation_intent import image_generation_prompt -from nanobot.utils.progress_events import ( - build_tool_event_finish_payloads, - build_tool_event_start_payload, - invoke_on_progress, - on_progress_accepts_tool_events, -) from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE from nanobot.utils.webui_titles import mark_webui_session, maybe_generate_webui_title_after_turn @@ -65,114 +53,6 @@ if TYPE_CHECKING: UNIFIED_SESSION_KEY = "unified:default" -class _LoopHook(AgentHook): - """Core hook for the main loop.""" - - def __init__( - self, - agent_loop: AgentLoop, - on_progress: Callable[..., Awaitable[None]] | None = None, - on_stream: Callable[[str], Awaitable[None]] | None = None, - on_stream_end: Callable[..., Awaitable[None]] | None = None, - *, - channel: str = "cli", - chat_id: str = "direct", - message_id: str | None = None, - metadata: dict[str, Any] | None = None, - session_key: str | None = None, - ) -> None: - super().__init__(reraise=True) - self._loop = agent_loop - self._on_progress = on_progress - self._on_stream = on_stream - self._on_stream_end = on_stream_end - self._channel = channel - self._chat_id = chat_id - self._message_id = message_id - self._metadata = metadata or {} - self._session_key = session_key - self._stream_buf = "" - - def wants_streaming(self) -> bool: - return self._on_stream is not None - - async def on_stream(self, context: AgentHookContext, delta: str) -> None: - from nanobot.utils.helpers import strip_think - - prev_clean = strip_think(self._stream_buf) - self._stream_buf += delta - new_clean = strip_think(self._stream_buf) - incremental = new_clean[len(prev_clean) :] - if incremental and self._on_stream: - await self._on_stream(incremental) - - async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: - if self._on_stream_end: - await self._on_stream_end(resuming=resuming) - self._stream_buf = "" - - async def before_iteration(self, context: AgentHookContext) -> None: - self._loop._current_iteration = context.iteration - logger.debug( - "Starting agent loop iteration {} for session {}", - context.iteration, - self._session_key, - ) - - async def before_execute_tools(self, context: AgentHookContext) -> None: - if self._on_progress: - if not self._on_stream and not context.streamed_content: - thought = self._loop._strip_think( - context.response.content if context.response else None - ) - if thought: - await self._on_progress(thought) - tool_hint = self._loop._strip_think(self._loop._tool_hint(context.tool_calls)) - tool_events = [build_tool_event_start_payload(tc) for tc in context.tool_calls] - await invoke_on_progress( - self._on_progress, - tool_hint, - tool_hint=True, - tool_events=tool_events, - ) - for tc in context.tool_calls: - args_str = json.dumps(tc.arguments, ensure_ascii=False) - logger.info("Tool call: {}({})", tc.name, args_str[:200]) - self._loop._set_tool_context( - self._channel, - self._chat_id, - self._message_id, - self._metadata, - session_key=self._session_key, - ) - - async def after_iteration(self, context: AgentHookContext) -> None: - if ( - self._on_progress - and context.tool_calls - and context.tool_events - and on_progress_accepts_tool_events(self._on_progress) - ): - tool_events = build_tool_event_finish_payloads(context) - if tool_events: - await invoke_on_progress( - self._on_progress, - "", - tool_hint=False, - tool_events=tool_events, - ) - u = context.usage or {} - logger.debug( - "LLM usage: prompt={} completion={} cached={}", - u.get("prompt_tokens", 0), - u.get("completion_tokens", 0), - u.get("cached_tokens", 0), - ) - - def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: - return self._loop._strip_think(content) - - class TurnState(Enum): RESTORE = auto() COMPACT = auto() @@ -623,26 +503,11 @@ class AgentLoop: if tool and isinstance(tool, ContextAware): tool.set_context(request_ctx) - @staticmethod - def _strip_think(text: str | None) -> str | None: - """Remove blocks that some models embed in content.""" - if not text: - return None - from nanobot.utils.helpers import strip_think - - return strip_think(text) or None - @staticmethod def _runtime_chat_id(msg: InboundMessage) -> str: """Return the chat id shown in runtime metadata for the model.""" return str(msg.metadata.get("context_chat_id") or msg.chat_id) - def _tool_hint(self, tool_calls: list) -> str: - """Format tool calls as concise hints with smart abbreviation.""" - from nanobot.utils.tool_hints import format_tool_hints - - return format_tool_hints(tool_calls, max_length=self.tool_hint_max_length) - async def _build_bus_progress_callback( self, msg: InboundMessage ) -> Callable[..., Awaitable[None]]: @@ -653,10 +518,16 @@ class AgentLoop: *, tool_hint: bool = False, tool_events: list[dict[str, Any]] | None = None, + reasoning: bool = False, + reasoning_end: bool = False, ) -> None: meta = dict(msg.metadata or {}) meta["_progress"] = True meta["_tool_hint"] = tool_hint + if reasoning: + meta["_reasoning_delta"] = True + if reasoning_end: + meta["_reasoning_end"] = True if tool_events: meta["_tool_events"] = tool_events await self.bus.publish_outbound( @@ -693,7 +564,6 @@ class AgentLoop: self, msg: InboundMessage, session: Session, - pending_ask_id: str | None, ) -> bool: """Persist the triggering user message before the turn starts. @@ -701,7 +571,7 @@ class AgentLoop: """ media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p] has_text = isinstance(msg.content, str) and msg.content.strip() - if not pending_ask_id and (has_text or media_paths): + if has_text or media_paths: extra: dict[str, Any] = {"media": list(media_paths)} if media_paths else {} text = msg.content if isinstance(msg.content, str) else "" session.add_message("user", text, **extra) @@ -715,21 +585,9 @@ class AgentLoop: msg: InboundMessage, session: Session, history: list[dict[str, Any]], - pending_ask_id: str | None, pending_summary: str | None, ) -> list[dict[str, Any]]: """Build the initial message list for the LLM turn.""" - if pending_ask_id: - system_prompt = self.context.build_system_prompt( - channel=msg.channel, - session_summary=pending_summary, - ) - return ask_user_tool_result_messages( - system_prompt, - history, - pending_ask_id, - image_generation_prompt(msg.content, msg.metadata), - ) return self.context.build_messages( history=history, current_message=image_generation_prompt(msg.content, msg.metadata), @@ -813,8 +671,7 @@ class AgentLoop: """ self._sync_subagent_runtime_limits() - loop_hook = _LoopHook( - self, + loop_hook = AgentProgressHook( on_progress=on_progress, on_stream=on_stream, on_stream_end=on_stream_end, @@ -823,6 +680,9 @@ class AgentLoop: message_id=message_id, metadata=metadata, session_key=session_key, + tool_hint_max_length=self.tool_hint_max_length, + set_tool_context=self._set_tool_context, + on_iteration=lambda iteration: setattr(self, "_current_iteration", iteration), ) hook: AgentHook = ( CompositeHook([loop_hook] + self._extra_hooks) if self._extra_hooks else loop_hook @@ -1237,12 +1097,7 @@ class AgentLoop: replay_max_messages=self._max_messages, ) ) - options = ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [] - content, buttons = ask_user_outbound( - final_content or "Background task completed.", - options, - channel, - ) + content = final_content or "Background task completed." outbound_metadata: dict[str, Any] = {} if channel == "slack" and key.startswith("slack:") and key.count(":") >= 2: outbound_metadata["slack"] = {"thread_ts": key.split(":", 2)[2]} @@ -1252,7 +1107,6 @@ class AgentLoop: channel=channel, chat_id=chat_id, content=content, - buttons=buttons, metadata=outbound_metadata, ) @@ -1365,21 +1219,15 @@ class AgentLoop: logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview) meta = dict(msg.metadata or {}) - content, buttons = ask_user_outbound( - final_content, - ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [], - msg.channel, - ) - if on_stream is not None and stop_reason not in {"ask_user", "error", "tool_error"}: + if on_stream is not None and stop_reason not in {"error", "tool_error"}: meta["_streamed"] = True return OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, - content=content, + content=final_content, media=generated_media, metadata=meta, - buttons=buttons, ) async def _state_restore(self, ctx: TurnContext) -> TurnState: @@ -1446,12 +1294,11 @@ class AgentLoop: } ctx.history = ctx.session.get_history(**_hist_kwargs) - pending_ask_id = pending_ask_user_id(ctx.history) ctx.initial_messages = self._build_initial_messages( - ctx.msg, ctx.session, ctx.history, pending_ask_id, ctx.pending_summary + ctx.msg, ctx.session, ctx.history, ctx.pending_summary ) ctx.user_persisted_early = self._persist_user_message_early( - ctx.msg, ctx.session, pending_ask_id + ctx.msg, ctx.session ) if ctx.on_progress is None: diff --git a/nanobot/agent/progress_hook.py b/nanobot/agent/progress_hook.py new file mode 100644 index 000000000..a9bf6a1e9 --- /dev/null +++ b/nanobot/agent/progress_hook.py @@ -0,0 +1,178 @@ +"""Agent hook that adapts runner events into channel progress UI.""" + +from __future__ import annotations + +import inspect +import json +from typing import Any, Awaitable, Callable + +from loguru import logger + +from nanobot.agent.hook import AgentHook, AgentHookContext +from nanobot.utils.helpers import IncrementalThinkExtractor, strip_think +from nanobot.utils.progress_events import ( + build_tool_event_finish_payloads, + build_tool_event_start_payload, + invoke_on_progress, + on_progress_accepts_tool_events, +) +from nanobot.utils.tool_hints import format_tool_hints + + +class AgentProgressHook(AgentHook): + """Translate runner lifecycle events into user-visible progress signals.""" + + def __init__( + self, + on_progress: Callable[..., Awaitable[None]] | None = None, + on_stream: Callable[[str], Awaitable[None]] | None = None, + on_stream_end: Callable[..., Awaitable[None]] | None = None, + *, + channel: str = "cli", + chat_id: str = "direct", + message_id: str | None = None, + metadata: dict[str, Any] | None = None, + session_key: str | None = None, + tool_hint_max_length: int = 40, + set_tool_context: Callable[..., None] | None = None, + on_iteration: Callable[[int], None] | None = None, + ) -> None: + super().__init__(reraise=True) + self._on_progress = on_progress + self._on_stream = on_stream + self._on_stream_end = on_stream_end + self._channel = channel + self._chat_id = chat_id + self._message_id = message_id + self._metadata = metadata or {} + self._session_key = session_key + self._tool_hint_max_length = tool_hint_max_length + self._set_tool_context = set_tool_context + self._on_iteration = on_iteration + self._stream_buf = "" + self._think_extractor = IncrementalThinkExtractor() + self._reasoning_open = False + + def wants_streaming(self) -> bool: + return self._on_stream is not None + + @staticmethod + def _strip_think(text: str | None) -> str | None: + if not text: + return None + return strip_think(text) or None + + def _tool_hint(self, tool_calls: list[Any]) -> str: + return format_tool_hints(tool_calls, max_length=self._tool_hint_max_length) + + @staticmethod + def _on_progress_accepts(cb: Callable[..., Any], name: str) -> bool: + try: + sig = inspect.signature(cb) + except (TypeError, ValueError): + return False + if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()): + return True + return name in sig.parameters + + async def on_stream(self, context: AgentHookContext, delta: str) -> None: + prev_clean = strip_think(self._stream_buf) + self._stream_buf += delta + new_clean = strip_think(self._stream_buf) + incremental = new_clean[len(prev_clean) :] + + if await self._think_extractor.feed(self._stream_buf, self.emit_reasoning): + context.streamed_reasoning = True + + if incremental: + # Answer text has started; close the reasoning segment so the UI can + # lock the bubble before the answer renders below it. + await self.emit_reasoning_end() + if self._on_stream: + await self._on_stream(incremental) + + async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: + await self.emit_reasoning_end() + if self._on_stream_end: + await self._on_stream_end(resuming=resuming) + self._stream_buf = "" + self._think_extractor.reset() + + async def before_iteration(self, context: AgentHookContext) -> None: + if self._on_iteration: + self._on_iteration(context.iteration) + logger.debug( + "Starting agent loop iteration {} for session {}", + context.iteration, + self._session_key, + ) + + async def before_execute_tools(self, context: AgentHookContext) -> None: + if self._on_progress: + if not self._on_stream and not context.streamed_content: + thought = self._strip_think(context.response.content if context.response else None) + if thought: + await self._on_progress(thought) + tool_hint = self._strip_think(self._tool_hint(context.tool_calls)) + tool_events = [build_tool_event_start_payload(tc) for tc in context.tool_calls] + await invoke_on_progress( + self._on_progress, + tool_hint, + tool_hint=True, + tool_events=tool_events, + ) + for tc in context.tool_calls: + args_str = json.dumps(tc.arguments, ensure_ascii=False) + logger.info("Tool call: {}({})", tc.name, args_str[:200]) + if self._set_tool_context: + self._set_tool_context( + self._channel, + self._chat_id, + self._message_id, + self._metadata, + session_key=self._session_key, + ) + + async def emit_reasoning(self, reasoning_content: str | None) -> None: + """Publish a reasoning chunk; channel plugins decide whether to render.""" + if ( + self._on_progress + and reasoning_content + and self._on_progress_accepts(self._on_progress, "reasoning") + ): + self._reasoning_open = True + await self._on_progress(reasoning_content, reasoning=True) + + async def emit_reasoning_end(self) -> None: + """Close the current reasoning stream segment, if any was open.""" + if self._reasoning_open and self._on_progress: + self._reasoning_open = False + await self._on_progress("", reasoning_end=True) + else: + self._reasoning_open = False + + async def after_iteration(self, context: AgentHookContext) -> None: + if ( + self._on_progress + and context.tool_calls + and context.tool_events + and on_progress_accepts_tool_events(self._on_progress) + ): + tool_events = build_tool_event_finish_payloads(context) + if tool_events: + await invoke_on_progress( + self._on_progress, + "", + tool_hint=False, + tool_events=tool_events, + ) + u = context.usage or {} + logger.debug( + "LLM usage: prompt={} completion={} cached={}", + u.get("prompt_tokens", 0), + u.get("completion_tokens", 0), + u.get("cached_tokens", 0), + ) + + def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: + return self._strip_think(content) diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 7fe92ad51..37da63872 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -13,13 +13,14 @@ from typing import Any from loguru import logger from nanobot.agent.hook import AgentHook, AgentHookContext -from nanobot.agent.tools.ask import AskUserInterrupt from nanobot.agent.tools.registry import ToolRegistry from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.utils.helpers import ( + IncrementalThinkExtractor, build_assistant_message, estimate_message_tokens, estimate_prompt_tokens_chain, + extract_reasoning, find_legal_message_start, maybe_persist_tool_result, strip_think, @@ -282,23 +283,30 @@ class AgentRunner: context.tool_calls = list(response.tool_calls) self._accumulate_usage(usage, raw_usage) + reasoning_text, cleaned_content = extract_reasoning( + response.reasoning_content, + response.thinking_blocks, + response.content, + ) + response.content = cleaned_content + if reasoning_text and not context.streamed_reasoning: + await hook.emit_reasoning(reasoning_text) + await hook.emit_reasoning_end() + context.streamed_reasoning = True + if response.should_execute_tools: - tool_calls = list(response.tool_calls) - ask_index = next((i for i, tc in enumerate(tool_calls) if tc.name == "ask_user"), None) - if ask_index is not None: - tool_calls = tool_calls[: ask_index + 1] - context.tool_calls = list(tool_calls) + context.tool_calls = list(response.tool_calls) if hook.wants_streaming(): await hook.on_stream_end(context, resuming=True) assistant_message = build_assistant_message( response.content or "", - tool_calls=[tc.to_openai_tool_call() for tc in tool_calls], + tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls], reasoning_content=response.reasoning_content, thinking_blocks=response.thinking_blocks, ) messages.append(assistant_message) - tools_used.extend(tc.name for tc in tool_calls) + tools_used.extend(tc.name for tc in response.tool_calls) await self._emit_checkpoint( spec, { @@ -307,7 +315,7 @@ class AgentRunner: "model": spec.model, "assistant_message": assistant_message, "completed_tool_results": [], - "pending_tool_calls": [tc.to_openai_tool_call() for tc in tool_calls], + "pending_tool_calls": [tc.to_openai_tool_call() for tc in response.tool_calls], }, ) @@ -315,7 +323,7 @@ class AgentRunner: results, new_events, fatal_error = await self._execute_tools( spec, - tool_calls, + response.tool_calls, external_lookup_counts, workspace_violation_counts, ) @@ -323,9 +331,7 @@ class AgentRunner: context.tool_results = list(results) context.tool_events = list(new_events) completed_tool_results: list[dict[str, Any]] = [] - for tool_call, result in zip(tool_calls, results): - if isinstance(fatal_error, AskUserInterrupt) and tool_call.name == "ask_user": - continue + for tool_call, result in zip(response.tool_calls, results): tool_message = { "role": "tool", "tool_call_id": tool_call.id, @@ -340,15 +346,6 @@ class AgentRunner: messages.append(tool_message) completed_tool_results.append(tool_message) if fatal_error is not None: - if isinstance(fatal_error, AskUserInterrupt): - final_content = fatal_error.question - stop_reason = "ask_user" - context.final_content = final_content - context.stop_reason = stop_reason - if hook.wants_streaming(): - await hook.on_stream_end(context, resuming=False) - await hook.after_iteration(context) - break error = f"Error: {type(fatal_error).__name__}: {fatal_error}" final_content = error stop_reason = "tool_error" @@ -621,6 +618,8 @@ class AgentRunner: and getattr(self.provider, "supports_progress_deltas", False) is True ) + progress_state: dict[str, bool] | None = None + if wants_streaming: async def _stream(delta: str) -> None: if delta: @@ -633,6 +632,8 @@ class AgentRunner: ) elif wants_progress_streaming: stream_buf = "" + think_extractor = IncrementalThinkExtractor() + progress_state = {"reasoning_open": False} async def _stream_progress(delta: str) -> None: nonlocal stream_buf @@ -642,7 +643,15 @@ class AgentRunner: stream_buf += delta new_clean = strip_think(stream_buf) incremental = new_clean[len(prev_clean):] + + if await think_extractor.feed(stream_buf, hook.emit_reasoning): + context.streamed_reasoning = True + progress_state["reasoning_open"] = True + if incremental: + if progress_state["reasoning_open"]: + await hook.emit_reasoning_end() + progress_state["reasoning_open"] = False context.streamed_content = True await spec.progress_callback(incremental) @@ -653,16 +662,20 @@ class AgentRunner: else: coro = self.provider.chat_with_retry(**kwargs) - if timeout_s is None: - return await coro try: - return await asyncio.wait_for(coro, timeout=timeout_s) + response = ( + await coro if timeout_s is None + else await asyncio.wait_for(coro, timeout=timeout_s) + ) except asyncio.TimeoutError: return LLMResponse( content=f"Error calling LLM: timed out after {timeout_s:g}s", finish_reason="error", error_kind="timeout", ) + if progress_state and progress_state.get("reasoning_open"): + await hook.emit_reasoning_end() + return response async def _request_finalization_retry( self, @@ -724,10 +737,6 @@ class AgentRunner: ) tool_results.append(result) batch_results.append(result) - if isinstance(result[2], AskUserInterrupt): - break - if any(isinstance(error, AskUserInterrupt) for _, _, error in batch_results): - break results: list[Any] = [] events: list[dict[str, str]] = [] @@ -799,9 +808,6 @@ class AgentRunner: "status": "error", "detail": str(exc), } - if isinstance(exc, AskUserInterrupt): - event["status"] = "waiting" - return "", event, exc payload = f"Error: {type(exc).__name__}: {exc}" handled = self._classify_violation( raw_text=str(exc), diff --git a/nanobot/agent/tools/ask.py b/nanobot/agent/tools/ask.py deleted file mode 100644 index db8c83a84..000000000 --- a/nanobot/agent/tools/ask.py +++ /dev/null @@ -1,136 +0,0 @@ -"""Tool for pausing a turn until the user answers.""" - -import json -from typing import Any - -from nanobot.agent.tools.base import Tool, tool_parameters -from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema - -STRUCTURED_BUTTON_CHANNELS = frozenset({"telegram", "websocket"}) - - -class AskUserInterrupt(BaseException): - """Internal signal: the runner should stop and wait for user input.""" - - def __init__(self, question: str, options: list[str] | None = None) -> None: - self.question = question - self.options = [str(option) for option in (options or []) if str(option)] - super().__init__(question) - - -@tool_parameters( - tool_parameters_schema( - question=StringSchema( - "The question to ask before continuing. Use this only when the task needs the user's answer." - ), - options=ArraySchema( - StringSchema("A possible answer label"), - description="Optional choices. The user may still reply with free text.", - ), - required=["question"], - ) -) -class AskUserTool(Tool): - """Ask the user a blocking question.""" - - @property - def name(self) -> str: - return "ask_user" - - @property - def description(self) -> str: - return ( - "Pause and ask the user a question when their answer is required to continue. " - "Use options for likely answers; the user's reply, typed or selected, is returned as the tool result. " - "For non-blocking notifications or buttons, use the message tool instead." - ) - - @property - def exclusive(self) -> bool: - return True - - async def execute(self, question: str, options: list[str] | None = None, **_: Any) -> Any: - raise AskUserInterrupt(question=question, options=options) - - -def _tool_call_name(tool_call: dict[str, Any]) -> str: - function = tool_call.get("function") - if isinstance(function, dict) and isinstance(function.get("name"), str): - return function["name"] - name = tool_call.get("name") - return name if isinstance(name, str) else "" - - -def _tool_call_arguments(tool_call: dict[str, Any]) -> dict[str, Any]: - function = tool_call.get("function") - raw = function.get("arguments") if isinstance(function, dict) else tool_call.get("arguments") - if isinstance(raw, dict): - return raw - if isinstance(raw, str): - try: - parsed = json.loads(raw) - except json.JSONDecodeError: - return {} - return parsed if isinstance(parsed, dict) else {} - return {} - - -def pending_ask_user_id(history: list[dict[str, Any]]) -> str | None: - pending: dict[str, str] = {} - for message in history: - if message.get("role") == "assistant": - for tool_call in message.get("tool_calls") or []: - if isinstance(tool_call, dict) and isinstance(tool_call.get("id"), str): - pending[tool_call["id"]] = _tool_call_name(tool_call) - elif message.get("role") == "tool": - tool_call_id = message.get("tool_call_id") - if isinstance(tool_call_id, str): - pending.pop(tool_call_id, None) - for tool_call_id, name in reversed(pending.items()): - if name == "ask_user": - return tool_call_id - return None - - -def ask_user_tool_result_messages( - system_prompt: str, - history: list[dict[str, Any]], - tool_call_id: str, - content: str, -) -> list[dict[str, Any]]: - return [ - {"role": "system", "content": system_prompt}, - *history, - { - "role": "tool", - "tool_call_id": tool_call_id, - "name": "ask_user", - "content": content, - }, - ] - - -def ask_user_options_from_messages(messages: list[dict[str, Any]]) -> list[str]: - for message in reversed(messages): - if message.get("role") != "assistant": - continue - for tool_call in reversed(message.get("tool_calls") or []): - if not isinstance(tool_call, dict) or _tool_call_name(tool_call) != "ask_user": - continue - options = _tool_call_arguments(tool_call).get("options") - if isinstance(options, list): - return [str(option) for option in options if isinstance(option, str)] - return [] - - -def ask_user_outbound( - content: str | None, - options: list[str], - channel: str, -) -> tuple[str | None, list[list[str]]]: - if not options: - return content, [] - if channel in STRUCTURED_BUTTON_CHANNELS: - return content, [options] - option_text = "\n".join(f"{index}. {option}" for index, option in enumerate(options, 1)) - return f"{content}\n\n{option_text}" if content else option_text, [] diff --git a/nanobot/channels/base.py b/nanobot/channels/base.py index 087677494..257127d5a 100644 --- a/nanobot/channels/base.py +++ b/nanobot/channels/base.py @@ -28,6 +28,7 @@ class BaseChannel(ABC): transcription_language: str | None = None send_progress: bool = True send_tool_hints: bool = False + show_reasoning: bool = True def __init__(self, config: Any, bus: MessageBus): """ @@ -120,6 +121,53 @@ class BaseChannel(ABC): """ pass + async def send_reasoning_delta( + self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None + ) -> None: + """Stream a chunk of model reasoning/thinking content. + + Default is no-op. Channels with a native low-emphasis primitive + (Slack context block, Telegram expandable blockquote, Discord + subtext, WebUI italic bubble, ...) override to render reasoning + as a subordinate trace that updates in place as the model thinks. + + Streaming contract mirrors :meth:`send_delta`: ``_reasoning_delta`` + is a chunk, ``_reasoning_end`` ends the current reasoning segment, + and stateful implementations should key buffers by ``_stream_id`` + rather than only by ``chat_id``. + """ + return + + async def send_reasoning_end( + self, chat_id: str, metadata: dict[str, Any] | None = None + ) -> None: + """Mark the end of a reasoning stream segment. + + Default is no-op. Channels that buffer ``send_reasoning_delta`` + chunks for in-place updates use this signal to flush and freeze + the rendered group; one-shot channels can ignore it entirely. + """ + return + + async def send_reasoning(self, msg: OutboundMessage) -> None: + """Deliver a complete reasoning block. + + Default implementation reuses the streaming pair so plugins only + need to override the delta/end methods. Equivalent to one delta + with the full content followed immediately by an end marker — + keeps a single rendering path for both streamed and one-shot + reasoning (e.g. DeepSeek-R1's final-response ``reasoning_content``). + """ + if not msg.content: + return + meta = dict(msg.metadata or {}) + meta.setdefault("_reasoning_delta", True) + await self.send_reasoning_delta(msg.chat_id, msg.content, meta) + end_meta = dict(meta) + end_meta.pop("_reasoning_delta", None) + end_meta["_reasoning_end"] = True + await self.send_reasoning_end(msg.chat_id, end_meta) + @property def supports_streaming(self) -> bool: """True when config enables streaming AND this subclass implements send_delta.""" diff --git a/nanobot/channels/manager.py b/nanobot/channels/manager.py index 1d92bb879..3a6b6e50f 100644 --- a/nanobot/channels/manager.py +++ b/nanobot/channels/manager.py @@ -36,6 +36,7 @@ _SEND_RETRY_DELAYS = (1, 2, 4) _BOOL_CAMEL_ALIASES: dict[str, str] = { "send_progress": "sendProgress", "send_tool_hints": "sendToolHints", + "show_reasoning": "showReasoning", } class ChannelManager: @@ -104,6 +105,9 @@ class ChannelManager: channel.send_tool_hints = self._resolve_bool_override( section, "send_tool_hints", self.config.channels.send_tool_hints, ) + channel.show_reasoning = self._resolve_bool_override( + section, "show_reasoning", self.config.channels.show_reasoning, + ) self.channels[name] = channel logger.info("{} channel enabled", cls.display_name) except Exception as e: @@ -279,6 +283,23 @@ class ChannelManager: timeout=1.0 ) + if ( + msg.metadata.get("_reasoning_delta") + or msg.metadata.get("_reasoning_end") + or msg.metadata.get("_reasoning") + ): + # Reasoning rides its own plugin channel: only delivered + # when the destination channel opts in via ``show_reasoning`` + # and overrides the streaming primitives. Channels without + # a low-emphasis UI affordance keep the base no-op and the + # content silently drops here. ``_reasoning`` (one-shot) + # is accepted for backward compatibility with hooks that + # haven't migrated to delta/end yet. + channel = self.channels.get(msg.channel) + if channel is not None and channel.show_reasoning: + await self._send_with_retry(channel, msg) + continue + if msg.metadata.get("_progress"): if msg.metadata.get("_tool_hint") and not self._should_send_progress( msg.channel, tool_hint=True, @@ -329,7 +350,16 @@ class ChannelManager: @staticmethod async def _send_once(channel: BaseChannel, msg: OutboundMessage) -> None: """Send one outbound message without retry policy.""" - if msg.metadata.get("_stream_delta") or msg.metadata.get("_stream_end"): + if msg.metadata.get("_reasoning_end"): + await channel.send_reasoning_end(msg.chat_id, msg.metadata) + elif msg.metadata.get("_reasoning_delta"): + await channel.send_reasoning_delta(msg.chat_id, msg.content, msg.metadata) + elif msg.metadata.get("_reasoning"): + # Back-compat: one-shot reasoning. BaseChannel translates this + # to a single delta + end pair so plugins only implement the + # streaming primitives. + await channel.send_reasoning(msg) + elif msg.metadata.get("_stream_delta") or msg.metadata.get("_stream_end"): await channel.send_delta(msg.chat_id, msg.content, msg.metadata) elif not msg.metadata.get("_streamed"): await channel.send(msg) diff --git a/nanobot/channels/slack.py b/nanobot/channels/slack.py index dc8899861..be3172bff 100644 --- a/nanobot/channels/slack.py +++ b/nanobot/channels/slack.py @@ -471,7 +471,7 @@ class SlackChannel(BaseChannel): return preview.startswith(_HTML_DOWNLOAD_PREFIXES) async def _on_block_action(self, client: SocketModeClient, req: SocketModeRequest) -> None: - """Handle button clicks from ask_user blocks.""" + """Handle button clicks from inline action buttons.""" await client.send_socket_mode_response(SocketModeResponse(envelope_id=req.envelope_id)) payload = req.payload or {} actions = payload.get("actions") or [] @@ -568,7 +568,7 @@ class SlackChannel(BaseChannel): @staticmethod def _build_button_blocks(text: str, buttons: list[list[str]]) -> list[dict[str, Any]]: - """Build Slack Block Kit blocks with action buttons for ask_user choices.""" + """Build Slack Block Kit blocks with action buttons.""" blocks: list[dict[str, Any]] = [ {"type": "section", "text": {"type": "mrkdwn", "text": text[:3000]}}, ] @@ -579,7 +579,7 @@ class SlackChannel(BaseChannel): "type": "button", "text": {"type": "plain_text", "text": label[:75]}, "value": label[:75], - "action_id": f"ask_user_{label[:50]}", + "action_id": f"btn_{label[:50]}", }) if elements: blocks.append({"type": "actions", "elements": elements[:25]}) diff --git a/nanobot/channels/websocket.py b/nanobot/channels/websocket.py index 86a1e9654..a77c8594f 100644 --- a/nanobot/channels/websocket.py +++ b/nanobot/channels/websocket.py @@ -55,14 +55,6 @@ def _normalize_config_path(path: str) -> str: return _strip_trailing_slash(path) -def _append_buttons_as_text(text: str, buttons: list[list[str]]) -> str: - labels = [label for row in buttons for label in row if label] - if not labels: - return text - fallback = "\n".join(f"{index}. {label}" for index, label in enumerate(labels, 1)) - return f"{text}\n\n{fallback}" if text else fallback - - class WebSocketConfig(Base): """WebSocket server channel configuration. @@ -1468,16 +1460,11 @@ class WebSocketChannel(BaseChannel): await self.send_session_updated(msg.chat_id) return text = msg.content - if msg.buttons: - text = _append_buttons_as_text(text, msg.buttons) payload: dict[str, Any] = { "event": "message", "chat_id": msg.chat_id, "text": text, } - if msg.buttons: - payload["buttons"] = msg.buttons - payload["button_prompt"] = msg.content if msg.media: payload["media"] = msg.media urls: list[dict[str, str]] = [] @@ -1500,6 +1487,54 @@ class WebSocketChannel(BaseChannel): for connection in conns: await self._safe_send_to(connection, raw, label=" ") + async def send_reasoning_delta( + self, + chat_id: str, + delta: str, + metadata: dict[str, Any] | None = None, + ) -> None: + """Push one chunk of model reasoning. Mirrors ``send_delta`` shape so + WebUI receives a stream that opens, updates in place, and closes — + rendered above the active assistant bubble with a shimmer header + until the matching ``reasoning_end`` arrives. + """ + conns = list(self._subs.get(chat_id, ())) + if not conns or not delta: + return + meta = metadata or {} + body: dict[str, Any] = { + "event": "reasoning_delta", + "chat_id": chat_id, + "text": delta, + } + stream_id = meta.get("_stream_id") + if stream_id is not None: + body["stream_id"] = stream_id + raw = json.dumps(body, ensure_ascii=False) + for connection in conns: + await self._safe_send_to(connection, raw, label=" reasoning ") + + async def send_reasoning_end( + self, + chat_id: str, + metadata: dict[str, Any] | None = None, + ) -> None: + """Close the current reasoning stream segment for in-place renderers.""" + conns = list(self._subs.get(chat_id, ())) + if not conns: + return + meta = metadata or {} + body: dict[str, Any] = { + "event": "reasoning_end", + "chat_id": chat_id, + } + stream_id = meta.get("_stream_id") + if stream_id is not None: + body["stream_id"] = stream_id + raw = json.dumps(body, ensure_ascii=False) + for connection in conns: + await self._safe_send_to(connection, raw, label=" reasoning_end ") + async def send_delta( self, chat_id: str, diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 0d71d91db..e02653bf9 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -176,13 +176,15 @@ def _print_agent_response( response: str, render_markdown: bool, metadata: dict | None = None, + show_header: bool = True, ) -> None: """Render assistant response with consistent terminal styling.""" console = _make_console() content = response or "" body = _response_renderable(content, render_markdown, metadata) - console.print() - console.print(f"[cyan]{__logo__} nanobot[/cyan]") + if show_header: + console.print() + console.print(f"[cyan]{__logo__} nanobot[/cyan]") console.print(body) console.print() @@ -228,42 +230,70 @@ async def _print_interactive_response( await run_in_terminal(_write) -def _print_cli_progress_line(text: str, thinking: ThinkingSpinner | None) -> None: +def _print_cli_progress_line(text: str, thinking: ThinkingSpinner | None, renderer: StreamRenderer | None = None) -> None: """Print a CLI progress line, pausing the spinner if needed.""" if not text.strip(): return - with thinking.pause() if thinking else nullcontext(): - console.print(f" [dim]↳ {text}[/dim]") + target = renderer.console if renderer else console + pause = renderer.pause_spinner() if renderer else (thinking.pause() if thinking else nullcontext()) + with pause: + if renderer: + renderer.ensure_header() + target.print(f" [dim]↳ {text}[/dim]") -async def _print_interactive_progress_line(text: str, renderer: StreamRenderer | None) -> None: - """Print an interactive progress line, pausing the renderer's spinner if needed.""" +def _print_cli_reasoning(text: str, thinking: ThinkingSpinner | None, renderer: StreamRenderer | None = None) -> None: + """Print reasoning/thinking content in a distinct style.""" if not text.strip(): return - with renderer.pause() if renderer else nullcontext(): - await _print_interactive_line(text) + target = renderer.console if renderer else console + pause = renderer.pause_spinner() if renderer else (thinking.pause() if thinking else nullcontext()) + with pause: + if renderer: + renderer.ensure_header() + target.print(f"[dim italic]✻ {text}[/dim italic]") + + +async def _print_interactive_progress_line(text: str, thinking: ThinkingSpinner | None, renderer: StreamRenderer | None = None) -> None: + """Print an interactive progress line, pausing the spinner if needed.""" + if not text.strip(): + return + if renderer: + with renderer.pause_spinner(): + renderer.ensure_header() + renderer.console.print(f" [dim]↳ {text}[/dim]") + else: + with thinking.pause() if thinking else nullcontext(): + await _print_interactive_line(text) async def _maybe_print_interactive_progress( msg: Any, - renderer: StreamRenderer | None, + thinking: ThinkingSpinner | None, channels_config: Any, + renderer: StreamRenderer | None = None, ) -> bool: metadata = msg.metadata or {} if metadata.get("_retry_wait"): - await _print_interactive_progress_line(msg.content, renderer) + await _print_interactive_progress_line(msg.content, thinking, renderer) return True if not metadata.get("_progress"): return False is_tool_hint = metadata.get("_tool_hint", False) + is_reasoning = metadata.get("_reasoning", False) or metadata.get("_reasoning_delta", False) + if is_reasoning: + if channels_config and not channels_config.show_reasoning: + return True + _print_cli_reasoning(msg.content, thinking, renderer) + return True if channels_config and is_tool_hint and not channels_config.send_tool_hints: return True if channels_config and not is_tool_hint and not channels_config.send_progress: return True - await _print_interactive_progress_line(msg.content, renderer) + await _print_interactive_progress_line(msg.content, thinking, renderer) return True @@ -1064,13 +1094,20 @@ def agent( # Shared reference for progress callbacks _thinking: ThinkingSpinner | None = None - async def _cli_progress(content: str, *, tool_hint: bool = False, **_kwargs: Any) -> None: - ch = agent_loop.channels_config - if ch and tool_hint and not ch.send_tool_hints: - return - if ch and not tool_hint and not ch.send_progress: - return - _print_cli_progress_line(content, _thinking) + def _make_progress(renderer: StreamRenderer | None = None): + async def _cli_progress(content: str, *, tool_hint: bool = False, reasoning: bool = False, **_kwargs: Any) -> None: + ch = agent_loop.channels_config + if reasoning: + if ch and not ch.show_reasoning: + return + _print_cli_reasoning(content, _thinking, renderer) + return + if ch and tool_hint and not ch.send_tool_hints: + return + if ch and not tool_hint and not ch.send_progress: + return + _print_cli_progress_line(content, _thinking, renderer) + return _cli_progress if message: # Single message mode — direct call, no bus needed @@ -1082,16 +1119,20 @@ def agent( ) response = await agent_loop.process_direct( message, session_id, - on_progress=_cli_progress, + on_progress=_make_progress(renderer), on_stream=renderer.on_delta, on_stream_end=renderer.on_end, ) if not renderer.streamed: await renderer.close() + print_kwargs: dict[str, Any] = {} + if renderer.header_printed: + print_kwargs["show_header"] = False _print_agent_response( response.content if response else "", render_markdown=markdown, metadata=response.metadata if response else None, + **print_kwargs, ) await agent_loop.close_mcp() @@ -1154,6 +1195,7 @@ def agent( msg, renderer, agent_loop.channels_config, + renderer, ): continue @@ -1215,8 +1257,14 @@ def agent( if content and not meta.get("_streamed"): if renderer: await renderer.close() + print_kwargs: dict[str, Any] = {} + if renderer and renderer.header_printed: + print_kwargs["show_header"] = False _print_agent_response( - content, render_markdown=markdown, metadata=meta, + content, + render_markdown=markdown, + metadata=meta, + **print_kwargs, ) elif renderer and not renderer.streamed: await renderer.close() diff --git a/nanobot/cli/stream.py b/nanobot/cli/stream.py index c6b5a87ad..24a141cdd 100644 --- a/nanobot/cli/stream.py +++ b/nanobot/cli/stream.py @@ -1,13 +1,16 @@ """Streaming renderer for CLI output. -Uses Rich Live with auto_refresh=False for stable, flicker-free -markdown rendering during streaming. Ellipsis mode handles overflow. +Uses Rich Live with ``transient=True`` for in-place markdown updates during +streaming. After the live display stops, a final clean render is printed +so the content persists on screen. ``transient=True`` ensures the live +area is erased before ``stop()`` returns, avoiding the duplication bug +that plagued earlier approaches. """ from __future__ import annotations import sys -import time +from contextlib import contextmanager, nullcontext from rich.console import Console from rich.live import Live @@ -15,6 +18,16 @@ from rich.markdown import Markdown from rich.text import Text +def _clear_current_line(console: Console) -> None: + """Erase a transient status line before printing persistent output.""" + file = console.file + isatty = getattr(file, "isatty", lambda: False) + if not isatty(): + return + file.write("\r\x1b[2K") + file.flush() + + def _make_console() -> Console: """Create a Console that emits plain text when stdout is not a TTY. @@ -34,6 +47,7 @@ class ThinkingSpinner: def __init__(self, console: Console | None = None, bot_name: str = "nanobot"): c = console or _make_console() + self._console = c self._spinner = c.status(f"[dim]{bot_name} is thinking...[/dim]", spinner="dots") self._active = False @@ -45,6 +59,7 @@ class ThinkingSpinner: def __exit__(self, *exc): self._active = False self._spinner.stop() + _clear_current_line(self._console) return False def pause(self): @@ -55,6 +70,7 @@ class ThinkingSpinner: def _ctx(): if self._spinner and self._active: self._spinner.stop() + _clear_current_line(self._console) try: yield finally: @@ -65,13 +81,14 @@ class ThinkingSpinner: class StreamRenderer: - """Rich Live streaming with markdown. auto_refresh=False avoids render races. + """Streaming renderer with Rich Live for in-place updates. - Deltas arrive pre-filtered (no tags) from the agent loop. + During streaming: updates content in-place via Rich Live. + On end: stops Live (transient=True erases it), then prints final render. Flow per round: - spinner -> first visible delta -> header + Live renders -> - on_end -> Live stops (content stays on screen) + spinner -> first delta -> header + Live updates -> + on_end -> stop Live + final render """ def __init__( @@ -86,14 +103,24 @@ class StreamRenderer: self._bot_name = bot_name self._bot_icon = bot_icon self._buf = "" - self._live: Live | None = None - self._t = 0.0 self.streamed = False + self._console = _make_console() + self._live: Live | None = None self._spinner: ThinkingSpinner | None = None + self._header_printed = False self._start_spinner() - def _render(self): - return Markdown(self._buf) if self._md and self._buf else Text(self._buf or "") + def _renderable(self): + """Create a renderable from the current buffer.""" + if self._md and self._buf: + return Markdown(self._buf) + return Text(self._buf or "") + + def _render_str(self) -> str: + """Render current buffer to a plain string via Rich.""" + with self._console.capture() as cap: + self._console.print(self._renderable()) + return cap.get() def _start_spinner(self) -> None: if self._show_spinner: @@ -105,37 +132,85 @@ class StreamRenderer: self._spinner.__exit__(None, None, None) self._spinner = None + @property + def console(self) -> Console: + """Expose the Live's console so external print functions can use it.""" + return self._console + + @property + def header_printed(self) -> bool: + """Whether this turn has already opened the assistant output block.""" + return self._header_printed + + def ensure_header(self) -> None: + """Stop transient status and print the assistant header once.""" + # A turn can print trace rows before the final answer, then restart the + # spinner while tools run. The next answer delta still needs to stop + # that spinner even though the header was already printed. + self._stop_spinner() + if self._header_printed: + return + self._console.print() + header = f"{self._bot_icon} {self._bot_name}" if self._bot_icon else self._bot_name + self._console.print(f"[cyan]{header}[/cyan]") + self._header_printed = True + + def pause_spinner(self): + """Context manager: temporarily stop transient output for clean trace lines.""" + @contextmanager + def _pause(): + live_was_active = self._live is not None + if self._live: + # Trace/reasoning can arrive after answer streaming has started. + # Stop the transient Live view first so it does not leak a raw + # partial markdown frame before the trace line. + self._live.stop() + self._live = None + with self._spinner.pause() if self._spinner else nullcontext(): + yield + # If more answer deltas arrive after the trace, on_delta() will + # create a fresh Live using the existing buffer. If no deltas arrive, + # on_end() prints the final buffered answer once. + if live_was_active: + return + + return _pause() + async def on_delta(self, delta: str) -> None: self.streamed = True self._buf += delta if self._live is None: if not self._buf.strip(): return - self._stop_spinner() - c = _make_console() - c.print() - header = f"{self._bot_icon} {self._bot_name}" if self._bot_icon else self._bot_name - c.print(f"[cyan]{header}[/cyan]") - self._live = Live(self._render(), console=c, auto_refresh=False) + self.ensure_header() + self._live = Live( + self._renderable(), + console=self._console, + auto_refresh=False, + transient=True, + ) self._live.start() - now = time.monotonic() - if (now - self._t) > 0.15: - self._live.update(self._render()) - self._live.refresh() - self._t = now + else: + self._live.update(self._renderable()) + self._live.refresh() async def on_end(self, *, resuming: bool = False) -> None: if self._live: - self._live.update(self._render()) + # Double-refresh to sync _shape before stop() calls refresh(). + self._live.refresh() + self._live.update(self._renderable()) self._live.refresh() self._live.stop() self._live = None self._stop_spinner() + if self._buf.strip(): + # Print final rendered content (persists after Live is gone). + out = sys.stdout + out.write(self._render_str()) + out.flush() if resuming: self._buf = "" self._start_spinner() - else: - _make_console().print() def stop_for_input(self) -> None: """Stop spinner before user input to avoid prompt_toolkit conflicts.""" @@ -143,7 +218,6 @@ class StreamRenderer: def pause(self): """Context manager: pause spinner for external output. No-op once streaming has started.""" - from contextlib import nullcontext if self._spinner: return self._spinner.pause() return nullcontext() diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 1cab02763..a112b932d 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -35,6 +35,7 @@ class ChannelsConfig(Base): send_progress: bool = True # stream agent's text progress to the channel send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…")) + show_reasoning: bool = True # surface model reasoning when channel implements it send_max_retries: int = Field(default=3, ge=0, le=10) # Max delivery attempts (initial send included) transcription_provider: str = "groq" # Voice transcription backend: "groq" or "openai" transcription_language: str | None = Field(default=None, pattern=r"^[a-z]{2,3}$") # Optional ISO-639-1 hint for audio transcription diff --git a/nanobot/session/manager.py b/nanobot/session/manager.py index 47d98976b..188911435 100644 --- a/nanobot/session/manager.py +++ b/nanobot/session/manager.py @@ -25,6 +25,7 @@ FILE_MAX_MESSAGES = 2000 _MESSAGE_TIME_PREFIX_RE = re.compile(r"^\[Message Time: [^\]]+\]\n?") _LOCAL_IMAGE_BREADCRUMB_RE = re.compile(r"^\[image: (?:/|~)[^\]]+\]\s*$") _TOOL_CALL_ECHO_RE = re.compile(r'^\s*(?:generate_image|message)\([^)]*\)\s*$') +_SESSION_PREVIEW_MAX_CHARS = 120 def _sanitize_assistant_replay_text(content: str) -> str: @@ -43,6 +44,27 @@ def _sanitize_assistant_replay_text(content: str) -> str: return "\n".join(lines).strip() +def _text_preview(content: Any) -> str: + """Return compact display text for session lists.""" + if isinstance(content, str): + text = content + elif isinstance(content, list): + parts: list[str] = [] + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + value = block.get("text") + if isinstance(value, str): + parts.append(value) + text = " ".join(parts) + else: + return "" + text = _sanitize_assistant_replay_text(text) + text = re.sub(r"\s+", " ", text).strip() + if len(text) > _SESSION_PREVIEW_MAX_CHARS: + text = text[: _SESSION_PREVIEW_MAX_CHARS - 1].rstrip() + "…" + return text + + @dataclass class Session: """A conversation session.""" @@ -560,7 +582,7 @@ class SessionManager: for path in self.sessions_dir.glob("*.jsonl"): fallback_key = path.stem.replace("_", ":", 1) try: - # Read just the metadata line + # Read the metadata line and a small preview for WebUI/session lists. with open(path, encoding="utf-8") as f: first_line = f.readline().strip() if first_line: @@ -569,11 +591,29 @@ class SessionManager: key = data.get("key") or path.stem.replace("_", ":", 1) metadata = data.get("metadata", {}) title = metadata.get("title") if isinstance(metadata, dict) else None + preview = "" + fallback_preview = "" + for line in f: + if not line.strip(): + continue + item = json.loads(line) + if item.get("_type") == "metadata": + continue + text = _text_preview(item.get("content")) + if not text: + continue + if item.get("role") == "user": + preview = text + break + if not fallback_preview and item.get("role") == "assistant": + fallback_preview = text + preview = preview or fallback_preview sessions.append({ "key": key, "created_at": data.get("created_at"), "updated_at": data.get("updated_at"), "title": title if isinstance(title, str) else "", + "preview": preview, "path": str(path) }) except Exception: @@ -588,6 +628,14 @@ class SessionManager: if isinstance(repaired.metadata.get("title"), str) else "" ), + "preview": next( + ( + text + for msg in repaired.messages + if (text := _text_preview(msg.get("content"))) + ), + "", + ), "path": str(path) }) continue diff --git a/nanobot/skills/update-setup/SKILL.md b/nanobot/skills/update-setup/SKILL.md index 7e9d5cc60..0838168f5 100644 --- a/nanobot/skills/update-setup/SKILL.md +++ b/nanobot/skills/update-setup/SKILL.md @@ -11,7 +11,7 @@ Generate a personalized upgrade skill for this workspace. Use `read_file` to check if `skills/update/SKILL.md` already exists in the workspace. -If it exists, use `ask_user` to ask: "An upgrade skill already exists. Reconfigure?" with options ["yes", "no"]. If no, stop here. +If it exists, ask the user: "An upgrade skill already exists. Reconfigure?" Wait for the user's reply. If no, stop here. ## Step 2: Current Version and Install Clues @@ -38,9 +38,9 @@ answer or confirmation, not from inference alone. If you cannot get a clear answer, stop and ask the user to rerun this setup when they know how nanobot was installed. -Use `ask_user` for the questions below, one question per call. If `ask_user` is -not available or cannot collect the answer, ask in normal chat and stop without -writing the skill. +Ask the user the questions below, one at a time, in your response text. Wait for +the user's reply before proceeding to the next question. If you cannot get a clear +answer, stop without writing the skill. **Question 1 — Install method:** diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index 0655b4439..2a969298c 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -71,6 +71,93 @@ def strip_think(text: str) -> str: return text.strip() +def extract_think(text: str) -> tuple[str | None, str]: + """Extract thinking content from inline ```` / ```` blocks. + + Returns ``(thinking_text, cleaned_text)``. Only closed blocks are + extracted; unclosed streaming prefixes are stripped from the cleaned + text but not surfaced — :func:`strip_think` handles that case. + """ + parts: list[str] = [] + for m in re.finditer(r"([\s\S]*?)", text): + parts.append(m.group(1).strip()) + for m in re.finditer(r"([\s\S]*?)", text): + parts.append(m.group(1).strip()) + thinking = "\n\n".join(parts) if parts else None + return thinking, strip_think(text) + + +class IncrementalThinkExtractor: + """Stateful inline ```` extractor for streaming buffers. + + Streaming providers expose only a single content delta channel. When a + model embeds reasoning in ``...`` blocks inside that + channel, callers need to surface the reasoning incrementally as it + arrives without re-emitting earlier text. This holds the "already + emitted" cursor so the runner and the loop hook share one shape. + """ + + __slots__ = ("_emitted",) + + def __init__(self) -> None: + self._emitted = "" + + def reset(self) -> None: + self._emitted = "" + + async def feed(self, buf: str, emit: Any) -> bool: + """Emit any new thinking text found in ``buf``. + + Returns True if anything was emitted this call. ``emit`` is an + async callable taking a single string (typically + ``hook.emit_reasoning``). + """ + thinking, _ = extract_think(buf) + if not thinking or thinking == self._emitted: + return False + new = thinking[len(self._emitted):].strip() + self._emitted = thinking + if not new: + return False + await emit(new) + return True + + +def extract_reasoning( + reasoning_content: str | None, + thinking_blocks: list[dict[str, Any]] | None, + content: str | None, +) -> tuple[str | None, str | None]: + """Return ``(reasoning_text, cleaned_content)`` from one model response. + + Single source of truth for "what reasoning did this response carry, and + what answer text remains after we peel it out". Fallback order: + + 1. Dedicated ``reasoning_content`` (DeepSeek-R1, Kimi, MiMo, OpenAI + reasoning models, Bedrock). + 2. Anthropic ``thinking_blocks``. + 3. Inline ```` / ```` blocks in ``content``. + + Only one source contributes per response; lower-priority sources are + ignored if a higher-priority one is present, but inline ```` + tags are still stripped from ``content`` so they never leak into the + final answer. + """ + if reasoning_content: + return reasoning_content, strip_think(content) if content else content + if thinking_blocks: + parts = [ + tb.get("thinking", "") + for tb in thinking_blocks + if isinstance(tb, dict) and tb.get("type") == "thinking" + ] + joined = "\n\n".join(p for p in parts if p) + return (joined or None), strip_think(content) if content else content + if content: + return extract_think(content) + return None, content + + def detect_image_mime(data: bytes) -> str | None: """Detect image MIME type from magic bytes, ignoring file extension.""" if data[:8] == b"\x89PNG\r\n\x1a\n": diff --git a/tests/agent/conftest.py b/tests/agent/conftest.py new file mode 100644 index 000000000..57f678aa9 --- /dev/null +++ b/tests/agent/conftest.py @@ -0,0 +1,93 @@ +"""Shared fixtures and helpers for agent tests.""" + +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from nanobot.agent.loop import AgentLoop +from nanobot.bus.queue import MessageBus +from nanobot.providers.base import LLMProvider + + +def make_provider( + default_model: str = "test-model", + *, + max_tokens: int = 4096, + spec: bool = True, +) -> MagicMock: + """Create a spec-limited LLM provider mock.""" + mock_type = MagicMock(spec=LLMProvider) if spec else MagicMock() + provider = mock_type + provider.get_default_model.return_value = default_model + provider.generation = SimpleNamespace( + max_tokens=max_tokens, + temperature=0.1, + reasoning_effort=None, + ) + provider.estimate_prompt_tokens.return_value = (10_000, "test") + return provider + + +def make_loop( + tmp_path: Path, + *, + model: str = "test-model", + context_window_tokens: int = 128_000, + session_ttl_minutes: int = 0, + max_messages: int = 120, + unified_session: bool = False, + mcp_servers: dict | None = None, + tools_config=None, + model_presets: dict | None = None, + hooks: list | None = None, + provider: MagicMock | None = None, + patch_deps: bool = False, +) -> AgentLoop: + """Create a real AgentLoop for testing. + + Args: + patch_deps: If True, patch ContextBuilder/SessionManager/SubagentManager + during construction (needed when workspace has no real files). + """ + bus = MessageBus() + if provider is None: + provider = make_provider(default_model=model) + + kwargs = dict( + bus=bus, + provider=provider, + workspace=tmp_path, + model=model, + context_window_tokens=context_window_tokens, + session_ttl_minutes=session_ttl_minutes, + max_messages=max_messages, + unified_session=unified_session, + ) + if mcp_servers is not None: + kwargs["mcp_servers"] = mcp_servers + if tools_config is not None: + kwargs["tools_config"] = tools_config + if model_presets is not None: + kwargs["model_presets"] = model_presets + if hooks is not None: + kwargs["hooks"] = hooks + + if patch_deps: + with patch("nanobot.agent.loop.ContextBuilder"), \ + patch("nanobot.agent.loop.SessionManager"), \ + patch("nanobot.agent.loop.SubagentManager") as MockSubMgr: + MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0) + return AgentLoop(**kwargs) + return AgentLoop(**kwargs) + + +@pytest.fixture +def loop_factory(tmp_path): + """Fixture providing a factory for creating AgentLoop instances.""" + def _factory(**kwargs): + return make_loop(tmp_path, **kwargs) + return _factory diff --git a/tests/agent/test_ask_user.py b/tests/agent/test_ask_user.py deleted file mode 100644 index a192ee4a6..000000000 --- a/tests/agent/test_ask_user.py +++ /dev/null @@ -1,241 +0,0 @@ -import asyncio -from unittest.mock import MagicMock - -import pytest - -from nanobot.agent.loop import AgentLoop -from nanobot.agent.runner import AgentRunner, AgentRunSpec -from nanobot.agent.tools.ask import AskUserInterrupt, AskUserTool -from nanobot.agent.tools.base import Tool, tool_parameters -from nanobot.agent.tools.registry import ToolRegistry -from nanobot.agent.tools.schema import tool_parameters_schema -from nanobot.bus.events import InboundMessage -from nanobot.bus.queue import MessageBus -from nanobot.providers.base import GenerationSettings, LLMResponse, ToolCallRequest - - -def _make_provider(chat_with_retry): - async def chat_stream_with_retry(**kwargs): - kwargs.pop("on_content_delta", None) - return await chat_with_retry(**kwargs) - - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - provider.generation = GenerationSettings() - provider.chat_with_retry = chat_with_retry - provider.chat_stream_with_retry = chat_stream_with_retry - return provider - - -def test_ask_user_tool_schema_and_interrupt(): - tool = AskUserTool() - schema = tool.to_schema()["function"] - - assert schema["name"] == "ask_user" - assert "question" in schema["parameters"]["required"] - assert schema["parameters"]["properties"]["options"]["type"] == "array" - - with pytest.raises(AskUserInterrupt) as exc: - asyncio.run(tool.execute("Continue?", options=["Yes", "No"])) - - assert exc.value.question == "Continue?" - assert exc.value.options == ["Yes", "No"] - - -@pytest.mark.asyncio -async def test_runner_pauses_on_ask_user_without_executing_later_tools(): - @tool_parameters(tool_parameters_schema(required=[])) - class LaterTool(Tool): - called = False - - @property - def name(self) -> str: - return "later" - - @property - def description(self) -> str: - return "Should not run after ask_user pauses the turn." - - async def execute(self, **kwargs): - self.called = True - return "later result" - - async def chat_with_retry(**kwargs): - return LLMResponse( - content="", - finish_reason="tool_calls", - tool_calls=[ - ToolCallRequest( - id="call_ask", - name="ask_user", - arguments={"question": "Install this package?", "options": ["Yes", "No"]}, - ), - ToolCallRequest(id="call_later", name="later", arguments={}), - ], - ) - - later = LaterTool() - tools = ToolRegistry() - tools.register(AskUserTool()) - tools.register(later) - - result = await AgentRunner(_make_provider(chat_with_retry)).run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "continue"}], - tools=tools, - model="test-model", - max_iterations=3, - max_tool_result_chars=16_000, - concurrent_tools=True, - )) - - assert result.stop_reason == "ask_user" - assert result.final_content == "Install this package?" - assert "ask_user" in result.tools_used - assert later.called is False - assert result.messages[-1]["role"] == "assistant" - tool_calls = result.messages[-1]["tool_calls"] - assert [tool_call["function"]["name"] for tool_call in tool_calls] == ["ask_user"] - assert not any(message.get("name") == "ask_user" for message in result.messages) - - -@pytest.mark.asyncio -async def test_ask_user_text_fallback_resumes_with_next_message(tmp_path): - seen_messages: list[list[dict]] = [] - - async def chat_with_retry(**kwargs): - seen_messages.append(kwargs["messages"]) - if len(seen_messages) == 1: - return LLMResponse( - content="", - finish_reason="tool_calls", - tool_calls=[ - ToolCallRequest( - id="call_ask", - name="ask_user", - arguments={ - "question": "Install the optional package?", - "options": ["Install", "Skip"], - }, - ) - ], - ) - return LLMResponse(content="Skipped install.", usage={}) - - loop = AgentLoop( - bus=MessageBus(), - provider=_make_provider(chat_with_retry), - workspace=tmp_path, - model="test-model", - ) - - async def on_stream(delta: str) -> None: - pass - - async def on_stream_end(**kwargs) -> None: - pass - - first = await loop._process_message( - InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="set it up"), - on_stream=on_stream, - on_stream_end=on_stream_end, - ) - - assert first is not None - assert first.content == "Install the optional package?\n\n1. Install\n2. Skip" - assert first.buttons == [] - assert "_streamed" not in first.metadata - - session = loop.sessions.get_or_create("cli:direct") - assert any(message.get("role") == "assistant" and message.get("tool_calls") for message in session.messages) - assert not any(message.get("role") == "tool" and message.get("name") == "ask_user" for message in session.messages) - - second = await loop._process_message( - InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="Skip") - ) - - assert second is not None - assert second.content == "Skipped install." - assert any( - message.get("role") == "tool" - and message.get("name") == "ask_user" - and message.get("content") == "Skip" - for message in seen_messages[-1] - ) - assert not any( - message.get("role") == "user" and message.get("content") == "Skip" - for message in session.messages - ) - assert any( - message.get("role") == "tool" - and message.get("name") == "ask_user" - and message.get("content") == "Skip" - for message in session.messages - ) - - -@pytest.mark.asyncio -async def test_ask_user_keeps_buttons_for_telegram(tmp_path): - async def chat_with_retry(**kwargs): - return LLMResponse( - content="", - finish_reason="tool_calls", - tool_calls=[ - ToolCallRequest( - id="call_ask", - name="ask_user", - arguments={ - "question": "Install the optional package?", - "options": ["Install", "Skip"], - }, - ) - ], - ) - - loop = AgentLoop( - bus=MessageBus(), - provider=_make_provider(chat_with_retry), - workspace=tmp_path, - model="test-model", - ) - - response = await loop._process_message( - InboundMessage(channel="telegram", sender_id="user", chat_id="123", content="set it up") - ) - - assert response is not None - assert response.content == "Install the optional package?" - assert response.buttons == [["Install", "Skip"]] - - -@pytest.mark.asyncio -async def test_ask_user_keeps_buttons_for_websocket(tmp_path): - async def chat_with_retry(**kwargs): - return LLMResponse( - content="", - finish_reason="tool_calls", - tool_calls=[ - ToolCallRequest( - id="call_ask", - name="ask_user", - arguments={ - "question": "Install the optional package?", - "options": ["Install", "Skip"], - }, - ) - ], - ) - - loop = AgentLoop( - bus=MessageBus(), - provider=_make_provider(chat_with_retry), - workspace=tmp_path, - model="test-model", - ) - - response = await loop._process_message( - InboundMessage(channel="websocket", sender_id="user", chat_id="123", content="set it up") - ) - - assert response is not None - assert response.content == "Install the optional package?" - assert response.buttons == [["Install", "Skip"]] diff --git a/tests/agent/test_autocompact_unit.py b/tests/agent/test_autocompact_unit.py new file mode 100644 index 000000000..d501770dd --- /dev/null +++ b/tests/agent/test_autocompact_unit.py @@ -0,0 +1,554 @@ +"""Direct unit tests for AutoCompact class methods in isolation.""" + +from datetime import datetime, timedelta +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from nanobot.agent.autocompact import AutoCompact +from nanobot.session.manager import Session, SessionManager + + +def _make_session( + key: str = "cli:test", + messages: list | None = None, + last_consolidated: int = 0, + updated_at: datetime | None = None, + metadata: dict | None = None, +) -> Session: + """Create a Session with sensible defaults for testing.""" + session = Session( + key=key, + messages=messages or [], + metadata=metadata or {}, + last_consolidated=last_consolidated, + ) + if updated_at is not None: + session.updated_at = updated_at + return session + + +def _make_autocompact( + ttl: int = 15, + sessions: SessionManager | None = None, + consolidator: MagicMock | None = None, +) -> AutoCompact: + """Create an AutoCompact with mock dependencies.""" + if sessions is None: + sessions = MagicMock(spec=SessionManager) + if consolidator is None: + consolidator = MagicMock() + consolidator.archive = AsyncMock(return_value="Summary.") + return AutoCompact( + sessions=sessions, + consolidator=consolidator, + session_ttl_minutes=ttl, + ) + + +def _add_turns(session: Session, turns: int, *, prefix: str = "msg") -> None: + """Append simple user/assistant turns to a session.""" + for i in range(turns): + session.add_message("user", f"{prefix} user {i}") + session.add_message("assistant", f"{prefix} assistant {i}") + + +# --------------------------------------------------------------------------- +# __init__ +# --------------------------------------------------------------------------- + + +class TestInit: + """Test AutoCompact.__init__ stores constructor arguments correctly.""" + + def test_stores_ttl(self): + """_ttl should match session_ttl_minutes argument.""" + ac = _make_autocompact(ttl=30) + assert ac._ttl == 30 + + def test_default_ttl_is_zero(self): + """Default TTL should be 0.""" + ac = _make_autocompact(ttl=0) + assert ac._ttl == 0 + + def test_archiving_set_is_empty(self): + """_archiving should start as an empty set.""" + ac = _make_autocompact() + assert ac._archiving == set() + + def test_summaries_dict_is_empty(self): + """_summaries should start as an empty dict.""" + ac = _make_autocompact() + assert ac._summaries == {} + + def test_stores_sessions_reference(self): + """sessions attribute should reference the passed SessionManager.""" + mock_sm = MagicMock(spec=SessionManager) + ac = _make_autocompact(sessions=mock_sm) + assert ac.sessions is mock_sm + + def test_stores_consolidator_reference(self): + """consolidator attribute should reference the passed Consolidator.""" + mock_c = MagicMock() + ac = _make_autocompact(consolidator=mock_c) + assert ac.consolidator is mock_c + + +# --------------------------------------------------------------------------- +# _is_expired +# --------------------------------------------------------------------------- + + +class TestIsExpired: + """Test AutoCompact._is_expired edge cases.""" + + def test_ttl_zero_always_false(self): + """TTL=0 means auto-compact is disabled; always returns False.""" + ac = _make_autocompact(ttl=0) + old = datetime.now() - timedelta(days=365) + assert ac._is_expired(old) is False + + def test_none_timestamp_returns_false(self): + """None timestamp should return False.""" + ac = _make_autocompact(ttl=15) + assert ac._is_expired(None) is False + + def test_empty_string_timestamp_returns_false(self): + """Empty string timestamp should return False (falsy).""" + ac = _make_autocompact(ttl=15) + assert ac._is_expired("") is False + + def test_exactly_at_boundary_is_expired(self): + """Timestamp exactly at TTL boundary should be expired (>=).""" + ac = _make_autocompact(ttl=15) + now = datetime(2026, 1, 1, 12, 0, 0) + ts = now - timedelta(minutes=15) + assert ac._is_expired(ts, now=now) is True + + def test_just_under_boundary_not_expired(self): + """Timestamp just under TTL boundary should NOT be expired.""" + ac = _make_autocompact(ttl=15) + now = datetime(2026, 1, 1, 12, 0, 0) + ts = now - timedelta(minutes=14, seconds=59) + assert ac._is_expired(ts, now=now) is False + + def test_iso_string_parses_correctly(self): + """ISO format string timestamp should be parsed and evaluated.""" + ac = _make_autocompact(ttl=15) + now = datetime(2026, 1, 1, 12, 0, 0) + ts = (now - timedelta(minutes=20)).isoformat() + assert ac._is_expired(ts, now=now) is True + + def test_custom_now_parameter(self): + """Custom 'now' parameter should override datetime.now().""" + ac = _make_autocompact(ttl=10) + ts = datetime(2026, 1, 1, 10, 0, 0) + # 9 minutes later → not expired + now_under = datetime(2026, 1, 1, 10, 9, 0) + assert ac._is_expired(ts, now=now_under) is False + # 10 minutes later → expired + now_over = datetime(2026, 1, 1, 10, 10, 0) + assert ac._is_expired(ts, now=now_over) is True + + +# --------------------------------------------------------------------------- +# _format_summary +# --------------------------------------------------------------------------- + + +class TestFormatSummary: + """Test AutoCompact._format_summary static method.""" + + def test_contains_isoformat_timestamp(self): + """Output should contain last_active as isoformat.""" + last_active = datetime(2026, 5, 13, 14, 30, 0) + result = AutoCompact._format_summary("Some text", last_active) + assert "2026-05-13T14:30:00" in result + + def test_contains_summary_text(self): + """Output should contain the provided text verbatim.""" + last_active = datetime(2026, 1, 1) + result = AutoCompact._format_summary("User discussed Python.", last_active) + assert "User discussed Python." in result + + def test_output_starts_with_label(self): + """Output should start with the standard prefix.""" + last_active = datetime(2026, 1, 1) + result = AutoCompact._format_summary("text", last_active) + assert result.startswith("Previous conversation summary (last active ") + + +# --------------------------------------------------------------------------- +# _split_unconsolidated +# --------------------------------------------------------------------------- + + +class TestSplitUnconsolidated: + """Test AutoCompact._split_unconsolidated splitting logic.""" + + def test_empty_session_returns_both_empty(self): + """Empty session should return ([], []).""" + ac = _make_autocompact() + session = _make_session(messages=[]) + archive, kept = ac._split_unconsolidated(session) + assert archive == [] + assert kept == [] + + def test_all_messages_archivable_when_more_than_suffix(self): + """Session with many messages should archive a prefix and keep suffix.""" + ac = _make_autocompact() + msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)] + session = _make_session(messages=msgs) + archive, kept = ac._split_unconsolidated(session) + assert len(archive) > 0 + assert len(kept) <= AutoCompact._RECENT_SUFFIX_MESSAGES + + def test_fewer_messages_than_suffix_returns_empty_archive(self): + """Session with fewer messages than suffix should have empty archive.""" + ac = _make_autocompact() + msgs = [{"role": "user", "content": f"u{i}"} for i in range(3)] + session = _make_session(messages=msgs) + archive, kept = ac._split_unconsolidated(session) + assert archive == [] + assert len(kept) == len(msgs) + + def test_respects_last_consolidated_offset(self): + """Only messages after last_consolidated should be considered.""" + ac = _make_autocompact() + msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)] + # First 10 are already consolidated + session = _make_session(messages=msgs, last_consolidated=10) + archive, kept = ac._split_unconsolidated(session) + # Only the tail of 10 messages is considered for splitting + assert all(m["content"] in [f"u{i}" for i in range(10, 20)] for m in kept) + assert all(m["content"] in [f"u{i}" for i in range(10, 20)] for m in archive) + + def test_retain_recent_legal_suffix_keeps_last_n(self): + """The kept suffix should be at most _RECENT_SUFFIX_MESSAGES long.""" + ac = _make_autocompact() + # 20 user messages = 20 messages total, all after last_consolidated=0 + msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)] + session = _make_session(messages=msgs) + archive, kept = ac._split_unconsolidated(session) + assert len(kept) <= AutoCompact._RECENT_SUFFIX_MESSAGES + assert len(archive) == len(msgs) - len(kept) + + +# --------------------------------------------------------------------------- +# check_expired +# --------------------------------------------------------------------------- + + +class TestCheckExpired: + """Test AutoCompact.check_expired scheduling logic.""" + + def test_empty_sessions_list(self): + """No sessions → schedule_background should never be called.""" + ac = _make_autocompact(ttl=15) + mock_sm = MagicMock(spec=SessionManager) + mock_sm.list_sessions.return_value = [] + ac.sessions = mock_sm + scheduler = MagicMock() + ac.check_expired(scheduler) + scheduler.assert_not_called() + + def test_expired_session_schedules_background(self): + """Expired session should trigger schedule_background.""" + ac = _make_autocompact(ttl=15) + mock_sm = MagicMock(spec=SessionManager) + old_ts = (datetime.now() - timedelta(minutes=20)).isoformat() + mock_sm.list_sessions.return_value = [{"key": "cli:old", "updated_at": old_ts}] + ac.sessions = mock_sm + scheduler = MagicMock() + ac.check_expired(scheduler) + scheduler.assert_called_once() + assert "cli:old" in ac._archiving + + def test_active_session_key_skips(self): + """Session in active_session_keys should be skipped.""" + ac = _make_autocompact(ttl=15) + mock_sm = MagicMock(spec=SessionManager) + old_ts = (datetime.now() - timedelta(minutes=20)).isoformat() + mock_sm.list_sessions.return_value = [{"key": "cli:busy", "updated_at": old_ts}] + ac.sessions = mock_sm + scheduler = MagicMock() + ac.check_expired(scheduler, active_session_keys={"cli:busy"}) + scheduler.assert_not_called() + + def test_session_already_in_archiving_skips(self): + """Session already in _archiving set should be skipped.""" + ac = _make_autocompact(ttl=15) + mock_sm = MagicMock(spec=SessionManager) + old_ts = (datetime.now() - timedelta(minutes=20)).isoformat() + mock_sm.list_sessions.return_value = [{"key": "cli:dup", "updated_at": old_ts}] + ac.sessions = mock_sm + ac._archiving.add("cli:dup") + scheduler = MagicMock() + ac.check_expired(scheduler) + scheduler.assert_not_called() + + def test_session_with_no_key_skips(self): + """Session info with empty/missing key should be skipped.""" + ac = _make_autocompact(ttl=15) + mock_sm = MagicMock(spec=SessionManager) + mock_sm.list_sessions.return_value = [{"key": "", "updated_at": "old"}] + ac.sessions = mock_sm + scheduler = MagicMock() + ac.check_expired(scheduler) + scheduler.assert_not_called() + + def test_session_with_missing_key_field_skips(self): + """Session info dict without 'key' field should be skipped.""" + ac = _make_autocompact(ttl=15) + mock_sm = MagicMock(spec=SessionManager) + mock_sm.list_sessions.return_value = [{"updated_at": "old"}] + ac.sessions = mock_sm + scheduler = MagicMock() + ac.check_expired(scheduler) + scheduler.assert_not_called() + + +# --------------------------------------------------------------------------- +# _archive +# --------------------------------------------------------------------------- + + +class TestArchive: + """Test AutoCompact._archive async method.""" + + @pytest.mark.asyncio + async def test_empty_session_updates_timestamp_no_archive_call(self): + """Empty session should refresh updated_at and not call consolidator.archive.""" + ac = _make_autocompact() + mock_sm = MagicMock(spec=SessionManager) + empty_session = _make_session(messages=[]) + mock_sm.get_or_create.return_value = empty_session + ac.sessions = mock_sm + ac.consolidator.archive = AsyncMock(return_value="Summary.") + + await ac._archive("cli:test") + + ac.consolidator.archive.assert_not_called() + mock_sm.save.assert_called_once_with(empty_session) + # updated_at was refreshed + assert empty_session.updated_at > datetime.now() - timedelta(seconds=5) + + @pytest.mark.asyncio + async def test_archive_returns_empty_string_no_summary_stored(self): + """If archive returns empty string, no summary should be stored.""" + ac = _make_autocompact() + mock_sm = MagicMock(spec=SessionManager) + msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)] + session = _make_session(messages=msgs) + mock_sm.get_or_create.return_value = session + ac.sessions = mock_sm + ac.consolidator.archive = AsyncMock(return_value="") + + await ac._archive("cli:test") + + assert "cli:test" not in ac._summaries + + @pytest.mark.asyncio + async def test_archive_returns_nothing_no_summary_stored(self): + """If archive returns '(nothing)', no summary should be stored.""" + ac = _make_autocompact() + mock_sm = MagicMock(spec=SessionManager) + msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)] + session = _make_session(messages=msgs) + mock_sm.get_or_create.return_value = session + ac.sessions = mock_sm + ac.consolidator.archive = AsyncMock(return_value="(nothing)") + + await ac._archive("cli:test") + + assert "cli:test" not in ac._summaries + + @pytest.mark.asyncio + async def test_archive_exception_caught_key_removed_from_archiving(self): + """If archive raises, exception is caught and key removed from _archiving.""" + ac = _make_autocompact() + mock_sm = MagicMock(spec=SessionManager) + msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)] + session = _make_session(messages=msgs) + mock_sm.get_or_create.return_value = session + ac.sessions = mock_sm + ac.consolidator.archive = AsyncMock(side_effect=RuntimeError("LLM down")) + + # Should not raise + await ac._archive("cli:test") + + assert "cli:test" not in ac._archiving + + @pytest.mark.asyncio + async def test_successful_archive_stores_summary_in_summaries_and_metadata(self): + """Successful archive should store summary in _summaries dict and metadata.""" + ac = _make_autocompact() + mock_sm = MagicMock(spec=SessionManager) + msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)] + last_active = datetime(2026, 5, 13, 10, 0, 0) + session = _make_session(messages=msgs, updated_at=last_active) + mock_sm.get_or_create.return_value = session + ac.sessions = mock_sm + ac.consolidator.archive = AsyncMock(return_value="User discussed AI.") + + await ac._archive("cli:test") + + # _summaries + entry = ac._summaries.get("cli:test") + assert entry is not None + assert entry[0] == "User discussed AI." + assert entry[1] == last_active + # metadata + meta = session.metadata.get("_last_summary") + assert meta is not None + assert meta["text"] == "User discussed AI." + assert "last_active" in meta + + @pytest.mark.asyncio + async def test_finally_block_always_removes_from_archiving(self): + """Finally block should always remove key from _archiving, even on error.""" + ac = _make_autocompact() + mock_sm = MagicMock(spec=SessionManager) + msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)] + session = _make_session(messages=msgs) + mock_sm.get_or_create.return_value = session + ac.sessions = mock_sm + ac.consolidator.archive = AsyncMock(side_effect=RuntimeError("fail")) + + # Pre-add key to archiving to verify it gets removed + ac._archiving.add("cli:test") + await ac._archive("cli:test") + assert "cli:test" not in ac._archiving + + @pytest.mark.asyncio + async def test_finally_removes_from_archiving_on_success(self): + """Finally block should remove key from _archiving on success too.""" + ac = _make_autocompact() + mock_sm = MagicMock(spec=SessionManager) + msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)] + session = _make_session(messages=msgs) + mock_sm.get_or_create.return_value = session + ac.sessions = mock_sm + ac.consolidator.archive = AsyncMock(return_value="Summary.") + + ac._archiving.add("cli:test") + await ac._archive("cli:test") + assert "cli:test" not in ac._archiving + + +# --------------------------------------------------------------------------- +# prepare_session +# --------------------------------------------------------------------------- + + +class TestPrepareSession: + """Test AutoCompact.prepare_session logic.""" + + def test_key_in_archiving_reloads_session(self): + """If key is in _archiving, session should be reloaded via get_or_create.""" + ac = _make_autocompact() + mock_sm = MagicMock(spec=SessionManager) + reloaded = _make_session(key="cli:test") + mock_sm.get_or_create.return_value = reloaded + ac.sessions = mock_sm + ac._archiving.add("cli:test") + + original_session = _make_session() + result_session, summary = ac.prepare_session(original_session, "cli:test") + + mock_sm.get_or_create.assert_called_once_with("cli:test") + assert result_session is reloaded + + def test_expired_session_reloads(self): + """If session is expired, it should be reloaded via get_or_create.""" + ac = _make_autocompact(ttl=15) + mock_sm = MagicMock(spec=SessionManager) + reloaded = _make_session(key="cli:test", updated_at=datetime.now()) + mock_sm.get_or_create.return_value = reloaded + ac.sessions = mock_sm + + old_session = _make_session(updated_at=datetime.now() - timedelta(minutes=20)) + result_session, summary = ac.prepare_session(old_session, "cli:test") + + mock_sm.get_or_create.assert_called_once_with("cli:test") + assert result_session is reloaded + + def test_hot_path_summary_from_summaries(self): + """Summary from _summaries dict should be returned (hot path).""" + ac = _make_autocompact() + session = _make_session() + last_active = datetime(2026, 5, 13, 14, 0, 0) + ac._summaries["cli:test"] = ("Hot summary.", last_active) + + result_session, summary = ac.prepare_session(session, "cli:test") + + assert result_session is session + assert summary is not None + assert "Hot summary." in summary + assert "Previous conversation summary" in summary + + def test_hot_path_pops_summary_one_shot(self): + """Hot path should pop the summary (one-shot; second call returns None).""" + ac = _make_autocompact() + session = _make_session() + last_active = datetime(2026, 1, 1) + ac._summaries["cli:test"] = ("One-shot.", last_active) + + _, summary1 = ac.prepare_session(session, "cli:test") + assert summary1 is not None + # Second call: hot path entry was popped + _, summary2 = ac.prepare_session(session, "cli:test") + assert summary2 is None + + def test_cold_path_summary_from_metadata(self): + """When _summaries is empty, summary should come from metadata (cold path).""" + ac = _make_autocompact() + last_active = datetime(2026, 5, 13, 14, 0, 0) + session = _make_session(metadata={ + "_last_summary": { + "text": "Cold summary.", + "last_active": last_active.isoformat(), + }, + }) + + result_session, summary = ac.prepare_session(session, "cli:test") + + assert result_session is session + assert summary is not None + assert "Cold summary." in summary + + def test_no_summary_available_returns_none(self): + """When no summary is available, should return (session, None).""" + ac = _make_autocompact() + session = _make_session() + + result_session, summary = ac.prepare_session(session, "cli:test") + + assert result_session is session + assert summary is None + + def test_cold_path_metadata_not_dict_returns_none(self): + """If metadata _last_summary is not a dict, should return None summary.""" + ac = _make_autocompact() + session = _make_session(metadata={"_last_summary": "not a dict"}) + + result_session, summary = ac.prepare_session(session, "cli:test") + + assert result_session is session + assert summary is None + + def test_hot_path_takes_priority_over_metadata(self): + """Hot path (_summaries) should take priority over metadata.""" + ac = _make_autocompact() + session = _make_session(metadata={ + "_last_summary": { + "text": "Cold summary.", + "last_active": datetime(2026, 1, 1).isoformat(), + }, + }) + last_active = datetime(2026, 5, 13, 14, 0, 0) + ac._summaries["cli:test"] = ("Hot summary.", last_active) + + _, summary = ac.prepare_session(session, "cli:test") + assert "Hot summary." in summary + # After hot path pops, cold path would kick in on next call diff --git a/tests/agent/test_context_builder.py b/tests/agent/test_context_builder.py new file mode 100644 index 000000000..862f1ff2b --- /dev/null +++ b/tests/agent/test_context_builder.py @@ -0,0 +1,333 @@ +"""Tests for ContextBuilder — system prompt and message assembly.""" + +import base64 +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from nanobot.agent.context import ContextBuilder + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _builder(tmp_path: Path, **kw) -> ContextBuilder: + return ContextBuilder(workspace=tmp_path, **kw) + + +# --------------------------------------------------------------------------- +# _build_runtime_context (static) +# --------------------------------------------------------------------------- + + +class TestBuildRuntimeContext: + def test_time_only(self): + ctx = ContextBuilder._build_runtime_context(None, None) + assert "[Runtime Context" in ctx + assert "[/Runtime Context]" in ctx + assert "Current Time:" in ctx + assert "Channel:" not in ctx + + def test_with_channel_and_chat_id(self): + ctx = ContextBuilder._build_runtime_context("telegram", "chat123") + assert "Channel: telegram" in ctx + assert "Chat ID: chat123" in ctx + + def test_with_sender_id(self): + ctx = ContextBuilder._build_runtime_context("cli", "direct", sender_id="user1") + assert "Sender ID: user1" in ctx + + def test_with_timezone(self): + ctx = ContextBuilder._build_runtime_context(None, None, timezone="Asia/Shanghai") + assert "Current Time:" in ctx + + def test_no_channel_no_chat_id_omits_both(self): + ctx = ContextBuilder._build_runtime_context(None, None) + assert "Channel:" not in ctx + assert "Chat ID:" not in ctx + + def test_no_sender_id_omits(self): + ctx = ContextBuilder._build_runtime_context("cli", "direct") + assert "Sender ID:" not in ctx + + +# --------------------------------------------------------------------------- +# _merge_message_content (static) +# --------------------------------------------------------------------------- + + +class TestMergeMessageContent: + def test_str_plus_str(self): + result = ContextBuilder._merge_message_content("hello", "world") + assert result == "hello\n\nworld" + + def test_empty_left_plus_str(self): + result = ContextBuilder._merge_message_content("", "world") + assert result == "world" + + def test_list_plus_list(self): + left = [{"type": "text", "text": "a"}] + right = [{"type": "text", "text": "b"}] + result = ContextBuilder._merge_message_content(left, right) + assert len(result) == 2 + assert result[0]["text"] == "a" + assert result[1]["text"] == "b" + + def test_str_plus_list(self): + right = [{"type": "text", "text": "b"}] + result = ContextBuilder._merge_message_content("hello", right) + assert len(result) == 2 + assert result[0]["text"] == "hello" + assert result[1]["text"] == "b" + + def test_list_plus_str(self): + left = [{"type": "text", "text": "a"}] + result = ContextBuilder._merge_message_content(left, "world") + assert len(result) == 2 + assert result[0]["text"] == "a" + assert result[1]["text"] == "world" + + def test_none_plus_str(self): + result = ContextBuilder._merge_message_content(None, "hello") + assert result == [{"type": "text", "text": "hello"}] + + def test_str_plus_none(self): + result = ContextBuilder._merge_message_content("hello", None) + assert result == [{"type": "text", "text": "hello"}] + + def test_none_plus_none(self): + result = ContextBuilder._merge_message_content(None, None) + assert result == [] + + def test_list_items_not_dicts_wrapped(self): + result = ContextBuilder._merge_message_content(["raw_item"], None) + assert result == [{"type": "text", "text": "raw_item"}] + + +# --------------------------------------------------------------------------- +# _load_bootstrap_files +# --------------------------------------------------------------------------- + + +class TestLoadBootstrapFiles: + def test_no_bootstrap_files(self, tmp_path): + builder = _builder(tmp_path) + assert builder._load_bootstrap_files() == "" + + def test_agents_md(self, tmp_path): + (tmp_path / "AGENTS.md").write_text("Be helpful.", encoding="utf-8") + builder = _builder(tmp_path) + result = builder._load_bootstrap_files() + assert "## AGENTS.md" in result + assert "Be helpful." in result + + def test_multiple_bootstrap_files(self, tmp_path): + (tmp_path / "AGENTS.md").write_text("Rules.", encoding="utf-8") + (tmp_path / "SOUL.md").write_text("Soul.", encoding="utf-8") + builder = _builder(tmp_path) + result = builder._load_bootstrap_files() + assert "## AGENTS.md" in result + assert "## SOUL.md" in result + assert "Rules." in result + assert "Soul." in result + + def test_all_bootstrap_files(self, tmp_path): + for name in ContextBuilder.BOOTSTRAP_FILES: + (tmp_path / name).write_text(f"Content of {name}", encoding="utf-8") + builder = _builder(tmp_path) + result = builder._load_bootstrap_files() + for name in ContextBuilder.BOOTSTRAP_FILES: + assert f"## {name}" in result + + def test_utf8_content(self, tmp_path): + (tmp_path / "AGENTS.md").write_text("用中文回复", encoding="utf-8") + builder = _builder(tmp_path) + result = builder._load_bootstrap_files() + assert "用中文回复" in result + + +# --------------------------------------------------------------------------- +# _is_template_content (static) +# --------------------------------------------------------------------------- + + +class TestIsTemplateContent: + def test_nonexistent_template_returns_false(self): + assert ContextBuilder._is_template_content("anything", "nonexistent/path.md") is False + + def test_content_matching_template(self): + from importlib.resources import files as pkg_files + tpl = pkg_files("nanobot") / "templates" / "memory" / "MEMORY.md" + if not tpl.is_file(): + pytest.skip("MEMORY.md template not bundled") + original = tpl.read_text(encoding="utf-8") + assert ContextBuilder._is_template_content(original, "memory/MEMORY.md") is True + + def test_modified_content_returns_false(self): + from importlib.resources import files as pkg_files + tpl = pkg_files("nanobot") / "templates" / "memory" / "MEMORY.md" + if not tpl.is_file(): + pytest.skip("MEMORY.md template not bundled") + assert ContextBuilder._is_template_content("totally different", "memory/MEMORY.md") is False + + +# --------------------------------------------------------------------------- +# _build_user_content +# --------------------------------------------------------------------------- + + +class TestBuildUserContent: + def test_no_media_returns_string(self, tmp_path): + builder = _builder(tmp_path) + result = builder._build_user_content("hello", None) + assert result == "hello" + + def test_empty_media_returns_string(self, tmp_path): + builder = _builder(tmp_path) + result = builder._build_user_content("hello", []) + assert result == "hello" + + def test_nonexistent_media_file_returns_string(self, tmp_path): + builder = _builder(tmp_path) + result = builder._build_user_content("hello", ["/nonexistent/image.png"]) + assert result == "hello" + + def test_non_image_file_returns_string(self, tmp_path): + txt = tmp_path / "doc.txt" + txt.write_text("not an image", encoding="utf-8") + builder = _builder(tmp_path) + result = builder._build_user_content("hello", [str(txt)]) + assert result == "hello" + + def test_valid_image_returns_list(self, tmp_path): + png = tmp_path / "test.png" + png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 16) + builder = _builder(tmp_path) + result = builder._build_user_content("hello", [str(png)]) + assert isinstance(result, list) + assert len(result) == 2 + assert result[0]["type"] == "image_url" + assert result[0]["image_url"]["url"].startswith("data:image/png;base64,") + assert result[1]["type"] == "text" + assert result[1]["text"] == "hello" + + def test_image_meta_includes_path(self, tmp_path): + png = tmp_path / "test.png" + png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 16) + builder = _builder(tmp_path) + result = builder._build_user_content("hello", [str(png)]) + assert "_meta" in result[0] + assert "path" in result[0]["_meta"] + + +# --------------------------------------------------------------------------- +# build_system_prompt +# --------------------------------------------------------------------------- + + +class TestBuildSystemPrompt: + def test_returns_nonempty_string(self, tmp_path): + builder = _builder(tmp_path) + result = builder.build_system_prompt() + assert isinstance(result, str) + assert len(result) > 0 + + def test_includes_identity_section(self, tmp_path): + builder = _builder(tmp_path) + result = builder.build_system_prompt() + assert "workspace" in result.lower() or "python" in result.lower() + + def test_includes_bootstrap_files(self, tmp_path): + (tmp_path / "AGENTS.md").write_text("Be helpful and concise.", encoding="utf-8") + builder = _builder(tmp_path) + result = builder.build_system_prompt() + assert "Be helpful and concise." in result + + def test_includes_session_summary(self, tmp_path): + builder = _builder(tmp_path) + result = builder.build_system_prompt(session_summary="Previous chat about Python.") + assert "Previous chat about Python." in result + assert "[Archived Context Summary]" in result + + def test_sections_separated_by_separator(self, tmp_path): + (tmp_path / "AGENTS.md").write_text("Rules.", encoding="utf-8") + builder = _builder(tmp_path) + result = builder.build_system_prompt(session_summary="Summary.") + assert "\n\n---\n\n" in result + + def test_no_bootstrap_no_summary(self, tmp_path): + builder = _builder(tmp_path) + result = builder.build_system_prompt() + assert "## AGENTS.md" not in result + assert "[Archived Context Summary]" not in result + + +# --------------------------------------------------------------------------- +# build_messages +# --------------------------------------------------------------------------- + + +class TestBuildMessages: + def test_basic_empty_history(self, tmp_path): + builder = _builder(tmp_path) + messages = builder.build_messages([], "hello") + assert len(messages) == 2 + assert messages[0]["role"] == "system" + assert messages[1]["role"] == "user" + assert "hello" in str(messages[1]["content"]) + + def test_runtime_context_injected(self, tmp_path): + builder = _builder(tmp_path) + messages = builder.build_messages([], "hello", channel="cli", chat_id="direct") + user_msg = str(messages[-1]["content"]) + assert "[Runtime Context" in user_msg + assert "hello" in user_msg + + def test_consecutive_same_role_merged(self, tmp_path): + builder = _builder(tmp_path) + history = [{"role": "user", "content": "previous user message"}] + messages = builder.build_messages(history, "new message") + assert len(messages) == 2 # system + merged user + assert "previous user message" in str(messages[1]["content"]) + assert "new message" in str(messages[1]["content"]) + + def test_different_role_appended(self, tmp_path): + builder = _builder(tmp_path) + history = [{"role": "assistant", "content": "previous response"}] + messages = builder.build_messages(history, "new message") + assert len(messages) == 3 # system + assistant + user + + def test_media_with_history(self, tmp_path): + png = tmp_path / "img.png" + png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 16) + builder = _builder(tmp_path) + history = [{"role": "assistant", "content": "see this"}] + messages = builder.build_messages(history, "check image", media=[str(png)]) + user_msg = messages[-1]["content"] + assert isinstance(user_msg, list) + assert any(b.get("type") == "image_url" for b in user_msg) + + +# --------------------------------------------------------------------------- +# add_tool_result +# --------------------------------------------------------------------------- + + +class TestAddToolResult: + def test_appends_tool_message(self, tmp_path): + builder = _builder(tmp_path) + msgs = [{"role": "user", "content": "hello"}] + result = builder.add_tool_result(msgs, "call_123", "read_file", "file content") + assert len(result) == 2 + assert result[1]["role"] == "tool" + assert result[1]["tool_call_id"] == "call_123" + assert result[1]["name"] == "read_file" + assert result[1]["content"] == "file content" + + def test_returns_same_list(self, tmp_path): + builder = _builder(tmp_path) + msgs = [] + result = builder.add_tool_result(msgs, "id", "tool", "ok") + assert result is msgs diff --git a/tests/agent/test_hook_composite.py b/tests/agent/test_hook_composite.py index 8971d48ec..9b6c2820d 100644 --- a/tests/agent/test_hook_composite.py +++ b/tests/agent/test_hook_composite.py @@ -13,6 +13,17 @@ def _ctx() -> AgentHookContext: return AgentHookContext(iteration=0, messages=[]) +# --------------------------------------------------------------------------- +# Base AgentHook emit_reasoning: no-op +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_base_hook_emit_reasoning_is_noop(): + hook = AgentHook() + await hook.emit_reasoning("should not raise") + + # --------------------------------------------------------------------------- # Fan-out: every hook is called in order # --------------------------------------------------------------------------- @@ -45,6 +56,9 @@ async def test_composite_fans_out_all_async_methods(): async def before_iteration(self, context: AgentHookContext) -> None: events.append("before_iteration") + async def emit_reasoning(self, reasoning_content: str | None) -> None: + events.append(f"emit_reasoning:{reasoning_content}") + async def on_stream(self, context: AgentHookContext, delta: str) -> None: events.append(f"on_stream:{delta}") @@ -61,6 +75,7 @@ async def test_composite_fans_out_all_async_methods(): ctx = _ctx() await hook.before_iteration(ctx) + await hook.emit_reasoning("thinking...") await hook.on_stream(ctx, "hi") await hook.on_stream_end(ctx, resuming=True) await hook.before_execute_tools(ctx) @@ -68,6 +83,7 @@ async def test_composite_fans_out_all_async_methods(): assert events == [ "before_iteration", "before_iteration", + "emit_reasoning:thinking...", "emit_reasoning:thinking...", "on_stream:hi", "on_stream:hi", "on_stream_end:True", "on_stream_end:True", "before_execute_tools", "before_execute_tools", @@ -120,6 +136,8 @@ async def test_composite_error_isolation_all_async(): calls: list[str] = [] class Bad(AgentHook): + async def emit_reasoning(self, reasoning_content): + raise RuntimeError("err") async def on_stream_end(self, context, *, resuming): raise RuntimeError("err") async def before_execute_tools(self, context): @@ -128,6 +146,8 @@ async def test_composite_error_isolation_all_async(): raise RuntimeError("err") class Good(AgentHook): + async def emit_reasoning(self, reasoning_content): + calls.append("emit_reasoning") async def on_stream_end(self, context, *, resuming): calls.append("on_stream_end") async def before_execute_tools(self, context): @@ -137,10 +157,11 @@ async def test_composite_error_isolation_all_async(): hook = CompositeHook([Bad(), Good()]) ctx = _ctx() + await hook.emit_reasoning("test") await hook.on_stream_end(ctx, resuming=False) await hook.before_execute_tools(ctx) await hook.after_iteration(ctx) - assert calls == ["on_stream_end", "before_execute_tools", "after_iteration"] + assert calls == ["emit_reasoning", "on_stream_end", "before_execute_tools", "after_iteration"] # --------------------------------------------------------------------------- diff --git a/tests/agent/test_loop_runner_integration.py b/tests/agent/test_loop_runner_integration.py new file mode 100644 index 000000000..3cfe07f41 --- /dev/null +++ b/tests/agent/test_loop_runner_integration.py @@ -0,0 +1,301 @@ +"""Tests for AgentLoop integration with AgentRunner: streaming, think-filter, error handling, subagent.""" + +from __future__ import annotations + +import asyncio +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from nanobot.config.schema import AgentDefaults +from nanobot.providers.base import LLMResponse, ToolCallRequest + +_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars + + +def _make_loop(tmp_path): + from nanobot.agent.loop import AgentLoop + from nanobot.bus.queue import MessageBus + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + + with patch("nanobot.agent.loop.ContextBuilder"), \ + patch("nanobot.agent.loop.SessionManager"), \ + patch("nanobot.agent.loop.SubagentManager") as MockSubMgr: + MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0) + loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path) + return loop + +@pytest.mark.asyncio +async def test_loop_max_iterations_message_stays_stable(tmp_path): + loop = _make_loop(tmp_path) + loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse( + content="working", + tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})], + )) + loop.tools.get_definitions = MagicMock(return_value=[]) + loop.tools.execute = AsyncMock(return_value="ok") + loop.max_iterations = 2 + + final_content, _, _, _, _ = await loop._run_agent_loop([]) + + assert final_content == ( + "I reached the maximum number of tool call iterations (2) " + "without completing the task. You can try breaking the task into smaller steps." + ) + + +@pytest.mark.asyncio +async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp_path): + loop = _make_loop(tmp_path) + deltas: list[str] = [] + endings: list[bool] = [] + + async def chat_stream_with_retry(*, on_content_delta, **kwargs): + await on_content_delta("hidden") + await on_content_delta("Hello") + return LLMResponse(content="hiddenHello", tool_calls=[], usage={}) + + loop.provider.chat_stream_with_retry = chat_stream_with_retry + + async def on_stream(delta: str) -> None: + deltas.append(delta) + + async def on_stream_end(*, resuming: bool = False) -> None: + endings.append(resuming) + + final_content, _, _, _, _ = await loop._run_agent_loop( + [], + on_stream=on_stream, + on_stream_end=on_stream_end, + ) + + assert final_content == "Hello" + assert deltas == ["Hello"] + assert endings == [False] + + +@pytest.mark.asyncio +async def test_loop_stream_filter_hides_partial_trailing_think_prefix(tmp_path): + loop = _make_loop(tmp_path) + deltas: list[str] = [] + + async def chat_stream_with_retry(*, on_content_delta, **kwargs): + await on_content_delta("Hello hiddenWorld") + return LLMResponse(content="Hello hiddenWorld", tool_calls=[], usage={}) + + loop.provider.chat_stream_with_retry = chat_stream_with_retry + + async def on_stream(delta: str) -> None: + deltas.append(delta) + + final_content, _, _, _, _ = await loop._run_agent_loop([], on_stream=on_stream) + + assert final_content == "Hello World" + assert deltas == ["Hello", " World"] + + +@pytest.mark.asyncio +async def test_loop_stream_filter_hides_complete_trailing_think_tag(tmp_path): + loop = _make_loop(tmp_path) + deltas: list[str] = [] + + async def chat_stream_with_retry(*, on_content_delta, **kwargs): + await on_content_delta("Hello ") + await on_content_delta("hiddenWorld") + return LLMResponse(content="Hello hiddenWorld", tool_calls=[], usage={}) + + loop.provider.chat_stream_with_retry = chat_stream_with_retry + + async def on_stream(delta: str) -> None: + deltas.append(delta) + + final_content, _, _, _, _ = await loop._run_agent_loop([], on_stream=on_stream) + + assert final_content == "Hello World" + assert deltas == ["Hello", " World"] + + +@pytest.mark.asyncio +async def test_loop_retries_think_only_final_response(tmp_path): + loop = _make_loop(tmp_path) + call_count = {"n": 0} + + async def chat_with_retry(**kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse(content="hidden", tool_calls=[], usage={}) + return LLMResponse(content="Recovered answer", tool_calls=[], usage={}) + + loop.provider.chat_with_retry = chat_with_retry + + final_content, _, _, _, _ = await loop._run_agent_loop([]) + + assert final_content == "Recovered answer" + assert call_count["n"] == 2 + + +@pytest.mark.asyncio +async def test_streamed_flag_not_set_on_llm_error(tmp_path): + """When LLM errors during a streaming-capable channel interaction, + _streamed must NOT be set so ChannelManager delivers the error.""" + from nanobot.agent.loop import AgentLoop + from nanobot.bus.events import InboundMessage + from nanobot.bus.queue import MessageBus + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") + error_resp = LLMResponse( + content="503 service unavailable", finish_reason="error", tool_calls=[], usage={}, + ) + loop.provider.chat_with_retry = AsyncMock(return_value=error_resp) + loop.provider.chat_stream_with_retry = AsyncMock(return_value=error_resp) + loop.tools.get_definitions = MagicMock(return_value=[]) + + msg = InboundMessage( + channel="feishu", sender_id="u1", chat_id="c1", content="hi", + ) + result = await loop._process_message( + msg, + on_stream=AsyncMock(), + on_stream_end=AsyncMock(), + ) + + assert result is not None + assert "503" in result.content + assert not result.metadata.get("_streamed"), \ + "_streamed must not be set when stop_reason is error" + + +@pytest.mark.asyncio +async def test_ssrf_soft_block_can_finalize_after_streamed_tool_call(tmp_path): + from nanobot.agent.loop import AgentLoop + from nanobot.bus.events import InboundMessage + from nanobot.bus.queue import MessageBus + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + tool_call_resp = LLMResponse( + content="checking metadata", + tool_calls=[ToolCallRequest( + id="call_ssrf", + name="exec", + arguments={"command": "curl http://169.254.169.254/latest/meta-data/"}, + )], + usage={}, + ) + provider.chat_stream_with_retry = AsyncMock(side_effect=[ + tool_call_resp, + LLMResponse( + content="I cannot access private URLs. Please share the local file.", + tool_calls=[], + usage={}, + ), + ]) + + loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") + loop.tools.get_definitions = MagicMock(return_value=[]) + loop.tools.prepare_call = MagicMock(return_value=(None, {}, None)) + loop.tools.execute = AsyncMock(return_value=( + "Error: Command blocked by safety guard (internal/private URL detected)" + )) + + result = await loop._process_message( + InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="hi"), + on_stream=AsyncMock(), + on_stream_end=AsyncMock(), + ) + + assert result is not None + assert result.content == "I cannot access private URLs. Please share the local file." + assert result.metadata.get("_streamed") is True + + +@pytest.mark.asyncio +async def test_next_turn_after_llm_error_keeps_turn_boundary(tmp_path): + from nanobot.agent.loop import AgentLoop + from nanobot.agent.runner import _PERSISTED_MODEL_ERROR_PLACEHOLDER + from nanobot.bus.events import InboundMessage + from nanobot.bus.queue import MessageBus + + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + provider.chat_with_retry = AsyncMock(side_effect=[ + LLMResponse(content="429 rate limit exceeded", finish_reason="error", tool_calls=[], usage={}), + LLMResponse(content="Recovered answer", tool_calls=[], usage={}), + ]) + + loop = AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model") + loop.tools.get_definitions = MagicMock(return_value=[]) + loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] + + first = await loop._process_message( + InboundMessage(channel="cli", sender_id="user", chat_id="test", content="first question") + ) + assert first is not None + assert first.content == "429 rate limit exceeded" + + session = loop.sessions.get_or_create("cli:test") + assert [ + {key: value for key, value in message.items() if key in {"role", "content"}} + for message in session.messages + ] == [ + {"role": "user", "content": "first question"}, + {"role": "assistant", "content": _PERSISTED_MODEL_ERROR_PLACEHOLDER}, + ] + + second = await loop._process_message( + InboundMessage(channel="cli", sender_id="user", chat_id="test", content="second question") + ) + assert second is not None + assert second.content == "Recovered answer" + + request_messages = provider.chat_with_retry.await_args_list[1].kwargs["messages"] + non_system = [message for message in request_messages if message.get("role") != "system"] + assert non_system[0]["role"] == "user" + assert "first question" in non_system[0]["content"] + assert non_system[1]["role"] == "assistant" + assert _PERSISTED_MODEL_ERROR_PLACEHOLDER in non_system[1]["content"] + assert non_system[2]["role"] == "user" + assert "second question" in non_system[2]["content"] + + +@pytest.mark.asyncio +async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, monkeypatch): + from nanobot.agent.subagent import SubagentManager, SubagentStatus + from nanobot.bus.queue import MessageBus + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + provider.chat_with_retry = AsyncMock(return_value=LLMResponse( + content="working", + tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})], + )) + mgr = SubagentManager( + provider=provider, + workspace=tmp_path, + bus=bus, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + ) + mgr._announce_result = AsyncMock() + + async def fake_execute(self, **kwargs): + return "tool result" + + monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute) + + status = SubagentStatus(task_id="sub-1", label="label", task_description="do task", started_at=time.monotonic()) + await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"}, status) + + mgr._announce_result.assert_awaited_once() + args = mgr._announce_result.await_args.args + assert args[3] == "Task completed but no final response was generated." + assert args[5] == "ok" diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py deleted file mode 100644 index b821d9bab..000000000 --- a/tests/agent/test_runner.py +++ /dev/null @@ -1,3313 +0,0 @@ -"""Tests for the shared agent runner and its integration contracts.""" - -from __future__ import annotations - -import asyncio -import base64 -import os -import time -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from nanobot.config.schema import AgentDefaults -from nanobot.agent.tools.base import Tool -from nanobot.agent.tools.registry import ToolRegistry -from nanobot.providers.base import LLMResponse, ToolCallRequest - -_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars - - -def _make_injection_callback(queue: asyncio.Queue): - """Return an async callback that drains *queue* into a list of dicts.""" - async def inject_cb(): - items = [] - while not queue.empty(): - items.append(await queue.get()) - return items - return inject_cb - - -def _make_loop(tmp_path): - from nanobot.agent.loop import AgentLoop - from nanobot.bus.queue import MessageBus - - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - - with patch("nanobot.agent.loop.ContextBuilder"), \ - patch("nanobot.agent.loop.SessionManager"), \ - patch("nanobot.agent.loop.SubagentManager") as MockSubMgr: - MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0) - loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path) - return loop - - -@pytest.mark.asyncio -async def test_runner_preserves_reasoning_fields_and_tool_results(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - captured_second_call: list[dict] = [] - call_count = {"n": 0} - - async def chat_with_retry(*, messages, **kwargs): - call_count["n"] += 1 - if call_count["n"] == 1: - return LLMResponse( - content="thinking", - tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})], - reasoning_content="hidden reasoning", - thinking_blocks=[{"type": "thinking", "thinking": "step"}], - usage={"prompt_tokens": 5, "completion_tokens": 3}, - ) - captured_second_call[:] = messages - return LLMResponse(content="done", tool_calls=[], usage={}) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - tools.execute = AsyncMock(return_value="tool result") - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[ - {"role": "system", "content": "system"}, - {"role": "user", "content": "do task"}, - ], - tools=tools, - model="test-model", - max_iterations=3, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - )) - - assert result.final_content == "done" - assert result.tools_used == ["list_dir"] - assert result.tool_events == [ - {"name": "list_dir", "status": "ok", "detail": "tool result"} - ] - - assistant_messages = [ - msg for msg in captured_second_call - if msg.get("role") == "assistant" and msg.get("tool_calls") - ] - assert len(assistant_messages) == 1 - assert assistant_messages[0]["reasoning_content"] == "hidden reasoning" - assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}] - assert any( - msg.get("role") == "tool" and msg.get("content") == "tool result" - for msg in captured_second_call - ) - - -@pytest.mark.asyncio -async def test_runner_calls_hooks_in_order(): - from nanobot.agent.hook import AgentHook, AgentHookContext - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - call_count = {"n": 0} - events: list[tuple] = [] - - async def chat_with_retry(**kwargs): - call_count["n"] += 1 - if call_count["n"] == 1: - return LLMResponse( - content="thinking", - tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})], - ) - return LLMResponse(content="done", tool_calls=[], usage={}) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - tools.execute = AsyncMock(return_value="tool result") - - class RecordingHook(AgentHook): - async def before_iteration(self, context: AgentHookContext) -> None: - events.append(("before_iteration", context.iteration)) - - async def before_execute_tools(self, context: AgentHookContext) -> None: - events.append(( - "before_execute_tools", - context.iteration, - [tc.name for tc in context.tool_calls], - )) - - async def after_iteration(self, context: AgentHookContext) -> None: - events.append(( - "after_iteration", - context.iteration, - context.final_content, - list(context.tool_results), - list(context.tool_events), - context.stop_reason, - )) - - def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: - events.append(("finalize_content", context.iteration, content)) - return content.upper() if content else content - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[], - tools=tools, - model="test-model", - max_iterations=3, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - hook=RecordingHook(), - )) - - assert result.final_content == "DONE" - assert events == [ - ("before_iteration", 0), - ("before_execute_tools", 0, ["list_dir"]), - ( - "after_iteration", - 0, - None, - ["tool result"], - [{"name": "list_dir", "status": "ok", "detail": "tool result"}], - None, - ), - ("before_iteration", 1), - ("finalize_content", 1, "done"), - ("after_iteration", 1, "DONE", [], [], "completed"), - ] - - -@pytest.mark.asyncio -async def test_runner_streaming_hook_receives_deltas_and_end_signal(): - from nanobot.agent.hook import AgentHook, AgentHookContext - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - streamed: list[str] = [] - endings: list[bool] = [] - - async def chat_stream_with_retry(*, on_content_delta, **kwargs): - await on_content_delta("he") - await on_content_delta("llo") - return LLMResponse(content="hello", tool_calls=[], usage={}) - - provider.chat_stream_with_retry = chat_stream_with_retry - provider.chat_with_retry = AsyncMock() - tools = MagicMock() - tools.get_definitions.return_value = [] - - class StreamingHook(AgentHook): - def wants_streaming(self) -> bool: - return True - - async def on_stream(self, context: AgentHookContext, delta: str) -> None: - streamed.append(delta) - - async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: - endings.append(resuming) - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[], - tools=tools, - model="test-model", - max_iterations=1, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - hook=StreamingHook(), - )) - - assert result.final_content == "hello" - assert streamed == ["he", "llo"] - assert endings == [False] - provider.chat_with_retry.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_runner_returns_max_iterations_fallback(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - provider.chat_with_retry = AsyncMock(return_value=LLMResponse( - content="still working", - tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})], - )) - tools = MagicMock() - tools.get_definitions.return_value = [] - tools.execute = AsyncMock(return_value="tool result") - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[], - tools=tools, - model="test-model", - max_iterations=2, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - )) - - assert result.stop_reason == "max_iterations" - assert result.final_content == ( - "I reached the maximum number of tool call iterations (2) " - "without completing the task. You can try breaking the task into smaller steps." - ) - assert result.messages[-1]["role"] == "assistant" - assert result.messages[-1]["content"] == result.final_content - - -@pytest.mark.asyncio -async def test_runner_times_out_hung_llm_request(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - - async def chat_with_retry(**kwargs): - await asyncio.sleep(3600) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - - runner = AgentRunner(provider) - started = time.monotonic() - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "hello"}], - tools=tools, - model="test-model", - max_iterations=1, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - llm_timeout_s=0.05, - )) - - assert (time.monotonic() - started) < 1.0 - assert result.stop_reason == "error" - assert "timed out" in (result.final_content or "").lower() - -@pytest.mark.asyncio -async def test_runner_returns_structured_tool_error(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - provider.chat_with_retry = AsyncMock(return_value=LLMResponse( - content="working", - tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})], - )) - tools = MagicMock() - tools.get_definitions.return_value = [] - tools.execute = AsyncMock(side_effect=RuntimeError("boom")) - - runner = AgentRunner(provider) - - result = await runner.run(AgentRunSpec( - initial_messages=[], - tools=tools, - model="test-model", - max_iterations=2, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - fail_on_tool_error=True, - )) - - assert result.stop_reason == "tool_error" - assert result.error == "Error: RuntimeError: boom" - assert result.tool_events == [ - {"name": "list_dir", "status": "error", "detail": "boom"} - ] - - -@pytest.mark.asyncio -async def test_runner_does_not_abort_on_workspace_violation_anymore(): - """v2 behavior: workspace-bound rejections are *soft* tool errors. - - Previously (PR #3493) any workspace boundary error became a fatal - RuntimeError that aborted the turn. That silently killed legitimate - workspace commands once the heuristic guard misfired (#3599 #3605), so - we now hand the error back to the LLM as a recoverable tool result and - rely on ``repeated_workspace_violation_error`` to throttle bypass loops. - """ - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - provider.chat_with_retry = AsyncMock(side_effect=[ - LLMResponse( - content="trying outside", - tool_calls=[ToolCallRequest( - id="call_1", name="read_file", arguments={"path": "/tmp/outside.md"}, - )], - ), - LLMResponse(content="ok, telling the user instead", tool_calls=[]), - ]) - tools = MagicMock() - tools.get_definitions.return_value = [] - tools.execute = AsyncMock( - side_effect=PermissionError( - "Path /tmp/outside.md is outside allowed directory /workspace" - ) - ) - - runner = AgentRunner(provider) - - result = await runner.run(AgentRunSpec( - initial_messages=[], - tools=tools, - model="test-model", - max_iterations=3, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - )) - - assert provider.chat_with_retry.await_count == 2, ( - "workspace violation must NOT short-circuit the loop" - ) - assert result.stop_reason != "tool_error" - assert result.error is None - assert result.final_content == "ok, telling the user instead" - assert result.tool_events and result.tool_events[0]["status"] == "error" - # Detail still carries the workspace_violation breadcrumb for telemetry, - # but the runner did not raise. - assert "workspace_violation" in result.tool_events[0]["detail"] - - -def test_is_ssrf_violation_recognizes_private_url_blocks(): - """SSRF rejections are classified separately from workspace boundaries.""" - from nanobot.agent.runner import AgentRunner - - ssrf_msg = "Error: Command blocked by safety guard (internal/private URL detected)" - assert AgentRunner._is_ssrf_violation(ssrf_msg) is True - assert AgentRunner._is_ssrf_violation( - "URL validation failed: Blocked: host resolves to private/internal address 192.168.1.2" - ) is True - - # Workspace-bound markers are NOT classified as SSRF. - assert AgentRunner._is_ssrf_violation( - "Error: Command blocked by safety guard (path outside working dir)" - ) is False - assert AgentRunner._is_ssrf_violation( - "Path /tmp/x is outside allowed directory /ws" - ) is False - # Deny / allowlist filter messages stay non-fatal too. - assert AgentRunner._is_ssrf_violation( - "Error: Command blocked by deny pattern filter" - ) is False - - -@pytest.mark.asyncio -async def test_runner_returns_non_retryable_hint_on_ssrf_violation(): - """SSRF stays blocked, but the runtime gives the LLM a final chance to recover.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - provider.chat_with_retry = AsyncMock(side_effect=[ - LLMResponse( - content="curl-ing metadata", - tool_calls=[ToolCallRequest( - id="call_ssrf", - name="exec", - arguments={"command": "curl http://169.254.169.254"}, - )], - ), - LLMResponse( - content="I cannot access that private URL. Please share local files.", - tool_calls=[], - ), - ]) - tools = MagicMock() - tools.get_definitions.return_value = [] - tools.execute = AsyncMock(return_value=( - "Error: Command blocked by safety guard (internal/private URL detected)" - )) - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[], - tools=tools, - model="test-model", - max_iterations=3, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - )) - - assert provider.chat_with_retry.await_count == 2 - assert result.stop_reason == "completed" - assert result.error is None - assert result.final_content == "I cannot access that private URL. Please share local files." - assert result.tool_events and result.tool_events[0]["detail"].startswith("ssrf_violation:") - tool_messages = [m for m in result.messages if m.get("role") == "tool"] - assert tool_messages - assert "non-bypassable security boundary" in tool_messages[0]["content"] - assert "Do not retry" in tool_messages[0]["content"] - assert "tools.ssrfWhitelist" in tool_messages[0]["content"] - - -@pytest.mark.asyncio -async def test_runner_lets_llm_recover_from_shell_guard_path_outside(): - """Reporter scenario for #3599 / #3605 -- guard hit, agent recovers. - - The shell `_guard_command` heuristic fires on `2>/dev/null`-style - redirects and other shell idioms. Before v2 that abort'd the whole - turn (silent hang on Telegram per #3605); now the LLM gets the soft - error back and can finalize on the next iteration. - """ - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - captured_second_call: list[dict] = [] - - async def chat_with_retry(*, messages, **kwargs): - if provider.chat_with_retry.await_count == 1: - return LLMResponse( - content="trying noisy cleanup", - tool_calls=[ToolCallRequest( - id="call_blocked", - name="exec", - arguments={"command": "rm scratch.txt 2>/dev/null"}, - )], - ) - captured_second_call[:] = list(messages) - return LLMResponse(content="recovered final answer", tool_calls=[]) - - provider.chat_with_retry = AsyncMock(side_effect=chat_with_retry) - tools = MagicMock() - tools.get_definitions.return_value = [] - tools.execute = AsyncMock( - return_value="Error: Command blocked by safety guard (path outside working dir)" - ) - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[], - tools=tools, - model="test-model", - max_iterations=3, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - )) - - assert provider.chat_with_retry.await_count == 2, ( - "guard hit must NOT short-circuit the loop -- LLM should get a second turn" - ) - assert result.stop_reason != "tool_error" - assert result.error is None - assert result.final_content == "recovered final answer" - assert result.tool_events and result.tool_events[0]["status"] == "error" - # v2: detail keeps the breadcrumb but the runner did not raise. - assert "workspace_violation" in result.tool_events[0]["detail"] - - -@pytest.mark.asyncio -async def test_runner_throttles_repeated_workspace_bypass_attempts(): - """#3493 motivation: stop the LLM bypass loop without aborting the turn. - - LLM keeps switching tools (read_file -> exec cat -> python -c open(...)) - against the same outside path. After the soft retry budget is exhausted - the runner replaces the tool result with a hard "stop trying" message - so the model finally gives up and surfaces the boundary to the user. - """ - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - bypass_attempts = [ - ToolCallRequest( - id=f"a{i}", name="exec", - arguments={"command": f"cat /Users/x/Downloads/01.md # try {i}"}, - ) - for i in range(4) - ] - responses: list[LLMResponse] = [ - LLMResponse(content=f"try {i}", tool_calls=[bypass_attempts[i]]) - for i in range(4) - ] - responses.append(LLMResponse(content="ok telling user", tool_calls=[])) - - provider = MagicMock() - provider.chat_with_retry = AsyncMock(side_effect=responses) - tools = MagicMock() - tools.get_definitions.return_value = [] - tools.execute = AsyncMock( - return_value="Error: Command blocked by safety guard (path outside working dir)" - ) - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[], - tools=tools, - model="test-model", - max_iterations=10, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - )) - - # All 4 bypass attempts surface to the LLM (no fatal abort), and the - # runner finally completes once the LLM stops asking. - assert result.stop_reason != "tool_error" - assert result.error is None - assert result.final_content == "ok telling user" - # The third+ attempts must have been escalated -- look at the events. - escalated = [ - ev for ev in result.tool_events - if ev["status"] == "error" - and ev["detail"].startswith("workspace_violation_escalated:") - ] - assert escalated, ( - "expected at least one escalated workspace_violation event, got: " - f"{result.tool_events}" - ) - - -@pytest.mark.asyncio -async def test_runner_persists_large_tool_results_for_follow_up_calls(tmp_path): - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - captured_second_call: list[dict] = [] - call_count = {"n": 0} - - async def chat_with_retry(*, messages, **kwargs): - call_count["n"] += 1 - if call_count["n"] == 1: - return LLMResponse( - content="working", - tool_calls=[ToolCallRequest(id="call_big", name="list_dir", arguments={"path": "."})], - usage={"prompt_tokens": 5, "completion_tokens": 3}, - ) - captured_second_call[:] = messages - return LLMResponse(content="done", tool_calls=[], usage={}) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - tools.execute = AsyncMock(return_value="x" * 20_000) - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "do task"}], - tools=tools, - model="test-model", - max_iterations=2, - workspace=tmp_path, - session_key="test:runner", - max_tool_result_chars=2048, - )) - - assert result.final_content == "done" - tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool") - assert "[tool output persisted]" in tool_message["content"] - assert "tool-results" in tool_message["content"] - assert (tmp_path / ".nanobot" / "tool-results" / "test_runner" / "call_big.txt").exists() - - -def test_persist_tool_result_prunes_old_session_buckets(tmp_path): - from nanobot.utils.helpers import maybe_persist_tool_result - - root = tmp_path / ".nanobot" / "tool-results" - old_bucket = root / "old_session" - recent_bucket = root / "recent_session" - old_bucket.mkdir(parents=True) - recent_bucket.mkdir(parents=True) - (old_bucket / "old.txt").write_text("old", encoding="utf-8") - (recent_bucket / "recent.txt").write_text("recent", encoding="utf-8") - - stale = time.time() - (8 * 24 * 60 * 60) - os.utime(old_bucket, (stale, stale)) - os.utime(old_bucket / "old.txt", (stale, stale)) - - persisted = maybe_persist_tool_result( - tmp_path, - "current:session", - "call_big", - "x" * 5000, - max_chars=64, - ) - - assert "[tool output persisted]" in persisted - assert not old_bucket.exists() - assert recent_bucket.exists() - assert (root / "current_session" / "call_big.txt").exists() - - -def test_persist_tool_result_leaves_no_temp_files(tmp_path): - from nanobot.utils.helpers import maybe_persist_tool_result - - root = tmp_path / ".nanobot" / "tool-results" - maybe_persist_tool_result( - tmp_path, - "current:session", - "call_big", - "x" * 5000, - max_chars=64, - ) - - assert (root / "current_session" / "call_big.txt").exists() - assert list((root / "current_session").glob("*.tmp")) == [] - - -def test_persist_tool_result_logs_cleanup_failures(monkeypatch, tmp_path): - from nanobot.utils.helpers import maybe_persist_tool_result - - warnings: list[str] = [] - - monkeypatch.setattr( - "nanobot.utils.helpers._cleanup_tool_result_buckets", - lambda *_args, **_kwargs: (_ for _ in ()).throw(OSError("busy")), - ) - monkeypatch.setattr( - "nanobot.utils.helpers.logger.exception", - lambda message, *args: warnings.append(message.format(*args)), - ) - - persisted = maybe_persist_tool_result( - tmp_path, - "current:session", - "call_big", - "x" * 5000, - max_chars=64, - ) - - assert "[tool output persisted]" in persisted - assert warnings and "Failed to clean stale tool result buckets" in warnings[0] - - -@pytest.mark.asyncio -async def test_runner_replaces_empty_tool_result_with_marker(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - captured_second_call: list[dict] = [] - call_count = {"n": 0} - - async def chat_with_retry(*, messages, **kwargs): - call_count["n"] += 1 - if call_count["n"] == 1: - return LLMResponse( - content="working", - tool_calls=[ToolCallRequest(id="call_1", name="noop", arguments={})], - usage={}, - ) - captured_second_call[:] = messages - return LLMResponse(content="done", tool_calls=[], usage={}) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - tools.execute = AsyncMock(return_value="") - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "do task"}], - tools=tools, - model="test-model", - max_iterations=2, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - )) - - assert result.final_content == "done" - tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool") - assert tool_message["content"] == "(noop completed with no output)" - - -@pytest.mark.asyncio -async def test_runner_uses_raw_messages_when_context_governance_fails(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - captured_messages: list[dict] = [] - - async def chat_with_retry(*, messages, **kwargs): - captured_messages[:] = messages - return LLMResponse(content="done", tool_calls=[], usage={}) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - initial_messages = [ - {"role": "system", "content": "system"}, - {"role": "user", "content": "hello"}, - ] - - runner = AgentRunner(provider) - runner._snip_history = MagicMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign] - result = await runner.run(AgentRunSpec( - initial_messages=initial_messages, - tools=tools, - model="test-model", - max_iterations=1, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - )) - - assert result.final_content == "done" - assert captured_messages == initial_messages - - -@pytest.mark.asyncio -async def test_runner_retries_empty_final_response_with_summary_prompt(): - """Empty responses get 2 silent retries before finalization kicks in.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - calls: list[dict] = [] - - async def chat_with_retry(*, messages, tools=None, **kwargs): - calls.append({"messages": messages, "tools": tools}) - if len(calls) <= 2: - return LLMResponse( - content=None, - tool_calls=[], - usage={"prompt_tokens": 5, "completion_tokens": 1}, - ) - return LLMResponse( - content="final answer", - tool_calls=[], - usage={"prompt_tokens": 3, "completion_tokens": 7}, - ) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "do task"}], - tools=tools, - model="test-model", - max_iterations=3, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - )) - - assert result.final_content == "final answer" - # 2 silent retries (iterations 0,1) + finalization on iteration 1 - assert len(calls) == 3 - assert calls[0]["tools"] is not None - assert calls[1]["tools"] is not None - assert calls[2]["tools"] is None - assert result.usage["prompt_tokens"] == 13 - assert result.usage["completion_tokens"] == 9 - - -@pytest.mark.asyncio -async def test_runner_uses_specific_message_after_empty_finalization_retry(): - """After silent retries + finalization all return empty, stop_reason is empty_final_response.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner - from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE - - provider = MagicMock() - - async def chat_with_retry(*, messages, **kwargs): - return LLMResponse(content=None, tool_calls=[], usage={}) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "do task"}], - tools=tools, - model="test-model", - max_iterations=3, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - )) - - assert result.final_content == EMPTY_FINAL_RESPONSE_MESSAGE - assert result.stop_reason == "empty_final_response" - - -@pytest.mark.asyncio -async def test_runner_empty_response_does_not_break_tool_chain(): - """An empty intermediate response must not kill an ongoing tool chain. - - Sequence: tool_call → empty → tool_call → final text. - The runner should recover via silent retry and complete normally. - """ - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - call_count = 0 - - async def chat_with_retry(*, messages, tools=None, **kwargs): - nonlocal call_count - call_count += 1 - if call_count == 1: - return LLMResponse( - content=None, - tool_calls=[ToolCallRequest(id="tc1", name="read_file", arguments={"path": "a.txt"})], - usage={"prompt_tokens": 10, "completion_tokens": 5}, - ) - if call_count == 2: - return LLMResponse(content=None, tool_calls=[], usage={"prompt_tokens": 10, "completion_tokens": 1}) - if call_count == 3: - return LLMResponse( - content=None, - tool_calls=[ToolCallRequest(id="tc2", name="read_file", arguments={"path": "b.txt"})], - usage={"prompt_tokens": 10, "completion_tokens": 5}, - ) - return LLMResponse( - content="Here are the results.", - tool_calls=[], - usage={"prompt_tokens": 10, "completion_tokens": 10}, - ) - - provider.chat_with_retry = chat_with_retry - provider.chat_stream_with_retry = chat_with_retry - - async def fake_tool(name, args, **kw): - return "file content" - - tool_registry = MagicMock() - tool_registry.get_definitions.return_value = [{"type": "function", "function": {"name": "read_file"}}] - tool_registry.execute = AsyncMock(side_effect=fake_tool) - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "read both files"}], - tools=tool_registry, - model="test-model", - max_iterations=10, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - )) - - assert result.final_content == "Here are the results." - assert result.stop_reason == "completed" - assert call_count == 4 - assert "read_file" in result.tools_used - - -def test_snip_history_drops_orphaned_tool_results_from_trimmed_slice(monkeypatch): - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - tools = MagicMock() - tools.get_definitions.return_value = [] - runner = AgentRunner(provider) - messages = [ - {"role": "system", "content": "system"}, - {"role": "user", "content": "old user"}, - { - "role": "assistant", - "content": "tool call", - "tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "ls", "arguments": "{}"}}], - }, - {"role": "tool", "tool_call_id": "call_1", "content": "tool output"}, - {"role": "assistant", "content": "after tool"}, - ] - spec = AgentRunSpec( - initial_messages=messages, - tools=tools, - model="test-model", - max_iterations=1, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - context_window_tokens=2000, - context_block_limit=100, - ) - - monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", lambda *_args, **_kwargs: (500, None)) - token_sizes = { - "old user": 120, - "tool call": 120, - "tool output": 40, - "after tool": 40, - "system": 0, - } - monkeypatch.setattr( - "nanobot.agent.runner.estimate_message_tokens", - lambda msg: token_sizes.get(str(msg.get("content")), 40), - ) - - trimmed = runner._snip_history(spec, messages) - - # After the fix, the user message is recovered so the sequence is valid - # for providers that require system → user (e.g. GLM error 1214). - assert trimmed[0]["role"] == "system" - non_system = [m for m in trimmed if m["role"] != "system"] - assert non_system[0]["role"] == "user", f"Expected user after system, got {non_system[0]['role']}" - - -@pytest.mark.asyncio -async def test_runner_keeps_going_when_tool_result_persistence_fails(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - captured_second_call: list[dict] = [] - call_count = {"n": 0} - - async def chat_with_retry(*, messages, **kwargs): - call_count["n"] += 1 - if call_count["n"] == 1: - return LLMResponse( - content="working", - tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})], - usage={"prompt_tokens": 5, "completion_tokens": 3}, - ) - captured_second_call[:] = messages - return LLMResponse(content="done", tool_calls=[], usage={}) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - tools.execute = AsyncMock(return_value="tool result") - - runner = AgentRunner(provider) - with patch("nanobot.agent.runner.maybe_persist_tool_result", side_effect=RuntimeError("disk full")): - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "do task"}], - tools=tools, - model="test-model", - max_iterations=2, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - )) - - assert result.final_content == "done" - tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool") - assert tool_message["content"] == "tool result" - - -class _DelayTool(Tool): - def __init__( - self, - name: str, - *, - delay: float, - read_only: bool, - shared_events: list[str], - exclusive: bool = False, - ): - self._name = name - self._delay = delay - self._read_only = read_only - self._shared_events = shared_events - self._exclusive = exclusive - - @property - def name(self) -> str: - return self._name - - @property - def description(self) -> str: - return self._name - - @property - def parameters(self) -> dict: - return {"type": "object", "properties": {}, "required": []} - - @property - def read_only(self) -> bool: - return self._read_only - - @property - def exclusive(self) -> bool: - return self._exclusive - - async def execute(self, **kwargs): - self._shared_events.append(f"start:{self._name}") - await asyncio.sleep(self._delay) - self._shared_events.append(f"end:{self._name}") - return self._name - - -@pytest.mark.asyncio -async def test_runner_batches_read_only_tools_before_exclusive_work(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - tools = ToolRegistry() - shared_events: list[str] = [] - read_a = _DelayTool("read_a", delay=0.05, read_only=True, shared_events=shared_events) - read_b = _DelayTool("read_b", delay=0.05, read_only=True, shared_events=shared_events) - write_a = _DelayTool("write_a", delay=0.01, read_only=False, shared_events=shared_events) - tools.register(read_a) - tools.register(read_b) - tools.register(write_a) - - runner = AgentRunner(MagicMock()) - await runner._execute_tools( - AgentRunSpec( - initial_messages=[], - tools=tools, - model="test-model", - max_iterations=1, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - concurrent_tools=True, - ), - [ - ToolCallRequest(id="ro1", name="read_a", arguments={}), - ToolCallRequest(id="ro2", name="read_b", arguments={}), - ToolCallRequest(id="rw1", name="write_a", arguments={}), - ], - {}, - {}, - ) - - assert shared_events[0:2] == ["start:read_a", "start:read_b"] - assert "end:read_a" in shared_events and "end:read_b" in shared_events - assert shared_events.index("end:read_a") < shared_events.index("start:write_a") - assert shared_events.index("end:read_b") < shared_events.index("start:write_a") - assert shared_events[-2:] == ["start:write_a", "end:write_a"] - - -@pytest.mark.asyncio -async def test_runner_does_not_batch_exclusive_read_only_tools(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - tools = ToolRegistry() - shared_events: list[str] = [] - read_a = _DelayTool("read_a", delay=0.03, read_only=True, shared_events=shared_events) - read_b = _DelayTool("read_b", delay=0.03, read_only=True, shared_events=shared_events) - ddg_like = _DelayTool( - "ddg_like", - delay=0.01, - read_only=True, - shared_events=shared_events, - exclusive=True, - ) - tools.register(read_a) - tools.register(ddg_like) - tools.register(read_b) - - runner = AgentRunner(MagicMock()) - await runner._execute_tools( - AgentRunSpec( - initial_messages=[], - tools=tools, - model="test-model", - max_iterations=1, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - concurrent_tools=True, - ), - [ - ToolCallRequest(id="ro1", name="read_a", arguments={}), - ToolCallRequest(id="ddg1", name="ddg_like", arguments={}), - ToolCallRequest(id="ro2", name="read_b", arguments={}), - ], - {}, - {}, - ) - - assert shared_events[0] == "start:read_a" - assert shared_events.index("end:read_a") < shared_events.index("start:ddg_like") - assert shared_events.index("end:ddg_like") < shared_events.index("start:read_b") - - -@pytest.mark.asyncio -async def test_runner_blocks_repeated_external_fetches(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - captured_final_call: list[dict] = [] - call_count = {"n": 0} - - async def chat_with_retry(*, messages, **kwargs): - call_count["n"] += 1 - if call_count["n"] <= 3: - return LLMResponse( - content="working", - tool_calls=[ToolCallRequest(id=f"call_{call_count['n']}", name="web_fetch", arguments={"url": "https://example.com"})], - usage={}, - ) - captured_final_call[:] = messages - return LLMResponse(content="done", tool_calls=[], usage={}) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - tools.execute = AsyncMock(return_value="page content") - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "research task"}], - tools=tools, - model="test-model", - max_iterations=4, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - )) - - assert result.final_content == "done" - assert tools.execute.await_count == 2 - blocked_tool_message = [ - msg for msg in captured_final_call - if msg.get("role") == "tool" and msg.get("tool_call_id") == "call_3" - ][0] - assert "repeated external lookup blocked" in blocked_tool_message["content"] - - -@pytest.mark.asyncio -async def test_loop_max_iterations_message_stays_stable(tmp_path): - loop = _make_loop(tmp_path) - loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse( - content="working", - tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})], - )) - loop.tools.get_definitions = MagicMock(return_value=[]) - loop.tools.execute = AsyncMock(return_value="ok") - loop.max_iterations = 2 - - final_content, _, _, _, _ = await loop._run_agent_loop([]) - - assert final_content == ( - "I reached the maximum number of tool call iterations (2) " - "without completing the task. You can try breaking the task into smaller steps." - ) - - -@pytest.mark.asyncio -async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp_path): - loop = _make_loop(tmp_path) - deltas: list[str] = [] - endings: list[bool] = [] - - async def chat_stream_with_retry(*, on_content_delta, **kwargs): - await on_content_delta("hidden") - await on_content_delta("Hello") - return LLMResponse(content="hiddenHello", tool_calls=[], usage={}) - - loop.provider.chat_stream_with_retry = chat_stream_with_retry - - async def on_stream(delta: str) -> None: - deltas.append(delta) - - async def on_stream_end(*, resuming: bool = False) -> None: - endings.append(resuming) - - final_content, _, _, _, _ = await loop._run_agent_loop( - [], - on_stream=on_stream, - on_stream_end=on_stream_end, - ) - - assert final_content == "Hello" - assert deltas == ["Hello"] - assert endings == [False] - - -@pytest.mark.asyncio -async def test_loop_stream_filter_hides_partial_trailing_think_prefix(tmp_path): - loop = _make_loop(tmp_path) - deltas: list[str] = [] - - async def chat_stream_with_retry(*, on_content_delta, **kwargs): - await on_content_delta("Hello hiddenWorld") - return LLMResponse(content="Hello hiddenWorld", tool_calls=[], usage={}) - - loop.provider.chat_stream_with_retry = chat_stream_with_retry - - async def on_stream(delta: str) -> None: - deltas.append(delta) - - final_content, _, _, _, _ = await loop._run_agent_loop([], on_stream=on_stream) - - assert final_content == "Hello World" - assert deltas == ["Hello", " World"] - - -@pytest.mark.asyncio -async def test_loop_stream_filter_hides_complete_trailing_think_tag(tmp_path): - loop = _make_loop(tmp_path) - deltas: list[str] = [] - - async def chat_stream_with_retry(*, on_content_delta, **kwargs): - await on_content_delta("Hello ") - await on_content_delta("hiddenWorld") - return LLMResponse(content="Hello hiddenWorld", tool_calls=[], usage={}) - - loop.provider.chat_stream_with_retry = chat_stream_with_retry - - async def on_stream(delta: str) -> None: - deltas.append(delta) - - final_content, _, _, _, _ = await loop._run_agent_loop([], on_stream=on_stream) - - assert final_content == "Hello World" - assert deltas == ["Hello", " World"] - - -@pytest.mark.asyncio -async def test_loop_retries_think_only_final_response(tmp_path): - loop = _make_loop(tmp_path) - call_count = {"n": 0} - - async def chat_with_retry(**kwargs): - call_count["n"] += 1 - if call_count["n"] == 1: - return LLMResponse(content="hidden", tool_calls=[], usage={}) - return LLMResponse(content="Recovered answer", tool_calls=[], usage={}) - - loop.provider.chat_with_retry = chat_with_retry - - final_content, _, _, _, _ = await loop._run_agent_loop([]) - - assert final_content == "Recovered answer" - assert call_count["n"] == 2 - - -@pytest.mark.asyncio -async def test_llm_error_not_appended_to_session_messages(): - """When LLM returns finish_reason='error', the error content must NOT be - appended to the messages list (prevents polluting session history).""" - from nanobot.agent.runner import ( - AgentRunSpec, - AgentRunner, - _PERSISTED_MODEL_ERROR_PLACEHOLDER, - ) - - provider = MagicMock() - provider.chat_with_retry = AsyncMock(return_value=LLMResponse( - content="429 rate limit exceeded", finish_reason="error", tool_calls=[], usage={}, - )) - tools = MagicMock() - tools.get_definitions.return_value = [] - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "hello"}], - tools=tools, - model="test-model", - max_iterations=5, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - )) - - assert result.stop_reason == "error" - assert result.final_content == "429 rate limit exceeded" - assistant_msgs = [m for m in result.messages if m.get("role") == "assistant"] - assert all("429" not in (m.get("content") or "") for m in assistant_msgs), \ - "Error content should not appear in session messages" - assert assistant_msgs[-1]["content"] == _PERSISTED_MODEL_ERROR_PLACEHOLDER - - -@pytest.mark.asyncio -async def test_streamed_flag_not_set_on_llm_error(tmp_path): - """When LLM errors during a streaming-capable channel interaction, - _streamed must NOT be set so ChannelManager delivers the error.""" - from nanobot.agent.loop import AgentLoop - from nanobot.bus.events import InboundMessage - from nanobot.bus.queue import MessageBus - - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") - error_resp = LLMResponse( - content="503 service unavailable", finish_reason="error", tool_calls=[], usage={}, - ) - loop.provider.chat_with_retry = AsyncMock(return_value=error_resp) - loop.provider.chat_stream_with_retry = AsyncMock(return_value=error_resp) - loop.tools.get_definitions = MagicMock(return_value=[]) - - msg = InboundMessage( - channel="feishu", sender_id="u1", chat_id="c1", content="hi", - ) - result = await loop._process_message( - msg, - on_stream=AsyncMock(), - on_stream_end=AsyncMock(), - ) - - assert result is not None - assert "503" in result.content - assert not result.metadata.get("_streamed"), \ - "_streamed must not be set when stop_reason is error" - - -@pytest.mark.asyncio -async def test_ssrf_soft_block_can_finalize_after_streamed_tool_call(tmp_path): - from nanobot.agent.loop import AgentLoop - from nanobot.bus.events import InboundMessage - from nanobot.bus.queue import MessageBus - - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - tool_call_resp = LLMResponse( - content="checking metadata", - tool_calls=[ToolCallRequest( - id="call_ssrf", - name="exec", - arguments={"command": "curl http://169.254.169.254/latest/meta-data/"}, - )], - usage={}, - ) - provider.chat_stream_with_retry = AsyncMock(side_effect=[ - tool_call_resp, - LLMResponse( - content="I cannot access private URLs. Please share the local file.", - tool_calls=[], - usage={}, - ), - ]) - - loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") - loop.tools.get_definitions = MagicMock(return_value=[]) - loop.tools.prepare_call = MagicMock(return_value=(None, {}, None)) - loop.tools.execute = AsyncMock(return_value=( - "Error: Command blocked by safety guard (internal/private URL detected)" - )) - - result = await loop._process_message( - InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="hi"), - on_stream=AsyncMock(), - on_stream_end=AsyncMock(), - ) - - assert result is not None - assert result.content == "I cannot access private URLs. Please share the local file." - assert result.metadata.get("_streamed") is True - - -@pytest.mark.asyncio -async def test_next_turn_after_llm_error_keeps_turn_boundary(tmp_path): - from nanobot.agent.loop import AgentLoop - from nanobot.agent.runner import _PERSISTED_MODEL_ERROR_PLACEHOLDER - from nanobot.bus.events import InboundMessage - from nanobot.bus.queue import MessageBus - - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - provider.chat_with_retry = AsyncMock(side_effect=[ - LLMResponse(content="429 rate limit exceeded", finish_reason="error", tool_calls=[], usage={}), - LLMResponse(content="Recovered answer", tool_calls=[], usage={}), - ]) - - loop = AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model") - loop.tools.get_definitions = MagicMock(return_value=[]) - loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] - - first = await loop._process_message( - InboundMessage(channel="cli", sender_id="user", chat_id="test", content="first question") - ) - assert first is not None - assert first.content == "429 rate limit exceeded" - - session = loop.sessions.get_or_create("cli:test") - assert [ - {key: value for key, value in message.items() if key in {"role", "content"}} - for message in session.messages - ] == [ - {"role": "user", "content": "first question"}, - {"role": "assistant", "content": _PERSISTED_MODEL_ERROR_PLACEHOLDER}, - ] - - second = await loop._process_message( - InboundMessage(channel="cli", sender_id="user", chat_id="test", content="second question") - ) - assert second is not None - assert second.content == "Recovered answer" - - request_messages = provider.chat_with_retry.await_args_list[1].kwargs["messages"] - non_system = [message for message in request_messages if message.get("role") != "system"] - assert non_system[0]["role"] == "user" - assert "first question" in non_system[0]["content"] - assert non_system[1]["role"] == "assistant" - assert _PERSISTED_MODEL_ERROR_PLACEHOLDER in non_system[1]["content"] - assert non_system[2]["role"] == "user" - assert "second question" in non_system[2]["content"] - - -@pytest.mark.asyncio -async def test_runner_tool_error_sets_final_content(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - - async def chat_with_retry(*, messages, **kwargs): - return LLMResponse( - content="working", - tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})], - usage={}, - ) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - tools.execute = AsyncMock(side_effect=RuntimeError("boom")) - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "do task"}], - tools=tools, - model="test-model", - max_iterations=1, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - fail_on_tool_error=True, - )) - - assert result.final_content == "Error: RuntimeError: boom" - assert result.stop_reason == "tool_error" - - -@pytest.mark.asyncio -async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, monkeypatch): - from nanobot.agent.subagent import SubagentManager, SubagentStatus - from nanobot.bus.queue import MessageBus - - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - provider.chat_with_retry = AsyncMock(return_value=LLMResponse( - content="working", - tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})], - )) - mgr = SubagentManager( - provider=provider, - workspace=tmp_path, - bus=bus, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - ) - mgr._announce_result = AsyncMock() - - async def fake_execute(self, **kwargs): - return "tool result" - - monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute) - - status = SubagentStatus(task_id="sub-1", label="label", task_description="do task", started_at=time.monotonic()) - await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"}, status) - - mgr._announce_result.assert_awaited_once() - args = mgr._announce_result.await_args.args - assert args[3] == "Task completed but no final response was generated." - assert args[5] == "ok" - - -@pytest.mark.asyncio -async def test_runner_accumulates_usage_and_preserves_cached_tokens(): - """Runner should accumulate prompt/completion tokens across iterations - and preserve cached_tokens from provider responses.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - call_count = {"n": 0} - - async def chat_with_retry(*, messages, **kwargs): - call_count["n"] += 1 - if call_count["n"] == 1: - return LLMResponse( - content="thinking", - tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})], - usage={"prompt_tokens": 100, "completion_tokens": 10, "cached_tokens": 80}, - ) - return LLMResponse( - content="done", - tool_calls=[], - usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150}, - ) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - tools.execute = AsyncMock(return_value="file content") - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "do task"}], - tools=tools, - model="test-model", - max_iterations=3, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - )) - - # Usage should be accumulated across iterations - assert result.usage["prompt_tokens"] == 300 # 100 + 200 - assert result.usage["completion_tokens"] == 30 # 10 + 20 - assert result.usage["cached_tokens"] == 230 # 80 + 150 - - -@pytest.mark.asyncio -async def test_runner_passes_cached_tokens_to_hook_context(): - """Hook context.usage should contain cached_tokens.""" - from nanobot.agent.hook import AgentHook, AgentHookContext - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - captured_usage: list[dict] = [] - - class UsageHook(AgentHook): - async def after_iteration(self, context: AgentHookContext) -> None: - captured_usage.append(dict(context.usage)) - - async def chat_with_retry(**kwargs): - return LLMResponse( - content="done", - tool_calls=[], - usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150}, - ) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - - runner = AgentRunner(provider) - await runner.run(AgentRunSpec( - initial_messages=[], - tools=tools, - model="test-model", - max_iterations=1, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - hook=UsageHook(), - )) - - assert len(captured_usage) == 1 - assert captured_usage[0]["cached_tokens"] == 150 - - -# --------------------------------------------------------------------------- -# Length recovery (auto-continue on finish_reason == "length") -# --------------------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_length_recovery_continues_from_truncated_output(): - """When finish_reason is 'length', runner should insert a continuation - prompt and retry, stitching partial outputs into the final result.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - call_count = {"n": 0} - - async def chat_with_retry(*, messages, **kwargs): - call_count["n"] += 1 - if call_count["n"] <= 2: - return LLMResponse( - content=f"part{call_count['n']} ", - finish_reason="length", - usage={}, - ) - return LLMResponse(content="final", finish_reason="stop", usage={}) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "write a long essay"}], - tools=tools, - model="test-model", - max_iterations=10, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - )) - - assert result.stop_reason == "completed" - assert result.final_content == "final" - assert call_count["n"] == 3 - roles = [m["role"] for m in result.messages if m["role"] == "user"] - assert len(roles) >= 3 # original + 2 recovery prompts - - -@pytest.mark.asyncio -async def test_length_recovery_streaming_calls_on_stream_end_with_resuming(): - """During length recovery with streaming, on_stream_end should be called - with resuming=True so the hook knows the conversation is continuing.""" - from nanobot.agent.hook import AgentHook, AgentHookContext - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - call_count = {"n": 0} - stream_end_calls: list[bool] = [] - - class StreamHook(AgentHook): - def wants_streaming(self) -> bool: - return True - - async def on_stream(self, context: AgentHookContext, delta: str) -> None: - pass - - async def on_stream_end(self, context: AgentHookContext, resuming: bool = False) -> None: - stream_end_calls.append(resuming) - - async def chat_stream_with_retry(*, messages, on_content_delta=None, **kwargs): - call_count["n"] += 1 - if call_count["n"] == 1: - return LLMResponse(content="partial ", finish_reason="length", usage={}) - return LLMResponse(content="done", finish_reason="stop", usage={}) - - provider.chat_stream_with_retry = chat_stream_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - - runner = AgentRunner(provider) - await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "go"}], - tools=tools, - model="test-model", - max_iterations=10, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - hook=StreamHook(), - )) - - assert len(stream_end_calls) == 2 - assert stream_end_calls[0] is True # length recovery: resuming - assert stream_end_calls[1] is False # final response: done - - -@pytest.mark.asyncio -async def test_length_recovery_gives_up_after_max_retries(): - """After _MAX_LENGTH_RECOVERIES attempts the runner should stop retrying.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_LENGTH_RECOVERIES - - provider = MagicMock() - call_count = {"n": 0} - - async def chat_with_retry(*, messages, **kwargs): - call_count["n"] += 1 - return LLMResponse( - content=f"chunk{call_count['n']}", - finish_reason="length", - usage={}, - ) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "go"}], - tools=tools, - model="test-model", - max_iterations=20, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - )) - - assert call_count["n"] == _MAX_LENGTH_RECOVERIES + 1 - assert result.final_content is not None - - -# --------------------------------------------------------------------------- -# Backfill missing tool_results -# --------------------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_backfill_missing_tool_results_inserts_error(): - """Orphaned tool_use (no matching tool_result) should get a synthetic error.""" - from nanobot.agent.runner import AgentRunner, _BACKFILL_CONTENT - - messages = [ - {"role": "user", "content": "hi"}, - { - "role": "assistant", - "content": "", - "tool_calls": [ - {"id": "call_a", "type": "function", "function": {"name": "exec", "arguments": "{}"}}, - {"id": "call_b", "type": "function", "function": {"name": "read_file", "arguments": "{}"}}, - ], - }, - {"role": "tool", "tool_call_id": "call_a", "name": "exec", "content": "ok"}, - ] - result = AgentRunner._backfill_missing_tool_results(messages) - tool_msgs = [m for m in result if m.get("role") == "tool"] - assert len(tool_msgs) == 2 - backfilled = [m for m in tool_msgs if m.get("tool_call_id") == "call_b"] - assert len(backfilled) == 1 - assert backfilled[0]["content"] == _BACKFILL_CONTENT - assert backfilled[0]["name"] == "read_file" - - -def test_drop_orphan_tool_results_removes_unmatched_tool_messages(): - from nanobot.agent.runner import AgentRunner - - messages = [ - {"role": "system", "content": "system"}, - {"role": "user", "content": "old user"}, - { - "role": "assistant", - "content": "", - "tool_calls": [ - {"id": "call_ok", "type": "function", "function": {"name": "read_file", "arguments": "{}"}}, - ], - }, - {"role": "tool", "tool_call_id": "call_ok", "name": "read_file", "content": "ok"}, - {"role": "tool", "tool_call_id": "call_orphan", "name": "exec", "content": "stale"}, - {"role": "assistant", "content": "after tool"}, - ] - - cleaned = AgentRunner._drop_orphan_tool_results(messages) - - assert cleaned == [ - {"role": "system", "content": "system"}, - {"role": "user", "content": "old user"}, - { - "role": "assistant", - "content": "", - "tool_calls": [ - {"id": "call_ok", "type": "function", "function": {"name": "read_file", "arguments": "{}"}}, - ], - }, - {"role": "tool", "tool_call_id": "call_ok", "name": "read_file", "content": "ok"}, - {"role": "assistant", "content": "after tool"}, - ] - - -@pytest.mark.asyncio -async def test_backfill_noop_when_complete(): - """Complete message chains should not be modified.""" - from nanobot.agent.runner import AgentRunner - - messages = [ - {"role": "user", "content": "hi"}, - { - "role": "assistant", - "content": "", - "tool_calls": [ - {"id": "call_x", "type": "function", "function": {"name": "exec", "arguments": "{}"}}, - ], - }, - {"role": "tool", "tool_call_id": "call_x", "name": "exec", "content": "done"}, - {"role": "assistant", "content": "all good"}, - ] - result = AgentRunner._backfill_missing_tool_results(messages) - assert result is messages # same object — no copy - - -@pytest.mark.asyncio -async def test_runner_drops_orphan_tool_results_before_model_request(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - captured_messages: list[dict] = [] - - async def chat_with_retry(*, messages, **kwargs): - captured_messages[:] = messages - return LLMResponse(content="done", tool_calls=[], usage={}) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[ - {"role": "system", "content": "system"}, - {"role": "user", "content": "old user"}, - {"role": "tool", "tool_call_id": "call_orphan", "name": "exec", "content": "stale"}, - {"role": "assistant", "content": "after orphan"}, - {"role": "user", "content": "new prompt"}, - ], - tools=tools, - model="test-model", - max_iterations=1, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - )) - - assert all( - message.get("tool_call_id") != "call_orphan" - for message in captured_messages - if message.get("role") == "tool" - ) - assert result.messages[2]["tool_call_id"] == "call_orphan" - assert result.final_content == "done" - - -@pytest.mark.asyncio -async def test_backfill_repairs_model_context_without_shifting_save_turn_boundary(tmp_path): - """Historical backfill should not duplicate old tail messages on persist.""" - from nanobot.agent.loop import AgentLoop - from nanobot.agent.runner import _BACKFILL_CONTENT - from nanobot.bus.events import InboundMessage - from nanobot.bus.queue import MessageBus - - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - response = LLMResponse(content="new answer", tool_calls=[], usage={}) - provider.chat_with_retry = AsyncMock(return_value=response) - provider.chat_stream_with_retry = AsyncMock(return_value=response) - - loop = AgentLoop( - bus=MessageBus(), - provider=provider, - workspace=tmp_path, - model="test-model", - ) - loop.tools.get_definitions = MagicMock(return_value=[]) - loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] - - session = loop.sessions.get_or_create("cli:test") - session.messages = [ - {"role": "user", "content": "old user", "timestamp": "2026-01-01T00:00:00"}, - { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": "call_missing", - "type": "function", - "function": {"name": "read_file", "arguments": "{}"}, - } - ], - "timestamp": "2026-01-01T00:00:01", - }, - {"role": "assistant", "content": "old tail", "timestamp": "2026-01-01T00:00:02"}, - ] - loop.sessions.save(session) - - result = await loop._process_message( - InboundMessage(channel="cli", sender_id="user", chat_id="test", content="new prompt") - ) - - assert result is not None - assert result.content == "new answer" - - request_messages = provider.chat_with_retry.await_args.kwargs["messages"] - synthetic = [ - message - for message in request_messages - if message.get("role") == "tool" and message.get("tool_call_id") == "call_missing" - ] - assert len(synthetic) == 1 - assert synthetic[0]["content"] == _BACKFILL_CONTENT - - session_after = loop.sessions.get_or_create("cli:test") - assert [ - { - key: value - for key, value in message.items() - if key in {"role", "content", "tool_call_id", "name", "tool_calls"} - } - for message in session_after.messages - ] == [ - {"role": "user", "content": "old user"}, - { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": "call_missing", - "type": "function", - "function": {"name": "read_file", "arguments": "{}"}, - } - ], - }, - {"role": "assistant", "content": "old tail"}, - {"role": "user", "content": "new prompt"}, - {"role": "assistant", "content": "new answer"}, - ] - - -@pytest.mark.asyncio -async def test_runner_backfill_only_mutates_model_context_not_returned_messages(): - """Runner should repair orphaned tool calls for the model without rewriting result.messages.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner, _BACKFILL_CONTENT - - provider = MagicMock() - captured_messages: list[dict] = [] - - async def chat_with_retry(*, messages, **kwargs): - captured_messages[:] = messages - return LLMResponse(content="done", tool_calls=[], usage={}) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - - initial_messages = [ - {"role": "system", "content": "system"}, - {"role": "user", "content": "old user"}, - { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": "call_missing", - "type": "function", - "function": {"name": "read_file", "arguments": "{}"}, - } - ], - }, - {"role": "assistant", "content": "old tail"}, - {"role": "user", "content": "new prompt"}, - ] - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=initial_messages, - tools=tools, - model="test-model", - max_iterations=3, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - )) - - synthetic = [ - message - for message in captured_messages - if message.get("role") == "tool" and message.get("tool_call_id") == "call_missing" - ] - assert len(synthetic) == 1 - assert synthetic[0]["content"] == _BACKFILL_CONTENT - - assert [ - { - key: value - for key, value in message.items() - if key in {"role", "content", "tool_call_id", "name", "tool_calls"} - } - for message in result.messages - ] == [ - {"role": "system", "content": "system"}, - {"role": "user", "content": "old user"}, - { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": "call_missing", - "type": "function", - "function": {"name": "read_file", "arguments": "{}"}, - } - ], - }, - {"role": "assistant", "content": "old tail"}, - {"role": "user", "content": "new prompt"}, - {"role": "assistant", "content": "done"}, - ] - - -# --------------------------------------------------------------------------- -# Microcompact (stale tool result compaction) -# --------------------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_microcompact_replaces_old_tool_results(): - """Tool results beyond _MICROCOMPACT_KEEP_RECENT should be summarized.""" - from nanobot.agent.runner import AgentRunner, _MICROCOMPACT_KEEP_RECENT - - total = _MICROCOMPACT_KEEP_RECENT + 5 - long_content = "x" * 600 - messages: list[dict] = [{"role": "system", "content": "sys"}] - for i in range(total): - messages.append({ - "role": "assistant", - "content": "", - "tool_calls": [{"id": f"c{i}", "type": "function", "function": {"name": "read_file", "arguments": "{}"}}], - }) - messages.append({ - "role": "tool", "tool_call_id": f"c{i}", "name": "read_file", - "content": long_content, - }) - - result = AgentRunner._microcompact(messages) - tool_msgs = [m for m in result if m.get("role") == "tool"] - stale_count = total - _MICROCOMPACT_KEEP_RECENT - compacted = [m for m in tool_msgs if "omitted from context" in str(m.get("content", ""))] - preserved = [m for m in tool_msgs if m.get("content") == long_content] - assert len(compacted) == stale_count - assert len(preserved) == _MICROCOMPACT_KEEP_RECENT - - -@pytest.mark.asyncio -async def test_microcompact_preserves_short_results(): - """Short tool results (< _MICROCOMPACT_MIN_CHARS) should not be replaced.""" - from nanobot.agent.runner import AgentRunner, _MICROCOMPACT_KEEP_RECENT - - total = _MICROCOMPACT_KEEP_RECENT + 5 - messages: list[dict] = [] - for i in range(total): - messages.append({ - "role": "assistant", - "content": "", - "tool_calls": [{"id": f"c{i}", "type": "function", "function": {"name": "exec", "arguments": "{}"}}], - }) - messages.append({ - "role": "tool", "tool_call_id": f"c{i}", "name": "exec", - "content": "short", - }) - - result = AgentRunner._microcompact(messages) - assert result is messages # no copy needed — all stale results are short - - -@pytest.mark.asyncio -async def test_microcompact_skips_non_compactable_tools(): - """Non-compactable tools (e.g. 'message') should never be replaced.""" - from nanobot.agent.runner import AgentRunner, _MICROCOMPACT_KEEP_RECENT - - total = _MICROCOMPACT_KEEP_RECENT + 5 - long_content = "y" * 1000 - messages: list[dict] = [] - for i in range(total): - messages.append({ - "role": "assistant", - "content": "", - "tool_calls": [{"id": f"c{i}", "type": "function", "function": {"name": "message", "arguments": "{}"}}], - }) - messages.append({ - "role": "tool", "tool_call_id": f"c{i}", "name": "message", - "content": long_content, - }) - - result = AgentRunner._microcompact(messages) - assert result is messages # no compactable tools found - - -@pytest.mark.asyncio -async def test_runner_tool_error_preserves_tool_results_in_messages(): - """When a tool raises a fatal error, its results must still be appended - to messages so the session never contains orphan tool_calls (#2943).""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - - async def chat_with_retry(*, messages, **kwargs): - return LLMResponse( - content=None, - tool_calls=[ - ToolCallRequest(id="tc1", name="read_file", arguments={"path": "a"}), - ToolCallRequest(id="tc2", name="exec", arguments={"cmd": "bad"}), - ], - usage={}, - ) - - provider.chat_with_retry = chat_with_retry - provider.chat_stream_with_retry = chat_with_retry - - call_idx = 0 - - async def fake_execute(name, args, **kw): - nonlocal call_idx - call_idx += 1 - if call_idx == 2: - raise RuntimeError("boom") - return "file content" - - tools = MagicMock() - tools.get_definitions.return_value = [] - tools.execute = AsyncMock(side_effect=fake_execute) - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "do stuff"}], - tools=tools, - model="test-model", - max_iterations=1, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - fail_on_tool_error=True, - )) - - assert result.stop_reason == "tool_error" - # Both tool results must be in messages even though tc2 had a fatal error. - tool_msgs = [m for m in result.messages if m.get("role") == "tool"] - assert len(tool_msgs) == 2 - assert tool_msgs[0]["tool_call_id"] == "tc1" - assert tool_msgs[1]["tool_call_id"] == "tc2" - # The assistant message with tool_calls must precede the tool results. - asst_tc_idx = next( - i for i, m in enumerate(result.messages) - if m.get("role") == "assistant" and m.get("tool_calls") - ) - tool_indices = [ - i for i, m in enumerate(result.messages) if m.get("role") == "tool" - ] - assert all(ti > asst_tc_idx for ti in tool_indices) - - -def test_governance_repairs_orphans_after_snip(): - """After _snip_history clips an assistant+tool_calls, the second - _drop_orphan_tool_results pass must clean up the resulting orphans.""" - from nanobot.agent.runner import AgentRunner - - messages = [ - {"role": "system", "content": "system"}, - {"role": "user", "content": "old msg"}, - {"role": "assistant", "content": None, - "tool_calls": [{"id": "tc_old", "type": "function", - "function": {"name": "search", "arguments": "{}"}}]}, - {"role": "tool", "tool_call_id": "tc_old", "name": "search", - "content": "old result"}, - {"role": "assistant", "content": "old answer"}, - {"role": "user", "content": "new msg"}, - ] - - # Simulate snipping that keeps only the tail: drop the assistant with - # tool_calls but keep its tool result (orphan). - snipped = [ - {"role": "system", "content": "system"}, - {"role": "tool", "tool_call_id": "tc_old", "name": "search", - "content": "old result"}, - {"role": "assistant", "content": "old answer"}, - {"role": "user", "content": "new msg"}, - ] - - cleaned = AgentRunner._drop_orphan_tool_results(snipped) - # The orphan tool result should be removed. - assert not any( - m.get("role") == "tool" and m.get("tool_call_id") == "tc_old" - for m in cleaned - ) - - -def test_governance_fallback_still_repairs_orphans(): - """When full governance fails, the fallback must still run - _drop_orphan_tool_results and _backfill_missing_tool_results.""" - from nanobot.agent.runner import AgentRunner - - # Messages with an orphan tool result (no matching assistant tool_call). - messages = [ - {"role": "user", "content": "hello"}, - {"role": "tool", "tool_call_id": "orphan_tc", "name": "read", - "content": "stale"}, - {"role": "assistant", "content": "hi"}, - ] - - repaired = AgentRunner._drop_orphan_tool_results(messages) - repaired = AgentRunner._backfill_missing_tool_results(repaired) - # Orphan tool result should be gone. - assert not any(m.get("tool_call_id") == "orphan_tc" for m in repaired) -# ── Mid-turn injection tests ────────────────────────────────────────────── - - -@pytest.mark.asyncio -async def test_drain_injections_returns_empty_when_no_callback(): - """No injection_callback → empty list.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - runner = AgentRunner(provider) - tools = MagicMock() - tools.get_definitions.return_value = [] - spec = AgentRunSpec( - initial_messages=[], tools=tools, model="m", - max_iterations=1, max_tool_result_chars=1000, - injection_callback=None, - ) - result = await runner._drain_injections(spec) - assert result == [] - - -@pytest.mark.asyncio -async def test_drain_injections_extracts_content_from_inbound_messages(): - """Should extract .content from InboundMessage objects.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner - from nanobot.bus.events import InboundMessage - - provider = MagicMock() - runner = AgentRunner(provider) - tools = MagicMock() - tools.get_definitions.return_value = [] - - msgs = [ - InboundMessage(channel="cli", sender_id="u", chat_id="c", content="hello"), - InboundMessage(channel="cli", sender_id="u", chat_id="c", content="world"), - ] - - async def cb(): - return msgs - - spec = AgentRunSpec( - initial_messages=[], tools=tools, model="m", - max_iterations=1, max_tool_result_chars=1000, - injection_callback=cb, - ) - result = await runner._drain_injections(spec) - assert result == [ - {"role": "user", "content": "hello"}, - {"role": "user", "content": "world"}, - ] - - -@pytest.mark.asyncio -async def test_drain_injections_passes_limit_to_callback_when_supported(): - """Limit-aware callbacks can preserve overflow in their own queue.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN - from nanobot.bus.events import InboundMessage - - provider = MagicMock() - runner = AgentRunner(provider) - tools = MagicMock() - tools.get_definitions.return_value = [] - seen_limits: list[int] = [] - - msgs = [ - InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg{i}") - for i in range(_MAX_INJECTIONS_PER_TURN + 3) - ] - - async def cb(*, limit: int): - seen_limits.append(limit) - return msgs[:limit] - - spec = AgentRunSpec( - initial_messages=[], tools=tools, model="m", - max_iterations=1, max_tool_result_chars=1000, - injection_callback=cb, - ) - result = await runner._drain_injections(spec) - assert seen_limits == [_MAX_INJECTIONS_PER_TURN] - assert result == [ - {"role": "user", "content": "msg0"}, - {"role": "user", "content": "msg1"}, - {"role": "user", "content": "msg2"}, - ] - - -@pytest.mark.asyncio -async def test_drain_injections_skips_empty_content(): - """Messages with blank content should be filtered out.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner - from nanobot.bus.events import InboundMessage - - provider = MagicMock() - runner = AgentRunner(provider) - tools = MagicMock() - tools.get_definitions.return_value = [] - - msgs = [ - InboundMessage(channel="cli", sender_id="u", chat_id="c", content=""), - InboundMessage(channel="cli", sender_id="u", chat_id="c", content=" "), - InboundMessage(channel="cli", sender_id="u", chat_id="c", content="valid"), - ] - - async def cb(): - return msgs - - spec = AgentRunSpec( - initial_messages=[], tools=tools, model="m", - max_iterations=1, max_tool_result_chars=1000, - injection_callback=cb, - ) - result = await runner._drain_injections(spec) - assert result == [{"role": "user", "content": "valid"}] - - -@pytest.mark.asyncio -async def test_drain_injections_handles_callback_exception(): - """If the callback raises, return empty list (error is logged).""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - runner = AgentRunner(provider) - tools = MagicMock() - tools.get_definitions.return_value = [] - - async def cb(): - raise RuntimeError("boom") - - spec = AgentRunSpec( - initial_messages=[], tools=tools, model="m", - max_iterations=1, max_tool_result_chars=1000, - injection_callback=cb, - ) - result = await runner._drain_injections(spec) - assert result == [] - - -@pytest.mark.asyncio -async def test_checkpoint1_injects_after_tool_execution(): - """Follow-up messages are injected after tool execution, before next LLM call.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner - from nanobot.bus.events import InboundMessage - - provider = MagicMock() - call_count = {"n": 0} - captured_messages = [] - - async def chat_with_retry(*, messages, **kwargs): - call_count["n"] += 1 - captured_messages.append(list(messages)) - if call_count["n"] == 1: - return LLMResponse( - content="using tool", - tool_calls=[ToolCallRequest(id="c1", name="read_file", arguments={"path": "x"})], - usage={}, - ) - return LLMResponse(content="final answer", tool_calls=[], usage={}) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - tools.execute = AsyncMock(return_value="file content") - - injection_queue = asyncio.Queue() - inject_cb = _make_injection_callback(injection_queue) - - # Put a follow-up message in the queue before the run starts - await injection_queue.put( - InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question") - ) - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "hello"}], - tools=tools, - model="test-model", - max_iterations=5, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - injection_callback=inject_cb, - )) - - assert result.had_injections is True - assert result.final_content == "final answer" - # The second call should have the injected user message - assert call_count["n"] == 2 - last_messages = captured_messages[-1] - injected = [m for m in last_messages if m.get("role") == "user" and m.get("content") == "follow-up question"] - assert len(injected) == 1 - - -@pytest.mark.asyncio -async def test_checkpoint2_injects_after_final_response_with_resuming_stream(): - """After final response, if injections exist, stream_end should get resuming=True.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner - from nanobot.agent.hook import AgentHook, AgentHookContext - from nanobot.bus.events import InboundMessage - - provider = MagicMock() - call_count = {"n": 0} - stream_end_calls = [] - - class TrackingHook(AgentHook): - def wants_streaming(self) -> bool: - return True - - async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: - stream_end_calls.append(resuming) - - def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: - return content - - async def chat_stream_with_retry(*, messages, on_content_delta=None, **kwargs): - call_count["n"] += 1 - if call_count["n"] == 1: - return LLMResponse(content="first answer", tool_calls=[], usage={}) - return LLMResponse(content="second answer", tool_calls=[], usage={}) - - provider.chat_stream_with_retry = chat_stream_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - - injection_queue = asyncio.Queue() - inject_cb = _make_injection_callback(injection_queue) - - # Inject a follow-up that arrives during the first response - await injection_queue.put( - InboundMessage(channel="cli", sender_id="u", chat_id="c", content="quick follow-up") - ) - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "hello"}], - tools=tools, - model="test-model", - max_iterations=5, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - hook=TrackingHook(), - injection_callback=inject_cb, - )) - - assert result.had_injections is True - assert result.final_content == "second answer" - assert call_count["n"] == 2 - # First stream_end should have resuming=True (because injections found) - assert stream_end_calls[0] is True - # Second (final) stream_end should have resuming=False - assert stream_end_calls[-1] is False - - -@pytest.mark.asyncio -async def test_checkpoint2_preserves_final_response_in_history_before_followup(): - """A follow-up injected after a final answer must still see that answer in history.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner - from nanobot.bus.events import InboundMessage - - provider = MagicMock() - call_count = {"n": 0} - captured_messages = [] - - async def chat_with_retry(*, messages, **kwargs): - call_count["n"] += 1 - captured_messages.append([dict(message) for message in messages]) - if call_count["n"] == 1: - return LLMResponse(content="first answer", tool_calls=[], usage={}) - return LLMResponse(content="second answer", tool_calls=[], usage={}) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - - injection_queue = asyncio.Queue() - inject_cb = _make_injection_callback(injection_queue) - - await injection_queue.put( - InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question") - ) - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "hello"}], - tools=tools, - model="test-model", - max_iterations=5, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - injection_callback=inject_cb, - )) - - assert result.final_content == "second answer" - assert call_count["n"] == 2 - assert captured_messages[-1] == [ - {"role": "user", "content": "hello"}, - {"role": "assistant", "content": "first answer"}, - {"role": "user", "content": "follow-up question"}, - ] - assert [ - {"role": message["role"], "content": message["content"]} - for message in result.messages - if message.get("role") == "assistant" - ] == [ - {"role": "assistant", "content": "first answer"}, - {"role": "assistant", "content": "second answer"}, - ] - - -@pytest.mark.asyncio -async def test_loop_injected_followup_preserves_image_media(tmp_path): - """Mid-turn follow-ups with images should keep multimodal content.""" - from nanobot.agent.loop import AgentLoop - from nanobot.bus.events import InboundMessage - from nanobot.bus.queue import MessageBus - - image_path = tmp_path / "followup.png" - image_path.write_bytes(base64.b64decode( - "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+yF9kAAAAASUVORK5CYII=" - )) - - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - captured_messages: list[list[dict]] = [] - call_count = {"n": 0} - - async def chat_with_retry(*, messages, **kwargs): - call_count["n"] += 1 - captured_messages.append(list(messages)) - if call_count["n"] == 1: - return LLMResponse(content="first answer", tool_calls=[], usage={}) - return LLMResponse(content="second answer", tool_calls=[], usage={}) - - provider.chat_with_retry = chat_with_retry - loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") - loop.tools.get_definitions = MagicMock(return_value=[]) - - pending_queue = asyncio.Queue() - await pending_queue.put(InboundMessage( - channel="cli", - sender_id="u", - chat_id="c", - content="", - media=[str(image_path)], - )) - - final_content, _, _, _, had_injections = await loop._run_agent_loop( - [{"role": "user", "content": "hello"}], - channel="cli", - chat_id="c", - pending_queue=pending_queue, - ) - - assert final_content == "second answer" - assert had_injections is True - assert call_count["n"] == 2 - injected_user_messages = [ - message for message in captured_messages[-1] - if message.get("role") == "user" and isinstance(message.get("content"), list) - ] - assert injected_user_messages - assert any( - block.get("type") == "image_url" - for block in injected_user_messages[-1]["content"] - if isinstance(block, dict) - ) - - -@pytest.mark.asyncio -async def test_runner_merges_multiple_injected_user_messages_without_losing_media(): - """Multiple injected follow-ups should not create lossy consecutive user messages.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - call_count = {"n": 0} - captured_messages = [] - - async def chat_with_retry(*, messages, **kwargs): - call_count["n"] += 1 - captured_messages.append([dict(message) for message in messages]) - if call_count["n"] == 1: - return LLMResponse(content="first answer", tool_calls=[], usage={}) - return LLMResponse(content="second answer", tool_calls=[], usage={}) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - - async def inject_cb(): - if call_count["n"] == 1: - return [ - { - "role": "user", - "content": [ - {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, - {"type": "text", "text": "look at this"}, - ], - }, - {"role": "user", "content": "and answer briefly"}, - ] - return [] - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "hello"}], - tools=tools, - model="test-model", - max_iterations=5, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - injection_callback=inject_cb, - )) - - assert result.final_content == "second answer" - assert call_count["n"] == 2 - second_call = captured_messages[-1] - user_messages = [message for message in second_call if message.get("role") == "user"] - assert len(user_messages) == 2 - injected = user_messages[-1] - assert isinstance(injected["content"], list) - assert any( - block.get("type") == "image_url" - for block in injected["content"] - if isinstance(block, dict) - ) - assert any( - block.get("type") == "text" and block.get("text") == "and answer briefly" - for block in injected["content"] - if isinstance(block, dict) - ) - - -@pytest.mark.asyncio -async def test_injection_cycles_capped_at_max(): - """Injection cycles should be capped at _MAX_INJECTION_CYCLES.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTION_CYCLES - from nanobot.bus.events import InboundMessage - - provider = MagicMock() - call_count = {"n": 0} - - async def chat_with_retry(*, messages, **kwargs): - call_count["n"] += 1 - return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={}) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - - drain_count = {"n": 0} - - async def inject_cb(): - drain_count["n"] += 1 - # Only inject for the first _MAX_INJECTION_CYCLES drains - if drain_count["n"] <= _MAX_INJECTION_CYCLES: - return [InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg-{drain_count['n']}")] - return [] - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "start"}], - tools=tools, - model="test-model", - max_iterations=20, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - injection_callback=inject_cb, - )) - - assert result.had_injections is True - # Should be capped: _MAX_INJECTION_CYCLES injection rounds + 1 final round - assert call_count["n"] == _MAX_INJECTION_CYCLES + 1 - - -@pytest.mark.asyncio -async def test_no_injections_flag_is_false_by_default(): - """had_injections should be False when no injection callback or no messages.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - - async def chat_with_retry(**kwargs): - return LLMResponse(content="done", tool_calls=[], usage={}) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "hi"}], - tools=tools, - model="test-model", - max_iterations=1, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - )) - - assert result.had_injections is False - - -@pytest.mark.asyncio -async def test_pending_queue_cleanup_on_dispatch(tmp_path): - """_pending_queues should be cleaned up after _dispatch completes.""" - loop = _make_loop(tmp_path) - - async def chat_with_retry(**kwargs): - return LLMResponse(content="done", tool_calls=[], usage={}) - - loop.provider.chat_with_retry = chat_with_retry - - from nanobot.bus.events import InboundMessage - - msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="hello") - # The queue should not exist before dispatch - assert msg.session_key not in loop._pending_queues - - await loop._dispatch(msg) - - # The queue should be cleaned up after dispatch - assert msg.session_key not in loop._pending_queues - - -@pytest.mark.asyncio -async def test_followup_routed_to_pending_queue(tmp_path): - """Unified-session follow-ups should route into the active pending queue.""" - from nanobot.agent.loop import UNIFIED_SESSION_KEY - from nanobot.bus.events import InboundMessage - - loop = _make_loop(tmp_path) - loop._unified_session = True - loop._dispatch = AsyncMock() # type: ignore[method-assign] - - pending = asyncio.Queue(maxsize=20) - loop._pending_queues[UNIFIED_SESSION_KEY] = pending - - run_task = asyncio.create_task(loop.run()) - msg = InboundMessage(channel="discord", sender_id="u", chat_id="c", content="follow-up") - await loop.bus.publish_inbound(msg) - - deadline = time.time() + 2 - while pending.empty() and time.time() < deadline: - await asyncio.sleep(0.01) - - loop.stop() - await asyncio.wait_for(run_task, timeout=2) - - assert loop._dispatch.await_count == 0 - assert not pending.empty() - queued_msg = pending.get_nowait() - assert queued_msg.content == "follow-up" - assert queued_msg.session_key == UNIFIED_SESSION_KEY - - -@pytest.mark.asyncio -async def test_pending_queue_preserves_overflow_for_next_injection_cycle(tmp_path): - """Pending queue should leave overflow messages queued for later drains.""" - from nanobot.agent.loop import AgentLoop - from nanobot.bus.events import InboundMessage - from nanobot.bus.queue import MessageBus - from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN - - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - captured_messages: list[list[dict]] = [] - call_count = {"n": 0} - - async def chat_with_retry(*, messages, **kwargs): - call_count["n"] += 1 - captured_messages.append([dict(message) for message in messages]) - return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={}) - - provider.chat_with_retry = chat_with_retry - loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") - loop.tools.get_definitions = MagicMock(return_value=[]) - - pending_queue = asyncio.Queue() - total_followups = _MAX_INJECTIONS_PER_TURN + 2 - for idx in range(total_followups): - await pending_queue.put(InboundMessage( - channel="cli", - sender_id="u", - chat_id="c", - content=f"follow-up-{idx}", - )) - - final_content, _, _, _, had_injections = await loop._run_agent_loop( - [{"role": "user", "content": "hello"}], - channel="cli", - chat_id="c", - pending_queue=pending_queue, - ) - - assert final_content == "answer-3" - assert had_injections is True - assert call_count["n"] == 3 - flattened_user_content = "\n".join( - message["content"] - for message in captured_messages[-1] - if message.get("role") == "user" and isinstance(message.get("content"), str) - ) - for idx in range(total_followups): - assert f"follow-up-{idx}" in flattened_user_content - assert pending_queue.empty() - - -@pytest.mark.asyncio -async def test_pending_queue_full_falls_back_to_queued_task(tmp_path): - """QueueFull should preserve the message by dispatching a queued task.""" - from nanobot.bus.events import InboundMessage - - loop = _make_loop(tmp_path) - loop._dispatch = AsyncMock() # type: ignore[method-assign] - - pending = asyncio.Queue(maxsize=1) - pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="already queued")) - loop._pending_queues["cli:c"] = pending - - run_task = asyncio.create_task(loop.run()) - msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up") - await loop.bus.publish_inbound(msg) - - deadline = time.time() + 2 - while loop._dispatch.await_count == 0 and time.time() < deadline: - await asyncio.sleep(0.01) - - loop.stop() - await asyncio.wait_for(run_task, timeout=2) - - assert loop._dispatch.await_count == 1 - dispatched_msg = loop._dispatch.await_args.args[0] - assert dispatched_msg.content == "follow-up" - assert pending.qsize() == 1 - - -@pytest.mark.asyncio -async def test_dispatch_republishes_leftover_queue_messages(tmp_path): - """Messages left in the pending queue after _dispatch are re-published to the bus. - - This tests the finally-block cleanup that prevents message loss when - the runner exits early (e.g., max_iterations, tool_error) with messages - still in the queue. - """ - from nanobot.bus.events import InboundMessage - - loop = _make_loop(tmp_path) - bus = loop.bus - - # Simulate a completed dispatch by manually registering a queue - # with leftover messages, then running the cleanup logic directly. - pending = asyncio.Queue(maxsize=20) - session_key = "cli:c" - loop._pending_queues[session_key] = pending - pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="leftover-1")) - pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="leftover-2")) - - # Execute the cleanup logic from the finally block - queue = loop._pending_queues.pop(session_key, None) - assert queue is not None - leftover = 0 - while True: - try: - item = queue.get_nowait() - except asyncio.QueueEmpty: - break - await bus.publish_inbound(item) - leftover += 1 - - assert leftover == 2 - - # Verify the messages are now on the bus - msgs = [] - while not bus.inbound.empty(): - msgs.append(await asyncio.wait_for(bus.consume_inbound(), timeout=0.5)) - contents = [m.content for m in msgs] - assert "leftover-1" in contents - assert "leftover-2" in contents - - -@pytest.mark.asyncio -async def test_drain_injections_on_fatal_tool_error(): - """Pending injections should be drained even when a fatal tool error occurs.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner - from nanobot.bus.events import InboundMessage - - provider = MagicMock() - call_count = {"n": 0} - - async def chat_with_retry(*, messages, **kwargs): - call_count["n"] += 1 - if call_count["n"] == 1: - return LLMResponse( - content="", - tool_calls=[ToolCallRequest(id="c1", name="exec", arguments={"cmd": "bad"})], - usage={}, - ) - # Second call: respond normally to the injected follow-up - return LLMResponse(content="reply to follow-up", tool_calls=[], usage={}) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - tools.execute = AsyncMock(side_effect=RuntimeError("tool exploded")) - - injection_queue = asyncio.Queue() - inject_cb = _make_injection_callback(injection_queue) - - await injection_queue.put( - InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after error") - ) - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "hello"}], - tools=tools, - model="test-model", - max_iterations=5, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - fail_on_tool_error=True, - injection_callback=inject_cb, - )) - - assert result.had_injections is True - assert result.final_content == "reply to follow-up" - # The injection should be in the messages history - injected = [ - m for m in result.messages - if m.get("role") == "user" and m.get("content") == "follow-up after error" - ] - assert len(injected) == 1 - - -@pytest.mark.asyncio -async def test_drain_injections_on_llm_error(): - """Pending injections should be drained when the LLM returns an error finish_reason.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner - from nanobot.bus.events import InboundMessage - - provider = MagicMock() - call_count = {"n": 0} - - async def chat_with_retry(*, messages, **kwargs): - call_count["n"] += 1 - if call_count["n"] == 1: - return LLMResponse( - content=None, - tool_calls=[], - finish_reason="error", - usage={}, - ) - # Second call: respond normally to the injected follow-up - return LLMResponse(content="recovered answer", tool_calls=[], usage={}) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - - injection_queue = asyncio.Queue() - inject_cb = _make_injection_callback(injection_queue) - - await injection_queue.put( - InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after LLM error") - ) - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[ - {"role": "user", "content": "hello"}, - {"role": "assistant", "content": "previous response"}, - {"role": "user", "content": "trigger error"}, - ], - tools=tools, - model="test-model", - max_iterations=5, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - injection_callback=inject_cb, - )) - - assert result.had_injections is True - assert result.final_content == "recovered answer" - injected = [ - m for m in result.messages - if m.get("role") == "user" and "follow-up after LLM error" in str(m.get("content", "")) - ] - assert len(injected) == 1 - - -@pytest.mark.asyncio -async def test_drain_injections_on_empty_final_response(): - """Pending injections should be drained when the runner exits due to empty response.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_EMPTY_RETRIES - from nanobot.bus.events import InboundMessage - - provider = MagicMock() - call_count = {"n": 0} - - async def chat_with_retry(*, messages, **kwargs): - call_count["n"] += 1 - if call_count["n"] <= _MAX_EMPTY_RETRIES + 1: - return LLMResponse(content="", tool_calls=[], usage={}) - # After retries exhausted + injection drain, respond normally - return LLMResponse(content="answer after empty", tool_calls=[], usage={}) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - - injection_queue = asyncio.Queue() - inject_cb = _make_injection_callback(injection_queue) - - await injection_queue.put( - InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after empty") - ) - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[ - {"role": "user", "content": "hello"}, - {"role": "assistant", "content": "previous response"}, - {"role": "user", "content": "trigger empty"}, - ], - tools=tools, - model="test-model", - max_iterations=10, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - injection_callback=inject_cb, - )) - - assert result.had_injections is True - assert result.final_content == "answer after empty" - injected = [ - m for m in result.messages - if m.get("role") == "user" and "follow-up after empty" in str(m.get("content", "")) - ] - assert len(injected) == 1 - - -@pytest.mark.asyncio -async def test_drain_injections_on_max_iterations(): - """Pending injections should be drained when the runner hits max_iterations. - - Unlike other error paths, max_iterations cannot continue the loop, so - injections are appended to messages but not processed by the LLM. - The key point is they are consumed from the queue to prevent re-publish. - """ - from nanobot.agent.runner import AgentRunSpec, AgentRunner - from nanobot.bus.events import InboundMessage - - provider = MagicMock() - call_count = {"n": 0} - - async def chat_with_retry(*, messages, **kwargs): - call_count["n"] += 1 - return LLMResponse( - content="", - tool_calls=[ToolCallRequest(id=f"c{call_count['n']}", name="read_file", arguments={"path": "x"})], - usage={}, - ) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - tools.execute = AsyncMock(return_value="file content") - - injection_queue = asyncio.Queue() - inject_cb = _make_injection_callback(injection_queue) - - await injection_queue.put( - InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after max iters") - ) - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "hello"}], - tools=tools, - model="test-model", - max_iterations=2, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - injection_callback=inject_cb, - )) - - assert result.stop_reason == "max_iterations" - assert result.had_injections is True - # The injection was consumed from the queue (preventing re-publish) - assert injection_queue.empty() - # The injection message is appended to conversation history - injected = [ - m for m in result.messages - if m.get("role") == "user" and m.get("content") == "follow-up after max iters" - ] - assert len(injected) == 1 - - -@pytest.mark.asyncio -async def test_drain_injections_set_flag_when_followup_arrives_after_last_iteration(): - """Late follow-ups drained in max_iterations should still flip had_injections.""" - from nanobot.agent.hook import AgentHook - from nanobot.agent.runner import AgentRunSpec, AgentRunner - from nanobot.bus.events import InboundMessage - - provider = MagicMock() - call_count = {"n": 0} - - async def chat_with_retry(*, messages, **kwargs): - call_count["n"] += 1 - return LLMResponse( - content="", - tool_calls=[ToolCallRequest(id=f"c{call_count['n']}", name="read_file", arguments={"path": "x"})], - usage={}, - ) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - tools.execute = AsyncMock(return_value="file content") - - injection_queue = asyncio.Queue() - inject_cb = _make_injection_callback(injection_queue) - - class InjectOnLastAfterIterationHook(AgentHook): - def __init__(self) -> None: - self.after_iteration_calls = 0 - - async def after_iteration(self, context) -> None: - self.after_iteration_calls += 1 - if self.after_iteration_calls == 2: - await injection_queue.put( - InboundMessage( - channel="cli", - sender_id="u", - chat_id="c", - content="late follow-up after max iters", - ) - ) - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[{"role": "user", "content": "hello"}], - tools=tools, - model="test-model", - max_iterations=2, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - injection_callback=inject_cb, - hook=InjectOnLastAfterIterationHook(), - )) - - assert result.stop_reason == "max_iterations" - assert result.had_injections is True - assert injection_queue.empty() - injected = [ - m for m in result.messages - if m.get("role") == "user" and m.get("content") == "late follow-up after max iters" - ] - assert len(injected) == 1 - - -@pytest.mark.asyncio -async def test_injection_cycle_cap_on_error_path(): - """Injection cycles should be capped even when every iteration hits an LLM error.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTION_CYCLES - from nanobot.bus.events import InboundMessage - - provider = MagicMock() - call_count = {"n": 0} - - async def chat_with_retry(*, messages, **kwargs): - call_count["n"] += 1 - return LLMResponse( - content=None, - tool_calls=[], - finish_reason="error", - usage={}, - ) - - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - - drain_count = {"n": 0} - - async def inject_cb(): - drain_count["n"] += 1 - if drain_count["n"] <= _MAX_INJECTION_CYCLES: - return [InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg-{drain_count['n']}")] - return [] - - runner = AgentRunner(provider) - result = await runner.run(AgentRunSpec( - initial_messages=[ - {"role": "user", "content": "hello"}, - {"role": "assistant", "content": "previous"}, - {"role": "user", "content": "trigger error"}, - ], - tools=tools, - model="test-model", - max_iterations=20, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - injection_callback=inject_cb, - )) - - assert result.had_injections is True - # Should cap: _MAX_INJECTION_CYCLES drained rounds + 1 final round that breaks - assert call_count["n"] == _MAX_INJECTION_CYCLES + 1 - - -# --------------------------------------------------------------------------- -# Regression tests for GLM-1214: _snip_history must preserve a user message -# --------------------------------------------------------------------------- - - -def test_snip_history_preserves_user_message_after_truncation(monkeypatch): - """When _snip_history truncates messages and the only user message ends up - outside the kept window, the method must recover the nearest user message - so the resulting sequence is valid for providers like GLM (which reject - system→assistant with error 1214). - - This reproduces the exact scenario from the bug report: - - Normal interaction: user asks, assistant calls tool, tool returns, - assistant replies. - - Injection adds a phantom user message, triggering more tool calls. - - _snip_history activates, keeping only recent assistant/tool pairs. - - The injected user message is in the truncated prefix and gets lost. - """ - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - tools = MagicMock() - tools.get_definitions.return_value = [] - runner = AgentRunner(provider) - - messages = [ - {"role": "system", "content": "system"}, - {"role": "assistant", "content": "previous reply"}, - {"role": "user", "content": ".nanobot的同目录"}, - { - "role": "assistant", - "content": None, - "tool_calls": [{"id": "tc_1", "type": "function", "function": {"name": "exec", "arguments": "{}"}}], - }, - {"role": "tool", "tool_call_id": "tc_1", "content": "tool output 1"}, - { - "role": "assistant", - "content": None, - "tool_calls": [{"id": "tc_2", "type": "function", "function": {"name": "exec", "arguments": "{}"}}], - }, - {"role": "tool", "tool_call_id": "tc_2", "content": "tool output 2"}, - ] - - spec = AgentRunSpec( - initial_messages=messages, - tools=tools, - model="test-model", - max_iterations=1, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - context_window_tokens=2000, - context_block_limit=100, - ) - - # Make estimate_prompt_tokens_chain report above budget so _snip_history activates. - monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", lambda *_a, **_kw: (500, None)) - # Make kept window small: only the last 2 messages fit the budget. - token_sizes = { - "system": 0, - "previous reply": 200, - ".nanobot的同目录": 80, - "tool output 1": 80, - "tool output 2": 80, - } - monkeypatch.setattr( - "nanobot.agent.runner.estimate_message_tokens", - lambda msg: token_sizes.get(str(msg.get("content")), 100), - ) - - trimmed = runner._snip_history(spec, messages) - - # The first non-system message MUST be user (not assistant). - non_system = [m for m in trimmed if m.get("role") != "system"] - assert non_system, "trimmed should contain at least one non-system message" - assert non_system[0]["role"] == "user", ( - f"First non-system message must be 'user', got '{non_system[0]['role']}'. " - f"Roles: {[m['role'] for m in trimmed]}" - ) - - -def test_snip_history_no_user_at_all_falls_back_gracefully(monkeypatch): - """Edge case: if non_system has zero user messages, _snip_history should - still return a valid sequence (not crash or produce system→assistant).""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - provider = MagicMock() - tools = MagicMock() - tools.get_definitions.return_value = [] - runner = AgentRunner(provider) - - messages = [ - {"role": "system", "content": "system"}, - {"role": "assistant", "content": "reply"}, - {"role": "tool", "tool_call_id": "tc_1", "content": "result"}, - {"role": "assistant", "content": "reply 2"}, - {"role": "tool", "tool_call_id": "tc_2", "content": "result 2"}, - ] - - spec = AgentRunSpec( - initial_messages=messages, - tools=tools, - model="test-model", - max_iterations=1, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - context_window_tokens=2000, - context_block_limit=100, - ) - - monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", lambda *_a, **_kw: (500, None)) - monkeypatch.setattr( - "nanobot.agent.runner.estimate_message_tokens", - lambda msg: 100, - ) - - trimmed = runner._snip_history(spec, messages) - - # Should not crash. The result should still be a valid list. - assert isinstance(trimmed, list) - # Must have at least system. - assert any(m.get("role") == "system" for m in trimmed) - # The _enforce_role_alternation safety net must be able to fix whatever - # _snip_history returns here — verify it produces a valid sequence. - from nanobot.providers.base import LLMProvider - fixed = LLMProvider._enforce_role_alternation(trimmed) - non_system = [m for m in fixed if m["role"] != "system"] - if non_system: - assert non_system[0]["role"] in ("user", "tool"), ( - f"Safety net should ensure first non-system is user/tool, got {non_system[0]['role']}" - ) - - -@pytest.mark.asyncio -async def test_runner_binds_on_retry_wait_to_retry_callback_not_progress(): - """Regression: provider retry heartbeats must route through - ``retry_wait_callback``, not ``progress_callback``. Binding them to - the progress callback (as an earlier runtime refactor did) caused - internal retry diagnostics like "Model request failed, retry in 1s" - to leak to end-user channels as normal progress updates. - """ - from nanobot.agent.runner import AgentRunSpec, AgentRunner - - captured: dict = {} - - async def chat_with_retry(**kwargs): - captured.update(kwargs) - return LLMResponse(content="done", tool_calls=[], usage={}) - - provider = MagicMock() - provider.chat_with_retry = chat_with_retry - tools = MagicMock() - tools.get_definitions.return_value = [] - - progress_cb = AsyncMock() - retry_wait_cb = AsyncMock() - - runner = AgentRunner(provider) - await runner.run(AgentRunSpec( - initial_messages=[ - {"role": "system", "content": "system"}, - {"role": "user", "content": "hi"}, - ], - tools=tools, - model="test-model", - max_iterations=1, - max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - progress_callback=progress_cb, - retry_wait_callback=retry_wait_cb, - )) - - assert captured["on_retry_wait"] is retry_wait_cb - assert captured["on_retry_wait"] is not progress_cb diff --git a/tests/agent/test_runner_core.py b/tests/agent/test_runner_core.py new file mode 100644 index 000000000..dd28fa1cc --- /dev/null +++ b/tests/agent/test_runner_core.py @@ -0,0 +1,481 @@ +"""Tests for core AgentRunner behavior: message passing, iteration limits, +timeouts, empty-response handling, usage accumulation, and config passthrough.""" + +from __future__ import annotations + +import asyncio +import time +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from nanobot.config.schema import AgentDefaults +from nanobot.agent.tools.registry import ToolRegistry +from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest + +_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars + + +@pytest.mark.asyncio +async def test_runner_preserves_reasoning_fields_and_tool_results(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock(spec=LLMProvider) + captured_second_call: list[dict] = [] + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="thinking", + tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})], + reasoning_content="hidden reasoning", + thinking_blocks=[{"type": "thinking", "thinking": "step"}], + usage={"prompt_tokens": 5, "completion_tokens": 3}, + ) + captured_second_call[:] = messages + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="tool result") + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[ + {"role": "system", "content": "system"}, + {"role": "user", "content": "do task"}, + ], + tools=tools, + model="test-model", + max_iterations=3, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.final_content == "done" + assert result.tools_used == ["list_dir"] + assert result.tool_events == [ + {"name": "list_dir", "status": "ok", "detail": "tool result"} + ] + + assistant_messages = [ + msg for msg in captured_second_call + if msg.get("role") == "assistant" and msg.get("tool_calls") + ] + assert len(assistant_messages) == 1 + assert assistant_messages[0]["reasoning_content"] == "hidden reasoning" + assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}] + assert any( + msg.get("role") == "tool" and msg.get("content") == "tool result" + for msg in captured_second_call + ) + + +@pytest.mark.asyncio +async def test_runner_returns_max_iterations_fallback(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock(spec=LLMProvider) + provider.chat_with_retry = AsyncMock(return_value=LLMResponse( + content="still working", + tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})], + )) + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="tool result") + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[], + tools=tools, + model="test-model", + max_iterations=2, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.stop_reason == "max_iterations" + assert result.final_content == ( + "I reached the maximum number of tool call iterations (2) " + "without completing the task. You can try breaking the task into smaller steps." + ) + assert result.messages[-1]["role"] == "assistant" + assert result.messages[-1]["content"] == result.final_content + + +@pytest.mark.asyncio +async def test_runner_times_out_hung_llm_request(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock(spec=LLMProvider) + + async def chat_with_retry(**kwargs): + await asyncio.sleep(3600) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + runner = AgentRunner(provider) + started = time.monotonic() + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + llm_timeout_s=0.05, + )) + + assert (time.monotonic() - started) < 1.0 + assert result.stop_reason == "error" + assert "timed out" in (result.final_content or "").lower() + + +@pytest.mark.asyncio +async def test_runner_replaces_empty_tool_result_with_marker(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock(spec=LLMProvider) + captured_second_call: list[dict] = [] + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="working", + tool_calls=[ToolCallRequest(id="call_1", name="noop", arguments={})], + usage={}, + ) + captured_second_call[:] = messages + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="") + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "do task"}], + tools=tools, + model="test-model", + max_iterations=2, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.final_content == "done" + tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool") + assert tool_message["content"] == "(noop completed with no output)" + + +@pytest.mark.asyncio +async def test_runner_retries_empty_final_response_with_summary_prompt(): + """Empty responses get 2 silent retries before finalization kicks in.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock(spec=LLMProvider) + calls: list[dict] = [] + + async def chat_with_retry(*, messages, tools=None, **kwargs): + calls.append({"messages": messages, "tools": tools}) + if len(calls) <= 2: + return LLMResponse( + content=None, + tool_calls=[], + usage={"prompt_tokens": 5, "completion_tokens": 1}, + ) + return LLMResponse( + content="final answer", + tool_calls=[], + usage={"prompt_tokens": 3, "completion_tokens": 7}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "do task"}], + tools=tools, + model="test-model", + max_iterations=3, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.final_content == "final answer" + # 2 silent retries (iterations 0,1) + finalization on iteration 1 + assert len(calls) == 3 + assert calls[0]["tools"] is not None + assert calls[1]["tools"] is not None + assert calls[2]["tools"] is None + assert result.usage["prompt_tokens"] == 13 + assert result.usage["completion_tokens"] == 9 + + +@pytest.mark.asyncio +async def test_runner_uses_specific_message_after_empty_finalization_retry(): + """After silent retries + finalization all return empty, stop_reason is empty_final_response.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE + + provider = MagicMock(spec=LLMProvider) + + async def chat_with_retry(*, messages, **kwargs): + return LLMResponse(content=None, tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "do task"}], + tools=tools, + model="test-model", + max_iterations=3, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.final_content == EMPTY_FINAL_RESPONSE_MESSAGE + assert result.stop_reason == "empty_final_response" + + +@pytest.mark.asyncio +async def test_runner_empty_response_does_not_break_tool_chain(): + """An empty intermediate response must not kill an ongoing tool chain. + + Sequence: tool_call -> empty -> tool_call -> final text. + The runner should recover via silent retry and complete normally. + """ + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock(spec=LLMProvider) + call_count = 0 + + async def chat_with_retry(*, messages, tools=None, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return LLMResponse( + content=None, + tool_calls=[ToolCallRequest(id="tc1", name="read_file", arguments={"path": "a.txt"})], + usage={"prompt_tokens": 10, "completion_tokens": 5}, + ) + if call_count == 2: + return LLMResponse(content=None, tool_calls=[], usage={"prompt_tokens": 10, "completion_tokens": 1}) + if call_count == 3: + return LLMResponse( + content=None, + tool_calls=[ToolCallRequest(id="tc2", name="read_file", arguments={"path": "b.txt"})], + usage={"prompt_tokens": 10, "completion_tokens": 5}, + ) + return LLMResponse( + content="Here are the results.", + tool_calls=[], + usage={"prompt_tokens": 10, "completion_tokens": 10}, + ) + + provider.chat_with_retry = chat_with_retry + provider.chat_stream_with_retry = chat_with_retry + + async def fake_tool(name, args, **kw): + return "file content" + + tool_registry = MagicMock() + tool_registry.get_definitions.return_value = [{"type": "function", "function": {"name": "read_file"}}] + tool_registry.execute = AsyncMock(side_effect=fake_tool) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "read both files"}], + tools=tool_registry, + model="test-model", + max_iterations=10, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.final_content == "Here are the results." + assert result.stop_reason == "completed" + assert call_count == 4 + assert "read_file" in result.tools_used + + +@pytest.mark.asyncio +async def test_runner_accumulates_usage_and_preserves_cached_tokens(): + """Runner should accumulate prompt/completion tokens across iterations + and preserve cached_tokens from provider responses.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock(spec=LLMProvider) + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="thinking", + tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})], + usage={"prompt_tokens": 100, "completion_tokens": 10, "cached_tokens": 80}, + ) + return LLMResponse( + content="done", + tool_calls=[], + usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="file content") + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "do task"}], + tools=tools, + model="test-model", + max_iterations=3, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + # Usage should be accumulated across iterations + assert result.usage["prompt_tokens"] == 300 # 100 + 200 + assert result.usage["completion_tokens"] == 30 # 10 + 20 + assert result.usage["cached_tokens"] == 230 # 80 + 150 + + +@pytest.mark.asyncio +async def test_runner_binds_on_retry_wait_to_retry_callback_not_progress(): + """Regression: provider retry heartbeats must route through + ``retry_wait_callback``, not ``progress_callback``. Binding them to + the progress callback (as an earlier runtime refactor did) caused + internal retry diagnostics like "Model request failed, retry in 1s" + to leak to end-user channels as normal progress updates. + """ + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + captured: dict = {} + + async def chat_with_retry(**kwargs): + captured.update(kwargs) + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider = MagicMock(spec=LLMProvider) + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + progress_cb = AsyncMock() + retry_wait_cb = AsyncMock() + + runner = AgentRunner(provider) + await runner.run(AgentRunSpec( + initial_messages=[ + {"role": "system", "content": "system"}, + {"role": "user", "content": "hi"}, + ], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + progress_callback=progress_cb, + retry_wait_callback=retry_wait_cb, + )) + + assert captured["on_retry_wait"] is retry_wait_cb + assert captured["on_retry_wait"] is not progress_cb + + +# --------------------------------------------------------------------------- +# Config passthrough tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_runner_passes_temperature_to_provider(): + """temperature from AgentRunSpec should reach provider.chat_with_retry.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + captured: dict = {} + + async def chat_with_retry(**kwargs): + captured.update(kwargs) + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider = MagicMock(spec=LLMProvider) + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + runner = AgentRunner(provider) + await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hi"}], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + temperature=0.7, + )) + + assert captured["temperature"] == 0.7 + + +@pytest.mark.asyncio +async def test_runner_passes_max_tokens_to_provider(): + """max_tokens from AgentRunSpec should reach provider.chat_with_retry.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + captured: dict = {} + + async def chat_with_retry(**kwargs): + captured.update(kwargs) + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider = MagicMock(spec=LLMProvider) + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + runner = AgentRunner(provider) + await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hi"}], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + max_tokens=8192, + )) + + assert captured["max_tokens"] == 8192 + + +@pytest.mark.asyncio +async def test_runner_passes_reasoning_effort_to_provider(): + """reasoning_effort from AgentRunSpec should reach provider.chat_with_retry.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + captured: dict = {} + + async def chat_with_retry(**kwargs): + captured.update(kwargs) + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider = MagicMock(spec=LLMProvider) + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + runner = AgentRunner(provider) + await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hi"}], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + reasoning_effort="high", + )) + + assert captured["reasoning_effort"] == "high" diff --git a/tests/agent/test_runner_errors.py b/tests/agent/test_runner_errors.py new file mode 100644 index 000000000..8df7ad8f3 --- /dev/null +++ b/tests/agent/test_runner_errors.py @@ -0,0 +1,171 @@ +"""Tests for AgentRunner error handling: tool errors, LLM errors, +session message isolation, and tool result preservation.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from nanobot.config.schema import AgentDefaults +from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest + +_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars + + +@pytest.mark.asyncio +async def test_runner_returns_structured_tool_error(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock(spec=LLMProvider) + provider.chat_with_retry = AsyncMock(return_value=LLMResponse( + content="working", + tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})], + )) + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(side_effect=RuntimeError("boom")) + + runner = AgentRunner(provider) + + result = await runner.run(AgentRunSpec( + initial_messages=[], + tools=tools, + model="test-model", + max_iterations=2, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + fail_on_tool_error=True, + )) + + assert result.stop_reason == "tool_error" + assert result.error == "Error: RuntimeError: boom" + assert result.tool_events == [ + {"name": "list_dir", "status": "error", "detail": "boom"} + ] + + +@pytest.mark.asyncio +async def test_llm_error_not_appended_to_session_messages(): + """When LLM returns finish_reason='error', the error content must NOT be + appended to the messages list (prevents polluting session history).""" + from nanobot.agent.runner import ( + AgentRunSpec, + AgentRunner, + _PERSISTED_MODEL_ERROR_PLACEHOLDER, + ) + + provider = MagicMock(spec=LLMProvider) + provider.chat_with_retry = AsyncMock(return_value=LLMResponse( + content="429 rate limit exceeded", finish_reason="error", tool_calls=[], usage={}, + )) + tools = MagicMock() + tools.get_definitions.return_value = [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.stop_reason == "error" + assert result.final_content == "429 rate limit exceeded" + assistant_msgs = [m for m in result.messages if m.get("role") == "assistant"] + assert all("429" not in (m.get("content") or "") for m in assistant_msgs), \ + "Error content should not appear in session messages" + assert assistant_msgs[-1]["content"] == _PERSISTED_MODEL_ERROR_PLACEHOLDER + + +@pytest.mark.asyncio +async def test_runner_tool_error_sets_final_content(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock(spec=LLMProvider) + + async def chat_with_retry(*, messages, **kwargs): + return LLMResponse( + content="working", + tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})], + usage={}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(side_effect=RuntimeError("boom")) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "do task"}], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + fail_on_tool_error=True, + )) + + assert result.final_content == "Error: RuntimeError: boom" + assert result.stop_reason == "tool_error" + + +@pytest.mark.asyncio +async def test_runner_tool_error_preserves_tool_results_in_messages(): + """When a tool raises a fatal error, its results must still be appended + to messages so the session never contains orphan tool_calls (#2943).""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock(spec=LLMProvider) + + async def chat_with_retry(*, messages, **kwargs): + return LLMResponse( + content=None, + tool_calls=[ + ToolCallRequest(id="tc1", name="read_file", arguments={"path": "a"}), + ToolCallRequest(id="tc2", name="exec", arguments={"cmd": "bad"}), + ], + usage={}, + ) + + provider.chat_with_retry = chat_with_retry + provider.chat_stream_with_retry = chat_with_retry + + call_idx = 0 + + async def fake_execute(name, args, **kw): + nonlocal call_idx + call_idx += 1 + if call_idx == 2: + raise RuntimeError("boom") + return "file content" + + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(side_effect=fake_execute) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "do stuff"}], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + fail_on_tool_error=True, + )) + + assert result.stop_reason == "tool_error" + # Both tool results must be in messages even though tc2 had a fatal error. + tool_msgs = [m for m in result.messages if m.get("role") == "tool"] + assert len(tool_msgs) == 2 + assert tool_msgs[0]["tool_call_id"] == "tc1" + assert tool_msgs[1]["tool_call_id"] == "tc2" + # The assistant message with tool_calls must precede the tool results. + asst_tc_idx = next( + i for i, m in enumerate(result.messages) + if m.get("role") == "assistant" and m.get("tool_calls") + ) + tool_indices = [ + i for i, m in enumerate(result.messages) if m.get("role") == "tool" + ] + assert all(ti > asst_tc_idx for ti in tool_indices) diff --git a/tests/agent/test_runner_governance.py b/tests/agent/test_runner_governance.py new file mode 100644 index 000000000..50e882ca6 --- /dev/null +++ b/tests/agent/test_runner_governance.py @@ -0,0 +1,643 @@ +"""Tests for AgentRunner context governance: backfill, orphan cleanup, microcompact, snip_history.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from nanobot.config.schema import AgentDefaults +from nanobot.providers.base import LLMResponse, ToolCallRequest + +_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars + + +def _make_loop(tmp_path): + from nanobot.agent.loop import AgentLoop + from nanobot.bus.queue import MessageBus + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + + with patch("nanobot.agent.loop.ContextBuilder"), \ + patch("nanobot.agent.loop.SessionManager"), \ + patch("nanobot.agent.loop.SubagentManager") as MockSubMgr: + MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0) + loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path) + return loop + +async def test_runner_uses_raw_messages_when_context_governance_fails(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + captured_messages: list[dict] = [] + + async def chat_with_retry(*, messages, **kwargs): + captured_messages[:] = messages + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + initial_messages = [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "hello"}, + ] + + runner = AgentRunner(provider) + runner._snip_history = MagicMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign] + result = await runner.run(AgentRunSpec( + initial_messages=initial_messages, + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.final_content == "done" + assert captured_messages == initial_messages +def test_snip_history_drops_orphaned_tool_results_from_trimmed_slice(monkeypatch): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + tools = MagicMock() + tools.get_definitions.return_value = [] + runner = AgentRunner(provider) + messages = [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "old user"}, + { + "role": "assistant", + "content": "tool call", + "tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "ls", "arguments": "{}"}}], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "tool output"}, + {"role": "assistant", "content": "after tool"}, + ] + spec = AgentRunSpec( + initial_messages=messages, + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + context_window_tokens=2000, + context_block_limit=100, + ) + + monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", lambda *_args, **_kwargs: (500, None)) + token_sizes = { + "old user": 120, + "tool call": 120, + "tool output": 40, + "after tool": 40, + "system": 0, + } + monkeypatch.setattr( + "nanobot.agent.runner.estimate_message_tokens", + lambda msg: token_sizes.get(str(msg.get("content")), 40), + ) + + trimmed = runner._snip_history(spec, messages) + + # After the fix, the user message is recovered so the sequence is valid + # for providers that require system → user (e.g. GLM error 1214). + assert trimmed[0]["role"] == "system" + non_system = [m for m in trimmed if m["role"] != "system"] + assert non_system[0]["role"] == "user", f"Expected user after system, got {non_system[0]['role']}" +async def test_backfill_missing_tool_results_inserts_error(): + """Orphaned tool_use (no matching tool_result) should get a synthetic error.""" + from nanobot.agent.runner import AgentRunner, _BACKFILL_CONTENT + + messages = [ + {"role": "user", "content": "hi"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + {"id": "call_a", "type": "function", "function": {"name": "exec", "arguments": "{}"}}, + {"id": "call_b", "type": "function", "function": {"name": "read_file", "arguments": "{}"}}, + ], + }, + {"role": "tool", "tool_call_id": "call_a", "name": "exec", "content": "ok"}, + ] + result = AgentRunner._backfill_missing_tool_results(messages) + tool_msgs = [m for m in result if m.get("role") == "tool"] + assert len(tool_msgs) == 2 + backfilled = [m for m in tool_msgs if m.get("tool_call_id") == "call_b"] + assert len(backfilled) == 1 + assert backfilled[0]["content"] == _BACKFILL_CONTENT + assert backfilled[0]["name"] == "read_file" + + +def test_drop_orphan_tool_results_removes_unmatched_tool_messages(): + from nanobot.agent.runner import AgentRunner + + messages = [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "old user"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + {"id": "call_ok", "type": "function", "function": {"name": "read_file", "arguments": "{}"}}, + ], + }, + {"role": "tool", "tool_call_id": "call_ok", "name": "read_file", "content": "ok"}, + {"role": "tool", "tool_call_id": "call_orphan", "name": "exec", "content": "stale"}, + {"role": "assistant", "content": "after tool"}, + ] + + cleaned = AgentRunner._drop_orphan_tool_results(messages) + + assert cleaned == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "old user"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + {"id": "call_ok", "type": "function", "function": {"name": "read_file", "arguments": "{}"}}, + ], + }, + {"role": "tool", "tool_call_id": "call_ok", "name": "read_file", "content": "ok"}, + {"role": "assistant", "content": "after tool"}, + ] + + +@pytest.mark.asyncio +async def test_backfill_noop_when_complete(): + """Complete message chains should not be modified.""" + from nanobot.agent.runner import AgentRunner + + messages = [ + {"role": "user", "content": "hi"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + {"id": "call_x", "type": "function", "function": {"name": "exec", "arguments": "{}"}}, + ], + }, + {"role": "tool", "tool_call_id": "call_x", "name": "exec", "content": "done"}, + {"role": "assistant", "content": "all good"}, + ] + result = AgentRunner._backfill_missing_tool_results(messages) + assert result is messages # same object — no copy + + +@pytest.mark.asyncio +async def test_runner_drops_orphan_tool_results_before_model_request(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + captured_messages: list[dict] = [] + + async def chat_with_retry(*, messages, **kwargs): + captured_messages[:] = messages + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[ + {"role": "system", "content": "system"}, + {"role": "user", "content": "old user"}, + {"role": "tool", "tool_call_id": "call_orphan", "name": "exec", "content": "stale"}, + {"role": "assistant", "content": "after orphan"}, + {"role": "user", "content": "new prompt"}, + ], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert all( + message.get("tool_call_id") != "call_orphan" + for message in captured_messages + if message.get("role") == "tool" + ) + assert result.messages[2]["tool_call_id"] == "call_orphan" + assert result.final_content == "done" + + +@pytest.mark.asyncio +async def test_backfill_repairs_model_context_without_shifting_save_turn_boundary(tmp_path): + """Historical backfill should not duplicate old tail messages on persist.""" + from nanobot.agent.loop import AgentLoop + from nanobot.agent.runner import _BACKFILL_CONTENT + from nanobot.bus.events import InboundMessage + from nanobot.bus.queue import MessageBus + + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + response = LLMResponse(content="new answer", tool_calls=[], usage={}) + provider.chat_with_retry = AsyncMock(return_value=response) + provider.chat_stream_with_retry = AsyncMock(return_value=response) + + loop = AgentLoop( + bus=MessageBus(), + provider=provider, + workspace=tmp_path, + model="test-model", + ) + loop.tools.get_definitions = MagicMock(return_value=[]) + loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] + + session = loop.sessions.get_or_create("cli:test") + session.messages = [ + {"role": "user", "content": "old user", "timestamp": "2026-01-01T00:00:00"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_missing", + "type": "function", + "function": {"name": "read_file", "arguments": "{}"}, + } + ], + "timestamp": "2026-01-01T00:00:01", + }, + {"role": "assistant", "content": "old tail", "timestamp": "2026-01-01T00:00:02"}, + ] + loop.sessions.save(session) + + result = await loop._process_message( + InboundMessage(channel="cli", sender_id="user", chat_id="test", content="new prompt") + ) + + assert result is not None + assert result.content == "new answer" + + request_messages = provider.chat_with_retry.await_args.kwargs["messages"] + synthetic = [ + message + for message in request_messages + if message.get("role") == "tool" and message.get("tool_call_id") == "call_missing" + ] + assert len(synthetic) == 1 + assert synthetic[0]["content"] == _BACKFILL_CONTENT + + session_after = loop.sessions.get_or_create("cli:test") + assert [ + { + key: value + for key, value in message.items() + if key in {"role", "content", "tool_call_id", "name", "tool_calls"} + } + for message in session_after.messages + ] == [ + {"role": "user", "content": "old user"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_missing", + "type": "function", + "function": {"name": "read_file", "arguments": "{}"}, + } + ], + }, + {"role": "assistant", "content": "old tail"}, + {"role": "user", "content": "new prompt"}, + {"role": "assistant", "content": "new answer"}, + ] + + +@pytest.mark.asyncio +async def test_runner_backfill_only_mutates_model_context_not_returned_messages(): + """Runner should repair orphaned tool calls for the model without rewriting result.messages.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner, _BACKFILL_CONTENT + + provider = MagicMock() + captured_messages: list[dict] = [] + + async def chat_with_retry(*, messages, **kwargs): + captured_messages[:] = messages + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + initial_messages = [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "old user"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_missing", + "type": "function", + "function": {"name": "read_file", "arguments": "{}"}, + } + ], + }, + {"role": "assistant", "content": "old tail"}, + {"role": "user", "content": "new prompt"}, + ] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=initial_messages, + tools=tools, + model="test-model", + max_iterations=3, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + synthetic = [ + message + for message in captured_messages + if message.get("role") == "tool" and message.get("tool_call_id") == "call_missing" + ] + assert len(synthetic) == 1 + assert synthetic[0]["content"] == _BACKFILL_CONTENT + + assert [ + { + key: value + for key, value in message.items() + if key in {"role", "content", "tool_call_id", "name", "tool_calls"} + } + for message in result.messages + ] == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "old user"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_missing", + "type": "function", + "function": {"name": "read_file", "arguments": "{}"}, + } + ], + }, + {"role": "assistant", "content": "old tail"}, + {"role": "user", "content": "new prompt"}, + {"role": "assistant", "content": "done"}, + ] + + +# --------------------------------------------------------------------------- +# Microcompact (stale tool result compaction) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_microcompact_replaces_old_tool_results(): + """Tool results beyond _MICROCOMPACT_KEEP_RECENT should be summarized.""" + from nanobot.agent.runner import AgentRunner, _MICROCOMPACT_KEEP_RECENT + + total = _MICROCOMPACT_KEEP_RECENT + 5 + long_content = "x" * 600 + messages: list[dict] = [{"role": "system", "content": "sys"}] + for i in range(total): + messages.append({ + "role": "assistant", + "content": "", + "tool_calls": [{"id": f"c{i}", "type": "function", "function": {"name": "read_file", "arguments": "{}"}}], + }) + messages.append({ + "role": "tool", "tool_call_id": f"c{i}", "name": "read_file", + "content": long_content, + }) + + result = AgentRunner._microcompact(messages) + tool_msgs = [m for m in result if m.get("role") == "tool"] + stale_count = total - _MICROCOMPACT_KEEP_RECENT + compacted = [m for m in tool_msgs if "omitted from context" in str(m.get("content", ""))] + preserved = [m for m in tool_msgs if m.get("content") == long_content] + assert len(compacted) == stale_count + assert len(preserved) == _MICROCOMPACT_KEEP_RECENT + + +@pytest.mark.asyncio +async def test_microcompact_preserves_short_results(): + """Short tool results (< _MICROCOMPACT_MIN_CHARS) should not be replaced.""" + from nanobot.agent.runner import AgentRunner, _MICROCOMPACT_KEEP_RECENT + + total = _MICROCOMPACT_KEEP_RECENT + 5 + messages: list[dict] = [] + for i in range(total): + messages.append({ + "role": "assistant", + "content": "", + "tool_calls": [{"id": f"c{i}", "type": "function", "function": {"name": "exec", "arguments": "{}"}}], + }) + messages.append({ + "role": "tool", "tool_call_id": f"c{i}", "name": "exec", + "content": "short", + }) + + result = AgentRunner._microcompact(messages) + assert result is messages # no copy needed — all stale results are short + + +@pytest.mark.asyncio +async def test_microcompact_skips_non_compactable_tools(): + """Non-compactable tools (e.g. 'message') should never be replaced.""" + from nanobot.agent.runner import AgentRunner, _MICROCOMPACT_KEEP_RECENT + + total = _MICROCOMPACT_KEEP_RECENT + 5 + long_content = "y" * 1000 + messages: list[dict] = [] + for i in range(total): + messages.append({ + "role": "assistant", + "content": "", + "tool_calls": [{"id": f"c{i}", "type": "function", "function": {"name": "message", "arguments": "{}"}}], + }) + messages.append({ + "role": "tool", "tool_call_id": f"c{i}", "name": "message", + "content": long_content, + }) + + result = AgentRunner._microcompact(messages) + assert result is messages # no compactable tools found + + +def test_governance_repairs_orphans_after_snip(): + """After _snip_history clips an assistant+tool_calls, the second + _drop_orphan_tool_results pass must clean up the resulting orphans.""" + from nanobot.agent.runner import AgentRunner + + messages = [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "old msg"}, + {"role": "assistant", "content": None, + "tool_calls": [{"id": "tc_old", "type": "function", + "function": {"name": "search", "arguments": "{}"}}]}, + {"role": "tool", "tool_call_id": "tc_old", "name": "search", + "content": "old result"}, + {"role": "assistant", "content": "old answer"}, + {"role": "user", "content": "new msg"}, + ] + + # Simulate snipping that keeps only the tail: drop the assistant with + # tool_calls but keep its tool result (orphan). + snipped = [ + {"role": "system", "content": "system"}, + {"role": "tool", "tool_call_id": "tc_old", "name": "search", + "content": "old result"}, + {"role": "assistant", "content": "old answer"}, + {"role": "user", "content": "new msg"}, + ] + + cleaned = AgentRunner._drop_orphan_tool_results(snipped) + # The orphan tool result should be removed. + assert not any( + m.get("role") == "tool" and m.get("tool_call_id") == "tc_old" + for m in cleaned + ) + + +def test_governance_fallback_still_repairs_orphans(): + """When full governance fails, the fallback must still run + _drop_orphan_tool_results and _backfill_missing_tool_results.""" + from nanobot.agent.runner import AgentRunner + + # Messages with an orphan tool result (no matching assistant tool_call). + messages = [ + {"role": "user", "content": "hello"}, + {"role": "tool", "tool_call_id": "orphan_tc", "name": "read", + "content": "stale"}, + {"role": "assistant", "content": "hi"}, + ] + + repaired = AgentRunner._drop_orphan_tool_results(messages) + repaired = AgentRunner._backfill_missing_tool_results(repaired) + # Orphan tool result should be gone. + assert not any(m.get("tool_call_id") == "orphan_tc" for m in repaired) +def test_snip_history_preserves_user_message_after_truncation(monkeypatch): + """When _snip_history truncates messages and the only user message ends up + outside the kept window, the method must recover the nearest user message + so the resulting sequence is valid for providers like GLM (which reject + system→assistant with error 1214). + + This reproduces the exact scenario from the bug report: + - Normal interaction: user asks, assistant calls tool, tool returns, + assistant replies. + - Injection adds a phantom user message, triggering more tool calls. + - _snip_history activates, keeping only recent assistant/tool pairs. + - The injected user message is in the truncated prefix and gets lost. + """ + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + tools = MagicMock() + tools.get_definitions.return_value = [] + runner = AgentRunner(provider) + + messages = [ + {"role": "system", "content": "system"}, + {"role": "assistant", "content": "previous reply"}, + {"role": "user", "content": ".nanobot的同目录"}, + { + "role": "assistant", + "content": None, + "tool_calls": [{"id": "tc_1", "type": "function", "function": {"name": "exec", "arguments": "{}"}}], + }, + {"role": "tool", "tool_call_id": "tc_1", "content": "tool output 1"}, + { + "role": "assistant", + "content": None, + "tool_calls": [{"id": "tc_2", "type": "function", "function": {"name": "exec", "arguments": "{}"}}], + }, + {"role": "tool", "tool_call_id": "tc_2", "content": "tool output 2"}, + ] + + spec = AgentRunSpec( + initial_messages=messages, + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + context_window_tokens=2000, + context_block_limit=100, + ) + + # Make estimate_prompt_tokens_chain report above budget so _snip_history activates. + monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", lambda *_a, **_kw: (500, None)) + # Make kept window small: only the last 2 messages fit the budget. + token_sizes = { + "system": 0, + "previous reply": 200, + ".nanobot的同目录": 80, + "tool output 1": 80, + "tool output 2": 80, + } + monkeypatch.setattr( + "nanobot.agent.runner.estimate_message_tokens", + lambda msg: token_sizes.get(str(msg.get("content")), 100), + ) + + trimmed = runner._snip_history(spec, messages) + + # The first non-system message MUST be user (not assistant). + non_system = [m for m in trimmed if m.get("role") != "system"] + assert non_system, "trimmed should contain at least one non-system message" + assert non_system[0]["role"] == "user", ( + f"First non-system message must be 'user', got '{non_system[0]['role']}'. " + f"Roles: {[m['role'] for m in trimmed]}" + ) + + +def test_snip_history_no_user_at_all_falls_back_gracefully(monkeypatch): + """Edge case: if non_system has zero user messages, _snip_history should + still return a valid sequence (not crash or produce system→assistant).""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + tools = MagicMock() + tools.get_definitions.return_value = [] + runner = AgentRunner(provider) + + messages = [ + {"role": "system", "content": "system"}, + {"role": "assistant", "content": "reply"}, + {"role": "tool", "tool_call_id": "tc_1", "content": "result"}, + {"role": "assistant", "content": "reply 2"}, + {"role": "tool", "tool_call_id": "tc_2", "content": "result 2"}, + ] + + spec = AgentRunSpec( + initial_messages=messages, + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + context_window_tokens=2000, + context_block_limit=100, + ) + + monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", lambda *_a, **_kw: (500, None)) + monkeypatch.setattr( + "nanobot.agent.runner.estimate_message_tokens", + lambda msg: 100, + ) + + trimmed = runner._snip_history(spec, messages) + + # Should not crash. The result should still be a valid list. + assert isinstance(trimmed, list) + # Must have at least system. + assert any(m.get("role") == "system" for m in trimmed) + # The _enforce_role_alternation safety net must be able to fix whatever + # _snip_history returns here — verify it produces a valid sequence. + from nanobot.providers.base import LLMProvider + fixed = LLMProvider._enforce_role_alternation(trimmed) + non_system = [m for m in fixed if m["role"] != "system"] + if non_system: + assert non_system[0]["role"] in ("user", "tool"), ( + f"Safety net should ensure first non-system is user/tool, got {non_system[0]['role']}" + ) diff --git a/tests/agent/test_runner_hooks.py b/tests/agent/test_runner_hooks.py new file mode 100644 index 000000000..7718eee20 --- /dev/null +++ b/tests/agent/test_runner_hooks.py @@ -0,0 +1,172 @@ +"""Tests for AgentRunner hook lifecycle: ordering, streaming deltas, +cached-token propagation, and hook context.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from nanobot.config.schema import AgentDefaults +from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest + +_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars + + +@pytest.mark.asyncio +async def test_runner_calls_hooks_in_order(): + from nanobot.agent.hook import AgentHook, AgentHookContext + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock(spec=LLMProvider) + call_count = {"n": 0} + events: list[tuple] = [] + + async def chat_with_retry(**kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="thinking", + tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})], + ) + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="tool result") + + class RecordingHook(AgentHook): + async def before_iteration(self, context: AgentHookContext) -> None: + events.append(("before_iteration", context.iteration)) + + async def before_execute_tools(self, context: AgentHookContext) -> None: + events.append(( + "before_execute_tools", + context.iteration, + [tc.name for tc in context.tool_calls], + )) + + async def after_iteration(self, context: AgentHookContext) -> None: + events.append(( + "after_iteration", + context.iteration, + context.final_content, + list(context.tool_results), + list(context.tool_events), + context.stop_reason, + )) + + def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: + events.append(("finalize_content", context.iteration, content)) + return content.upper() if content else content + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[], + tools=tools, + model="test-model", + max_iterations=3, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + hook=RecordingHook(), + )) + + assert result.final_content == "DONE" + assert events == [ + ("before_iteration", 0), + ("before_execute_tools", 0, ["list_dir"]), + ( + "after_iteration", + 0, + None, + ["tool result"], + [{"name": "list_dir", "status": "ok", "detail": "tool result"}], + None, + ), + ("before_iteration", 1), + ("finalize_content", 1, "done"), + ("after_iteration", 1, "DONE", [], [], "completed"), + ] + + +@pytest.mark.asyncio +async def test_runner_streaming_hook_receives_deltas_and_end_signal(): + from nanobot.agent.hook import AgentHook, AgentHookContext + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock(spec=LLMProvider) + streamed: list[str] = [] + endings: list[bool] = [] + + async def chat_stream_with_retry(*, on_content_delta, **kwargs): + await on_content_delta("he") + await on_content_delta("llo") + return LLMResponse(content="hello", tool_calls=[], usage={}) + + provider.chat_stream_with_retry = chat_stream_with_retry + provider.chat_with_retry = AsyncMock() + tools = MagicMock() + tools.get_definitions.return_value = [] + + class StreamingHook(AgentHook): + def wants_streaming(self) -> bool: + return True + + async def on_stream(self, context: AgentHookContext, delta: str) -> None: + streamed.append(delta) + + async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: + endings.append(resuming) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + hook=StreamingHook(), + )) + + assert result.final_content == "hello" + assert streamed == ["he", "llo"] + assert endings == [False] + provider.chat_with_retry.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_runner_passes_cached_tokens_to_hook_context(): + """Hook context.usage should contain cached_tokens.""" + from nanobot.agent.hook import AgentHook, AgentHookContext + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock(spec=LLMProvider) + captured_usage: list[dict] = [] + + class UsageHook(AgentHook): + async def after_iteration(self, context: AgentHookContext) -> None: + captured_usage.append(dict(context.usage)) + + async def chat_with_retry(**kwargs): + return LLMResponse( + content="done", + tool_calls=[], + usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + runner = AgentRunner(provider) + await runner.run(AgentRunSpec( + initial_messages=[], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + hook=UsageHook(), + )) + + assert len(captured_usage) == 1 + assert captured_usage[0]["cached_tokens"] == 150 diff --git a/tests/agent/test_runner_injections.py b/tests/agent/test_runner_injections.py new file mode 100644 index 000000000..1aa504e32 --- /dev/null +++ b/tests/agent/test_runner_injections.py @@ -0,0 +1,1038 @@ +"""Tests for the mid-turn injection system: drain, checkpoints, pending queues, error paths.""" + +from __future__ import annotations + +import asyncio +import base64 +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from nanobot.config.schema import AgentDefaults +from nanobot.providers.base import LLMResponse, ToolCallRequest + +_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars + + +def _make_injection_callback(queue: asyncio.Queue): + """Return an async callback that drains *queue* into a list of dicts.""" + async def inject_cb(): + items = [] + while not queue.empty(): + items.append(await queue.get()) + return items + return inject_cb + + +def _make_loop(tmp_path): + from nanobot.agent.loop import AgentLoop + from nanobot.bus.queue import MessageBus + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + + with patch("nanobot.agent.loop.ContextBuilder"), \ + patch("nanobot.agent.loop.SessionManager"), \ + patch("nanobot.agent.loop.SubagentManager") as MockSubMgr: + MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0) + loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path) + return loop + +@pytest.mark.asyncio +async def test_drain_injections_returns_empty_when_no_callback(): + """No injection_callback → empty list.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + runner = AgentRunner(provider) + tools = MagicMock() + tools.get_definitions.return_value = [] + spec = AgentRunSpec( + initial_messages=[], tools=tools, model="m", + max_iterations=1, max_tool_result_chars=1000, + injection_callback=None, + ) + result = await runner._drain_injections(spec) + assert result == [] + + +@pytest.mark.asyncio +async def test_drain_injections_extracts_content_from_inbound_messages(): + """Should extract .content from InboundMessage objects.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + runner = AgentRunner(provider) + tools = MagicMock() + tools.get_definitions.return_value = [] + + msgs = [ + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="hello"), + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="world"), + ] + + async def cb(): + return msgs + + spec = AgentRunSpec( + initial_messages=[], tools=tools, model="m", + max_iterations=1, max_tool_result_chars=1000, + injection_callback=cb, + ) + result = await runner._drain_injections(spec) + assert result == [ + {"role": "user", "content": "hello"}, + {"role": "user", "content": "world"}, + ] + + +@pytest.mark.asyncio +async def test_drain_injections_passes_limit_to_callback_when_supported(): + """Limit-aware callbacks can preserve overflow in their own queue.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + runner = AgentRunner(provider) + tools = MagicMock() + tools.get_definitions.return_value = [] + seen_limits: list[int] = [] + + msgs = [ + InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg{i}") + for i in range(_MAX_INJECTIONS_PER_TURN + 3) + ] + + async def cb(*, limit: int): + seen_limits.append(limit) + return msgs[:limit] + + spec = AgentRunSpec( + initial_messages=[], tools=tools, model="m", + max_iterations=1, max_tool_result_chars=1000, + injection_callback=cb, + ) + result = await runner._drain_injections(spec) + assert seen_limits == [_MAX_INJECTIONS_PER_TURN] + assert result == [ + {"role": "user", "content": "msg0"}, + {"role": "user", "content": "msg1"}, + {"role": "user", "content": "msg2"}, + ] + + +@pytest.mark.asyncio +async def test_drain_injections_skips_empty_content(): + """Messages with blank content should be filtered out.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + runner = AgentRunner(provider) + tools = MagicMock() + tools.get_definitions.return_value = [] + + msgs = [ + InboundMessage(channel="cli", sender_id="u", chat_id="c", content=""), + InboundMessage(channel="cli", sender_id="u", chat_id="c", content=" "), + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="valid"), + ] + + async def cb(): + return msgs + + spec = AgentRunSpec( + initial_messages=[], tools=tools, model="m", + max_iterations=1, max_tool_result_chars=1000, + injection_callback=cb, + ) + result = await runner._drain_injections(spec) + assert result == [{"role": "user", "content": "valid"}] + + +@pytest.mark.asyncio +async def test_drain_injections_handles_callback_exception(): + """If the callback raises, return empty list (error is logged).""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + runner = AgentRunner(provider) + tools = MagicMock() + tools.get_definitions.return_value = [] + + async def cb(): + raise RuntimeError("boom") + + spec = AgentRunSpec( + initial_messages=[], tools=tools, model="m", + max_iterations=1, max_tool_result_chars=1000, + injection_callback=cb, + ) + result = await runner._drain_injections(spec) + assert result == [] + + +@pytest.mark.asyncio +async def test_checkpoint1_injects_after_tool_execution(): + """Follow-up messages are injected after tool execution, before next LLM call.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + captured_messages = [] + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + captured_messages.append(list(messages)) + if call_count["n"] == 1: + return LLMResponse( + content="using tool", + tool_calls=[ToolCallRequest(id="c1", name="read_file", arguments={"path": "x"})], + usage={}, + ) + return LLMResponse(content="final answer", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="file content") + + injection_queue = asyncio.Queue() + inject_cb = _make_injection_callback(injection_queue) + + # Put a follow-up message in the queue before the run starts + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + assert result.final_content == "final answer" + # The second call should have the injected user message + assert call_count["n"] == 2 + last_messages = captured_messages[-1] + injected = [m for m in last_messages if m.get("role") == "user" and m.get("content") == "follow-up question"] + assert len(injected) == 1 + + +@pytest.mark.asyncio +async def test_checkpoint2_injects_after_final_response_with_resuming_stream(): + """After final response, if injections exist, stream_end should get resuming=True.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.hook import AgentHook, AgentHookContext + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + stream_end_calls = [] + + class TrackingHook(AgentHook): + def wants_streaming(self) -> bool: + return True + + async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: + stream_end_calls.append(resuming) + + def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: + return content + + async def chat_stream_with_retry(*, messages, on_content_delta=None, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse(content="first answer", tool_calls=[], usage={}) + return LLMResponse(content="second answer", tool_calls=[], usage={}) + + provider.chat_stream_with_retry = chat_stream_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + injection_queue = asyncio.Queue() + inject_cb = _make_injection_callback(injection_queue) + + # Inject a follow-up that arrives during the first response + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="quick follow-up") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + hook=TrackingHook(), + injection_callback=inject_cb, + )) + + assert result.had_injections is True + assert result.final_content == "second answer" + assert call_count["n"] == 2 + # First stream_end should have resuming=True (because injections found) + assert stream_end_calls[0] is True + # Second (final) stream_end should have resuming=False + assert stream_end_calls[-1] is False + + +@pytest.mark.asyncio +async def test_checkpoint2_preserves_final_response_in_history_before_followup(): + """A follow-up injected after a final answer must still see that answer in history.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + captured_messages = [] + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + captured_messages.append([dict(message) for message in messages]) + if call_count["n"] == 1: + return LLMResponse(content="first answer", tool_calls=[], usage={}) + return LLMResponse(content="second answer", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + injection_queue = asyncio.Queue() + inject_cb = _make_injection_callback(injection_queue) + + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.final_content == "second answer" + assert call_count["n"] == 2 + assert captured_messages[-1] == [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "first answer"}, + {"role": "user", "content": "follow-up question"}, + ] + assert [ + {"role": message["role"], "content": message["content"]} + for message in result.messages + if message.get("role") == "assistant" + ] == [ + {"role": "assistant", "content": "first answer"}, + {"role": "assistant", "content": "second answer"}, + ] + + +@pytest.mark.asyncio +async def test_loop_injected_followup_preserves_image_media(tmp_path): + """Mid-turn follow-ups with images should keep multimodal content.""" + from nanobot.agent.loop import AgentLoop + from nanobot.bus.events import InboundMessage + from nanobot.bus.queue import MessageBus + + image_path = tmp_path / "followup.png" + image_path.write_bytes(base64.b64decode( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+yF9kAAAAASUVORK5CYII=" + )) + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + captured_messages: list[list[dict]] = [] + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + captured_messages.append(list(messages)) + if call_count["n"] == 1: + return LLMResponse(content="first answer", tool_calls=[], usage={}) + return LLMResponse(content="second answer", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") + loop.tools.get_definitions = MagicMock(return_value=[]) + + pending_queue = asyncio.Queue() + await pending_queue.put(InboundMessage( + channel="cli", + sender_id="u", + chat_id="c", + content="", + media=[str(image_path)], + )) + + final_content, _, _, _, had_injections = await loop._run_agent_loop( + [{"role": "user", "content": "hello"}], + channel="cli", + chat_id="c", + pending_queue=pending_queue, + ) + + assert final_content == "second answer" + assert had_injections is True + assert call_count["n"] == 2 + injected_user_messages = [ + message for message in captured_messages[-1] + if message.get("role") == "user" and isinstance(message.get("content"), list) + ] + assert injected_user_messages + assert any( + block.get("type") == "image_url" + for block in injected_user_messages[-1]["content"] + if isinstance(block, dict) + ) + + +@pytest.mark.asyncio +async def test_runner_merges_multiple_injected_user_messages_without_losing_media(): + """Multiple injected follow-ups should not create lossy consecutive user messages.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + call_count = {"n": 0} + captured_messages = [] + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + captured_messages.append([dict(message) for message in messages]) + if call_count["n"] == 1: + return LLMResponse(content="first answer", tool_calls=[], usage={}) + return LLMResponse(content="second answer", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + async def inject_cb(): + if call_count["n"] == 1: + return [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + {"type": "text", "text": "look at this"}, + ], + }, + {"role": "user", "content": "and answer briefly"}, + ] + return [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.final_content == "second answer" + assert call_count["n"] == 2 + second_call = captured_messages[-1] + user_messages = [message for message in second_call if message.get("role") == "user"] + assert len(user_messages) == 2 + injected = user_messages[-1] + assert isinstance(injected["content"], list) + assert any( + block.get("type") == "image_url" + for block in injected["content"] + if isinstance(block, dict) + ) + assert any( + block.get("type") == "text" and block.get("text") == "and answer briefly" + for block in injected["content"] + if isinstance(block, dict) + ) + + +@pytest.mark.asyncio +async def test_injection_cycles_capped_at_max(): + """Injection cycles should be capped at _MAX_INJECTION_CYCLES.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTION_CYCLES + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + drain_count = {"n": 0} + + async def inject_cb(): + drain_count["n"] += 1 + # Only inject for the first _MAX_INJECTION_CYCLES drains + if drain_count["n"] <= _MAX_INJECTION_CYCLES: + return [InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg-{drain_count['n']}")] + return [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "start"}], + tools=tools, + model="test-model", + max_iterations=20, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + # Should be capped: _MAX_INJECTION_CYCLES injection rounds + 1 final round + assert call_count["n"] == _MAX_INJECTION_CYCLES + 1 + + +@pytest.mark.asyncio +async def test_no_injections_flag_is_false_by_default(): + """had_injections should be False when no injection callback or no messages.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + + async def chat_with_retry(**kwargs): + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hi"}], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.had_injections is False + + +@pytest.mark.asyncio +async def test_pending_queue_cleanup_on_dispatch(tmp_path): + """_pending_queues should be cleaned up after _dispatch completes.""" + loop = _make_loop(tmp_path) + + async def chat_with_retry(**kwargs): + return LLMResponse(content="done", tool_calls=[], usage={}) + + loop.provider.chat_with_retry = chat_with_retry + + from nanobot.bus.events import InboundMessage + + msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="hello") + # The queue should not exist before dispatch + assert msg.session_key not in loop._pending_queues + + await loop._dispatch(msg) + + # The queue should be cleaned up after dispatch + assert msg.session_key not in loop._pending_queues + + +@pytest.mark.asyncio +async def test_followup_routed_to_pending_queue(tmp_path): + """Unified-session follow-ups should route into the active pending queue.""" + from nanobot.agent.loop import UNIFIED_SESSION_KEY + from nanobot.bus.events import InboundMessage + + loop = _make_loop(tmp_path) + loop._unified_session = True + loop._dispatch = AsyncMock() # type: ignore[method-assign] + + pending = asyncio.Queue(maxsize=20) + loop._pending_queues[UNIFIED_SESSION_KEY] = pending + + run_task = asyncio.create_task(loop.run()) + msg = InboundMessage(channel="discord", sender_id="u", chat_id="c", content="follow-up") + await loop.bus.publish_inbound(msg) + + deadline = time.time() + 2 + while pending.empty() and time.time() < deadline: + await asyncio.sleep(0.01) + + loop.stop() + await asyncio.wait_for(run_task, timeout=2) + + assert loop._dispatch.await_count == 0 + assert not pending.empty() + queued_msg = pending.get_nowait() + assert queued_msg.content == "follow-up" + assert queued_msg.session_key == UNIFIED_SESSION_KEY + + +@pytest.mark.asyncio +async def test_pending_queue_preserves_overflow_for_next_injection_cycle(tmp_path): + """Pending queue should leave overflow messages queued for later drains.""" + from nanobot.agent.loop import AgentLoop + from nanobot.bus.events import InboundMessage + from nanobot.bus.queue import MessageBus + from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + captured_messages: list[list[dict]] = [] + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + captured_messages.append([dict(message) for message in messages]) + return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") + loop.tools.get_definitions = MagicMock(return_value=[]) + + pending_queue = asyncio.Queue() + total_followups = _MAX_INJECTIONS_PER_TURN + 2 + for idx in range(total_followups): + await pending_queue.put(InboundMessage( + channel="cli", + sender_id="u", + chat_id="c", + content=f"follow-up-{idx}", + )) + + final_content, _, _, _, had_injections = await loop._run_agent_loop( + [{"role": "user", "content": "hello"}], + channel="cli", + chat_id="c", + pending_queue=pending_queue, + ) + + assert final_content == "answer-3" + assert had_injections is True + assert call_count["n"] == 3 + flattened_user_content = "\n".join( + message["content"] + for message in captured_messages[-1] + if message.get("role") == "user" and isinstance(message.get("content"), str) + ) + for idx in range(total_followups): + assert f"follow-up-{idx}" in flattened_user_content + assert pending_queue.empty() + + +@pytest.mark.asyncio +async def test_pending_queue_full_falls_back_to_queued_task(tmp_path): + """QueueFull should preserve the message by dispatching a queued task.""" + from nanobot.bus.events import InboundMessage + + loop = _make_loop(tmp_path) + loop._dispatch = AsyncMock() # type: ignore[method-assign] + + pending = asyncio.Queue(maxsize=1) + pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="already queued")) + loop._pending_queues["cli:c"] = pending + + run_task = asyncio.create_task(loop.run()) + msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up") + await loop.bus.publish_inbound(msg) + + deadline = time.time() + 2 + while loop._dispatch.await_count == 0 and time.time() < deadline: + await asyncio.sleep(0.01) + + loop.stop() + await asyncio.wait_for(run_task, timeout=2) + + assert loop._dispatch.await_count == 1 + dispatched_msg = loop._dispatch.await_args.args[0] + assert dispatched_msg.content == "follow-up" + assert pending.qsize() == 1 + + +@pytest.mark.asyncio +async def test_dispatch_republishes_leftover_queue_messages(tmp_path): + """Messages left in the pending queue after _dispatch are re-published to the bus. + + This tests the finally-block cleanup that prevents message loss when + the runner exits early (e.g., max_iterations, tool_error) with messages + still in the queue. + """ + from nanobot.bus.events import InboundMessage + + loop = _make_loop(tmp_path) + bus = loop.bus + + # Simulate a completed dispatch by manually registering a queue + # with leftover messages, then running the cleanup logic directly. + pending = asyncio.Queue(maxsize=20) + session_key = "cli:c" + loop._pending_queues[session_key] = pending + pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="leftover-1")) + pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="leftover-2")) + + # Execute the cleanup logic from the finally block + queue = loop._pending_queues.pop(session_key, None) + assert queue is not None + leftover = 0 + while True: + try: + item = queue.get_nowait() + except asyncio.QueueEmpty: + break + await bus.publish_inbound(item) + leftover += 1 + + assert leftover == 2 + + # Verify the messages are now on the bus + msgs = [] + while not bus.inbound.empty(): + msgs.append(await asyncio.wait_for(bus.consume_inbound(), timeout=0.5)) + contents = [m.content for m in msgs] + assert "leftover-1" in contents + assert "leftover-2" in contents + + +@pytest.mark.asyncio +async def test_drain_injections_on_fatal_tool_error(): + """Pending injections should be drained even when a fatal tool error occurs.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="", + tool_calls=[ToolCallRequest(id="c1", name="exec", arguments={"cmd": "bad"})], + usage={}, + ) + # Second call: respond normally to the injected follow-up + return LLMResponse(content="reply to follow-up", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(side_effect=RuntimeError("tool exploded")) + + injection_queue = asyncio.Queue() + inject_cb = _make_injection_callback(injection_queue) + + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after error") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + fail_on_tool_error=True, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + assert result.final_content == "reply to follow-up" + # The injection should be in the messages history + injected = [ + m for m in result.messages + if m.get("role") == "user" and m.get("content") == "follow-up after error" + ] + assert len(injected) == 1 + + +@pytest.mark.asyncio +async def test_drain_injections_on_llm_error(): + """Pending injections should be drained when the LLM returns an error finish_reason.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content=None, + tool_calls=[], + finish_reason="error", + usage={}, + ) + # Second call: respond normally to the injected follow-up + return LLMResponse(content="recovered answer", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + injection_queue = asyncio.Queue() + inject_cb = _make_injection_callback(injection_queue) + + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after LLM error") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "previous response"}, + {"role": "user", "content": "trigger error"}, + ], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + assert result.final_content == "recovered answer" + injected = [ + m for m in result.messages + if m.get("role") == "user" and "follow-up after LLM error" in str(m.get("content", "")) + ] + assert len(injected) == 1 + + +@pytest.mark.asyncio +async def test_drain_injections_on_empty_final_response(): + """Pending injections should be drained when the runner exits due to empty response.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_EMPTY_RETRIES + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] <= _MAX_EMPTY_RETRIES + 1: + return LLMResponse(content="", tool_calls=[], usage={}) + # After retries exhausted + injection drain, respond normally + return LLMResponse(content="answer after empty", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + injection_queue = asyncio.Queue() + inject_cb = _make_injection_callback(injection_queue) + + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after empty") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "previous response"}, + {"role": "user", "content": "trigger empty"}, + ], + tools=tools, + model="test-model", + max_iterations=10, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + assert result.final_content == "answer after empty" + injected = [ + m for m in result.messages + if m.get("role") == "user" and "follow-up after empty" in str(m.get("content", "")) + ] + assert len(injected) == 1 + + +@pytest.mark.asyncio +async def test_drain_injections_on_max_iterations(): + """Pending injections should be drained when the runner hits max_iterations. + + Unlike other error paths, max_iterations cannot continue the loop, so + injections are appended to messages but not processed by the LLM. + The key point is they are consumed from the queue to prevent re-publish. + """ + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + return LLMResponse( + content="", + tool_calls=[ToolCallRequest(id=f"c{call_count['n']}", name="read_file", arguments={"path": "x"})], + usage={}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="file content") + + injection_queue = asyncio.Queue() + inject_cb = _make_injection_callback(injection_queue) + + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after max iters") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=2, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.stop_reason == "max_iterations" + assert result.had_injections is True + # The injection was consumed from the queue (preventing re-publish) + assert injection_queue.empty() + # The injection message is appended to conversation history + injected = [ + m for m in result.messages + if m.get("role") == "user" and m.get("content") == "follow-up after max iters" + ] + assert len(injected) == 1 + + +@pytest.mark.asyncio +async def test_drain_injections_set_flag_when_followup_arrives_after_last_iteration(): + """Late follow-ups drained in max_iterations should still flip had_injections.""" + from nanobot.agent.hook import AgentHook + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + return LLMResponse( + content="", + tool_calls=[ToolCallRequest(id=f"c{call_count['n']}", name="read_file", arguments={"path": "x"})], + usage={}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="file content") + + injection_queue = asyncio.Queue() + inject_cb = _make_injection_callback(injection_queue) + + class InjectOnLastAfterIterationHook(AgentHook): + def __init__(self) -> None: + self.after_iteration_calls = 0 + + async def after_iteration(self, context) -> None: + self.after_iteration_calls += 1 + if self.after_iteration_calls == 2: + await injection_queue.put( + InboundMessage( + channel="cli", + sender_id="u", + chat_id="c", + content="late follow-up after max iters", + ) + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=2, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + hook=InjectOnLastAfterIterationHook(), + )) + + assert result.stop_reason == "max_iterations" + assert result.had_injections is True + assert injection_queue.empty() + injected = [ + m for m in result.messages + if m.get("role") == "user" and m.get("content") == "late follow-up after max iters" + ] + assert len(injected) == 1 + + +@pytest.mark.asyncio +async def test_injection_cycle_cap_on_error_path(): + """Injection cycles should be capped even when every iteration hits an LLM error.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTION_CYCLES + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + return LLMResponse( + content=None, + tool_calls=[], + finish_reason="error", + usage={}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + drain_count = {"n": 0} + + async def inject_cb(): + drain_count["n"] += 1 + if drain_count["n"] <= _MAX_INJECTION_CYCLES: + return [InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg-{drain_count['n']}")] + return [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "previous"}, + {"role": "user", "content": "trigger error"}, + ], + tools=tools, + model="test-model", + max_iterations=20, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + # Should cap: _MAX_INJECTION_CYCLES drained rounds + 1 final round that breaks + assert call_count["n"] == _MAX_INJECTION_CYCLES + 1 + diff --git a/tests/agent/test_runner_persistence.py b/tests/agent/test_runner_persistence.py new file mode 100644 index 000000000..d2bcfa9d4 --- /dev/null +++ b/tests/agent/test_runner_persistence.py @@ -0,0 +1,161 @@ +"""Tests for tool result persistence: large results, pruning, temp files, cleanup.""" + +from __future__ import annotations + +import os +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from nanobot.config.schema import AgentDefaults +from nanobot.providers.base import LLMResponse, ToolCallRequest + +_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars + +async def test_runner_persists_large_tool_results_for_follow_up_calls(tmp_path): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + captured_second_call: list[dict] = [] + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="working", + tool_calls=[ToolCallRequest(id="call_big", name="list_dir", arguments={"path": "."})], + usage={"prompt_tokens": 5, "completion_tokens": 3}, + ) + captured_second_call[:] = messages + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="x" * 20_000) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "do task"}], + tools=tools, + model="test-model", + max_iterations=2, + workspace=tmp_path, + session_key="test:runner", + max_tool_result_chars=2048, + )) + + assert result.final_content == "done" + tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool") + assert "[tool output persisted]" in tool_message["content"] + assert "tool-results" in tool_message["content"] + assert (tmp_path / ".nanobot" / "tool-results" / "test_runner" / "call_big.txt").exists() + + +def test_persist_tool_result_prunes_old_session_buckets(tmp_path): + from nanobot.utils.helpers import maybe_persist_tool_result + + root = tmp_path / ".nanobot" / "tool-results" + old_bucket = root / "old_session" + recent_bucket = root / "recent_session" + old_bucket.mkdir(parents=True) + recent_bucket.mkdir(parents=True) + (old_bucket / "old.txt").write_text("old", encoding="utf-8") + (recent_bucket / "recent.txt").write_text("recent", encoding="utf-8") + + stale = time.time() - (8 * 24 * 60 * 60) + os.utime(old_bucket, (stale, stale)) + os.utime(old_bucket / "old.txt", (stale, stale)) + + persisted = maybe_persist_tool_result( + tmp_path, + "current:session", + "call_big", + "x" * 5000, + max_chars=64, + ) + + assert "[tool output persisted]" in persisted + assert not old_bucket.exists() + assert recent_bucket.exists() + assert (root / "current_session" / "call_big.txt").exists() + + +def test_persist_tool_result_leaves_no_temp_files(tmp_path): + from nanobot.utils.helpers import maybe_persist_tool_result + + root = tmp_path / ".nanobot" / "tool-results" + maybe_persist_tool_result( + tmp_path, + "current:session", + "call_big", + "x" * 5000, + max_chars=64, + ) + + assert (root / "current_session" / "call_big.txt").exists() + assert list((root / "current_session").glob("*.tmp")) == [] + + +def test_persist_tool_result_logs_cleanup_failures(monkeypatch, tmp_path): + from nanobot.utils.helpers import maybe_persist_tool_result + + warnings: list[str] = [] + + monkeypatch.setattr( + "nanobot.utils.helpers._cleanup_tool_result_buckets", + lambda *_args, **_kwargs: (_ for _ in ()).throw(OSError("busy")), + ) + monkeypatch.setattr( + "nanobot.utils.helpers.logger.exception", + lambda message, *args: warnings.append(message.format(*args)), + ) + + persisted = maybe_persist_tool_result( + tmp_path, + "current:session", + "call_big", + "x" * 5000, + max_chars=64, + ) + + assert "[tool output persisted]" in persisted + assert warnings and "Failed to clean stale tool result buckets" in warnings[0] +async def test_runner_keeps_going_when_tool_result_persistence_fails(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + captured_second_call: list[dict] = [] + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="working", + tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})], + usage={"prompt_tokens": 5, "completion_tokens": 3}, + ) + captured_second_call[:] = messages + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="tool result") + + runner = AgentRunner(provider) + with patch("nanobot.agent.runner.maybe_persist_tool_result", side_effect=RuntimeError("disk full")): + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "do task"}], + tools=tools, + model="test-model", + max_iterations=2, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.final_content == "done" + tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool") + assert tool_message["content"] == "tool result" diff --git a/tests/agent/test_runner_reasoning.py b/tests/agent/test_runner_reasoning.py new file mode 100644 index 000000000..d971e05a1 --- /dev/null +++ b/tests/agent/test_runner_reasoning.py @@ -0,0 +1,321 @@ +"""Tests for AgentRunner reasoning extraction and emission. + +Covers the three sources of model reasoning (dedicated ``reasoning_content``, +Anthropic ``thinking_blocks``, inline ````/```` tags) plus +the streaming interaction: reasoning and answer streams are independent +channels, gated by ``context.streamed_reasoning`` rather than +``context.streamed_content``. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from nanobot.agent.hook import AgentHook +from nanobot.config.schema import AgentDefaults +from nanobot.providers.base import LLMResponse, ToolCallRequest + +_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars + + +class _RecordingHook(AgentHook): + def __init__(self) -> None: + super().__init__() + self.emitted: list[str] = [] + self.end_calls = 0 + + async def emit_reasoning(self, reasoning_content: str | None) -> None: + if reasoning_content: + self.emitted.append(reasoning_content) + + async def emit_reasoning_end(self) -> None: + self.end_calls += 1 + + +@pytest.mark.asyncio +async def test_runner_preserves_reasoning_fields_in_assistant_history(): + """Reasoning fields ride along on the persisted assistant message so + follow-up provider calls retain the model's prior thinking context.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + captured_second_call: list[dict] = [] + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="thinking", + tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})], + reasoning_content="hidden reasoning", + thinking_blocks=[{"type": "thinking", "thinking": "step"}], + usage={"prompt_tokens": 5, "completion_tokens": 3}, + ) + captured_second_call[:] = messages + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="tool result") + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[ + {"role": "system", "content": "system"}, + {"role": "user", "content": "do task"}, + ], + tools=tools, + model="test-model", + max_iterations=3, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.final_content == "done" + assistant_messages = [ + msg for msg in captured_second_call + if msg.get("role") == "assistant" and msg.get("tool_calls") + ] + assert len(assistant_messages) == 1 + assert assistant_messages[0]["reasoning_content"] == "hidden reasoning" + assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}] + + +@pytest.mark.asyncio +async def test_runner_emits_anthropic_thinking_blocks(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + + async def chat_with_retry(**kwargs): + return LLMResponse( + content="The answer is 42.", + thinking_blocks=[ + {"type": "thinking", "thinking": "Let me analyze this step by step.", "signature": "sig1"}, + {"type": "thinking", "thinking": "After careful consideration.", "signature": "sig2"}, + ], + tool_calls=[], + usage={"prompt_tokens": 5, "completion_tokens": 3}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + hook = _RecordingHook() + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "question"}], + tools=tools, + model="test-model", + max_iterations=3, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + hook=hook, + )) + + assert result.final_content == "The answer is 42." + assert len(hook.emitted) == 1 + assert "Let me analyze this" in hook.emitted[0] + assert "After careful consideration" in hook.emitted[0] + + +@pytest.mark.asyncio +async def test_runner_emits_inline_think_content_as_reasoning(): + """Models embedding reasoning in ... blocks should have + that content extracted and emitted, and stripped from the answer.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + + async def chat_with_retry(**kwargs): + return LLMResponse( + content="Let me think about this...\nThe answer is 42.The answer is 42.", + tool_calls=[], + usage={"prompt_tokens": 5, "completion_tokens": 3}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + hook = _RecordingHook() + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "what is the answer?"}], + tools=tools, + model="test-model", + max_iterations=3, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + hook=hook, + )) + + assert result.final_content == "The answer is 42." + assert len(hook.emitted) == 1 + assert "Let me think about this" in hook.emitted[0] + + +@pytest.mark.asyncio +async def test_runner_prefers_reasoning_content_over_inline_think(): + """Fallback priority: dedicated reasoning_content wins; inline + is still scrubbed from the answer content.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + + async def chat_with_retry(**kwargs): + return LLMResponse( + content="inline thinkingThe answer.", + reasoning_content="dedicated reasoning field", + tool_calls=[], + usage={"prompt_tokens": 5, "completion_tokens": 3}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + hook = _RecordingHook() + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "question"}], + tools=tools, + model="test-model", + max_iterations=3, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + hook=hook, + )) + + assert result.final_content == "The answer." + assert hook.emitted == ["dedicated reasoning field"] + + +@pytest.mark.asyncio +async def test_runner_emits_reasoning_content_even_when_answer_was_streamed(): + """`reasoning_content` arrives only on the final response; streaming the + answer must not suppress it (the answer stream and the reasoning channel + are independent — only the reasoning-already-emitted bit matters).""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + provider.supports_progress_deltas = True + + async def chat_stream_with_retry(*, on_content_delta=None, **kwargs): + if on_content_delta: + await on_content_delta("The ") + await on_content_delta("answer.") + return LLMResponse( + content="The answer.", + reasoning_content="step-by-step deduction", + tool_calls=[], + usage={"prompt_tokens": 5, "completion_tokens": 3}, + ) + + provider.chat_stream_with_retry = chat_stream_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + progress_calls: list[str] = [] + + async def _progress(content: str, **_kwargs): + progress_calls.append(content) + + hook = _RecordingHook() + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "question"}], + tools=tools, + model="test-model", + max_iterations=3, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + hook=hook, + stream_progress_deltas=True, + progress_callback=_progress, + )) + + assert result.final_content == "The answer." + assert progress_calls, "answer should have streamed via progress callback" + assert hook.emitted == ["step-by-step deduction"] + + +@pytest.mark.asyncio +async def test_runner_does_not_double_emit_when_inline_think_already_streamed(): + """Inline `` blocks streamed incrementally during the answer + stream must not be re-emitted from the final response.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + provider.supports_progress_deltas = True + + async def chat_stream_with_retry(*, on_content_delta=None, **kwargs): + if on_content_delta: + await on_content_delta("working...") + await on_content_delta("The answer.") + return LLMResponse( + content="working...The answer.", + tool_calls=[], + usage={"prompt_tokens": 5, "completion_tokens": 3}, + ) + + provider.chat_stream_with_retry = chat_stream_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + async def _progress(content: str, **_kwargs): + pass + + hook = _RecordingHook() + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "question"}], + tools=tools, + model="test-model", + max_iterations=3, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + hook=hook, + stream_progress_deltas=True, + progress_callback=_progress, + )) + + assert result.final_content == "The answer." + assert hook.emitted == ["working..."] + assert hook.end_calls >= 1, "reasoning stream must be closed once the answer starts" + + +@pytest.mark.asyncio +async def test_runner_closes_reasoning_stream_after_one_shot_response(): + """A non-streaming response carrying ``reasoning_content`` must emit + both a reasoning delta and an end marker so channels can finalize the + in-place bubble.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + + async def chat_with_retry(**kwargs): + return LLMResponse( + content="answer", + reasoning_content="hidden thought", + tool_calls=[], + usage={"prompt_tokens": 5, "completion_tokens": 3}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + hook = _RecordingHook() + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "q"}], + tools=tools, + model="test-model", + max_iterations=3, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + hook=hook, + )) + + assert result.final_content == "answer" + assert hook.emitted == ["hidden thought"] + assert hook.end_calls == 1 diff --git a/tests/agent/test_runner_safety.py b/tests/agent/test_runner_safety.py new file mode 100644 index 000000000..14565e203 --- /dev/null +++ b/tests/agent/test_runner_safety.py @@ -0,0 +1,244 @@ +"""Tests for AgentRunner security: workspace violations, SSRF, shell guard, throttling.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from nanobot.config.schema import AgentDefaults +from nanobot.providers.base import LLMResponse, ToolCallRequest + +_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars + +async def test_runner_does_not_abort_on_workspace_violation_anymore(): + """v2 behavior: workspace-bound rejections are *soft* tool errors. + + Previously (PR #3493) any workspace boundary error became a fatal + RuntimeError that aborted the turn. That silently killed legitimate + workspace commands once the heuristic guard misfired (#3599 #3605), so + we now hand the error back to the LLM as a recoverable tool result and + rely on ``repeated_workspace_violation_error`` to throttle bypass loops. + """ + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + provider.chat_with_retry = AsyncMock(side_effect=[ + LLMResponse( + content="trying outside", + tool_calls=[ToolCallRequest( + id="call_1", name="read_file", arguments={"path": "/tmp/outside.md"}, + )], + ), + LLMResponse(content="ok, telling the user instead", tool_calls=[]), + ]) + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock( + side_effect=PermissionError( + "Path /tmp/outside.md is outside allowed directory /workspace" + ) + ) + + runner = AgentRunner(provider) + + result = await runner.run(AgentRunSpec( + initial_messages=[], + tools=tools, + model="test-model", + max_iterations=3, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert provider.chat_with_retry.await_count == 2, ( + "workspace violation must NOT short-circuit the loop" + ) + assert result.stop_reason != "tool_error" + assert result.error is None + assert result.final_content == "ok, telling the user instead" + assert result.tool_events and result.tool_events[0]["status"] == "error" + # Detail still carries the workspace_violation breadcrumb for telemetry, + # but the runner did not raise. + assert "workspace_violation" in result.tool_events[0]["detail"] + + +def test_is_ssrf_violation_recognizes_private_url_blocks(): + """SSRF rejections are classified separately from workspace boundaries.""" + from nanobot.agent.runner import AgentRunner + + ssrf_msg = "Error: Command blocked by safety guard (internal/private URL detected)" + assert AgentRunner._is_ssrf_violation(ssrf_msg) is True + assert AgentRunner._is_ssrf_violation( + "URL validation failed: Blocked: host resolves to private/internal address 192.168.1.2" + ) is True + + # Workspace-bound markers are NOT classified as SSRF. + assert AgentRunner._is_ssrf_violation( + "Error: Command blocked by safety guard (path outside working dir)" + ) is False + assert AgentRunner._is_ssrf_violation( + "Path /tmp/x is outside allowed directory /ws" + ) is False + # Deny / allowlist filter messages stay non-fatal too. + assert AgentRunner._is_ssrf_violation( + "Error: Command blocked by deny pattern filter" + ) is False + + +@pytest.mark.asyncio +async def test_runner_returns_non_retryable_hint_on_ssrf_violation(): + """SSRF stays blocked, but the runtime gives the LLM a final chance to recover.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + provider.chat_with_retry = AsyncMock(side_effect=[ + LLMResponse( + content="curl-ing metadata", + tool_calls=[ToolCallRequest( + id="call_ssrf", + name="exec", + arguments={"command": "curl http://169.254.169.254"}, + )], + ), + LLMResponse( + content="I cannot access that private URL. Please share local files.", + tool_calls=[], + ), + ]) + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value=( + "Error: Command blocked by safety guard (internal/private URL detected)" + )) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[], + tools=tools, + model="test-model", + max_iterations=3, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert provider.chat_with_retry.await_count == 2 + assert result.stop_reason == "completed" + assert result.error is None + assert result.final_content == "I cannot access that private URL. Please share local files." + assert result.tool_events and result.tool_events[0]["detail"].startswith("ssrf_violation:") + tool_messages = [m for m in result.messages if m.get("role") == "tool"] + assert tool_messages + assert "non-bypassable security boundary" in tool_messages[0]["content"] + assert "Do not retry" in tool_messages[0]["content"] + assert "tools.ssrfWhitelist" in tool_messages[0]["content"] + + +@pytest.mark.asyncio +async def test_runner_lets_llm_recover_from_shell_guard_path_outside(): + """Reporter scenario for #3599 / #3605 -- guard hit, agent recovers. + + The shell `_guard_command` heuristic fires on `2>/dev/null`-style + redirects and other shell idioms. Before v2 that abort'd the whole + turn (silent hang on Telegram per #3605); now the LLM gets the soft + error back and can finalize on the next iteration. + """ + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + captured_second_call: list[dict] = [] + + async def chat_with_retry(*, messages, **kwargs): + if provider.chat_with_retry.await_count == 1: + return LLMResponse( + content="trying noisy cleanup", + tool_calls=[ToolCallRequest( + id="call_blocked", + name="exec", + arguments={"command": "rm scratch.txt 2>/dev/null"}, + )], + ) + captured_second_call[:] = list(messages) + return LLMResponse(content="recovered final answer", tool_calls=[]) + + provider.chat_with_retry = AsyncMock(side_effect=chat_with_retry) + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock( + return_value="Error: Command blocked by safety guard (path outside working dir)" + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[], + tools=tools, + model="test-model", + max_iterations=3, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert provider.chat_with_retry.await_count == 2, ( + "guard hit must NOT short-circuit the loop -- LLM should get a second turn" + ) + assert result.stop_reason != "tool_error" + assert result.error is None + assert result.final_content == "recovered final answer" + assert result.tool_events and result.tool_events[0]["status"] == "error" + # v2: detail keeps the breadcrumb but the runner did not raise. + assert "workspace_violation" in result.tool_events[0]["detail"] + + +@pytest.mark.asyncio +async def test_runner_throttles_repeated_workspace_bypass_attempts(): + """#3493 motivation: stop the LLM bypass loop without aborting the turn. + + LLM keeps switching tools (read_file -> exec cat -> python -c open(...)) + against the same outside path. After the soft retry budget is exhausted + the runner replaces the tool result with a hard "stop trying" message + so the model finally gives up and surfaces the boundary to the user. + """ + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + bypass_attempts = [ + ToolCallRequest( + id=f"a{i}", name="exec", + arguments={"command": f"cat /Users/x/Downloads/01.md # try {i}"}, + ) + for i in range(4) + ] + responses: list[LLMResponse] = [ + LLMResponse(content=f"try {i}", tool_calls=[bypass_attempts[i]]) + for i in range(4) + ] + responses.append(LLMResponse(content="ok telling user", tool_calls=[])) + + provider = MagicMock() + provider.chat_with_retry = AsyncMock(side_effect=responses) + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock( + return_value="Error: Command blocked by safety guard (path outside working dir)" + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[], + tools=tools, + model="test-model", + max_iterations=10, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + # All 4 bypass attempts surface to the LLM (no fatal abort), and the + # runner finally completes once the LLM stops asking. + assert result.stop_reason != "tool_error" + assert result.error is None + assert result.final_content == "ok telling user" + # The third+ attempts must have been escalated -- look at the events. + escalated = [ + ev for ev in result.tool_events + if ev["status"] == "error" + and ev["detail"].startswith("workspace_violation_escalated:") + ] + assert escalated, ( + "expected at least one escalated workspace_violation event, got: " + f"{result.tool_events}" + ) diff --git a/tests/agent/test_runner_tool_execution.py b/tests/agent/test_runner_tool_execution.py new file mode 100644 index 000000000..a0380e871 --- /dev/null +++ b/tests/agent/test_runner_tool_execution.py @@ -0,0 +1,181 @@ +"""Tests for AgentRunner tool execution: batching, concurrency, exclusive tools.""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from nanobot.agent.tools.base import Tool +from nanobot.agent.tools.registry import ToolRegistry +from nanobot.config.schema import AgentDefaults +from nanobot.providers.base import LLMResponse, ToolCallRequest + +_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars + +class _DelayTool(Tool): + def __init__( + self, + name: str, + *, + delay: float, + read_only: bool, + shared_events: list[str], + exclusive: bool = False, + ): + self._name = name + self._delay = delay + self._read_only = read_only + self._shared_events = shared_events + self._exclusive = exclusive + + @property + def name(self) -> str: + return self._name + + @property + def description(self) -> str: + return self._name + + @property + def parameters(self) -> dict: + return {"type": "object", "properties": {}, "required": []} + + @property + def read_only(self) -> bool: + return self._read_only + + @property + def exclusive(self) -> bool: + return self._exclusive + + async def execute(self, **kwargs): + self._shared_events.append(f"start:{self._name}") + await asyncio.sleep(self._delay) + self._shared_events.append(f"end:{self._name}") + return self._name + + +@pytest.mark.asyncio +async def test_runner_batches_read_only_tools_before_exclusive_work(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + tools = ToolRegistry() + shared_events: list[str] = [] + read_a = _DelayTool("read_a", delay=0.05, read_only=True, shared_events=shared_events) + read_b = _DelayTool("read_b", delay=0.05, read_only=True, shared_events=shared_events) + write_a = _DelayTool("write_a", delay=0.01, read_only=False, shared_events=shared_events) + tools.register(read_a) + tools.register(read_b) + tools.register(write_a) + + runner = AgentRunner(MagicMock()) + await runner._execute_tools( + AgentRunSpec( + initial_messages=[], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + concurrent_tools=True, + ), + [ + ToolCallRequest(id="ro1", name="read_a", arguments={}), + ToolCallRequest(id="ro2", name="read_b", arguments={}), + ToolCallRequest(id="rw1", name="write_a", arguments={}), + ], + {}, + {}, + ) + + assert shared_events[0:2] == ["start:read_a", "start:read_b"] + assert "end:read_a" in shared_events and "end:read_b" in shared_events + assert shared_events.index("end:read_a") < shared_events.index("start:write_a") + assert shared_events.index("end:read_b") < shared_events.index("start:write_a") + assert shared_events[-2:] == ["start:write_a", "end:write_a"] + + +@pytest.mark.asyncio +async def test_runner_does_not_batch_exclusive_read_only_tools(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + tools = ToolRegistry() + shared_events: list[str] = [] + read_a = _DelayTool("read_a", delay=0.03, read_only=True, shared_events=shared_events) + read_b = _DelayTool("read_b", delay=0.03, read_only=True, shared_events=shared_events) + ddg_like = _DelayTool( + "ddg_like", + delay=0.01, + read_only=True, + shared_events=shared_events, + exclusive=True, + ) + tools.register(read_a) + tools.register(ddg_like) + tools.register(read_b) + + runner = AgentRunner(MagicMock()) + await runner._execute_tools( + AgentRunSpec( + initial_messages=[], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + concurrent_tools=True, + ), + [ + ToolCallRequest(id="ro1", name="read_a", arguments={}), + ToolCallRequest(id="ddg1", name="ddg_like", arguments={}), + ToolCallRequest(id="ro2", name="read_b", arguments={}), + ], + {}, + {}, + ) + + assert shared_events[0] == "start:read_a" + assert shared_events.index("end:read_a") < shared_events.index("start:ddg_like") + assert shared_events.index("end:ddg_like") < shared_events.index("start:read_b") + + +@pytest.mark.asyncio +async def test_runner_blocks_repeated_external_fetches(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + captured_final_call: list[dict] = [] + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] <= 3: + return LLMResponse( + content="working", + tool_calls=[ToolCallRequest(id=f"call_{call_count['n']}", name="web_fetch", arguments={"url": "https://example.com"})], + usage={}, + ) + captured_final_call[:] = messages + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="page content") + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "research task"}], + tools=tools, + model="test-model", + max_iterations=4, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.final_content == "done" + assert tools.execute.await_count == 2 + blocked_tool_message = [ + msg for msg in captured_final_call + if msg.get("role") == "tool" and msg.get("tool_call_id") == "call_3" + ][0] + assert "repeated external lookup blocked" in blocked_tool_message["content"] diff --git a/tests/agent/test_session_manager_history.py b/tests/agent/test_session_manager_history.py index 9fb77fafd..ffc41583d 100644 --- a/tests/agent/test_session_manager_history.py +++ b/tests/agent/test_session_manager_history.py @@ -43,6 +43,19 @@ def test_list_sessions_includes_metadata_title(tmp_path): assert rows[0]["title"] == "自动生成标题" +def test_list_sessions_includes_user_preview(tmp_path): + manager = SessionManager(tmp_path) + session = manager.get_or_create("websocket:chat-preview") + session.add_message("user", "帮我总结一下 OpenAI 的最新硬件计划") + session.add_message("assistant", "可以,我会先查最新消息。") + manager.save(session) + + rows = manager.list_sessions() + + assert rows[0]["key"] == "websocket:chat-preview" + assert rows[0]["preview"] == "帮我总结一下 OpenAI 的最新硬件计划" + + # --- Original regression test (from PR 2075) --- def test_get_history_drops_orphan_tool_results_when_window_cuts_tool_calls(): diff --git a/tests/agent/test_stop_preserves_context.py b/tests/agent/test_stop_preserves_context.py index 2a082850f..c7e766be1 100644 --- a/tests/agent/test_stop_preserves_context.py +++ b/tests/agent/test_stop_preserves_context.py @@ -10,6 +10,7 @@ See: https://github.com/HKUDS/nanobot/issues/2966 from __future__ import annotations import asyncio +from pathlib import Path from types import SimpleNamespace from typing import Any from unittest.mock import MagicMock, patch, AsyncMock @@ -17,42 +18,47 @@ from unittest.mock import MagicMock, patch, AsyncMock import pytest from nanobot.agent.loop import AgentLoop +from nanobot.bus.queue import MessageBus +from nanobot.providers.base import LLMProvider -@pytest.fixture -def mock_loop(): - """Create a minimal AgentLoop with mocked dependencies.""" - with patch.object(AgentLoop, "__init__", lambda self: None): - loop = AgentLoop() - loop.sessions = MagicMock() - loop._pending_queues = {} - loop._session_locks = {} - loop._active_tasks = {} - loop._concurrency_gate = None - loop._RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint" - loop._PENDING_USER_TURN_KEY = "pending_user_turn" - loop.bus = MagicMock() - loop.bus.publish_outbound = AsyncMock() - loop.bus.publish_inbound = AsyncMock() - loop.commands = MagicMock() - loop.commands.dispatch_priority = AsyncMock(return_value=None) - return loop +def _make_provider(): + """Create an LLM provider mock with required attributes.""" + from types import SimpleNamespace + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + provider.generation = SimpleNamespace(max_tokens=4096, temperature=0.1, reasoning_effort=None) + provider.estimate_prompt_tokens.return_value = (10_000, "test") + return provider + + +def _make_loop(tmp_path: Path) -> AgentLoop: + """Create a real AgentLoop with mocked provider — avoids patching __init__.""" + bus = MessageBus() + provider = _make_provider() + with patch("nanobot.agent.loop.ContextBuilder"), \ + patch("nanobot.agent.loop.SessionManager"), \ + patch("nanobot.agent.loop.SubagentManager") as MockSubMgr: + MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0) + return AgentLoop(bus=bus, provider=provider, workspace=tmp_path) class TestStopPreservesContext: """Verify that /stop restores partial context via checkpoint.""" - def test_restore_checkpoint_method_exists(self, mock_loop): + def test_restore_checkpoint_method_exists(self, tmp_path): """AgentLoop should have _restore_runtime_checkpoint.""" - assert hasattr(mock_loop, "_restore_runtime_checkpoint") + loop = _make_loop(tmp_path) + assert hasattr(loop, "_restore_runtime_checkpoint") - def test_checkpoint_key_constant(self, mock_loop): + def test_checkpoint_key_constant(self, tmp_path): """The runtime checkpoint key should be defined.""" - assert mock_loop._RUNTIME_CHECKPOINT_KEY == "runtime_checkpoint" + loop = _make_loop(tmp_path) + assert loop._RUNTIME_CHECKPOINT_KEY == "runtime_checkpoint" - def test_cancel_dispatch_restores_checkpoint(self, mock_loop): + def test_cancel_dispatch_restores_checkpoint(self, tmp_path): """When a task is cancelled, the checkpoint should be restored.""" - # Create a mock session with a checkpoint + loop = _make_loop(tmp_path) session = MagicMock() session.metadata = { "runtime_checkpoint": { @@ -74,14 +80,11 @@ class TestStopPreservesContext: session.messages = [ {"role": "user", "content": "Search for something"}, ] - mock_loop.sessions.get_or_create.return_value = session + loop.sessions.get_or_create.return_value = session - # The restore method should add checkpoint messages to session history - restored = mock_loop._restore_runtime_checkpoint(session) + restored = loop._restore_runtime_checkpoint(session) assert restored is True - # After restore, session should have more messages assert len(session.messages) > 1 - # The checkpoint should be cleared assert "runtime_checkpoint" not in session.metadata diff --git a/tests/agent/test_subagent_lifecycle.py b/tests/agent/test_subagent_lifecycle.py new file mode 100644 index 000000000..bf3564f28 --- /dev/null +++ b/tests/agent/test_subagent_lifecycle.py @@ -0,0 +1,558 @@ +"""Tests for SubagentManager lifecycle — spawn, run, announce, cancel.""" + +import asyncio +import time +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from nanobot.agent.hook import AgentHookContext +from nanobot.agent.runner import AgentRunResult +from nanobot.agent.subagent import ( + SubagentManager, + SubagentStatus, + _SubagentHook, +) +from nanobot.bus.queue import MessageBus +from nanobot.providers.base import LLMProvider + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _manager(tmp_path: Path, **kw) -> SubagentManager: + provider = MagicMock(spec=LLMProvider) + provider.get_default_model.return_value = "test-model" + defaults = dict( + provider=provider, + workspace=tmp_path, + bus=MessageBus(), + model="test-model", + max_tool_result_chars=16_000, + ) + defaults.update(kw) + return SubagentManager(**defaults) + + +def _make_hook_context(**overrides) -> AgentHookContext: + defaults = dict( + iteration=1, + tool_calls=[], + tool_events=[], + messages=[], + usage={}, + error=None, + stop_reason="completed", + final_content="ok", + ) + defaults.update(overrides) + return AgentHookContext(**defaults) + + +# --------------------------------------------------------------------------- +# SubagentStatus defaults +# --------------------------------------------------------------------------- + + +class TestSubagentStatus: + def test_defaults(self): + s = SubagentStatus( + task_id="abc", label="test", task_description="do stuff", + started_at=time.monotonic(), + ) + assert s.phase == "initializing" + assert s.iteration == 0 + assert s.tool_events == [] + assert s.usage == {} + assert s.stop_reason is None + assert s.error is None + + +# --------------------------------------------------------------------------- +# set_provider +# --------------------------------------------------------------------------- + + +class TestSetProvider: + def test_updates_provider_model_runner(self, tmp_path): + sm = _manager(tmp_path) + new_provider = MagicMock(spec=LLMProvider) + sm.set_provider(new_provider, "new-model") + assert sm.provider is new_provider + assert sm.model == "new-model" + assert sm.runner.provider is new_provider + + +# --------------------------------------------------------------------------- +# spawn +# --------------------------------------------------------------------------- + + +class TestSpawn: + @pytest.mark.asyncio + async def test_returns_string_with_task_id(self, tmp_path): + sm = _manager(tmp_path) + sm.runner.run = AsyncMock(return_value=AgentRunResult( + final_content="done", messages=[], stop_reason="completed", + )) + result = await sm.spawn("do something") + assert "started" in result + assert "id:" in result + + @pytest.mark.asyncio + async def test_creates_task_in_running_tasks(self, tmp_path): + sm = _manager(tmp_path) + block = asyncio.Event() + async def _slow_run(spec): + await block.wait() + return AgentRunResult(final_content="done", messages=[], stop_reason="completed") + sm.runner.run = _slow_run + + await sm.spawn("task", session_key="s1") + assert len(sm._running_tasks) == 1 + + block.set() + await asyncio.sleep(0.1) + assert len(sm._running_tasks) == 0 + + @pytest.mark.asyncio + async def test_creates_status(self, tmp_path): + sm = _manager(tmp_path) + sm.runner.run = AsyncMock(return_value=AgentRunResult( + final_content="done", messages=[], stop_reason="completed", + )) + await sm.spawn("my task") + await asyncio.sleep(0.1) + # Status cleaned up after task completes + assert len(sm._task_statuses) == 0 + + @pytest.mark.asyncio + async def test_registers_in_session_tasks(self, tmp_path): + sm = _manager(tmp_path) + block = asyncio.Event() + async def _slow_run(spec): + await block.wait() + return AgentRunResult(final_content="done", messages=[], stop_reason="completed") + sm.runner.run = _slow_run + + await sm.spawn("task", session_key="s1") + assert "s1" in sm._session_tasks + assert len(sm._session_tasks["s1"]) == 1 + + block.set() + await asyncio.sleep(0.1) + assert "s1" not in sm._session_tasks + + @pytest.mark.asyncio + async def test_no_session_key_no_registration(self, tmp_path): + sm = _manager(tmp_path) + block = asyncio.Event() + async def _slow_run(spec): + await block.wait() + return AgentRunResult(final_content="done", messages=[], stop_reason="completed") + sm.runner.run = _slow_run + + await sm.spawn("task") + assert len(sm._session_tasks) == 0 + + block.set() + await asyncio.sleep(0.1) + + @pytest.mark.asyncio + async def test_label_defaults_to_truncated_task(self, tmp_path): + sm = _manager(tmp_path) + block = asyncio.Event() + async def _slow_run(spec): + await block.wait() + return AgentRunResult(final_content="done", messages=[], stop_reason="completed") + sm.runner.run = _slow_run + + long_task = "A" * 50 + await sm.spawn(long_task, session_key="s1") + status = next(iter(sm._task_statuses.values())) + assert status.label == long_task[:30] + "..." + + block.set() + await asyncio.sleep(0.1) + + @pytest.mark.asyncio + async def test_custom_label(self, tmp_path): + sm = _manager(tmp_path) + block = asyncio.Event() + async def _slow_run(spec): + await block.wait() + return AgentRunResult(final_content="done", messages=[], stop_reason="completed") + sm.runner.run = _slow_run + + await sm.spawn("task", label="Custom Label", session_key="s1") + status = next(iter(sm._task_statuses.values())) + assert status.label == "Custom Label" + + block.set() + await asyncio.sleep(0.1) + + @pytest.mark.asyncio + async def test_cleanup_callback_removes_all_entries(self, tmp_path): + sm = _manager(tmp_path) + sm.runner.run = AsyncMock(return_value=AgentRunResult( + final_content="done", messages=[], stop_reason="completed", + )) + await sm.spawn("task", session_key="s1") + await asyncio.sleep(0.1) + assert len(sm._running_tasks) == 0 + assert len(sm._task_statuses) == 0 + assert len(sm._session_tasks) == 0 + + +# --------------------------------------------------------------------------- +# _run_subagent +# --------------------------------------------------------------------------- + + +class TestRunSubagent: + @pytest.mark.asyncio + async def test_successful_run(self, tmp_path): + sm = _manager(tmp_path) + sm.runner.run = AsyncMock(return_value=AgentRunResult( + final_content="Task done!", messages=[], stop_reason="completed", + )) + with patch.object(sm, "_announce_result", new_callable=AsyncMock) as mock_announce: + await sm._run_subagent( + "t1", "do task", "label", + {"channel": "cli", "chat_id": "direct"}, + SubagentStatus(task_id="t1", label="label", task_description="do task", started_at=time.monotonic()), + ) + mock_announce.assert_called_once() + assert mock_announce.call_args.args[-2] == "ok" + + @pytest.mark.asyncio + async def test_tool_error_run(self, tmp_path): + sm = _manager(tmp_path) + sm.runner.run = AsyncMock(return_value=AgentRunResult( + final_content=None, messages=[], stop_reason="tool_error", + tool_events=[{"name": "read_file", "status": "error", "detail": "not found"}], + )) + status = SubagentStatus(task_id="t1", label="label", task_description="do task", started_at=time.monotonic()) + with patch.object(sm, "_announce_result", new_callable=AsyncMock) as mock_announce: + await sm._run_subagent( + "t1", "do task", "label", + {"channel": "cli", "chat_id": "direct"}, status, + ) + assert mock_announce.call_args.args[-2] == "error" + + @pytest.mark.asyncio + async def test_exception_run(self, tmp_path): + sm = _manager(tmp_path) + sm.runner.run = AsyncMock(side_effect=RuntimeError("LLM down")) + status = SubagentStatus(task_id="t1", label="label", task_description="do task", started_at=time.monotonic()) + with patch.object(sm, "_announce_result", new_callable=AsyncMock) as mock_announce: + await sm._run_subagent( + "t1", "do task", "label", + {"channel": "cli", "chat_id": "direct"}, status, + ) + assert status.phase == "error" + assert "LLM down" in status.error + assert mock_announce.call_args.args[-2] == "error" + + @pytest.mark.asyncio + async def test_status_updated_on_success(self, tmp_path): + sm = _manager(tmp_path) + sm.runner.run = AsyncMock(return_value=AgentRunResult( + final_content="ok", messages=[], stop_reason="completed", + )) + status = SubagentStatus(task_id="t1", label="label", task_description="do task", started_at=time.monotonic()) + with patch.object(sm, "_announce_result", new_callable=AsyncMock): + await sm._run_subagent( + "t1", "do task", "label", + {"channel": "cli", "chat_id": "direct"}, status, + ) + assert status.phase == "done" + assert status.stop_reason == "completed" + + +# --------------------------------------------------------------------------- +# _announce_result +# --------------------------------------------------------------------------- + + +class TestAnnounceResult: + @pytest.mark.asyncio + async def test_publishes_inbound_message(self, tmp_path): + sm = _manager(tmp_path) + published = [] + sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg)) + + await sm._announce_result( + "t1", "label", "task", "result text", + {"channel": "cli", "chat_id": "direct"}, "ok", + ) + + assert len(published) == 1 + msg = published[0] + assert msg.channel == "system" + assert msg.sender_id == "subagent" + assert msg.metadata["injected_event"] == "subagent_result" + assert msg.metadata["subagent_task_id"] == "t1" + + @pytest.mark.asyncio + async def test_session_key_override(self, tmp_path): + sm = _manager(tmp_path) + published = [] + sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg)) + + await sm._announce_result( + "t1", "label", "task", "result", + {"channel": "telegram", "chat_id": "123", "session_key": "s1"}, "ok", + ) + + assert published[0].session_key_override == "s1" + + @pytest.mark.asyncio + async def test_session_key_override_fallback(self, tmp_path): + sm = _manager(tmp_path) + published = [] + sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg)) + + await sm._announce_result( + "t1", "label", "task", "result", + {"channel": "telegram", "chat_id": "123"}, "ok", + ) + + assert published[0].session_key_override == "telegram:123" + + @pytest.mark.asyncio + async def test_ok_status_text(self, tmp_path): + sm = _manager(tmp_path) + published = [] + sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg)) + + await sm._announce_result( + "t1", "label", "task", "result", + {"channel": "cli", "chat_id": "direct"}, "ok", + ) + + assert "completed successfully" in published[0].content + + @pytest.mark.asyncio + async def test_error_status_text(self, tmp_path): + sm = _manager(tmp_path) + published = [] + sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg)) + + await sm._announce_result( + "t1", "label", "task", "error details", + {"channel": "cli", "chat_id": "direct"}, "error", + ) + + assert "failed" in published[0].content + + @pytest.mark.asyncio + async def test_origin_message_id_in_metadata(self, tmp_path): + sm = _manager(tmp_path) + published = [] + sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg)) + + await sm._announce_result( + "t1", "label", "task", "result", + {"channel": "cli", "chat_id": "direct"}, "ok", + origin_message_id="msg-123", + ) + + assert published[0].metadata["origin_message_id"] == "msg-123" + + +# --------------------------------------------------------------------------- +# _format_partial_progress +# --------------------------------------------------------------------------- + + +class TestFormatPartialProgress: + def _make_result(self, tool_events=None, error=None): + return MagicMock(tool_events=tool_events or [], error=error) + + def test_completed_only(self): + result = self._make_result(tool_events=[ + {"name": "read_file", "status": "ok", "detail": "file content"}, + {"name": "exec", "status": "ok", "detail": "output"}, + ]) + text = SubagentManager._format_partial_progress(result) + assert "Completed steps:" in text + assert "read_file" in text + assert "exec" in text + + def test_failure_only(self): + result = self._make_result(tool_events=[ + {"name": "read_file", "status": "error", "detail": "not found"}, + ]) + text = SubagentManager._format_partial_progress(result) + assert "Failure:" in text + assert "not found" in text + + def test_completed_and_failure(self): + result = self._make_result(tool_events=[ + {"name": "read_file", "status": "ok", "detail": "content"}, + {"name": "exec", "status": "error", "detail": "timeout"}, + ]) + text = SubagentManager._format_partial_progress(result) + assert "Completed steps:" in text + assert "Failure:" in text + + def test_limited_to_last_three(self): + result = self._make_result(tool_events=[ + {"name": f"tool_{i}", "status": "ok", "detail": f"result_{i}"} + for i in range(5) + ]) + text = SubagentManager._format_partial_progress(result) + assert "tool_2" in text + assert "tool_3" in text + assert "tool_4" in text + assert "tool_0" not in text + assert "tool_1" not in text + + def test_error_without_failure_event(self): + result = self._make_result( + tool_events=[{"name": "read_file", "status": "ok", "detail": "ok"}], + error="Something went wrong", + ) + text = SubagentManager._format_partial_progress(result) + assert "Something went wrong" in text + + def test_empty_events_with_error(self): + result = self._make_result(error="Total failure") + text = SubagentManager._format_partial_progress(result) + assert "Total failure" in text + + def test_empty_no_error_returns_fallback(self): + result = self._make_result() + text = SubagentManager._format_partial_progress(result) + assert "Error" in text + + +# --------------------------------------------------------------------------- +# cancel_by_session +# --------------------------------------------------------------------------- + + +class TestCancelBySession: + @pytest.mark.asyncio + async def test_cancels_running_tasks(self, tmp_path): + sm = _manager(tmp_path) + block = asyncio.Event() + async def _slow_run(spec): + await block.wait() + return AgentRunResult(final_content="done", messages=[], stop_reason="completed") + sm.runner.run = _slow_run + + await sm.spawn("task1", session_key="s1") + await sm.spawn("task2", session_key="s1") + assert len(sm._session_tasks.get("s1", set())) == 2 + + count = await sm.cancel_by_session("s1") + assert count == 2 + block.set() + await asyncio.sleep(0.1) + + @pytest.mark.asyncio + async def test_no_tasks_returns_zero(self, tmp_path): + sm = _manager(tmp_path) + count = await sm.cancel_by_session("nonexistent") + assert count == 0 + + @pytest.mark.asyncio + async def test_already_done_not_counted(self, tmp_path): + sm = _manager(tmp_path) + sm.runner.run = AsyncMock(return_value=AgentRunResult( + final_content="done", messages=[], stop_reason="completed", + )) + await sm.spawn("task1", session_key="s1") + await asyncio.sleep(0.1) # Wait for completion + + count = await sm.cancel_by_session("s1") + assert count == 0 + + +# --------------------------------------------------------------------------- +# get_running_count / get_running_count_by_session +# --------------------------------------------------------------------------- + + +class TestRunningCounts: + @pytest.mark.asyncio + async def test_running_count_zero(self, tmp_path): + sm = _manager(tmp_path) + assert sm.get_running_count() == 0 + + @pytest.mark.asyncio + async def test_running_count_tracks_tasks(self, tmp_path): + sm = _manager(tmp_path) + block = asyncio.Event() + async def _slow_run(spec): + await block.wait() + return AgentRunResult(final_content="done", messages=[], stop_reason="completed") + sm.runner.run = _slow_run + + await sm.spawn("t1", session_key="s1") + await sm.spawn("t2", session_key="s1") + assert sm.get_running_count() == 2 + assert sm.get_running_count_by_session("s1") == 2 + + block.set() + await asyncio.sleep(0.1) + assert sm.get_running_count() == 0 + + @pytest.mark.asyncio + async def test_running_count_by_session_nonexistent(self, tmp_path): + sm = _manager(tmp_path) + assert sm.get_running_count_by_session("nonexistent") == 0 + + +# --------------------------------------------------------------------------- +# _SubagentHook +# --------------------------------------------------------------------------- + + +class TestSubagentHook: + @pytest.mark.asyncio + async def test_before_execute_tools_logs(self, tmp_path): + hook = _SubagentHook("t1") + tool_call = MagicMock() + tool_call.name = "read_file" + tool_call.arguments = {"path": "/tmp/test"} + ctx = _make_hook_context(tool_calls=[tool_call]) + # Should not raise + await hook.before_execute_tools(ctx) + + @pytest.mark.asyncio + async def test_after_iteration_updates_status(self): + status = SubagentStatus( + task_id="t1", label="test", task_description="do", started_at=time.monotonic(), + ) + hook = _SubagentHook("t1", status) + ctx = _make_hook_context( + iteration=3, + tool_events=[{"name": "read_file", "status": "ok", "detail": ""}], + usage={"prompt_tokens": 100}, + ) + await hook.after_iteration(ctx) + assert status.iteration == 3 + assert len(status.tool_events) == 1 + assert status.usage == {"prompt_tokens": 100} + + @pytest.mark.asyncio + async def test_after_iteration_no_status_noop(self): + hook = _SubagentHook("t1", status=None) + ctx = _make_hook_context(iteration=5) + # Should not raise + await hook.after_iteration(ctx) + + @pytest.mark.asyncio + async def test_after_iteration_sets_error(self): + status = SubagentStatus( + task_id="t1", label="test", task_description="do", started_at=time.monotonic(), + ) + hook = _SubagentHook("t1", status) + ctx = _make_hook_context(error="something broke") + await hook.after_iteration(ctx) + assert status.error == "something broke" diff --git a/tests/channels/test_channel_manager_reasoning.py b/tests/channels/test_channel_manager_reasoning.py new file mode 100644 index 000000000..bc2a640c6 --- /dev/null +++ b/tests/channels/test_channel_manager_reasoning.py @@ -0,0 +1,228 @@ +"""Tests for ChannelManager routing of model reasoning content. + +Reasoning is delivered through plugin streaming primitives +(``send_reasoning_delta`` / ``send_reasoning_end``) so each channel +controls in-place rendering — mirroring the existing answer ``send_delta`` +/ ``stream_end`` pair. The manager forwards reasoning frames only to +channels that opt in via ``channel.show_reasoning``; plugins without a +low-emphasis UI primitive keep the base no-op and the content silently +drops at dispatch. + +One-shot ``_reasoning`` frames are accepted for back-compat with hooks +that haven't migrated yet — ``BaseChannel.send_reasoning`` expands them +to a single delta + end pair so plugins only implement the streaming +primitives. +""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock + +import pytest + +from nanobot.bus.events import OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.base import BaseChannel +from nanobot.channels.manager import ChannelManager +from nanobot.config.schema import Config + + +class _MockChannel(BaseChannel): + name = "mock" + display_name = "Mock" + + def __init__(self, config, bus): + super().__init__(config, bus) + self._send_mock = AsyncMock() + self._delta_mock = AsyncMock() + self._end_mock = AsyncMock() + + async def start(self): # pragma: no cover - not exercised + pass + + async def stop(self): # pragma: no cover - not exercised + pass + + async def send(self, msg): + return await self._send_mock(msg) + + async def send_reasoning_delta(self, chat_id, delta, metadata=None): + return await self._delta_mock(chat_id, delta, metadata) + + async def send_reasoning_end(self, chat_id, metadata=None): + return await self._end_mock(chat_id, metadata) + + +@pytest.fixture +def manager() -> ChannelManager: + mgr = ChannelManager(Config(), MessageBus()) + mgr.channels["mock"] = _MockChannel({}, mgr.bus) + return mgr + + +@pytest.mark.asyncio +async def test_reasoning_delta_routes_to_send_reasoning_delta(manager): + channel = manager.channels["mock"] + msg = OutboundMessage( + channel="mock", + chat_id="c1", + content="step-by-step", + metadata={"_progress": True, "_reasoning_delta": True, "_stream_id": "r1"}, + ) + await manager._send_once(channel, msg) + channel._delta_mock.assert_awaited_once() + args = channel._delta_mock.await_args.args + assert args[0] == "c1" + assert args[1] == "step-by-step" + channel._send_mock.assert_not_awaited() + channel._end_mock.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_reasoning_end_routes_to_send_reasoning_end(manager): + channel = manager.channels["mock"] + msg = OutboundMessage( + channel="mock", + chat_id="c1", + content="", + metadata={"_progress": True, "_reasoning_end": True, "_stream_id": "r1"}, + ) + await manager._send_once(channel, msg) + channel._end_mock.assert_awaited_once() + channel._delta_mock.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_legacy_one_shot_reasoning_expands_to_delta_plus_end(manager): + """`_reasoning` (no delta/end pair) falls back through `send_reasoning` + which the base class expands to a single delta + end. Hooks that haven't + migrated still surface in WebUI as a complete stream segment.""" + channel = manager.channels["mock"] + msg = OutboundMessage( + channel="mock", + chat_id="c1", + content="one-shot reasoning", + metadata={"_progress": True, "_reasoning": True}, + ) + await manager._send_once(channel, msg) + channel._delta_mock.assert_awaited_once() + channel._end_mock.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_dispatch_drops_reasoning_when_channel_opts_out(manager): + channel = manager.channels["mock"] + channel.show_reasoning = False + msg = OutboundMessage( + channel="mock", + chat_id="c1", + content="hidden thinking", + metadata={"_progress": True, "_reasoning_delta": True}, + ) + await manager.bus.publish_outbound(msg) + + await _pump_one(manager) + + channel._delta_mock.assert_not_awaited() + channel._end_mock.assert_not_awaited() + channel._send_mock.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_dispatch_delivers_reasoning_when_channel_opts_in(manager): + channel = manager.channels["mock"] + channel.show_reasoning = True + for chunk in ("first ", "second"): + await manager.bus.publish_outbound(OutboundMessage( + channel="mock", + chat_id="c1", + content=chunk, + metadata={"_progress": True, "_reasoning_delta": True, "_stream_id": "r1"}, + )) + await manager.bus.publish_outbound(OutboundMessage( + channel="mock", + chat_id="c1", + content="", + metadata={"_progress": True, "_reasoning_end": True, "_stream_id": "r1"}, + )) + + await _pump_one(manager) + + assert channel._delta_mock.await_count == 2 + channel._end_mock.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_dispatch_silently_drops_reasoning_for_unknown_channel(manager): + msg = OutboundMessage( + channel="ghost", + chat_id="c1", + content="nobody home", + metadata={"_progress": True, "_reasoning_delta": True}, + ) + await manager.bus.publish_outbound(msg) + + await _pump_one(manager) + + manager.channels["mock"]._delta_mock.assert_not_awaited() + manager.channels["mock"]._send_mock.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_base_channel_reasoning_primitives_are_noop_safe(): + """Plugins that don't override the streaming primitives must not blow up.""" + + class _Plain(BaseChannel): + name = "plain" + display_name = "Plain" + + async def start(self): # pragma: no cover + pass + + async def stop(self): # pragma: no cover + pass + + async def send(self, msg): # pragma: no cover + pass + + channel = _Plain({}, MessageBus()) + assert await channel.send_reasoning_delta("c", "x") is None + assert await channel.send_reasoning_end("c") is None + # And the one-shot wrapper translates without raising. + assert await channel.send_reasoning( + OutboundMessage(channel="plain", chat_id="c", content="x", metadata={}) + ) is None + + +@pytest.mark.asyncio +async def test_reasoning_routing_does_not_consult_send_progress(manager): + """`show_reasoning` is orthogonal to `send_progress` — turning off + progress streaming must not silence reasoning.""" + channel = manager.channels["mock"] + channel.send_progress = False + channel.show_reasoning = True + await manager.bus.publish_outbound(OutboundMessage( + channel="mock", + chat_id="c1", + content="still surfaces", + metadata={"_progress": True, "_reasoning_delta": True}, + )) + + await _pump_one(manager) + + channel._delta_mock.assert_awaited_once() + + +async def _pump_one(manager: ChannelManager) -> None: + """Drive the dispatcher until the outbound queue drains, then cancel.""" + task = asyncio.create_task(manager._dispatch_outbound()) + for _ in range(50): + await asyncio.sleep(0.01) + if manager.bus.outbound.qsize() == 0: + break + task.cancel() + try: + await task + except asyncio.CancelledError: + pass diff --git a/tests/channels/test_slack_channel.py b/tests/channels/test_slack_channel.py index 630685eed..d0f41766a 100644 --- a/tests/channels/test_slack_channel.py +++ b/tests/channels/test_slack_channel.py @@ -234,13 +234,13 @@ async def test_send_renders_buttons_on_last_message_chunk() -> None: "type": "button", "text": {"type": "plain_text", "text": "Yes"}, "value": "Yes", - "action_id": "ask_user_Yes", + "action_id": "btn_Yes", }, { "type": "button", "text": {"type": "plain_text", "text": "No"}, "value": "No", - "action_id": "ask_user_No", + "action_id": "btn_No", }, ], } diff --git a/tests/channels/test_websocket_channel.py b/tests/channels/test_websocket_channel.py index af144dbf7..f11cb21b4 100644 --- a/tests/channels/test_websocket_channel.py +++ b/tests/channels/test_websocket_channel.py @@ -224,11 +224,9 @@ async def test_send_delivers_json_message_with_media_and_reply() -> None: payload = json.loads(mock_ws.send.call_args[0][0]) assert payload["event"] == "message" assert payload["chat_id"] == "chat-1" - assert payload["text"] == "hello\n\n1. Yes\n2. No" - assert payload["button_prompt"] == "hello" + assert payload["text"] == "hello" assert payload["reply_to"] == "m1" assert payload["media"] == ["/tmp/a.png"] - assert payload["buttons"] == [["Yes", "No"]] @pytest.mark.asyncio @@ -360,6 +358,87 @@ async def test_send_delta_emits_delta_and_stream_end() -> None: assert second["stream_id"] == "sid" +@pytest.mark.asyncio +async def test_send_reasoning_delta_emits_streaming_frame() -> None: + bus = MagicMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) + mock_ws = AsyncMock() + channel._attach(mock_ws, "chat-1") + + await channel.send_reasoning_delta( + "chat-1", + "step-by-step thinking", + {"_reasoning_delta": True, "_stream_id": "r1"}, + ) + + mock_ws.send.assert_awaited_once() + payload = json.loads(mock_ws.send.await_args.args[0]) + assert payload["event"] == "reasoning_delta" + assert payload["chat_id"] == "chat-1" + assert payload["text"] == "step-by-step thinking" + assert payload["stream_id"] == "r1" + + +@pytest.mark.asyncio +async def test_send_reasoning_end_emits_close_frame() -> None: + bus = MagicMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) + mock_ws = AsyncMock() + channel._attach(mock_ws, "chat-1") + + await channel.send_reasoning_end("chat-1", {"_reasoning_end": True, "_stream_id": "r1"}) + + payload = json.loads(mock_ws.send.await_args.args[0]) + assert payload == {"event": "reasoning_end", "chat_id": "chat-1", "stream_id": "r1"} + + +@pytest.mark.asyncio +async def test_send_reasoning_one_shot_expands_to_delta_plus_end() -> None: + """``send_reasoning`` is back-compat for hooks that haven't migrated: + the base implementation must produce one delta and one end so the + WebUI sees the same shape either way.""" + bus = MagicMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) + mock_ws = AsyncMock() + channel._attach(mock_ws, "chat-1") + + await channel.send_reasoning(OutboundMessage( + channel="websocket", + chat_id="chat-1", + content="thinking", + metadata={"_reasoning": True}, + )) + + assert mock_ws.send.await_count == 2 + first = json.loads(mock_ws.send.call_args_list[0][0][0]) + second = json.loads(mock_ws.send.call_args_list[1][0][0]) + assert first["event"] == "reasoning_delta" + assert first["text"] == "thinking" + assert second["event"] == "reasoning_end" + + +@pytest.mark.asyncio +async def test_send_reasoning_delta_drops_empty_chunks() -> None: + bus = MagicMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) + mock_ws = AsyncMock() + channel._attach(mock_ws, "chat-1") + + await channel.send_reasoning_delta("chat-1", "", {"_reasoning_delta": True}) + + mock_ws.send.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_send_reasoning_without_subscribers_is_noop() -> None: + bus = MagicMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) + + await channel.send_reasoning_delta("unattached", "thinking", None) + await channel.send_reasoning_end("unattached", None) + # No subscribers, no exception, no send. + + @pytest.mark.asyncio async def test_send_turn_end_emits_turn_end_event() -> None: bus = MagicMock() diff --git a/tests/cli/test_cli_input.py b/tests/cli/test_cli_input.py index e648e818c..34046e8d4 100644 --- a/tests/cli/test_cli_input.py +++ b/tests/cli/test_cli_input.py @@ -1,4 +1,6 @@ import asyncio +from contextlib import nullcontext +from io import StringIO from unittest.mock import AsyncMock, MagicMock, call, patch import pytest @@ -96,6 +98,66 @@ def test_print_cli_progress_line_pauses_spinner_before_printing(): assert order == ["start", "stop", "print", "start", "stop"] +def test_thinking_spinner_clears_status_line_when_paused(): + """Stopping the spinner should erase its transient line before output.""" + stream = StringIO() + stream.isatty = lambda: True # type: ignore[method-assign] + mock_console = MagicMock() + mock_console.file = stream + spinner = MagicMock() + mock_console.status.return_value = spinner + + thinking = stream_mod.ThinkingSpinner(console=mock_console) + with thinking: + with thinking.pause(): + pass + + assert "\r\x1b[2K" in stream.getvalue() + + +def test_stream_renderer_stops_spinner_even_after_header_printed(): + """A later answer delta must stop the spinner even when header already exists.""" + stream = StringIO() + stream.isatty = lambda: True # type: ignore[method-assign] + mock_console = MagicMock() + mock_console.file = stream + spinner = MagicMock() + mock_console.status.return_value = spinner + + with patch.object(stream_mod, "_make_console", return_value=mock_console): + renderer = stream_mod.StreamRenderer(show_spinner=True) + renderer._header_printed = True + renderer.ensure_header() + + spinner.stop.assert_called_once() + assert "\r\x1b[2K" in stream.getvalue() + + +def test_print_cli_progress_line_opens_renderer_header_before_trace(): + """Trace lines should appear under the assistant header, not under You.""" + order: list[str] = [] + renderer = MagicMock() + renderer.console.print.side_effect = lambda *_args, **_kwargs: order.append("print") + renderer.ensure_header.side_effect = lambda: order.append("header") + renderer.pause_spinner.return_value = nullcontext() + + commands._print_cli_progress_line("tool running", None, renderer) + + assert order == ["header", "print"] + + +def test_print_cli_progress_line_stops_live_before_trace(): + """A trace line should not leak the current transient Live frame.""" + mock_live = MagicMock() + renderer = stream_mod.StreamRenderer(show_spinner=False) + renderer._live = mock_live + + commands._print_cli_progress_line("tool running", None, renderer) + + mock_live.stop.assert_called_once() + assert renderer._live is None + + @pytest.mark.asyncio async def test_print_interactive_progress_line_pauses_spinner_before_printing(): """Interactive progress output should also pause spinner cleanly.""" @@ -156,17 +218,65 @@ def test_stream_renderer_stop_for_input_stops_spinner(): # Create renderer with mocked console with patch.object(stream_mod, "_make_console", return_value=mock_console): renderer = stream_mod.StreamRenderer(show_spinner=True) - + # Verify spinner started spinner.start.assert_called_once() - + # Stop for input renderer.stop_for_input() - + # Verify spinner stopped spinner.stop.assert_called_once() +@pytest.mark.asyncio +async def test_on_end_writes_final_content_to_stdout_after_stopping_live(): + """on_end should stop Live (transient erases it) then print final content to stdout.""" + mock_live = MagicMock() + mock_console = MagicMock() + mock_console.capture.return_value.__enter__ = MagicMock( + return_value=MagicMock(get=lambda: "final output\n") + ) + mock_console.capture.return_value.__exit__ = MagicMock(return_value=False) + + with patch.object(stream_mod, "_make_console", return_value=mock_console): + renderer = stream_mod.StreamRenderer(show_spinner=False) + renderer._live = mock_live + renderer._buf = "final output" + + written: list[str] = [] + with patch("sys.stdout") as mock_stdout: + mock_stdout.write = lambda s: written.append(s) + mock_stdout.flush = MagicMock() + await renderer.on_end() + + mock_live.stop.assert_called_once() + assert renderer._live is None + assert written == ["final output\n"] + + +@pytest.mark.asyncio +async def test_on_end_resuming_clears_buffer_and_restarts_spinner(): + """on_end(resuming=True) should reset state for the next iteration.""" + spinner = MagicMock() + mock_console = MagicMock() + mock_console.status.return_value = spinner + mock_console.capture.return_value.__enter__ = MagicMock( + return_value=MagicMock(get=lambda: "") + ) + mock_console.capture.return_value.__exit__ = MagicMock(return_value=False) + + with patch.object(stream_mod, "_make_console", return_value=mock_console): + renderer = stream_mod.StreamRenderer(show_spinner=True) + renderer._buf = "some content" + + await renderer.on_end(resuming=True) + + assert renderer._buf == "" + # Spinner should have been restarted (start called twice: __init__ + resuming) + assert spinner.start.call_count == 2 + + def test_make_console_force_terminal_when_stdout_is_tty(): """Console should set force_terminal=True when stdout is a TTY (rich output).""" import sys diff --git a/tests/cli/test_interactive_retry_wait.py b/tests/cli/test_interactive_retry_wait.py index 5cc217c56..52c27d2c9 100644 --- a/tests/cli/test_interactive_retry_wait.py +++ b/tests/cli/test_interactive_retry_wait.py @@ -17,7 +17,7 @@ async def test_interactive_retry_wait_is_rendered_as_progress_even_when_progress metadata={"_retry_wait": True}, ) - async def fake_print(text: str, active_thinking: object | None) -> None: + async def fake_print(text: str, active_thinking: object | None, renderer=None) -> None: calls.append((text, active_thinking)) with patch("nanobot.cli.commands._print_interactive_progress_line", side_effect=fake_print): @@ -29,3 +29,104 @@ async def test_interactive_retry_wait_is_rendered_as_progress_even_when_progress assert handled is True assert calls == [("Model request failed, retry in 2s (attempt 1).", thinking)] + + +@pytest.mark.asyncio +async def test_reasoning_displayed_when_show_reasoning_enabled(): + """Reasoning content should be displayed when show_reasoning is True.""" + calls: list[str] = [] + channels_config = SimpleNamespace( + send_progress=True, send_tool_hints=False, show_reasoning=True, + ) + msg = SimpleNamespace( + content="Let me think about this...", + metadata={"_progress": True, "_reasoning": True}, + ) + + with patch("nanobot.cli.commands._print_cli_reasoning", side_effect=lambda t, th, r=None: calls.append(t)): + handled = await commands._maybe_print_interactive_progress(msg, None, channels_config) + + assert handled is True + assert calls == ["Let me think about this..."] + + +@pytest.mark.asyncio +async def test_reasoning_delta_displayed_when_show_reasoning_enabled(): + """Streamed reasoning delta frames should use the reasoning renderer.""" + calls: list[str] = [] + channels_config = SimpleNamespace( + send_progress=True, send_tool_hints=False, show_reasoning=True, + ) + msg = SimpleNamespace( + content="I should search first.", + metadata={"_progress": True, "_reasoning_delta": True}, + ) + + with patch("nanobot.cli.commands._print_cli_reasoning", side_effect=lambda t, th, r=None: calls.append(t)): + handled = await commands._maybe_print_interactive_progress(msg, None, channels_config) + + assert handled is True + assert calls == ["I should search first."] + + +@pytest.mark.asyncio +async def test_reasoning_hidden_when_show_reasoning_disabled(): + """Reasoning content should be suppressed when show_reasoning is False.""" + channels_config = SimpleNamespace( + send_progress=True, send_tool_hints=False, show_reasoning=False, + ) + msg = SimpleNamespace( + content="Let me think about this...", + metadata={"_progress": True, "_reasoning": True}, + ) + + with patch("nanobot.cli.commands._print_cli_reasoning") as mock_reasoning: + handled = await commands._maybe_print_interactive_progress(msg, None, channels_config) + + assert handled is True + mock_reasoning.assert_not_called() + + +@pytest.mark.asyncio +async def test_non_reasoning_progress_not_affected_by_show_reasoning(): + """Regular progress lines should display regardless of show_reasoning.""" + calls: list[str] = [] + channels_config = SimpleNamespace( + send_progress=True, send_tool_hints=False, show_reasoning=False, + ) + msg = SimpleNamespace( + content="working on it...", + metadata={"_progress": True}, + ) + + async def fake_print(text: str, thinking=None, renderer=None): + calls.append(text) + + with patch("nanobot.cli.commands._print_interactive_progress_line", side_effect=fake_print): + handled = await commands._maybe_print_interactive_progress(msg, None, channels_config) + + assert handled is True + assert calls == ["working on it..."] + + +@pytest.mark.asyncio +async def test_reasoning_shown_when_send_progress_disabled(): + """Reasoning display is governed by `show_reasoning` alone, independent + of `send_progress` — the two knobs are orthogonal.""" + calls: list[str] = [] + channels_config = SimpleNamespace( + send_progress=False, send_tool_hints=False, show_reasoning=True, + ) + msg = SimpleNamespace( + content="Let me think about this...", + metadata={"_progress": True, "_reasoning": True}, + ) + + with patch( + "nanobot.cli.commands._print_cli_reasoning", + side_effect=lambda t, th, r=None: calls.append(t), + ): + handled = await commands._maybe_print_interactive_progress(msg, None, channels_config) + + assert handled is True + assert calls == ["Let me think about this..."] diff --git a/tests/tools/test_tool_loader.py b/tests/tools/test_tool_loader.py index 60ad8057b..fa33b140b 100644 --- a/tests/tools/test_tool_loader.py +++ b/tests/tools/test_tool_loader.py @@ -405,7 +405,7 @@ def test_loader_registers_same_tools_as_old_hardcoded(): registered = loader.load(ctx, registry) expected = { - "ask_user", "read_file", "write_file", "edit_file", "list_dir", + "read_file", "write_file", "edit_file", "list_dir", "glob", "grep", "notebook_edit", "exec", "web_search", "web_fetch", "message", "spawn", "cron", } diff --git a/tests/utils/test_strip_think.py b/tests/utils/test_strip_think.py index 5db93e658..f1048f40c 100644 --- a/tests/utils/test_strip_think.py +++ b/tests/utils/test_strip_think.py @@ -1,4 +1,4 @@ -from nanobot.utils.helpers import strip_think +from nanobot.utils.helpers import extract_reasoning, extract_think, strip_think class TestStripThinkTag: @@ -144,3 +144,130 @@ class TestStripThinkConservativePreserve: def test_literal_channel_marker_in_code_block_preserved(self): text = "Example:\n```\nif line.startswith(''):\n skip()\n```" assert strip_think(text) == text + + +class TestExtractThink: + + def test_no_think_tags(self): + thinking, clean = extract_think("Hello World") + assert thinking is None + assert clean == "Hello World" + + def test_single_think_block(self): + text = "Hello reasoning content\nhere World" + thinking, clean = extract_think(text) + assert thinking == "reasoning content\nhere" + assert clean == "Hello World" + + def test_single_thought_block(self): + text = "Hello reasoning content World" + thinking, clean = extract_think(text) + assert thinking == "reasoning content" + assert clean == "Hello World" + + def test_multiple_think_blocks(self): + text = "AfirstBsecondC" + thinking, clean = extract_think(text) + assert thinking == "first\n\nsecond" + assert clean == "ABC" + + def test_think_only_no_content(self): + text = "just thinking" + thinking, clean = extract_think(text) + assert thinking == "just thinking" + assert clean == "" + + def test_unclosed_think_not_extracted(self): + # Unclosed blocks at start are stripped but NOT extracted + text = "unclosed thinking..." + thinking, clean = extract_think(text) + assert thinking is None + assert clean == "" + + def test_empty_think_block(self): + text = "Hello World" + thinking, clean = extract_think(text) + # Empty blocks result in empty string after strip + assert thinking == "" + assert clean == "Hello World" + + def test_think_with_whitespace_only(self): + text = "Hello \n World" + thinking, clean = extract_think(text) + assert thinking is None + assert clean == "Hello \n World" + + def test_mixed_think_and_thought(self): + text = "Startfirst reasoningmiddlesecond reasoningEnd" + thinking, clean = extract_think(text) + assert thinking == "first reasoning\n\nsecond reasoning" + assert clean == "StartmiddleEnd" + + def test_real_world_ollama_response(self): + text = """ +The user is asking about Python list comprehensions. +Let me explain the syntax and give examples. + + +List comprehensions in Python provide a concise way to create lists. Here's the syntax: + +```python +[expression for item in iterable if condition] +``` + +For example: +```python +squares = [x**2 for x in range(10)] +```""" + thinking, clean = extract_think(text) + assert "list comprehensions" in thinking.lower() + assert "Let me explain" in thinking + assert "List comprehensions in Python" in clean + assert "" not in clean + assert "" not in clean + + +class TestExtractReasoning: + """Single source of truth for reasoning extraction across all providers.""" + + def test_prefers_reasoning_content_and_strips_inline_think(self): + # Dedicated field wins; inline tags are still scrubbed from content. + reasoning, content = extract_reasoning( + "dedicated", + None, + "inlinevisible answer", + ) + assert reasoning == "dedicated" + assert content == "visible answer" + + def test_falls_back_to_thinking_blocks(self): + reasoning, content = extract_reasoning( + None, + [ + {"type": "thinking", "thinking": "step 1"}, + {"type": "thinking", "thinking": "step 2"}, + {"type": "redacted_thinking"}, + ], + "hello", + ) + assert reasoning == "step 1\n\nstep 2" + assert content == "hello" + + def test_falls_back_to_inline_think_tags(self): + reasoning, content = extract_reasoning( + None, None, "plananswer" + ) + assert reasoning == "plan" + assert content == "answer" + + def test_no_reasoning_returns_none(self): + reasoning, content = extract_reasoning(None, None, "plain answer") + assert reasoning is None + assert content == "plain answer" + + def test_empty_thinking_blocks_falls_through_to_inline(self): + reasoning, content = extract_reasoning( + None, [], "plananswer" + ) + assert reasoning == "plan" + assert content == "answer" diff --git a/webui/src/App.tsx b/webui/src/App.tsx index 1cadcc231..d5b7485a6 100644 --- a/webui/src/App.tsx +++ b/webui/src/App.tsx @@ -250,7 +250,6 @@ function Shell({ onModelNameChange, onLogout }: { onModelNameChange: (modelName: key: string; label: string; } | null>(null); - const lastSessionsLen = useRef(0); const restartSawDisconnectRef = useRef(false); const [restartToast, setRestartToast] = useState(null); const [isRestarting, setIsRestarting] = useState(false); @@ -266,13 +265,7 @@ function Shell({ onModelNameChange, onLogout }: { onModelNameChange: (modelName: } }, [desktopSidebarOpen]); - useEffect(() => { - if (activeKey) return; - if (sessions.length > 0 && lastSessionsLen.current === 0) { - setActiveKey(sessions[0].key); - } - lastSessionsLen.current = sessions.length; - }, [sessions, activeKey]); + const activeSession = useMemo(() => { if (!activeKey) return null; @@ -335,9 +328,8 @@ function Shell({ onModelNameChange, onLogout }: { onModelNameChange: (modelName: setView("chat"); setMobileSidebarOpen(false); setActiveKey((current) => { - if (current && sessions.some((session) => session.key === current)) { - return current; - } + if (!current) return null; + if (sessions.some((session) => session.key === current)) return current; return sessions[0]?.key ?? null; }); }, [sessions]); @@ -479,18 +471,13 @@ function Shell({ onModelNameChange, onLogout }: { onModelNameChange: (modelName: ) : null} -
- {view === "settings" ? ( - - ) : ( +
+
+
+ {view === "settings" && ( +
+ +
)}
diff --git a/webui/src/components/MessageBubble.tsx b/webui/src/components/MessageBubble.tsx index 3bd580567..bd1d8c93b 100644 --- a/webui/src/components/MessageBubble.tsx +++ b/webui/src/components/MessageBubble.tsx @@ -1,5 +1,5 @@ import { useCallback, useEffect, useRef, useState } from "react"; -import { Check, ChevronRight, Copy, FileIcon, ImageIcon, PlaySquare, Wrench } from "lucide-react"; +import { Check, ChevronRight, Copy, FileIcon, ImageIcon, PlaySquare, Sparkles, Wrench } from "lucide-react"; import { useTranslation } from "react-i18next"; import { ImageLightbox } from "@/components/ImageLightbox"; @@ -85,12 +85,18 @@ export function MessageBubble({ message }: MessageBubbleProps) { const empty = message.content.trim().length === 0; const media = message.media ?? []; + const reasoning = message.role === "assistant" ? message.reasoning ?? "" : ""; + const reasoningStreaming = !!(message.role === "assistant" && message.reasoningStreaming); + const hasReasoning = reasoning.length > 0 || reasoningStreaming; const showAssistantActions = message.role === "assistant" && !message.isStreaming && !empty; return (
- {empty && message.isStreaming ? ( + {hasReasoning ? ( + + ) : null} + {empty && message.isStreaming && !hasReasoning ? ( - ) : ( + ) : empty && message.isStreaming ? null : ( <> {message.content} {message.isStreaming && } @@ -380,14 +386,14 @@ interface TraceGroupProps { /** * Collapsible group of tool-call / progress breadcrumbs. Defaults to - * expanded for discoverability; a single click on the header folds the - * group down to a one-line summary so it never dominates the thread. + * collapsed because tool traces are supporting evidence, not the answer. + * A single click expands the exact calls when the user wants details. */ function TraceGroup({ message, animClass }: TraceGroupProps) { const { t } = useTranslation(); const lines = message.traces ?? [message.content]; const count = lines.length; - const [open, setOpen] = useState(true); + const [open, setOpen] = useState(false); return (
+ {open && text.length > 0 && ( +
+ {text} +
+ )} +
+ ); +} diff --git a/webui/src/components/thread/AskUserPrompt.tsx b/webui/src/components/thread/AskUserPrompt.tsx deleted file mode 100644 index 4de76307c..000000000 --- a/webui/src/components/thread/AskUserPrompt.tsx +++ /dev/null @@ -1,108 +0,0 @@ -import { useCallback, useEffect, useRef, useState } from "react"; -import { MessageSquareText } from "lucide-react"; - -import { Button } from "@/components/ui/button"; -import { cn } from "@/lib/utils"; - -interface AskUserPromptProps { - question: string; - buttons: string[][]; - onAnswer: (answer: string) => void; -} - -export function AskUserPrompt({ - question, - buttons, - onAnswer, -}: AskUserPromptProps) { - const [customOpen, setCustomOpen] = useState(false); - const [custom, setCustom] = useState(""); - const inputRef = useRef(null); - const options = buttons.flat().filter(Boolean); - - useEffect(() => { - if (customOpen) { - inputRef.current?.focus(); - } - }, [customOpen]); - - const submitCustom = useCallback(() => { - const answer = custom.trim(); - if (!answer) return; - onAnswer(answer); - setCustom(""); - setCustomOpen(false); - }, [custom, onAnswer]); - - if (options.length === 0) return null; - - return ( -
-
-
- -
-

- {question} -

-
- -
- {options.map((option) => ( - - ))} - -
- - {customOpen ? ( -
-