mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
Merge remote-tracking branch 'origin/main' into pr-3756
This commit is contained in:
commit
eaa8ebd5d3
@ -238,6 +238,9 @@ nanobot channels login <channel_name> --force # re-authenticate
|
||||
| `supports_streaming` (property) | `True` when config has `"streaming": true` **and** subclass overrides `send_delta()`. |
|
||||
| `is_running` | Returns `self._running`. |
|
||||
| `login(force=False)` | Perform interactive login (e.g. QR code scan). Returns `True` if already authenticated or login succeeds. Override in subclasses that support interactive login. |
|
||||
| `send_reasoning_delta(chat_id, delta, metadata?)` | Optional hook for streamed model reasoning/thinking content. Default is no-op. |
|
||||
| `send_reasoning_end(chat_id, metadata?)` | Optional hook marking the end of a reasoning block. Default is no-op. |
|
||||
| `send_reasoning(msg)` | Optional one-shot reasoning fallback. Default translates to `send_reasoning_delta()` + `send_reasoning_end()`. |
|
||||
|
||||
### Optional (streaming)
|
||||
|
||||
@ -350,6 +353,112 @@ When `streaming` is `false` (default) or omitted, only `send()` is called — no
|
||||
| `async send_delta(chat_id, delta, metadata?)` | Override to handle streaming chunks. No-op by default. |
|
||||
| `supports_streaming` (property) | Returns `True` when config has `streaming: true` **and** subclass overrides `send_delta`. |
|
||||
|
||||
## Progress, Tool Hints, and Reasoning
|
||||
|
||||
Besides normal assistant text, nanobot can emit low-emphasis trace blocks. These are intended for UI affordances like status rows, collapsible "used tools" groups, or reasoning/thinking blocks. Platforms that do not have a good place for them can ignore them safely.
|
||||
|
||||
### Progress and Tool Hints
|
||||
|
||||
Progress and tool hints arrive through the normal `send(msg)` path. Check `msg.metadata` before rendering:
|
||||
|
||||
```python
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
meta = msg.metadata or {}
|
||||
|
||||
if meta.get("_tool_hint"):
|
||||
# A short tool breadcrumb, e.g. read_file("config.json")
|
||||
await self._send_trace(msg.chat_id, msg.content, kind="tool")
|
||||
return
|
||||
|
||||
if meta.get("_progress"):
|
||||
# Generic non-final status, e.g. "Thinking..." or "Running command..."
|
||||
await self._send_trace(msg.chat_id, msg.content, kind="progress")
|
||||
return
|
||||
|
||||
await self._send_message(msg.chat_id, msg.content, media=msg.media)
|
||||
```
|
||||
|
||||
Tool hints are off by default for most channels. Users can enable them globally or per channel:
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"sendToolHints": true,
|
||||
"webhook": {
|
||||
"enabled": true,
|
||||
"sendToolHints": true
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Reasoning Blocks
|
||||
|
||||
Reasoning is delivered through dedicated optional hooks, not `send()`. Override `send_reasoning_delta()` and `send_reasoning_end()` if your platform can show model reasoning as a subdued/collapsible block. The default implementation is a no-op, so unsupported channels simply drop reasoning content.
|
||||
|
||||
```python
|
||||
class WebhookChannel(BaseChannel):
|
||||
name = "webhook"
|
||||
display_name = "Webhook"
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
if isinstance(config, dict):
|
||||
config = WebhookConfig(**config)
|
||||
super().__init__(config, bus)
|
||||
self._reasoning_buffers: dict[str, str] = {}
|
||||
|
||||
async def send_reasoning_delta(
|
||||
self,
|
||||
chat_id: str,
|
||||
delta: str,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
meta = metadata or {}
|
||||
stream_id = str(meta.get("_stream_id") or chat_id)
|
||||
self._reasoning_buffers[stream_id] = self._reasoning_buffers.get(stream_id, "") + delta
|
||||
await self._update_reasoning_block(chat_id, self._reasoning_buffers[stream_id], final=False)
|
||||
|
||||
async def send_reasoning_end(
|
||||
self,
|
||||
chat_id: str,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
meta = metadata or {}
|
||||
stream_id = str(meta.get("_stream_id") or chat_id)
|
||||
text = self._reasoning_buffers.pop(stream_id, "")
|
||||
if text:
|
||||
await self._update_reasoning_block(chat_id, text, final=True)
|
||||
```
|
||||
|
||||
**Reasoning metadata flags:**
|
||||
|
||||
| Flag | Meaning |
|
||||
|------|---------|
|
||||
| `_reasoning_delta: True` | A reasoning/thinking chunk; `delta` contains the new text. |
|
||||
| `_reasoning_end: True` | The current reasoning block is complete; `delta` is empty. |
|
||||
| `_reasoning: True` | Legacy one-shot reasoning. `BaseChannel.send_reasoning()` converts it to delta + end. |
|
||||
| `_stream_id` | Stable id for this assistant turn/segment. Use it to key buffers instead of only `chat_id`. |
|
||||
|
||||
Reasoning visibility is controlled by `showReasoning` globally or per channel:
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"showReasoning": true,
|
||||
"webhook": {
|
||||
"enabled": true,
|
||||
"showReasoning": true
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Recommended rendering:
|
||||
|
||||
- Render tool hints and progress as trace/status UI, not as normal assistant replies.
|
||||
- Render reasoning with lower visual emphasis and collapse it after completion when the platform supports that.
|
||||
- Keep reasoning separate from final answer text. A final answer still arrives through `send()` or `send_delta()`.
|
||||
|
||||
## Config
|
||||
|
||||
### Why Pydantic model is required
|
||||
|
||||
@ -743,6 +743,7 @@ Global settings that apply to all channels. Configure under the `channels` secti
|
||||
|---------|---------|-------------|
|
||||
| `sendProgress` | `true` | Stream agent's text progress to the channel |
|
||||
| `sendToolHints` | `false` | Stream tool-call hints (e.g. `read_file("…")`) |
|
||||
| `showReasoning` | `true` | Allow channels to surface model reasoning/thinking content (DeepSeek-R1 `reasoning_content`, Anthropic `thinking_blocks`, inline `<think>` tags). Reasoning flows as a dedicated stream with `_reasoning_delta` / `_reasoning_end` markers — channels override `send_reasoning_delta` / `send_reasoning_end` to render in-place updates. Even with `true`, channels without those overrides stay no-op silently. Currently surfaced on CLI and WebSocket/WebUI (italic shimmer header, auto-collapses after the stream ends); Telegram / Slack / Discord / Feishu / WeChat / Matrix keep the base no-op until their bubble UI is adapted. Independent of `sendProgress`. |
|
||||
| `sendMaxRetries` | `3` | Max delivery attempts per outbound message, including the initial send (0-10 configured, minimum 1 actual attempt) |
|
||||
| `transcriptionProvider` | `"groq"` | Voice transcription backend: `"groq"` (free tier, default) or `"openai"`. API key is auto-resolved from the matching provider config. |
|
||||
| `transcriptionLanguage` | `null` | Optional ISO-639-1 language hint for audio transcription, e.g. `"en"`, `"ko"`, `"ja"`. |
|
||||
|
||||
@ -128,6 +128,29 @@ All frames are JSON text. Each message has an `event` field.
|
||||
}
|
||||
```
|
||||
|
||||
**`reasoning_delta`** — incremental model reasoning / thinking chunk for the active assistant turn. Mirrors `delta` but targets the reasoning bubble above the answer rather than the answer body:
|
||||
|
||||
```json
|
||||
{
|
||||
"event": "reasoning_delta",
|
||||
"chat_id": "uuid-v4",
|
||||
"text": "Let me decompose ",
|
||||
"stream_id": "r1"
|
||||
}
|
||||
```
|
||||
|
||||
**`reasoning_end`** — close marker for the active reasoning stream. WebUI uses this to lock the in-place bubble and switch from the shimmer header to a static collapsed state:
|
||||
|
||||
```json
|
||||
{
|
||||
"event": "reasoning_end",
|
||||
"chat_id": "uuid-v4",
|
||||
"stream_id": "r1"
|
||||
}
|
||||
```
|
||||
|
||||
Reasoning frames only flow when the channel's `showReasoning` is `true` (default) and the model returns reasoning content (DeepSeek-R1 / Kimi / MiMo / OpenAI reasoning models, Anthropic extended thinking, or inline `<think>` / `<thought>` tags). Models without reasoning produce zero `reasoning_delta` frames.
|
||||
|
||||
**`runtime_model_updated`** — broadcast when the gateway runtime model changes, for example after `/model <preset>`:
|
||||
|
||||
```json
|
||||
|
||||
@ -22,6 +22,7 @@ class AgentHookContext:
|
||||
tool_results: list[Any] = field(default_factory=list)
|
||||
tool_events: list[dict[str, str]] = field(default_factory=list)
|
||||
streamed_content: bool = False
|
||||
streamed_reasoning: bool = False
|
||||
final_content: str | None = None
|
||||
stop_reason: str | None = None
|
||||
error: str | None = None
|
||||
@ -48,6 +49,17 @@ class AgentHook:
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
pass
|
||||
|
||||
async def emit_reasoning(self, reasoning_content: str | None) -> None:
|
||||
pass
|
||||
|
||||
async def emit_reasoning_end(self) -> None:
|
||||
"""Mark the end of an in-flight reasoning stream.
|
||||
|
||||
Hooks that buffer ``emit_reasoning`` chunks (for in-place UI updates)
|
||||
flush and freeze the rendered group here. One-shot hooks ignore.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
pass
|
||||
|
||||
@ -95,6 +107,12 @@ class CompositeHook(AgentHook):
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
await self._for_each_hook_safe("before_execute_tools", context)
|
||||
|
||||
async def emit_reasoning(self, reasoning_content: str | None) -> None:
|
||||
await self._for_each_hook_safe("emit_reasoning", reasoning_content)
|
||||
|
||||
async def emit_reasoning_end(self) -> None:
|
||||
await self._for_each_hook_safe("emit_reasoning_end")
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
await self._for_each_hook_safe("after_iteration", context)
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from contextlib import AsyncExitStack, nullcontext, suppress
|
||||
@ -15,19 +14,14 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent import model_presets as preset_helpers
|
||||
from nanobot.agent.autocompact import AutoCompact
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||
from nanobot.agent.hook import AgentHook, CompositeHook
|
||||
from nanobot.agent.memory import Consolidator, Dream
|
||||
from nanobot.agent import model_presets as preset_helpers
|
||||
from nanobot.agent.progress_hook import AgentProgressHook
|
||||
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.agent.tools.ask import (
|
||||
ask_user_options_from_messages,
|
||||
ask_user_outbound,
|
||||
ask_user_tool_result_messages,
|
||||
pending_ask_user_id,
|
||||
)
|
||||
from nanobot.agent.tools.file_state import FileStateStore, bind_file_states, reset_file_states
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
@ -44,12 +38,6 @@ from nanobot.utils.document import extract_documents
|
||||
from nanobot.utils.helpers import image_placeholder_text
|
||||
from nanobot.utils.helpers import truncate_text as truncate_text_fn
|
||||
from nanobot.utils.image_generation_intent import image_generation_prompt
|
||||
from nanobot.utils.progress_events import (
|
||||
build_tool_event_finish_payloads,
|
||||
build_tool_event_start_payload,
|
||||
invoke_on_progress,
|
||||
on_progress_accepts_tool_events,
|
||||
)
|
||||
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
from nanobot.utils.webui_titles import mark_webui_session, maybe_generate_webui_title_after_turn
|
||||
|
||||
@ -65,114 +53,6 @@ if TYPE_CHECKING:
|
||||
UNIFIED_SESSION_KEY = "unified:default"
|
||||
|
||||
|
||||
class _LoopHook(AgentHook):
|
||||
"""Core hook for the main loop."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_loop: AgentLoop,
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||
*,
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
message_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
session_key: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(reraise=True)
|
||||
self._loop = agent_loop
|
||||
self._on_progress = on_progress
|
||||
self._on_stream = on_stream
|
||||
self._on_stream_end = on_stream_end
|
||||
self._channel = channel
|
||||
self._chat_id = chat_id
|
||||
self._message_id = message_id
|
||||
self._metadata = metadata or {}
|
||||
self._session_key = session_key
|
||||
self._stream_buf = ""
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
return self._on_stream is not None
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
from nanobot.utils.helpers import strip_think
|
||||
|
||||
prev_clean = strip_think(self._stream_buf)
|
||||
self._stream_buf += delta
|
||||
new_clean = strip_think(self._stream_buf)
|
||||
incremental = new_clean[len(prev_clean) :]
|
||||
if incremental and self._on_stream:
|
||||
await self._on_stream(incremental)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
if self._on_stream_end:
|
||||
await self._on_stream_end(resuming=resuming)
|
||||
self._stream_buf = ""
|
||||
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
self._loop._current_iteration = context.iteration
|
||||
logger.debug(
|
||||
"Starting agent loop iteration {} for session {}",
|
||||
context.iteration,
|
||||
self._session_key,
|
||||
)
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
if self._on_progress:
|
||||
if not self._on_stream and not context.streamed_content:
|
||||
thought = self._loop._strip_think(
|
||||
context.response.content if context.response else None
|
||||
)
|
||||
if thought:
|
||||
await self._on_progress(thought)
|
||||
tool_hint = self._loop._strip_think(self._loop._tool_hint(context.tool_calls))
|
||||
tool_events = [build_tool_event_start_payload(tc) for tc in context.tool_calls]
|
||||
await invoke_on_progress(
|
||||
self._on_progress,
|
||||
tool_hint,
|
||||
tool_hint=True,
|
||||
tool_events=tool_events,
|
||||
)
|
||||
for tc in context.tool_calls:
|
||||
args_str = json.dumps(tc.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
||||
self._loop._set_tool_context(
|
||||
self._channel,
|
||||
self._chat_id,
|
||||
self._message_id,
|
||||
self._metadata,
|
||||
session_key=self._session_key,
|
||||
)
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
if (
|
||||
self._on_progress
|
||||
and context.tool_calls
|
||||
and context.tool_events
|
||||
and on_progress_accepts_tool_events(self._on_progress)
|
||||
):
|
||||
tool_events = build_tool_event_finish_payloads(context)
|
||||
if tool_events:
|
||||
await invoke_on_progress(
|
||||
self._on_progress,
|
||||
"",
|
||||
tool_hint=False,
|
||||
tool_events=tool_events,
|
||||
)
|
||||
u = context.usage or {}
|
||||
logger.debug(
|
||||
"LLM usage: prompt={} completion={} cached={}",
|
||||
u.get("prompt_tokens", 0),
|
||||
u.get("completion_tokens", 0),
|
||||
u.get("cached_tokens", 0),
|
||||
)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return self._loop._strip_think(content)
|
||||
|
||||
|
||||
class TurnState(Enum):
|
||||
RESTORE = auto()
|
||||
COMPACT = auto()
|
||||
@ -623,26 +503,11 @@ class AgentLoop:
|
||||
if tool and isinstance(tool, ContextAware):
|
||||
tool.set_context(request_ctx)
|
||||
|
||||
@staticmethod
|
||||
def _strip_think(text: str | None) -> str | None:
|
||||
"""Remove <think>…</think> blocks that some models embed in content."""
|
||||
if not text:
|
||||
return None
|
||||
from nanobot.utils.helpers import strip_think
|
||||
|
||||
return strip_think(text) or None
|
||||
|
||||
@staticmethod
|
||||
def _runtime_chat_id(msg: InboundMessage) -> str:
|
||||
"""Return the chat id shown in runtime metadata for the model."""
|
||||
return str(msg.metadata.get("context_chat_id") or msg.chat_id)
|
||||
|
||||
def _tool_hint(self, tool_calls: list) -> str:
|
||||
"""Format tool calls as concise hints with smart abbreviation."""
|
||||
from nanobot.utils.tool_hints import format_tool_hints
|
||||
|
||||
return format_tool_hints(tool_calls, max_length=self.tool_hint_max_length)
|
||||
|
||||
async def _build_bus_progress_callback(
|
||||
self, msg: InboundMessage
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
@ -653,10 +518,16 @@ class AgentLoop:
|
||||
*,
|
||||
tool_hint: bool = False,
|
||||
tool_events: list[dict[str, Any]] | None = None,
|
||||
reasoning: bool = False,
|
||||
reasoning_end: bool = False,
|
||||
) -> None:
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_progress"] = True
|
||||
meta["_tool_hint"] = tool_hint
|
||||
if reasoning:
|
||||
meta["_reasoning_delta"] = True
|
||||
if reasoning_end:
|
||||
meta["_reasoning_end"] = True
|
||||
if tool_events:
|
||||
meta["_tool_events"] = tool_events
|
||||
await self.bus.publish_outbound(
|
||||
@ -693,7 +564,6 @@ class AgentLoop:
|
||||
self,
|
||||
msg: InboundMessage,
|
||||
session: Session,
|
||||
pending_ask_id: str | None,
|
||||
) -> bool:
|
||||
"""Persist the triggering user message before the turn starts.
|
||||
|
||||
@ -701,7 +571,7 @@ class AgentLoop:
|
||||
"""
|
||||
media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p]
|
||||
has_text = isinstance(msg.content, str) and msg.content.strip()
|
||||
if not pending_ask_id and (has_text or media_paths):
|
||||
if has_text or media_paths:
|
||||
extra: dict[str, Any] = {"media": list(media_paths)} if media_paths else {}
|
||||
text = msg.content if isinstance(msg.content, str) else ""
|
||||
session.add_message("user", text, **extra)
|
||||
@ -715,21 +585,9 @@ class AgentLoop:
|
||||
msg: InboundMessage,
|
||||
session: Session,
|
||||
history: list[dict[str, Any]],
|
||||
pending_ask_id: str | None,
|
||||
pending_summary: str | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build the initial message list for the LLM turn."""
|
||||
if pending_ask_id:
|
||||
system_prompt = self.context.build_system_prompt(
|
||||
channel=msg.channel,
|
||||
session_summary=pending_summary,
|
||||
)
|
||||
return ask_user_tool_result_messages(
|
||||
system_prompt,
|
||||
history,
|
||||
pending_ask_id,
|
||||
image_generation_prompt(msg.content, msg.metadata),
|
||||
)
|
||||
return self.context.build_messages(
|
||||
history=history,
|
||||
current_message=image_generation_prompt(msg.content, msg.metadata),
|
||||
@ -813,8 +671,7 @@ class AgentLoop:
|
||||
"""
|
||||
self._sync_subagent_runtime_limits()
|
||||
|
||||
loop_hook = _LoopHook(
|
||||
self,
|
||||
loop_hook = AgentProgressHook(
|
||||
on_progress=on_progress,
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
@ -823,6 +680,9 @@ class AgentLoop:
|
||||
message_id=message_id,
|
||||
metadata=metadata,
|
||||
session_key=session_key,
|
||||
tool_hint_max_length=self.tool_hint_max_length,
|
||||
set_tool_context=self._set_tool_context,
|
||||
on_iteration=lambda iteration: setattr(self, "_current_iteration", iteration),
|
||||
)
|
||||
hook: AgentHook = (
|
||||
CompositeHook([loop_hook] + self._extra_hooks) if self._extra_hooks else loop_hook
|
||||
@ -1237,12 +1097,7 @@ class AgentLoop:
|
||||
replay_max_messages=self._max_messages,
|
||||
)
|
||||
)
|
||||
options = ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else []
|
||||
content, buttons = ask_user_outbound(
|
||||
final_content or "Background task completed.",
|
||||
options,
|
||||
channel,
|
||||
)
|
||||
content = final_content or "Background task completed."
|
||||
outbound_metadata: dict[str, Any] = {}
|
||||
if channel == "slack" and key.startswith("slack:") and key.count(":") >= 2:
|
||||
outbound_metadata["slack"] = {"thread_ts": key.split(":", 2)[2]}
|
||||
@ -1252,7 +1107,6 @@ class AgentLoop:
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
buttons=buttons,
|
||||
metadata=outbound_metadata,
|
||||
)
|
||||
|
||||
@ -1365,21 +1219,15 @@ class AgentLoop:
|
||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
|
||||
meta = dict(msg.metadata or {})
|
||||
content, buttons = ask_user_outbound(
|
||||
final_content,
|
||||
ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [],
|
||||
msg.channel,
|
||||
)
|
||||
if on_stream is not None and stop_reason not in {"ask_user", "error", "tool_error"}:
|
||||
if on_stream is not None and stop_reason not in {"error", "tool_error"}:
|
||||
meta["_streamed"] = True
|
||||
|
||||
return OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content=content,
|
||||
content=final_content,
|
||||
media=generated_media,
|
||||
metadata=meta,
|
||||
buttons=buttons,
|
||||
)
|
||||
|
||||
async def _state_restore(self, ctx: TurnContext) -> TurnState:
|
||||
@ -1446,12 +1294,11 @@ class AgentLoop:
|
||||
}
|
||||
ctx.history = ctx.session.get_history(**_hist_kwargs)
|
||||
|
||||
pending_ask_id = pending_ask_user_id(ctx.history)
|
||||
ctx.initial_messages = self._build_initial_messages(
|
||||
ctx.msg, ctx.session, ctx.history, pending_ask_id, ctx.pending_summary
|
||||
ctx.msg, ctx.session, ctx.history, ctx.pending_summary
|
||||
)
|
||||
ctx.user_persisted_early = self._persist_user_message_early(
|
||||
ctx.msg, ctx.session, pending_ask_id
|
||||
ctx.msg, ctx.session
|
||||
)
|
||||
|
||||
if ctx.on_progress is None:
|
||||
|
||||
178
nanobot/agent/progress_hook.py
Normal file
178
nanobot/agent/progress_hook.py
Normal file
@ -0,0 +1,178 @@
|
||||
"""Agent hook that adapts runner events into channel progress UI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.utils.helpers import IncrementalThinkExtractor, strip_think
|
||||
from nanobot.utils.progress_events import (
|
||||
build_tool_event_finish_payloads,
|
||||
build_tool_event_start_payload,
|
||||
invoke_on_progress,
|
||||
on_progress_accepts_tool_events,
|
||||
)
|
||||
from nanobot.utils.tool_hints import format_tool_hints
|
||||
|
||||
|
||||
class AgentProgressHook(AgentHook):
|
||||
"""Translate runner lifecycle events into user-visible progress signals."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||
*,
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
message_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
session_key: str | None = None,
|
||||
tool_hint_max_length: int = 40,
|
||||
set_tool_context: Callable[..., None] | None = None,
|
||||
on_iteration: Callable[[int], None] | None = None,
|
||||
) -> None:
|
||||
super().__init__(reraise=True)
|
||||
self._on_progress = on_progress
|
||||
self._on_stream = on_stream
|
||||
self._on_stream_end = on_stream_end
|
||||
self._channel = channel
|
||||
self._chat_id = chat_id
|
||||
self._message_id = message_id
|
||||
self._metadata = metadata or {}
|
||||
self._session_key = session_key
|
||||
self._tool_hint_max_length = tool_hint_max_length
|
||||
self._set_tool_context = set_tool_context
|
||||
self._on_iteration = on_iteration
|
||||
self._stream_buf = ""
|
||||
self._think_extractor = IncrementalThinkExtractor()
|
||||
self._reasoning_open = False
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
return self._on_stream is not None
|
||||
|
||||
@staticmethod
|
||||
def _strip_think(text: str | None) -> str | None:
|
||||
if not text:
|
||||
return None
|
||||
return strip_think(text) or None
|
||||
|
||||
def _tool_hint(self, tool_calls: list[Any]) -> str:
|
||||
return format_tool_hints(tool_calls, max_length=self._tool_hint_max_length)
|
||||
|
||||
@staticmethod
|
||||
def _on_progress_accepts(cb: Callable[..., Any], name: str) -> bool:
|
||||
try:
|
||||
sig = inspect.signature(cb)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()):
|
||||
return True
|
||||
return name in sig.parameters
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
prev_clean = strip_think(self._stream_buf)
|
||||
self._stream_buf += delta
|
||||
new_clean = strip_think(self._stream_buf)
|
||||
incremental = new_clean[len(prev_clean) :]
|
||||
|
||||
if await self._think_extractor.feed(self._stream_buf, self.emit_reasoning):
|
||||
context.streamed_reasoning = True
|
||||
|
||||
if incremental:
|
||||
# Answer text has started; close the reasoning segment so the UI can
|
||||
# lock the bubble before the answer renders below it.
|
||||
await self.emit_reasoning_end()
|
||||
if self._on_stream:
|
||||
await self._on_stream(incremental)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
await self.emit_reasoning_end()
|
||||
if self._on_stream_end:
|
||||
await self._on_stream_end(resuming=resuming)
|
||||
self._stream_buf = ""
|
||||
self._think_extractor.reset()
|
||||
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
if self._on_iteration:
|
||||
self._on_iteration(context.iteration)
|
||||
logger.debug(
|
||||
"Starting agent loop iteration {} for session {}",
|
||||
context.iteration,
|
||||
self._session_key,
|
||||
)
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
if self._on_progress:
|
||||
if not self._on_stream and not context.streamed_content:
|
||||
thought = self._strip_think(context.response.content if context.response else None)
|
||||
if thought:
|
||||
await self._on_progress(thought)
|
||||
tool_hint = self._strip_think(self._tool_hint(context.tool_calls))
|
||||
tool_events = [build_tool_event_start_payload(tc) for tc in context.tool_calls]
|
||||
await invoke_on_progress(
|
||||
self._on_progress,
|
||||
tool_hint,
|
||||
tool_hint=True,
|
||||
tool_events=tool_events,
|
||||
)
|
||||
for tc in context.tool_calls:
|
||||
args_str = json.dumps(tc.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
||||
if self._set_tool_context:
|
||||
self._set_tool_context(
|
||||
self._channel,
|
||||
self._chat_id,
|
||||
self._message_id,
|
||||
self._metadata,
|
||||
session_key=self._session_key,
|
||||
)
|
||||
|
||||
async def emit_reasoning(self, reasoning_content: str | None) -> None:
|
||||
"""Publish a reasoning chunk; channel plugins decide whether to render."""
|
||||
if (
|
||||
self._on_progress
|
||||
and reasoning_content
|
||||
and self._on_progress_accepts(self._on_progress, "reasoning")
|
||||
):
|
||||
self._reasoning_open = True
|
||||
await self._on_progress(reasoning_content, reasoning=True)
|
||||
|
||||
async def emit_reasoning_end(self) -> None:
|
||||
"""Close the current reasoning stream segment, if any was open."""
|
||||
if self._reasoning_open and self._on_progress:
|
||||
self._reasoning_open = False
|
||||
await self._on_progress("", reasoning_end=True)
|
||||
else:
|
||||
self._reasoning_open = False
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
if (
|
||||
self._on_progress
|
||||
and context.tool_calls
|
||||
and context.tool_events
|
||||
and on_progress_accepts_tool_events(self._on_progress)
|
||||
):
|
||||
tool_events = build_tool_event_finish_payloads(context)
|
||||
if tool_events:
|
||||
await invoke_on_progress(
|
||||
self._on_progress,
|
||||
"",
|
||||
tool_hint=False,
|
||||
tool_events=tool_events,
|
||||
)
|
||||
u = context.usage or {}
|
||||
logger.debug(
|
||||
"LLM usage: prompt={} completion={} cached={}",
|
||||
u.get("prompt_tokens", 0),
|
||||
u.get("completion_tokens", 0),
|
||||
u.get("cached_tokens", 0),
|
||||
)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return self._strip_think(content)
|
||||
@ -13,13 +13,14 @@ from typing import Any
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.tools.ask import AskUserInterrupt
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.utils.helpers import (
|
||||
IncrementalThinkExtractor,
|
||||
build_assistant_message,
|
||||
estimate_message_tokens,
|
||||
estimate_prompt_tokens_chain,
|
||||
extract_reasoning,
|
||||
find_legal_message_start,
|
||||
maybe_persist_tool_result,
|
||||
strip_think,
|
||||
@ -282,23 +283,30 @@ class AgentRunner:
|
||||
context.tool_calls = list(response.tool_calls)
|
||||
self._accumulate_usage(usage, raw_usage)
|
||||
|
||||
reasoning_text, cleaned_content = extract_reasoning(
|
||||
response.reasoning_content,
|
||||
response.thinking_blocks,
|
||||
response.content,
|
||||
)
|
||||
response.content = cleaned_content
|
||||
if reasoning_text and not context.streamed_reasoning:
|
||||
await hook.emit_reasoning(reasoning_text)
|
||||
await hook.emit_reasoning_end()
|
||||
context.streamed_reasoning = True
|
||||
|
||||
if response.should_execute_tools:
|
||||
tool_calls = list(response.tool_calls)
|
||||
ask_index = next((i for i, tc in enumerate(tool_calls) if tc.name == "ask_user"), None)
|
||||
if ask_index is not None:
|
||||
tool_calls = tool_calls[: ask_index + 1]
|
||||
context.tool_calls = list(tool_calls)
|
||||
context.tool_calls = list(response.tool_calls)
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=True)
|
||||
|
||||
assistant_message = build_assistant_message(
|
||||
response.content or "",
|
||||
tool_calls=[tc.to_openai_tool_call() for tc in tool_calls],
|
||||
tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls],
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
)
|
||||
messages.append(assistant_message)
|
||||
tools_used.extend(tc.name for tc in tool_calls)
|
||||
tools_used.extend(tc.name for tc in response.tool_calls)
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
@ -307,7 +315,7 @@ class AgentRunner:
|
||||
"model": spec.model,
|
||||
"assistant_message": assistant_message,
|
||||
"completed_tool_results": [],
|
||||
"pending_tool_calls": [tc.to_openai_tool_call() for tc in tool_calls],
|
||||
"pending_tool_calls": [tc.to_openai_tool_call() for tc in response.tool_calls],
|
||||
},
|
||||
)
|
||||
|
||||
@ -315,7 +323,7 @@ class AgentRunner:
|
||||
|
||||
results, new_events, fatal_error = await self._execute_tools(
|
||||
spec,
|
||||
tool_calls,
|
||||
response.tool_calls,
|
||||
external_lookup_counts,
|
||||
workspace_violation_counts,
|
||||
)
|
||||
@ -323,9 +331,7 @@ class AgentRunner:
|
||||
context.tool_results = list(results)
|
||||
context.tool_events = list(new_events)
|
||||
completed_tool_results: list[dict[str, Any]] = []
|
||||
for tool_call, result in zip(tool_calls, results):
|
||||
if isinstance(fatal_error, AskUserInterrupt) and tool_call.name == "ask_user":
|
||||
continue
|
||||
for tool_call, result in zip(response.tool_calls, results):
|
||||
tool_message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
@ -340,15 +346,6 @@ class AgentRunner:
|
||||
messages.append(tool_message)
|
||||
completed_tool_results.append(tool_message)
|
||||
if fatal_error is not None:
|
||||
if isinstance(fatal_error, AskUserInterrupt):
|
||||
final_content = fatal_error.question
|
||||
stop_reason = "ask_user"
|
||||
context.final_content = final_content
|
||||
context.stop_reason = stop_reason
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=False)
|
||||
await hook.after_iteration(context)
|
||||
break
|
||||
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
|
||||
final_content = error
|
||||
stop_reason = "tool_error"
|
||||
@ -621,6 +618,8 @@ class AgentRunner:
|
||||
and getattr(self.provider, "supports_progress_deltas", False) is True
|
||||
)
|
||||
|
||||
progress_state: dict[str, bool] | None = None
|
||||
|
||||
if wants_streaming:
|
||||
async def _stream(delta: str) -> None:
|
||||
if delta:
|
||||
@ -633,6 +632,8 @@ class AgentRunner:
|
||||
)
|
||||
elif wants_progress_streaming:
|
||||
stream_buf = ""
|
||||
think_extractor = IncrementalThinkExtractor()
|
||||
progress_state = {"reasoning_open": False}
|
||||
|
||||
async def _stream_progress(delta: str) -> None:
|
||||
nonlocal stream_buf
|
||||
@ -642,7 +643,15 @@ class AgentRunner:
|
||||
stream_buf += delta
|
||||
new_clean = strip_think(stream_buf)
|
||||
incremental = new_clean[len(prev_clean):]
|
||||
|
||||
if await think_extractor.feed(stream_buf, hook.emit_reasoning):
|
||||
context.streamed_reasoning = True
|
||||
progress_state["reasoning_open"] = True
|
||||
|
||||
if incremental:
|
||||
if progress_state["reasoning_open"]:
|
||||
await hook.emit_reasoning_end()
|
||||
progress_state["reasoning_open"] = False
|
||||
context.streamed_content = True
|
||||
await spec.progress_callback(incremental)
|
||||
|
||||
@ -653,16 +662,20 @@ class AgentRunner:
|
||||
else:
|
||||
coro = self.provider.chat_with_retry(**kwargs)
|
||||
|
||||
if timeout_s is None:
|
||||
return await coro
|
||||
try:
|
||||
return await asyncio.wait_for(coro, timeout=timeout_s)
|
||||
response = (
|
||||
await coro if timeout_s is None
|
||||
else await asyncio.wait_for(coro, timeout=timeout_s)
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
return LLMResponse(
|
||||
content=f"Error calling LLM: timed out after {timeout_s:g}s",
|
||||
finish_reason="error",
|
||||
error_kind="timeout",
|
||||
)
|
||||
if progress_state and progress_state.get("reasoning_open"):
|
||||
await hook.emit_reasoning_end()
|
||||
return response
|
||||
|
||||
async def _request_finalization_retry(
|
||||
self,
|
||||
@ -724,10 +737,6 @@ class AgentRunner:
|
||||
)
|
||||
tool_results.append(result)
|
||||
batch_results.append(result)
|
||||
if isinstance(result[2], AskUserInterrupt):
|
||||
break
|
||||
if any(isinstance(error, AskUserInterrupt) for _, _, error in batch_results):
|
||||
break
|
||||
|
||||
results: list[Any] = []
|
||||
events: list[dict[str, str]] = []
|
||||
@ -799,9 +808,6 @@ class AgentRunner:
|
||||
"status": "error",
|
||||
"detail": str(exc),
|
||||
}
|
||||
if isinstance(exc, AskUserInterrupt):
|
||||
event["status"] = "waiting"
|
||||
return "", event, exc
|
||||
payload = f"Error: {type(exc).__name__}: {exc}"
|
||||
handled = self._classify_violation(
|
||||
raw_text=str(exc),
|
||||
|
||||
@ -1,136 +0,0 @@
|
||||
"""Tool for pausing a turn until the user answers."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema
|
||||
|
||||
STRUCTURED_BUTTON_CHANNELS = frozenset({"telegram", "websocket"})
|
||||
|
||||
|
||||
class AskUserInterrupt(BaseException):
|
||||
"""Internal signal: the runner should stop and wait for user input."""
|
||||
|
||||
def __init__(self, question: str, options: list[str] | None = None) -> None:
|
||||
self.question = question
|
||||
self.options = [str(option) for option in (options or []) if str(option)]
|
||||
super().__init__(question)
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
question=StringSchema(
|
||||
"The question to ask before continuing. Use this only when the task needs the user's answer."
|
||||
),
|
||||
options=ArraySchema(
|
||||
StringSchema("A possible answer label"),
|
||||
description="Optional choices. The user may still reply with free text.",
|
||||
),
|
||||
required=["question"],
|
||||
)
|
||||
)
|
||||
class AskUserTool(Tool):
|
||||
"""Ask the user a blocking question."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "ask_user"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Pause and ask the user a question when their answer is required to continue. "
|
||||
"Use options for likely answers; the user's reply, typed or selected, is returned as the tool result. "
|
||||
"For non-blocking notifications or buttons, use the message tool instead."
|
||||
)
|
||||
|
||||
@property
|
||||
def exclusive(self) -> bool:
|
||||
return True
|
||||
|
||||
async def execute(self, question: str, options: list[str] | None = None, **_: Any) -> Any:
|
||||
raise AskUserInterrupt(question=question, options=options)
|
||||
|
||||
|
||||
def _tool_call_name(tool_call: dict[str, Any]) -> str:
|
||||
function = tool_call.get("function")
|
||||
if isinstance(function, dict) and isinstance(function.get("name"), str):
|
||||
return function["name"]
|
||||
name = tool_call.get("name")
|
||||
return name if isinstance(name, str) else ""
|
||||
|
||||
|
||||
def _tool_call_arguments(tool_call: dict[str, Any]) -> dict[str, Any]:
|
||||
function = tool_call.get("function")
|
||||
raw = function.get("arguments") if isinstance(function, dict) else tool_call.get("arguments")
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
return {}
|
||||
|
||||
|
||||
def pending_ask_user_id(history: list[dict[str, Any]]) -> str | None:
|
||||
pending: dict[str, str] = {}
|
||||
for message in history:
|
||||
if message.get("role") == "assistant":
|
||||
for tool_call in message.get("tool_calls") or []:
|
||||
if isinstance(tool_call, dict) and isinstance(tool_call.get("id"), str):
|
||||
pending[tool_call["id"]] = _tool_call_name(tool_call)
|
||||
elif message.get("role") == "tool":
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
if isinstance(tool_call_id, str):
|
||||
pending.pop(tool_call_id, None)
|
||||
for tool_call_id, name in reversed(pending.items()):
|
||||
if name == "ask_user":
|
||||
return tool_call_id
|
||||
return None
|
||||
|
||||
|
||||
def ask_user_tool_result_messages(
|
||||
system_prompt: str,
|
||||
history: list[dict[str, Any]],
|
||||
tool_call_id: str,
|
||||
content: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{"role": "system", "content": system_prompt},
|
||||
*history,
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"name": "ask_user",
|
||||
"content": content,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def ask_user_options_from_messages(messages: list[dict[str, Any]]) -> list[str]:
|
||||
for message in reversed(messages):
|
||||
if message.get("role") != "assistant":
|
||||
continue
|
||||
for tool_call in reversed(message.get("tool_calls") or []):
|
||||
if not isinstance(tool_call, dict) or _tool_call_name(tool_call) != "ask_user":
|
||||
continue
|
||||
options = _tool_call_arguments(tool_call).get("options")
|
||||
if isinstance(options, list):
|
||||
return [str(option) for option in options if isinstance(option, str)]
|
||||
return []
|
||||
|
||||
|
||||
def ask_user_outbound(
|
||||
content: str | None,
|
||||
options: list[str],
|
||||
channel: str,
|
||||
) -> tuple[str | None, list[list[str]]]:
|
||||
if not options:
|
||||
return content, []
|
||||
if channel in STRUCTURED_BUTTON_CHANNELS:
|
||||
return content, [options]
|
||||
option_text = "\n".join(f"{index}. {option}" for index, option in enumerate(options, 1))
|
||||
return f"{content}\n\n{option_text}" if content else option_text, []
|
||||
@ -28,6 +28,7 @@ class BaseChannel(ABC):
|
||||
transcription_language: str | None = None
|
||||
send_progress: bool = True
|
||||
send_tool_hints: bool = False
|
||||
show_reasoning: bool = True
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
"""
|
||||
@ -120,6 +121,53 @@ class BaseChannel(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
async def send_reasoning_delta(
|
||||
self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
"""Stream a chunk of model reasoning/thinking content.
|
||||
|
||||
Default is no-op. Channels with a native low-emphasis primitive
|
||||
(Slack context block, Telegram expandable blockquote, Discord
|
||||
subtext, WebUI italic bubble, ...) override to render reasoning
|
||||
as a subordinate trace that updates in place as the model thinks.
|
||||
|
||||
Streaming contract mirrors :meth:`send_delta`: ``_reasoning_delta``
|
||||
is a chunk, ``_reasoning_end`` ends the current reasoning segment,
|
||||
and stateful implementations should key buffers by ``_stream_id``
|
||||
rather than only by ``chat_id``.
|
||||
"""
|
||||
return
|
||||
|
||||
async def send_reasoning_end(
|
||||
self, chat_id: str, metadata: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
"""Mark the end of a reasoning stream segment.
|
||||
|
||||
Default is no-op. Channels that buffer ``send_reasoning_delta``
|
||||
chunks for in-place updates use this signal to flush and freeze
|
||||
the rendered group; one-shot channels can ignore it entirely.
|
||||
"""
|
||||
return
|
||||
|
||||
async def send_reasoning(self, msg: OutboundMessage) -> None:
|
||||
"""Deliver a complete reasoning block.
|
||||
|
||||
Default implementation reuses the streaming pair so plugins only
|
||||
need to override the delta/end methods. Equivalent to one delta
|
||||
with the full content followed immediately by an end marker —
|
||||
keeps a single rendering path for both streamed and one-shot
|
||||
reasoning (e.g. DeepSeek-R1's final-response ``reasoning_content``).
|
||||
"""
|
||||
if not msg.content:
|
||||
return
|
||||
meta = dict(msg.metadata or {})
|
||||
meta.setdefault("_reasoning_delta", True)
|
||||
await self.send_reasoning_delta(msg.chat_id, msg.content, meta)
|
||||
end_meta = dict(meta)
|
||||
end_meta.pop("_reasoning_delta", None)
|
||||
end_meta["_reasoning_end"] = True
|
||||
await self.send_reasoning_end(msg.chat_id, end_meta)
|
||||
|
||||
@property
|
||||
def supports_streaming(self) -> bool:
|
||||
"""True when config enables streaming AND this subclass implements send_delta."""
|
||||
|
||||
@ -36,6 +36,7 @@ _SEND_RETRY_DELAYS = (1, 2, 4)
|
||||
_BOOL_CAMEL_ALIASES: dict[str, str] = {
|
||||
"send_progress": "sendProgress",
|
||||
"send_tool_hints": "sendToolHints",
|
||||
"show_reasoning": "showReasoning",
|
||||
}
|
||||
|
||||
class ChannelManager:
|
||||
@ -104,6 +105,9 @@ class ChannelManager:
|
||||
channel.send_tool_hints = self._resolve_bool_override(
|
||||
section, "send_tool_hints", self.config.channels.send_tool_hints,
|
||||
)
|
||||
channel.show_reasoning = self._resolve_bool_override(
|
||||
section, "show_reasoning", self.config.channels.show_reasoning,
|
||||
)
|
||||
self.channels[name] = channel
|
||||
logger.info("{} channel enabled", cls.display_name)
|
||||
except Exception as e:
|
||||
@ -279,6 +283,23 @@ class ChannelManager:
|
||||
timeout=1.0
|
||||
)
|
||||
|
||||
if (
|
||||
msg.metadata.get("_reasoning_delta")
|
||||
or msg.metadata.get("_reasoning_end")
|
||||
or msg.metadata.get("_reasoning")
|
||||
):
|
||||
# Reasoning rides its own plugin channel: only delivered
|
||||
# when the destination channel opts in via ``show_reasoning``
|
||||
# and overrides the streaming primitives. Channels without
|
||||
# a low-emphasis UI affordance keep the base no-op and the
|
||||
# content silently drops here. ``_reasoning`` (one-shot)
|
||||
# is accepted for backward compatibility with hooks that
|
||||
# haven't migrated to delta/end yet.
|
||||
channel = self.channels.get(msg.channel)
|
||||
if channel is not None and channel.show_reasoning:
|
||||
await self._send_with_retry(channel, msg)
|
||||
continue
|
||||
|
||||
if msg.metadata.get("_progress"):
|
||||
if msg.metadata.get("_tool_hint") and not self._should_send_progress(
|
||||
msg.channel, tool_hint=True,
|
||||
@ -329,7 +350,16 @@ class ChannelManager:
|
||||
@staticmethod
|
||||
async def _send_once(channel: BaseChannel, msg: OutboundMessage) -> None:
|
||||
"""Send one outbound message without retry policy."""
|
||||
if msg.metadata.get("_stream_delta") or msg.metadata.get("_stream_end"):
|
||||
if msg.metadata.get("_reasoning_end"):
|
||||
await channel.send_reasoning_end(msg.chat_id, msg.metadata)
|
||||
elif msg.metadata.get("_reasoning_delta"):
|
||||
await channel.send_reasoning_delta(msg.chat_id, msg.content, msg.metadata)
|
||||
elif msg.metadata.get("_reasoning"):
|
||||
# Back-compat: one-shot reasoning. BaseChannel translates this
|
||||
# to a single delta + end pair so plugins only implement the
|
||||
# streaming primitives.
|
||||
await channel.send_reasoning(msg)
|
||||
elif msg.metadata.get("_stream_delta") or msg.metadata.get("_stream_end"):
|
||||
await channel.send_delta(msg.chat_id, msg.content, msg.metadata)
|
||||
elif not msg.metadata.get("_streamed"):
|
||||
await channel.send(msg)
|
||||
|
||||
@ -471,7 +471,7 @@ class SlackChannel(BaseChannel):
|
||||
return preview.startswith(_HTML_DOWNLOAD_PREFIXES)
|
||||
|
||||
async def _on_block_action(self, client: SocketModeClient, req: SocketModeRequest) -> None:
|
||||
"""Handle button clicks from ask_user blocks."""
|
||||
"""Handle button clicks from inline action buttons."""
|
||||
await client.send_socket_mode_response(SocketModeResponse(envelope_id=req.envelope_id))
|
||||
payload = req.payload or {}
|
||||
actions = payload.get("actions") or []
|
||||
@ -568,7 +568,7 @@ class SlackChannel(BaseChannel):
|
||||
|
||||
@staticmethod
|
||||
def _build_button_blocks(text: str, buttons: list[list[str]]) -> list[dict[str, Any]]:
|
||||
"""Build Slack Block Kit blocks with action buttons for ask_user choices."""
|
||||
"""Build Slack Block Kit blocks with action buttons."""
|
||||
blocks: list[dict[str, Any]] = [
|
||||
{"type": "section", "text": {"type": "mrkdwn", "text": text[:3000]}},
|
||||
]
|
||||
@ -579,7 +579,7 @@ class SlackChannel(BaseChannel):
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": label[:75]},
|
||||
"value": label[:75],
|
||||
"action_id": f"ask_user_{label[:50]}",
|
||||
"action_id": f"btn_{label[:50]}",
|
||||
})
|
||||
if elements:
|
||||
blocks.append({"type": "actions", "elements": elements[:25]})
|
||||
|
||||
@ -55,14 +55,6 @@ def _normalize_config_path(path: str) -> str:
|
||||
return _strip_trailing_slash(path)
|
||||
|
||||
|
||||
def _append_buttons_as_text(text: str, buttons: list[list[str]]) -> str:
|
||||
labels = [label for row in buttons for label in row if label]
|
||||
if not labels:
|
||||
return text
|
||||
fallback = "\n".join(f"{index}. {label}" for index, label in enumerate(labels, 1))
|
||||
return f"{text}\n\n{fallback}" if text else fallback
|
||||
|
||||
|
||||
class WebSocketConfig(Base):
|
||||
"""WebSocket server channel configuration.
|
||||
|
||||
@ -1468,16 +1460,11 @@ class WebSocketChannel(BaseChannel):
|
||||
await self.send_session_updated(msg.chat_id)
|
||||
return
|
||||
text = msg.content
|
||||
if msg.buttons:
|
||||
text = _append_buttons_as_text(text, msg.buttons)
|
||||
payload: dict[str, Any] = {
|
||||
"event": "message",
|
||||
"chat_id": msg.chat_id,
|
||||
"text": text,
|
||||
}
|
||||
if msg.buttons:
|
||||
payload["buttons"] = msg.buttons
|
||||
payload["button_prompt"] = msg.content
|
||||
if msg.media:
|
||||
payload["media"] = msg.media
|
||||
urls: list[dict[str, str]] = []
|
||||
@ -1500,6 +1487,54 @@ class WebSocketChannel(BaseChannel):
|
||||
for connection in conns:
|
||||
await self._safe_send_to(connection, raw, label=" ")
|
||||
|
||||
async def send_reasoning_delta(
|
||||
self,
|
||||
chat_id: str,
|
||||
delta: str,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Push one chunk of model reasoning. Mirrors ``send_delta`` shape so
|
||||
WebUI receives a stream that opens, updates in place, and closes —
|
||||
rendered above the active assistant bubble with a shimmer header
|
||||
until the matching ``reasoning_end`` arrives.
|
||||
"""
|
||||
conns = list(self._subs.get(chat_id, ()))
|
||||
if not conns or not delta:
|
||||
return
|
||||
meta = metadata or {}
|
||||
body: dict[str, Any] = {
|
||||
"event": "reasoning_delta",
|
||||
"chat_id": chat_id,
|
||||
"text": delta,
|
||||
}
|
||||
stream_id = meta.get("_stream_id")
|
||||
if stream_id is not None:
|
||||
body["stream_id"] = stream_id
|
||||
raw = json.dumps(body, ensure_ascii=False)
|
||||
for connection in conns:
|
||||
await self._safe_send_to(connection, raw, label=" reasoning ")
|
||||
|
||||
async def send_reasoning_end(
|
||||
self,
|
||||
chat_id: str,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Close the current reasoning stream segment for in-place renderers."""
|
||||
conns = list(self._subs.get(chat_id, ()))
|
||||
if not conns:
|
||||
return
|
||||
meta = metadata or {}
|
||||
body: dict[str, Any] = {
|
||||
"event": "reasoning_end",
|
||||
"chat_id": chat_id,
|
||||
}
|
||||
stream_id = meta.get("_stream_id")
|
||||
if stream_id is not None:
|
||||
body["stream_id"] = stream_id
|
||||
raw = json.dumps(body, ensure_ascii=False)
|
||||
for connection in conns:
|
||||
await self._safe_send_to(connection, raw, label=" reasoning_end ")
|
||||
|
||||
async def send_delta(
|
||||
self,
|
||||
chat_id: str,
|
||||
|
||||
@ -176,13 +176,15 @@ def _print_agent_response(
|
||||
response: str,
|
||||
render_markdown: bool,
|
||||
metadata: dict | None = None,
|
||||
show_header: bool = True,
|
||||
) -> None:
|
||||
"""Render assistant response with consistent terminal styling."""
|
||||
console = _make_console()
|
||||
content = response or ""
|
||||
body = _response_renderable(content, render_markdown, metadata)
|
||||
console.print()
|
||||
console.print(f"[cyan]{__logo__} nanobot[/cyan]")
|
||||
if show_header:
|
||||
console.print()
|
||||
console.print(f"[cyan]{__logo__} nanobot[/cyan]")
|
||||
console.print(body)
|
||||
console.print()
|
||||
|
||||
@ -228,42 +230,70 @@ async def _print_interactive_response(
|
||||
await run_in_terminal(_write)
|
||||
|
||||
|
||||
def _print_cli_progress_line(text: str, thinking: ThinkingSpinner | None) -> None:
|
||||
def _print_cli_progress_line(text: str, thinking: ThinkingSpinner | None, renderer: StreamRenderer | None = None) -> None:
|
||||
"""Print a CLI progress line, pausing the spinner if needed."""
|
||||
if not text.strip():
|
||||
return
|
||||
with thinking.pause() if thinking else nullcontext():
|
||||
console.print(f" [dim]↳ {text}[/dim]")
|
||||
target = renderer.console if renderer else console
|
||||
pause = renderer.pause_spinner() if renderer else (thinking.pause() if thinking else nullcontext())
|
||||
with pause:
|
||||
if renderer:
|
||||
renderer.ensure_header()
|
||||
target.print(f" [dim]↳ {text}[/dim]")
|
||||
|
||||
|
||||
async def _print_interactive_progress_line(text: str, renderer: StreamRenderer | None) -> None:
|
||||
"""Print an interactive progress line, pausing the renderer's spinner if needed."""
|
||||
def _print_cli_reasoning(text: str, thinking: ThinkingSpinner | None, renderer: StreamRenderer | None = None) -> None:
|
||||
"""Print reasoning/thinking content in a distinct style."""
|
||||
if not text.strip():
|
||||
return
|
||||
with renderer.pause() if renderer else nullcontext():
|
||||
await _print_interactive_line(text)
|
||||
target = renderer.console if renderer else console
|
||||
pause = renderer.pause_spinner() if renderer else (thinking.pause() if thinking else nullcontext())
|
||||
with pause:
|
||||
if renderer:
|
||||
renderer.ensure_header()
|
||||
target.print(f"[dim italic]✻ {text}[/dim italic]")
|
||||
|
||||
|
||||
async def _print_interactive_progress_line(text: str, thinking: ThinkingSpinner | None, renderer: StreamRenderer | None = None) -> None:
|
||||
"""Print an interactive progress line, pausing the spinner if needed."""
|
||||
if not text.strip():
|
||||
return
|
||||
if renderer:
|
||||
with renderer.pause_spinner():
|
||||
renderer.ensure_header()
|
||||
renderer.console.print(f" [dim]↳ {text}[/dim]")
|
||||
else:
|
||||
with thinking.pause() if thinking else nullcontext():
|
||||
await _print_interactive_line(text)
|
||||
|
||||
|
||||
async def _maybe_print_interactive_progress(
|
||||
msg: Any,
|
||||
renderer: StreamRenderer | None,
|
||||
thinking: ThinkingSpinner | None,
|
||||
channels_config: Any,
|
||||
renderer: StreamRenderer | None = None,
|
||||
) -> bool:
|
||||
metadata = msg.metadata or {}
|
||||
if metadata.get("_retry_wait"):
|
||||
await _print_interactive_progress_line(msg.content, renderer)
|
||||
await _print_interactive_progress_line(msg.content, thinking, renderer)
|
||||
return True
|
||||
|
||||
if not metadata.get("_progress"):
|
||||
return False
|
||||
|
||||
is_tool_hint = metadata.get("_tool_hint", False)
|
||||
is_reasoning = metadata.get("_reasoning", False) or metadata.get("_reasoning_delta", False)
|
||||
if is_reasoning:
|
||||
if channels_config and not channels_config.show_reasoning:
|
||||
return True
|
||||
_print_cli_reasoning(msg.content, thinking, renderer)
|
||||
return True
|
||||
if channels_config and is_tool_hint and not channels_config.send_tool_hints:
|
||||
return True
|
||||
if channels_config and not is_tool_hint and not channels_config.send_progress:
|
||||
return True
|
||||
|
||||
await _print_interactive_progress_line(msg.content, renderer)
|
||||
await _print_interactive_progress_line(msg.content, thinking, renderer)
|
||||
return True
|
||||
|
||||
|
||||
@ -1064,13 +1094,20 @@ def agent(
|
||||
# Shared reference for progress callbacks
|
||||
_thinking: ThinkingSpinner | None = None
|
||||
|
||||
async def _cli_progress(content: str, *, tool_hint: bool = False, **_kwargs: Any) -> None:
|
||||
ch = agent_loop.channels_config
|
||||
if ch and tool_hint and not ch.send_tool_hints:
|
||||
return
|
||||
if ch and not tool_hint and not ch.send_progress:
|
||||
return
|
||||
_print_cli_progress_line(content, _thinking)
|
||||
def _make_progress(renderer: StreamRenderer | None = None):
|
||||
async def _cli_progress(content: str, *, tool_hint: bool = False, reasoning: bool = False, **_kwargs: Any) -> None:
|
||||
ch = agent_loop.channels_config
|
||||
if reasoning:
|
||||
if ch and not ch.show_reasoning:
|
||||
return
|
||||
_print_cli_reasoning(content, _thinking, renderer)
|
||||
return
|
||||
if ch and tool_hint and not ch.send_tool_hints:
|
||||
return
|
||||
if ch and not tool_hint and not ch.send_progress:
|
||||
return
|
||||
_print_cli_progress_line(content, _thinking, renderer)
|
||||
return _cli_progress
|
||||
|
||||
if message:
|
||||
# Single message mode — direct call, no bus needed
|
||||
@ -1082,16 +1119,20 @@ def agent(
|
||||
)
|
||||
response = await agent_loop.process_direct(
|
||||
message, session_id,
|
||||
on_progress=_cli_progress,
|
||||
on_progress=_make_progress(renderer),
|
||||
on_stream=renderer.on_delta,
|
||||
on_stream_end=renderer.on_end,
|
||||
)
|
||||
if not renderer.streamed:
|
||||
await renderer.close()
|
||||
print_kwargs: dict[str, Any] = {}
|
||||
if renderer.header_printed:
|
||||
print_kwargs["show_header"] = False
|
||||
_print_agent_response(
|
||||
response.content if response else "",
|
||||
render_markdown=markdown,
|
||||
metadata=response.metadata if response else None,
|
||||
**print_kwargs,
|
||||
)
|
||||
await agent_loop.close_mcp()
|
||||
|
||||
@ -1154,6 +1195,7 @@ def agent(
|
||||
msg,
|
||||
renderer,
|
||||
agent_loop.channels_config,
|
||||
renderer,
|
||||
):
|
||||
continue
|
||||
|
||||
@ -1215,8 +1257,14 @@ def agent(
|
||||
if content and not meta.get("_streamed"):
|
||||
if renderer:
|
||||
await renderer.close()
|
||||
print_kwargs: dict[str, Any] = {}
|
||||
if renderer and renderer.header_printed:
|
||||
print_kwargs["show_header"] = False
|
||||
_print_agent_response(
|
||||
content, render_markdown=markdown, metadata=meta,
|
||||
content,
|
||||
render_markdown=markdown,
|
||||
metadata=meta,
|
||||
**print_kwargs,
|
||||
)
|
||||
elif renderer and not renderer.streamed:
|
||||
await renderer.close()
|
||||
|
||||
@ -1,13 +1,16 @@
|
||||
"""Streaming renderer for CLI output.
|
||||
|
||||
Uses Rich Live with auto_refresh=False for stable, flicker-free
|
||||
markdown rendering during streaming. Ellipsis mode handles overflow.
|
||||
Uses Rich Live with ``transient=True`` for in-place markdown updates during
|
||||
streaming. After the live display stops, a final clean render is printed
|
||||
so the content persists on screen. ``transient=True`` ensures the live
|
||||
area is erased before ``stop()`` returns, avoiding the duplication bug
|
||||
that plagued earlier approaches.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import time
|
||||
from contextlib import contextmanager, nullcontext
|
||||
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
@ -15,6 +18,16 @@ from rich.markdown import Markdown
|
||||
from rich.text import Text
|
||||
|
||||
|
||||
def _clear_current_line(console: Console) -> None:
|
||||
"""Erase a transient status line before printing persistent output."""
|
||||
file = console.file
|
||||
isatty = getattr(file, "isatty", lambda: False)
|
||||
if not isatty():
|
||||
return
|
||||
file.write("\r\x1b[2K")
|
||||
file.flush()
|
||||
|
||||
|
||||
def _make_console() -> Console:
|
||||
"""Create a Console that emits plain text when stdout is not a TTY.
|
||||
|
||||
@ -34,6 +47,7 @@ class ThinkingSpinner:
|
||||
|
||||
def __init__(self, console: Console | None = None, bot_name: str = "nanobot"):
|
||||
c = console or _make_console()
|
||||
self._console = c
|
||||
self._spinner = c.status(f"[dim]{bot_name} is thinking...[/dim]", spinner="dots")
|
||||
self._active = False
|
||||
|
||||
@ -45,6 +59,7 @@ class ThinkingSpinner:
|
||||
def __exit__(self, *exc):
|
||||
self._active = False
|
||||
self._spinner.stop()
|
||||
_clear_current_line(self._console)
|
||||
return False
|
||||
|
||||
def pause(self):
|
||||
@ -55,6 +70,7 @@ class ThinkingSpinner:
|
||||
def _ctx():
|
||||
if self._spinner and self._active:
|
||||
self._spinner.stop()
|
||||
_clear_current_line(self._console)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
@ -65,13 +81,14 @@ class ThinkingSpinner:
|
||||
|
||||
|
||||
class StreamRenderer:
|
||||
"""Rich Live streaming with markdown. auto_refresh=False avoids render races.
|
||||
"""Streaming renderer with Rich Live for in-place updates.
|
||||
|
||||
Deltas arrive pre-filtered (no <think> tags) from the agent loop.
|
||||
During streaming: updates content in-place via Rich Live.
|
||||
On end: stops Live (transient=True erases it), then prints final render.
|
||||
|
||||
Flow per round:
|
||||
spinner -> first visible delta -> header + Live renders ->
|
||||
on_end -> Live stops (content stays on screen)
|
||||
spinner -> first delta -> header + Live updates ->
|
||||
on_end -> stop Live + final render
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -86,14 +103,24 @@ class StreamRenderer:
|
||||
self._bot_name = bot_name
|
||||
self._bot_icon = bot_icon
|
||||
self._buf = ""
|
||||
self._live: Live | None = None
|
||||
self._t = 0.0
|
||||
self.streamed = False
|
||||
self._console = _make_console()
|
||||
self._live: Live | None = None
|
||||
self._spinner: ThinkingSpinner | None = None
|
||||
self._header_printed = False
|
||||
self._start_spinner()
|
||||
|
||||
def _render(self):
|
||||
return Markdown(self._buf) if self._md and self._buf else Text(self._buf or "")
|
||||
def _renderable(self):
|
||||
"""Create a renderable from the current buffer."""
|
||||
if self._md and self._buf:
|
||||
return Markdown(self._buf)
|
||||
return Text(self._buf or "")
|
||||
|
||||
def _render_str(self) -> str:
|
||||
"""Render current buffer to a plain string via Rich."""
|
||||
with self._console.capture() as cap:
|
||||
self._console.print(self._renderable())
|
||||
return cap.get()
|
||||
|
||||
def _start_spinner(self) -> None:
|
||||
if self._show_spinner:
|
||||
@ -105,37 +132,85 @@ class StreamRenderer:
|
||||
self._spinner.__exit__(None, None, None)
|
||||
self._spinner = None
|
||||
|
||||
@property
|
||||
def console(self) -> Console:
|
||||
"""Expose the Live's console so external print functions can use it."""
|
||||
return self._console
|
||||
|
||||
@property
|
||||
def header_printed(self) -> bool:
|
||||
"""Whether this turn has already opened the assistant output block."""
|
||||
return self._header_printed
|
||||
|
||||
def ensure_header(self) -> None:
|
||||
"""Stop transient status and print the assistant header once."""
|
||||
# A turn can print trace rows before the final answer, then restart the
|
||||
# spinner while tools run. The next answer delta still needs to stop
|
||||
# that spinner even though the header was already printed.
|
||||
self._stop_spinner()
|
||||
if self._header_printed:
|
||||
return
|
||||
self._console.print()
|
||||
header = f"{self._bot_icon} {self._bot_name}" if self._bot_icon else self._bot_name
|
||||
self._console.print(f"[cyan]{header}[/cyan]")
|
||||
self._header_printed = True
|
||||
|
||||
def pause_spinner(self):
|
||||
"""Context manager: temporarily stop transient output for clean trace lines."""
|
||||
@contextmanager
|
||||
def _pause():
|
||||
live_was_active = self._live is not None
|
||||
if self._live:
|
||||
# Trace/reasoning can arrive after answer streaming has started.
|
||||
# Stop the transient Live view first so it does not leak a raw
|
||||
# partial markdown frame before the trace line.
|
||||
self._live.stop()
|
||||
self._live = None
|
||||
with self._spinner.pause() if self._spinner else nullcontext():
|
||||
yield
|
||||
# If more answer deltas arrive after the trace, on_delta() will
|
||||
# create a fresh Live using the existing buffer. If no deltas arrive,
|
||||
# on_end() prints the final buffered answer once.
|
||||
if live_was_active:
|
||||
return
|
||||
|
||||
return _pause()
|
||||
|
||||
async def on_delta(self, delta: str) -> None:
|
||||
self.streamed = True
|
||||
self._buf += delta
|
||||
if self._live is None:
|
||||
if not self._buf.strip():
|
||||
return
|
||||
self._stop_spinner()
|
||||
c = _make_console()
|
||||
c.print()
|
||||
header = f"{self._bot_icon} {self._bot_name}" if self._bot_icon else self._bot_name
|
||||
c.print(f"[cyan]{header}[/cyan]")
|
||||
self._live = Live(self._render(), console=c, auto_refresh=False)
|
||||
self.ensure_header()
|
||||
self._live = Live(
|
||||
self._renderable(),
|
||||
console=self._console,
|
||||
auto_refresh=False,
|
||||
transient=True,
|
||||
)
|
||||
self._live.start()
|
||||
now = time.monotonic()
|
||||
if (now - self._t) > 0.15:
|
||||
self._live.update(self._render())
|
||||
self._live.refresh()
|
||||
self._t = now
|
||||
else:
|
||||
self._live.update(self._renderable())
|
||||
self._live.refresh()
|
||||
|
||||
async def on_end(self, *, resuming: bool = False) -> None:
|
||||
if self._live:
|
||||
self._live.update(self._render())
|
||||
# Double-refresh to sync _shape before stop() calls refresh().
|
||||
self._live.refresh()
|
||||
self._live.update(self._renderable())
|
||||
self._live.refresh()
|
||||
self._live.stop()
|
||||
self._live = None
|
||||
self._stop_spinner()
|
||||
if self._buf.strip():
|
||||
# Print final rendered content (persists after Live is gone).
|
||||
out = sys.stdout
|
||||
out.write(self._render_str())
|
||||
out.flush()
|
||||
if resuming:
|
||||
self._buf = ""
|
||||
self._start_spinner()
|
||||
else:
|
||||
_make_console().print()
|
||||
|
||||
def stop_for_input(self) -> None:
|
||||
"""Stop spinner before user input to avoid prompt_toolkit conflicts."""
|
||||
@ -143,7 +218,6 @@ class StreamRenderer:
|
||||
|
||||
def pause(self):
|
||||
"""Context manager: pause spinner for external output. No-op once streaming has started."""
|
||||
from contextlib import nullcontext
|
||||
if self._spinner:
|
||||
return self._spinner.pause()
|
||||
return nullcontext()
|
||||
|
||||
@ -35,6 +35,7 @@ class ChannelsConfig(Base):
|
||||
|
||||
send_progress: bool = True # stream agent's text progress to the channel
|
||||
send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…"))
|
||||
show_reasoning: bool = True # surface model reasoning when channel implements it
|
||||
send_max_retries: int = Field(default=3, ge=0, le=10) # Max delivery attempts (initial send included)
|
||||
transcription_provider: str = "groq" # Voice transcription backend: "groq" or "openai"
|
||||
transcription_language: str | None = Field(default=None, pattern=r"^[a-z]{2,3}$") # Optional ISO-639-1 hint for audio transcription
|
||||
|
||||
@ -25,6 +25,7 @@ FILE_MAX_MESSAGES = 2000
|
||||
_MESSAGE_TIME_PREFIX_RE = re.compile(r"^\[Message Time: [^\]]+\]\n?")
|
||||
_LOCAL_IMAGE_BREADCRUMB_RE = re.compile(r"^\[image: (?:/|~)[^\]]+\]\s*$")
|
||||
_TOOL_CALL_ECHO_RE = re.compile(r'^\s*(?:generate_image|message)\([^)]*\)\s*$')
|
||||
_SESSION_PREVIEW_MAX_CHARS = 120
|
||||
|
||||
|
||||
def _sanitize_assistant_replay_text(content: str) -> str:
|
||||
@ -43,6 +44,27 @@ def _sanitize_assistant_replay_text(content: str) -> str:
|
||||
return "\n".join(lines).strip()
|
||||
|
||||
|
||||
def _text_preview(content: Any) -> str:
|
||||
"""Return compact display text for session lists."""
|
||||
if isinstance(content, str):
|
||||
text = content
|
||||
elif isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
value = block.get("text")
|
||||
if isinstance(value, str):
|
||||
parts.append(value)
|
||||
text = " ".join(parts)
|
||||
else:
|
||||
return ""
|
||||
text = _sanitize_assistant_replay_text(text)
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
if len(text) > _SESSION_PREVIEW_MAX_CHARS:
|
||||
text = text[: _SESSION_PREVIEW_MAX_CHARS - 1].rstrip() + "…"
|
||||
return text
|
||||
|
||||
|
||||
@dataclass
|
||||
class Session:
|
||||
"""A conversation session."""
|
||||
@ -560,7 +582,7 @@ class SessionManager:
|
||||
for path in self.sessions_dir.glob("*.jsonl"):
|
||||
fallback_key = path.stem.replace("_", ":", 1)
|
||||
try:
|
||||
# Read just the metadata line
|
||||
# Read the metadata line and a small preview for WebUI/session lists.
|
||||
with open(path, encoding="utf-8") as f:
|
||||
first_line = f.readline().strip()
|
||||
if first_line:
|
||||
@ -569,11 +591,29 @@ class SessionManager:
|
||||
key = data.get("key") or path.stem.replace("_", ":", 1)
|
||||
metadata = data.get("metadata", {})
|
||||
title = metadata.get("title") if isinstance(metadata, dict) else None
|
||||
preview = ""
|
||||
fallback_preview = ""
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
item = json.loads(line)
|
||||
if item.get("_type") == "metadata":
|
||||
continue
|
||||
text = _text_preview(item.get("content"))
|
||||
if not text:
|
||||
continue
|
||||
if item.get("role") == "user":
|
||||
preview = text
|
||||
break
|
||||
if not fallback_preview and item.get("role") == "assistant":
|
||||
fallback_preview = text
|
||||
preview = preview or fallback_preview
|
||||
sessions.append({
|
||||
"key": key,
|
||||
"created_at": data.get("created_at"),
|
||||
"updated_at": data.get("updated_at"),
|
||||
"title": title if isinstance(title, str) else "",
|
||||
"preview": preview,
|
||||
"path": str(path)
|
||||
})
|
||||
except Exception:
|
||||
@ -588,6 +628,14 @@ class SessionManager:
|
||||
if isinstance(repaired.metadata.get("title"), str)
|
||||
else ""
|
||||
),
|
||||
"preview": next(
|
||||
(
|
||||
text
|
||||
for msg in repaired.messages
|
||||
if (text := _text_preview(msg.get("content")))
|
||||
),
|
||||
"",
|
||||
),
|
||||
"path": str(path)
|
||||
})
|
||||
continue
|
||||
|
||||
@ -11,7 +11,7 @@ Generate a personalized upgrade skill for this workspace.
|
||||
|
||||
Use `read_file` to check if `skills/update/SKILL.md` already exists in the workspace.
|
||||
|
||||
If it exists, use `ask_user` to ask: "An upgrade skill already exists. Reconfigure?" with options ["yes", "no"]. If no, stop here.
|
||||
If it exists, ask the user: "An upgrade skill already exists. Reconfigure?" Wait for the user's reply. If no, stop here.
|
||||
|
||||
## Step 2: Current Version and Install Clues
|
||||
|
||||
@ -38,9 +38,9 @@ answer or confirmation, not from inference alone. If you cannot get a clear
|
||||
answer, stop and ask the user to rerun this setup when they know how nanobot was
|
||||
installed.
|
||||
|
||||
Use `ask_user` for the questions below, one question per call. If `ask_user` is
|
||||
not available or cannot collect the answer, ask in normal chat and stop without
|
||||
writing the skill.
|
||||
Ask the user the questions below, one at a time, in your response text. Wait for
|
||||
the user's reply before proceeding to the next question. If you cannot get a clear
|
||||
answer, stop without writing the skill.
|
||||
|
||||
**Question 1 — Install method:**
|
||||
|
||||
|
||||
@ -71,6 +71,93 @@ def strip_think(text: str) -> str:
|
||||
return text.strip()
|
||||
|
||||
|
||||
def extract_think(text: str) -> tuple[str | None, str]:
|
||||
"""Extract thinking content from inline ``<think>`` / ``<thought>`` blocks.
|
||||
|
||||
Returns ``(thinking_text, cleaned_text)``. Only closed blocks are
|
||||
extracted; unclosed streaming prefixes are stripped from the cleaned
|
||||
text but not surfaced — :func:`strip_think` handles that case.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
for m in re.finditer(r"<think>([\s\S]*?)</think>", text):
|
||||
parts.append(m.group(1).strip())
|
||||
for m in re.finditer(r"<thought>([\s\S]*?)</thought>", text):
|
||||
parts.append(m.group(1).strip())
|
||||
thinking = "\n\n".join(parts) if parts else None
|
||||
return thinking, strip_think(text)
|
||||
|
||||
|
||||
class IncrementalThinkExtractor:
|
||||
"""Stateful inline ``<think>`` extractor for streaming buffers.
|
||||
|
||||
Streaming providers expose only a single content delta channel. When a
|
||||
model embeds reasoning in ``<think>...</think>`` blocks inside that
|
||||
channel, callers need to surface the reasoning incrementally as it
|
||||
arrives without re-emitting earlier text. This holds the "already
|
||||
emitted" cursor so the runner and the loop hook share one shape.
|
||||
"""
|
||||
|
||||
__slots__ = ("_emitted",)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._emitted = ""
|
||||
|
||||
def reset(self) -> None:
|
||||
self._emitted = ""
|
||||
|
||||
async def feed(self, buf: str, emit: Any) -> bool:
|
||||
"""Emit any new thinking text found in ``buf``.
|
||||
|
||||
Returns True if anything was emitted this call. ``emit`` is an
|
||||
async callable taking a single string (typically
|
||||
``hook.emit_reasoning``).
|
||||
"""
|
||||
thinking, _ = extract_think(buf)
|
||||
if not thinking or thinking == self._emitted:
|
||||
return False
|
||||
new = thinking[len(self._emitted):].strip()
|
||||
self._emitted = thinking
|
||||
if not new:
|
||||
return False
|
||||
await emit(new)
|
||||
return True
|
||||
|
||||
|
||||
def extract_reasoning(
|
||||
reasoning_content: str | None,
|
||||
thinking_blocks: list[dict[str, Any]] | None,
|
||||
content: str | None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Return ``(reasoning_text, cleaned_content)`` from one model response.
|
||||
|
||||
Single source of truth for "what reasoning did this response carry, and
|
||||
what answer text remains after we peel it out". Fallback order:
|
||||
|
||||
1. Dedicated ``reasoning_content`` (DeepSeek-R1, Kimi, MiMo, OpenAI
|
||||
reasoning models, Bedrock).
|
||||
2. Anthropic ``thinking_blocks``.
|
||||
3. Inline ``<think>`` / ``<thought>`` blocks in ``content``.
|
||||
|
||||
Only one source contributes per response; lower-priority sources are
|
||||
ignored if a higher-priority one is present, but inline ``<think>``
|
||||
tags are still stripped from ``content`` so they never leak into the
|
||||
final answer.
|
||||
"""
|
||||
if reasoning_content:
|
||||
return reasoning_content, strip_think(content) if content else content
|
||||
if thinking_blocks:
|
||||
parts = [
|
||||
tb.get("thinking", "")
|
||||
for tb in thinking_blocks
|
||||
if isinstance(tb, dict) and tb.get("type") == "thinking"
|
||||
]
|
||||
joined = "\n\n".join(p for p in parts if p)
|
||||
return (joined or None), strip_think(content) if content else content
|
||||
if content:
|
||||
return extract_think(content)
|
||||
return None, content
|
||||
|
||||
|
||||
def detect_image_mime(data: bytes) -> str | None:
|
||||
"""Detect image MIME type from magic bytes, ignoring file extension."""
|
||||
if data[:8] == b"\x89PNG\r\n\x1a\n":
|
||||
|
||||
93
tests/agent/conftest.py
Normal file
93
tests/agent/conftest.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""Shared fixtures and helpers for agent tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMProvider
|
||||
|
||||
|
||||
def make_provider(
|
||||
default_model: str = "test-model",
|
||||
*,
|
||||
max_tokens: int = 4096,
|
||||
spec: bool = True,
|
||||
) -> MagicMock:
|
||||
"""Create a spec-limited LLM provider mock."""
|
||||
mock_type = MagicMock(spec=LLMProvider) if spec else MagicMock()
|
||||
provider = mock_type
|
||||
provider.get_default_model.return_value = default_model
|
||||
provider.generation = SimpleNamespace(
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.1,
|
||||
reasoning_effort=None,
|
||||
)
|
||||
provider.estimate_prompt_tokens.return_value = (10_000, "test")
|
||||
return provider
|
||||
|
||||
|
||||
def make_loop(
|
||||
tmp_path: Path,
|
||||
*,
|
||||
model: str = "test-model",
|
||||
context_window_tokens: int = 128_000,
|
||||
session_ttl_minutes: int = 0,
|
||||
max_messages: int = 120,
|
||||
unified_session: bool = False,
|
||||
mcp_servers: dict | None = None,
|
||||
tools_config=None,
|
||||
model_presets: dict | None = None,
|
||||
hooks: list | None = None,
|
||||
provider: MagicMock | None = None,
|
||||
patch_deps: bool = False,
|
||||
) -> AgentLoop:
|
||||
"""Create a real AgentLoop for testing.
|
||||
|
||||
Args:
|
||||
patch_deps: If True, patch ContextBuilder/SessionManager/SubagentManager
|
||||
during construction (needed when workspace has no real files).
|
||||
"""
|
||||
bus = MessageBus()
|
||||
if provider is None:
|
||||
provider = make_provider(default_model=model)
|
||||
|
||||
kwargs = dict(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model=model,
|
||||
context_window_tokens=context_window_tokens,
|
||||
session_ttl_minutes=session_ttl_minutes,
|
||||
max_messages=max_messages,
|
||||
unified_session=unified_session,
|
||||
)
|
||||
if mcp_servers is not None:
|
||||
kwargs["mcp_servers"] = mcp_servers
|
||||
if tools_config is not None:
|
||||
kwargs["tools_config"] = tools_config
|
||||
if model_presets is not None:
|
||||
kwargs["model_presets"] = model_presets
|
||||
if hooks is not None:
|
||||
kwargs["hooks"] = hooks
|
||||
|
||||
if patch_deps:
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||
return AgentLoop(**kwargs)
|
||||
return AgentLoop(**kwargs)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loop_factory(tmp_path):
|
||||
"""Fixture providing a factory for creating AgentLoop instances."""
|
||||
def _factory(**kwargs):
|
||||
return make_loop(tmp_path, **kwargs)
|
||||
return _factory
|
||||
@ -1,241 +0,0 @@
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||
from nanobot.agent.tools.ask import AskUserInterrupt, AskUserTool
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.schema import tool_parameters_schema
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import GenerationSettings, LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
def _make_provider(chat_with_retry):
|
||||
async def chat_stream_with_retry(**kwargs):
|
||||
kwargs.pop("on_content_delta", None)
|
||||
return await chat_with_retry(**kwargs)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.generation = GenerationSettings()
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
return provider
|
||||
|
||||
|
||||
def test_ask_user_tool_schema_and_interrupt():
|
||||
tool = AskUserTool()
|
||||
schema = tool.to_schema()["function"]
|
||||
|
||||
assert schema["name"] == "ask_user"
|
||||
assert "question" in schema["parameters"]["required"]
|
||||
assert schema["parameters"]["properties"]["options"]["type"] == "array"
|
||||
|
||||
with pytest.raises(AskUserInterrupt) as exc:
|
||||
asyncio.run(tool.execute("Continue?", options=["Yes", "No"]))
|
||||
|
||||
assert exc.value.question == "Continue?"
|
||||
assert exc.value.options == ["Yes", "No"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_pauses_on_ask_user_without_executing_later_tools():
|
||||
@tool_parameters(tool_parameters_schema(required=[]))
|
||||
class LaterTool(Tool):
|
||||
called = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "later"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Should not run after ask_user pauses the turn."
|
||||
|
||||
async def execute(self, **kwargs):
|
||||
self.called = True
|
||||
return "later result"
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(
|
||||
content="",
|
||||
finish_reason="tool_calls",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_ask",
|
||||
name="ask_user",
|
||||
arguments={"question": "Install this package?", "options": ["Yes", "No"]},
|
||||
),
|
||||
ToolCallRequest(id="call_later", name="later", arguments={}),
|
||||
],
|
||||
)
|
||||
|
||||
later = LaterTool()
|
||||
tools = ToolRegistry()
|
||||
tools.register(AskUserTool())
|
||||
tools.register(later)
|
||||
|
||||
result = await AgentRunner(_make_provider(chat_with_retry)).run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "continue"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=16_000,
|
||||
concurrent_tools=True,
|
||||
))
|
||||
|
||||
assert result.stop_reason == "ask_user"
|
||||
assert result.final_content == "Install this package?"
|
||||
assert "ask_user" in result.tools_used
|
||||
assert later.called is False
|
||||
assert result.messages[-1]["role"] == "assistant"
|
||||
tool_calls = result.messages[-1]["tool_calls"]
|
||||
assert [tool_call["function"]["name"] for tool_call in tool_calls] == ["ask_user"]
|
||||
assert not any(message.get("name") == "ask_user" for message in result.messages)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ask_user_text_fallback_resumes_with_next_message(tmp_path):
|
||||
seen_messages: list[list[dict]] = []
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
seen_messages.append(kwargs["messages"])
|
||||
if len(seen_messages) == 1:
|
||||
return LLMResponse(
|
||||
content="",
|
||||
finish_reason="tool_calls",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_ask",
|
||||
name="ask_user",
|
||||
arguments={
|
||||
"question": "Install the optional package?",
|
||||
"options": ["Install", "Skip"],
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
return LLMResponse(content="Skipped install.", usage={})
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=_make_provider(chat_with_retry),
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
pass
|
||||
|
||||
async def on_stream_end(**kwargs) -> None:
|
||||
pass
|
||||
|
||||
first = await loop._process_message(
|
||||
InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="set it up"),
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
)
|
||||
|
||||
assert first is not None
|
||||
assert first.content == "Install the optional package?\n\n1. Install\n2. Skip"
|
||||
assert first.buttons == []
|
||||
assert "_streamed" not in first.metadata
|
||||
|
||||
session = loop.sessions.get_or_create("cli:direct")
|
||||
assert any(message.get("role") == "assistant" and message.get("tool_calls") for message in session.messages)
|
||||
assert not any(message.get("role") == "tool" and message.get("name") == "ask_user" for message in session.messages)
|
||||
|
||||
second = await loop._process_message(
|
||||
InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="Skip")
|
||||
)
|
||||
|
||||
assert second is not None
|
||||
assert second.content == "Skipped install."
|
||||
assert any(
|
||||
message.get("role") == "tool"
|
||||
and message.get("name") == "ask_user"
|
||||
and message.get("content") == "Skip"
|
||||
for message in seen_messages[-1]
|
||||
)
|
||||
assert not any(
|
||||
message.get("role") == "user" and message.get("content") == "Skip"
|
||||
for message in session.messages
|
||||
)
|
||||
assert any(
|
||||
message.get("role") == "tool"
|
||||
and message.get("name") == "ask_user"
|
||||
and message.get("content") == "Skip"
|
||||
for message in session.messages
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ask_user_keeps_buttons_for_telegram(tmp_path):
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(
|
||||
content="",
|
||||
finish_reason="tool_calls",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_ask",
|
||||
name="ask_user",
|
||||
arguments={
|
||||
"question": "Install the optional package?",
|
||||
"options": ["Install", "Skip"],
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=_make_provider(chat_with_retry),
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
response = await loop._process_message(
|
||||
InboundMessage(channel="telegram", sender_id="user", chat_id="123", content="set it up")
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.content == "Install the optional package?"
|
||||
assert response.buttons == [["Install", "Skip"]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ask_user_keeps_buttons_for_websocket(tmp_path):
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(
|
||||
content="",
|
||||
finish_reason="tool_calls",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_ask",
|
||||
name="ask_user",
|
||||
arguments={
|
||||
"question": "Install the optional package?",
|
||||
"options": ["Install", "Skip"],
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=_make_provider(chat_with_retry),
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
response = await loop._process_message(
|
||||
InboundMessage(channel="websocket", sender_id="user", chat_id="123", content="set it up")
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.content == "Install the optional package?"
|
||||
assert response.buttons == [["Install", "Skip"]]
|
||||
554
tests/agent/test_autocompact_unit.py
Normal file
554
tests/agent/test_autocompact_unit.py
Normal file
@ -0,0 +1,554 @@
|
||||
"""Direct unit tests for AutoCompact class methods in isolation."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.autocompact import AutoCompact
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
|
||||
|
||||
def _make_session(
|
||||
key: str = "cli:test",
|
||||
messages: list | None = None,
|
||||
last_consolidated: int = 0,
|
||||
updated_at: datetime | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> Session:
|
||||
"""Create a Session with sensible defaults for testing."""
|
||||
session = Session(
|
||||
key=key,
|
||||
messages=messages or [],
|
||||
metadata=metadata or {},
|
||||
last_consolidated=last_consolidated,
|
||||
)
|
||||
if updated_at is not None:
|
||||
session.updated_at = updated_at
|
||||
return session
|
||||
|
||||
|
||||
def _make_autocompact(
|
||||
ttl: int = 15,
|
||||
sessions: SessionManager | None = None,
|
||||
consolidator: MagicMock | None = None,
|
||||
) -> AutoCompact:
|
||||
"""Create an AutoCompact with mock dependencies."""
|
||||
if sessions is None:
|
||||
sessions = MagicMock(spec=SessionManager)
|
||||
if consolidator is None:
|
||||
consolidator = MagicMock()
|
||||
consolidator.archive = AsyncMock(return_value="Summary.")
|
||||
return AutoCompact(
|
||||
sessions=sessions,
|
||||
consolidator=consolidator,
|
||||
session_ttl_minutes=ttl,
|
||||
)
|
||||
|
||||
|
||||
def _add_turns(session: Session, turns: int, *, prefix: str = "msg") -> None:
|
||||
"""Append simple user/assistant turns to a session."""
|
||||
for i in range(turns):
|
||||
session.add_message("user", f"{prefix} user {i}")
|
||||
session.add_message("assistant", f"{prefix} assistant {i}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# __init__
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInit:
|
||||
"""Test AutoCompact.__init__ stores constructor arguments correctly."""
|
||||
|
||||
def test_stores_ttl(self):
|
||||
"""_ttl should match session_ttl_minutes argument."""
|
||||
ac = _make_autocompact(ttl=30)
|
||||
assert ac._ttl == 30
|
||||
|
||||
def test_default_ttl_is_zero(self):
|
||||
"""Default TTL should be 0."""
|
||||
ac = _make_autocompact(ttl=0)
|
||||
assert ac._ttl == 0
|
||||
|
||||
def test_archiving_set_is_empty(self):
|
||||
"""_archiving should start as an empty set."""
|
||||
ac = _make_autocompact()
|
||||
assert ac._archiving == set()
|
||||
|
||||
def test_summaries_dict_is_empty(self):
|
||||
"""_summaries should start as an empty dict."""
|
||||
ac = _make_autocompact()
|
||||
assert ac._summaries == {}
|
||||
|
||||
def test_stores_sessions_reference(self):
|
||||
"""sessions attribute should reference the passed SessionManager."""
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
ac = _make_autocompact(sessions=mock_sm)
|
||||
assert ac.sessions is mock_sm
|
||||
|
||||
def test_stores_consolidator_reference(self):
|
||||
"""consolidator attribute should reference the passed Consolidator."""
|
||||
mock_c = MagicMock()
|
||||
ac = _make_autocompact(consolidator=mock_c)
|
||||
assert ac.consolidator is mock_c
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_expired
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsExpired:
|
||||
"""Test AutoCompact._is_expired edge cases."""
|
||||
|
||||
def test_ttl_zero_always_false(self):
|
||||
"""TTL=0 means auto-compact is disabled; always returns False."""
|
||||
ac = _make_autocompact(ttl=0)
|
||||
old = datetime.now() - timedelta(days=365)
|
||||
assert ac._is_expired(old) is False
|
||||
|
||||
def test_none_timestamp_returns_false(self):
|
||||
"""None timestamp should return False."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
assert ac._is_expired(None) is False
|
||||
|
||||
def test_empty_string_timestamp_returns_false(self):
|
||||
"""Empty string timestamp should return False (falsy)."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
assert ac._is_expired("") is False
|
||||
|
||||
def test_exactly_at_boundary_is_expired(self):
|
||||
"""Timestamp exactly at TTL boundary should be expired (>=)."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
now = datetime(2026, 1, 1, 12, 0, 0)
|
||||
ts = now - timedelta(minutes=15)
|
||||
assert ac._is_expired(ts, now=now) is True
|
||||
|
||||
def test_just_under_boundary_not_expired(self):
|
||||
"""Timestamp just under TTL boundary should NOT be expired."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
now = datetime(2026, 1, 1, 12, 0, 0)
|
||||
ts = now - timedelta(minutes=14, seconds=59)
|
||||
assert ac._is_expired(ts, now=now) is False
|
||||
|
||||
def test_iso_string_parses_correctly(self):
|
||||
"""ISO format string timestamp should be parsed and evaluated."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
now = datetime(2026, 1, 1, 12, 0, 0)
|
||||
ts = (now - timedelta(minutes=20)).isoformat()
|
||||
assert ac._is_expired(ts, now=now) is True
|
||||
|
||||
def test_custom_now_parameter(self):
|
||||
"""Custom 'now' parameter should override datetime.now()."""
|
||||
ac = _make_autocompact(ttl=10)
|
||||
ts = datetime(2026, 1, 1, 10, 0, 0)
|
||||
# 9 minutes later → not expired
|
||||
now_under = datetime(2026, 1, 1, 10, 9, 0)
|
||||
assert ac._is_expired(ts, now=now_under) is False
|
||||
# 10 minutes later → expired
|
||||
now_over = datetime(2026, 1, 1, 10, 10, 0)
|
||||
assert ac._is_expired(ts, now=now_over) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _format_summary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFormatSummary:
|
||||
"""Test AutoCompact._format_summary static method."""
|
||||
|
||||
def test_contains_isoformat_timestamp(self):
|
||||
"""Output should contain last_active as isoformat."""
|
||||
last_active = datetime(2026, 5, 13, 14, 30, 0)
|
||||
result = AutoCompact._format_summary("Some text", last_active)
|
||||
assert "2026-05-13T14:30:00" in result
|
||||
|
||||
def test_contains_summary_text(self):
|
||||
"""Output should contain the provided text verbatim."""
|
||||
last_active = datetime(2026, 1, 1)
|
||||
result = AutoCompact._format_summary("User discussed Python.", last_active)
|
||||
assert "User discussed Python." in result
|
||||
|
||||
def test_output_starts_with_label(self):
|
||||
"""Output should start with the standard prefix."""
|
||||
last_active = datetime(2026, 1, 1)
|
||||
result = AutoCompact._format_summary("text", last_active)
|
||||
assert result.startswith("Previous conversation summary (last active ")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _split_unconsolidated
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSplitUnconsolidated:
|
||||
"""Test AutoCompact._split_unconsolidated splitting logic."""
|
||||
|
||||
def test_empty_session_returns_both_empty(self):
|
||||
"""Empty session should return ([], [])."""
|
||||
ac = _make_autocompact()
|
||||
session = _make_session(messages=[])
|
||||
archive, kept = ac._split_unconsolidated(session)
|
||||
assert archive == []
|
||||
assert kept == []
|
||||
|
||||
def test_all_messages_archivable_when_more_than_suffix(self):
|
||||
"""Session with many messages should archive a prefix and keep suffix."""
|
||||
ac = _make_autocompact()
|
||||
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||
session = _make_session(messages=msgs)
|
||||
archive, kept = ac._split_unconsolidated(session)
|
||||
assert len(archive) > 0
|
||||
assert len(kept) <= AutoCompact._RECENT_SUFFIX_MESSAGES
|
||||
|
||||
def test_fewer_messages_than_suffix_returns_empty_archive(self):
|
||||
"""Session with fewer messages than suffix should have empty archive."""
|
||||
ac = _make_autocompact()
|
||||
msgs = [{"role": "user", "content": f"u{i}"} for i in range(3)]
|
||||
session = _make_session(messages=msgs)
|
||||
archive, kept = ac._split_unconsolidated(session)
|
||||
assert archive == []
|
||||
assert len(kept) == len(msgs)
|
||||
|
||||
def test_respects_last_consolidated_offset(self):
|
||||
"""Only messages after last_consolidated should be considered."""
|
||||
ac = _make_autocompact()
|
||||
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||
# First 10 are already consolidated
|
||||
session = _make_session(messages=msgs, last_consolidated=10)
|
||||
archive, kept = ac._split_unconsolidated(session)
|
||||
# Only the tail of 10 messages is considered for splitting
|
||||
assert all(m["content"] in [f"u{i}" for i in range(10, 20)] for m in kept)
|
||||
assert all(m["content"] in [f"u{i}" for i in range(10, 20)] for m in archive)
|
||||
|
||||
def test_retain_recent_legal_suffix_keeps_last_n(self):
|
||||
"""The kept suffix should be at most _RECENT_SUFFIX_MESSAGES long."""
|
||||
ac = _make_autocompact()
|
||||
# 20 user messages = 20 messages total, all after last_consolidated=0
|
||||
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||
session = _make_session(messages=msgs)
|
||||
archive, kept = ac._split_unconsolidated(session)
|
||||
assert len(kept) <= AutoCompact._RECENT_SUFFIX_MESSAGES
|
||||
assert len(archive) == len(msgs) - len(kept)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_expired
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckExpired:
|
||||
"""Test AutoCompact.check_expired scheduling logic."""
|
||||
|
||||
def test_empty_sessions_list(self):
|
||||
"""No sessions → schedule_background should never be called."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
mock_sm.list_sessions.return_value = []
|
||||
ac.sessions = mock_sm
|
||||
scheduler = MagicMock()
|
||||
ac.check_expired(scheduler)
|
||||
scheduler.assert_not_called()
|
||||
|
||||
def test_expired_session_schedules_background(self):
|
||||
"""Expired session should trigger schedule_background."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
old_ts = (datetime.now() - timedelta(minutes=20)).isoformat()
|
||||
mock_sm.list_sessions.return_value = [{"key": "cli:old", "updated_at": old_ts}]
|
||||
ac.sessions = mock_sm
|
||||
scheduler = MagicMock()
|
||||
ac.check_expired(scheduler)
|
||||
scheduler.assert_called_once()
|
||||
assert "cli:old" in ac._archiving
|
||||
|
||||
def test_active_session_key_skips(self):
|
||||
"""Session in active_session_keys should be skipped."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
old_ts = (datetime.now() - timedelta(minutes=20)).isoformat()
|
||||
mock_sm.list_sessions.return_value = [{"key": "cli:busy", "updated_at": old_ts}]
|
||||
ac.sessions = mock_sm
|
||||
scheduler = MagicMock()
|
||||
ac.check_expired(scheduler, active_session_keys={"cli:busy"})
|
||||
scheduler.assert_not_called()
|
||||
|
||||
def test_session_already_in_archiving_skips(self):
|
||||
"""Session already in _archiving set should be skipped."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
old_ts = (datetime.now() - timedelta(minutes=20)).isoformat()
|
||||
mock_sm.list_sessions.return_value = [{"key": "cli:dup", "updated_at": old_ts}]
|
||||
ac.sessions = mock_sm
|
||||
ac._archiving.add("cli:dup")
|
||||
scheduler = MagicMock()
|
||||
ac.check_expired(scheduler)
|
||||
scheduler.assert_not_called()
|
||||
|
||||
def test_session_with_no_key_skips(self):
|
||||
"""Session info with empty/missing key should be skipped."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
mock_sm.list_sessions.return_value = [{"key": "", "updated_at": "old"}]
|
||||
ac.sessions = mock_sm
|
||||
scheduler = MagicMock()
|
||||
ac.check_expired(scheduler)
|
||||
scheduler.assert_not_called()
|
||||
|
||||
def test_session_with_missing_key_field_skips(self):
|
||||
"""Session info dict without 'key' field should be skipped."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
mock_sm.list_sessions.return_value = [{"updated_at": "old"}]
|
||||
ac.sessions = mock_sm
|
||||
scheduler = MagicMock()
|
||||
ac.check_expired(scheduler)
|
||||
scheduler.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _archive
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestArchive:
|
||||
"""Test AutoCompact._archive async method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_session_updates_timestamp_no_archive_call(self):
|
||||
"""Empty session should refresh updated_at and not call consolidator.archive."""
|
||||
ac = _make_autocompact()
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
empty_session = _make_session(messages=[])
|
||||
mock_sm.get_or_create.return_value = empty_session
|
||||
ac.sessions = mock_sm
|
||||
ac.consolidator.archive = AsyncMock(return_value="Summary.")
|
||||
|
||||
await ac._archive("cli:test")
|
||||
|
||||
ac.consolidator.archive.assert_not_called()
|
||||
mock_sm.save.assert_called_once_with(empty_session)
|
||||
# updated_at was refreshed
|
||||
assert empty_session.updated_at > datetime.now() - timedelta(seconds=5)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_archive_returns_empty_string_no_summary_stored(self):
|
||||
"""If archive returns empty string, no summary should be stored."""
|
||||
ac = _make_autocompact()
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||
session = _make_session(messages=msgs)
|
||||
mock_sm.get_or_create.return_value = session
|
||||
ac.sessions = mock_sm
|
||||
ac.consolidator.archive = AsyncMock(return_value="")
|
||||
|
||||
await ac._archive("cli:test")
|
||||
|
||||
assert "cli:test" not in ac._summaries
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_archive_returns_nothing_no_summary_stored(self):
|
||||
"""If archive returns '(nothing)', no summary should be stored."""
|
||||
ac = _make_autocompact()
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||
session = _make_session(messages=msgs)
|
||||
mock_sm.get_or_create.return_value = session
|
||||
ac.sessions = mock_sm
|
||||
ac.consolidator.archive = AsyncMock(return_value="(nothing)")
|
||||
|
||||
await ac._archive("cli:test")
|
||||
|
||||
assert "cli:test" not in ac._summaries
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_archive_exception_caught_key_removed_from_archiving(self):
|
||||
"""If archive raises, exception is caught and key removed from _archiving."""
|
||||
ac = _make_autocompact()
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||
session = _make_session(messages=msgs)
|
||||
mock_sm.get_or_create.return_value = session
|
||||
ac.sessions = mock_sm
|
||||
ac.consolidator.archive = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
|
||||
# Should not raise
|
||||
await ac._archive("cli:test")
|
||||
|
||||
assert "cli:test" not in ac._archiving
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_archive_stores_summary_in_summaries_and_metadata(self):
|
||||
"""Successful archive should store summary in _summaries dict and metadata."""
|
||||
ac = _make_autocompact()
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||
last_active = datetime(2026, 5, 13, 10, 0, 0)
|
||||
session = _make_session(messages=msgs, updated_at=last_active)
|
||||
mock_sm.get_or_create.return_value = session
|
||||
ac.sessions = mock_sm
|
||||
ac.consolidator.archive = AsyncMock(return_value="User discussed AI.")
|
||||
|
||||
await ac._archive("cli:test")
|
||||
|
||||
# _summaries
|
||||
entry = ac._summaries.get("cli:test")
|
||||
assert entry is not None
|
||||
assert entry[0] == "User discussed AI."
|
||||
assert entry[1] == last_active
|
||||
# metadata
|
||||
meta = session.metadata.get("_last_summary")
|
||||
assert meta is not None
|
||||
assert meta["text"] == "User discussed AI."
|
||||
assert "last_active" in meta
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finally_block_always_removes_from_archiving(self):
|
||||
"""Finally block should always remove key from _archiving, even on error."""
|
||||
ac = _make_autocompact()
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||
session = _make_session(messages=msgs)
|
||||
mock_sm.get_or_create.return_value = session
|
||||
ac.sessions = mock_sm
|
||||
ac.consolidator.archive = AsyncMock(side_effect=RuntimeError("fail"))
|
||||
|
||||
# Pre-add key to archiving to verify it gets removed
|
||||
ac._archiving.add("cli:test")
|
||||
await ac._archive("cli:test")
|
||||
assert "cli:test" not in ac._archiving
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finally_removes_from_archiving_on_success(self):
|
||||
"""Finally block should remove key from _archiving on success too."""
|
||||
ac = _make_autocompact()
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||
session = _make_session(messages=msgs)
|
||||
mock_sm.get_or_create.return_value = session
|
||||
ac.sessions = mock_sm
|
||||
ac.consolidator.archive = AsyncMock(return_value="Summary.")
|
||||
|
||||
ac._archiving.add("cli:test")
|
||||
await ac._archive("cli:test")
|
||||
assert "cli:test" not in ac._archiving
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# prepare_session
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPrepareSession:
|
||||
"""Test AutoCompact.prepare_session logic."""
|
||||
|
||||
def test_key_in_archiving_reloads_session(self):
|
||||
"""If key is in _archiving, session should be reloaded via get_or_create."""
|
||||
ac = _make_autocompact()
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
reloaded = _make_session(key="cli:test")
|
||||
mock_sm.get_or_create.return_value = reloaded
|
||||
ac.sessions = mock_sm
|
||||
ac._archiving.add("cli:test")
|
||||
|
||||
original_session = _make_session()
|
||||
result_session, summary = ac.prepare_session(original_session, "cli:test")
|
||||
|
||||
mock_sm.get_or_create.assert_called_once_with("cli:test")
|
||||
assert result_session is reloaded
|
||||
|
||||
def test_expired_session_reloads(self):
|
||||
"""If session is expired, it should be reloaded via get_or_create."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
reloaded = _make_session(key="cli:test", updated_at=datetime.now())
|
||||
mock_sm.get_or_create.return_value = reloaded
|
||||
ac.sessions = mock_sm
|
||||
|
||||
old_session = _make_session(updated_at=datetime.now() - timedelta(minutes=20))
|
||||
result_session, summary = ac.prepare_session(old_session, "cli:test")
|
||||
|
||||
mock_sm.get_or_create.assert_called_once_with("cli:test")
|
||||
assert result_session is reloaded
|
||||
|
||||
def test_hot_path_summary_from_summaries(self):
|
||||
"""Summary from _summaries dict should be returned (hot path)."""
|
||||
ac = _make_autocompact()
|
||||
session = _make_session()
|
||||
last_active = datetime(2026, 5, 13, 14, 0, 0)
|
||||
ac._summaries["cli:test"] = ("Hot summary.", last_active)
|
||||
|
||||
result_session, summary = ac.prepare_session(session, "cli:test")
|
||||
|
||||
assert result_session is session
|
||||
assert summary is not None
|
||||
assert "Hot summary." in summary
|
||||
assert "Previous conversation summary" in summary
|
||||
|
||||
def test_hot_path_pops_summary_one_shot(self):
|
||||
"""Hot path should pop the summary (one-shot; second call returns None)."""
|
||||
ac = _make_autocompact()
|
||||
session = _make_session()
|
||||
last_active = datetime(2026, 1, 1)
|
||||
ac._summaries["cli:test"] = ("One-shot.", last_active)
|
||||
|
||||
_, summary1 = ac.prepare_session(session, "cli:test")
|
||||
assert summary1 is not None
|
||||
# Second call: hot path entry was popped
|
||||
_, summary2 = ac.prepare_session(session, "cli:test")
|
||||
assert summary2 is None
|
||||
|
||||
def test_cold_path_summary_from_metadata(self):
|
||||
"""When _summaries is empty, summary should come from metadata (cold path)."""
|
||||
ac = _make_autocompact()
|
||||
last_active = datetime(2026, 5, 13, 14, 0, 0)
|
||||
session = _make_session(metadata={
|
||||
"_last_summary": {
|
||||
"text": "Cold summary.",
|
||||
"last_active": last_active.isoformat(),
|
||||
},
|
||||
})
|
||||
|
||||
result_session, summary = ac.prepare_session(session, "cli:test")
|
||||
|
||||
assert result_session is session
|
||||
assert summary is not None
|
||||
assert "Cold summary." in summary
|
||||
|
||||
def test_no_summary_available_returns_none(self):
|
||||
"""When no summary is available, should return (session, None)."""
|
||||
ac = _make_autocompact()
|
||||
session = _make_session()
|
||||
|
||||
result_session, summary = ac.prepare_session(session, "cli:test")
|
||||
|
||||
assert result_session is session
|
||||
assert summary is None
|
||||
|
||||
def test_cold_path_metadata_not_dict_returns_none(self):
|
||||
"""If metadata _last_summary is not a dict, should return None summary."""
|
||||
ac = _make_autocompact()
|
||||
session = _make_session(metadata={"_last_summary": "not a dict"})
|
||||
|
||||
result_session, summary = ac.prepare_session(session, "cli:test")
|
||||
|
||||
assert result_session is session
|
||||
assert summary is None
|
||||
|
||||
def test_hot_path_takes_priority_over_metadata(self):
|
||||
"""Hot path (_summaries) should take priority over metadata."""
|
||||
ac = _make_autocompact()
|
||||
session = _make_session(metadata={
|
||||
"_last_summary": {
|
||||
"text": "Cold summary.",
|
||||
"last_active": datetime(2026, 1, 1).isoformat(),
|
||||
},
|
||||
})
|
||||
last_active = datetime(2026, 5, 13, 14, 0, 0)
|
||||
ac._summaries["cli:test"] = ("Hot summary.", last_active)
|
||||
|
||||
_, summary = ac.prepare_session(session, "cli:test")
|
||||
assert "Hot summary." in summary
|
||||
# After hot path pops, cold path would kick in on next call
|
||||
333
tests/agent/test_context_builder.py
Normal file
333
tests/agent/test_context_builder.py
Normal file
@ -0,0 +1,333 @@
|
||||
"""Tests for ContextBuilder — system prompt and message assembly."""
|
||||
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _builder(tmp_path: Path, **kw) -> ContextBuilder:
|
||||
return ContextBuilder(workspace=tmp_path, **kw)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_runtime_context (static)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildRuntimeContext:
|
||||
def test_time_only(self):
|
||||
ctx = ContextBuilder._build_runtime_context(None, None)
|
||||
assert "[Runtime Context" in ctx
|
||||
assert "[/Runtime Context]" in ctx
|
||||
assert "Current Time:" in ctx
|
||||
assert "Channel:" not in ctx
|
||||
|
||||
def test_with_channel_and_chat_id(self):
|
||||
ctx = ContextBuilder._build_runtime_context("telegram", "chat123")
|
||||
assert "Channel: telegram" in ctx
|
||||
assert "Chat ID: chat123" in ctx
|
||||
|
||||
def test_with_sender_id(self):
|
||||
ctx = ContextBuilder._build_runtime_context("cli", "direct", sender_id="user1")
|
||||
assert "Sender ID: user1" in ctx
|
||||
|
||||
def test_with_timezone(self):
|
||||
ctx = ContextBuilder._build_runtime_context(None, None, timezone="Asia/Shanghai")
|
||||
assert "Current Time:" in ctx
|
||||
|
||||
def test_no_channel_no_chat_id_omits_both(self):
|
||||
ctx = ContextBuilder._build_runtime_context(None, None)
|
||||
assert "Channel:" not in ctx
|
||||
assert "Chat ID:" not in ctx
|
||||
|
||||
def test_no_sender_id_omits(self):
|
||||
ctx = ContextBuilder._build_runtime_context("cli", "direct")
|
||||
assert "Sender ID:" not in ctx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _merge_message_content (static)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMergeMessageContent:
|
||||
def test_str_plus_str(self):
|
||||
result = ContextBuilder._merge_message_content("hello", "world")
|
||||
assert result == "hello\n\nworld"
|
||||
|
||||
def test_empty_left_plus_str(self):
|
||||
result = ContextBuilder._merge_message_content("", "world")
|
||||
assert result == "world"
|
||||
|
||||
def test_list_plus_list(self):
|
||||
left = [{"type": "text", "text": "a"}]
|
||||
right = [{"type": "text", "text": "b"}]
|
||||
result = ContextBuilder._merge_message_content(left, right)
|
||||
assert len(result) == 2
|
||||
assert result[0]["text"] == "a"
|
||||
assert result[1]["text"] == "b"
|
||||
|
||||
def test_str_plus_list(self):
|
||||
right = [{"type": "text", "text": "b"}]
|
||||
result = ContextBuilder._merge_message_content("hello", right)
|
||||
assert len(result) == 2
|
||||
assert result[0]["text"] == "hello"
|
||||
assert result[1]["text"] == "b"
|
||||
|
||||
def test_list_plus_str(self):
|
||||
left = [{"type": "text", "text": "a"}]
|
||||
result = ContextBuilder._merge_message_content(left, "world")
|
||||
assert len(result) == 2
|
||||
assert result[0]["text"] == "a"
|
||||
assert result[1]["text"] == "world"
|
||||
|
||||
def test_none_plus_str(self):
|
||||
result = ContextBuilder._merge_message_content(None, "hello")
|
||||
assert result == [{"type": "text", "text": "hello"}]
|
||||
|
||||
def test_str_plus_none(self):
|
||||
result = ContextBuilder._merge_message_content("hello", None)
|
||||
assert result == [{"type": "text", "text": "hello"}]
|
||||
|
||||
def test_none_plus_none(self):
|
||||
result = ContextBuilder._merge_message_content(None, None)
|
||||
assert result == []
|
||||
|
||||
def test_list_items_not_dicts_wrapped(self):
|
||||
result = ContextBuilder._merge_message_content(["raw_item"], None)
|
||||
assert result == [{"type": "text", "text": "raw_item"}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _load_bootstrap_files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLoadBootstrapFiles:
|
||||
def test_no_bootstrap_files(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
assert builder._load_bootstrap_files() == ""
|
||||
|
||||
def test_agents_md(self, tmp_path):
|
||||
(tmp_path / "AGENTS.md").write_text("Be helpful.", encoding="utf-8")
|
||||
builder = _builder(tmp_path)
|
||||
result = builder._load_bootstrap_files()
|
||||
assert "## AGENTS.md" in result
|
||||
assert "Be helpful." in result
|
||||
|
||||
def test_multiple_bootstrap_files(self, tmp_path):
|
||||
(tmp_path / "AGENTS.md").write_text("Rules.", encoding="utf-8")
|
||||
(tmp_path / "SOUL.md").write_text("Soul.", encoding="utf-8")
|
||||
builder = _builder(tmp_path)
|
||||
result = builder._load_bootstrap_files()
|
||||
assert "## AGENTS.md" in result
|
||||
assert "## SOUL.md" in result
|
||||
assert "Rules." in result
|
||||
assert "Soul." in result
|
||||
|
||||
def test_all_bootstrap_files(self, tmp_path):
|
||||
for name in ContextBuilder.BOOTSTRAP_FILES:
|
||||
(tmp_path / name).write_text(f"Content of {name}", encoding="utf-8")
|
||||
builder = _builder(tmp_path)
|
||||
result = builder._load_bootstrap_files()
|
||||
for name in ContextBuilder.BOOTSTRAP_FILES:
|
||||
assert f"## {name}" in result
|
||||
|
||||
def test_utf8_content(self, tmp_path):
|
||||
(tmp_path / "AGENTS.md").write_text("用中文回复", encoding="utf-8")
|
||||
builder = _builder(tmp_path)
|
||||
result = builder._load_bootstrap_files()
|
||||
assert "用中文回复" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_template_content (static)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsTemplateContent:
|
||||
def test_nonexistent_template_returns_false(self):
|
||||
assert ContextBuilder._is_template_content("anything", "nonexistent/path.md") is False
|
||||
|
||||
def test_content_matching_template(self):
|
||||
from importlib.resources import files as pkg_files
|
||||
tpl = pkg_files("nanobot") / "templates" / "memory" / "MEMORY.md"
|
||||
if not tpl.is_file():
|
||||
pytest.skip("MEMORY.md template not bundled")
|
||||
original = tpl.read_text(encoding="utf-8")
|
||||
assert ContextBuilder._is_template_content(original, "memory/MEMORY.md") is True
|
||||
|
||||
def test_modified_content_returns_false(self):
|
||||
from importlib.resources import files as pkg_files
|
||||
tpl = pkg_files("nanobot") / "templates" / "memory" / "MEMORY.md"
|
||||
if not tpl.is_file():
|
||||
pytest.skip("MEMORY.md template not bundled")
|
||||
assert ContextBuilder._is_template_content("totally different", "memory/MEMORY.md") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_user_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildUserContent:
|
||||
def test_no_media_returns_string(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
result = builder._build_user_content("hello", None)
|
||||
assert result == "hello"
|
||||
|
||||
def test_empty_media_returns_string(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
result = builder._build_user_content("hello", [])
|
||||
assert result == "hello"
|
||||
|
||||
def test_nonexistent_media_file_returns_string(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
result = builder._build_user_content("hello", ["/nonexistent/image.png"])
|
||||
assert result == "hello"
|
||||
|
||||
def test_non_image_file_returns_string(self, tmp_path):
|
||||
txt = tmp_path / "doc.txt"
|
||||
txt.write_text("not an image", encoding="utf-8")
|
||||
builder = _builder(tmp_path)
|
||||
result = builder._build_user_content("hello", [str(txt)])
|
||||
assert result == "hello"
|
||||
|
||||
def test_valid_image_returns_list(self, tmp_path):
|
||||
png = tmp_path / "test.png"
|
||||
png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 16)
|
||||
builder = _builder(tmp_path)
|
||||
result = builder._build_user_content("hello", [str(png)])
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert result[0]["type"] == "image_url"
|
||||
assert result[0]["image_url"]["url"].startswith("data:image/png;base64,")
|
||||
assert result[1]["type"] == "text"
|
||||
assert result[1]["text"] == "hello"
|
||||
|
||||
def test_image_meta_includes_path(self, tmp_path):
|
||||
png = tmp_path / "test.png"
|
||||
png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 16)
|
||||
builder = _builder(tmp_path)
|
||||
result = builder._build_user_content("hello", [str(png)])
|
||||
assert "_meta" in result[0]
|
||||
assert "path" in result[0]["_meta"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_system_prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildSystemPrompt:
|
||||
def test_returns_nonempty_string(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
result = builder.build_system_prompt()
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
def test_includes_identity_section(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
result = builder.build_system_prompt()
|
||||
assert "workspace" in result.lower() or "python" in result.lower()
|
||||
|
||||
def test_includes_bootstrap_files(self, tmp_path):
|
||||
(tmp_path / "AGENTS.md").write_text("Be helpful and concise.", encoding="utf-8")
|
||||
builder = _builder(tmp_path)
|
||||
result = builder.build_system_prompt()
|
||||
assert "Be helpful and concise." in result
|
||||
|
||||
def test_includes_session_summary(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
result = builder.build_system_prompt(session_summary="Previous chat about Python.")
|
||||
assert "Previous chat about Python." in result
|
||||
assert "[Archived Context Summary]" in result
|
||||
|
||||
def test_sections_separated_by_separator(self, tmp_path):
|
||||
(tmp_path / "AGENTS.md").write_text("Rules.", encoding="utf-8")
|
||||
builder = _builder(tmp_path)
|
||||
result = builder.build_system_prompt(session_summary="Summary.")
|
||||
assert "\n\n---\n\n" in result
|
||||
|
||||
def test_no_bootstrap_no_summary(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
result = builder.build_system_prompt()
|
||||
assert "## AGENTS.md" not in result
|
||||
assert "[Archived Context Summary]" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildMessages:
|
||||
def test_basic_empty_history(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
messages = builder.build_messages([], "hello")
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["role"] == "system"
|
||||
assert messages[1]["role"] == "user"
|
||||
assert "hello" in str(messages[1]["content"])
|
||||
|
||||
def test_runtime_context_injected(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
messages = builder.build_messages([], "hello", channel="cli", chat_id="direct")
|
||||
user_msg = str(messages[-1]["content"])
|
||||
assert "[Runtime Context" in user_msg
|
||||
assert "hello" in user_msg
|
||||
|
||||
def test_consecutive_same_role_merged(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
history = [{"role": "user", "content": "previous user message"}]
|
||||
messages = builder.build_messages(history, "new message")
|
||||
assert len(messages) == 2 # system + merged user
|
||||
assert "previous user message" in str(messages[1]["content"])
|
||||
assert "new message" in str(messages[1]["content"])
|
||||
|
||||
def test_different_role_appended(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
history = [{"role": "assistant", "content": "previous response"}]
|
||||
messages = builder.build_messages(history, "new message")
|
||||
assert len(messages) == 3 # system + assistant + user
|
||||
|
||||
def test_media_with_history(self, tmp_path):
|
||||
png = tmp_path / "img.png"
|
||||
png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 16)
|
||||
builder = _builder(tmp_path)
|
||||
history = [{"role": "assistant", "content": "see this"}]
|
||||
messages = builder.build_messages(history, "check image", media=[str(png)])
|
||||
user_msg = messages[-1]["content"]
|
||||
assert isinstance(user_msg, list)
|
||||
assert any(b.get("type") == "image_url" for b in user_msg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# add_tool_result
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAddToolResult:
|
||||
def test_appends_tool_message(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
msgs = [{"role": "user", "content": "hello"}]
|
||||
result = builder.add_tool_result(msgs, "call_123", "read_file", "file content")
|
||||
assert len(result) == 2
|
||||
assert result[1]["role"] == "tool"
|
||||
assert result[1]["tool_call_id"] == "call_123"
|
||||
assert result[1]["name"] == "read_file"
|
||||
assert result[1]["content"] == "file content"
|
||||
|
||||
def test_returns_same_list(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
msgs = []
|
||||
result = builder.add_tool_result(msgs, "id", "tool", "ok")
|
||||
assert result is msgs
|
||||
@ -13,6 +13,17 @@ def _ctx() -> AgentHookContext:
|
||||
return AgentHookContext(iteration=0, messages=[])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Base AgentHook emit_reasoning: no-op
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_hook_emit_reasoning_is_noop():
|
||||
hook = AgentHook()
|
||||
await hook.emit_reasoning("should not raise")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fan-out: every hook is called in order
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -45,6 +56,9 @@ async def test_composite_fans_out_all_async_methods():
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
events.append("before_iteration")
|
||||
|
||||
async def emit_reasoning(self, reasoning_content: str | None) -> None:
|
||||
events.append(f"emit_reasoning:{reasoning_content}")
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
events.append(f"on_stream:{delta}")
|
||||
|
||||
@ -61,6 +75,7 @@ async def test_composite_fans_out_all_async_methods():
|
||||
ctx = _ctx()
|
||||
|
||||
await hook.before_iteration(ctx)
|
||||
await hook.emit_reasoning("thinking...")
|
||||
await hook.on_stream(ctx, "hi")
|
||||
await hook.on_stream_end(ctx, resuming=True)
|
||||
await hook.before_execute_tools(ctx)
|
||||
@ -68,6 +83,7 @@ async def test_composite_fans_out_all_async_methods():
|
||||
|
||||
assert events == [
|
||||
"before_iteration", "before_iteration",
|
||||
"emit_reasoning:thinking...", "emit_reasoning:thinking...",
|
||||
"on_stream:hi", "on_stream:hi",
|
||||
"on_stream_end:True", "on_stream_end:True",
|
||||
"before_execute_tools", "before_execute_tools",
|
||||
@ -120,6 +136,8 @@ async def test_composite_error_isolation_all_async():
|
||||
calls: list[str] = []
|
||||
|
||||
class Bad(AgentHook):
|
||||
async def emit_reasoning(self, reasoning_content):
|
||||
raise RuntimeError("err")
|
||||
async def on_stream_end(self, context, *, resuming):
|
||||
raise RuntimeError("err")
|
||||
async def before_execute_tools(self, context):
|
||||
@ -128,6 +146,8 @@ async def test_composite_error_isolation_all_async():
|
||||
raise RuntimeError("err")
|
||||
|
||||
class Good(AgentHook):
|
||||
async def emit_reasoning(self, reasoning_content):
|
||||
calls.append("emit_reasoning")
|
||||
async def on_stream_end(self, context, *, resuming):
|
||||
calls.append("on_stream_end")
|
||||
async def before_execute_tools(self, context):
|
||||
@ -137,10 +157,11 @@ async def test_composite_error_isolation_all_async():
|
||||
|
||||
hook = CompositeHook([Bad(), Good()])
|
||||
ctx = _ctx()
|
||||
await hook.emit_reasoning("test")
|
||||
await hook.on_stream_end(ctx, resuming=False)
|
||||
await hook.before_execute_tools(ctx)
|
||||
await hook.after_iteration(ctx)
|
||||
assert calls == ["on_stream_end", "before_execute_tools", "after_iteration"]
|
||||
assert calls == ["emit_reasoning", "on_stream_end", "before_execute_tools", "after_iteration"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
301
tests/agent/test_loop_runner_integration.py
Normal file
301
tests/agent/test_loop_runner_integration.py
Normal file
@ -0,0 +1,301 @@
|
||||
"""Tests for AgentLoop integration with AgentRunner: streaming, think-filter, error handling, subagent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
|
||||
def _make_loop(tmp_path):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path)
|
||||
return loop
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_max_iterations_message_stays_stable(tmp_path):
|
||||
loop = _make_loop(tmp_path)
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
loop.max_iterations = 2
|
||||
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop([])
|
||||
|
||||
assert final_content == (
|
||||
"I reached the maximum number of tool call iterations (2) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp_path):
|
||||
loop = _make_loop(tmp_path)
|
||||
deltas: list[str] = []
|
||||
endings: list[bool] = []
|
||||
|
||||
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
|
||||
await on_content_delta("<think>hidden")
|
||||
await on_content_delta("</think>Hello")
|
||||
return LLMResponse(content="<think>hidden</think>Hello", tool_calls=[], usage={})
|
||||
|
||||
loop.provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
deltas.append(delta)
|
||||
|
||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||
endings.append(resuming)
|
||||
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop(
|
||||
[],
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
)
|
||||
|
||||
assert final_content == "Hello"
|
||||
assert deltas == ["Hello"]
|
||||
assert endings == [False]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_stream_filter_hides_partial_trailing_think_prefix(tmp_path):
|
||||
loop = _make_loop(tmp_path)
|
||||
deltas: list[str] = []
|
||||
|
||||
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
|
||||
await on_content_delta("Hello <thin")
|
||||
await on_content_delta("k>hidden</think>World")
|
||||
return LLMResponse(content="Hello <think>hidden</think>World", tool_calls=[], usage={})
|
||||
|
||||
loop.provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
deltas.append(delta)
|
||||
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop([], on_stream=on_stream)
|
||||
|
||||
assert final_content == "Hello World"
|
||||
assert deltas == ["Hello", " World"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_stream_filter_hides_complete_trailing_think_tag(tmp_path):
|
||||
loop = _make_loop(tmp_path)
|
||||
deltas: list[str] = []
|
||||
|
||||
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
|
||||
await on_content_delta("Hello <think>")
|
||||
await on_content_delta("hidden</think>World")
|
||||
return LLMResponse(content="Hello <think>hidden</think>World", tool_calls=[], usage={})
|
||||
|
||||
loop.provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
deltas.append(delta)
|
||||
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop([], on_stream=on_stream)
|
||||
|
||||
assert final_content == "Hello World"
|
||||
assert deltas == ["Hello", " World"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_retries_think_only_final_response(tmp_path):
|
||||
loop = _make_loop(tmp_path)
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(content="<think>hidden</think>", tool_calls=[], usage={})
|
||||
return LLMResponse(content="Recovered answer", tool_calls=[], usage={})
|
||||
|
||||
loop.provider.chat_with_retry = chat_with_retry
|
||||
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop([])
|
||||
|
||||
assert final_content == "Recovered answer"
|
||||
assert call_count["n"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streamed_flag_not_set_on_llm_error(tmp_path):
|
||||
"""When LLM errors during a streaming-capable channel interaction,
|
||||
_streamed must NOT be set so ChannelManager delivers the error."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||
error_resp = LLMResponse(
|
||||
content="503 service unavailable", finish_reason="error", tool_calls=[], usage={},
|
||||
)
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=error_resp)
|
||||
loop.provider.chat_stream_with_retry = AsyncMock(return_value=error_resp)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
msg = InboundMessage(
|
||||
channel="feishu", sender_id="u1", chat_id="c1", content="hi",
|
||||
)
|
||||
result = await loop._process_message(
|
||||
msg,
|
||||
on_stream=AsyncMock(),
|
||||
on_stream_end=AsyncMock(),
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert "503" in result.content
|
||||
assert not result.metadata.get("_streamed"), \
|
||||
"_streamed must not be set when stop_reason is error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ssrf_soft_block_can_finalize_after_streamed_tool_call(tmp_path):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
tool_call_resp = LLMResponse(
|
||||
content="checking metadata",
|
||||
tool_calls=[ToolCallRequest(
|
||||
id="call_ssrf",
|
||||
name="exec",
|
||||
arguments={"command": "curl http://169.254.169.254/latest/meta-data/"},
|
||||
)],
|
||||
usage={},
|
||||
)
|
||||
provider.chat_stream_with_retry = AsyncMock(side_effect=[
|
||||
tool_call_resp,
|
||||
LLMResponse(
|
||||
content="I cannot access private URLs. Please share the local file.",
|
||||
tool_calls=[],
|
||||
usage={},
|
||||
),
|
||||
])
|
||||
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.prepare_call = MagicMock(return_value=(None, {}, None))
|
||||
loop.tools.execute = AsyncMock(return_value=(
|
||||
"Error: Command blocked by safety guard (internal/private URL detected)"
|
||||
))
|
||||
|
||||
result = await loop._process_message(
|
||||
InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="hi"),
|
||||
on_stream=AsyncMock(),
|
||||
on_stream_end=AsyncMock(),
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.content == "I cannot access private URLs. Please share the local file."
|
||||
assert result.metadata.get("_streamed") is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_next_turn_after_llm_error_keeps_turn_boundary(tmp_path):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.runner import _PERSISTED_MODEL_ERROR_PLACEHOLDER
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(side_effect=[
|
||||
LLMResponse(content="429 rate limit exceeded", finish_reason="error", tool_calls=[], usage={}),
|
||||
LLMResponse(content="Recovered answer", tool_calls=[], usage={}),
|
||||
])
|
||||
|
||||
loop = AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model")
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||
|
||||
first = await loop._process_message(
|
||||
InboundMessage(channel="cli", sender_id="user", chat_id="test", content="first question")
|
||||
)
|
||||
assert first is not None
|
||||
assert first.content == "429 rate limit exceeded"
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
assert [
|
||||
{key: value for key, value in message.items() if key in {"role", "content"}}
|
||||
for message in session.messages
|
||||
] == [
|
||||
{"role": "user", "content": "first question"},
|
||||
{"role": "assistant", "content": _PERSISTED_MODEL_ERROR_PLACEHOLDER},
|
||||
]
|
||||
|
||||
second = await loop._process_message(
|
||||
InboundMessage(channel="cli", sender_id="user", chat_id="test", content="second question")
|
||||
)
|
||||
assert second is not None
|
||||
assert second.content == "Recovered answer"
|
||||
|
||||
request_messages = provider.chat_with_retry.await_args_list[1].kwargs["messages"]
|
||||
non_system = [message for message in request_messages if message.get("role") != "system"]
|
||||
assert non_system[0]["role"] == "user"
|
||||
assert "first question" in non_system[0]["content"]
|
||||
assert non_system[1]["role"] == "assistant"
|
||||
assert _PERSISTED_MODEL_ERROR_PLACEHOLDER in non_system[1]["content"]
|
||||
assert non_system[2]["role"] == "user"
|
||||
assert "second question" in non_system[2]["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, monkeypatch):
|
||||
from nanobot.agent.subagent import SubagentManager, SubagentStatus
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
))
|
||||
mgr = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
bus=bus,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
)
|
||||
mgr._announce_result = AsyncMock()
|
||||
|
||||
async def fake_execute(self, **kwargs):
|
||||
return "tool result"
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
|
||||
|
||||
status = SubagentStatus(task_id="sub-1", label="label", task_description="do task", started_at=time.monotonic())
|
||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"}, status)
|
||||
|
||||
mgr._announce_result.assert_awaited_once()
|
||||
args = mgr._announce_result.await_args.args
|
||||
assert args[3] == "Task completed but no final response was generated."
|
||||
assert args[5] == "ok"
|
||||
File diff suppressed because it is too large
Load Diff
481
tests/agent/test_runner_core.py
Normal file
481
tests/agent/test_runner_core.py
Normal file
@ -0,0 +1,481 @@
|
||||
"""Tests for core AgentRunner behavior: message passing, iteration limits,
|
||||
timeouts, empty-response handling, usage accumulation, and config passthrough."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_preserves_reasoning_fields_and_tool_results():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
captured_second_call: list[dict] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
reasoning_content="hidden reasoning",
|
||||
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
|
||||
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="tool result")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "do task"},
|
||||
],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
assert result.tools_used == ["list_dir"]
|
||||
assert result.tool_events == [
|
||||
{"name": "list_dir", "status": "ok", "detail": "tool result"}
|
||||
]
|
||||
|
||||
assistant_messages = [
|
||||
msg for msg in captured_second_call
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls")
|
||||
]
|
||||
assert len(assistant_messages) == 1
|
||||
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
|
||||
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]
|
||||
assert any(
|
||||
msg.get("role") == "tool" and msg.get("content") == "tool result"
|
||||
for msg in captured_second_call
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_returns_max_iterations_fallback():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="still working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
))
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="tool result")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.stop_reason == "max_iterations"
|
||||
assert result.final_content == (
|
||||
"I reached the maximum number of tool call iterations (2) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
assert result.messages[-1]["role"] == "assistant"
|
||||
assert result.messages[-1]["content"] == result.final_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_times_out_hung_llm_request():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
started = time.monotonic()
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hello"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
llm_timeout_s=0.05,
|
||||
))
|
||||
|
||||
assert (time.monotonic() - started) < 1.0
|
||||
assert result.stop_reason == "error"
|
||||
assert "timed out" in (result.final_content or "").lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_replaces_empty_tool_result_with_marker():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
captured_second_call: list[dict] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="noop", arguments={})],
|
||||
usage={},
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
|
||||
assert tool_message["content"] == "(noop completed with no output)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_retries_empty_final_response_with_summary_prompt():
|
||||
"""Empty responses get 2 silent retries before finalization kicks in."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
calls: list[dict] = []
|
||||
|
||||
async def chat_with_retry(*, messages, tools=None, **kwargs):
|
||||
calls.append({"messages": messages, "tools": tools})
|
||||
if len(calls) <= 2:
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 5, "completion_tokens": 1},
|
||||
)
|
||||
return LLMResponse(
|
||||
content="final answer",
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 3, "completion_tokens": 7},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "final answer"
|
||||
# 2 silent retries (iterations 0,1) + finalization on iteration 1
|
||||
assert len(calls) == 3
|
||||
assert calls[0]["tools"] is not None
|
||||
assert calls[1]["tools"] is not None
|
||||
assert calls[2]["tools"] is None
|
||||
assert result.usage["prompt_tokens"] == 13
|
||||
assert result.usage["completion_tokens"] == 9
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_uses_specific_message_after_empty_finalization_retry():
|
||||
"""After silent retries + finalization all return empty, stop_reason is empty_final_response."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
return LLMResponse(content=None, tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
assert result.stop_reason == "empty_final_response"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_empty_response_does_not_break_tool_chain():
|
||||
"""An empty intermediate response must not kill an ongoing tool chain.
|
||||
|
||||
Sequence: tool_call -> empty -> tool_call -> final text.
|
||||
The runner should recover via silent retry and complete normally.
|
||||
"""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
call_count = 0
|
||||
|
||||
async def chat_with_retry(*, messages, tools=None, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[ToolCallRequest(id="tc1", name="read_file", arguments={"path": "a.txt"})],
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 5},
|
||||
)
|
||||
if call_count == 2:
|
||||
return LLMResponse(content=None, tool_calls=[], usage={"prompt_tokens": 10, "completion_tokens": 1})
|
||||
if call_count == 3:
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[ToolCallRequest(id="tc2", name="read_file", arguments={"path": "b.txt"})],
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 5},
|
||||
)
|
||||
return LLMResponse(
|
||||
content="Here are the results.",
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 10},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
provider.chat_stream_with_retry = chat_with_retry
|
||||
|
||||
async def fake_tool(name, args, **kw):
|
||||
return "file content"
|
||||
|
||||
tool_registry = MagicMock()
|
||||
tool_registry.get_definitions.return_value = [{"type": "function", "function": {"name": "read_file"}}]
|
||||
tool_registry.execute = AsyncMock(side_effect=fake_tool)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "read both files"}],
|
||||
tools=tool_registry,
|
||||
model="test-model",
|
||||
max_iterations=10,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "Here are the results."
|
||||
assert result.stop_reason == "completed"
|
||||
assert call_count == 4
|
||||
assert "read_file" in result.tools_used
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_accumulates_usage_and_preserves_cached_tokens():
|
||||
"""Runner should accumulate prompt/completion tokens across iterations
|
||||
and preserve cached_tokens from provider responses."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})],
|
||||
usage={"prompt_tokens": 100, "completion_tokens": 10, "cached_tokens": 80},
|
||||
)
|
||||
return LLMResponse(
|
||||
content="done",
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="file content")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
# Usage should be accumulated across iterations
|
||||
assert result.usage["prompt_tokens"] == 300 # 100 + 200
|
||||
assert result.usage["completion_tokens"] == 30 # 10 + 20
|
||||
assert result.usage["cached_tokens"] == 230 # 80 + 150
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_binds_on_retry_wait_to_retry_callback_not_progress():
|
||||
"""Regression: provider retry heartbeats must route through
|
||||
``retry_wait_callback``, not ``progress_callback``. Binding them to
|
||||
the progress callback (as an earlier runtime refactor did) caused
|
||||
internal retry diagnostics like "Model request failed, retry in 1s"
|
||||
to leak to end-user channels as normal progress updates.
|
||||
"""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
progress_cb = AsyncMock()
|
||||
retry_wait_cb = AsyncMock()
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
await runner.run(AgentRunSpec(
|
||||
initial_messages=[
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "hi"},
|
||||
],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
progress_callback=progress_cb,
|
||||
retry_wait_callback=retry_wait_cb,
|
||||
))
|
||||
|
||||
assert captured["on_retry_wait"] is retry_wait_cb
|
||||
assert captured["on_retry_wait"] is not progress_cb
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config passthrough tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_passes_temperature_to_provider():
|
||||
"""temperature from AgentRunSpec should reach provider.chat_with_retry."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hi"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
temperature=0.7,
|
||||
))
|
||||
|
||||
assert captured["temperature"] == 0.7
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_passes_max_tokens_to_provider():
|
||||
"""max_tokens from AgentRunSpec should reach provider.chat_with_retry."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hi"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
max_tokens=8192,
|
||||
))
|
||||
|
||||
assert captured["max_tokens"] == 8192
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_passes_reasoning_effort_to_provider():
|
||||
"""reasoning_effort from AgentRunSpec should reach provider.chat_with_retry."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hi"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
reasoning_effort="high",
|
||||
))
|
||||
|
||||
assert captured["reasoning_effort"] == "high"
|
||||
171
tests/agent/test_runner_errors.py
Normal file
171
tests/agent/test_runner_errors.py
Normal file
@ -0,0 +1,171 @@
|
||||
"""Tests for AgentRunner error handling: tool errors, LLM errors,
|
||||
session message isolation, and tool result preservation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_returns_structured_tool_error():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
))
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
fail_on_tool_error=True,
|
||||
))
|
||||
|
||||
assert result.stop_reason == "tool_error"
|
||||
assert result.error == "Error: RuntimeError: boom"
|
||||
assert result.tool_events == [
|
||||
{"name": "list_dir", "status": "error", "detail": "boom"}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_error_not_appended_to_session_messages():
|
||||
"""When LLM returns finish_reason='error', the error content must NOT be
|
||||
appended to the messages list (prevents polluting session history)."""
|
||||
from nanobot.agent.runner import (
|
||||
AgentRunSpec,
|
||||
AgentRunner,
|
||||
_PERSISTED_MODEL_ERROR_PLACEHOLDER,
|
||||
)
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="429 rate limit exceeded", finish_reason="error", tool_calls=[], usage={},
|
||||
))
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hello"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=5,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.stop_reason == "error"
|
||||
assert result.final_content == "429 rate limit exceeded"
|
||||
assistant_msgs = [m for m in result.messages if m.get("role") == "assistant"]
|
||||
assert all("429" not in (m.get("content") or "") for m in assistant_msgs), \
|
||||
"Error content should not appear in session messages"
|
||||
assert assistant_msgs[-1]["content"] == _PERSISTED_MODEL_ERROR_PLACEHOLDER
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_tool_error_sets_final_content():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
return LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})],
|
||||
usage={},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
fail_on_tool_error=True,
|
||||
))
|
||||
|
||||
assert result.final_content == "Error: RuntimeError: boom"
|
||||
assert result.stop_reason == "tool_error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_tool_error_preserves_tool_results_in_messages():
|
||||
"""When a tool raises a fatal error, its results must still be appended
|
||||
to messages so the session never contains orphan tool_calls (#2943)."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(id="tc1", name="read_file", arguments={"path": "a"}),
|
||||
ToolCallRequest(id="tc2", name="exec", arguments={"cmd": "bad"}),
|
||||
],
|
||||
usage={},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
provider.chat_stream_with_retry = chat_with_retry
|
||||
|
||||
call_idx = 0
|
||||
|
||||
async def fake_execute(name, args, **kw):
|
||||
nonlocal call_idx
|
||||
call_idx += 1
|
||||
if call_idx == 2:
|
||||
raise RuntimeError("boom")
|
||||
return "file content"
|
||||
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(side_effect=fake_execute)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do stuff"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
fail_on_tool_error=True,
|
||||
))
|
||||
|
||||
assert result.stop_reason == "tool_error"
|
||||
# Both tool results must be in messages even though tc2 had a fatal error.
|
||||
tool_msgs = [m for m in result.messages if m.get("role") == "tool"]
|
||||
assert len(tool_msgs) == 2
|
||||
assert tool_msgs[0]["tool_call_id"] == "tc1"
|
||||
assert tool_msgs[1]["tool_call_id"] == "tc2"
|
||||
# The assistant message with tool_calls must precede the tool results.
|
||||
asst_tc_idx = next(
|
||||
i for i, m in enumerate(result.messages)
|
||||
if m.get("role") == "assistant" and m.get("tool_calls")
|
||||
)
|
||||
tool_indices = [
|
||||
i for i, m in enumerate(result.messages) if m.get("role") == "tool"
|
||||
]
|
||||
assert all(ti > asst_tc_idx for ti in tool_indices)
|
||||
643
tests/agent/test_runner_governance.py
Normal file
643
tests/agent/test_runner_governance.py
Normal file
@ -0,0 +1,643 @@
|
||||
"""Tests for AgentRunner context governance: backfill, orphan cleanup, microcompact, snip_history."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
|
||||
def _make_loop(tmp_path):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path)
|
||||
return loop
|
||||
|
||||
async def test_runner_uses_raw_messages_when_context_governance_fails():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_messages: list[dict] = []
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
captured_messages[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
initial_messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "hello"},
|
||||
]
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
runner._snip_history = MagicMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign]
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
assert captured_messages == initial_messages
|
||||
def test_snip_history_drops_orphaned_tool_results_from_trimmed_slice(monkeypatch):
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
runner = AgentRunner(provider)
|
||||
messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "old user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "tool call",
|
||||
"tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "ls", "arguments": "{}"}}],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "tool output"},
|
||||
{"role": "assistant", "content": "after tool"},
|
||||
]
|
||||
spec = AgentRunSpec(
|
||||
initial_messages=messages,
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
context_window_tokens=2000,
|
||||
context_block_limit=100,
|
||||
)
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", lambda *_args, **_kwargs: (500, None))
|
||||
token_sizes = {
|
||||
"old user": 120,
|
||||
"tool call": 120,
|
||||
"tool output": 40,
|
||||
"after tool": 40,
|
||||
"system": 0,
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
"nanobot.agent.runner.estimate_message_tokens",
|
||||
lambda msg: token_sizes.get(str(msg.get("content")), 40),
|
||||
)
|
||||
|
||||
trimmed = runner._snip_history(spec, messages)
|
||||
|
||||
# After the fix, the user message is recovered so the sequence is valid
|
||||
# for providers that require system → user (e.g. GLM error 1214).
|
||||
assert trimmed[0]["role"] == "system"
|
||||
non_system = [m for m in trimmed if m["role"] != "system"]
|
||||
assert non_system[0]["role"] == "user", f"Expected user after system, got {non_system[0]['role']}"
|
||||
async def test_backfill_missing_tool_results_inserts_error():
|
||||
"""Orphaned tool_use (no matching tool_result) should get a synthetic error."""
|
||||
from nanobot.agent.runner import AgentRunner, _BACKFILL_CONTENT
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "call_a", "type": "function", "function": {"name": "exec", "arguments": "{}"}},
|
||||
{"id": "call_b", "type": "function", "function": {"name": "read_file", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_a", "name": "exec", "content": "ok"},
|
||||
]
|
||||
result = AgentRunner._backfill_missing_tool_results(messages)
|
||||
tool_msgs = [m for m in result if m.get("role") == "tool"]
|
||||
assert len(tool_msgs) == 2
|
||||
backfilled = [m for m in tool_msgs if m.get("tool_call_id") == "call_b"]
|
||||
assert len(backfilled) == 1
|
||||
assert backfilled[0]["content"] == _BACKFILL_CONTENT
|
||||
assert backfilled[0]["name"] == "read_file"
|
||||
|
||||
|
||||
def test_drop_orphan_tool_results_removes_unmatched_tool_messages():
|
||||
from nanobot.agent.runner import AgentRunner
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "old user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "call_ok", "type": "function", "function": {"name": "read_file", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_ok", "name": "read_file", "content": "ok"},
|
||||
{"role": "tool", "tool_call_id": "call_orphan", "name": "exec", "content": "stale"},
|
||||
{"role": "assistant", "content": "after tool"},
|
||||
]
|
||||
|
||||
cleaned = AgentRunner._drop_orphan_tool_results(messages)
|
||||
|
||||
assert cleaned == [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "old user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "call_ok", "type": "function", "function": {"name": "read_file", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_ok", "name": "read_file", "content": "ok"},
|
||||
{"role": "assistant", "content": "after tool"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_noop_when_complete():
|
||||
"""Complete message chains should not be modified."""
|
||||
from nanobot.agent.runner import AgentRunner
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "call_x", "type": "function", "function": {"name": "exec", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_x", "name": "exec", "content": "done"},
|
||||
{"role": "assistant", "content": "all good"},
|
||||
]
|
||||
result = AgentRunner._backfill_missing_tool_results(messages)
|
||||
assert result is messages # same object — no copy
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_drops_orphan_tool_results_before_model_request():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_messages: list[dict] = []
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
captured_messages[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "old user"},
|
||||
{"role": "tool", "tool_call_id": "call_orphan", "name": "exec", "content": "stale"},
|
||||
{"role": "assistant", "content": "after orphan"},
|
||||
{"role": "user", "content": "new prompt"},
|
||||
],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert all(
|
||||
message.get("tool_call_id") != "call_orphan"
|
||||
for message in captured_messages
|
||||
if message.get("role") == "tool"
|
||||
)
|
||||
assert result.messages[2]["tool_call_id"] == "call_orphan"
|
||||
assert result.final_content == "done"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_repairs_model_context_without_shifting_save_turn_boundary(tmp_path):
|
||||
"""Historical backfill should not duplicate old tail messages on persist."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.runner import _BACKFILL_CONTENT
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
response = LLMResponse(content="new answer", tool_calls=[], usage={})
|
||||
provider.chat_with_retry = AsyncMock(return_value=response)
|
||||
provider.chat_stream_with_retry = AsyncMock(return_value=response)
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "old user", "timestamp": "2026-01-01T00:00:00"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_missing",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
"timestamp": "2026-01-01T00:00:01",
|
||||
},
|
||||
{"role": "assistant", "content": "old tail", "timestamp": "2026-01-01T00:00:02"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
|
||||
result = await loop._process_message(
|
||||
InboundMessage(channel="cli", sender_id="user", chat_id="test", content="new prompt")
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.content == "new answer"
|
||||
|
||||
request_messages = provider.chat_with_retry.await_args.kwargs["messages"]
|
||||
synthetic = [
|
||||
message
|
||||
for message in request_messages
|
||||
if message.get("role") == "tool" and message.get("tool_call_id") == "call_missing"
|
||||
]
|
||||
assert len(synthetic) == 1
|
||||
assert synthetic[0]["content"] == _BACKFILL_CONTENT
|
||||
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert [
|
||||
{
|
||||
key: value
|
||||
for key, value in message.items()
|
||||
if key in {"role", "content", "tool_call_id", "name", "tool_calls"}
|
||||
}
|
||||
for message in session_after.messages
|
||||
] == [
|
||||
{"role": "user", "content": "old user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_missing",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": "old tail"},
|
||||
{"role": "user", "content": "new prompt"},
|
||||
{"role": "assistant", "content": "new answer"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_backfill_only_mutates_model_context_not_returned_messages():
|
||||
"""Runner should repair orphaned tool calls for the model without rewriting result.messages."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _BACKFILL_CONTENT
|
||||
|
||||
provider = MagicMock()
|
||||
captured_messages: list[dict] = []
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
captured_messages[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
initial_messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "old user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_missing",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": "old tail"},
|
||||
{"role": "user", "content": "new prompt"},
|
||||
]
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
synthetic = [
|
||||
message
|
||||
for message in captured_messages
|
||||
if message.get("role") == "tool" and message.get("tool_call_id") == "call_missing"
|
||||
]
|
||||
assert len(synthetic) == 1
|
||||
assert synthetic[0]["content"] == _BACKFILL_CONTENT
|
||||
|
||||
assert [
|
||||
{
|
||||
key: value
|
||||
for key, value in message.items()
|
||||
if key in {"role", "content", "tool_call_id", "name", "tool_calls"}
|
||||
}
|
||||
for message in result.messages
|
||||
] == [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "old user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_missing",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": "old tail"},
|
||||
{"role": "user", "content": "new prompt"},
|
||||
{"role": "assistant", "content": "done"},
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Microcompact (stale tool result compaction)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_microcompact_replaces_old_tool_results():
|
||||
"""Tool results beyond _MICROCOMPACT_KEEP_RECENT should be summarized."""
|
||||
from nanobot.agent.runner import AgentRunner, _MICROCOMPACT_KEEP_RECENT
|
||||
|
||||
total = _MICROCOMPACT_KEEP_RECENT + 5
|
||||
long_content = "x" * 600
|
||||
messages: list[dict] = [{"role": "system", "content": "sys"}]
|
||||
for i in range(total):
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [{"id": f"c{i}", "type": "function", "function": {"name": "read_file", "arguments": "{}"}}],
|
||||
})
|
||||
messages.append({
|
||||
"role": "tool", "tool_call_id": f"c{i}", "name": "read_file",
|
||||
"content": long_content,
|
||||
})
|
||||
|
||||
result = AgentRunner._microcompact(messages)
|
||||
tool_msgs = [m for m in result if m.get("role") == "tool"]
|
||||
stale_count = total - _MICROCOMPACT_KEEP_RECENT
|
||||
compacted = [m for m in tool_msgs if "omitted from context" in str(m.get("content", ""))]
|
||||
preserved = [m for m in tool_msgs if m.get("content") == long_content]
|
||||
assert len(compacted) == stale_count
|
||||
assert len(preserved) == _MICROCOMPACT_KEEP_RECENT
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_microcompact_preserves_short_results():
|
||||
"""Short tool results (< _MICROCOMPACT_MIN_CHARS) should not be replaced."""
|
||||
from nanobot.agent.runner import AgentRunner, _MICROCOMPACT_KEEP_RECENT
|
||||
|
||||
total = _MICROCOMPACT_KEEP_RECENT + 5
|
||||
messages: list[dict] = []
|
||||
for i in range(total):
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [{"id": f"c{i}", "type": "function", "function": {"name": "exec", "arguments": "{}"}}],
|
||||
})
|
||||
messages.append({
|
||||
"role": "tool", "tool_call_id": f"c{i}", "name": "exec",
|
||||
"content": "short",
|
||||
})
|
||||
|
||||
result = AgentRunner._microcompact(messages)
|
||||
assert result is messages # no copy needed — all stale results are short
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_microcompact_skips_non_compactable_tools():
|
||||
"""Non-compactable tools (e.g. 'message') should never be replaced."""
|
||||
from nanobot.agent.runner import AgentRunner, _MICROCOMPACT_KEEP_RECENT
|
||||
|
||||
total = _MICROCOMPACT_KEEP_RECENT + 5
|
||||
long_content = "y" * 1000
|
||||
messages: list[dict] = []
|
||||
for i in range(total):
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [{"id": f"c{i}", "type": "function", "function": {"name": "message", "arguments": "{}"}}],
|
||||
})
|
||||
messages.append({
|
||||
"role": "tool", "tool_call_id": f"c{i}", "name": "message",
|
||||
"content": long_content,
|
||||
})
|
||||
|
||||
result = AgentRunner._microcompact(messages)
|
||||
assert result is messages # no compactable tools found
|
||||
|
||||
|
||||
def test_governance_repairs_orphans_after_snip():
|
||||
"""After _snip_history clips an assistant+tool_calls, the second
|
||||
_drop_orphan_tool_results pass must clean up the resulting orphans."""
|
||||
from nanobot.agent.runner import AgentRunner
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "old msg"},
|
||||
{"role": "assistant", "content": None,
|
||||
"tool_calls": [{"id": "tc_old", "type": "function",
|
||||
"function": {"name": "search", "arguments": "{}"}}]},
|
||||
{"role": "tool", "tool_call_id": "tc_old", "name": "search",
|
||||
"content": "old result"},
|
||||
{"role": "assistant", "content": "old answer"},
|
||||
{"role": "user", "content": "new msg"},
|
||||
]
|
||||
|
||||
# Simulate snipping that keeps only the tail: drop the assistant with
|
||||
# tool_calls but keep its tool result (orphan).
|
||||
snipped = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "tool", "tool_call_id": "tc_old", "name": "search",
|
||||
"content": "old result"},
|
||||
{"role": "assistant", "content": "old answer"},
|
||||
{"role": "user", "content": "new msg"},
|
||||
]
|
||||
|
||||
cleaned = AgentRunner._drop_orphan_tool_results(snipped)
|
||||
# The orphan tool result should be removed.
|
||||
assert not any(
|
||||
m.get("role") == "tool" and m.get("tool_call_id") == "tc_old"
|
||||
for m in cleaned
|
||||
)
|
||||
|
||||
|
||||
def test_governance_fallback_still_repairs_orphans():
|
||||
"""When full governance fails, the fallback must still run
|
||||
_drop_orphan_tool_results and _backfill_missing_tool_results."""
|
||||
from nanobot.agent.runner import AgentRunner
|
||||
|
||||
# Messages with an orphan tool result (no matching assistant tool_call).
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "tool", "tool_call_id": "orphan_tc", "name": "read",
|
||||
"content": "stale"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
]
|
||||
|
||||
repaired = AgentRunner._drop_orphan_tool_results(messages)
|
||||
repaired = AgentRunner._backfill_missing_tool_results(repaired)
|
||||
# Orphan tool result should be gone.
|
||||
assert not any(m.get("tool_call_id") == "orphan_tc" for m in repaired)
|
||||
def test_snip_history_preserves_user_message_after_truncation(monkeypatch):
|
||||
"""When _snip_history truncates messages and the only user message ends up
|
||||
outside the kept window, the method must recover the nearest user message
|
||||
so the resulting sequence is valid for providers like GLM (which reject
|
||||
system→assistant with error 1214).
|
||||
|
||||
This reproduces the exact scenario from the bug report:
|
||||
- Normal interaction: user asks, assistant calls tool, tool returns,
|
||||
assistant replies.
|
||||
- Injection adds a phantom user message, triggering more tool calls.
|
||||
- _snip_history activates, keeping only recent assistant/tool pairs.
|
||||
- The injected user message is in the truncated prefix and gets lost.
|
||||
"""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
runner = AgentRunner(provider)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "assistant", "content": "previous reply"},
|
||||
{"role": "user", "content": ".nanobot的同目录"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "tc_1", "type": "function", "function": {"name": "exec", "arguments": "{}"}}],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "tc_1", "content": "tool output 1"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "tc_2", "type": "function", "function": {"name": "exec", "arguments": "{}"}}],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "tc_2", "content": "tool output 2"},
|
||||
]
|
||||
|
||||
spec = AgentRunSpec(
|
||||
initial_messages=messages,
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
context_window_tokens=2000,
|
||||
context_block_limit=100,
|
||||
)
|
||||
|
||||
# Make estimate_prompt_tokens_chain report above budget so _snip_history activates.
|
||||
monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", lambda *_a, **_kw: (500, None))
|
||||
# Make kept window small: only the last 2 messages fit the budget.
|
||||
token_sizes = {
|
||||
"system": 0,
|
||||
"previous reply": 200,
|
||||
".nanobot的同目录": 80,
|
||||
"tool output 1": 80,
|
||||
"tool output 2": 80,
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
"nanobot.agent.runner.estimate_message_tokens",
|
||||
lambda msg: token_sizes.get(str(msg.get("content")), 100),
|
||||
)
|
||||
|
||||
trimmed = runner._snip_history(spec, messages)
|
||||
|
||||
# The first non-system message MUST be user (not assistant).
|
||||
non_system = [m for m in trimmed if m.get("role") != "system"]
|
||||
assert non_system, "trimmed should contain at least one non-system message"
|
||||
assert non_system[0]["role"] == "user", (
|
||||
f"First non-system message must be 'user', got '{non_system[0]['role']}'. "
|
||||
f"Roles: {[m['role'] for m in trimmed]}"
|
||||
)
|
||||
|
||||
|
||||
def test_snip_history_no_user_at_all_falls_back_gracefully(monkeypatch):
|
||||
"""Edge case: if non_system has zero user messages, _snip_history should
|
||||
still return a valid sequence (not crash or produce system→assistant)."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
runner = AgentRunner(provider)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "assistant", "content": "reply"},
|
||||
{"role": "tool", "tool_call_id": "tc_1", "content": "result"},
|
||||
{"role": "assistant", "content": "reply 2"},
|
||||
{"role": "tool", "tool_call_id": "tc_2", "content": "result 2"},
|
||||
]
|
||||
|
||||
spec = AgentRunSpec(
|
||||
initial_messages=messages,
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
context_window_tokens=2000,
|
||||
context_block_limit=100,
|
||||
)
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", lambda *_a, **_kw: (500, None))
|
||||
monkeypatch.setattr(
|
||||
"nanobot.agent.runner.estimate_message_tokens",
|
||||
lambda msg: 100,
|
||||
)
|
||||
|
||||
trimmed = runner._snip_history(spec, messages)
|
||||
|
||||
# Should not crash. The result should still be a valid list.
|
||||
assert isinstance(trimmed, list)
|
||||
# Must have at least system.
|
||||
assert any(m.get("role") == "system" for m in trimmed)
|
||||
# The _enforce_role_alternation safety net must be able to fix whatever
|
||||
# _snip_history returns here — verify it produces a valid sequence.
|
||||
from nanobot.providers.base import LLMProvider
|
||||
fixed = LLMProvider._enforce_role_alternation(trimmed)
|
||||
non_system = [m for m in fixed if m["role"] != "system"]
|
||||
if non_system:
|
||||
assert non_system[0]["role"] in ("user", "tool"), (
|
||||
f"Safety net should ensure first non-system is user/tool, got {non_system[0]['role']}"
|
||||
)
|
||||
172
tests/agent/test_runner_hooks.py
Normal file
172
tests/agent/test_runner_hooks.py
Normal file
@ -0,0 +1,172 @@
|
||||
"""Tests for AgentRunner hook lifecycle: ordering, streaming deltas,
|
||||
cached-token propagation, and hook context."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_calls_hooks_in_order():
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
call_count = {"n": 0}
|
||||
events: list[tuple] = []
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
)
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="tool result")
|
||||
|
||||
class RecordingHook(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
events.append(("before_iteration", context.iteration))
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
events.append((
|
||||
"before_execute_tools",
|
||||
context.iteration,
|
||||
[tc.name for tc in context.tool_calls],
|
||||
))
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
events.append((
|
||||
"after_iteration",
|
||||
context.iteration,
|
||||
context.final_content,
|
||||
list(context.tool_results),
|
||||
list(context.tool_events),
|
||||
context.stop_reason,
|
||||
))
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
events.append(("finalize_content", context.iteration, content))
|
||||
return content.upper() if content else content
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=RecordingHook(),
|
||||
))
|
||||
|
||||
assert result.final_content == "DONE"
|
||||
assert events == [
|
||||
("before_iteration", 0),
|
||||
("before_execute_tools", 0, ["list_dir"]),
|
||||
(
|
||||
"after_iteration",
|
||||
0,
|
||||
None,
|
||||
["tool result"],
|
||||
[{"name": "list_dir", "status": "ok", "detail": "tool result"}],
|
||||
None,
|
||||
),
|
||||
("before_iteration", 1),
|
||||
("finalize_content", 1, "done"),
|
||||
("after_iteration", 1, "DONE", [], [], "completed"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_streaming_hook_receives_deltas_and_end_signal():
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
streamed: list[str] = []
|
||||
endings: list[bool] = []
|
||||
|
||||
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
|
||||
await on_content_delta("he")
|
||||
await on_content_delta("llo")
|
||||
return LLMResponse(content="hello", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
provider.chat_with_retry = AsyncMock()
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
class StreamingHook(AgentHook):
|
||||
def wants_streaming(self) -> bool:
|
||||
return True
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
streamed.append(delta)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
endings.append(resuming)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=StreamingHook(),
|
||||
))
|
||||
|
||||
assert result.final_content == "hello"
|
||||
assert streamed == ["he", "llo"]
|
||||
assert endings == [False]
|
||||
provider.chat_with_retry.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_passes_cached_tokens_to_hook_context():
|
||||
"""Hook context.usage should contain cached_tokens."""
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
captured_usage: list[dict] = []
|
||||
|
||||
class UsageHook(AgentHook):
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
captured_usage.append(dict(context.usage))
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(
|
||||
content="done",
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=UsageHook(),
|
||||
))
|
||||
|
||||
assert len(captured_usage) == 1
|
||||
assert captured_usage[0]["cached_tokens"] == 150
|
||||
1038
tests/agent/test_runner_injections.py
Normal file
1038
tests/agent/test_runner_injections.py
Normal file
File diff suppressed because it is too large
Load Diff
161
tests/agent/test_runner_persistence.py
Normal file
161
tests/agent/test_runner_persistence.py
Normal file
@ -0,0 +1,161 @@
|
||||
"""Tests for tool result persistence: large results, pruning, temp files, cleanup."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
async def test_runner_persists_large_tool_results_for_follow_up_calls(tmp_path):
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_second_call: list[dict] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_big", name="list_dir", arguments={"path": "."})],
|
||||
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="x" * 20_000)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
workspace=tmp_path,
|
||||
session_key="test:runner",
|
||||
max_tool_result_chars=2048,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
|
||||
assert "[tool output persisted]" in tool_message["content"]
|
||||
assert "tool-results" in tool_message["content"]
|
||||
assert (tmp_path / ".nanobot" / "tool-results" / "test_runner" / "call_big.txt").exists()
|
||||
|
||||
|
||||
def test_persist_tool_result_prunes_old_session_buckets(tmp_path):
|
||||
from nanobot.utils.helpers import maybe_persist_tool_result
|
||||
|
||||
root = tmp_path / ".nanobot" / "tool-results"
|
||||
old_bucket = root / "old_session"
|
||||
recent_bucket = root / "recent_session"
|
||||
old_bucket.mkdir(parents=True)
|
||||
recent_bucket.mkdir(parents=True)
|
||||
(old_bucket / "old.txt").write_text("old", encoding="utf-8")
|
||||
(recent_bucket / "recent.txt").write_text("recent", encoding="utf-8")
|
||||
|
||||
stale = time.time() - (8 * 24 * 60 * 60)
|
||||
os.utime(old_bucket, (stale, stale))
|
||||
os.utime(old_bucket / "old.txt", (stale, stale))
|
||||
|
||||
persisted = maybe_persist_tool_result(
|
||||
tmp_path,
|
||||
"current:session",
|
||||
"call_big",
|
||||
"x" * 5000,
|
||||
max_chars=64,
|
||||
)
|
||||
|
||||
assert "[tool output persisted]" in persisted
|
||||
assert not old_bucket.exists()
|
||||
assert recent_bucket.exists()
|
||||
assert (root / "current_session" / "call_big.txt").exists()
|
||||
|
||||
|
||||
def test_persist_tool_result_leaves_no_temp_files(tmp_path):
|
||||
from nanobot.utils.helpers import maybe_persist_tool_result
|
||||
|
||||
root = tmp_path / ".nanobot" / "tool-results"
|
||||
maybe_persist_tool_result(
|
||||
tmp_path,
|
||||
"current:session",
|
||||
"call_big",
|
||||
"x" * 5000,
|
||||
max_chars=64,
|
||||
)
|
||||
|
||||
assert (root / "current_session" / "call_big.txt").exists()
|
||||
assert list((root / "current_session").glob("*.tmp")) == []
|
||||
|
||||
|
||||
def test_persist_tool_result_logs_cleanup_failures(monkeypatch, tmp_path):
|
||||
from nanobot.utils.helpers import maybe_persist_tool_result
|
||||
|
||||
warnings: list[str] = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.utils.helpers._cleanup_tool_result_buckets",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(OSError("busy")),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.utils.helpers.logger.exception",
|
||||
lambda message, *args: warnings.append(message.format(*args)),
|
||||
)
|
||||
|
||||
persisted = maybe_persist_tool_result(
|
||||
tmp_path,
|
||||
"current:session",
|
||||
"call_big",
|
||||
"x" * 5000,
|
||||
max_chars=64,
|
||||
)
|
||||
|
||||
assert "[tool output persisted]" in persisted
|
||||
assert warnings and "Failed to clean stale tool result buckets" in warnings[0]
|
||||
async def test_runner_keeps_going_when_tool_result_persistence_fails():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_second_call: list[dict] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="tool result")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
with patch("nanobot.agent.runner.maybe_persist_tool_result", side_effect=RuntimeError("disk full")):
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
|
||||
assert tool_message["content"] == "tool result"
|
||||
321
tests/agent/test_runner_reasoning.py
Normal file
321
tests/agent/test_runner_reasoning.py
Normal file
@ -0,0 +1,321 @@
|
||||
"""Tests for AgentRunner reasoning extraction and emission.
|
||||
|
||||
Covers the three sources of model reasoning (dedicated ``reasoning_content``,
|
||||
Anthropic ``thinking_blocks``, inline ``<think>``/``<thought>`` tags) plus
|
||||
the streaming interaction: reasoning and answer streams are independent
|
||||
channels, gated by ``context.streamed_reasoning`` rather than
|
||||
``context.streamed_content``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.hook import AgentHook
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
|
||||
class _RecordingHook(AgentHook):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.emitted: list[str] = []
|
||||
self.end_calls = 0
|
||||
|
||||
async def emit_reasoning(self, reasoning_content: str | None) -> None:
|
||||
if reasoning_content:
|
||||
self.emitted.append(reasoning_content)
|
||||
|
||||
async def emit_reasoning_end(self) -> None:
|
||||
self.end_calls += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_preserves_reasoning_fields_in_assistant_history():
|
||||
"""Reasoning fields ride along on the persisted assistant message so
|
||||
follow-up provider calls retain the model's prior thinking context."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_second_call: list[dict] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
reasoning_content="hidden reasoning",
|
||||
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
|
||||
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="tool result")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "do task"},
|
||||
],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
assistant_messages = [
|
||||
msg for msg in captured_second_call
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls")
|
||||
]
|
||||
assert len(assistant_messages) == 1
|
||||
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
|
||||
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_emits_anthropic_thinking_blocks():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(
|
||||
content="The answer is 42.",
|
||||
thinking_blocks=[
|
||||
{"type": "thinking", "thinking": "Let me analyze this step by step.", "signature": "sig1"},
|
||||
{"type": "thinking", "thinking": "After careful consideration.", "signature": "sig2"},
|
||||
],
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
hook = _RecordingHook()
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "question"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=hook,
|
||||
))
|
||||
|
||||
assert result.final_content == "The answer is 42."
|
||||
assert len(hook.emitted) == 1
|
||||
assert "Let me analyze this" in hook.emitted[0]
|
||||
assert "After careful consideration" in hook.emitted[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_emits_inline_think_content_as_reasoning():
|
||||
"""Models embedding reasoning in <think>...</think> blocks should have
|
||||
that content extracted and emitted, and stripped from the answer."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(
|
||||
content="<think>Let me think about this...\nThe answer is 42.</think>The answer is 42.",
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
hook = _RecordingHook()
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "what is the answer?"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=hook,
|
||||
))
|
||||
|
||||
assert result.final_content == "The answer is 42."
|
||||
assert len(hook.emitted) == 1
|
||||
assert "Let me think about this" in hook.emitted[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_prefers_reasoning_content_over_inline_think():
|
||||
"""Fallback priority: dedicated reasoning_content wins; inline <think>
|
||||
is still scrubbed from the answer content."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(
|
||||
content="<think>inline thinking</think>The answer.",
|
||||
reasoning_content="dedicated reasoning field",
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
hook = _RecordingHook()
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "question"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=hook,
|
||||
))
|
||||
|
||||
assert result.final_content == "The answer."
|
||||
assert hook.emitted == ["dedicated reasoning field"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_emits_reasoning_content_even_when_answer_was_streamed():
|
||||
"""`reasoning_content` arrives only on the final response; streaming the
|
||||
answer must not suppress it (the answer stream and the reasoning channel
|
||||
are independent — only the reasoning-already-emitted bit matters)."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
provider.supports_progress_deltas = True
|
||||
|
||||
async def chat_stream_with_retry(*, on_content_delta=None, **kwargs):
|
||||
if on_content_delta:
|
||||
await on_content_delta("The ")
|
||||
await on_content_delta("answer.")
|
||||
return LLMResponse(
|
||||
content="The answer.",
|
||||
reasoning_content="step-by-step deduction",
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||
)
|
||||
|
||||
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
progress_calls: list[str] = []
|
||||
|
||||
async def _progress(content: str, **_kwargs):
|
||||
progress_calls.append(content)
|
||||
|
||||
hook = _RecordingHook()
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "question"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=hook,
|
||||
stream_progress_deltas=True,
|
||||
progress_callback=_progress,
|
||||
))
|
||||
|
||||
assert result.final_content == "The answer."
|
||||
assert progress_calls, "answer should have streamed via progress callback"
|
||||
assert hook.emitted == ["step-by-step deduction"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_does_not_double_emit_when_inline_think_already_streamed():
|
||||
"""Inline `<think>` blocks streamed incrementally during the answer
|
||||
stream must not be re-emitted from the final response."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
provider.supports_progress_deltas = True
|
||||
|
||||
async def chat_stream_with_retry(*, on_content_delta=None, **kwargs):
|
||||
if on_content_delta:
|
||||
await on_content_delta("<think>working...</think>")
|
||||
await on_content_delta("The answer.")
|
||||
return LLMResponse(
|
||||
content="<think>working...</think>The answer.",
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||
)
|
||||
|
||||
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
async def _progress(content: str, **_kwargs):
|
||||
pass
|
||||
|
||||
hook = _RecordingHook()
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "question"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=hook,
|
||||
stream_progress_deltas=True,
|
||||
progress_callback=_progress,
|
||||
))
|
||||
|
||||
assert result.final_content == "The answer."
|
||||
assert hook.emitted == ["working..."]
|
||||
assert hook.end_calls >= 1, "reasoning stream must be closed once the answer starts"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_closes_reasoning_stream_after_one_shot_response():
|
||||
"""A non-streaming response carrying ``reasoning_content`` must emit
|
||||
both a reasoning delta and an end marker so channels can finalize the
|
||||
in-place bubble."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(
|
||||
content="answer",
|
||||
reasoning_content="hidden thought",
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
hook = _RecordingHook()
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "q"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=hook,
|
||||
))
|
||||
|
||||
assert result.final_content == "answer"
|
||||
assert hook.emitted == ["hidden thought"]
|
||||
assert hook.end_calls == 1
|
||||
244
tests/agent/test_runner_safety.py
Normal file
244
tests/agent/test_runner_safety.py
Normal file
@ -0,0 +1,244 @@
|
||||
"""Tests for AgentRunner security: workspace violations, SSRF, shell guard, throttling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
async def test_runner_does_not_abort_on_workspace_violation_anymore():
|
||||
"""v2 behavior: workspace-bound rejections are *soft* tool errors.
|
||||
|
||||
Previously (PR #3493) any workspace boundary error became a fatal
|
||||
RuntimeError that aborted the turn. That silently killed legitimate
|
||||
workspace commands once the heuristic guard misfired (#3599 #3605), so
|
||||
we now hand the error back to the LLM as a recoverable tool result and
|
||||
rely on ``repeated_workspace_violation_error`` to throttle bypass loops.
|
||||
"""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
provider.chat_with_retry = AsyncMock(side_effect=[
|
||||
LLMResponse(
|
||||
content="trying outside",
|
||||
tool_calls=[ToolCallRequest(
|
||||
id="call_1", name="read_file", arguments={"path": "/tmp/outside.md"},
|
||||
)],
|
||||
),
|
||||
LLMResponse(content="ok, telling the user instead", tool_calls=[]),
|
||||
])
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(
|
||||
side_effect=PermissionError(
|
||||
"Path /tmp/outside.md is outside allowed directory /workspace"
|
||||
)
|
||||
)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert provider.chat_with_retry.await_count == 2, (
|
||||
"workspace violation must NOT short-circuit the loop"
|
||||
)
|
||||
assert result.stop_reason != "tool_error"
|
||||
assert result.error is None
|
||||
assert result.final_content == "ok, telling the user instead"
|
||||
assert result.tool_events and result.tool_events[0]["status"] == "error"
|
||||
# Detail still carries the workspace_violation breadcrumb for telemetry,
|
||||
# but the runner did not raise.
|
||||
assert "workspace_violation" in result.tool_events[0]["detail"]
|
||||
|
||||
|
||||
def test_is_ssrf_violation_recognizes_private_url_blocks():
|
||||
"""SSRF rejections are classified separately from workspace boundaries."""
|
||||
from nanobot.agent.runner import AgentRunner
|
||||
|
||||
ssrf_msg = "Error: Command blocked by safety guard (internal/private URL detected)"
|
||||
assert AgentRunner._is_ssrf_violation(ssrf_msg) is True
|
||||
assert AgentRunner._is_ssrf_violation(
|
||||
"URL validation failed: Blocked: host resolves to private/internal address 192.168.1.2"
|
||||
) is True
|
||||
|
||||
# Workspace-bound markers are NOT classified as SSRF.
|
||||
assert AgentRunner._is_ssrf_violation(
|
||||
"Error: Command blocked by safety guard (path outside working dir)"
|
||||
) is False
|
||||
assert AgentRunner._is_ssrf_violation(
|
||||
"Path /tmp/x is outside allowed directory /ws"
|
||||
) is False
|
||||
# Deny / allowlist filter messages stay non-fatal too.
|
||||
assert AgentRunner._is_ssrf_violation(
|
||||
"Error: Command blocked by deny pattern filter"
|
||||
) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_returns_non_retryable_hint_on_ssrf_violation():
|
||||
"""SSRF stays blocked, but the runtime gives the LLM a final chance to recover."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
provider.chat_with_retry = AsyncMock(side_effect=[
|
||||
LLMResponse(
|
||||
content="curl-ing metadata",
|
||||
tool_calls=[ToolCallRequest(
|
||||
id="call_ssrf",
|
||||
name="exec",
|
||||
arguments={"command": "curl http://169.254.169.254"},
|
||||
)],
|
||||
),
|
||||
LLMResponse(
|
||||
content="I cannot access that private URL. Please share local files.",
|
||||
tool_calls=[],
|
||||
),
|
||||
])
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value=(
|
||||
"Error: Command blocked by safety guard (internal/private URL detected)"
|
||||
))
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert provider.chat_with_retry.await_count == 2
|
||||
assert result.stop_reason == "completed"
|
||||
assert result.error is None
|
||||
assert result.final_content == "I cannot access that private URL. Please share local files."
|
||||
assert result.tool_events and result.tool_events[0]["detail"].startswith("ssrf_violation:")
|
||||
tool_messages = [m for m in result.messages if m.get("role") == "tool"]
|
||||
assert tool_messages
|
||||
assert "non-bypassable security boundary" in tool_messages[0]["content"]
|
||||
assert "Do not retry" in tool_messages[0]["content"]
|
||||
assert "tools.ssrfWhitelist" in tool_messages[0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_lets_llm_recover_from_shell_guard_path_outside():
|
||||
"""Reporter scenario for #3599 / #3605 -- guard hit, agent recovers.
|
||||
|
||||
The shell `_guard_command` heuristic fires on `2>/dev/null`-style
|
||||
redirects and other shell idioms. Before v2 that abort'd the whole
|
||||
turn (silent hang on Telegram per #3605); now the LLM gets the soft
|
||||
error back and can finalize on the next iteration.
|
||||
"""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_second_call: list[dict] = []
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
if provider.chat_with_retry.await_count == 1:
|
||||
return LLMResponse(
|
||||
content="trying noisy cleanup",
|
||||
tool_calls=[ToolCallRequest(
|
||||
id="call_blocked",
|
||||
name="exec",
|
||||
arguments={"command": "rm scratch.txt 2>/dev/null"},
|
||||
)],
|
||||
)
|
||||
captured_second_call[:] = list(messages)
|
||||
return LLMResponse(content="recovered final answer", tool_calls=[])
|
||||
|
||||
provider.chat_with_retry = AsyncMock(side_effect=chat_with_retry)
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(
|
||||
return_value="Error: Command blocked by safety guard (path outside working dir)"
|
||||
)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert provider.chat_with_retry.await_count == 2, (
|
||||
"guard hit must NOT short-circuit the loop -- LLM should get a second turn"
|
||||
)
|
||||
assert result.stop_reason != "tool_error"
|
||||
assert result.error is None
|
||||
assert result.final_content == "recovered final answer"
|
||||
assert result.tool_events and result.tool_events[0]["status"] == "error"
|
||||
# v2: detail keeps the breadcrumb but the runner did not raise.
|
||||
assert "workspace_violation" in result.tool_events[0]["detail"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_throttles_repeated_workspace_bypass_attempts():
|
||||
"""#3493 motivation: stop the LLM bypass loop without aborting the turn.
|
||||
|
||||
LLM keeps switching tools (read_file -> exec cat -> python -c open(...))
|
||||
against the same outside path. After the soft retry budget is exhausted
|
||||
the runner replaces the tool result with a hard "stop trying" message
|
||||
so the model finally gives up and surfaces the boundary to the user.
|
||||
"""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
bypass_attempts = [
|
||||
ToolCallRequest(
|
||||
id=f"a{i}", name="exec",
|
||||
arguments={"command": f"cat /Users/x/Downloads/01.md # try {i}"},
|
||||
)
|
||||
for i in range(4)
|
||||
]
|
||||
responses: list[LLMResponse] = [
|
||||
LLMResponse(content=f"try {i}", tool_calls=[bypass_attempts[i]])
|
||||
for i in range(4)
|
||||
]
|
||||
responses.append(LLMResponse(content="ok telling user", tool_calls=[]))
|
||||
|
||||
provider = MagicMock()
|
||||
provider.chat_with_retry = AsyncMock(side_effect=responses)
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(
|
||||
return_value="Error: Command blocked by safety guard (path outside working dir)"
|
||||
)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=10,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
# All 4 bypass attempts surface to the LLM (no fatal abort), and the
|
||||
# runner finally completes once the LLM stops asking.
|
||||
assert result.stop_reason != "tool_error"
|
||||
assert result.error is None
|
||||
assert result.final_content == "ok telling user"
|
||||
# The third+ attempts must have been escalated -- look at the events.
|
||||
escalated = [
|
||||
ev for ev in result.tool_events
|
||||
if ev["status"] == "error"
|
||||
and ev["detail"].startswith("workspace_violation_escalated:")
|
||||
]
|
||||
assert escalated, (
|
||||
"expected at least one escalated workspace_violation event, got: "
|
||||
f"{result.tool_events}"
|
||||
)
|
||||
181
tests/agent/test_runner_tool_execution.py
Normal file
181
tests/agent/test_runner_tool_execution.py
Normal file
@ -0,0 +1,181 @@
|
||||
"""Tests for AgentRunner tool execution: batching, concurrency, exclusive tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
class _DelayTool(Tool):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
delay: float,
|
||||
read_only: bool,
|
||||
shared_events: list[str],
|
||||
exclusive: bool = False,
|
||||
):
|
||||
self._name = name
|
||||
self._delay = delay
|
||||
self._read_only = read_only
|
||||
self._shared_events = shared_events
|
||||
self._exclusive = exclusive
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict:
|
||||
return {"type": "object", "properties": {}, "required": []}
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
return self._read_only
|
||||
|
||||
@property
|
||||
def exclusive(self) -> bool:
|
||||
return self._exclusive
|
||||
|
||||
async def execute(self, **kwargs):
|
||||
self._shared_events.append(f"start:{self._name}")
|
||||
await asyncio.sleep(self._delay)
|
||||
self._shared_events.append(f"end:{self._name}")
|
||||
return self._name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_batches_read_only_tools_before_exclusive_work():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
tools = ToolRegistry()
|
||||
shared_events: list[str] = []
|
||||
read_a = _DelayTool("read_a", delay=0.05, read_only=True, shared_events=shared_events)
|
||||
read_b = _DelayTool("read_b", delay=0.05, read_only=True, shared_events=shared_events)
|
||||
write_a = _DelayTool("write_a", delay=0.01, read_only=False, shared_events=shared_events)
|
||||
tools.register(read_a)
|
||||
tools.register(read_b)
|
||||
tools.register(write_a)
|
||||
|
||||
runner = AgentRunner(MagicMock())
|
||||
await runner._execute_tools(
|
||||
AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
concurrent_tools=True,
|
||||
),
|
||||
[
|
||||
ToolCallRequest(id="ro1", name="read_a", arguments={}),
|
||||
ToolCallRequest(id="ro2", name="read_b", arguments={}),
|
||||
ToolCallRequest(id="rw1", name="write_a", arguments={}),
|
||||
],
|
||||
{},
|
||||
{},
|
||||
)
|
||||
|
||||
assert shared_events[0:2] == ["start:read_a", "start:read_b"]
|
||||
assert "end:read_a" in shared_events and "end:read_b" in shared_events
|
||||
assert shared_events.index("end:read_a") < shared_events.index("start:write_a")
|
||||
assert shared_events.index("end:read_b") < shared_events.index("start:write_a")
|
||||
assert shared_events[-2:] == ["start:write_a", "end:write_a"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_does_not_batch_exclusive_read_only_tools():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
tools = ToolRegistry()
|
||||
shared_events: list[str] = []
|
||||
read_a = _DelayTool("read_a", delay=0.03, read_only=True, shared_events=shared_events)
|
||||
read_b = _DelayTool("read_b", delay=0.03, read_only=True, shared_events=shared_events)
|
||||
ddg_like = _DelayTool(
|
||||
"ddg_like",
|
||||
delay=0.01,
|
||||
read_only=True,
|
||||
shared_events=shared_events,
|
||||
exclusive=True,
|
||||
)
|
||||
tools.register(read_a)
|
||||
tools.register(ddg_like)
|
||||
tools.register(read_b)
|
||||
|
||||
runner = AgentRunner(MagicMock())
|
||||
await runner._execute_tools(
|
||||
AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
concurrent_tools=True,
|
||||
),
|
||||
[
|
||||
ToolCallRequest(id="ro1", name="read_a", arguments={}),
|
||||
ToolCallRequest(id="ddg1", name="ddg_like", arguments={}),
|
||||
ToolCallRequest(id="ro2", name="read_b", arguments={}),
|
||||
],
|
||||
{},
|
||||
{},
|
||||
)
|
||||
|
||||
assert shared_events[0] == "start:read_a"
|
||||
assert shared_events.index("end:read_a") < shared_events.index("start:ddg_like")
|
||||
assert shared_events.index("end:ddg_like") < shared_events.index("start:read_b")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_blocks_repeated_external_fetches():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_final_call: list[dict] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] <= 3:
|
||||
return LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id=f"call_{call_count['n']}", name="web_fetch", arguments={"url": "https://example.com"})],
|
||||
usage={},
|
||||
)
|
||||
captured_final_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="page content")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "research task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=4,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
assert tools.execute.await_count == 2
|
||||
blocked_tool_message = [
|
||||
msg for msg in captured_final_call
|
||||
if msg.get("role") == "tool" and msg.get("tool_call_id") == "call_3"
|
||||
][0]
|
||||
assert "repeated external lookup blocked" in blocked_tool_message["content"]
|
||||
@ -43,6 +43,19 @@ def test_list_sessions_includes_metadata_title(tmp_path):
|
||||
assert rows[0]["title"] == "自动生成标题"
|
||||
|
||||
|
||||
def test_list_sessions_includes_user_preview(tmp_path):
|
||||
manager = SessionManager(tmp_path)
|
||||
session = manager.get_or_create("websocket:chat-preview")
|
||||
session.add_message("user", "帮我总结一下 OpenAI 的最新硬件计划")
|
||||
session.add_message("assistant", "可以,我会先查最新消息。")
|
||||
manager.save(session)
|
||||
|
||||
rows = manager.list_sessions()
|
||||
|
||||
assert rows[0]["key"] == "websocket:chat-preview"
|
||||
assert rows[0]["preview"] == "帮我总结一下 OpenAI 的最新硬件计划"
|
||||
|
||||
|
||||
# --- Original regression test (from PR 2075) ---
|
||||
|
||||
def test_get_history_drops_orphan_tool_results_when_window_cuts_tool_calls():
|
||||
|
||||
@ -10,6 +10,7 @@ See: https://github.com/HKUDS/nanobot/issues/2966
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
@ -17,42 +18,47 @@ from unittest.mock import MagicMock, patch, AsyncMock
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMProvider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_loop():
|
||||
"""Create a minimal AgentLoop with mocked dependencies."""
|
||||
with patch.object(AgentLoop, "__init__", lambda self: None):
|
||||
loop = AgentLoop()
|
||||
loop.sessions = MagicMock()
|
||||
loop._pending_queues = {}
|
||||
loop._session_locks = {}
|
||||
loop._active_tasks = {}
|
||||
loop._concurrency_gate = None
|
||||
loop._RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
|
||||
loop._PENDING_USER_TURN_KEY = "pending_user_turn"
|
||||
loop.bus = MagicMock()
|
||||
loop.bus.publish_outbound = AsyncMock()
|
||||
loop.bus.publish_inbound = AsyncMock()
|
||||
loop.commands = MagicMock()
|
||||
loop.commands.dispatch_priority = AsyncMock(return_value=None)
|
||||
return loop
|
||||
def _make_provider():
|
||||
"""Create an LLM provider mock with required attributes."""
|
||||
from types import SimpleNamespace
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.generation = SimpleNamespace(max_tokens=4096, temperature=0.1, reasoning_effort=None)
|
||||
provider.estimate_prompt_tokens.return_value = (10_000, "test")
|
||||
return provider
|
||||
|
||||
|
||||
def _make_loop(tmp_path: Path) -> AgentLoop:
|
||||
"""Create a real AgentLoop with mocked provider — avoids patching __init__."""
|
||||
bus = MessageBus()
|
||||
provider = _make_provider()
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path)
|
||||
|
||||
|
||||
class TestStopPreservesContext:
|
||||
"""Verify that /stop restores partial context via checkpoint."""
|
||||
|
||||
def test_restore_checkpoint_method_exists(self, mock_loop):
|
||||
def test_restore_checkpoint_method_exists(self, tmp_path):
|
||||
"""AgentLoop should have _restore_runtime_checkpoint."""
|
||||
assert hasattr(mock_loop, "_restore_runtime_checkpoint")
|
||||
loop = _make_loop(tmp_path)
|
||||
assert hasattr(loop, "_restore_runtime_checkpoint")
|
||||
|
||||
def test_checkpoint_key_constant(self, mock_loop):
|
||||
def test_checkpoint_key_constant(self, tmp_path):
|
||||
"""The runtime checkpoint key should be defined."""
|
||||
assert mock_loop._RUNTIME_CHECKPOINT_KEY == "runtime_checkpoint"
|
||||
loop = _make_loop(tmp_path)
|
||||
assert loop._RUNTIME_CHECKPOINT_KEY == "runtime_checkpoint"
|
||||
|
||||
def test_cancel_dispatch_restores_checkpoint(self, mock_loop):
|
||||
def test_cancel_dispatch_restores_checkpoint(self, tmp_path):
|
||||
"""When a task is cancelled, the checkpoint should be restored."""
|
||||
# Create a mock session with a checkpoint
|
||||
loop = _make_loop(tmp_path)
|
||||
session = MagicMock()
|
||||
session.metadata = {
|
||||
"runtime_checkpoint": {
|
||||
@ -74,14 +80,11 @@ class TestStopPreservesContext:
|
||||
session.messages = [
|
||||
{"role": "user", "content": "Search for something"},
|
||||
]
|
||||
mock_loop.sessions.get_or_create.return_value = session
|
||||
loop.sessions.get_or_create.return_value = session
|
||||
|
||||
# The restore method should add checkpoint messages to session history
|
||||
restored = mock_loop._restore_runtime_checkpoint(session)
|
||||
restored = loop._restore_runtime_checkpoint(session)
|
||||
assert restored is True
|
||||
# After restore, session should have more messages
|
||||
assert len(session.messages) > 1
|
||||
# The checkpoint should be cleared
|
||||
assert "runtime_checkpoint" not in session.metadata
|
||||
|
||||
|
||||
|
||||
558
tests/agent/test_subagent_lifecycle.py
Normal file
558
tests/agent/test_subagent_lifecycle.py
Normal file
@ -0,0 +1,558 @@
|
||||
"""Tests for SubagentManager lifecycle — spawn, run, announce, cancel."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.hook import AgentHookContext
|
||||
from nanobot.agent.runner import AgentRunResult
|
||||
from nanobot.agent.subagent import (
|
||||
SubagentManager,
|
||||
SubagentStatus,
|
||||
_SubagentHook,
|
||||
)
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMProvider
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _manager(tmp_path: Path, **kw) -> SubagentManager:
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
defaults = dict(
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
bus=MessageBus(),
|
||||
model="test-model",
|
||||
max_tool_result_chars=16_000,
|
||||
)
|
||||
defaults.update(kw)
|
||||
return SubagentManager(**defaults)
|
||||
|
||||
|
||||
def _make_hook_context(**overrides) -> AgentHookContext:
|
||||
defaults = dict(
|
||||
iteration=1,
|
||||
tool_calls=[],
|
||||
tool_events=[],
|
||||
messages=[],
|
||||
usage={},
|
||||
error=None,
|
||||
stop_reason="completed",
|
||||
final_content="ok",
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return AgentHookContext(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubagentStatus defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSubagentStatus:
|
||||
def test_defaults(self):
|
||||
s = SubagentStatus(
|
||||
task_id="abc", label="test", task_description="do stuff",
|
||||
started_at=time.monotonic(),
|
||||
)
|
||||
assert s.phase == "initializing"
|
||||
assert s.iteration == 0
|
||||
assert s.tool_events == []
|
||||
assert s.usage == {}
|
||||
assert s.stop_reason is None
|
||||
assert s.error is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# set_provider
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSetProvider:
|
||||
def test_updates_provider_model_runner(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
new_provider = MagicMock(spec=LLMProvider)
|
||||
sm.set_provider(new_provider, "new-model")
|
||||
assert sm.provider is new_provider
|
||||
assert sm.model == "new-model"
|
||||
assert sm.runner.provider is new_provider
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# spawn
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSpawn:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_string_with_task_id(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
sm.runner.run = AsyncMock(return_value=AgentRunResult(
|
||||
final_content="done", messages=[], stop_reason="completed",
|
||||
))
|
||||
result = await sm.spawn("do something")
|
||||
assert "started" in result
|
||||
assert "id:" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_task_in_running_tasks(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
block = asyncio.Event()
|
||||
async def _slow_run(spec):
|
||||
await block.wait()
|
||||
return AgentRunResult(final_content="done", messages=[], stop_reason="completed")
|
||||
sm.runner.run = _slow_run
|
||||
|
||||
await sm.spawn("task", session_key="s1")
|
||||
assert len(sm._running_tasks) == 1
|
||||
|
||||
block.set()
|
||||
await asyncio.sleep(0.1)
|
||||
assert len(sm._running_tasks) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_status(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
sm.runner.run = AsyncMock(return_value=AgentRunResult(
|
||||
final_content="done", messages=[], stop_reason="completed",
|
||||
))
|
||||
await sm.spawn("my task")
|
||||
await asyncio.sleep(0.1)
|
||||
# Status cleaned up after task completes
|
||||
assert len(sm._task_statuses) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registers_in_session_tasks(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
block = asyncio.Event()
|
||||
async def _slow_run(spec):
|
||||
await block.wait()
|
||||
return AgentRunResult(final_content="done", messages=[], stop_reason="completed")
|
||||
sm.runner.run = _slow_run
|
||||
|
||||
await sm.spawn("task", session_key="s1")
|
||||
assert "s1" in sm._session_tasks
|
||||
assert len(sm._session_tasks["s1"]) == 1
|
||||
|
||||
block.set()
|
||||
await asyncio.sleep(0.1)
|
||||
assert "s1" not in sm._session_tasks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_session_key_no_registration(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
block = asyncio.Event()
|
||||
async def _slow_run(spec):
|
||||
await block.wait()
|
||||
return AgentRunResult(final_content="done", messages=[], stop_reason="completed")
|
||||
sm.runner.run = _slow_run
|
||||
|
||||
await sm.spawn("task")
|
||||
assert len(sm._session_tasks) == 0
|
||||
|
||||
block.set()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_label_defaults_to_truncated_task(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
block = asyncio.Event()
|
||||
async def _slow_run(spec):
|
||||
await block.wait()
|
||||
return AgentRunResult(final_content="done", messages=[], stop_reason="completed")
|
||||
sm.runner.run = _slow_run
|
||||
|
||||
long_task = "A" * 50
|
||||
await sm.spawn(long_task, session_key="s1")
|
||||
status = next(iter(sm._task_statuses.values()))
|
||||
assert status.label == long_task[:30] + "..."
|
||||
|
||||
block.set()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_label(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
block = asyncio.Event()
|
||||
async def _slow_run(spec):
|
||||
await block.wait()
|
||||
return AgentRunResult(final_content="done", messages=[], stop_reason="completed")
|
||||
sm.runner.run = _slow_run
|
||||
|
||||
await sm.spawn("task", label="Custom Label", session_key="s1")
|
||||
status = next(iter(sm._task_statuses.values()))
|
||||
assert status.label == "Custom Label"
|
||||
|
||||
block.set()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_callback_removes_all_entries(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
sm.runner.run = AsyncMock(return_value=AgentRunResult(
|
||||
final_content="done", messages=[], stop_reason="completed",
|
||||
))
|
||||
await sm.spawn("task", session_key="s1")
|
||||
await asyncio.sleep(0.1)
|
||||
assert len(sm._running_tasks) == 0
|
||||
assert len(sm._task_statuses) == 0
|
||||
assert len(sm._session_tasks) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_subagent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunSubagent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_run(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
sm.runner.run = AsyncMock(return_value=AgentRunResult(
|
||||
final_content="Task done!", messages=[], stop_reason="completed",
|
||||
))
|
||||
with patch.object(sm, "_announce_result", new_callable=AsyncMock) as mock_announce:
|
||||
await sm._run_subagent(
|
||||
"t1", "do task", "label",
|
||||
{"channel": "cli", "chat_id": "direct"},
|
||||
SubagentStatus(task_id="t1", label="label", task_description="do task", started_at=time.monotonic()),
|
||||
)
|
||||
mock_announce.assert_called_once()
|
||||
assert mock_announce.call_args.args[-2] == "ok"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_error_run(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
sm.runner.run = AsyncMock(return_value=AgentRunResult(
|
||||
final_content=None, messages=[], stop_reason="tool_error",
|
||||
tool_events=[{"name": "read_file", "status": "error", "detail": "not found"}],
|
||||
))
|
||||
status = SubagentStatus(task_id="t1", label="label", task_description="do task", started_at=time.monotonic())
|
||||
with patch.object(sm, "_announce_result", new_callable=AsyncMock) as mock_announce:
|
||||
await sm._run_subagent(
|
||||
"t1", "do task", "label",
|
||||
{"channel": "cli", "chat_id": "direct"}, status,
|
||||
)
|
||||
assert mock_announce.call_args.args[-2] == "error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_run(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
sm.runner.run = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
status = SubagentStatus(task_id="t1", label="label", task_description="do task", started_at=time.monotonic())
|
||||
with patch.object(sm, "_announce_result", new_callable=AsyncMock) as mock_announce:
|
||||
await sm._run_subagent(
|
||||
"t1", "do task", "label",
|
||||
{"channel": "cli", "chat_id": "direct"}, status,
|
||||
)
|
||||
assert status.phase == "error"
|
||||
assert "LLM down" in status.error
|
||||
assert mock_announce.call_args.args[-2] == "error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_updated_on_success(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
sm.runner.run = AsyncMock(return_value=AgentRunResult(
|
||||
final_content="ok", messages=[], stop_reason="completed",
|
||||
))
|
||||
status = SubagentStatus(task_id="t1", label="label", task_description="do task", started_at=time.monotonic())
|
||||
with patch.object(sm, "_announce_result", new_callable=AsyncMock):
|
||||
await sm._run_subagent(
|
||||
"t1", "do task", "label",
|
||||
{"channel": "cli", "chat_id": "direct"}, status,
|
||||
)
|
||||
assert status.phase == "done"
|
||||
assert status.stop_reason == "completed"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _announce_result
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAnnounceResult:
|
||||
@pytest.mark.asyncio
|
||||
async def test_publishes_inbound_message(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
published = []
|
||||
sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg))
|
||||
|
||||
await sm._announce_result(
|
||||
"t1", "label", "task", "result text",
|
||||
{"channel": "cli", "chat_id": "direct"}, "ok",
|
||||
)
|
||||
|
||||
assert len(published) == 1
|
||||
msg = published[0]
|
||||
assert msg.channel == "system"
|
||||
assert msg.sender_id == "subagent"
|
||||
assert msg.metadata["injected_event"] == "subagent_result"
|
||||
assert msg.metadata["subagent_task_id"] == "t1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_key_override(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
published = []
|
||||
sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg))
|
||||
|
||||
await sm._announce_result(
|
||||
"t1", "label", "task", "result",
|
||||
{"channel": "telegram", "chat_id": "123", "session_key": "s1"}, "ok",
|
||||
)
|
||||
|
||||
assert published[0].session_key_override == "s1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_key_override_fallback(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
published = []
|
||||
sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg))
|
||||
|
||||
await sm._announce_result(
|
||||
"t1", "label", "task", "result",
|
||||
{"channel": "telegram", "chat_id": "123"}, "ok",
|
||||
)
|
||||
|
||||
assert published[0].session_key_override == "telegram:123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ok_status_text(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
published = []
|
||||
sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg))
|
||||
|
||||
await sm._announce_result(
|
||||
"t1", "label", "task", "result",
|
||||
{"channel": "cli", "chat_id": "direct"}, "ok",
|
||||
)
|
||||
|
||||
assert "completed successfully" in published[0].content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_status_text(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
published = []
|
||||
sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg))
|
||||
|
||||
await sm._announce_result(
|
||||
"t1", "label", "task", "error details",
|
||||
{"channel": "cli", "chat_id": "direct"}, "error",
|
||||
)
|
||||
|
||||
assert "failed" in published[0].content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_origin_message_id_in_metadata(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
published = []
|
||||
sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg))
|
||||
|
||||
await sm._announce_result(
|
||||
"t1", "label", "task", "result",
|
||||
{"channel": "cli", "chat_id": "direct"}, "ok",
|
||||
origin_message_id="msg-123",
|
||||
)
|
||||
|
||||
assert published[0].metadata["origin_message_id"] == "msg-123"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _format_partial_progress
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFormatPartialProgress:
|
||||
def _make_result(self, tool_events=None, error=None):
|
||||
return MagicMock(tool_events=tool_events or [], error=error)
|
||||
|
||||
def test_completed_only(self):
|
||||
result = self._make_result(tool_events=[
|
||||
{"name": "read_file", "status": "ok", "detail": "file content"},
|
||||
{"name": "exec", "status": "ok", "detail": "output"},
|
||||
])
|
||||
text = SubagentManager._format_partial_progress(result)
|
||||
assert "Completed steps:" in text
|
||||
assert "read_file" in text
|
||||
assert "exec" in text
|
||||
|
||||
def test_failure_only(self):
|
||||
result = self._make_result(tool_events=[
|
||||
{"name": "read_file", "status": "error", "detail": "not found"},
|
||||
])
|
||||
text = SubagentManager._format_partial_progress(result)
|
||||
assert "Failure:" in text
|
||||
assert "not found" in text
|
||||
|
||||
def test_completed_and_failure(self):
|
||||
result = self._make_result(tool_events=[
|
||||
{"name": "read_file", "status": "ok", "detail": "content"},
|
||||
{"name": "exec", "status": "error", "detail": "timeout"},
|
||||
])
|
||||
text = SubagentManager._format_partial_progress(result)
|
||||
assert "Completed steps:" in text
|
||||
assert "Failure:" in text
|
||||
|
||||
def test_limited_to_last_three(self):
|
||||
result = self._make_result(tool_events=[
|
||||
{"name": f"tool_{i}", "status": "ok", "detail": f"result_{i}"}
|
||||
for i in range(5)
|
||||
])
|
||||
text = SubagentManager._format_partial_progress(result)
|
||||
assert "tool_2" in text
|
||||
assert "tool_3" in text
|
||||
assert "tool_4" in text
|
||||
assert "tool_0" not in text
|
||||
assert "tool_1" not in text
|
||||
|
||||
def test_error_without_failure_event(self):
|
||||
result = self._make_result(
|
||||
tool_events=[{"name": "read_file", "status": "ok", "detail": "ok"}],
|
||||
error="Something went wrong",
|
||||
)
|
||||
text = SubagentManager._format_partial_progress(result)
|
||||
assert "Something went wrong" in text
|
||||
|
||||
def test_empty_events_with_error(self):
|
||||
result = self._make_result(error="Total failure")
|
||||
text = SubagentManager._format_partial_progress(result)
|
||||
assert "Total failure" in text
|
||||
|
||||
def test_empty_no_error_returns_fallback(self):
|
||||
result = self._make_result()
|
||||
text = SubagentManager._format_partial_progress(result)
|
||||
assert "Error" in text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cancel_by_session
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCancelBySession:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancels_running_tasks(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
block = asyncio.Event()
|
||||
async def _slow_run(spec):
|
||||
await block.wait()
|
||||
return AgentRunResult(final_content="done", messages=[], stop_reason="completed")
|
||||
sm.runner.run = _slow_run
|
||||
|
||||
await sm.spawn("task1", session_key="s1")
|
||||
await sm.spawn("task2", session_key="s1")
|
||||
assert len(sm._session_tasks.get("s1", set())) == 2
|
||||
|
||||
count = await sm.cancel_by_session("s1")
|
||||
assert count == 2
|
||||
block.set()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tasks_returns_zero(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
count = await sm.cancel_by_session("nonexistent")
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_done_not_counted(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
sm.runner.run = AsyncMock(return_value=AgentRunResult(
|
||||
final_content="done", messages=[], stop_reason="completed",
|
||||
))
|
||||
await sm.spawn("task1", session_key="s1")
|
||||
await asyncio.sleep(0.1) # Wait for completion
|
||||
|
||||
count = await sm.cancel_by_session("s1")
|
||||
assert count == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_running_count / get_running_count_by_session
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunningCounts:
|
||||
@pytest.mark.asyncio
|
||||
async def test_running_count_zero(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
assert sm.get_running_count() == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_running_count_tracks_tasks(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
block = asyncio.Event()
|
||||
async def _slow_run(spec):
|
||||
await block.wait()
|
||||
return AgentRunResult(final_content="done", messages=[], stop_reason="completed")
|
||||
sm.runner.run = _slow_run
|
||||
|
||||
await sm.spawn("t1", session_key="s1")
|
||||
await sm.spawn("t2", session_key="s1")
|
||||
assert sm.get_running_count() == 2
|
||||
assert sm.get_running_count_by_session("s1") == 2
|
||||
|
||||
block.set()
|
||||
await asyncio.sleep(0.1)
|
||||
assert sm.get_running_count() == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_running_count_by_session_nonexistent(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
assert sm.get_running_count_by_session("nonexistent") == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _SubagentHook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSubagentHook:
|
||||
@pytest.mark.asyncio
|
||||
async def test_before_execute_tools_logs(self, tmp_path):
|
||||
hook = _SubagentHook("t1")
|
||||
tool_call = MagicMock()
|
||||
tool_call.name = "read_file"
|
||||
tool_call.arguments = {"path": "/tmp/test"}
|
||||
ctx = _make_hook_context(tool_calls=[tool_call])
|
||||
# Should not raise
|
||||
await hook.before_execute_tools(ctx)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_iteration_updates_status(self):
|
||||
status = SubagentStatus(
|
||||
task_id="t1", label="test", task_description="do", started_at=time.monotonic(),
|
||||
)
|
||||
hook = _SubagentHook("t1", status)
|
||||
ctx = _make_hook_context(
|
||||
iteration=3,
|
||||
tool_events=[{"name": "read_file", "status": "ok", "detail": ""}],
|
||||
usage={"prompt_tokens": 100},
|
||||
)
|
||||
await hook.after_iteration(ctx)
|
||||
assert status.iteration == 3
|
||||
assert len(status.tool_events) == 1
|
||||
assert status.usage == {"prompt_tokens": 100}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_iteration_no_status_noop(self):
|
||||
hook = _SubagentHook("t1", status=None)
|
||||
ctx = _make_hook_context(iteration=5)
|
||||
# Should not raise
|
||||
await hook.after_iteration(ctx)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_iteration_sets_error(self):
|
||||
status = SubagentStatus(
|
||||
task_id="t1", label="test", task_description="do", started_at=time.monotonic(),
|
||||
)
|
||||
hook = _SubagentHook("t1", status)
|
||||
ctx = _make_hook_context(error="something broke")
|
||||
await hook.after_iteration(ctx)
|
||||
assert status.error == "something broke"
|
||||
228
tests/channels/test_channel_manager_reasoning.py
Normal file
228
tests/channels/test_channel_manager_reasoning.py
Normal file
@ -0,0 +1,228 @@
|
||||
"""Tests for ChannelManager routing of model reasoning content.
|
||||
|
||||
Reasoning is delivered through plugin streaming primitives
|
||||
(``send_reasoning_delta`` / ``send_reasoning_end``) so each channel
|
||||
controls in-place rendering — mirroring the existing answer ``send_delta``
|
||||
/ ``stream_end`` pair. The manager forwards reasoning frames only to
|
||||
channels that opt in via ``channel.show_reasoning``; plugins without a
|
||||
low-emphasis UI primitive keep the base no-op and the content silently
|
||||
drops at dispatch.
|
||||
|
||||
One-shot ``_reasoning`` frames are accepted for back-compat with hooks
|
||||
that haven't migrated yet — ``BaseChannel.send_reasoning`` expands them
|
||||
to a single delta + end pair so plugins only implement the streaming
|
||||
primitives.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.channels.manager import ChannelManager
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
|
||||
class _MockChannel(BaseChannel):
|
||||
name = "mock"
|
||||
display_name = "Mock"
|
||||
|
||||
def __init__(self, config, bus):
|
||||
super().__init__(config, bus)
|
||||
self._send_mock = AsyncMock()
|
||||
self._delta_mock = AsyncMock()
|
||||
self._end_mock = AsyncMock()
|
||||
|
||||
async def start(self): # pragma: no cover - not exercised
|
||||
pass
|
||||
|
||||
async def stop(self): # pragma: no cover - not exercised
|
||||
pass
|
||||
|
||||
async def send(self, msg):
|
||||
return await self._send_mock(msg)
|
||||
|
||||
async def send_reasoning_delta(self, chat_id, delta, metadata=None):
|
||||
return await self._delta_mock(chat_id, delta, metadata)
|
||||
|
||||
async def send_reasoning_end(self, chat_id, metadata=None):
|
||||
return await self._end_mock(chat_id, metadata)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager() -> ChannelManager:
|
||||
mgr = ChannelManager(Config(), MessageBus())
|
||||
mgr.channels["mock"] = _MockChannel({}, mgr.bus)
|
||||
return mgr
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_delta_routes_to_send_reasoning_delta(manager):
|
||||
channel = manager.channels["mock"]
|
||||
msg = OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="c1",
|
||||
content="step-by-step",
|
||||
metadata={"_progress": True, "_reasoning_delta": True, "_stream_id": "r1"},
|
||||
)
|
||||
await manager._send_once(channel, msg)
|
||||
channel._delta_mock.assert_awaited_once()
|
||||
args = channel._delta_mock.await_args.args
|
||||
assert args[0] == "c1"
|
||||
assert args[1] == "step-by-step"
|
||||
channel._send_mock.assert_not_awaited()
|
||||
channel._end_mock.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_end_routes_to_send_reasoning_end(manager):
|
||||
channel = manager.channels["mock"]
|
||||
msg = OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="c1",
|
||||
content="",
|
||||
metadata={"_progress": True, "_reasoning_end": True, "_stream_id": "r1"},
|
||||
)
|
||||
await manager._send_once(channel, msg)
|
||||
channel._end_mock.assert_awaited_once()
|
||||
channel._delta_mock.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_one_shot_reasoning_expands_to_delta_plus_end(manager):
|
||||
"""`_reasoning` (no delta/end pair) falls back through `send_reasoning`
|
||||
which the base class expands to a single delta + end. Hooks that haven't
|
||||
migrated still surface in WebUI as a complete stream segment."""
|
||||
channel = manager.channels["mock"]
|
||||
msg = OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="c1",
|
||||
content="one-shot reasoning",
|
||||
metadata={"_progress": True, "_reasoning": True},
|
||||
)
|
||||
await manager._send_once(channel, msg)
|
||||
channel._delta_mock.assert_awaited_once()
|
||||
channel._end_mock.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_drops_reasoning_when_channel_opts_out(manager):
|
||||
channel = manager.channels["mock"]
|
||||
channel.show_reasoning = False
|
||||
msg = OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="c1",
|
||||
content="hidden thinking",
|
||||
metadata={"_progress": True, "_reasoning_delta": True},
|
||||
)
|
||||
await manager.bus.publish_outbound(msg)
|
||||
|
||||
await _pump_one(manager)
|
||||
|
||||
channel._delta_mock.assert_not_awaited()
|
||||
channel._end_mock.assert_not_awaited()
|
||||
channel._send_mock.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_delivers_reasoning_when_channel_opts_in(manager):
|
||||
channel = manager.channels["mock"]
|
||||
channel.show_reasoning = True
|
||||
for chunk in ("first ", "second"):
|
||||
await manager.bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="c1",
|
||||
content=chunk,
|
||||
metadata={"_progress": True, "_reasoning_delta": True, "_stream_id": "r1"},
|
||||
))
|
||||
await manager.bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="c1",
|
||||
content="",
|
||||
metadata={"_progress": True, "_reasoning_end": True, "_stream_id": "r1"},
|
||||
))
|
||||
|
||||
await _pump_one(manager)
|
||||
|
||||
assert channel._delta_mock.await_count == 2
|
||||
channel._end_mock.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_silently_drops_reasoning_for_unknown_channel(manager):
|
||||
msg = OutboundMessage(
|
||||
channel="ghost",
|
||||
chat_id="c1",
|
||||
content="nobody home",
|
||||
metadata={"_progress": True, "_reasoning_delta": True},
|
||||
)
|
||||
await manager.bus.publish_outbound(msg)
|
||||
|
||||
await _pump_one(manager)
|
||||
|
||||
manager.channels["mock"]._delta_mock.assert_not_awaited()
|
||||
manager.channels["mock"]._send_mock.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_channel_reasoning_primitives_are_noop_safe():
|
||||
"""Plugins that don't override the streaming primitives must not blow up."""
|
||||
|
||||
class _Plain(BaseChannel):
|
||||
name = "plain"
|
||||
display_name = "Plain"
|
||||
|
||||
async def start(self): # pragma: no cover
|
||||
pass
|
||||
|
||||
async def stop(self): # pragma: no cover
|
||||
pass
|
||||
|
||||
async def send(self, msg): # pragma: no cover
|
||||
pass
|
||||
|
||||
channel = _Plain({}, MessageBus())
|
||||
assert await channel.send_reasoning_delta("c", "x") is None
|
||||
assert await channel.send_reasoning_end("c") is None
|
||||
# And the one-shot wrapper translates without raising.
|
||||
assert await channel.send_reasoning(
|
||||
OutboundMessage(channel="plain", chat_id="c", content="x", metadata={})
|
||||
) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_routing_does_not_consult_send_progress(manager):
|
||||
"""`show_reasoning` is orthogonal to `send_progress` — turning off
|
||||
progress streaming must not silence reasoning."""
|
||||
channel = manager.channels["mock"]
|
||||
channel.send_progress = False
|
||||
channel.show_reasoning = True
|
||||
await manager.bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="c1",
|
||||
content="still surfaces",
|
||||
metadata={"_progress": True, "_reasoning_delta": True},
|
||||
))
|
||||
|
||||
await _pump_one(manager)
|
||||
|
||||
channel._delta_mock.assert_awaited_once()
|
||||
|
||||
|
||||
async def _pump_one(manager: ChannelManager) -> None:
|
||||
"""Drive the dispatcher until the outbound queue drains, then cancel."""
|
||||
task = asyncio.create_task(manager._dispatch_outbound())
|
||||
for _ in range(50):
|
||||
await asyncio.sleep(0.01)
|
||||
if manager.bus.outbound.qsize() == 0:
|
||||
break
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
@ -234,13 +234,13 @@ async def test_send_renders_buttons_on_last_message_chunk() -> None:
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": "Yes"},
|
||||
"value": "Yes",
|
||||
"action_id": "ask_user_Yes",
|
||||
"action_id": "btn_Yes",
|
||||
},
|
||||
{
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": "No"},
|
||||
"value": "No",
|
||||
"action_id": "ask_user_No",
|
||||
"action_id": "btn_No",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
@ -224,11 +224,9 @@ async def test_send_delivers_json_message_with_media_and_reply() -> None:
|
||||
payload = json.loads(mock_ws.send.call_args[0][0])
|
||||
assert payload["event"] == "message"
|
||||
assert payload["chat_id"] == "chat-1"
|
||||
assert payload["text"] == "hello\n\n1. Yes\n2. No"
|
||||
assert payload["button_prompt"] == "hello"
|
||||
assert payload["text"] == "hello"
|
||||
assert payload["reply_to"] == "m1"
|
||||
assert payload["media"] == ["/tmp/a.png"]
|
||||
assert payload["buttons"] == [["Yes", "No"]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -360,6 +358,87 @@ async def test_send_delta_emits_delta_and_stream_end() -> None:
|
||||
assert second["stream_id"] == "sid"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_reasoning_delta_emits_streaming_frame() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
await channel.send_reasoning_delta(
|
||||
"chat-1",
|
||||
"step-by-step thinking",
|
||||
{"_reasoning_delta": True, "_stream_id": "r1"},
|
||||
)
|
||||
|
||||
mock_ws.send.assert_awaited_once()
|
||||
payload = json.loads(mock_ws.send.await_args.args[0])
|
||||
assert payload["event"] == "reasoning_delta"
|
||||
assert payload["chat_id"] == "chat-1"
|
||||
assert payload["text"] == "step-by-step thinking"
|
||||
assert payload["stream_id"] == "r1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_reasoning_end_emits_close_frame() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
await channel.send_reasoning_end("chat-1", {"_reasoning_end": True, "_stream_id": "r1"})
|
||||
|
||||
payload = json.loads(mock_ws.send.await_args.args[0])
|
||||
assert payload == {"event": "reasoning_end", "chat_id": "chat-1", "stream_id": "r1"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_reasoning_one_shot_expands_to_delta_plus_end() -> None:
|
||||
"""``send_reasoning`` is back-compat for hooks that haven't migrated:
|
||||
the base implementation must produce one delta and one end so the
|
||||
WebUI sees the same shape either way."""
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
await channel.send_reasoning(OutboundMessage(
|
||||
channel="websocket",
|
||||
chat_id="chat-1",
|
||||
content="thinking",
|
||||
metadata={"_reasoning": True},
|
||||
))
|
||||
|
||||
assert mock_ws.send.await_count == 2
|
||||
first = json.loads(mock_ws.send.call_args_list[0][0][0])
|
||||
second = json.loads(mock_ws.send.call_args_list[1][0][0])
|
||||
assert first["event"] == "reasoning_delta"
|
||||
assert first["text"] == "thinking"
|
||||
assert second["event"] == "reasoning_end"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_reasoning_delta_drops_empty_chunks() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
await channel.send_reasoning_delta("chat-1", "", {"_reasoning_delta": True})
|
||||
|
||||
mock_ws.send.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_reasoning_without_subscribers_is_noop() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
|
||||
await channel.send_reasoning_delta("unattached", "thinking", None)
|
||||
await channel.send_reasoning_end("unattached", None)
|
||||
# No subscribers, no exception, no send.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_turn_end_emits_turn_end_event() -> None:
|
||||
bus = MagicMock()
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import asyncio
|
||||
from contextlib import nullcontext
|
||||
from io import StringIO
|
||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
@ -96,6 +98,66 @@ def test_print_cli_progress_line_pauses_spinner_before_printing():
|
||||
assert order == ["start", "stop", "print", "start", "stop"]
|
||||
|
||||
|
||||
def test_thinking_spinner_clears_status_line_when_paused():
|
||||
"""Stopping the spinner should erase its transient line before output."""
|
||||
stream = StringIO()
|
||||
stream.isatty = lambda: True # type: ignore[method-assign]
|
||||
mock_console = MagicMock()
|
||||
mock_console.file = stream
|
||||
spinner = MagicMock()
|
||||
mock_console.status.return_value = spinner
|
||||
|
||||
thinking = stream_mod.ThinkingSpinner(console=mock_console)
|
||||
with thinking:
|
||||
with thinking.pause():
|
||||
pass
|
||||
|
||||
assert "\r\x1b[2K" in stream.getvalue()
|
||||
|
||||
|
||||
def test_stream_renderer_stops_spinner_even_after_header_printed():
|
||||
"""A later answer delta must stop the spinner even when header already exists."""
|
||||
stream = StringIO()
|
||||
stream.isatty = lambda: True # type: ignore[method-assign]
|
||||
mock_console = MagicMock()
|
||||
mock_console.file = stream
|
||||
spinner = MagicMock()
|
||||
mock_console.status.return_value = spinner
|
||||
|
||||
with patch.object(stream_mod, "_make_console", return_value=mock_console):
|
||||
renderer = stream_mod.StreamRenderer(show_spinner=True)
|
||||
renderer._header_printed = True
|
||||
renderer.ensure_header()
|
||||
|
||||
spinner.stop.assert_called_once()
|
||||
assert "\r\x1b[2K" in stream.getvalue()
|
||||
|
||||
|
||||
def test_print_cli_progress_line_opens_renderer_header_before_trace():
|
||||
"""Trace lines should appear under the assistant header, not under You."""
|
||||
order: list[str] = []
|
||||
renderer = MagicMock()
|
||||
renderer.console.print.side_effect = lambda *_args, **_kwargs: order.append("print")
|
||||
renderer.ensure_header.side_effect = lambda: order.append("header")
|
||||
renderer.pause_spinner.return_value = nullcontext()
|
||||
|
||||
commands._print_cli_progress_line("tool running", None, renderer)
|
||||
|
||||
assert order == ["header", "print"]
|
||||
|
||||
|
||||
def test_print_cli_progress_line_stops_live_before_trace():
|
||||
"""A trace line should not leak the current transient Live frame."""
|
||||
mock_live = MagicMock()
|
||||
renderer = stream_mod.StreamRenderer(show_spinner=False)
|
||||
renderer._live = mock_live
|
||||
|
||||
commands._print_cli_progress_line("tool running", None, renderer)
|
||||
|
||||
mock_live.stop.assert_called_once()
|
||||
assert renderer._live is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_print_interactive_progress_line_pauses_spinner_before_printing():
|
||||
"""Interactive progress output should also pause spinner cleanly."""
|
||||
@ -156,17 +218,65 @@ def test_stream_renderer_stop_for_input_stops_spinner():
|
||||
# Create renderer with mocked console
|
||||
with patch.object(stream_mod, "_make_console", return_value=mock_console):
|
||||
renderer = stream_mod.StreamRenderer(show_spinner=True)
|
||||
|
||||
|
||||
# Verify spinner started
|
||||
spinner.start.assert_called_once()
|
||||
|
||||
|
||||
# Stop for input
|
||||
renderer.stop_for_input()
|
||||
|
||||
|
||||
# Verify spinner stopped
|
||||
spinner.stop.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_end_writes_final_content_to_stdout_after_stopping_live():
|
||||
"""on_end should stop Live (transient erases it) then print final content to stdout."""
|
||||
mock_live = MagicMock()
|
||||
mock_console = MagicMock()
|
||||
mock_console.capture.return_value.__enter__ = MagicMock(
|
||||
return_value=MagicMock(get=lambda: "final output\n")
|
||||
)
|
||||
mock_console.capture.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch.object(stream_mod, "_make_console", return_value=mock_console):
|
||||
renderer = stream_mod.StreamRenderer(show_spinner=False)
|
||||
renderer._live = mock_live
|
||||
renderer._buf = "final output"
|
||||
|
||||
written: list[str] = []
|
||||
with patch("sys.stdout") as mock_stdout:
|
||||
mock_stdout.write = lambda s: written.append(s)
|
||||
mock_stdout.flush = MagicMock()
|
||||
await renderer.on_end()
|
||||
|
||||
mock_live.stop.assert_called_once()
|
||||
assert renderer._live is None
|
||||
assert written == ["final output\n"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_end_resuming_clears_buffer_and_restarts_spinner():
|
||||
"""on_end(resuming=True) should reset state for the next iteration."""
|
||||
spinner = MagicMock()
|
||||
mock_console = MagicMock()
|
||||
mock_console.status.return_value = spinner
|
||||
mock_console.capture.return_value.__enter__ = MagicMock(
|
||||
return_value=MagicMock(get=lambda: "")
|
||||
)
|
||||
mock_console.capture.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch.object(stream_mod, "_make_console", return_value=mock_console):
|
||||
renderer = stream_mod.StreamRenderer(show_spinner=True)
|
||||
renderer._buf = "some content"
|
||||
|
||||
await renderer.on_end(resuming=True)
|
||||
|
||||
assert renderer._buf == ""
|
||||
# Spinner should have been restarted (start called twice: __init__ + resuming)
|
||||
assert spinner.start.call_count == 2
|
||||
|
||||
|
||||
def test_make_console_force_terminal_when_stdout_is_tty():
|
||||
"""Console should set force_terminal=True when stdout is a TTY (rich output)."""
|
||||
import sys
|
||||
|
||||
@ -17,7 +17,7 @@ async def test_interactive_retry_wait_is_rendered_as_progress_even_when_progress
|
||||
metadata={"_retry_wait": True},
|
||||
)
|
||||
|
||||
async def fake_print(text: str, active_thinking: object | None) -> None:
|
||||
async def fake_print(text: str, active_thinking: object | None, renderer=None) -> None:
|
||||
calls.append((text, active_thinking))
|
||||
|
||||
with patch("nanobot.cli.commands._print_interactive_progress_line", side_effect=fake_print):
|
||||
@ -29,3 +29,104 @@ async def test_interactive_retry_wait_is_rendered_as_progress_even_when_progress
|
||||
|
||||
assert handled is True
|
||||
assert calls == [("Model request failed, retry in 2s (attempt 1).", thinking)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_displayed_when_show_reasoning_enabled():
|
||||
"""Reasoning content should be displayed when show_reasoning is True."""
|
||||
calls: list[str] = []
|
||||
channels_config = SimpleNamespace(
|
||||
send_progress=True, send_tool_hints=False, show_reasoning=True,
|
||||
)
|
||||
msg = SimpleNamespace(
|
||||
content="Let me think about this...",
|
||||
metadata={"_progress": True, "_reasoning": True},
|
||||
)
|
||||
|
||||
with patch("nanobot.cli.commands._print_cli_reasoning", side_effect=lambda t, th, r=None: calls.append(t)):
|
||||
handled = await commands._maybe_print_interactive_progress(msg, None, channels_config)
|
||||
|
||||
assert handled is True
|
||||
assert calls == ["Let me think about this..."]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_delta_displayed_when_show_reasoning_enabled():
|
||||
"""Streamed reasoning delta frames should use the reasoning renderer."""
|
||||
calls: list[str] = []
|
||||
channels_config = SimpleNamespace(
|
||||
send_progress=True, send_tool_hints=False, show_reasoning=True,
|
||||
)
|
||||
msg = SimpleNamespace(
|
||||
content="I should search first.",
|
||||
metadata={"_progress": True, "_reasoning_delta": True},
|
||||
)
|
||||
|
||||
with patch("nanobot.cli.commands._print_cli_reasoning", side_effect=lambda t, th, r=None: calls.append(t)):
|
||||
handled = await commands._maybe_print_interactive_progress(msg, None, channels_config)
|
||||
|
||||
assert handled is True
|
||||
assert calls == ["I should search first."]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_hidden_when_show_reasoning_disabled():
|
||||
"""Reasoning content should be suppressed when show_reasoning is False."""
|
||||
channels_config = SimpleNamespace(
|
||||
send_progress=True, send_tool_hints=False, show_reasoning=False,
|
||||
)
|
||||
msg = SimpleNamespace(
|
||||
content="Let me think about this...",
|
||||
metadata={"_progress": True, "_reasoning": True},
|
||||
)
|
||||
|
||||
with patch("nanobot.cli.commands._print_cli_reasoning") as mock_reasoning:
|
||||
handled = await commands._maybe_print_interactive_progress(msg, None, channels_config)
|
||||
|
||||
assert handled is True
|
||||
mock_reasoning.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_reasoning_progress_not_affected_by_show_reasoning():
|
||||
"""Regular progress lines should display regardless of show_reasoning."""
|
||||
calls: list[str] = []
|
||||
channels_config = SimpleNamespace(
|
||||
send_progress=True, send_tool_hints=False, show_reasoning=False,
|
||||
)
|
||||
msg = SimpleNamespace(
|
||||
content="working on it...",
|
||||
metadata={"_progress": True},
|
||||
)
|
||||
|
||||
async def fake_print(text: str, thinking=None, renderer=None):
|
||||
calls.append(text)
|
||||
|
||||
with patch("nanobot.cli.commands._print_interactive_progress_line", side_effect=fake_print):
|
||||
handled = await commands._maybe_print_interactive_progress(msg, None, channels_config)
|
||||
|
||||
assert handled is True
|
||||
assert calls == ["working on it..."]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_shown_when_send_progress_disabled():
|
||||
"""Reasoning display is governed by `show_reasoning` alone, independent
|
||||
of `send_progress` — the two knobs are orthogonal."""
|
||||
calls: list[str] = []
|
||||
channels_config = SimpleNamespace(
|
||||
send_progress=False, send_tool_hints=False, show_reasoning=True,
|
||||
)
|
||||
msg = SimpleNamespace(
|
||||
content="Let me think about this...",
|
||||
metadata={"_progress": True, "_reasoning": True},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"nanobot.cli.commands._print_cli_reasoning",
|
||||
side_effect=lambda t, th, r=None: calls.append(t),
|
||||
):
|
||||
handled = await commands._maybe_print_interactive_progress(msg, None, channels_config)
|
||||
|
||||
assert handled is True
|
||||
assert calls == ["Let me think about this..."]
|
||||
|
||||
@ -405,7 +405,7 @@ def test_loader_registers_same_tools_as_old_hardcoded():
|
||||
registered = loader.load(ctx, registry)
|
||||
|
||||
expected = {
|
||||
"ask_user", "read_file", "write_file", "edit_file", "list_dir",
|
||||
"read_file", "write_file", "edit_file", "list_dir",
|
||||
"glob", "grep", "notebook_edit", "exec", "web_search", "web_fetch",
|
||||
"message", "spawn", "cron",
|
||||
}
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from nanobot.utils.helpers import strip_think
|
||||
from nanobot.utils.helpers import extract_reasoning, extract_think, strip_think
|
||||
|
||||
|
||||
class TestStripThinkTag:
|
||||
@ -144,3 +144,130 @@ class TestStripThinkConservativePreserve:
|
||||
def test_literal_channel_marker_in_code_block_preserved(self):
|
||||
text = "Example:\n```\nif line.startswith('<channel|>'):\n skip()\n```"
|
||||
assert strip_think(text) == text
|
||||
|
||||
|
||||
class TestExtractThink:
|
||||
|
||||
def test_no_think_tags(self):
|
||||
thinking, clean = extract_think("Hello World")
|
||||
assert thinking is None
|
||||
assert clean == "Hello World"
|
||||
|
||||
def test_single_think_block(self):
|
||||
text = "Hello <think>reasoning content\nhere</think> World"
|
||||
thinking, clean = extract_think(text)
|
||||
assert thinking == "reasoning content\nhere"
|
||||
assert clean == "Hello World"
|
||||
|
||||
def test_single_thought_block(self):
|
||||
text = "Hello <thought>reasoning content</thought> World"
|
||||
thinking, clean = extract_think(text)
|
||||
assert thinking == "reasoning content"
|
||||
assert clean == "Hello World"
|
||||
|
||||
def test_multiple_think_blocks(self):
|
||||
text = "A<think>first</think>B<thought>second</thought>C"
|
||||
thinking, clean = extract_think(text)
|
||||
assert thinking == "first\n\nsecond"
|
||||
assert clean == "ABC"
|
||||
|
||||
def test_think_only_no_content(self):
|
||||
text = "<think>just thinking</think>"
|
||||
thinking, clean = extract_think(text)
|
||||
assert thinking == "just thinking"
|
||||
assert clean == ""
|
||||
|
||||
def test_unclosed_think_not_extracted(self):
|
||||
# Unclosed blocks at start are stripped but NOT extracted
|
||||
text = "<think>unclosed thinking..."
|
||||
thinking, clean = extract_think(text)
|
||||
assert thinking is None
|
||||
assert clean == ""
|
||||
|
||||
def test_empty_think_block(self):
|
||||
text = "Hello <think></think> World"
|
||||
thinking, clean = extract_think(text)
|
||||
# Empty blocks result in empty string after strip
|
||||
assert thinking == ""
|
||||
assert clean == "Hello World"
|
||||
|
||||
def test_think_with_whitespace_only(self):
|
||||
text = "Hello <think> \n World"
|
||||
thinking, clean = extract_think(text)
|
||||
assert thinking is None
|
||||
assert clean == "Hello <think> \n World"
|
||||
|
||||
def test_mixed_think_and_thought(self):
|
||||
text = "Start<think>first reasoning</think>middle<thought>second reasoning</thought>End"
|
||||
thinking, clean = extract_think(text)
|
||||
assert thinking == "first reasoning\n\nsecond reasoning"
|
||||
assert clean == "StartmiddleEnd"
|
||||
|
||||
def test_real_world_ollama_response(self):
|
||||
text = """<think>
|
||||
The user is asking about Python list comprehensions.
|
||||
Let me explain the syntax and give examples.
|
||||
</think>
|
||||
|
||||
List comprehensions in Python provide a concise way to create lists. Here's the syntax:
|
||||
|
||||
```python
|
||||
[expression for item in iterable if condition]
|
||||
```
|
||||
|
||||
For example:
|
||||
```python
|
||||
squares = [x**2 for x in range(10)]
|
||||
```"""
|
||||
thinking, clean = extract_think(text)
|
||||
assert "list comprehensions" in thinking.lower()
|
||||
assert "Let me explain" in thinking
|
||||
assert "List comprehensions in Python" in clean
|
||||
assert "<think>" not in clean
|
||||
assert "</think>" not in clean
|
||||
|
||||
|
||||
class TestExtractReasoning:
|
||||
"""Single source of truth for reasoning extraction across all providers."""
|
||||
|
||||
def test_prefers_reasoning_content_and_strips_inline_think(self):
|
||||
# Dedicated field wins; inline tags are still scrubbed from content.
|
||||
reasoning, content = extract_reasoning(
|
||||
"dedicated",
|
||||
None,
|
||||
"<think>inline</think>visible answer",
|
||||
)
|
||||
assert reasoning == "dedicated"
|
||||
assert content == "visible answer"
|
||||
|
||||
def test_falls_back_to_thinking_blocks(self):
|
||||
reasoning, content = extract_reasoning(
|
||||
None,
|
||||
[
|
||||
{"type": "thinking", "thinking": "step 1"},
|
||||
{"type": "thinking", "thinking": "step 2"},
|
||||
{"type": "redacted_thinking"},
|
||||
],
|
||||
"hello",
|
||||
)
|
||||
assert reasoning == "step 1\n\nstep 2"
|
||||
assert content == "hello"
|
||||
|
||||
def test_falls_back_to_inline_think_tags(self):
|
||||
reasoning, content = extract_reasoning(
|
||||
None, None, "<think>plan</think>answer"
|
||||
)
|
||||
assert reasoning == "plan"
|
||||
assert content == "answer"
|
||||
|
||||
def test_no_reasoning_returns_none(self):
|
||||
reasoning, content = extract_reasoning(None, None, "plain answer")
|
||||
assert reasoning is None
|
||||
assert content == "plain answer"
|
||||
|
||||
def test_empty_thinking_blocks_falls_through_to_inline(self):
|
||||
reasoning, content = extract_reasoning(
|
||||
None, [], "<think>plan</think>answer"
|
||||
)
|
||||
assert reasoning == "plan"
|
||||
assert content == "answer"
|
||||
|
||||
@ -250,7 +250,6 @@ function Shell({ onModelNameChange, onLogout }: { onModelNameChange: (modelName:
|
||||
key: string;
|
||||
label: string;
|
||||
} | null>(null);
|
||||
const lastSessionsLen = useRef(0);
|
||||
const restartSawDisconnectRef = useRef(false);
|
||||
const [restartToast, setRestartToast] = useState<string | null>(null);
|
||||
const [isRestarting, setIsRestarting] = useState(false);
|
||||
@ -266,13 +265,7 @@ function Shell({ onModelNameChange, onLogout }: { onModelNameChange: (modelName:
|
||||
}
|
||||
}, [desktopSidebarOpen]);
|
||||
|
||||
useEffect(() => {
|
||||
if (activeKey) return;
|
||||
if (sessions.length > 0 && lastSessionsLen.current === 0) {
|
||||
setActiveKey(sessions[0].key);
|
||||
}
|
||||
lastSessionsLen.current = sessions.length;
|
||||
}, [sessions, activeKey]);
|
||||
|
||||
|
||||
const activeSession = useMemo<ChatSummary | null>(() => {
|
||||
if (!activeKey) return null;
|
||||
@ -335,9 +328,8 @@ function Shell({ onModelNameChange, onLogout }: { onModelNameChange: (modelName:
|
||||
setView("chat");
|
||||
setMobileSidebarOpen(false);
|
||||
setActiveKey((current) => {
|
||||
if (current && sessions.some((session) => session.key === current)) {
|
||||
return current;
|
||||
}
|
||||
if (!current) return null;
|
||||
if (sessions.some((session) => session.key === current)) return current;
|
||||
return sessions[0]?.key ?? null;
|
||||
});
|
||||
}, [sessions]);
|
||||
@ -479,18 +471,13 @@ function Shell({ onModelNameChange, onLogout }: { onModelNameChange: (modelName:
|
||||
</Sheet>
|
||||
) : null}
|
||||
|
||||
<main className="flex h-full min-w-0 flex-1 flex-col">
|
||||
{view === "settings" ? (
|
||||
<SettingsView
|
||||
theme={theme}
|
||||
onToggleTheme={toggle}
|
||||
onBackToChat={onBackToChat}
|
||||
onModelNameChange={onModelNameChange}
|
||||
onLogout={onLogout}
|
||||
onRestart={onRestart}
|
||||
isRestarting={isRestarting}
|
||||
/>
|
||||
) : (
|
||||
<main className="relative flex h-full min-w-0 flex-1 flex-col">
|
||||
<div
|
||||
className={cn(
|
||||
"absolute inset-0 flex flex-col",
|
||||
view === "settings" && "invisible pointer-events-none",
|
||||
)}
|
||||
>
|
||||
<ThreadShell
|
||||
session={activeSession}
|
||||
title={headerTitle}
|
||||
@ -502,6 +489,19 @@ function Shell({ onModelNameChange, onLogout }: { onModelNameChange: (modelName:
|
||||
onToggleTheme={toggle}
|
||||
hideSidebarToggleOnDesktop={desktopSidebarOpen}
|
||||
/>
|
||||
</div>
|
||||
{view === "settings" && (
|
||||
<div className="absolute inset-0 flex flex-col">
|
||||
<SettingsView
|
||||
theme={theme}
|
||||
onToggleTheme={toggle}
|
||||
onBackToChat={onBackToChat}
|
||||
onModelNameChange={onModelNameChange}
|
||||
onLogout={onLogout}
|
||||
onRestart={onRestart}
|
||||
isRestarting={isRestarting}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</main>
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { Check, ChevronRight, Copy, FileIcon, ImageIcon, PlaySquare, Wrench } from "lucide-react";
|
||||
import { Check, ChevronRight, Copy, FileIcon, ImageIcon, PlaySquare, Sparkles, Wrench } from "lucide-react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
|
||||
import { ImageLightbox } from "@/components/ImageLightbox";
|
||||
@ -85,12 +85,18 @@ export function MessageBubble({ message }: MessageBubbleProps) {
|
||||
|
||||
const empty = message.content.trim().length === 0;
|
||||
const media = message.media ?? [];
|
||||
const reasoning = message.role === "assistant" ? message.reasoning ?? "" : "";
|
||||
const reasoningStreaming = !!(message.role === "assistant" && message.reasoningStreaming);
|
||||
const hasReasoning = reasoning.length > 0 || reasoningStreaming;
|
||||
const showAssistantActions = message.role === "assistant" && !message.isStreaming && !empty;
|
||||
return (
|
||||
<div className={cn("w-full text-[15px]", baseAnim)} style={{ lineHeight: "var(--cjk-line-height)" }}>
|
||||
{empty && message.isStreaming ? (
|
||||
{hasReasoning ? (
|
||||
<ReasoningBubble text={reasoning} streaming={reasoningStreaming} hasBodyBelow={!empty} />
|
||||
) : null}
|
||||
{empty && message.isStreaming && !hasReasoning ? (
|
||||
<TypingDots />
|
||||
) : (
|
||||
) : empty && message.isStreaming ? null : (
|
||||
<>
|
||||
<MarkdownText>{message.content}</MarkdownText>
|
||||
{message.isStreaming && <StreamCursor />}
|
||||
@ -380,14 +386,14 @@ interface TraceGroupProps {
|
||||
|
||||
/**
|
||||
* Collapsible group of tool-call / progress breadcrumbs. Defaults to
|
||||
* expanded for discoverability; a single click on the header folds the
|
||||
* group down to a one-line summary so it never dominates the thread.
|
||||
* collapsed because tool traces are supporting evidence, not the answer.
|
||||
* A single click expands the exact calls when the user wants details.
|
||||
*/
|
||||
function TraceGroup({ message, animClass }: TraceGroupProps) {
|
||||
const { t } = useTranslation();
|
||||
const lines = message.traces ?? [message.content];
|
||||
const count = lines.length;
|
||||
const [open, setOpen] = useState(true);
|
||||
const [open, setOpen] = useState(false);
|
||||
return (
|
||||
<div className={cn("w-full", animClass)}>
|
||||
<button
|
||||
@ -433,3 +439,79 @@ function TraceGroup({ message, animClass }: TraceGroupProps) {
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface ReasoningBubbleProps {
|
||||
text: string;
|
||||
streaming: boolean;
|
||||
hasBodyBelow: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Subordinate "thinking" trace shown above an assistant turn.
|
||||
*
|
||||
* Lifecycle:
|
||||
* - While ``streaming`` is true (``reasoning_delta`` frames still arriving),
|
||||
* the bubble defaults to open and the header runs a shimmer + pulse so
|
||||
* the user sees the model "thinking out loud" in real time.
|
||||
* - On ``reasoning_end`` the bubble auto-collapses for prose density —
|
||||
* the user can re-expand to inspect the chain of thought. The local
|
||||
* toggle persists once the user interacts.
|
||||
*/
|
||||
function ReasoningBubble({ text, streaming, hasBodyBelow }: ReasoningBubbleProps) {
|
||||
const { t } = useTranslation();
|
||||
const [userToggled, setUserToggled] = useState(false);
|
||||
const [openLocal, setOpenLocal] = useState(true);
|
||||
const open = userToggled ? openLocal : streaming;
|
||||
const onToggle = () => {
|
||||
setUserToggled(true);
|
||||
setOpenLocal((v) => (userToggled ? !v : !open));
|
||||
};
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"w-full animate-in fade-in-0 slide-in-from-top-1 duration-200",
|
||||
hasBodyBelow && "mb-2",
|
||||
)}
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"group flex w-full items-center gap-2 rounded-md px-2 py-1.5",
|
||||
"text-xs text-muted-foreground transition-colors hover:bg-muted/45",
|
||||
streaming && "reasoning-shimmer",
|
||||
)}
|
||||
aria-expanded={open}
|
||||
aria-live={streaming ? "polite" : undefined}
|
||||
>
|
||||
<Sparkles
|
||||
className={cn("h-3.5 w-3.5", streaming && "animate-pulse")}
|
||||
aria-hidden
|
||||
/>
|
||||
<span className="font-medium">
|
||||
{streaming
|
||||
? t("message.reasoningStreaming", { defaultValue: "Thinking…" })
|
||||
: t("message.reasoning", { defaultValue: "Thinking" })}
|
||||
</span>
|
||||
<ChevronRight
|
||||
aria-hidden
|
||||
className={cn(
|
||||
"ml-auto h-3.5 w-3.5 transition-transform duration-200",
|
||||
open && "rotate-90",
|
||||
)}
|
||||
/>
|
||||
</button>
|
||||
{open && text.length > 0 && (
|
||||
<div
|
||||
className={cn(
|
||||
"mt-1 space-y-0.5 whitespace-pre-wrap break-words border-l border-muted-foreground/20 pl-3",
|
||||
"animate-in fade-in-0 slide-in-from-top-1 duration-200",
|
||||
"text-[12.5px] italic leading-relaxed text-muted-foreground/85",
|
||||
)}
|
||||
>
|
||||
{text}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@ -1,108 +0,0 @@
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { MessageSquareText } from "lucide-react";
|
||||
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface AskUserPromptProps {
|
||||
question: string;
|
||||
buttons: string[][];
|
||||
onAnswer: (answer: string) => void;
|
||||
}
|
||||
|
||||
export function AskUserPrompt({
|
||||
question,
|
||||
buttons,
|
||||
onAnswer,
|
||||
}: AskUserPromptProps) {
|
||||
const [customOpen, setCustomOpen] = useState(false);
|
||||
const [custom, setCustom] = useState("");
|
||||
const inputRef = useRef<HTMLTextAreaElement>(null);
|
||||
const options = buttons.flat().filter(Boolean);
|
||||
|
||||
useEffect(() => {
|
||||
if (customOpen) {
|
||||
inputRef.current?.focus();
|
||||
}
|
||||
}, [customOpen]);
|
||||
|
||||
const submitCustom = useCallback(() => {
|
||||
const answer = custom.trim();
|
||||
if (!answer) return;
|
||||
onAnswer(answer);
|
||||
setCustom("");
|
||||
setCustomOpen(false);
|
||||
}, [custom, onAnswer]);
|
||||
|
||||
if (options.length === 0) return null;
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"mx-auto mb-2 w-full max-w-[49.5rem] rounded-[16px] border border-primary/30",
|
||||
"bg-card/95 p-3 shadow-sm backdrop-blur",
|
||||
)}
|
||||
role="group"
|
||||
aria-label="Question"
|
||||
>
|
||||
<div className="mb-2 flex items-start gap-2">
|
||||
<div className="mt-0.5 rounded-full bg-primary/10 p-1.5 text-primary">
|
||||
<MessageSquareText className="h-3.5 w-3.5" aria-hidden />
|
||||
</div>
|
||||
<p className="min-w-0 flex-1 text-[13.5px] font-medium leading-5 text-foreground">
|
||||
{question}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-1.5 sm:grid-cols-2">
|
||||
{options.map((option) => (
|
||||
<Button
|
||||
key={option}
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => onAnswer(option)}
|
||||
className="justify-start rounded-[10px] px-3 text-left"
|
||||
>
|
||||
<span className="truncate">{option}</span>
|
||||
</Button>
|
||||
))}
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={() => setCustomOpen((open) => !open)}
|
||||
className="justify-start rounded-[10px] px-3 text-muted-foreground"
|
||||
>
|
||||
Other...
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{customOpen ? (
|
||||
<div className="mt-2 flex gap-2">
|
||||
<textarea
|
||||
ref={inputRef}
|
||||
value={custom}
|
||||
onChange={(event) => setCustom(event.target.value)}
|
||||
onKeyDown={(event) => {
|
||||
if (event.key === "Enter" && !event.shiftKey && !event.nativeEvent.isComposing) {
|
||||
event.preventDefault();
|
||||
submitCustom();
|
||||
}
|
||||
}}
|
||||
rows={1}
|
||||
placeholder="Type your own answer..."
|
||||
className={cn(
|
||||
"min-h-9 flex-1 resize-none rounded-[10px] border border-border/70 bg-background",
|
||||
"px-3 py-2 text-[13.5px] leading-5 outline-none placeholder:text-muted-foreground",
|
||||
"focus-visible:ring-1 focus-visible:ring-primary/40",
|
||||
)}
|
||||
/>
|
||||
<Button type="button" size="sm" onClick={submitCustom} disabled={!custom.trim()}>
|
||||
Send
|
||||
</Button>
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -1,6 +1,7 @@
|
||||
import {
|
||||
useCallback,
|
||||
useEffect,
|
||||
useLayoutEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
@ -77,6 +78,17 @@ const COMMAND_ICONS: Record<string, LucideIcon> = {
|
||||
type ImageAspectRatio = "auto" | "1:1" | "3:4" | "9:16" | "4:3" | "16:9";
|
||||
|
||||
const IMAGE_ASPECT_RATIOS: ImageAspectRatio[] = ["auto", "1:1", "3:4", "9:16", "4:3", "16:9"];
|
||||
const SLASH_PALETTE_GAP_PX = 8;
|
||||
const SLASH_PALETTE_MAX_HEIGHT_PX = 288;
|
||||
const SLASH_PALETTE_MIN_HEIGHT_PX = 144;
|
||||
const SLASH_PALETTE_CHROME_PX = 64;
|
||||
|
||||
type SlashPalettePlacement = "above" | "below";
|
||||
|
||||
interface SlashPaletteLayout {
|
||||
placement: SlashPalettePlacement;
|
||||
maxHeight: number;
|
||||
}
|
||||
|
||||
function slashCommandI18nKey(command: string): string {
|
||||
return command.replace(/^\//, "").replace(/-/g, "_");
|
||||
@ -96,6 +108,24 @@ function scrollNearestOverflowParent(target: EventTarget | null, deltaY: number)
|
||||
}
|
||||
}
|
||||
|
||||
function getVisibleBounds(el: HTMLElement): { top: number; bottom: number } {
|
||||
let top = 0;
|
||||
let bottom = window.innerHeight;
|
||||
let parent = el.parentElement;
|
||||
|
||||
while (parent) {
|
||||
const style = window.getComputedStyle(parent);
|
||||
if (/(auto|scroll|hidden|clip)/.test(style.overflowY)) {
|
||||
const rect = parent.getBoundingClientRect();
|
||||
top = Math.max(top, rect.top);
|
||||
bottom = Math.min(bottom, rect.bottom);
|
||||
}
|
||||
parent = parent.parentElement;
|
||||
}
|
||||
|
||||
return { top, bottom };
|
||||
}
|
||||
|
||||
export function ThreadComposer({
|
||||
onSend,
|
||||
disabled,
|
||||
@ -117,6 +147,7 @@ export function ThreadComposer({
|
||||
const [imageAspectRatio, setImageAspectRatio] = useState<ImageAspectRatio>("auto");
|
||||
const [aspectMenuOpen, setAspectMenuOpen] = useState(false);
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
const formRef = useRef<HTMLFormElement>(null);
|
||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
const aspectControlRef = useRef<HTMLDivElement>(null);
|
||||
const chipRefs = useRef(new Map<string, HTMLButtonElement>());
|
||||
@ -221,6 +252,10 @@ export function ThreadComposer({
|
||||
}, [slashCommands, slashQuery, t]);
|
||||
|
||||
const showSlashMenu = filteredSlashCommands.length > 0;
|
||||
const [slashPaletteLayout, setSlashPaletteLayout] = useState<SlashPaletteLayout>({
|
||||
placement: "above",
|
||||
maxHeight: SLASH_PALETTE_MAX_HEIGHT_PX,
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
setSelectedCommandIndex(0);
|
||||
@ -232,6 +267,56 @@ export function ThreadComposer({
|
||||
}
|
||||
}, [filteredSlashCommands.length, selectedCommandIndex]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!showSlashMenu) return;
|
||||
|
||||
const dismissOnPointerDown = (event: PointerEvent) => {
|
||||
const target = event.target;
|
||||
if (target instanceof Node && formRef.current?.contains(target)) return;
|
||||
setSlashMenuDismissed(true);
|
||||
};
|
||||
|
||||
document.addEventListener("pointerdown", dismissOnPointerDown, true);
|
||||
return () => {
|
||||
document.removeEventListener("pointerdown", dismissOnPointerDown, true);
|
||||
};
|
||||
}, [showSlashMenu]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
if (!showSlashMenu) return;
|
||||
|
||||
const updateLayout = () => {
|
||||
const form = formRef.current;
|
||||
if (!form) return;
|
||||
const rect = form.getBoundingClientRect();
|
||||
if (rect.width === 0 && rect.height === 0) return;
|
||||
|
||||
const bounds = getVisibleBounds(form);
|
||||
const spaceAbove = Math.max(0, rect.top - bounds.top - SLASH_PALETTE_GAP_PX);
|
||||
const spaceBelow = Math.max(0, bounds.bottom - rect.bottom - SLASH_PALETTE_GAP_PX);
|
||||
const placement: SlashPalettePlacement =
|
||||
spaceAbove >= SLASH_PALETTE_MIN_HEIGHT_PX || spaceAbove >= spaceBelow
|
||||
? "above"
|
||||
: "below";
|
||||
const available = placement === "above" ? spaceAbove : spaceBelow;
|
||||
const maxHeight = Math.min(SLASH_PALETTE_MAX_HEIGHT_PX, available);
|
||||
|
||||
setSlashPaletteLayout((current) =>
|
||||
current.placement === placement && current.maxHeight === maxHeight
|
||||
? current
|
||||
: { placement, maxHeight },
|
||||
);
|
||||
};
|
||||
|
||||
updateLayout();
|
||||
window.addEventListener("resize", updateLayout);
|
||||
document.addEventListener("scroll", updateLayout, true);
|
||||
return () => {
|
||||
window.removeEventListener("resize", updateLayout);
|
||||
document.removeEventListener("scroll", updateLayout, true);
|
||||
};
|
||||
}, [filteredSlashCommands.length, showSlashMenu]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!aspectMenuOpen) return;
|
||||
|
||||
@ -398,6 +483,7 @@ export function ThreadComposer({
|
||||
|
||||
return (
|
||||
<form
|
||||
ref={formRef}
|
||||
onSubmit={(e) => {
|
||||
e.preventDefault();
|
||||
submit();
|
||||
@ -412,6 +498,7 @@ export function ThreadComposer({
|
||||
<SlashCommandPalette
|
||||
commands={filteredSlashCommands}
|
||||
selectedIndex={selectedCommandIndex}
|
||||
layout={slashPaletteLayout}
|
||||
isHero={isHero}
|
||||
onHover={setSelectedCommandIndex}
|
||||
onChoose={chooseSlashCommand}
|
||||
@ -634,6 +721,7 @@ export function ThreadComposer({
|
||||
interface SlashCommandPaletteProps {
|
||||
commands: SlashCommand[];
|
||||
selectedIndex: number;
|
||||
layout: SlashPaletteLayout;
|
||||
isHero: boolean;
|
||||
onHover: (index: number) => void;
|
||||
onChoose: (command: SlashCommand) => void;
|
||||
@ -695,17 +783,24 @@ function ImageAspectMenu({
|
||||
function SlashCommandPalette({
|
||||
commands,
|
||||
selectedIndex,
|
||||
layout,
|
||||
isHero,
|
||||
onHover,
|
||||
onChoose,
|
||||
}: SlashCommandPaletteProps) {
|
||||
const { t } = useTranslation();
|
||||
const listMaxHeight = Math.max(
|
||||
0,
|
||||
layout.maxHeight - SLASH_PALETTE_CHROME_PX,
|
||||
);
|
||||
return (
|
||||
<div
|
||||
role="listbox"
|
||||
aria-label={t("thread.composer.slash.ariaLabel")}
|
||||
style={{ maxHeight: layout.maxHeight }}
|
||||
className={cn(
|
||||
"absolute bottom-full left-1/2 z-30 mb-2 max-h-[22rem] w-[calc(100%-0.5rem)] -translate-x-1/2 overflow-hidden rounded-[18px] border",
|
||||
"absolute left-1/2 z-30 w-[calc(100%-0.5rem)] -translate-x-1/2 overflow-hidden rounded-[18px] border",
|
||||
layout.placement === "above" ? "bottom-full mb-2" : "top-full mt-2",
|
||||
"border-border/65 bg-popover p-1.5 text-popover-foreground shadow-[0_18px_55px_rgba(15,23,42,0.18)]",
|
||||
"dark:border-white/10 dark:shadow-[0_22px_55px_rgba(0,0,0,0.45)]",
|
||||
isHero ? "max-w-[58rem]" : "max-w-[49.5rem]",
|
||||
@ -714,7 +809,7 @@ function SlashCommandPalette({
|
||||
<div className="px-2 pb-1 pt-1 text-[11px] font-medium tracking-[0.08em] text-muted-foreground/70">
|
||||
{t("thread.composer.slash.label")}
|
||||
</div>
|
||||
<div className="max-h-[18rem] overflow-y-auto pr-0.5">
|
||||
<div className="overflow-y-auto pr-0.5" style={{ maxHeight: listMaxHeight }}>
|
||||
{commands.map((command, index) => {
|
||||
const Icon = COMMAND_ICONS[command.icon] ?? CircleHelp;
|
||||
const selected = index === selectedIndex;
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import { MessageBubble } from "@/components/MessageBubble";
|
||||
import { cn } from "@/lib/utils";
|
||||
import type { UIMessage } from "@/lib/types";
|
||||
|
||||
interface ThreadMessagesProps {
|
||||
@ -7,10 +8,30 @@ interface ThreadMessagesProps {
|
||||
|
||||
export function ThreadMessages({ messages }: ThreadMessagesProps) {
|
||||
return (
|
||||
<div className="flex w-full flex-col gap-5">
|
||||
{messages.map((message) => (
|
||||
<MessageBubble key={message.id} message={message} />
|
||||
))}
|
||||
<div className="flex w-full flex-col">
|
||||
{messages.map((message, index) => {
|
||||
const prev = messages[index - 1];
|
||||
const compact = isAuxiliaryRow(message) && prev && isAuxiliaryRow(prev);
|
||||
return (
|
||||
<div
|
||||
key={message.id}
|
||||
className={cn(index > 0 && (compact ? "mt-2" : "mt-5"))}
|
||||
>
|
||||
<MessageBubble message={message} />
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function isAuxiliaryRow(message: UIMessage): boolean {
|
||||
return (
|
||||
message.kind === "trace"
|
||||
|| (
|
||||
message.role === "assistant"
|
||||
&& message.content.trim().length === 0
|
||||
&& (!!message.reasoning || !!message.reasoningStreaming)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@ -13,7 +13,6 @@ import {
|
||||
} from "lucide-react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
|
||||
import { AskUserPrompt } from "@/components/thread/AskUserPrompt";
|
||||
import { ThreadComposer } from "@/components/thread/ThreadComposer";
|
||||
import { ThreadHeader } from "@/components/thread/ThreadHeader";
|
||||
import { StreamErrorNotice } from "@/components/thread/StreamErrorNotice";
|
||||
@ -105,21 +104,6 @@ export function ThreadShell({
|
||||
dismissStreamError,
|
||||
} = useNanobotStream(chatId, initial, hasPendingToolCalls, onTurnEnd);
|
||||
const showHeroComposer = messages.length === 0 && !loading;
|
||||
const pendingAsk = useMemo(() => {
|
||||
for (let index = messages.length - 1; index >= 0; index -= 1) {
|
||||
const message = messages[index];
|
||||
if (message.kind === "trace") continue;
|
||||
if (message.role === "user") return null;
|
||||
if (message.role === "assistant" && message.buttons?.some((row) => row.length > 0)) {
|
||||
return {
|
||||
question: message.content,
|
||||
buttons: message.buttons,
|
||||
};
|
||||
}
|
||||
if (message.role === "assistant") return null;
|
||||
}
|
||||
return null;
|
||||
}, [messages]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!chatId || loading) return;
|
||||
@ -247,13 +231,6 @@ export function ThreadShell({
|
||||
onDismiss={dismissStreamError}
|
||||
/>
|
||||
) : null}
|
||||
{pendingAsk ? (
|
||||
<AskUserPrompt
|
||||
question={pendingAsk.question}
|
||||
buttons={pendingAsk.buttons}
|
||||
onAnswer={send}
|
||||
/>
|
||||
) : null}
|
||||
{session ? (
|
||||
<ThreadComposer
|
||||
onSend={send}
|
||||
@ -283,6 +260,7 @@ export function ThreadShell({
|
||||
}
|
||||
modelLabel={toModelBadgeLabel(modelName)}
|
||||
variant="hero"
|
||||
slashCommands={slashCommands}
|
||||
imageMode={heroImageMode}
|
||||
onImageModeChange={setHeroImageMode}
|
||||
/>
|
||||
|
||||
@ -117,6 +117,34 @@
|
||||
--cjk-line-height: 1.625;
|
||||
}
|
||||
|
||||
/* Shimmer band sweeping across the reasoning header while
|
||||
``reasoning_delta`` frames are arriving. Pure CSS, no JS animation,
|
||||
respects ``prefers-reduced-motion``. */
|
||||
@keyframes reasoning-shimmer-sweep {
|
||||
0% {
|
||||
background-position: -200% 0;
|
||||
}
|
||||
100% {
|
||||
background-position: 200% 0;
|
||||
}
|
||||
}
|
||||
.reasoning-shimmer {
|
||||
background-image: linear-gradient(
|
||||
90deg,
|
||||
transparent 0%,
|
||||
hsl(var(--muted-foreground) / 0.18) 50%,
|
||||
transparent 100%
|
||||
);
|
||||
background-size: 200% 100%;
|
||||
background-repeat: no-repeat;
|
||||
animation: reasoning-shimmer-sweep 2.2s linear infinite;
|
||||
}
|
||||
@media (prefers-reduced-motion: reduce) {
|
||||
.reasoning-shimmer {
|
||||
animation: none;
|
||||
}
|
||||
}
|
||||
|
||||
/* Subtle scrollbar that doesn't fight the dark background. */
|
||||
.scrollbar-thin {
|
||||
scrollbar-width: thin;
|
||||
|
||||
@ -18,6 +18,95 @@ interface StreamBuffer {
|
||||
parts: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Append a reasoning chunk to the last open reasoning stream in ``prev``.
|
||||
*
|
||||
* Lookup rule: prefer the most recent assistant turn in the active UI tail.
|
||||
* Most providers emit reasoning before answer text, but some only expose
|
||||
* ``reasoning_content`` after the answer stream completes. In that post-hoc
|
||||
* case the reasoning still belongs to the same assistant turn and must render
|
||||
* above the answer, not as a new row below it.
|
||||
*/
|
||||
function attachReasoningChunk(prev: UIMessage[], chunk: string): UIMessage[] {
|
||||
for (let i = prev.length - 1; i >= 0; i -= 1) {
|
||||
const candidate = prev[i];
|
||||
// A user turn is a hard boundary: reasoning after it belongs to the new
|
||||
// assistant turn, never to an earlier assistant reply.
|
||||
if (candidate.role === "user") break;
|
||||
// A trace row (e.g. Used tools) is also a phase boundary. Reasoning after
|
||||
// tools belongs to the next assistant iteration, not the assistant turn
|
||||
// that produced those tool calls.
|
||||
if (candidate.kind === "trace") break;
|
||||
if (candidate.role !== "assistant") continue;
|
||||
const hasAnswer = candidate.content.length > 0;
|
||||
if (
|
||||
candidate.reasoningStreaming
|
||||
|| candidate.reasoning !== undefined
|
||||
|| hasAnswer
|
||||
|| candidate.isStreaming
|
||||
) {
|
||||
const merged: UIMessage = {
|
||||
...candidate,
|
||||
reasoning: (candidate.reasoning ?? "") + chunk,
|
||||
reasoningStreaming: true,
|
||||
};
|
||||
return [...prev.slice(0, i), merged, ...prev.slice(i + 1)];
|
||||
}
|
||||
if (!hasAnswer && candidate.isStreaming) {
|
||||
const merged: UIMessage = {
|
||||
...candidate,
|
||||
reasoning: chunk,
|
||||
reasoningStreaming: true,
|
||||
};
|
||||
return [...prev.slice(0, i), merged, ...prev.slice(i + 1)];
|
||||
}
|
||||
break;
|
||||
}
|
||||
return [
|
||||
...prev,
|
||||
{
|
||||
id: crypto.randomUUID(),
|
||||
role: "assistant",
|
||||
content: "",
|
||||
isStreaming: true,
|
||||
reasoning: chunk,
|
||||
reasoningStreaming: true,
|
||||
createdAt: Date.now(),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the most recent assistant placeholder that an incoming answer
|
||||
* delta should adopt instead of spawning a parallel row. We look for an
|
||||
* empty-content assistant turn that is still marked ``isStreaming`` —
|
||||
* typically created earlier by ``reasoning_delta``. Anything else means
|
||||
* the model already produced an answer in a previous turn, so the new
|
||||
* delta belongs in a fresh row.
|
||||
*/
|
||||
function findActiveAssistantPlaceholder(prev: UIMessage[]): string | null {
|
||||
const last = prev[prev.length - 1];
|
||||
if (!last) return null;
|
||||
if (last.role !== "assistant" || last.kind === "trace") return null;
|
||||
if (last.content.length > 0) return null;
|
||||
if (!last.isStreaming) return null;
|
||||
return last.id;
|
||||
}
|
||||
|
||||
/**
|
||||
* Close the active reasoning stream segment, if any. Idempotent: a
|
||||
* ``reasoning_end`` with no preceding deltas is a harmless no-op.
|
||||
*/
|
||||
function closeReasoningStream(prev: UIMessage[]): UIMessage[] {
|
||||
for (let i = prev.length - 1; i >= 0; i -= 1) {
|
||||
const candidate = prev[i];
|
||||
if (!candidate.reasoningStreaming) continue;
|
||||
const merged: UIMessage = { ...candidate, reasoningStreaming: false };
|
||||
return [...prev.slice(0, i), merged, ...prev.slice(i + 1)];
|
||||
}
|
||||
return prev;
|
||||
}
|
||||
|
||||
/**
|
||||
* Subscribe to a chat by ID. Returns the in-memory message list for the chat,
|
||||
* a streaming flag, and a ``send`` function. Initial history must be seeded
|
||||
@ -122,27 +211,42 @@ export function useNanobotStream(
|
||||
|
||||
if (ev.event === "delta") {
|
||||
if (suppressStreamUntilTurnEndRef.current) return;
|
||||
const id = buffer.current?.messageId ?? crypto.randomUUID();
|
||||
if (!buffer.current) {
|
||||
buffer.current = { messageId: id, parts: [] };
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
id,
|
||||
role: "assistant",
|
||||
content: "",
|
||||
isStreaming: true,
|
||||
createdAt: Date.now(),
|
||||
},
|
||||
]);
|
||||
setIsStreaming(true);
|
||||
}
|
||||
buffer.current.parts.push(ev.text);
|
||||
const combined = buffer.current.parts.join("");
|
||||
const targetId = buffer.current.messageId;
|
||||
setMessages((prev) =>
|
||||
prev.map((m) => (m.id === targetId ? { ...m, content: combined } : m)),
|
||||
);
|
||||
const chunk = ev.text;
|
||||
setIsStreaming(true);
|
||||
setMessages((prev) => {
|
||||
// Reuse an in-flight assistant placeholder (typically created by
|
||||
// ``reasoning_delta``) so the answer renders below its own
|
||||
// thinking trace instead of in a parallel row.
|
||||
const adopted = !buffer.current ? findActiveAssistantPlaceholder(prev) : null;
|
||||
let targetId: string;
|
||||
let next: UIMessage[];
|
||||
if (buffer.current) {
|
||||
targetId = buffer.current.messageId;
|
||||
next = prev;
|
||||
} else if (adopted) {
|
||||
targetId = adopted;
|
||||
buffer.current = { messageId: targetId, parts: [] };
|
||||
next = prev;
|
||||
} else {
|
||||
targetId = crypto.randomUUID();
|
||||
buffer.current = { messageId: targetId, parts: [] };
|
||||
next = [
|
||||
...prev,
|
||||
{
|
||||
id: targetId,
|
||||
role: "assistant",
|
||||
content: "",
|
||||
isStreaming: true,
|
||||
createdAt: Date.now(),
|
||||
},
|
||||
];
|
||||
}
|
||||
buffer.current.parts.push(chunk);
|
||||
const combined = buffer.current.parts.join("");
|
||||
return next.map((m) =>
|
||||
m.id === targetId ? { ...m, content: combined, isStreaming: true } : m,
|
||||
);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
@ -159,6 +263,21 @@ export function useNanobotStream(
|
||||
return;
|
||||
}
|
||||
|
||||
if (ev.event === "reasoning_delta") {
|
||||
if (suppressStreamUntilTurnEndRef.current) return;
|
||||
const chunk = ev.text;
|
||||
if (!chunk) return;
|
||||
setMessages((prev) => attachReasoningChunk(prev, chunk));
|
||||
setIsStreaming(true);
|
||||
return;
|
||||
}
|
||||
|
||||
if (ev.event === "reasoning_end") {
|
||||
if (suppressStreamUntilTurnEndRef.current) return;
|
||||
setMessages((prev) => closeReasoningStream(prev));
|
||||
return;
|
||||
}
|
||||
|
||||
if (ev.event === "turn_end") {
|
||||
// Definitive signal that the turn is fully complete. Cancel any
|
||||
// pending debounce timer and stop the loading indicator immediately.
|
||||
@ -175,18 +294,22 @@ export function useNanobotStream(
|
||||
return;
|
||||
}
|
||||
|
||||
if (ev.event === "session_updated") {
|
||||
onTurnEnd?.();
|
||||
return;
|
||||
}
|
||||
|
||||
if (ev.event === "message") {
|
||||
if (
|
||||
suppressStreamUntilTurnEndRef.current &&
|
||||
(ev.kind === "tool_hint" || ev.kind === "progress")
|
||||
(ev.kind === "tool_hint" || ev.kind === "progress" || ev.kind === "reasoning")
|
||||
) {
|
||||
return;
|
||||
}
|
||||
// Back-compat: a legacy ``kind: "reasoning"`` message (no streaming
|
||||
// partner) is treated as one complete delta + immediate end so the
|
||||
// bubble renders identically to the streaming path.
|
||||
if (ev.kind === "reasoning") {
|
||||
const line = ev.text;
|
||||
if (!line) return;
|
||||
setMessages((prev) => closeReasoningStream(attachReasoningChunk(prev, line)));
|
||||
return;
|
||||
}
|
||||
// Intermediate agent breadcrumbs (tool-call hints, raw progress).
|
||||
// Attach them to the last trace row if it was the last emitted item
|
||||
// so a sequence of calls collapses into one compact trace group.
|
||||
@ -230,7 +353,7 @@ export function useNanobotStream(
|
||||
// the full turn (all tool calls + final text) is complete.
|
||||
setMessages((prev) => {
|
||||
const filtered = activeId ? prev.filter((m) => m.id !== activeId) : prev;
|
||||
const content = ev.buttons?.length ? (ev.button_prompt ?? ev.text) : ev.text;
|
||||
const content = ev.text;
|
||||
return [
|
||||
...filtered,
|
||||
{
|
||||
@ -238,7 +361,6 @@ export function useNanobotStream(
|
||||
role: "assistant",
|
||||
content,
|
||||
createdAt: Date.now(),
|
||||
...(ev.buttons && ev.buttons.length > 0 ? { buttons: ev.buttons } : {}),
|
||||
...(hasMedia ? { media } : {}),
|
||||
},
|
||||
];
|
||||
|
||||
@ -14,6 +14,48 @@ import type { ChatSummary, UIMessage } from "@/lib/types";
|
||||
|
||||
const EMPTY_MESSAGES: UIMessage[] = [];
|
||||
|
||||
type HistoryMessage = Awaited<ReturnType<typeof fetchSessionMessages>>["messages"][number];
|
||||
|
||||
function reasoningFromHistory(message: HistoryMessage): string | undefined {
|
||||
if (typeof message.reasoning_content === "string" && message.reasoning_content.trim()) {
|
||||
return message.reasoning_content;
|
||||
}
|
||||
if (!Array.isArray(message.thinking_blocks)) return undefined;
|
||||
const parts = message.thinking_blocks
|
||||
.map((block) => {
|
||||
if (!block || typeof block !== "object") return "";
|
||||
const thinking = (block as { thinking?: unknown }).thinking;
|
||||
return typeof thinking === "string" ? thinking.trim() : "";
|
||||
})
|
||||
.filter(Boolean);
|
||||
return parts.length > 0 ? parts.join("\n\n") : undefined;
|
||||
}
|
||||
|
||||
function formatToolCallTrace(call: unknown): string | null {
|
||||
if (!call || typeof call !== "object") return null;
|
||||
const item = call as {
|
||||
name?: unknown;
|
||||
function?: { name?: unknown; arguments?: unknown };
|
||||
};
|
||||
const name =
|
||||
typeof item.function?.name === "string"
|
||||
? item.function.name
|
||||
: typeof item.name === "string"
|
||||
? item.name
|
||||
: "";
|
||||
if (!name) return null;
|
||||
const args = item.function?.arguments;
|
||||
if (typeof args === "string" && args.trim()) return `${name}(${args})`;
|
||||
return `${name}()`;
|
||||
}
|
||||
|
||||
function toolTracesFromHistory(message: HistoryMessage): string[] {
|
||||
if (!Array.isArray(message.tool_calls)) return [];
|
||||
return message.tool_calls
|
||||
.map(formatToolCallTrace)
|
||||
.filter((trace): trace is string => !!trace);
|
||||
}
|
||||
|
||||
/** Sidebar state: fetches the full session list and exposes create / delete actions. */
|
||||
export function useSessions(): {
|
||||
sessions: ChatSummary[];
|
||||
@ -49,6 +91,12 @@ export function useSessions(): {
|
||||
void refresh();
|
||||
}, [refresh]);
|
||||
|
||||
useEffect(() => {
|
||||
return client.onSessionUpdate(() => {
|
||||
void refresh();
|
||||
});
|
||||
}, [client, refresh]);
|
||||
|
||||
const createChat = useCallback(async (): Promise<string> => {
|
||||
const chatId = await client.newChat();
|
||||
const key = `websocket:${chatId}`;
|
||||
@ -143,14 +191,28 @@ export function useSessionHistory(key: string | null): {
|
||||
m.role === "user" && media?.every((item) => item.kind === "image")
|
||||
? media.map((item) => ({ url: item.url, name: item.name }))
|
||||
: undefined;
|
||||
const row: UIMessage = {
|
||||
id: `hist-${idx}`,
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
createdAt: m.timestamp ? Date.parse(m.timestamp) : Date.now(),
|
||||
...(images ? { images } : {}),
|
||||
...(media ? { media } : {}),
|
||||
...(m.role === "assistant" && reasoningFromHistory(m)
|
||||
? { reasoning: reasoningFromHistory(m), reasoningStreaming: false }
|
||||
: {}),
|
||||
};
|
||||
const traces = m.role === "assistant" ? toolTracesFromHistory(m) : [];
|
||||
if (traces.length === 0) return [row];
|
||||
return [
|
||||
...(row.content.trim() || row.reasoning || row.media?.length ? [row] : []),
|
||||
{
|
||||
id: `hist-${idx}`,
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
id: `hist-${idx}-tools`,
|
||||
role: "tool" as const,
|
||||
kind: "trace" as const,
|
||||
content: traces[traces.length - 1],
|
||||
traces,
|
||||
createdAt: m.timestamp ? Date.parse(m.timestamp) : Date.now(),
|
||||
...(images ? { images } : {}),
|
||||
...(media ? { media } : {}),
|
||||
},
|
||||
];
|
||||
});
|
||||
|
||||
@ -332,6 +332,8 @@
|
||||
"assistantTyping": "Assistant is typing",
|
||||
"toolSingle": "Using a tool",
|
||||
"toolMany": "Used {{count}} tools",
|
||||
"reasoning": "Thinking",
|
||||
"reasoningStreaming": "Thinking…",
|
||||
"imageAttachment": "Image attachment",
|
||||
"copyReply": "Copy reply",
|
||||
"copiedReply": "Copied reply"
|
||||
|
||||
@ -320,6 +320,8 @@
|
||||
"assistantTyping": "助手正在输入",
|
||||
"toolSingle": "正在使用工具",
|
||||
"toolMany": "已使用 {{count}} 个工具",
|
||||
"reasoning": "思考过程",
|
||||
"reasoningStreaming": "正在思考…",
|
||||
"imageAttachment": "图片附件",
|
||||
"copyReply": "复制回复",
|
||||
"copiedReply": "已复制回复"
|
||||
|
||||
@ -89,6 +89,8 @@ export async function fetchSessionMessages(
|
||||
content: string;
|
||||
timestamp?: string;
|
||||
tool_calls?: unknown;
|
||||
reasoning_content?: string | null;
|
||||
thinking_blocks?: unknown;
|
||||
tool_call_id?: string;
|
||||
name?: string;
|
||||
/** Present on ``user`` turns that attached images. Paths have already
|
||||
|
||||
@ -15,6 +15,7 @@ type Unsubscribe = () => void;
|
||||
type EventHandler = (ev: InboundEvent) => void;
|
||||
type StatusHandler = (status: ConnectionStatus) => void;
|
||||
type RuntimeModelHandler = (modelName: string | null, modelPreset?: string | null) => void;
|
||||
type SessionUpdateHandler = (chatId: string) => void;
|
||||
|
||||
/** Structured connection-level errors surfaced to the UI.
|
||||
*
|
||||
@ -60,6 +61,7 @@ export class NanobotClient {
|
||||
private socket: WebSocket | null = null;
|
||||
private statusHandlers = new Set<StatusHandler>();
|
||||
private runtimeModelHandlers = new Set<RuntimeModelHandler>();
|
||||
private sessionUpdateHandlers = new Set<SessionUpdateHandler>();
|
||||
private errorHandlers = new Set<ErrorHandler>();
|
||||
// chat_id -> handlers listening on it
|
||||
private chatHandlers = new Map<string, Set<EventHandler>>();
|
||||
@ -116,6 +118,13 @@ export class NanobotClient {
|
||||
};
|
||||
}
|
||||
|
||||
onSessionUpdate(handler: SessionUpdateHandler): Unsubscribe {
|
||||
this.sessionUpdateHandlers.add(handler);
|
||||
return () => {
|
||||
this.sessionUpdateHandlers.delete(handler);
|
||||
};
|
||||
}
|
||||
|
||||
/** Subscribe to transport-level faults (see :type:`StreamError`). */
|
||||
onError(handler: ErrorHandler): Unsubscribe {
|
||||
this.errorHandlers.add(handler);
|
||||
@ -259,6 +268,11 @@ export class NanobotClient {
|
||||
return;
|
||||
}
|
||||
|
||||
if (parsed.event === "session_updated") {
|
||||
this.emitSessionUpdate(parsed.chat_id);
|
||||
return;
|
||||
}
|
||||
|
||||
const chatId = (parsed as { chat_id?: string }).chat_id;
|
||||
if (chatId) this.dispatch(chatId, parsed);
|
||||
}
|
||||
@ -269,6 +283,12 @@ export class NanobotClient {
|
||||
}
|
||||
}
|
||||
|
||||
private emitSessionUpdate(chatId: string): void {
|
||||
for (const handler of this.sessionUpdateHandlers) {
|
||||
handler(chatId);
|
||||
}
|
||||
}
|
||||
|
||||
private dispatch(chatId: string, ev: InboundEvent): void {
|
||||
const handlers = this.chatHandlers.get(chatId);
|
||||
if (!handlers) return;
|
||||
|
||||
@ -44,8 +44,13 @@ export interface UIMessage {
|
||||
images?: UIImage[];
|
||||
/** Signed or local UI-renderable media attachments. */
|
||||
media?: UIMediaAttachment[];
|
||||
/** Optional answer choices for a pending ask_user question. */
|
||||
buttons?: string[][];
|
||||
/** Assistant turn: accumulated model reasoning / thinking text. Built up
|
||||
* incrementally from ``reasoning_delta`` frames; finalized when
|
||||
* ``reasoning_end`` arrives. */
|
||||
reasoning?: string;
|
||||
/** True while ``reasoning_delta`` frames are still arriving for this turn.
|
||||
* Drives the shimmer header on ``ReasoningBubble``. */
|
||||
reasoningStreaming?: boolean;
|
||||
}
|
||||
|
||||
export interface ChatSummary {
|
||||
@ -141,12 +146,9 @@ export type InboundEvent =
|
||||
reply_to?: string;
|
||||
media?: string[];
|
||||
media_urls?: Array<{ url: string; name?: string }>;
|
||||
buttons?: string[][];
|
||||
/** Original prompt before the websocket text fallback appends buttons. */
|
||||
button_prompt?: string;
|
||||
/** Present when the frame is an agent breadcrumb (e.g. tool hint,
|
||||
* generic progress line) rather than a conversational reply. */
|
||||
kind?: "tool_hint" | "progress";
|
||||
kind?: "tool_hint" | "progress" | "reasoning";
|
||||
}
|
||||
| {
|
||||
event: "delta";
|
||||
@ -159,6 +161,17 @@ export type InboundEvent =
|
||||
chat_id: string;
|
||||
stream_id?: string;
|
||||
}
|
||||
| {
|
||||
event: "reasoning_delta";
|
||||
chat_id: string;
|
||||
text: string;
|
||||
stream_id?: string;
|
||||
}
|
||||
| {
|
||||
event: "reasoning_end";
|
||||
chat_id: string;
|
||||
stream_id?: string;
|
||||
}
|
||||
| {
|
||||
event: "runtime_model_updated";
|
||||
model_name: string;
|
||||
|
||||
@ -265,7 +265,7 @@ describe("App layout", () => {
|
||||
expect(screen.queryByDisplayValue("unsaved-brave-key")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("returns from settings to an available chat instead of the blank start page", async () => {
|
||||
it("returns from settings to the blank start page when no session was active", async () => {
|
||||
mockSessions = [
|
||||
{
|
||||
key: "websocket:chat-a",
|
||||
@ -330,10 +330,8 @@ describe("App layout", () => {
|
||||
expect(await screen.findByRole("heading", { name: "General" })).toBeInTheDocument();
|
||||
fireEvent.click(screen.getByRole("button", { name: "Back to chat" }));
|
||||
|
||||
await waitFor(() => expect(document.title).toBe("First chat · nanobot"));
|
||||
const restoredSidebar = screen.getByRole("navigation", { name: "Sidebar navigation" });
|
||||
fireEvent.click(within(restoredSidebar).getByRole("button", { name: /^Second chat$/ }));
|
||||
await waitFor(() => expect(document.title).toBe("Second chat · nanobot"));
|
||||
await waitFor(() => expect(document.title).toBe("nanobot"));
|
||||
expect(screen.getByText("What can I do for you?")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("filters sidebar sessions through the lightweight search row", async () => {
|
||||
|
||||
@ -72,11 +72,12 @@ describe("MessageBubble", () => {
|
||||
render(<MessageBubble message={message} />);
|
||||
const toggle = screen.getByRole("button", { name: /used 2 tools/i });
|
||||
|
||||
expect(screen.getByText('weather("get")')).toBeInTheDocument();
|
||||
expect(screen.getByText('search "hk weather"')).toBeInTheDocument();
|
||||
expect(screen.queryByText('weather("get")')).not.toBeInTheDocument();
|
||||
expect(screen.queryByText('search "hk weather"')).not.toBeInTheDocument();
|
||||
|
||||
fireEvent.click(toggle);
|
||||
expect(screen.queryByText('weather("get")')).not.toBeInTheDocument();
|
||||
expect(screen.getByText('weather("get")')).toBeInTheDocument();
|
||||
expect(screen.getByText('search "hk weather"')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("renders video media as an inline player", () => {
|
||||
@ -103,6 +104,45 @@ describe("MessageBubble", () => {
|
||||
expect(container.querySelector("video[controls]")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("auto-expands the reasoning trace while streaming with a shimmer header", () => {
|
||||
const message: UIMessage = {
|
||||
id: "a-reasoning-streaming",
|
||||
role: "assistant",
|
||||
content: "",
|
||||
createdAt: Date.now(),
|
||||
reasoning: "Step 1: parse intent. Step 2: compute.",
|
||||
reasoningStreaming: true,
|
||||
};
|
||||
|
||||
const { container } = render(<MessageBubble message={message} />);
|
||||
|
||||
expect(screen.getByText("Thinking…")).toBeInTheDocument();
|
||||
expect(screen.getByText(/Step 1: parse intent\./)).toBeInTheDocument();
|
||||
expect(container.querySelector(".reasoning-shimmer")).toBeInTheDocument();
|
||||
expect(screen.getByRole("button", { name: /thinking/i }).parentElement).not.toHaveClass("mb-2");
|
||||
});
|
||||
|
||||
it("collapses the reasoning section by default once streaming ends", () => {
|
||||
const message: UIMessage = {
|
||||
id: "a-reasoning-done",
|
||||
role: "assistant",
|
||||
content: "The answer is 42.",
|
||||
createdAt: Date.now(),
|
||||
reasoning: "hidden until expanded",
|
||||
reasoningStreaming: false,
|
||||
};
|
||||
|
||||
render(<MessageBubble message={message} />);
|
||||
|
||||
expect(screen.getByText("Thinking")).toBeInTheDocument();
|
||||
expect(screen.getByText("The answer is 42.")).toBeInTheDocument();
|
||||
expect(screen.queryByText("hidden until expanded")).not.toBeInTheDocument();
|
||||
expect(screen.getByRole("button", { name: /thinking/i }).parentElement).toHaveClass("mb-2");
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /thinking/i }));
|
||||
expect(screen.getByText("hidden until expanded")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("renders assistant image media as a larger generated result", () => {
|
||||
const message: UIMessage = {
|
||||
id: "a-image",
|
||||
|
||||
@ -109,6 +109,25 @@ describe("NanobotClient", () => {
|
||||
expect(handler).toHaveBeenCalledWith("openai/gpt-4.1", "fast");
|
||||
});
|
||||
|
||||
it("dispatches session updates globally", () => {
|
||||
const client = new NanobotClient({
|
||||
url: "ws://test",
|
||||
reconnect: false,
|
||||
socketFactory: (url) => new FakeSocket(url) as unknown as WebSocket,
|
||||
});
|
||||
const globalHandler = vi.fn();
|
||||
const chatHandler = vi.fn();
|
||||
client.onSessionUpdate(globalHandler);
|
||||
client.onChat("chat-title", chatHandler);
|
||||
client.connect();
|
||||
lastSocket().fakeOpen();
|
||||
|
||||
lastSocket().fakeMessage({ event: "session_updated", chat_id: "chat-title" });
|
||||
|
||||
expect(globalHandler).toHaveBeenCalledWith("chat-title");
|
||||
expect(chatHandler).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("resolves newChat() via the server-assigned chat_id", async () => {
|
||||
const client = new NanobotClient({
|
||||
url: "ws://test",
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { fireEvent, render, screen } from "@testing-library/react";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { fireEvent, render, screen, waitFor } from "@testing-library/react";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
import { ThreadComposer } from "@/components/thread/ThreadComposer";
|
||||
import type { SlashCommand } from "@/lib/types";
|
||||
@ -19,6 +19,33 @@ const COMMANDS: SlashCommand[] = [
|
||||
argHint: "[n]",
|
||||
},
|
||||
];
|
||||
const ORIGINAL_INNER_HEIGHT = window.innerHeight;
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
Object.defineProperty(window, "innerHeight", {
|
||||
value: ORIGINAL_INNER_HEIGHT,
|
||||
configurable: true,
|
||||
});
|
||||
});
|
||||
|
||||
function rect(init: Partial<DOMRect>): DOMRect {
|
||||
const top = init.top ?? 0;
|
||||
const left = init.left ?? 0;
|
||||
const width = init.width ?? 0;
|
||||
const height = init.height ?? 0;
|
||||
return {
|
||||
x: init.x ?? left,
|
||||
y: init.y ?? top,
|
||||
top,
|
||||
left,
|
||||
width,
|
||||
height,
|
||||
right: init.right ?? left + width,
|
||||
bottom: init.bottom ?? top + height,
|
||||
toJSON: () => ({}),
|
||||
};
|
||||
}
|
||||
|
||||
describe("ThreadComposer", () => {
|
||||
it("renders a readonly hero model composer when provided", () => {
|
||||
@ -74,7 +101,9 @@ describe("ThreadComposer", () => {
|
||||
const input = screen.getByLabelText("Message input");
|
||||
fireEvent.change(input, { target: { value: "/" } });
|
||||
|
||||
expect(screen.getByRole("listbox", { name: "Slash commands" })).toBeInTheDocument();
|
||||
const palette = screen.getByRole("listbox", { name: "Slash commands" });
|
||||
expect(palette).toBeInTheDocument();
|
||||
expect(palette).toHaveStyle({ maxHeight: "288px" });
|
||||
expect(screen.getByRole("option", { name: /\/stop/i })).toHaveAttribute(
|
||||
"aria-selected",
|
||||
"true",
|
||||
@ -92,6 +121,55 @@ describe("ThreadComposer", () => {
|
||||
expect(screen.queryByRole("listbox", { name: "Slash commands" })).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("opens the slash command palette downward when there is more room below", async () => {
|
||||
vi.spyOn(HTMLFormElement.prototype, "getBoundingClientRect").mockReturnValue(
|
||||
rect({ top: 40, bottom: 160, width: 800, height: 120 }),
|
||||
);
|
||||
Object.defineProperty(window, "innerHeight", {
|
||||
value: 330,
|
||||
configurable: true,
|
||||
});
|
||||
render(
|
||||
<ThreadComposer
|
||||
onSend={vi.fn()}
|
||||
placeholder="Ask anything..."
|
||||
slashCommands={COMMANDS}
|
||||
variant="hero"
|
||||
/>,
|
||||
);
|
||||
const input = screen.getByLabelText("Message input");
|
||||
|
||||
fireEvent.change(input, { target: { value: "/" } });
|
||||
|
||||
await waitFor(() => {
|
||||
const palette = screen.getByRole("listbox", { name: "Slash commands" });
|
||||
expect(palette.className).toContain("top-full");
|
||||
expect(palette).toHaveStyle({ maxHeight: "162px" });
|
||||
});
|
||||
});
|
||||
|
||||
it("dismisses the slash command palette on outside click", () => {
|
||||
render(
|
||||
<div>
|
||||
<button type="button">outside</button>
|
||||
<ThreadComposer
|
||||
onSend={vi.fn()}
|
||||
placeholder="Type your message..."
|
||||
slashCommands={COMMANDS}
|
||||
/>
|
||||
</div>,
|
||||
);
|
||||
|
||||
fireEvent.change(screen.getByLabelText("Message input"), {
|
||||
target: { value: "/" },
|
||||
});
|
||||
expect(screen.getByRole("listbox", { name: "Slash commands" })).toBeInTheDocument();
|
||||
|
||||
fireEvent.pointerDown(screen.getByRole("button", { name: "outside" }));
|
||||
|
||||
expect(screen.queryByRole("listbox", { name: "Slash commands" })).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("sends image generation mode with automatic aspect ratio", () => {
|
||||
const onSend = vi.fn();
|
||||
render(
|
||||
|
||||
52
webui/src/tests/thread-messages.test.tsx
Normal file
52
webui/src/tests/thread-messages.test.tsx
Normal file
@ -0,0 +1,52 @@
|
||||
import { render } from "@testing-library/react";
|
||||
import { describe, expect, it } from "vitest";
|
||||
|
||||
import { ThreadMessages } from "@/components/thread/ThreadMessages";
|
||||
import type { UIMessage } from "@/lib/types";
|
||||
|
||||
describe("ThreadMessages", () => {
|
||||
it("uses compact spacing between consecutive auxiliary rows", () => {
|
||||
const messages: UIMessage[] = [
|
||||
{
|
||||
id: "r1",
|
||||
role: "assistant",
|
||||
content: "",
|
||||
reasoning: "thinking",
|
||||
reasoningStreaming: false,
|
||||
isStreaming: true,
|
||||
createdAt: Date.now(),
|
||||
},
|
||||
{
|
||||
id: "t1",
|
||||
role: "tool",
|
||||
kind: "trace",
|
||||
content: "search()",
|
||||
traces: ["search()"],
|
||||
createdAt: Date.now(),
|
||||
},
|
||||
{
|
||||
id: "r2",
|
||||
role: "assistant",
|
||||
content: "",
|
||||
reasoning: "more thinking",
|
||||
reasoningStreaming: false,
|
||||
isStreaming: true,
|
||||
createdAt: Date.now(),
|
||||
},
|
||||
{
|
||||
id: "a1",
|
||||
role: "assistant",
|
||||
content: "final answer",
|
||||
createdAt: Date.now(),
|
||||
},
|
||||
];
|
||||
|
||||
const { container } = render(<ThreadMessages messages={messages} />);
|
||||
const rows = Array.from(container.firstElementChild?.children ?? []);
|
||||
|
||||
expect(rows[0]).not.toHaveClass("mt-2", "mt-5");
|
||||
expect(rows[1]).toHaveClass("mt-2");
|
||||
expect(rows[2]).toHaveClass("mt-2");
|
||||
expect(rows[3]).toHaveClass("mt-5");
|
||||
});
|
||||
});
|
||||
@ -573,7 +573,7 @@ describe("ThreadShell", () => {
|
||||
await waitFor(() => expect(screen.getByText("live assistant reply")).toBeInTheDocument());
|
||||
});
|
||||
|
||||
it("does not open slash commands on the blank welcome page", async () => {
|
||||
it("opens slash commands on the blank welcome page", async () => {
|
||||
const client = makeClient();
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
@ -583,10 +583,11 @@ describe("ThreadShell", () => {
|
||||
return httpJson({
|
||||
commands: [
|
||||
{
|
||||
command: "/stop",
|
||||
title: "Stop current task",
|
||||
description: "Cancel the active agent turn.",
|
||||
icon: "square",
|
||||
command: "/history",
|
||||
title: "Show conversation history",
|
||||
description: "Print the last N persisted messages.",
|
||||
icon: "history",
|
||||
arg_hint: "[n]",
|
||||
},
|
||||
],
|
||||
});
|
||||
@ -622,7 +623,8 @@ describe("ThreadShell", () => {
|
||||
target: { value: "/" },
|
||||
});
|
||||
|
||||
expect(screen.queryByRole("listbox", { name: "Slash commands" })).not.toBeInTheDocument();
|
||||
expect(screen.getByRole("listbox", { name: "Slash commands" })).toBeInTheDocument();
|
||||
expect(screen.getByRole("option", { name: /\/history/i })).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("switches welcome quick actions when image mode is enabled", async () => {
|
||||
@ -809,46 +811,4 @@ describe("ThreadShell", () => {
|
||||
await waitFor(() => expect(screen.getByText("from chat b")).toBeInTheDocument());
|
||||
expect(screen.queryByText("from chat a")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("renders ask_user options above the composer and sends selected answers", async () => {
|
||||
const client = makeClient();
|
||||
const onNewChat = vi.fn().mockResolvedValue("chat-a");
|
||||
|
||||
render(
|
||||
wrap(
|
||||
client,
|
||||
<ThreadShell
|
||||
session={session("chat-a")}
|
||||
title="Chat chat-a"
|
||||
onToggleSidebar={() => {}}
|
||||
onGoHome={() => {}}
|
||||
onNewChat={onNewChat}
|
||||
/>,
|
||||
),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
client._emitChat("chat-a", {
|
||||
event: "message",
|
||||
chat_id: "chat-a",
|
||||
text: "How should I continue?",
|
||||
buttons: [["Short answer", "Detailed answer"]],
|
||||
});
|
||||
});
|
||||
|
||||
expect(screen.getByRole("group", { name: "Question" })).toHaveTextContent(
|
||||
"How should I continue?",
|
||||
);
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: "Short answer" }));
|
||||
|
||||
expect(client.sendMessage).toHaveBeenCalledWith(
|
||||
"chat-a",
|
||||
"Short answer",
|
||||
undefined,
|
||||
);
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByRole("group", { name: "Question" })).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@ -113,6 +113,208 @@ describe("useNanobotStream", () => {
|
||||
expect(result.current.messages[1].kind).toBeUndefined();
|
||||
});
|
||||
|
||||
it("accumulates reasoning_delta chunks on a placeholder until reasoning_end", () => {
|
||||
const fake = fakeClient();
|
||||
const { result } = renderHook(() => useNanobotStream("chat-r", EMPTY_MESSAGES), {
|
||||
wrapper: wrap(fake.client),
|
||||
});
|
||||
|
||||
act(() => {
|
||||
fake.emit("chat-r", {
|
||||
event: "reasoning_delta",
|
||||
chat_id: "chat-r",
|
||||
text: "Let me think ",
|
||||
});
|
||||
fake.emit("chat-r", {
|
||||
event: "reasoning_delta",
|
||||
chat_id: "chat-r",
|
||||
text: "step by step.",
|
||||
});
|
||||
});
|
||||
|
||||
expect(result.current.messages).toHaveLength(1);
|
||||
expect(result.current.messages[0].role).toBe("assistant");
|
||||
expect(result.current.messages[0].reasoning).toBe("Let me think step by step.");
|
||||
expect(result.current.messages[0].reasoningStreaming).toBe(true);
|
||||
|
||||
act(() => {
|
||||
fake.emit("chat-r", { event: "reasoning_end", chat_id: "chat-r" });
|
||||
});
|
||||
|
||||
expect(result.current.messages[0].reasoningStreaming).toBe(false);
|
||||
expect(result.current.messages[0].reasoning).toBe("Let me think step by step.");
|
||||
});
|
||||
|
||||
it("absorbs a streaming reasoning placeholder into the answer turn that follows", () => {
|
||||
const fake = fakeClient();
|
||||
const { result } = renderHook(() => useNanobotStream("chat-r2", EMPTY_MESSAGES), {
|
||||
wrapper: wrap(fake.client),
|
||||
});
|
||||
|
||||
act(() => {
|
||||
fake.emit("chat-r2", {
|
||||
event: "reasoning_delta",
|
||||
chat_id: "chat-r2",
|
||||
text: "Plan first.",
|
||||
});
|
||||
fake.emit("chat-r2", { event: "reasoning_end", chat_id: "chat-r2" });
|
||||
fake.emit("chat-r2", {
|
||||
event: "delta",
|
||||
chat_id: "chat-r2",
|
||||
text: "The answer is 42.",
|
||||
});
|
||||
fake.emit("chat-r2", { event: "stream_end", chat_id: "chat-r2" });
|
||||
});
|
||||
|
||||
expect(result.current.messages).toHaveLength(1);
|
||||
expect(result.current.messages[0].content).toBe("The answer is 42.");
|
||||
expect(result.current.messages[0].reasoning).toBe("Plan first.");
|
||||
expect(result.current.messages[0].reasoningStreaming).toBe(false);
|
||||
});
|
||||
|
||||
it("ignores empty reasoning_delta frames", () => {
|
||||
const fake = fakeClient();
|
||||
const { result } = renderHook(() => useNanobotStream("chat-r3", EMPTY_MESSAGES), {
|
||||
wrapper: wrap(fake.client),
|
||||
});
|
||||
|
||||
act(() => {
|
||||
fake.emit("chat-r3", {
|
||||
event: "reasoning_delta",
|
||||
chat_id: "chat-r3",
|
||||
text: "",
|
||||
});
|
||||
});
|
||||
|
||||
expect(result.current.messages).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("treats legacy kind=reasoning messages as a complete delta + end pair", () => {
|
||||
const fake = fakeClient();
|
||||
const { result } = renderHook(() => useNanobotStream("chat-r4", EMPTY_MESSAGES), {
|
||||
wrapper: wrap(fake.client),
|
||||
});
|
||||
|
||||
act(() => {
|
||||
fake.emit("chat-r4", {
|
||||
event: "message",
|
||||
chat_id: "chat-r4",
|
||||
text: "one-shot reasoning",
|
||||
kind: "reasoning",
|
||||
});
|
||||
});
|
||||
|
||||
expect(result.current.messages).toHaveLength(1);
|
||||
expect(result.current.messages[0].reasoning).toBe("one-shot reasoning");
|
||||
expect(result.current.messages[0].reasoningStreaming).toBe(false);
|
||||
});
|
||||
|
||||
it("attaches post-hoc reasoning to the same assistant turn above the answer", () => {
|
||||
const fake = fakeClient();
|
||||
const { result } = renderHook(() => useNanobotStream("chat-r5", EMPTY_MESSAGES), {
|
||||
wrapper: wrap(fake.client),
|
||||
});
|
||||
|
||||
act(() => {
|
||||
fake.emit("chat-r5", {
|
||||
event: "delta",
|
||||
chat_id: "chat-r5",
|
||||
text: "hi~",
|
||||
});
|
||||
fake.emit("chat-r5", { event: "stream_end", chat_id: "chat-r5" });
|
||||
fake.emit("chat-r5", {
|
||||
event: "reasoning_delta",
|
||||
chat_id: "chat-r5",
|
||||
text: "This reasoning arrived after the answer stream.",
|
||||
});
|
||||
fake.emit("chat-r5", { event: "reasoning_end", chat_id: "chat-r5" });
|
||||
});
|
||||
|
||||
expect(result.current.messages).toHaveLength(1);
|
||||
expect(result.current.messages[0].content).toBe("hi~");
|
||||
expect(result.current.messages[0].reasoning).toBe(
|
||||
"This reasoning arrived after the answer stream.",
|
||||
);
|
||||
expect(result.current.messages[0].reasoningStreaming).toBe(false);
|
||||
});
|
||||
|
||||
it("does not attach a new turn's reasoning across the latest user boundary", () => {
|
||||
const fake = fakeClient();
|
||||
const initialMessages = [
|
||||
{
|
||||
id: "a-prev",
|
||||
role: "assistant" as const,
|
||||
content: "Previous answer.",
|
||||
reasoning: "Previous thought.",
|
||||
createdAt: Date.now(),
|
||||
},
|
||||
{
|
||||
id: "u-next",
|
||||
role: "user" as const,
|
||||
content: "Next question",
|
||||
createdAt: Date.now(),
|
||||
},
|
||||
];
|
||||
const { result } = renderHook(
|
||||
() => useNanobotStream("chat-r6", initialMessages),
|
||||
{ wrapper: wrap(fake.client) },
|
||||
);
|
||||
|
||||
act(() => {
|
||||
fake.emit("chat-r6", {
|
||||
event: "reasoning_delta",
|
||||
chat_id: "chat-r6",
|
||||
text: "New turn thinking.",
|
||||
});
|
||||
});
|
||||
|
||||
expect(result.current.messages).toHaveLength(3);
|
||||
expect(result.current.messages[0].reasoning).toBe("Previous thought.");
|
||||
expect(result.current.messages[2].role).toBe("assistant");
|
||||
expect(result.current.messages[2].content).toBe("");
|
||||
expect(result.current.messages[2].reasoning).toBe("New turn thinking.");
|
||||
expect(result.current.messages[2].reasoningStreaming).toBe(true);
|
||||
});
|
||||
|
||||
it("does not attach reasoning across a tool trace boundary", () => {
|
||||
const fake = fakeClient();
|
||||
const { result } = renderHook(() => useNanobotStream("chat-r7", EMPTY_MESSAGES), {
|
||||
wrapper: wrap(fake.client),
|
||||
});
|
||||
|
||||
act(() => {
|
||||
fake.emit("chat-r7", {
|
||||
event: "reasoning_delta",
|
||||
chat_id: "chat-r7",
|
||||
text: "First reasoning.",
|
||||
});
|
||||
fake.emit("chat-r7", { event: "reasoning_end", chat_id: "chat-r7" });
|
||||
fake.emit("chat-r7", {
|
||||
event: "message",
|
||||
chat_id: "chat-r7",
|
||||
text: "web_search({\"query\":\"OpenClaw\"})",
|
||||
kind: "tool_hint",
|
||||
});
|
||||
fake.emit("chat-r7", {
|
||||
event: "reasoning_delta",
|
||||
chat_id: "chat-r7",
|
||||
text: "Second reasoning.",
|
||||
});
|
||||
});
|
||||
|
||||
expect(result.current.messages).toHaveLength(3);
|
||||
expect(result.current.messages.map((m) => m.kind ?? "message")).toEqual([
|
||||
"message",
|
||||
"trace",
|
||||
"message",
|
||||
]);
|
||||
expect(result.current.messages[0].reasoning).toBe("First reasoning.");
|
||||
expect(result.current.messages[1].traces).toEqual([
|
||||
"web_search({\"query\":\"OpenClaw\"})",
|
||||
]);
|
||||
expect(result.current.messages[2].reasoning).toBe("Second reasoning.");
|
||||
});
|
||||
|
||||
it("attaches assistant media_urls to complete messages", () => {
|
||||
const fake = fakeClient();
|
||||
const { result } = renderHook(() => useNanobotStream("chat-m", EMPTY_MESSAGES), {
|
||||
@ -217,29 +419,6 @@ describe("useNanobotStream", () => {
|
||||
expect(result.current.messages[0].content).toBe("long task");
|
||||
});
|
||||
|
||||
it("keeps assistant buttons on complete messages", () => {
|
||||
const fake = fakeClient();
|
||||
const { result } = renderHook(() => useNanobotStream("chat-q", EMPTY_MESSAGES), {
|
||||
wrapper: wrap(fake.client),
|
||||
});
|
||||
|
||||
act(() => {
|
||||
fake.emit("chat-q", {
|
||||
event: "message",
|
||||
chat_id: "chat-q",
|
||||
text: "How should I continue?\n\n1. Short answer\n2. Detailed answer",
|
||||
button_prompt: "How should I continue?",
|
||||
buttons: [["Short answer", "Detailed answer"]],
|
||||
});
|
||||
});
|
||||
|
||||
expect(result.current.messages).toHaveLength(1);
|
||||
expect(result.current.messages[0].content).toBe("How should I continue?");
|
||||
expect(result.current.messages[0].buttons).toEqual([
|
||||
["Short answer", "Detailed answer"],
|
||||
]);
|
||||
});
|
||||
|
||||
it("keeps streaming alive across stream_end and completes on turn_end", () => {
|
||||
const fake = fakeClient();
|
||||
const onTurnEnd = vi.fn();
|
||||
@ -298,20 +477,4 @@ describe("useNanobotStream", () => {
|
||||
expect(onTurnEnd).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("refreshes session metadata when the server reports a session update", () => {
|
||||
const fake = fakeClient();
|
||||
const onTurnEnd = vi.fn();
|
||||
renderHook(() => useNanobotStream("chat-title", EMPTY_MESSAGES, false, onTurnEnd), {
|
||||
wrapper: wrap(fake.client),
|
||||
});
|
||||
|
||||
act(() => {
|
||||
fake.emit("chat-title", {
|
||||
event: "session_updated",
|
||||
chat_id: "chat-title",
|
||||
});
|
||||
});
|
||||
|
||||
expect(onTurnEnd).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
@ -17,12 +17,20 @@ vi.mock("@/lib/api", async (importOriginal) => {
|
||||
});
|
||||
|
||||
function fakeClient() {
|
||||
const sessionUpdateHandlers = new Set<(chatId: string) => void>();
|
||||
return {
|
||||
status: "open" as const,
|
||||
defaultChatId: null as string | null,
|
||||
onStatus: () => () => {},
|
||||
onError: () => () => {},
|
||||
onChat: () => () => {},
|
||||
onSessionUpdate: (handler: (chatId: string) => void) => {
|
||||
sessionUpdateHandlers.add(handler);
|
||||
return () => sessionUpdateHandlers.delete(handler);
|
||||
},
|
||||
emitSessionUpdate: (chatId: string) => {
|
||||
for (const handler of sessionUpdateHandlers) handler(chatId);
|
||||
},
|
||||
sendMessage: vi.fn(),
|
||||
newChat: vi.fn(),
|
||||
attach: vi.fn(),
|
||||
@ -87,6 +95,45 @@ describe("useSessions", () => {
|
||||
expect(result.current.sessions.map((s) => s.key)).toEqual(["websocket:chat-b"]);
|
||||
});
|
||||
|
||||
it("refreshes sessions when the websocket reports a session update", async () => {
|
||||
vi.mocked(api.listSessions)
|
||||
.mockResolvedValueOnce([
|
||||
{
|
||||
key: "websocket:chat-a",
|
||||
channel: "websocket",
|
||||
chatId: "chat-a",
|
||||
createdAt: "2026-04-16T10:00:00Z",
|
||||
updatedAt: "2026-04-16T10:00:00Z",
|
||||
preview: "",
|
||||
},
|
||||
])
|
||||
.mockResolvedValueOnce([
|
||||
{
|
||||
key: "websocket:chat-a",
|
||||
channel: "websocket",
|
||||
chatId: "chat-a",
|
||||
createdAt: "2026-04-16T10:00:00Z",
|
||||
updatedAt: "2026-04-16T10:01:00Z",
|
||||
title: "生成的小标题",
|
||||
preview: "用户第一句话",
|
||||
},
|
||||
]);
|
||||
const client = fakeClient();
|
||||
|
||||
const { result } = renderHook(() => useSessions(), {
|
||||
wrapper: wrap(client),
|
||||
});
|
||||
|
||||
await waitFor(() => expect(result.current.sessions[0]?.title).toBeUndefined());
|
||||
|
||||
act(() => {
|
||||
client.emitSessionUpdate("chat-a");
|
||||
});
|
||||
|
||||
await waitFor(() => expect(result.current.sessions[0]?.title).toBe("生成的小标题"));
|
||||
expect(api.listSessions).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it("hydrates media_urls from historical user turns into UIMessage.images", async () => {
|
||||
// Round-trip check for the signed-media replay: the backend emits
|
||||
// ``media_urls`` on a historical user row and the hook must surface them
|
||||
@ -170,6 +217,92 @@ describe("useSessions", () => {
|
||||
]);
|
||||
});
|
||||
|
||||
it("hydrates persisted assistant reasoning into the replayed message", async () => {
|
||||
vi.mocked(api.fetchSessionMessages).mockResolvedValue({
|
||||
key: "websocket:chat-reasoning",
|
||||
created_at: "2026-04-20T10:00:00Z",
|
||||
updated_at: "2026-04-20T10:05:00Z",
|
||||
messages: [
|
||||
{
|
||||
role: "assistant",
|
||||
content: "final answer",
|
||||
timestamp: "2026-04-20T10:00:01Z",
|
||||
reasoning_content: "hidden but persisted reasoning",
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useSessionHistory("websocket:chat-reasoning"), {
|
||||
wrapper: wrap(fakeClient()),
|
||||
});
|
||||
|
||||
await waitFor(() => expect(result.current.loading).toBe(false));
|
||||
|
||||
expect(result.current.messages).toHaveLength(1);
|
||||
expect(result.current.messages[0].role).toBe("assistant");
|
||||
expect(result.current.messages[0].content).toBe("final answer");
|
||||
expect(result.current.messages[0].reasoning).toBe("hidden but persisted reasoning");
|
||||
expect(result.current.messages[0].reasoningStreaming).toBe(false);
|
||||
});
|
||||
|
||||
it("hydrates historical assistant tool calls into a replay trace row", async () => {
|
||||
vi.mocked(api.fetchSessionMessages).mockResolvedValue({
|
||||
key: "websocket:chat-tools",
|
||||
created_at: "2026-04-20T10:00:00Z",
|
||||
updated_at: "2026-04-20T10:05:00Z",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "research this",
|
||||
timestamp: "2026-04-20T10:00:00Z",
|
||||
},
|
||||
{
|
||||
role: "assistant",
|
||||
content: "",
|
||||
timestamp: "2026-04-20T10:00:01Z",
|
||||
tool_calls: [
|
||||
{
|
||||
id: "call-1",
|
||||
type: "function",
|
||||
function: { name: "web_search", arguments: "{\"query\":\"agents\"}" },
|
||||
},
|
||||
{
|
||||
id: "call-2",
|
||||
type: "function",
|
||||
function: { name: "web_fetch", arguments: "{\"url\":\"https://example.com\"}" },
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: "tool",
|
||||
content: "tool output that should not render directly",
|
||||
timestamp: "2026-04-20T10:00:02Z",
|
||||
tool_call_id: "call-1",
|
||||
},
|
||||
{
|
||||
role: "assistant",
|
||||
content: "summary",
|
||||
timestamp: "2026-04-20T10:00:03Z",
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useSessionHistory("websocket:chat-tools"), {
|
||||
wrapper: wrap(fakeClient()),
|
||||
});
|
||||
|
||||
await waitFor(() => expect(result.current.loading).toBe(false));
|
||||
|
||||
expect(result.current.messages.map((m) => m.role)).toEqual(["user", "tool", "assistant"]);
|
||||
const trace = result.current.messages[1];
|
||||
expect(trace.kind).toBe("trace");
|
||||
expect(trace.traces).toEqual([
|
||||
"web_search({\"query\":\"agents\"})",
|
||||
"web_fetch({\"url\":\"https://example.com\"})",
|
||||
]);
|
||||
expect(result.current.messages[2].content).toBe("summary");
|
||||
});
|
||||
|
||||
it("flags history with trailing assistant tool calls as still pending", async () => {
|
||||
vi.mocked(api.fetchSessionMessages).mockResolvedValue({
|
||||
key: "websocket:chat-pending",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user