diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index bc807092e..81cc393b8 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -33,7 +33,6 @@ from nanobot.config.schema import AgentDefaults, ModelPresetConfig from nanobot.providers.base import LLMProvider from nanobot.providers.factory import ProviderSnapshot from nanobot.session.goal_state import ( - goal_state_ws_blob, runner_wall_llm_timeout_s, ) from nanobot.session.manager import Session, SessionManager @@ -44,8 +43,11 @@ from nanobot.utils.helpers import truncate_text as truncate_text_fn from nanobot.utils.image_generation_intent import image_generation_prompt from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE from nanobot.utils.session_attachments import merge_turn_media_into_last_assistant -from nanobot.utils.webui_titles import mark_webui_session, maybe_generate_webui_title_after_turn -from nanobot.utils.webui_turn_helpers import publish_turn_run_status +from nanobot.utils.webui_turn_helpers import ( + WebuiTurnCoordinator, + build_bus_progress_callback, + mark_webui_session, +) if TYPE_CHECKING: from nanobot.config.schema import ( @@ -237,6 +239,11 @@ class AgentLoop: self.context = ContextBuilder(workspace, timezone=timezone, disabled_skills=disabled_skills) self.sessions = session_manager or SessionManager(workspace) + self._webui_turns = WebuiTurnCoordinator( + bus=self.bus, + sessions=self.sessions, + schedule_background=lambda coro: self._schedule_background(coro), + ) self.tools = ToolRegistry() # One file-read/write tracker per logical session. The tool registry is # shared by this loop, so tools resolve the active state via contextvars. @@ -524,34 +531,7 @@ class AgentLoop: self, msg: InboundMessage ) -> Callable[..., Awaitable[None]]: """Build a progress callback that publishes to the message bus.""" - - async def _bus_progress( - content: str, - *, - tool_hint: bool = False, - tool_events: list[dict[str, Any]] | None = None, - reasoning: bool = False, - reasoning_end: bool = False, - ) -> None: - meta = dict(msg.metadata or {}) - meta["_progress"] = True - meta["_tool_hint"] = tool_hint - if reasoning: - meta["_reasoning_delta"] = True - if reasoning_end: - meta["_reasoning_end"] = True - if tool_events: - meta["_tool_events"] = tool_events - await self.bus.publish_outbound( - OutboundMessage( - channel=msg.channel, - chat_id=msg.chat_id, - content=content, - metadata=meta, - ) - ) - - return _bus_progress + return build_bus_progress_callback(self.bus, msg) async def _build_retry_wait_callback( self, msg: InboundMessage @@ -938,38 +918,12 @@ class AgentLoop: content="", metadata=msg.metadata or {}, )) if msg.channel == "websocket": - # Signal that the turn is fully complete (all tools executed, - # final text streamed). This lets WS clients know when to - # definitively stop the loading indicator. turn_lat = self._pending_turn_latency_ms.pop(session_key, None) - turn_metadata: dict[str, Any] = {**msg.metadata, "_turn_end": True} - if turn_lat is not None: - turn_metadata["latency_ms"] = int(turn_lat) - sess_turn = self.sessions.get_or_create(session_key) - turn_metadata["goal_state"] = goal_state_ws_blob(sess_turn.metadata) - await self.bus.publish_outbound(OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, - content="", metadata=turn_metadata, - )) - if msg.metadata.get("webui") is True: - async def _generate_title_and_notify() -> None: - generated = await maybe_generate_webui_title_after_turn( - channel=msg.channel, - metadata=msg.metadata, - sessions=self.sessions, - session_key=session_key, - provider=self.provider, - model=self.model, - ) - if generated: - await self.bus.publish_outbound(OutboundMessage( - channel=msg.channel, - chat_id=msg.chat_id, - content="", - metadata={**msg.metadata, "_session_updated": True}, - )) - - self._schedule_background(_generate_title_and_notify()) + await self._webui_turns.handle_turn_end( + msg, + session_key=session_key, + latency_ms=turn_lat, + ) except asyncio.CancelledError: logger.info("Task cancelled for session {}", session_key) # Preserve partial context from the interrupted turn so @@ -1021,8 +975,9 @@ class AgentLoop: "Re-published {} leftover message(s) to bus for session {}", leftover, session_key, ) - await publish_turn_run_status(self.bus, msg, "idle") + await self._webui_turns.publish_run_status(msg, "idle") self._pending_turn_latency_ms.pop(session_key, None) + self._webui_turns.discard(session_key) async def close_mcp(self) -> None: """Drain pending background archives, then close MCP connections.""" @@ -1338,6 +1293,12 @@ class AgentLoop: "include_timestamps": True, } ctx.history = ctx.session.get_history(**_hist_kwargs) + self._webui_turns.capture_title_context( + ctx.session_key, + ctx.msg, + self.provider, + self.model, + ) ctx.initial_messages = self._build_initial_messages( ctx.msg, ctx.session, ctx.history, ctx.pending_summary @@ -1354,7 +1315,7 @@ class AgentLoop: return "ok" async def _state_run(self, ctx: TurnContext) -> str: - await publish_turn_run_status(self.bus, ctx.msg, "running") + await self._webui_turns.publish_run_status(ctx.msg, "running") result = await self._run_agent_loop( ctx.initial_messages, on_progress=ctx.on_progress, diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 56482f75b..64345822a 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -15,6 +15,12 @@ from loguru import logger from nanobot.agent.hook import AgentHook, AgentHookContext from nanobot.agent.tools.registry import ToolRegistry from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest +from nanobot.utils.file_edit_events import ( + build_file_edit_end_event, + build_file_edit_error_event, + build_file_edit_start_event, + prepare_file_edit_tracker, +) from nanobot.utils.helpers import ( IncrementalThinkExtractor, build_assistant_message, @@ -26,6 +32,7 @@ from nanobot.utils.helpers import ( strip_think, truncate_text, ) +from nanobot.utils.progress_events import invoke_file_edit_progress from nanobot.utils.prompt_templates import render_template from nanobot.utils.runtime import ( EMPTY_FINAL_RESPONSE_MESSAGE, @@ -813,6 +820,21 @@ class AgentRunner: return prep_error + hint, event, ( RuntimeError(prep_error) if spec.fail_on_tool_error else None ) + file_edit_tracker = prepare_file_edit_tracker( + call_id=tool_call.id, + tool_name=tool_call.name, + tool=tool, + workspace=spec.workspace, + params=params if isinstance(params, dict) else None, + ) + if file_edit_tracker is not None and spec.progress_callback is not None: + await invoke_file_edit_progress( + spec.progress_callback, + [build_file_edit_start_event( + file_edit_tracker, + params if isinstance(params, dict) else None, + )], + ) try: if tool is not None: result = await tool.execute(**params) @@ -821,6 +843,11 @@ class AgentRunner: except asyncio.CancelledError: raise except BaseException as exc: + if file_edit_tracker is not None and spec.progress_callback is not None: + await invoke_file_edit_progress( + spec.progress_callback, + [build_file_edit_error_event(file_edit_tracker, str(exc))], + ) event = { "name": tool_call.name, "status": "error", @@ -842,6 +869,11 @@ class AgentRunner: return payload, event, None if isinstance(result, str) and result.startswith("Error"): + if file_edit_tracker is not None and spec.progress_callback is not None: + await invoke_file_edit_progress( + spec.progress_callback, + [build_file_edit_error_event(file_edit_tracker, result)], + ) event = { "name": tool_call.name, "status": "error", @@ -860,6 +892,12 @@ class AgentRunner: return result + hint, event, RuntimeError(result) return result + hint, event, None + if file_edit_tracker is not None and spec.progress_callback is not None: + await invoke_file_edit_progress( + spec.progress_callback, + [build_file_edit_end_event(file_edit_tracker)], + ) + detail = "" if result is None else str(result) detail = detail.replace("\n", " ").strip() if not detail: diff --git a/nanobot/channels/websocket.py b/nanobot/channels/websocket.py index 86a33c8b7..0202bd33d 100644 --- a/nanobot/channels/websocket.py +++ b/nanobot/channels/websocket.py @@ -1606,6 +1606,7 @@ class WebSocketChannel(BaseChannel): if not conns: if ( msg.metadata.get("_progress") + or msg.metadata.get("_file_edit_events") or msg.metadata.get("_turn_end") or msg.metadata.get("_session_updated") or msg.metadata.get("_goal_status") @@ -1638,7 +1639,22 @@ class WebSocketChannel(BaseChannel): await self.send_turn_end(msg.chat_id, latency_ms=lat_i, goal_state=gs_blob) return if msg.metadata.get("_session_updated"): - await self.send_session_updated(msg.chat_id) + scope = msg.metadata.get("_session_update_scope") + await self.send_session_updated( + msg.chat_id, + scope=scope if isinstance(scope, str) else None, + ) + return + if msg.metadata.get("_file_edit_events"): + payload: dict[str, Any] = { + "event": "file_edit", + "chat_id": msg.chat_id, + "edits": msg.metadata["_file_edit_events"], + } + self._try_append_webui_transcript(msg.chat_id, payload) + raw = json.dumps(payload, ensure_ascii=False) + for connection in conns: + await self._safe_send_to(connection, raw, label=" ") return text = msg.content payload: dict[str, Any] = { @@ -1805,12 +1821,14 @@ class WebSocketChannel(BaseChannel): for connection in conns: await self._safe_send_to(connection, raw, label=" goal_status ") - async def send_session_updated(self, chat_id: str) -> None: + async def send_session_updated(self, chat_id: str, *, scope: str | None = None) -> None: """Notify clients that session metadata changed outside the main turn.""" conns = list(self._subs.get(chat_id, ())) if not conns: return body: dict[str, Any] = {"event": "session_updated", "chat_id": chat_id} + if scope: + body["scope"] = scope raw = json.dumps(body, ensure_ascii=False) for connection in conns: await self._safe_send_to(connection, raw, label=" session_updated ") diff --git a/nanobot/utils/file_edit_events.py b/nanobot/utils/file_edit_events.py new file mode 100644 index 000000000..8164aa18d --- /dev/null +++ b/nanobot/utils/file_edit_events.py @@ -0,0 +1,311 @@ +"""File-edit activity helpers for WebUI progress events.""" + +from __future__ import annotations + +import difflib +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any + + +TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file", "notebook_edit"}) +_MAX_SNAPSHOT_BYTES = 2 * 1024 * 1024 + + +@dataclass(slots=True) +class FileSnapshot: + path: Path + exists: bool + text: str | None + unreadable: bool = False + binary: bool = False + oversized: bool = False + + @property + def countable(self) -> bool: + return ( + self.text is not None + and not self.binary + and not self.oversized + and not self.unreadable + ) + + +@dataclass(slots=True) +class FileEditTracker: + call_id: str + tool: str + path: Path + display_path: str + before: FileSnapshot + + +def is_file_edit_tool(tool_name: str | None) -> bool: + return bool(tool_name) and tool_name in TRACKED_FILE_EDIT_TOOLS + + +def resolve_file_edit_path( + tool: Any, + workspace: Path | None, + params: dict[str, Any] | None, +) -> Path | None: + """Resolve the target file path after tool argument preparation.""" + if not isinstance(params, dict): + return None + raw_path = params.get("path") + if not isinstance(raw_path, str) or not raw_path.strip(): + return None + resolver = getattr(tool, "_resolve", None) + if callable(resolver): + try: + resolved = resolver(raw_path) + if isinstance(resolved, Path): + return resolved + if resolved: + return Path(resolved) + except Exception: + return None + if workspace is None: + return Path(raw_path).expanduser().resolve() + return (workspace / raw_path).expanduser().resolve() + + +def display_file_edit_path(path: Path, workspace: Path | None) -> str: + if workspace is not None: + try: + return path.resolve().relative_to(workspace.resolve()).as_posix() + except Exception: + pass + return path.as_posix() + + +def read_file_snapshot(path: Path, *, max_bytes: int = _MAX_SNAPSHOT_BYTES) -> FileSnapshot: + try: + if not path.exists() or not path.is_file(): + return FileSnapshot(path=path, exists=False, text="") + size = path.stat().st_size + if size > max_bytes: + return FileSnapshot(path=path, exists=True, text=None, oversized=True) + raw = path.read_bytes() + except OSError: + return FileSnapshot(path=path, exists=path.exists(), text=None, unreadable=True) + if b"\x00" in raw: + return FileSnapshot(path=path, exists=True, text=None, binary=True) + try: + text = raw.decode("utf-8") + except UnicodeDecodeError: + return FileSnapshot(path=path, exists=True, text=None, binary=True) + return FileSnapshot(path=path, exists=True, text=text.replace("\r\n", "\n")) + + +def line_diff_stats(before: str | None, after: str | None) -> tuple[int, int]: + """Return ``(added, deleted)`` for a UTF-8 text line-level diff.""" + if before is None or after is None: + return 0, 0 + before_lines = before.replace("\r\n", "\n").splitlines() + after_lines = after.replace("\r\n", "\n").splitlines() + added = 0 + deleted = 0 + matcher = difflib.SequenceMatcher(a=before_lines, b=after_lines, autojunk=False) + for tag, i1, i2, j1, j2 in matcher.get_opcodes(): + if tag == "equal": + continue + if tag in ("replace", "delete"): + deleted += i2 - i1 + if tag in ("replace", "insert"): + added += j2 - j1 + return added, deleted + + +def prepare_file_edit_tracker( + *, + call_id: str, + tool_name: str, + tool: Any, + workspace: Path | None, + params: dict[str, Any] | None, +) -> FileEditTracker | None: + if not is_file_edit_tool(tool_name): + return None + path = resolve_file_edit_path(tool, workspace, params) + if path is None: + return None + before = read_file_snapshot(path) + return FileEditTracker( + call_id=str(call_id or ""), + tool=tool_name, + path=path, + display_path=display_file_edit_path(path, workspace), + before=before, + ) + + +def build_file_edit_start_event( + tracker: FileEditTracker, + params: dict[str, Any] | None, +) -> dict[str, Any]: + predicted_after = _predict_after_text(tracker.tool, params or {}, tracker.before) + if tracker.before.countable and predicted_after is not None: + added, deleted = line_diff_stats(tracker.before.text, predicted_after) + else: + added, deleted = 0, 0 + return _event_payload( + tracker, + phase="start", + status="editing", + added=added, + deleted=deleted, + approximate=True, + ) + + +def build_file_edit_end_event(tracker: FileEditTracker) -> dict[str, Any]: + after = read_file_snapshot(tracker.path) + if tracker.before.countable and after.countable: + added, deleted = line_diff_stats(tracker.before.text, after.text) + else: + added, deleted = 0, 0 + return _event_payload( + tracker, + phase="end", + status="done", + added=added, + deleted=deleted, + approximate=False, + binary=after.binary or after.oversized or after.unreadable, + ) + + +def build_file_edit_error_event(tracker: FileEditTracker, error: str | None = None) -> dict[str, Any]: + payload = _event_payload( + tracker, + phase="error", + status="error", + added=0, + deleted=0, + approximate=False, + ) + if error: + payload["error"] = error.strip()[:240] + return payload + + +def _event_payload( + tracker: FileEditTracker, + *, + phase: str, + status: str, + added: int, + deleted: int, + approximate: bool, + binary: bool = False, +) -> dict[str, Any]: + payload: dict[str, Any] = { + "version": 1, + "call_id": tracker.call_id, + "tool": tracker.tool, + "path": tracker.display_path, + "phase": phase, + "added": max(0, int(added)), + "deleted": max(0, int(deleted)), + "approximate": bool(approximate), + "status": status, + } + if binary: + payload["binary"] = True + return payload + + +def _predict_after_text( + tool_name: str, + params: dict[str, Any], + before: FileSnapshot, +) -> str | None: + if not before.countable: + return None + before_text = before.text or "" + if tool_name == "write_file": + content = params.get("content") + return content if isinstance(content, str) else "" + if tool_name == "edit_file": + old_text = params.get("old_text") + new_text = params.get("new_text") + if not isinstance(old_text, str) or not isinstance(new_text, str): + return None + replace_all = bool(params.get("replace_all")) + if old_text == "": + return new_text if not before.exists else before_text + if old_text in before_text: + if replace_all: + return before_text.replace(old_text, new_text) + return before_text.replace(old_text, new_text, 1) + return None + if tool_name == "notebook_edit": + return _predict_notebook_after_text(params, before_text) + return None + + +def _predict_notebook_after_text(params: dict[str, Any], before_text: str) -> str | None: + try: + nb = json.loads(before_text) if before_text.strip() else _empty_notebook() + except Exception: + return None + cells = nb.get("cells") + if not isinstance(cells, list): + return None + try: + cell_index = int(params.get("cell_index", 0)) + except (TypeError, ValueError): + return None + new_source = params.get("new_source") + source = new_source if isinstance(new_source, str) else "" + cell_type = params.get("cell_type") if params.get("cell_type") in ("code", "markdown") else "code" + mode = params.get("edit_mode") if params.get("edit_mode") in ("replace", "insert", "delete") else "replace" + if mode == "delete": + if 0 <= cell_index < len(cells): + cells.pop(cell_index) + else: + return None + elif mode == "insert": + insert_at = min(max(cell_index + 1, 0), len(cells)) + cells.insert(insert_at, _new_notebook_cell(source, str(cell_type))) + else: + if not (0 <= cell_index < len(cells)): + return None + cell = cells[cell_index] + if not isinstance(cell, dict): + return None + cell["source"] = source + cell["cell_type"] = cell_type + if cell_type == "code": + cell.setdefault("outputs", []) + cell.setdefault("execution_count", None) + else: + cell.pop("outputs", None) + cell.pop("execution_count", None) + nb["cells"] = cells + try: + return json.dumps(nb, indent=1, ensure_ascii=False) + except Exception: + return None + + +def _empty_notebook() -> dict[str, Any]: + return { + "nbformat": 4, + "nbformat_minor": 5, + "metadata": { + "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, + "language_info": {"name": "python"}, + }, + "cells": [], + } + + +def _new_notebook_cell(source: str, cell_type: str) -> dict[str, Any]: + cell: dict[str, Any] = {"cell_type": cell_type, "source": source, "metadata": {}} + if cell_type == "code": + cell["outputs"] = [] + cell["execution_count"] = None + return cell diff --git a/nanobot/utils/progress_events.py b/nanobot/utils/progress_events.py index 10a282b99..ccf125ec4 100644 --- a/nanobot/utils/progress_events.py +++ b/nanobot/utils/progress_events.py @@ -10,13 +10,21 @@ from nanobot.agent.hook import AgentHookContext def on_progress_accepts_tool_events(cb: Callable[..., Any]) -> bool: + return _on_progress_accepts(cb, "tool_events") + + +def on_progress_accepts_file_edit_events(cb: Callable[..., Any]) -> bool: + return _on_progress_accepts(cb, "file_edit_events") + + +def _on_progress_accepts(cb: Callable[..., Any], name: str) -> bool: try: sig = inspect.signature(cb) except (TypeError, ValueError): return False if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()): return True - return "tool_events" in sig.parameters + return name in sig.parameters async def invoke_on_progress( @@ -32,6 +40,15 @@ async def invoke_on_progress( await on_progress(content, tool_hint=tool_hint) +async def invoke_file_edit_progress( + on_progress: Callable[..., Awaitable[None]], + file_edit_events: list[dict[str, Any]], +) -> None: + if not file_edit_events or not on_progress_accepts_file_edit_events(on_progress): + return + await on_progress("", file_edit_events=file_edit_events) + + def build_tool_event_start_payload(tool_call: Any) -> dict[str, Any]: return { "version": 1, diff --git a/nanobot/utils/webui_titles.py b/nanobot/utils/webui_titles.py deleted file mode 100644 index 2d363f926..000000000 --- a/nanobot/utils/webui_titles.py +++ /dev/null @@ -1,138 +0,0 @@ -"""Helpers for WebUI chat title generation.""" - -from __future__ import annotations - -import re -from typing import Any - -from loguru import logger - -from nanobot.providers.base import LLMProvider -from nanobot.session.manager import Session, SessionManager -from nanobot.utils.helpers import truncate_text - -WEBUI_SESSION_METADATA_KEY = "webui" -WEBUI_TITLE_METADATA_KEY = "title" -WEBUI_TITLE_USER_EDITED_METADATA_KEY = "title_user_edited" -TITLE_MAX_CHARS = 60 - - -def mark_webui_session(session: Session, metadata: dict[str, Any]) -> bool: - """Persist a WebUI marker only when the inbound websocket frame opted in.""" - if metadata.get(WEBUI_SESSION_METADATA_KEY) is not True: - return False - session.metadata[WEBUI_SESSION_METADATA_KEY] = True - return True - - -def clean_generated_title(raw: str | None) -> str: - text = (raw or "").strip() - if not text: - return "" - text = re.sub(r"^\s*(title|标题)\s*[::]\s*", "", text, flags=re.IGNORECASE) - text = text.strip().strip("\"'`“”‘’") - text = re.sub(r"\s+", " ", text).strip() - text = text.rstrip("。.!!??,,;;:") - if len(text) > TITLE_MAX_CHARS: - text = text[: TITLE_MAX_CHARS - 1].rstrip() + "…" - return text - - -def _title_inputs(session: Session) -> tuple[str, str]: - user_text = "" - assistant_text = "" - for message in session.messages: - role = message.get("role") - content = message.get("content") - if not isinstance(content, str) or not content.strip(): - continue - if role == "user" and not user_text: - user_text = content.strip() - elif role == "assistant" and not assistant_text: - assistant_text = content.strip() - if user_text and assistant_text: - break - return user_text, assistant_text - - -async def maybe_generate_webui_title( - *, - sessions: SessionManager, - session_key: str, - provider: LLMProvider, - model: str, -) -> bool: - """Generate and persist a short title for WebUI-owned sessions only.""" - session = sessions.get_or_create(session_key) - if session.metadata.get(WEBUI_SESSION_METADATA_KEY) is not True: - return False - if session.metadata.get(WEBUI_TITLE_USER_EDITED_METADATA_KEY) is True: - return False - current_title = session.metadata.get(WEBUI_TITLE_METADATA_KEY) - if isinstance(current_title, str) and current_title.strip(): - return False - - user_text, assistant_text = _title_inputs(session) - if not user_text: - return False - - prompt = ( - "Generate a concise title for this chat.\n" - "Rules:\n" - "- Use the same language as the user when practical.\n" - "- 3 to 8 words.\n" - "- No quotes.\n" - "- No punctuation at the end.\n" - "- Return only the title.\n\n" - f"User: {truncate_text(user_text, 1_000)}" - ) - if assistant_text: - prompt += f"\nAssistant: {truncate_text(assistant_text, 1_000)}" - - try: - response = await provider.chat_with_retry( - [ - { - "role": "system", - "content": ( - "You write short, neutral chat titles. " - "Return only the title text." - ), - }, - {"role": "user", "content": prompt}, - ], - tools=None, - model=model, - max_tokens=32, - temperature=0.2, - retry_mode="standard", - ) - except Exception: - logger.debug("Failed to generate webui session title for {}", session_key, exc_info=True) - return False - - title = clean_generated_title(response.content) - if not title or title.lower().startswith("error"): - return False - session.metadata[WEBUI_TITLE_METADATA_KEY] = title - sessions.save(session) - return True - - -async def maybe_generate_webui_title_after_turn( - *, - channel: str, - metadata: dict[str, Any], - sessions: SessionManager, - session_key: str, - provider: LLMProvider, - model: str, -) -> bool: - if channel != "websocket" or metadata.get(WEBUI_SESSION_METADATA_KEY) is not True: - return False - return await maybe_generate_webui_title( - sessions=sessions, - session_key=session_key, - provider=provider, - model=model, - ) diff --git a/nanobot/utils/webui_transcript.py b/nanobot/utils/webui_transcript.py index dde0e9168..bee71c542 100644 --- a/nanobot/utils/webui_transcript.py +++ b/nanobot/utils/webui_transcript.py @@ -125,11 +125,25 @@ def replay_transcript_to_ui_messages( buffer_message_id: str | None = None buffer_parts: list[str] = [] suppress_until_turn_end = False + active_activity_segment_id: str | None = None + active_file_edit_segment_id: str | None = None + activity_segment_counter = 0 _ts_base = int(time.time() * 1000) def _new_id(prefix: str, idx: int) -> str: return f"{prefix}-{idx}-{uuid.uuid4().hex[:8]}" + def _new_activity_segment(*, activate: bool = True) -> str: + nonlocal active_activity_segment_id, activity_segment_counter + activity_segment_counter += 1 + segment_id = f"activity-{activity_segment_counter}" + if activate: + active_activity_segment_id = segment_id + return segment_id + + def _ensure_activity_segment() -> str: + return active_activity_segment_id or _new_activity_segment() + def attach_reasoning_chunk(prev: list[dict[str, Any]], chunk: str, idx: int) -> None: for i in range(len(prev) - 1, -1, -1): candidate = prev[i] @@ -151,12 +165,19 @@ def replay_transcript_to_ui_messages( **candidate, "reasoning": (str(candidate.get("reasoning") or "")) + chunk, "reasoningStreaming": True, + "activitySegmentId": candidate.get("activitySegmentId") or _ensure_activity_segment(), } return if not has_answer and candidate.get("isStreaming"): - prev[i] = {**candidate, "reasoning": chunk, "reasoningStreaming": True} + prev[i] = { + **candidate, + "reasoning": chunk, + "reasoningStreaming": True, + "activitySegmentId": candidate.get("activitySegmentId") or _ensure_activity_segment(), + } return break + segment = _ensure_activity_segment() prev.append( { "id": _new_id("as", idx), @@ -165,6 +186,7 @@ def replay_transcript_to_ui_messages( "isStreaming": True, "reasoning": chunk, "reasoningStreaming": True, + "activitySegmentId": segment, "createdAt": _ts_base + idx, }, ) @@ -221,6 +243,7 @@ def replay_transcript_to_ui_messages( return def absorb_complete(extra: dict[str, Any], idx: int) -> None: + nonlocal active_activity_segment_id last = messages[-1] if messages else None if last and is_reasoning_only_placeholder(last): messages[-1] = { @@ -238,10 +261,76 @@ def replay_transcript_to_ui_messages( **extra, }, ) + active_activity_segment_id = None + + def _file_edit_key(edit: dict[str, Any]) -> str: + return "|".join( + str(edit.get(k) or "") + for k in ("call_id", "tool", "path") + ) + + def upsert_file_edits(edits: list[dict[str, Any]], idx: int) -> None: + nonlocal active_file_edit_segment_id + if not edits: + return + last = messages[-1] if messages else None + if ( + active_file_edit_segment_id + and last + and last.get("kind") == "trace" + and last.get("fileEdits") + ): + segment = active_file_edit_segment_id + else: + segment = _new_activity_segment(activate=False) + active_file_edit_segment_id = segment + if not ( + last + and last.get("kind") == "trace" + and not last.get("isStreaming") + and last.get("fileEdits") + and last.get("activitySegmentId") == segment + ): + messages.append( + { + "id": _new_id("tr", idx), + "role": "tool", + "kind": "trace", + "content": "", + "traces": [], + "fileEdits": [], + "activitySegmentId": segment, + "createdAt": _ts_base + idx, + }, + ) + last = messages[-1] + existing = list(last.get("fileEdits") or []) + index_by_key = { + _file_edit_key(edit): pos + for pos, edit in enumerate(existing) + if isinstance(edit, dict) + } + for edit in edits: + if not isinstance(edit, dict): + continue + key = _file_edit_key(edit) + if key in index_by_key: + pos = index_by_key[key] + existing[pos] = {**existing[pos], **edit} + else: + index_by_key[key] = len(existing) + existing.append(dict(edit)) + messages[-1] = { + **last, + "fileEdits": existing, + "activitySegmentId": last.get("activitySegmentId") or segment, + } for idx, rec in enumerate(lines): ev = rec.get("event") if ev == "user": + active_activity_segment_id = None + active_file_edit_segment_id = None text = rec.get("text") text_s = text if isinstance(text, str) else "" media_paths = rec.get("media_paths") @@ -264,6 +353,12 @@ def replay_transcript_to_ui_messages( messages.append(row) continue + if ev == "file_edit": + raw_edits = rec.get("edits") + if isinstance(raw_edits, list): + upsert_file_edits([e for e in raw_edits if isinstance(e, dict)], idx) + continue + if ev == "delta": if suppress_until_turn_end: continue @@ -338,14 +433,21 @@ def replay_transcript_to_ui_messages( trace_lines = structured if structured else ([text] if isinstance(text, str) and text else []) if not trace_lines: continue + segment = _ensure_activity_segment() last = messages[-1] if messages else None - if last and last.get("kind") == "trace" and not last.get("isStreaming"): + if ( + last + and last.get("kind") == "trace" + and not last.get("isStreaming") + and (last.get("activitySegmentId") in (None, segment)) + ): prev_traces = list(last.get("traces") or [last.get("content")]) merged_traces = prev_traces + trace_lines messages[-1] = { **last, "traces": merged_traces, "content": trace_lines[-1], + "activitySegmentId": last.get("activitySegmentId") or segment, } else: messages.append( @@ -355,6 +457,7 @@ def replay_transcript_to_ui_messages( "kind": "trace", "content": trace_lines[-1], "traces": trace_lines, + "activitySegmentId": segment, "createdAt": _ts_base + idx, }, ) @@ -389,6 +492,8 @@ def replay_transcript_to_ui_messages( if ev == "turn_end": suppress_until_turn_end = False + active_activity_segment_id = None + active_file_edit_segment_id = None for i, m in enumerate(messages): if m.get("isStreaming"): messages[i] = {**m, "isStreaming": False} diff --git a/nanobot/utils/webui_turn_helpers.py b/nanobot/utils/webui_turn_helpers.py index 3fbca3729..10403852f 100644 --- a/nanobot/utils/webui_turn_helpers.py +++ b/nanobot/utils/webui_turn_helpers.py @@ -6,15 +6,161 @@ AgentLoop uses these without importing a concrete channel plugin; only from __future__ import annotations +import re import time +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field from typing import Any +from loguru import logger + from nanobot.bus.events import InboundMessage, OutboundMessage from nanobot.bus.queue import MessageBus +from nanobot.providers.base import LLMProvider +from nanobot.session.goal_state import goal_state_ws_blob +from nanobot.session.manager import Session, SessionManager +from nanobot.utils.helpers import truncate_text + +WEBUI_SESSION_METADATA_KEY = "webui" +WEBUI_TITLE_METADATA_KEY = "title" +WEBUI_TITLE_USER_EDITED_METADATA_KEY = "title_user_edited" +TITLE_MAX_CHARS = 60 +TITLE_GENERATION_MAX_TOKENS = 96 +TITLE_GENERATION_REASONING_EFFORT = "none" # Wall-clock turn start per ``chat_id`` (websocket only). Survives browser refresh while the # gateway process stays up; cleared on idle/stop and implicitly dropped on restart. _WEBSOCKET_TURN_WALL_STARTED_AT: dict[str, float] = {} +TitleContext = tuple[LLMProvider, str] + + +def mark_webui_session(session: Session, metadata: dict[str, Any]) -> bool: + """Persist a WebUI marker only when the inbound websocket frame opted in.""" + if metadata.get(WEBUI_SESSION_METADATA_KEY) is not True: + return False + session.metadata[WEBUI_SESSION_METADATA_KEY] = True + return True + + +def clean_generated_title(raw: str | None) -> str: + text = (raw or "").strip() + if not text: + return "" + text = re.sub(r"^\s*(title|标题)\s*[::]\s*", "", text, flags=re.IGNORECASE) + text = text.strip().strip("\"'`“”‘’") + text = re.sub(r"\s+", " ", text).strip() + text = text.rstrip("。.!!??,,;;:") + if len(text) > TITLE_MAX_CHARS: + text = text[: TITLE_MAX_CHARS - 1].rstrip() + "…" + return text + + +def _title_inputs(session: Session) -> tuple[str, str]: + user_text = "" + assistant_text = "" + for message in session.messages: + if message.get("_command") is True: + continue + role = message.get("role") + content = message.get("content") + if not isinstance(content, str) or not content.strip(): + continue + if role == "user" and not user_text: + user_text = content.strip() + elif role == "assistant" and not assistant_text: + assistant_text = content.strip() + if user_text and assistant_text: + break + return user_text, assistant_text + + +async def maybe_generate_webui_title( + *, + sessions: SessionManager, + session_key: str, + provider: LLMProvider, + model: str, +) -> bool: + """Generate and persist a short title for WebUI-owned sessions only.""" + session = sessions.get_or_create(session_key) + if session.metadata.get(WEBUI_SESSION_METADATA_KEY) is not True: + return False + if session.metadata.get(WEBUI_TITLE_USER_EDITED_METADATA_KEY) is True: + return False + current_title = session.metadata.get(WEBUI_TITLE_METADATA_KEY) + if isinstance(current_title, str) and current_title.strip(): + return False + + user_text, assistant_text = _title_inputs(session) + if not user_text: + return False + + prompt = ( + "Generate a concise title for this chat.\n" + "Rules:\n" + "- Use the same language as the user when practical.\n" + "- 3 to 8 words.\n" + "- No quotes.\n" + "- No punctuation at the end.\n" + "- Return only the title.\n\n" + f"User: {truncate_text(user_text, 1_000)}" + ) + if assistant_text: + prompt += f"\nAssistant: {truncate_text(assistant_text, 1_000)}" + + try: + response = await provider.chat_with_retry( + [ + { + "role": "system", + "content": ( + "You write short, neutral chat titles. " + "Return only the title text." + ), + }, + {"role": "user", "content": prompt}, + ], + tools=None, + model=model, + max_tokens=TITLE_GENERATION_MAX_TOKENS, + temperature=0.2, + reasoning_effort=TITLE_GENERATION_REASONING_EFFORT, + retry_mode="standard", + ) + except Exception: + logger.debug("Failed to generate webui session title for {}", session_key, exc_info=True) + return False + + title = clean_generated_title(response.content) + if not title or title.lower().startswith("error"): + logger.debug( + "WebUI title generation returned no usable title for {} (finish_reason={})", + session_key, + response.finish_reason, + ) + return False + session.metadata[WEBUI_TITLE_METADATA_KEY] = title + sessions.save(session) + return True + + +async def maybe_generate_webui_title_after_turn( + *, + channel: str, + metadata: dict[str, Any], + sessions: SessionManager, + session_key: str, + provider: LLMProvider, + model: str, +) -> bool: + if channel != "websocket" or metadata.get(WEBUI_SESSION_METADATA_KEY) is not True: + return False + return await maybe_generate_webui_title( + sessions=sessions, + session_key=session_key, + provider=provider, + model=model, + ) def websocket_turn_wall_started_at(chat_id: str) -> float | None: @@ -46,3 +192,125 @@ async def publish_turn_run_status(bus: MessageBus, msg: InboundMessage, status: metadata=meta, ), ) + + +def build_bus_progress_callback( + bus: MessageBus, + msg: InboundMessage, +) -> Callable[..., Awaitable[None]]: + """Return the bus progress callback for agent runtime events.""" + + async def _bus_progress( + content: str, + *, + tool_hint: bool = False, + tool_events: list[dict[str, Any]] | None = None, + file_edit_events: list[dict[str, Any]] | None = None, + reasoning: bool = False, + reasoning_end: bool = False, + ) -> None: + if file_edit_events and msg.channel != "websocket": + return + meta = dict(msg.metadata or {}) + meta["_progress"] = True + meta["_tool_hint"] = tool_hint + if reasoning: + meta["_reasoning_delta"] = True + if reasoning_end: + meta["_reasoning_end"] = True + if tool_events: + meta["_tool_events"] = tool_events + if file_edit_events: + meta["_file_edit_events"] = file_edit_events + await bus.publish_outbound( + OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content=content, + metadata=meta, + ) + ) + + return _bus_progress + + +@dataclass +class WebuiTurnCoordinator: + """Own the WebUI/WebSocket wire details that hang off AgentLoop turns.""" + + bus: MessageBus + sessions: SessionManager + schedule_background: Callable[[Awaitable[None]], None] + _title_contexts: dict[str, TitleContext] = field(default_factory=dict) + + def capture_title_context( + self, + session_key: str, + msg: InboundMessage, + provider: LLMProvider, + model: str, + ) -> None: + if msg.channel == "websocket" and msg.metadata.get("webui") is True: + self._title_contexts[session_key] = (provider, model) + + def discard(self, session_key: str) -> None: + self._title_contexts.pop(session_key, None) + + async def publish_run_status(self, msg: InboundMessage, status: str) -> None: + await publish_turn_run_status(self.bus, msg, status) + + async def handle_turn_end( + self, + msg: InboundMessage, + *, + session_key: str, + latency_ms: int | None, + ) -> None: + if msg.channel != "websocket": + return + + turn_metadata: dict[str, Any] = {**msg.metadata, "_turn_end": True} + if latency_ms is not None: + turn_metadata["latency_ms"] = int(latency_ms) + session = self.sessions.get_or_create(session_key) + turn_metadata["goal_state"] = goal_state_ws_blob(session.metadata) + await self.bus.publish_outbound(OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content="", + metadata=turn_metadata, + )) + self._schedule_title_update(msg, session_key=session_key) + + def _schedule_title_update(self, msg: InboundMessage, *, session_key: str) -> None: + title_context = self._title_contexts.pop(session_key, None) + if msg.metadata.get("webui") is not True or title_context is None: + return + + title_provider, title_model = title_context + + async def _generate_title_and_notify( + provider: LLMProvider = title_provider, + model: str = title_model, + ) -> None: + generated = await maybe_generate_webui_title_after_turn( + channel=msg.channel, + metadata=msg.metadata, + sessions=self.sessions, + session_key=session_key, + provider=provider, + model=model, + ) + if generated: + await self.bus.publish_outbound(OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content="", + metadata={ + **msg.metadata, + "_session_updated": True, + "_session_update_scope": "metadata", + }, + )) + + self.schedule_background(_generate_title_and_notify()) diff --git a/tests/agent/test_loop_progress.py b/tests/agent/test_loop_progress.py index fcf6198c1..b1b33612f 100644 --- a/tests/agent/test_loop_progress.py +++ b/tests/agent/test_loop_progress.py @@ -82,6 +82,96 @@ class TestToolEventProgress: ), ] + @pytest.mark.asyncio + async def test_write_file_emits_file_edit_progress(self, tmp_path: Path) -> None: + loop = _make_loop(tmp_path) + target = tmp_path / "foo.txt" + target.write_text("old\n", encoding="utf-8") + tool_call = ToolCallRequest( + id="call-write", + name="write_file", + arguments={"path": "foo.txt", "content": "new\nextra\n"}, + ) + calls = iter([ + LLMResponse(content="", tool_calls=[tool_call]), + LLMResponse(content="Done", tool_calls=[]), + ]) + loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls)) + loop.tools.get_definitions = MagicMock(return_value=[]) + loop.tools.prepare_call = MagicMock( + return_value=(None, {"path": "foo.txt", "content": "new\nextra\n"}, None), + ) + + async def execute(name: str, params: dict) -> str: + target.write_text(params["content"], encoding="utf-8") + return "ok" + + loop.tools.execute = AsyncMock(side_effect=execute) + file_events: list[dict] = [] + + async def on_progress( + content: str, + *, + tool_hint: bool = False, + tool_events: list[dict] | None = None, + file_edit_events: list[dict] | None = None, + ) -> None: + if file_edit_events: + file_events.extend(file_edit_events) + + final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress) + + assert final_content == "Done" + assert [event["phase"] for event in file_events] == ["start", "end"] + assert file_events[0] == { + "version": 1, + "call_id": "call-write", + "tool": "write_file", + "path": "foo.txt", + "phase": "start", + "added": 2, + "deleted": 1, + "approximate": True, + "status": "editing", + } + assert file_events[1]["status"] == "done" + assert file_events[1]["approximate"] is False + assert (file_events[1]["added"], file_events[1]["deleted"]) == (2, 1) + + @pytest.mark.asyncio + async def test_exec_does_not_emit_file_edit_progress(self, tmp_path: Path) -> None: + loop = _make_loop(tmp_path) + tool_call = ToolCallRequest( + id="call-exec", + name="exec", + arguments={"command": "printf hi > foo.txt"}, + ) + calls = iter([ + LLMResponse(content="", tool_calls=[tool_call]), + LLMResponse(content="Done", tool_calls=[]), + ]) + loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls)) + loop.tools.get_definitions = MagicMock(return_value=[]) + loop.tools.prepare_call = MagicMock( + return_value=(None, {"command": "printf hi > foo.txt"}, None), + ) + loop.tools.execute = AsyncMock(return_value="ok") + file_events: list[dict] = [] + + async def on_progress( + content: str, + *, + tool_hint: bool = False, + tool_events: list[dict] | None = None, + file_edit_events: list[dict] | None = None, + ) -> None: + if file_edit_events: + file_events.extend(file_edit_events) + + await loop._run_agent_loop([], on_progress=on_progress) + + assert file_events == [] + @pytest.mark.asyncio async def test_bus_progress_forwards_tool_events_to_outbound_metadata(self, tmp_path: Path) -> None: """When run() handles a bus message, _tool_events lands in OutboundMessage metadata.""" @@ -130,6 +220,42 @@ class TestToolEventProgress: assert finish["phase"] == "end" assert finish["result"] == "file.txt" + @pytest.mark.asyncio + async def test_bus_progress_forwards_file_edit_events_for_websocket_only(self, tmp_path: Path) -> None: + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") + edit_events = [{ + "call_id": "call-write", + "tool": "write_file", + "path": "foo.txt", + "phase": "start", + "added": 1, + "deleted": 0, + "approximate": True, + "status": "editing", + }] + + websocket_progress = await loop._build_bus_progress_callback(InboundMessage( + channel="websocket", + sender_id="u1", + chat_id="chat1", + content="edit", + )) + await websocket_progress("", file_edit_events=edit_events) + outbound = await bus.consume_outbound() + assert outbound.metadata["_file_edit_events"] == edit_events + + telegram_progress = await loop._build_bus_progress_callback(InboundMessage( + channel="telegram", + sender_id="u1", + chat_id="chat2", + content="edit", + )) + await telegram_progress("", file_edit_events=edit_events) + assert bus.outbound_size == 0 + @pytest.mark.asyncio async def test_non_streaming_channel_does_not_publish_codex_progress_deltas( self, @@ -353,8 +479,93 @@ class TestToolEventProgress: assert session_updated is not None assert (session_updated.metadata or {}).get("_session_updated") is True + assert (session_updated.metadata or {}).get("_session_update_scope") == "metadata" assert provider.chat_with_retry.await_count == 2 + @pytest.mark.asyncio + async def test_webui_title_generation_uses_turn_model_snapshot( + self, + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Done", tool_calls=[])) + loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") + loop.tools.get_definitions = MagicMock(return_value=[]) + loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] + + captured: dict[str, object] = {} + + async def fake_title_after_turn(**kwargs: object) -> bool: + captured.update(kwargs) + return False + + monkeypatch.setattr( + "nanobot.utils.webui_turn_helpers.maybe_generate_webui_title_after_turn", + fake_title_after_turn, + ) + scheduled_title: list[object] = [] + + def schedule_background(coro: object) -> None: + name = getattr(coro, "__qualname__", "") + if "_generate_title_and_notify" in name: + scheduled_title.append(coro) + elif hasattr(coro, "close"): + coro.close() + + loop._schedule_background = schedule_background # type: ignore[method-assign] + + await loop._dispatch(InboundMessage( + channel="websocket", + sender_id="u1", + chat_id="chat1", + content="say hello", + metadata={"webui": True}, + )) + + assert len(scheduled_title) == 1 + loop.provider = MagicMock() + loop.model = "switched-after-turn" + + await scheduled_title[0] # type: ignore[misc] + + assert captured["provider"] is provider + assert captured["model"] == "test-model" + + @pytest.mark.asyncio + async def test_webui_command_turn_does_not_schedule_title_generation( + self, + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Done", tool_calls=[])) + loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") + + async def fake_title_after_turn(**_kwargs: object) -> bool: + raise AssertionError("command-only turns should not generate titles") + + monkeypatch.setattr( + "nanobot.utils.webui_turn_helpers.maybe_generate_webui_title_after_turn", + fake_title_after_turn, + ) + scheduled: list[object] = [] + loop._schedule_background = scheduled.append # type: ignore[method-assign] + + await loop._dispatch(InboundMessage( + channel="websocket", + sender_id="u1", + chat_id="chat1", + content="/model", + metadata={"webui": True}, + )) + + assert scheduled == [] + @pytest.mark.asyncio async def test_non_websocket_dispatch_does_not_publish_turn_end_marker(self, tmp_path: Path) -> None: bus = MessageBus() diff --git a/tests/agent/test_loop_save_turn.py b/tests/agent/test_loop_save_turn.py index ed78e7192..105291347 100644 --- a/tests/agent/test_loop_save_turn.py +++ b/tests/agent/test_loop_save_turn.py @@ -11,7 +11,9 @@ from nanobot.bus.queue import MessageBus from nanobot.providers.base import LLMResponse from nanobot.session.goal_state import GOAL_STATE_KEY from nanobot.session.manager import Session -from nanobot.utils.webui_titles import ( +from nanobot.utils.webui_turn_helpers import ( + TITLE_GENERATION_MAX_TOKENS, + TITLE_GENERATION_REASONING_EFFORT, WEBUI_SESSION_METADATA_KEY, WEBUI_TITLE_METADATA_KEY, maybe_generate_webui_title, @@ -55,6 +57,11 @@ async def test_generate_webui_title_only_for_marked_webui_sessions(tmp_path: Pat assert generated is True assert session.metadata[WEBUI_TITLE_METADATA_KEY] == "优化 WebUI 侧边栏" loop.provider.chat_with_retry.assert_awaited_once() + assert loop.provider.chat_with_retry.await_args.kwargs["max_tokens"] == TITLE_GENERATION_MAX_TOKENS + assert ( + loop.provider.chat_with_retry.await_args.kwargs["reasoning_effort"] + == TITLE_GENERATION_REASONING_EFFORT + ) @pytest.mark.asyncio @@ -79,6 +86,31 @@ async def test_generate_webui_title_skips_plain_websocket_sessions(tmp_path: Pat loop.provider.chat_with_retry.assert_not_awaited() +@pytest.mark.asyncio +async def test_generate_webui_title_ignores_command_only_sessions(tmp_path: Path) -> None: + loop = _make_full_loop(tmp_path) + session = loop.sessions.get_or_create("websocket:command-title") + session.metadata[WEBUI_SESSION_METADATA_KEY] = True + session.add_message("user", "/model deep", _command=True) + session.add_message( + "assistant", + "Switched model preset to `deep`.\n- Model: `deepseek-v4-pro`", + _command=True, + ) + loop.sessions.save(session) + + generated = await maybe_generate_webui_title( + sessions=loop.sessions, + session_key="websocket:command-title", + provider=loop.provider, + model=loop.model, + ) + + assert generated is False + assert WEBUI_TITLE_METADATA_KEY not in session.metadata + loop.provider.chat_with_retry.assert_not_awaited() + + def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None: loop = _mk_loop() session = Session(key="test:runtime-only") diff --git a/tests/channels/test_websocket_channel.py b/tests/channels/test_websocket_channel.py index 2fa7285fb..c6f9d66a3 100644 --- a/tests/channels/test_websocket_channel.py +++ b/tests/channels/test_websocket_channel.py @@ -370,6 +370,55 @@ async def test_send_progress_includes_structured_tool_events() -> None: ] +@pytest.mark.asyncio +async def test_send_file_edit_progress_uses_file_edit_event() -> None: + bus = MagicMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) + mock_ws = AsyncMock() + channel._attach(mock_ws, "chat-1") + + await channel.send(OutboundMessage( + channel="websocket", + chat_id="chat-1", + content="", + metadata={ + "_progress": True, + "_file_edit_events": [ + { + "version": 1, + "phase": "start", + "call_id": "call-1", + "tool": "write_file", + "path": "src/app.py", + "added": 12, + "deleted": 2, + "approximate": True, + "status": "editing", + } + ], + }, + )) + + payload = json.loads(mock_ws.send.await_args.args[0]) + assert payload == { + "event": "file_edit", + "chat_id": "chat-1", + "edits": [ + { + "version": 1, + "phase": "start", + "call_id": "call-1", + "tool": "write_file", + "path": "src/app.py", + "added": 12, + "deleted": 2, + "approximate": True, + "status": "editing", + } + ], + } + + @pytest.mark.asyncio async def test_send_progress_includes_agent_ui_blob() -> None: bus = MagicMock() @@ -758,6 +807,25 @@ async def test_send_session_updated_emits_session_updated_event() -> None: assert body == {"event": "session_updated", "chat_id": "chat-1"} +@pytest.mark.asyncio +async def test_send_session_updated_includes_scope_when_present() -> None: + bus = MagicMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) + mock_ws = AsyncMock() + channel._attach(mock_ws, "chat-1") + + await channel.send(OutboundMessage( + channel="websocket", + chat_id="chat-1", + content="", + metadata={"_session_updated": True, "_session_update_scope": "metadata"}, + )) + + mock_ws.send.assert_awaited_once() + body = json.loads(mock_ws.send.await_args.args[0]) + assert body == {"event": "session_updated", "chat_id": "chat-1", "scope": "metadata"} + + @pytest.mark.asyncio async def test_send_non_connection_closed_exception_is_raised() -> None: bus = MagicMock() diff --git a/tests/utils/test_file_edit_events.py b/tests/utils/test_file_edit_events.py new file mode 100644 index 000000000..6176a5e36 --- /dev/null +++ b/tests/utils/test_file_edit_events.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from pathlib import Path + +from nanobot.utils.file_edit_events import ( + build_file_edit_end_event, + build_file_edit_start_event, + line_diff_stats, + prepare_file_edit_tracker, + read_file_snapshot, +) + + +def test_line_diff_stats_counts_replacements_insertions_and_deletions() -> None: + added, deleted = line_diff_stats("a\nb\nc\n", "a\nB\nc\nd\n") + assert (added, deleted) == (2, 1) + + +def test_line_diff_stats_normalizes_crlf() -> None: + assert line_diff_stats("a\r\nb\r\n", "a\nb\nc\n") == (1, 0) + + +def test_write_file_start_predicts_and_end_calibrates_exact_diff(tmp_path: Path) -> None: + target = tmp_path / "notes.txt" + target.write_text("old\nkeep\n", encoding="utf-8") + params = {"path": "notes.txt", "content": "new\nkeep\nextra\n"} + tracker = prepare_file_edit_tracker( + call_id="call-write", + tool_name="write_file", + tool=None, + workspace=tmp_path, + params=params, + ) + + assert tracker is not None + start = build_file_edit_start_event(tracker, params) + assert start == { + "version": 1, + "call_id": "call-write", + "tool": "write_file", + "path": "notes.txt", + "phase": "start", + "added": 2, + "deleted": 1, + "approximate": True, + "status": "editing", + } + + target.write_text("new\nkeep\nextra\n", encoding="utf-8") + end = build_file_edit_end_event(tracker) + assert end["phase"] == "end" + assert end["status"] == "done" + assert end["approximate"] is False + assert (end["added"], end["deleted"]) == (2, 1) + + +def test_binary_file_is_reported_but_not_counted(tmp_path: Path) -> None: + target = tmp_path / "data.bin" + target.write_bytes(b"\x00\x01before") + tracker = prepare_file_edit_tracker( + call_id="call-bin", + tool_name="edit_file", + tool=None, + workspace=tmp_path, + params={"path": "data.bin", "old_text": "before", "new_text": "after"}, + ) + + assert tracker is not None + assert not read_file_snapshot(target).countable + target.write_bytes(b"\x00\x01after") + event = build_file_edit_end_event(tracker) + assert event["binary"] is True + assert (event["added"], event["deleted"]) == (0, 0) + + +def test_untracked_tools_do_not_prepare_file_edit_tracker(tmp_path: Path) -> None: + assert prepare_file_edit_tracker( + call_id="call-exec", + tool_name="exec", + tool=None, + workspace=tmp_path, + params={"path": "created-by-shell.txt"}, + ) is None diff --git a/tests/utils/test_webui_transcript.py b/tests/utils/test_webui_transcript.py index 419abbfcd..f13380f46 100644 --- a/tests/utils/test_webui_transcript.py +++ b/tests/utils/test_webui_transcript.py @@ -42,6 +42,62 @@ def test_replay_delta_and_turn_end(tmp_path, monkeypatch) -> None: assert msgs[1]["latencyMs"] == 42 +def test_replay_file_edit_event_creates_file_activity(tmp_path, monkeypatch) -> None: + monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path) + key = "websocket:t-file" + for ev in ( + {"event": "user", "chat_id": "t-file", "text": "edit"}, + { + "event": "message", + "chat_id": "t-file", + "text": 'write_file({"path":"foo.txt"})', + "kind": "tool_hint", + }, + { + "event": "file_edit", + "chat_id": "t-file", + "edits": [ + { + "version": 1, + "call_id": "call-write", + "tool": "write_file", + "path": "foo.txt", + "phase": "end", + "added": 2, + "deleted": 1, + "approximate": False, + "status": "done", + }, + ], + }, + ): + append_transcript_object(key, ev) + + msgs = replay_transcript_to_ui_messages(read_transcript_lines(key)) + + assert len(msgs) == 3 + assert msgs[1]["kind"] == "trace" + assert msgs[1]["traces"] == ['write_file({"path":"foo.txt"})'] + assert "fileEdits" not in msgs[1] + assert msgs[2]["kind"] == "trace" + assert msgs[2]["traces"] == [] + assert msgs[2]["fileEdits"] == [ + { + "version": 1, + "call_id": "call-write", + "tool": "write_file", + "path": "foo.txt", + "phase": "end", + "added": 2, + "deleted": 1, + "approximate": False, + "status": "done", + }, + ] + assert msgs[2]["activitySegmentId"] + assert msgs[2]["activitySegmentId"] != msgs[1]["activitySegmentId"] + + def test_build_response_schema(monkeypatch, tmp_path) -> None: from nanobot.utils.webui_transcript import build_webui_thread_response