diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 776885ecb..0b0164fd0 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -20,6 +20,7 @@ from nanobot.utils.file_edit_events import ( build_file_edit_error_event, build_file_edit_start_event, prepare_file_edit_tracker, + StreamingFileEditTracker, ) from nanobot.utils.helpers import ( IncrementalThinkExtractor, @@ -629,6 +630,24 @@ class AgentRunner: ) progress_state: dict[str, bool] | None = None + live_file_edits: StreamingFileEditTracker | None = None + + if ( + spec.progress_callback is not None + and on_progress_accepts_file_edit_events(spec.progress_callback) + ): + async def _emit_live_file_edits(events: list[dict[str, Any]]) -> None: + await invoke_file_edit_progress(spec.progress_callback, events) + + live_file_edits = StreamingFileEditTracker( + workspace=spec.workspace, + tools=spec.tools, + emit=_emit_live_file_edits, + ) + + async def _tool_call_delta(delta: dict[str, Any]) -> None: + if live_file_edits is not None: + await live_file_edits.update(delta) if wants_streaming: async def _stream(delta: str) -> None: @@ -646,6 +665,7 @@ class AgentRunner: **kwargs, on_content_delta=_stream, on_thinking_delta=_thinking, + on_tool_call_delta=_tool_call_delta if live_file_edits is not None else None, ) elif wants_progress_streaming: stream_buf = "" @@ -675,6 +695,7 @@ class AgentRunner: coro = self.provider.chat_stream_with_retry( **kwargs, on_content_delta=_stream_progress, + on_tool_call_delta=_tool_call_delta if live_file_edits is not None else None, ) else: coro = self.provider.chat_with_retry(**kwargs) @@ -689,6 +710,14 @@ class AgentRunner: await coro if outer_timeout_s is None else await asyncio.wait_for(coro, timeout=outer_timeout_s) ) + if live_file_edits is not None: + await live_file_edits.flush() + if response.should_execute_tools: + live_file_edits.apply_final_call_ids(response.tool_calls) + await live_file_edits.error_unmatched( + response.tool_calls if response.should_execute_tools else [], + "Tool call did not complete.", + ) except asyncio.TimeoutError: if outer_timeout_s is None: return LLMResponse( @@ -907,7 +936,10 @@ class AgentRunner: if file_edit_tracker is not None and progress_callback is not None: await invoke_file_edit_progress( progress_callback, - [build_file_edit_end_event(file_edit_tracker)], + [build_file_edit_end_event( + file_edit_tracker, + params if isinstance(params, dict) else None, + )], ) detail = "" if result is None else str(result) diff --git a/nanobot/providers/anthropic_provider.py b/nanobot/providers/anthropic_provider.py index b667853a1..31f2bc2f1 100644 --- a/nanobot/providers/anthropic_provider.py +++ b/nanobot/providers/anthropic_provider.py @@ -590,6 +590,7 @@ class AnthropicProvider(LLMProvider): tool_choice: str | dict[str, Any] | None = None, on_content_delta: Callable[[str], Awaitable[None]] | None = None, on_thinking_delta: Callable[[str], Awaitable[None]] | None = None, + on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None, ) -> LLMResponse: kwargs = self._build_kwargs( messages, tools, model, max_tokens, temperature, @@ -598,11 +599,12 @@ class AnthropicProvider(LLMProvider): idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90")) try: async with self._client.messages.stream(**kwargs) as stream: - if on_content_delta or on_thinking_delta: + if on_content_delta or on_thinking_delta or on_tool_call_delta: # Idle timeout must track *any* SSE chunk (thinking_delta, # tool JSON deltas, etc.), not only text_stream tokens. # Otherwise extended thinking can stall text_stream for minutes # while the connection is healthy (e.g. MiniMax Anthropic). + tool_blocks: dict[int, dict[str, str]] = {} while True: try: chunk = await asyncio.wait_for( @@ -611,7 +613,22 @@ class AnthropicProvider(LLMProvider): ) except StopAsyncIteration: break - if ( + if chunk.type == "content_block_start": + block = getattr(chunk, "content_block", None) + if getattr(block, "type", None) == "tool_use": + index = int(getattr(chunk, "index", 0) or 0) + state = { + "call_id": str(getattr(block, "id", "") or ""), + "name": str(getattr(block, "name", "") or ""), + } + tool_blocks[index] = state + if on_tool_call_delta: + await on_tool_call_delta({ + "index": index, + **state, + "arguments_delta": "", + }) + elif ( chunk.type == "content_block_delta" and getattr(chunk.delta, "type", None) == "thinking_delta" ): @@ -625,6 +642,20 @@ class AnthropicProvider(LLMProvider): text = getattr(chunk.delta, "text", None) or "" if text and on_content_delta: await on_content_delta(text) + elif ( + chunk.type == "content_block_delta" + and getattr(chunk.delta, "type", None) == "input_json_delta" + ): + partial = getattr(chunk.delta, "partial_json", None) or "" + if partial and on_tool_call_delta: + index = int(getattr(chunk, "index", 0) or 0) + state = tool_blocks.get(index, {}) + await on_tool_call_delta({ + "index": index, + "call_id": state.get("call_id", ""), + "name": state.get("name", ""), + "arguments_delta": partial, + }) response = await asyncio.wait_for( stream.get_final_message(), timeout=idle_timeout_s, diff --git a/nanobot/providers/azure_openai_provider.py b/nanobot/providers/azure_openai_provider.py index 918a11ce2..24a65cdfe 100644 --- a/nanobot/providers/azure_openai_provider.py +++ b/nanobot/providers/azure_openai_provider.py @@ -158,6 +158,7 @@ class AzureOpenAIProvider(LLMProvider): tool_choice: str | dict[str, Any] | None = None, on_content_delta: Callable[[str], Awaitable[None]] | None = None, on_thinking_delta: Callable[[str], Awaitable[None]] | None = None, + on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None, ) -> LLMResponse: _ = on_thinking_delta body = self._build_body( @@ -169,7 +170,7 @@ class AzureOpenAIProvider(LLMProvider): try: stream = await self._client.responses.create(**body) content, tool_calls, finish_reason, usage, reasoning_content = ( - await consume_sdk_stream(stream, on_content_delta) + await consume_sdk_stream(stream, on_content_delta, on_tool_call_delta) ) return LLMResponse( content=content or None, diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 98f048db6..87697650a 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -70,11 +70,11 @@ class LLMResponse: @property def should_execute_tools(self) -> bool: - """Tools execute only when has_tool_calls AND finish_reason is ``tool_calls`` / ``stop``. + """Tools execute only when has_tool_calls AND finish_reason is a tool-capable stop. Blocks gateway-injected calls under ``refusal`` / ``content_filter`` / ``error`` (#3220).""" if not self.has_tool_calls: return False - return self.finish_reason in ("tool_calls", "stop") + return self.finish_reason in ("tool_calls", "function_call", "stop") @dataclass(frozen=True) @@ -501,6 +501,7 @@ class LLMProvider(ABC): tool_choice: str | dict[str, Any] | None = None, on_content_delta: Callable[[str], Awaitable[None]] | None = None, on_thinking_delta: Callable[[str], Awaitable[None]] | None = None, + on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None, ) -> LLMResponse: """Stream a chat completion, calling *on_content_delta* for each text chunk. @@ -514,7 +515,7 @@ class LLMProvider(ABC): full content as a single delta. Providers that support native streaming should override this method. """ - _ = on_thinking_delta + _ = on_thinking_delta, on_tool_call_delta response = await self.chat( messages=messages, tools=tools, model=model, max_tokens=max_tokens, temperature=temperature, @@ -544,6 +545,7 @@ class LLMProvider(ABC): tool_choice: str | dict[str, Any] | None = None, on_content_delta: Callable[[str], Awaitable[None]] | None = None, on_thinking_delta: Callable[[str], Awaitable[None]] | None = None, + on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None, retry_mode: str = "standard", on_retry_wait: Callable[[str], Awaitable[None]] | None = None, ) -> LLMResponse: @@ -561,6 +563,7 @@ class LLMProvider(ABC): reasoning_effort=reasoning_effort, tool_choice=tool_choice, on_content_delta=on_content_delta, on_thinking_delta=on_thinking_delta, + on_tool_call_delta=on_tool_call_delta, ) return await self._run_with_retry( self._safe_chat_stream, diff --git a/nanobot/providers/bedrock_provider.py b/nanobot/providers/bedrock_provider.py index b3f4ea572..ff74badbc 100644 --- a/nanobot/providers/bedrock_provider.py +++ b/nanobot/providers/bedrock_provider.py @@ -704,8 +704,9 @@ class BedrockProvider(LLMProvider): tool_choice: str | dict[str, Any] | None = None, on_content_delta: Callable[[str], Awaitable[None]] | None = None, on_thinking_delta: Callable[[str], Awaitable[None]] | None = None, + on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None, ) -> LLMResponse: - _ = on_thinking_delta + _ = on_thinking_delta, on_tool_call_delta idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90")) content_parts: list[str] = [] reasoning_parts: list[str] = [] diff --git a/nanobot/providers/github_copilot_provider.py b/nanobot/providers/github_copilot_provider.py index fdba99ebc..bec7c11e1 100644 --- a/nanobot/providers/github_copilot_provider.py +++ b/nanobot/providers/github_copilot_provider.py @@ -243,6 +243,7 @@ class GitHubCopilotProvider(OpenAICompatProvider): tool_choice: str | dict[str, object] | None = None, on_content_delta: Callable[[str], None] | None = None, on_thinking_delta: Callable[[str], Awaitable[None]] | None = None, + on_tool_call_delta: Callable[[dict[str, object]], Awaitable[None]] | None = None, ): await self._refresh_client_api_key() return await super().chat_stream( @@ -255,4 +256,5 @@ class GitHubCopilotProvider(OpenAICompatProvider): tool_choice=tool_choice, on_content_delta=on_content_delta, on_thinking_delta=on_thinking_delta, + on_tool_call_delta=on_tool_call_delta, ) diff --git a/nanobot/providers/openai_codex_provider.py b/nanobot/providers/openai_codex_provider.py index 38209f59c..523b2a72a 100644 --- a/nanobot/providers/openai_codex_provider.py +++ b/nanobot/providers/openai_codex_provider.py @@ -40,6 +40,7 @@ class OpenAICodexProvider(LLMProvider): reasoning_effort: str | None, tool_choice: str | dict[str, Any] | None, on_content_delta: Callable[[str], Awaitable[None]] | None = None, + on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None, ) -> LLMResponse: """Shared request logic for both chat() and chat_stream().""" model = model or self.default_model @@ -70,6 +71,7 @@ class OpenAICodexProvider(LLMProvider): content, tool_calls, finish_reason = await _request_codex( DEFAULT_CODEX_URL, headers, body, verify=True, on_content_delta=on_content_delta, + on_tool_call_delta=on_tool_call_delta, ) except Exception as e: if "CERTIFICATE_VERIFY_FAILED" not in str(e): @@ -78,6 +80,7 @@ class OpenAICodexProvider(LLMProvider): content, tool_calls, finish_reason = await _request_codex( DEFAULT_CODEX_URL, headers, body, verify=False, on_content_delta=on_content_delta, + on_tool_call_delta=on_tool_call_delta, ) return LLMResponse(content=content, tool_calls=tool_calls, finish_reason=finish_reason) except Exception as e: @@ -100,9 +103,18 @@ class OpenAICodexProvider(LLMProvider): tool_choice: str | dict[str, Any] | None = None, on_content_delta: Callable[[str], Awaitable[None]] | None = None, on_thinking_delta: Callable[[str], Awaitable[None]] | None = None, + on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None, ) -> LLMResponse: _ = on_thinking_delta - return await self._call_codex(messages, tools, model, reasoning_effort, tool_choice, on_content_delta) + return await self._call_codex( + messages, + tools, + model, + reasoning_effort, + tool_choice, + on_content_delta, + on_tool_call_delta, + ) def get_default_model(self) -> str: return self.default_model @@ -138,6 +150,7 @@ async def _request_codex( body: dict[str, Any], verify: bool, on_content_delta: Callable[[str], Awaitable[None]] | None = None, + on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None, ) -> tuple[str, list[ToolCallRequest], str]: async with httpx.AsyncClient(timeout=60.0, verify=verify) as client: async with client.stream("POST", url, headers=headers, json=body) as response: @@ -148,7 +161,7 @@ async def _request_codex( _friendly_error(response.status_code, text.decode("utf-8", "ignore")), retry_after=retry_after, ) - return await consume_sse(response, on_content_delta) + return await consume_sse(response, on_content_delta, on_tool_call_delta) def _prompt_cache_key(messages: list[dict[str, Any]]) -> str: diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 2bcb840cd..2f8455416 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -999,6 +999,21 @@ class OpenAICompatProvider(LLMProvider): if fn_prov: buf["fn_prov"] = fn_prov + def _accum_legacy_function_call(function_call: Any) -> None: + """Accumulate legacy ``delta.function_call`` streaming chunks.""" + if not function_call: + return + buf = tc_bufs.setdefault(0, { + "id": "", "name": "", "arguments": "", + "extra_content": None, "prov": None, "fn_prov": None, + }) + fn_name = _get(function_call, "name") + if fn_name: + buf["name"] = str(fn_name) + fn_args = _get(function_call, "arguments") + if fn_args: + buf["arguments"] += str(fn_args) + for chunk in chunks: if isinstance(chunk, str): content_parts.append(chunk) @@ -1029,6 +1044,7 @@ class OpenAICompatProvider(LLMProvider): reasoning_parts.append(text) for idx, tc in enumerate(delta.get("tool_calls") or []): _accum_tc(tc, idx) + _accum_legacy_function_call(delta.get("function_call")) usage = cls._extract_usage(chunk_map) or usage continue @@ -1047,8 +1063,10 @@ class OpenAICompatProvider(LLMProvider): reasoning = getattr(delta, "reasoning", None) if reasoning: reasoning_parts.append(reasoning) - for tc in (delta.tool_calls or []) if delta else []: + for tc in (getattr(delta, "tool_calls", None) or []) if delta else []: _accum_tc(tc, getattr(tc, "index", 0)) + if delta: + _accum_legacy_function_call(getattr(delta, "function_call", None)) return LLMResponse( content="".join(content_parts) or None, @@ -1203,6 +1221,7 @@ class OpenAICompatProvider(LLMProvider): tool_choice: str | dict[str, Any] | None = None, on_content_delta: Callable[[str], Awaitable[None]] | None = None, on_thinking_delta: Callable[[str], Awaitable[None]] | None = None, + on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None, ) -> LLMResponse: idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90")) try: @@ -1226,9 +1245,16 @@ class OpenAICompatProvider(LLMProvider): except StopAsyncIteration: break - content, tool_calls, finish_reason, usage, reasoning_content = await consume_sdk_stream( + ( + content, + tool_calls, + finish_reason, + usage, + reasoning_content, + ) = await consume_sdk_stream( _timed_stream(), on_content_delta, + on_tool_call_delta=on_tool_call_delta, ) self._record_responses_success(model, reasoning_effort) return LLMResponse( @@ -1252,6 +1278,12 @@ class OpenAICompatProvider(LLMProvider): messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice, ) + if self._spec and self._spec.name == "zhipu" and tools and on_tool_call_delta: + # Z.AI/GLM keeps streaming tool-call arguments behind an + # explicit provider flag. Pass it through the OpenAI SDK's + # extra_body escape hatch so the usual delta.tool_calls path + # can surface live file-edit progress. + kwargs.setdefault("extra_body", {})["tool_stream"] = True kwargs["stream"] = True kwargs["stream_options"] = {"include_usage": True} stream = await self._client.chat.completions.create(**kwargs) @@ -1279,6 +1311,28 @@ class OpenAICompatProvider(LLMProvider): r_text = self._extract_text_content(reasoning) if r_text: await on_thinking_delta(r_text) + if on_tool_call_delta: + for idx, tool_delta in enumerate( + getattr(delta_obj, "tool_calls", None) or [] + ): + fn = _get(tool_delta, "function") + tool_index = _get(tool_delta, "index") + await on_tool_call_delta({ + "index": tool_index if tool_index is not None else idx, + "call_id": str(_get(tool_delta, "id") or ""), + "name": str(_get(fn, "name") or "") if fn is not None else "", + "arguments_delta": ( + str(_get(fn, "arguments") or "") if fn is not None else "" + ), + }) + function_call = getattr(delta_obj, "function_call", None) + if function_call: + await on_tool_call_delta({ + "index": 0, + "call_id": "", + "name": str(_get(function_call, "name") or ""), + "arguments_delta": str(_get(function_call, "arguments") or ""), + }) return self._parse_chunks(chunks) except asyncio.TimeoutError: return LLMResponse( diff --git a/nanobot/providers/openai_responses/parsing.py b/nanobot/providers/openai_responses/parsing.py index 9e3f0ef02..707652d74 100644 --- a/nanobot/providers/openai_responses/parsing.py +++ b/nanobot/providers/openai_responses/parsing.py @@ -62,6 +62,7 @@ async def iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], N async def consume_sse( response: httpx.Response, on_content_delta: Callable[[str], Awaitable[None]] | None = None, + on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None, ) -> tuple[str, list[ToolCallRequest], str]: """Consume a Responses API SSE stream into ``(content, tool_calls, finish_reason)``.""" content = "" @@ -82,6 +83,12 @@ async def consume_sse( "name": item.get("name"), "arguments": item.get("arguments") or "", } + if on_tool_call_delta: + await on_tool_call_delta({ + "call_id": str(call_id), + "name": str(item.get("name") or ""), + "arguments_delta": "", + }) elif event_type == "response.output_text.delta": delta_text = event.get("delta") or "" content += delta_text @@ -90,7 +97,14 @@ async def consume_sse( elif event_type == "response.function_call_arguments.delta": call_id = event.get("call_id") if call_id and call_id in tool_call_buffers: - tool_call_buffers[call_id]["arguments"] += event.get("delta") or "" + delta = event.get("delta") or "" + tool_call_buffers[call_id]["arguments"] += delta + if on_tool_call_delta and delta: + await on_tool_call_delta({ + "call_id": str(call_id), + "name": str(tool_call_buffers[call_id].get("name") or ""), + "arguments_delta": str(delta), + }) elif event_type == "response.function_call_arguments.done": call_id = event.get("call_id") if call_id and call_id in tool_call_buffers: @@ -210,6 +224,7 @@ def parse_response_output(response: Any) -> LLMResponse: async def consume_sdk_stream( stream: Any, on_content_delta: Callable[[str], Awaitable[None]] | None = None, + on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None, ) -> tuple[str, list[ToolCallRequest], str, dict[str, int], str | None]: """Consume an SDK async stream from ``client.responses.create(stream=True)``.""" content = "" @@ -232,6 +247,12 @@ async def consume_sdk_stream( "name": getattr(item, "name", None), "arguments": getattr(item, "arguments", None) or "", } + if on_tool_call_delta: + await on_tool_call_delta({ + "call_id": str(call_id), + "name": str(getattr(item, "name", None) or ""), + "arguments_delta": "", + }) elif event_type == "response.output_text.delta": delta_text = getattr(event, "delta", "") or "" content += delta_text @@ -240,7 +261,14 @@ async def consume_sdk_stream( elif event_type == "response.function_call_arguments.delta": call_id = getattr(event, "call_id", None) if call_id and call_id in tool_call_buffers: - tool_call_buffers[call_id]["arguments"] += getattr(event, "delta", "") or "" + delta = getattr(event, "delta", "") or "" + tool_call_buffers[call_id]["arguments"] += delta + if on_tool_call_delta and delta: + await on_tool_call_delta({ + "call_id": str(call_id), + "name": str(tool_call_buffers[call_id].get("name") or ""), + "arguments_delta": str(delta), + }) elif event_type == "response.function_call_arguments.done": call_id = getattr(event, "call_id", None) if call_id and call_id in tool_call_buffers: diff --git a/nanobot/utils/file_edit_events.py b/nanobot/utils/file_edit_events.py index 8164aa18d..b5d2f6d73 100644 --- a/nanobot/utils/file_edit_events.py +++ b/nanobot/utils/file_edit_events.py @@ -4,13 +4,17 @@ from __future__ import annotations import difflib import json -from dataclasses import dataclass +import re +import time +from dataclasses import dataclass, field from pathlib import Path -from typing import Any +from typing import Any, Awaitable, Callable TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file", "notebook_edit"}) _MAX_SNAPSHOT_BYTES = 2 * 1024 * 1024 +_LIVE_EMIT_INTERVAL_S = 0.18 +_LIVE_EMIT_LINE_STEP = 24 @dataclass(slots=True) @@ -103,6 +107,8 @@ def line_diff_stats(before: str | None, after: str | None) -> tuple[int, int]: """Return ``(added, deleted)`` for a UTF-8 text line-level diff.""" if before is None or after is None: return 0, 0 + if before == "": + return _text_line_count(after), 0 before_lines = before.replace("\r\n", "\n").splitlines() after_lines = after.replace("\r\n", "\n").splitlines() added = 0 @@ -118,6 +124,28 @@ def line_diff_stats(before: str | None, after: str | None) -> tuple[int, int]: return added, deleted +def _text_line_count(text: str) -> int: + if not text: + return 0 + line_count = 0 + last_was_newline = False + last_was_cr = False + for ch in text: + if ch == "\r": + line_count += 1 + last_was_newline = True + last_was_cr = True + elif ch == "\n": + if not last_was_cr: + line_count += 1 + last_was_newline = True + last_was_cr = False + else: + last_was_newline = False + last_was_cr = False + return line_count if last_was_newline else line_count + 1 + + def prepare_file_edit_tracker( *, call_id: str, @@ -160,12 +188,22 @@ def build_file_edit_start_event( ) -def build_file_edit_end_event(tracker: FileEditTracker) -> dict[str, Any]: +def build_file_edit_end_event( + tracker: FileEditTracker, + params: dict[str, Any] | None = None, +) -> dict[str, Any]: after = read_file_snapshot(tracker.path) + counted = False if tracker.before.countable and after.countable: added, deleted = line_diff_stats(tracker.before.text, after.text) + counted = True else: - added, deleted = 0, 0 + predicted_after = _predict_after_text(tracker.tool, params or {}, tracker.before) + if tracker.before.countable and predicted_after is not None: + added, deleted = line_diff_stats(tracker.before.text, predicted_after) + counted = True + else: + added, deleted = 0, 0 return _event_payload( tracker, phase="end", @@ -173,11 +211,14 @@ def build_file_edit_end_event(tracker: FileEditTracker) -> dict[str, Any]: added=added, deleted=deleted, approximate=False, - binary=after.binary or after.oversized or after.unreadable, + binary=(after.binary or after.oversized or after.unreadable) and not counted, ) -def build_file_edit_error_event(tracker: FileEditTracker, error: str | None = None) -> dict[str, Any]: +def build_file_edit_error_event( + tracker: FileEditTracker, + error: str | None = None, +) -> dict[str, Any]: payload = _event_payload( tracker, phase="error", @@ -191,6 +232,427 @@ def build_file_edit_error_event(tracker: FileEditTracker, error: str | None = No return payload +def build_file_edit_live_event( + tracker: FileEditTracker, + *, + added: int, + deleted: int = 0, +) -> dict[str, Any]: + """Build an approximate in-progress event while tool-call arguments stream.""" + return _event_payload( + tracker, + phase="start", + status="editing", + added=added, + deleted=deleted, + approximate=True, + ) + + +def build_file_edit_pending_event( + *, + call_id: str, + tool_name: str, + added: int = 0, + deleted: int = 0, +) -> dict[str, Any]: + """Build an early placeholder before the streamed JSON path is available.""" + return { + "version": 1, + "call_id": str(call_id or ""), + "tool": tool_name, + "path": "", + "phase": "start", + "added": max(0, int(added)), + "deleted": max(0, int(deleted)), + "approximate": True, + "status": "editing", + "pending": True, + } + + +class StreamingFileEditTracker: + """Track file-edit tool arguments while the model is still streaming them. + + Tool execution events only begin after the provider has completed the full + function call. For large ``write_file`` calls, the long wait is usually the + model producing the JSON ``content`` argument. Large ``edit_file`` calls + can have the same wait while ``old_text`` / ``new_text`` stream in. This + tracker converts those argument deltas into approximate WebUI file-edit + events before the final exact diff is available. + """ + + def __init__( + self, + *, + workspace: Path | None, + tools: Any, + emit: Callable[[list[dict[str, Any]]], Awaitable[None]], + ) -> None: + self._workspace = workspace + self._tools = tools + self._emit = emit + self._states: dict[str, _StreamingFileEditState] = {} + + async def update(self, payload: dict[str, Any]) -> None: + key = _stream_key(payload) + if not key: + return + state = self._states.get(key) + if state is None: + state = _StreamingFileEditState(key=key) + self._states[key] = state + + state.apply_delta(payload) + if state.name not in {"write_file", "edit_file"}: + return + if state.path is None: + state.path = _extract_complete_json_string(state.arguments, "path") + if state.path is None: + added, deleted = state.live_diff_counts() + now = time.monotonic() + if state.should_emit_pending(added, deleted, now): + state.mark_pending_emitted(added, deleted, now) + await self._emit([build_file_edit_pending_event( + call_id=state.call_id or state.key, + tool_name=state.name, + added=added, + deleted=deleted, + )]) + return + if state.tracker is None: + tool = self._tools.get(state.name) if hasattr(self._tools, "get") else None + state.tracker = prepare_file_edit_tracker( + call_id=state.call_id or state.key, + tool_name=state.name, + tool=tool, + workspace=self._workspace, + params={"path": state.path}, + ) + if state.tracker is None: + return + + added, deleted = state.live_diff_counts() + now = time.monotonic() + if not state.should_emit(added, deleted, now): + return + state.mark_emitted(added, deleted, now) + await self._emit([build_file_edit_live_event( + state.tracker, + added=added, + deleted=deleted, + )]) + + async def flush(self) -> None: + events: list[dict[str, Any]] = [] + now = time.monotonic() + for state in self._states.values(): + if state.tracker is None: + continue + added, deleted = state.live_diff_counts() + if ( + state.last_emitted_added == added + and state.last_emitted_deleted == deleted + and state.emitted_once + ): + continue + state.mark_emitted(added, deleted, now) + events.append(build_file_edit_live_event( + state.tracker, + added=added, + deleted=deleted, + )) + if events: + await self._emit(events) + + def apply_final_call_ids(self, final_tool_calls: list[Any]) -> None: + """Keep final start/end events keyed to any earlier streamed placeholder.""" + for tool_call in final_tool_calls: + canonical = self.canonical_call_id_for(tool_call) + if canonical: + try: + tool_call.id = canonical + except Exception: + pass + + def canonical_call_id_for(self, tool_call: Any) -> str | None: + for state in self._states.values(): + if state.matches_final_tool_call(tool_call): + return state.call_id or (state.tracker.call_id if state.tracker else None) or state.key + return None + + async def error_unmatched( + self, + final_tool_calls: list[Any], + error: str, + ) -> None: + """Mark streamed edits as failed when no final tool call will run.""" + events: list[dict[str, Any]] = [] + for state in self._states.values(): + if state.tracker is None: + continue + if any(state.matches_final_tool_call(tool_call) for tool_call in final_tool_calls): + continue + events.append(build_file_edit_error_event(state.tracker, error)) + if events: + await self._emit(events) + + +@dataclass(slots=True) +class _StreamingJsonStringField: + key: str + scan_pos: int | None = None + closed: bool = False + escape: bool = False + unicode_remaining: int = 0 + unicode_buffer: str = "" + newline_count: int = 0 + has_chars: bool = False + last_char_newline: bool = False + last_char_cr: bool = False + + @property + def line_count(self) -> int: + if not self.has_chars: + return 0 + return self.newline_count + (0 if self.last_char_newline else 1) + + def reset(self) -> None: + self.scan_pos = None + self.closed = False + self.escape = False + self.unicode_remaining = 0 + self.unicode_buffer = "" + self.newline_count = 0 + self.has_chars = False + self.last_char_newline = False + self.last_char_cr = False + + def scan(self, source: str) -> None: + if self.closed: + return + if self.scan_pos is None: + match = re.search(rf'"{re.escape(self.key)}"\s*:\s*"', source) + if match is None: + return + self.scan_pos = match.end() + i = self.scan_pos + while i < len(source): + ch = source[i] + if self.unicode_remaining > 0: + self.unicode_buffer += ch + self.unicode_remaining -= 1 + if self.unicode_remaining == 0: + try: + decoded = chr(int(self.unicode_buffer, 16)) + except ValueError: + decoded = "x" + self.unicode_buffer = "" + self._mark_char(decoded) + i += 1 + continue + if self.escape: + self.escape = False + if ch == "u": + self.unicode_remaining = 4 + self.unicode_buffer = "" + elif ch == "n": + self._mark_char("\n") + elif ch == "r": + self._mark_char("\r") + else: + self._mark_char(ch) + i += 1 + continue + if ch == "\\": + self.escape = True + i += 1 + continue + if ch == '"': + self.closed = True + i += 1 + break + self._mark_char(ch) + i += 1 + self.scan_pos = i + + def _mark_char(self, ch: str) -> None: + self.has_chars = True + if ch == "\r": + self.newline_count += 1 + self.last_char_newline = True + self.last_char_cr = True + elif ch == "\n": + if not self.last_char_cr: + self.newline_count += 1 + self.last_char_newline = True + self.last_char_cr = False + else: + self.last_char_newline = False + self.last_char_cr = False + + +@dataclass(slots=True) +class _StreamingFileEditState: + key: str + call_id: str = "" + name: str = "" + arguments: str = "" + path: str | None = None + tracker: FileEditTracker | None = None + content: _StreamingJsonStringField = field( + default_factory=lambda: _StreamingJsonStringField("content") + ) + old_text: _StreamingJsonStringField = field( + default_factory=lambda: _StreamingJsonStringField("old_text") + ) + new_text: _StreamingJsonStringField = field( + default_factory=lambda: _StreamingJsonStringField("new_text") + ) + emitted_once: bool = False + last_emitted_added: int = -1 + last_emitted_deleted: int = -1 + last_emit_at: float = 0.0 + pending_emitted: bool = False + last_pending_added: int = -1 + last_pending_deleted: int = -1 + last_pending_at: float = 0.0 + + def apply_delta(self, payload: dict[str, Any]) -> None: + call_id = payload.get("call_id") + if isinstance(call_id, str) and call_id: + self.call_id = call_id + name = payload.get("name") + if isinstance(name, str) and name: + self.name = name + args = payload.get("arguments") + if isinstance(args, str): + self.arguments = args + self.content.reset() + self.old_text.reset() + self.new_text.reset() + return + delta = payload.get("arguments_delta") + if isinstance(delta, str) and delta: + self.arguments += delta + + def live_diff_counts(self) -> tuple[int, int]: + if self.name == "write_file": + self.content.scan(self.arguments) + return self.content.line_count, 0 + if self.name == "edit_file": + self.old_text.scan(self.arguments) + self.new_text.scan(self.arguments) + return self.new_text.line_count, self.old_text.line_count + return 0, 0 + + def should_emit(self, added: int, deleted: int, now: float) -> bool: + if not self.emitted_once: + return True + if added == self.last_emitted_added and deleted == self.last_emitted_deleted: + return False + if max( + abs(added - self.last_emitted_added), + abs(deleted - self.last_emitted_deleted), + ) >= _LIVE_EMIT_LINE_STEP: + return True + return now - self.last_emit_at >= _LIVE_EMIT_INTERVAL_S + + def mark_emitted(self, added: int, deleted: int, now: float) -> None: + self.emitted_once = True + self.last_emitted_added = added + self.last_emitted_deleted = deleted + self.last_emit_at = now + + def should_emit_pending(self, added: int, deleted: int, now: float) -> bool: + if not self.pending_emitted: + return True + if added == self.last_pending_added and deleted == self.last_pending_deleted: + return False + if max( + abs(added - self.last_pending_added), + abs(deleted - self.last_pending_deleted), + ) >= _LIVE_EMIT_LINE_STEP: + return True + return now - self.last_pending_at >= _LIVE_EMIT_INTERVAL_S + + def mark_pending_emitted(self, added: int, deleted: int, now: float) -> None: + self.pending_emitted = True + self.last_pending_added = added + self.last_pending_deleted = deleted + self.last_pending_at = now + + def matches_final_tool_call(self, tool_call: Any) -> bool: + call_id = getattr(tool_call, "id", None) + canonical = self.call_id or (self.tracker.call_id if self.tracker else "") + if isinstance(call_id, str) and call_id and canonical and call_id == canonical: + return True + name = getattr(tool_call, "name", None) + if name != self.name: + return False + arguments = getattr(tool_call, "arguments", None) + if not isinstance(arguments, dict): + return False + path = arguments.get("path") + if self.path is None and isinstance(path, str) and path: + self.path = path + return True + return isinstance(path, str) and path == self.path + + +def _stream_key(payload: dict[str, Any]) -> str: + index = payload.get("index") + if isinstance(index, int): + return f"idx:{index}" + if isinstance(index, str) and index: + return f"idx:{index}" + call_id = payload.get("call_id") + if isinstance(call_id, str) and call_id: + return f"id:{call_id}" + return "" + + +def _extract_complete_json_string(source: str, key: str) -> str | None: + match = re.search(rf'"{re.escape(key)}"\s*:\s*"', source) + if match is None: + return None + out: list[str] = [] + i = match.end() + escape = False + while i < len(source): + ch = source[i] + if escape: + escape = False + if ch == "n": + out.append("\n") + elif ch == "r": + out.append("\r") + elif ch == "t": + out.append("\t") + elif ch == "u": + digits = source[i + 1:i + 5] + if len(digits) < 4: + return None + try: + out.append(chr(int(digits, 16))) + except ValueError: + return None + i += 4 + else: + out.append(ch) + i += 1 + continue + if ch == "\\": + escape = True + i += 1 + continue + if ch == '"': + return "".join(out) + out.append(ch) + i += 1 + return None + + def _event_payload( tracker: FileEditTracker, *, @@ -206,6 +668,7 @@ def _event_payload( "call_id": tracker.call_id, "tool": tracker.tool, "path": tracker.display_path, + "absolute_path": tracker.path.as_posix(), "phase": phase, "added": max(0, int(added)), "deleted": max(0, int(deleted)), @@ -260,8 +723,14 @@ def _predict_notebook_after_text(params: dict[str, Any], before_text: str) -> st return None new_source = params.get("new_source") source = new_source if isinstance(new_source, str) else "" - cell_type = params.get("cell_type") if params.get("cell_type") in ("code", "markdown") else "code" - mode = params.get("edit_mode") if params.get("edit_mode") in ("replace", "insert", "delete") else "replace" + cell_type = ( + params.get("cell_type") if params.get("cell_type") in ("code", "markdown") else "code" + ) + mode = ( + params.get("edit_mode") + if params.get("edit_mode") in ("replace", "insert", "delete") + else "replace" + ) if mode == "delete": if 0 <= cell_index < len(cells): cells.pop(cell_index) diff --git a/nanobot/utils/webui_transcript.py b/nanobot/utils/webui_transcript.py index bee71c542..38444dce6 100644 --- a/nanobot/utils/webui_transcript.py +++ b/nanobot/utils/webui_transcript.py @@ -144,6 +144,17 @@ def replay_transcript_to_ui_messages( def _ensure_activity_segment() -> str: return active_activity_segment_id or _new_activity_segment() + def close_activity_for_answer() -> None: + nonlocal active_activity_segment_id, active_file_edit_segment_id + active_activity_segment_id = None + active_file_edit_segment_id = None + + def close_file_edit_phase_before_activity() -> None: + nonlocal active_activity_segment_id, active_file_edit_segment_id + if active_file_edit_segment_id: + active_activity_segment_id = None + active_file_edit_segment_id = None + def attach_reasoning_chunk(prev: list[dict[str, Any]], chunk: str, idx: int) -> None: for i in range(len(prev) - 1, -1, -1): candidate = prev[i] @@ -243,7 +254,7 @@ def replay_transcript_to_ui_messages( return def absorb_complete(extra: dict[str, Any], idx: int) -> None: - nonlocal active_activity_segment_id + nonlocal active_activity_segment_id, active_file_edit_segment_id last = messages[-1] if messages else None if last and is_reasoning_only_placeholder(last): messages[-1] = { @@ -262,35 +273,50 @@ def replay_transcript_to_ui_messages( }, ) active_activity_segment_id = None + active_file_edit_segment_id = None def _file_edit_key(edit: dict[str, Any]) -> str: - return "|".join( - str(edit.get(k) or "") - for k in ("call_id", "tool", "path") - ) + call_id = str(edit.get("call_id") or "") + tool = str(edit.get("tool") or "") + if call_id: + return f"{call_id}|{tool}" + return f"{tool}|{edit.get('path') or ''}" + + def find_file_edit_trace_index( + segment: str | None, + edits: list[dict[str, Any]], + ) -> int | None: + incoming_keys = {_file_edit_key(edit) for edit in edits if isinstance(edit, dict)} + for i in range(len(messages) - 1, -1, -1): + candidate = messages[i] + if candidate.get("role") == "user": + break + if candidate.get("kind") != "trace" or not candidate.get("fileEdits"): + continue + if segment and candidate.get("activitySegmentId") == segment: + return i + existing_edits = candidate.get("fileEdits") + if not isinstance(existing_edits, list): + continue + for existing in existing_edits: + if isinstance(existing, dict) and _file_edit_key(existing) in incoming_keys: + return i + return None def upsert_file_edits(edits: list[dict[str, Any]], idx: int) -> None: nonlocal active_file_edit_segment_id if not edits: return - last = messages[-1] if messages else None - if ( - active_file_edit_segment_id - and last - and last.get("kind") == "trace" - and last.get("fileEdits") - ): - segment = active_file_edit_segment_id - else: - segment = _new_activity_segment(activate=False) + segment = active_file_edit_segment_id + target_index = find_file_edit_trace_index(segment, edits) + if target_index is not None: + last = messages[target_index] + segment = str(last.get("activitySegmentId") or segment or _new_activity_segment(activate=False)) + active_file_edit_segment_id = segment + else: + if not segment: + segment = _new_activity_segment(activate=False) active_file_edit_segment_id = segment - if not ( - last - and last.get("kind") == "trace" - and not last.get("isStreaming") - and last.get("fileEdits") - and last.get("activitySegmentId") == segment - ): messages.append( { "id": _new_id("tr", idx), @@ -303,7 +329,11 @@ def replay_transcript_to_ui_messages( "createdAt": _ts_base + idx, }, ) - last = messages[-1] + target_index = len(messages) - 1 + last = messages[target_index] + if not segment: + segment = _new_activity_segment(activate=False) + active_file_edit_segment_id = segment existing = list(last.get("fileEdits") or []) index_by_key = { _file_edit_key(edit): pos @@ -316,11 +346,14 @@ def replay_transcript_to_ui_messages( key = _file_edit_key(edit) if key in index_by_key: pos = index_by_key[key] - existing[pos] = {**existing[pos], **edit} + merged = {**existing[pos], **edit} + if edit.get("path") and not edit.get("pending"): + merged.pop("pending", None) + existing[pos] = merged else: index_by_key[key] = len(existing) existing.append(dict(edit)) - messages[-1] = { + messages[target_index] = { **last, "fileEdits": existing, "activitySegmentId": last.get("activitySegmentId") or segment, @@ -365,6 +398,7 @@ def replay_transcript_to_ui_messages( chunk = rec.get("text") if not isinstance(chunk, str): continue + close_activity_for_answer() adopted = find_active_placeholder(messages) if buffer_message_id is None else None if buffer_message_id is None: if adopted: @@ -403,6 +437,7 @@ def replay_transcript_to_ui_messages( chunk = rec.get("text") if not isinstance(chunk, str) or not chunk: continue + close_file_edit_phase_before_activity() attach_reasoning_chunk(messages, chunk, idx) continue @@ -424,6 +459,7 @@ def replay_transcript_to_ui_messages( line = rec.get("text") if not isinstance(line, str) or not line: continue + close_file_edit_phase_before_activity() attach_reasoning_chunk(messages, line, idx) close_reasoning(messages) continue diff --git a/tests/agent/test_loop_progress.py b/tests/agent/test_loop_progress.py index 43a691437..cace4e46c 100644 --- a/tests/agent/test_loop_progress.py +++ b/tests/agent/test_loop_progress.py @@ -309,6 +309,100 @@ class TestToolEventProgress: await invoke_file_edit_progress(telegram_progress, edit_events) assert bus.outbound_size == 0 + @pytest.mark.asyncio + async def test_goal_turn_keeps_live_file_edit_progress_for_webui(self, tmp_path: Path) -> None: + """The /goal command rewrites the prompt but must not bypass WebUI file-edit progress.""" + bus = MessageBus() + provider = MagicMock() + provider.supports_progress_deltas = True + provider.get_default_model.return_value = "test-model" + call_count = 0 + target = tmp_path / "goal.txt" + + async def chat_stream_with_retry(*, on_tool_call_delta=None, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + assert on_tool_call_delta is not None + await on_tool_call_delta({ + "index": 0, + "call_id": "call-goal-write", + "name": "write_file", + "arguments_delta": '{"path":"goal.txt","content":"', + }) + await on_tool_call_delta({ + "index": 0, + "arguments_delta": "one\\ntwo\\nthree\\n", + }) + await on_tool_call_delta({"index": 0, "arguments_delta": '"}'}) + return LLMResponse( + content=None, + tool_calls=[ + ToolCallRequest( + id="call-goal-write", + name="write_file", + arguments={ + "path": "goal.txt", + "content": "one\ntwo\nthree\n", + }, + ) + ], + usage={}, + ) + return LLMResponse(content="Done", tool_calls=[], usage={}) + + async def execute(name: str, params: dict) -> str: + assert name == "write_file" + target.write_text(params["content"], encoding="utf-8") + return "ok" + + provider.chat_stream_with_retry = chat_stream_with_retry + provider.chat_with_retry = AsyncMock() + loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") + loop.tools.get_definitions = MagicMock(return_value=[ + {"type": "function", "function": {"name": "write_file"}}, + ]) + loop.tools.prepare_call = MagicMock( + return_value=( + None, + {"path": "goal.txt", "content": "one\ntwo\nthree\n"}, + None, + ), + ) + loop.tools.execute = AsyncMock(side_effect=execute) + loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] + + await loop._dispatch(InboundMessage( + channel="websocket", + sender_id="u1", + chat_id="chat1", + content="/goal create goal file", + metadata={"_wants_stream": True}, + )) + + outbound = [] + while bus.outbound_size > 0: + outbound.append(await bus.consume_outbound()) + + edit_events = [ + event + for msg in outbound + for event in msg.metadata.get("_file_edit_events", []) + ] + assert any( + event["status"] == "editing" + and event["approximate"] + and event["added"] == 3 + for event in edit_events + ) + assert any( + event["status"] == "done" + and not event["approximate"] + and event["added"] == 3 + for event in edit_events + ) + provider.chat_with_retry.assert_not_awaited() + @pytest.mark.asyncio async def test_non_streaming_channel_does_not_publish_codex_progress_deltas( self, diff --git a/tests/agent/test_runner_progress_deltas.py b/tests/agent/test_runner_progress_deltas.py index 13d5ea799..27a85ab8a 100644 --- a/tests/agent/test_runner_progress_deltas.py +++ b/tests/agent/test_runner_progress_deltas.py @@ -6,7 +6,7 @@ import pytest from nanobot.agent.runner import AgentRunner, AgentRunSpec from nanobot.config.schema import AgentDefaults -from nanobot.providers.base import LLMResponse +from nanobot.providers.base import LLMResponse, ToolCallRequest _MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars @@ -77,3 +77,220 @@ async def test_runner_streams_provider_progress_deltas_by_default(): assert result.final_content == "hello" assert [call.args[0] for call in progress_cb.await_args_list] == ["he", "llo"] provider.chat_with_retry.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_runner_streams_live_write_file_activity_from_tool_argument_deltas(tmp_path): + provider = MagicMock() + provider.supports_progress_deltas = True + call_count = 0 + progress_events: list[dict] = [] + + async def progress_cb(content, *, file_edit_events=None, **kwargs): + if file_edit_events: + progress_events.extend(file_edit_events) + + class Tools: + def get_definitions(self): + return [{"type": "function", "function": {"name": "write_file"}}] + + def get(self, name): + return None + + async def execute(self, name, params): + assert name == "write_file" + assert any(event["approximate"] and event["added"] == 24 for event in progress_events) + target = tmp_path / params["path"] + target.write_text(params["content"], encoding="utf-8") + return "ok" + + async def chat_stream_with_retry(*, on_tool_call_delta=None, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + assert on_tool_call_delta is not None + await on_tool_call_delta({ + "index": 0, + "call_id": "call-write", + "name": "write_file", + "arguments_delta": '{"path":"big.txt","content":"', + }) + await on_tool_call_delta({"index": 0, "arguments_delta": "line\\n" * 24}) + return LLMResponse( + content=None, + tool_calls=[ + ToolCallRequest( + id="call-write", + name="write_file", + arguments={"path": "big.txt", "content": "line\n" * 24}, + ) + ], + usage={}, + ) + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_stream_with_retry = chat_stream_with_retry + provider.chat_with_retry = AsyncMock() + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "write a large file"}], + tools=Tools(), + model="test-model", + max_iterations=2, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + progress_callback=progress_cb, + workspace=tmp_path, + )) + + assert result.final_content == "done" + assert any(event["approximate"] and event["added"] == 24 for event in progress_events) + assert any( + not event["approximate"] and event["phase"] == "end" and event["added"] == 24 + for event in progress_events + ) + provider.chat_with_retry.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_runner_streams_live_edit_file_activity_from_tool_argument_deltas(tmp_path): + provider = MagicMock() + provider.supports_progress_deltas = True + call_count = 0 + progress_events: list[dict] = [] + target = tmp_path / "notes.txt" + target.write_text("old\nkeep\n", encoding="utf-8") + + async def progress_cb(content, *, file_edit_events=None, **kwargs): + if file_edit_events: + progress_events.extend(file_edit_events) + + class Tools: + def get_definitions(self): + return [{"type": "function", "function": {"name": "edit_file"}}] + + def get(self, name): + return None + + async def execute(self, name, params): + assert name == "edit_file" + assert any( + event["tool"] == "edit_file" + and event["approximate"] + and event["added"] == 3 + and event["deleted"] == 2 + for event in progress_events + ) + target.write_text(params["new_text"], encoding="utf-8") + return "ok" + + async def chat_stream_with_retry(*, on_tool_call_delta=None, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + assert on_tool_call_delta is not None + await on_tool_call_delta({ + "index": 0, + "call_id": "call-edit", + "name": "edit_file", + "arguments_delta": ( + '{"path":"notes.txt","old_text":"old\\nkeep\\n","new_text":"' + ), + }) + await on_tool_call_delta({ + "index": 0, + "arguments_delta": "new\\nkeep\\nextra\\n", + }) + await on_tool_call_delta({"index": 0, "arguments_delta": '"}'}) + return LLMResponse( + content=None, + tool_calls=[ + ToolCallRequest( + id="call-edit", + name="edit_file", + arguments={ + "path": "notes.txt", + "old_text": "old\nkeep\n", + "new_text": "new\nkeep\nextra\n", + }, + ) + ], + usage={}, + ) + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_stream_with_retry = chat_stream_with_retry + provider.chat_with_retry = AsyncMock() + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "edit a file"}], + tools=Tools(), + model="test-model", + max_iterations=2, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + progress_callback=progress_cb, + workspace=tmp_path, + )) + + assert result.final_content == "done" + assert any( + event["tool"] == "edit_file" + and event["approximate"] + and event["added"] == 3 + and event["deleted"] == 2 + for event in progress_events + ) + assert any( + event["tool"] == "edit_file" + and not event["approximate"] + and event["phase"] == "end" + and event["added"] == 2 + and event["deleted"] == 1 + for event in progress_events + ) + provider.chat_with_retry.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_runner_marks_unfinished_live_write_file_activity_failed(tmp_path): + provider = MagicMock() + provider.supports_progress_deltas = True + progress_events: list[dict] = [] + + async def progress_cb(content, *, file_edit_events=None, **kwargs): + if file_edit_events: + progress_events.extend(file_edit_events) + + async def chat_stream_with_retry(*, on_tool_call_delta=None, **kwargs): + assert on_tool_call_delta is not None + await on_tool_call_delta({ + "index": 0, + "call_id": "call-write", + "name": "write_file", + "arguments_delta": '{"path":"aborted.txt","content":"partial\\n', + }) + return LLMResponse(content="stopped", tool_calls=[], finish_reason="stop", usage={}) + + provider.chat_stream_with_retry = chat_stream_with_retry + provider.chat_with_retry = AsyncMock() + tools = MagicMock() + tools.get_definitions.return_value = [{"type": "function", "function": {"name": "write_file"}}] + tools.get.return_value = None + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "write a large file"}], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + progress_callback=progress_cb, + workspace=tmp_path, + )) + + assert result.final_content == "stopped" + assert progress_events[-1]["path"] == "aborted.txt" + assert progress_events[-1]["phase"] == "error" + assert progress_events[-1]["status"] == "error" + provider.chat_with_retry.assert_not_awaited() diff --git a/tests/providers/test_anthropic_stream_idle.py b/tests/providers/test_anthropic_stream_idle.py index da4939bf7..d46f291fb 100644 --- a/tests/providers/test_anthropic_stream_idle.py +++ b/tests/providers/test_anthropic_stream_idle.py @@ -129,6 +129,74 @@ async def test_chat_stream_invokes_on_thinking_delta_for_thinking_delta() -> Non assert text_parts == ["X"] +@pytest.mark.asyncio +async def test_chat_stream_invokes_tool_call_delta_for_input_json_delta() -> None: + provider = AnthropicProvider(api_key="sk-test") + provider._client = MagicMock() + + chunks = [ + SimpleNamespace( + type="content_block_start", + index=1, + content_block=SimpleNamespace( + type="tool_use", + id="toolu_1", + name="write_file", + ), + ), + SimpleNamespace( + type="content_block_delta", + index=1, + delta=SimpleNamespace( + type="input_json_delta", + partial_json='{"path":"notes.md","content":"', + ), + ), + SimpleNamespace( + type="content_block_delta", + index=1, + delta=SimpleNamespace(type="input_json_delta", partial_json="line\\n"), + ), + ] + fake = _FakeAsyncStream(chunks) + stream_cm = MagicMock() + stream_cm.__aenter__ = AsyncMock(return_value=fake) + stream_cm.__aexit__ = AsyncMock(return_value=None) + provider._client.messages.stream = MagicMock(return_value=stream_cm) + + deltas: list[dict] = [] + + async def on_tool_delta(delta: dict) -> None: + deltas.append(delta) + + await provider.chat_stream( + messages=[{"role": "user", "content": "write"}], + on_tool_call_delta=on_tool_delta, + ) + + assert deltas == [ + { + "index": 1, + "call_id": "toolu_1", + "name": "write_file", + "arguments_delta": "", + }, + { + "index": 1, + "call_id": "toolu_1", + "name": "write_file", + "arguments_delta": '{"path":"notes.md","content":"', + }, + { + "index": 1, + "call_id": "toolu_1", + "name": "write_file", + "arguments_delta": "line\\n", + }, + ] + fake.get_final_message.assert_awaited_once() + + @pytest.mark.asyncio async def test_chat_stream_without_callback_still_finalizes() -> None: provider = AnthropicProvider(api_key="sk-test") diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index 7ae97159c..3acb2e76c 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -164,6 +164,130 @@ def _fake_chat_stream_reasoning_chunks(): return _stream() +def _fake_chat_stream_tool_call_chunks(): + """Mimic OpenAI-compatible streaming tool-call argument deltas.""" + + async def _stream(): + yield SimpleNamespace( + choices=[ + SimpleNamespace( + finish_reason=None, + delta=SimpleNamespace( + content=None, + reasoning_content=None, + reasoning=None, + tool_calls=[ + SimpleNamespace( + index=0, + id="call_write", + function=SimpleNamespace( + name="write_file", + arguments='{"path":"notes.md","content":"', + ), + ) + ], + ), + ), + ], + usage=None, + ) + yield SimpleNamespace( + choices=[ + SimpleNamespace( + finish_reason=None, + delta=SimpleNamespace( + content=None, + reasoning_content=None, + reasoning=None, + tool_calls=[ + SimpleNamespace( + index=0, + id=None, + function=SimpleNamespace(name=None, arguments='line\\n"}'), + ) + ], + ), + ), + ], + usage=None, + ) + yield SimpleNamespace( + choices=[ + SimpleNamespace( + finish_reason="tool_calls", + delta=SimpleNamespace( + content=None, + reasoning_content=None, + reasoning=None, + tool_calls=None, + ), + ), + ], + usage=SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ) + + return _stream() + + +def _fake_chat_stream_legacy_function_call_chunks(): + """Mimic older OpenAI-compatible ``delta.function_call`` chunks.""" + + async def _stream(): + yield SimpleNamespace( + choices=[ + SimpleNamespace( + finish_reason=None, + delta=SimpleNamespace( + content=None, + reasoning_content=None, + reasoning=None, + tool_calls=None, + function_call=SimpleNamespace( + name="write_file", + arguments='{"path":"notes.md","content":"', + ), + ), + ), + ], + usage=None, + ) + yield SimpleNamespace( + choices=[ + SimpleNamespace( + finish_reason=None, + delta=SimpleNamespace( + content=None, + reasoning_content=None, + reasoning=None, + tool_calls=None, + function_call=SimpleNamespace( + name=None, + arguments='line\\n"}', + ), + ), + ), + ], + usage=None, + ) + yield SimpleNamespace( + choices=[ + SimpleNamespace( + finish_reason="function_call", + delta=SimpleNamespace( + content=None, + reasoning_content=None, + reasoning=None, + tool_calls=None, + function_call=None, + ), + ), + ], + usage=SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ) + + return _stream() + + @pytest.mark.asyncio async def test_openai_compat_stream_forwards_reasoning_deltas_deepseek_style() -> None: """Regression: DeepSeek-V4 / reasoner expose ``delta.reasoning_content`` during streaming.""" @@ -202,6 +326,98 @@ async def test_openai_compat_stream_forwards_reasoning_deltas_deepseek_style() - mock_chat.assert_awaited_once() +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("provider_name", "model"), + [ + ("openai", "gpt-4o"), + ("deepseek", "deepseek-chat"), + ("minimax", "MiniMax-M2.7"), + ("zhipu", "glm-4.6"), + ], +) +async def test_openai_compat_stream_forwards_tool_call_argument_deltas( + provider_name: str, + model: str, +) -> None: + mock_chat = AsyncMock(return_value=_fake_chat_stream_tool_call_chunks()) + spec = find_by_name(provider_name) + deltas: list[dict] = [] + + async def on_tool_delta(delta: dict) -> None: + deltas.append(delta) + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_openai: + client_instance = mock_openai.return_value + client_instance.chat.completions.create = mock_chat + + provider = OpenAICompatProvider( + api_key="sk-test", + default_model=model, + spec=spec, + ) + result = await provider.chat_stream( + messages=[{"role": "user", "content": "write"}], + tools=[{"type": "function", "function": {"name": "write_file"}}], + model=model, + on_tool_call_delta=on_tool_delta, + ) + + assert deltas == [ + { + "index": 0, + "call_id": "call_write", + "name": "write_file", + "arguments_delta": '{"path":"notes.md","content":"', + }, + {"index": 0, "call_id": "", "name": "", "arguments_delta": 'line\\n"}'}, + ] + assert result.tool_calls[0].name == "write_file" + assert result.tool_calls[0].arguments == {"path": "notes.md", "content": "line\n"} + kwargs = mock_chat.await_args.kwargs + if provider_name == "zhipu": + assert kwargs["extra_body"]["tool_stream"] is True + else: + assert kwargs.get("extra_body", {}).get("tool_stream") is None + + +@pytest.mark.asyncio +async def test_openai_compat_stream_forwards_legacy_function_call_argument_deltas() -> None: + mock_chat = AsyncMock(return_value=_fake_chat_stream_legacy_function_call_chunks()) + deltas: list[dict] = [] + + async def on_tool_delta(delta: dict) -> None: + deltas.append(delta) + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_openai: + client_instance = mock_openai.return_value + client_instance.chat.completions.create = mock_chat + + provider = OpenAICompatProvider( + api_key="sk-test", + default_model="deepseek-chat", + spec=find_by_name("deepseek"), + ) + result = await provider.chat_stream( + messages=[{"role": "user", "content": "write"}], + tools=[{"type": "function", "function": {"name": "write_file"}}], + model="deepseek-chat", + on_tool_call_delta=on_tool_delta, + ) + + assert deltas == [ + { + "index": 0, + "call_id": "", + "name": "write_file", + "arguments_delta": '{"path":"notes.md","content":"', + }, + {"index": 0, "call_id": "", "name": "", "arguments_delta": 'line\\n"}'}, + ] + assert result.tool_calls[0].name == "write_file" + assert result.tool_calls[0].arguments == {"path": "notes.md", "content": "line\n"} + + class _FakeResponsesError(Exception): def __init__(self, status_code: int, text: str): super().__init__(text) diff --git a/tests/providers/test_llm_response.py b/tests/providers/test_llm_response.py index ca9644dc2..fff0ccaa7 100644 --- a/tests/providers/test_llm_response.py +++ b/tests/providers/test_llm_response.py @@ -44,9 +44,15 @@ class TestShouldExecuteTools: resp = _response("stop") assert resp.should_execute_tools is True + def test_legacy_function_call_reason_executes(self) -> None: + # Older OpenAI-compatible streaming APIs can still use the singular + # function_call finish reason while carrying a tool-call-shaped payload. + resp = _response("function_call") + assert resp.should_execute_tools is True + @pytest.mark.parametrize( "anomalous_reason", - ["refusal", "content_filter", "error", "length", "function_call", ""], + ["refusal", "content_filter", "error", "length", ""], ) def test_tool_calls_under_anomalous_reason_blocked(self, anomalous_reason: str) -> None: # This is the #3220 bug: gateways injecting tool_calls under any of these diff --git a/tests/providers/test_openai_responses.py b/tests/providers/test_openai_responses.py index ce4220655..74a934f85 100644 --- a/tests/providers/test_openai_responses.py +++ b/tests/providers/test_openai_responses.py @@ -453,6 +453,56 @@ class TestConsumeSdkStream: assert tool_calls[0].name == "get_weather" assert tool_calls[0].arguments == {"city": "SF"} + @pytest.mark.asyncio + async def test_tool_call_argument_delta_callback(self): + item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="") + item_added.name = "write_file" + ev1 = MagicMock(type="response.output_item.added", item=item_added) + ev2 = MagicMock( + type="response.function_call_arguments.delta", + call_id="c1", + delta='{"path":"a.txt","content":"', + ) + ev3 = MagicMock( + type="response.function_call_arguments.delta", + call_id="c1", + delta='hello\\n', + ) + ev4 = MagicMock( + type="response.function_call_arguments.done", + call_id="c1", + arguments='{"path":"a.txt","content":"hello\\n"}', + ) + item_done = MagicMock( + type="function_call", + call_id="c1", + id="fc1", + arguments='{"path":"a.txt","content":"hello\\n"}', + ) + item_done.name = "write_file" + ev5 = MagicMock(type="response.output_item.done", item=item_done) + resp_obj = MagicMock(status="completed", usage=None, output=[]) + ev6 = MagicMock(type="response.completed", response=resp_obj) + deltas: list[dict] = [] + + async def cb(delta: dict) -> None: + deltas.append(delta) + + async def stream(): + for e in [ev1, ev2, ev3, ev4, ev5, ev6]: + yield e + + await consume_sdk_stream(stream(), on_tool_call_delta=cb) + assert deltas == [ + {"call_id": "c1", "name": "write_file", "arguments_delta": ""}, + { + "call_id": "c1", + "name": "write_file", + "arguments_delta": '{"path":"a.txt","content":"', + }, + {"call_id": "c1", "name": "write_file", "arguments_delta": "hello\\n"}, + ] + @pytest.mark.asyncio async def test_usage_extracted(self): usage_obj = MagicMock(input_tokens=10, output_tokens=5, total_tokens=15) diff --git a/tests/utils/test_file_edit_events.py b/tests/utils/test_file_edit_events.py index 6176a5e36..9180032cf 100644 --- a/tests/utils/test_file_edit_events.py +++ b/tests/utils/test_file_edit_events.py @@ -1,6 +1,8 @@ from __future__ import annotations +import asyncio from pathlib import Path +from types import SimpleNamespace from nanobot.utils.file_edit_events import ( build_file_edit_end_event, @@ -8,6 +10,7 @@ from nanobot.utils.file_edit_events import ( line_diff_stats, prepare_file_edit_tracker, read_file_snapshot, + StreamingFileEditTracker, ) @@ -20,6 +23,10 @@ def test_line_diff_stats_normalizes_crlf() -> None: assert line_diff_stats("a\r\nb\r\n", "a\nb\nc\n") == (1, 0) +def test_line_diff_stats_counts_new_file_crlf_lines_once() -> None: + assert line_diff_stats("", "a\r\nb\r\n") == (2, 0) + + def test_write_file_start_predicts_and_end_calibrates_exact_diff(tmp_path: Path) -> None: target = tmp_path / "notes.txt" target.write_text("old\nkeep\n", encoding="utf-8") @@ -39,6 +46,7 @@ def test_write_file_start_predicts_and_end_calibrates_exact_diff(tmp_path: Path) "call_id": "call-write", "tool": "write_file", "path": "notes.txt", + "absolute_path": (tmp_path / "notes.txt").as_posix(), "phase": "start", "added": 2, "deleted": 1, @@ -73,6 +81,307 @@ def test_binary_file_is_reported_but_not_counted(tmp_path: Path) -> None: assert (event["added"], event["deleted"]) == (0, 0) +def test_oversized_write_file_end_uses_known_content_for_exact_count(tmp_path: Path) -> None: + target = tmp_path / "large.txt" + params = {"path": "large.txt", "content": "x" * (2 * 1024 * 1024 + 1)} + tracker = prepare_file_edit_tracker( + call_id="call-large", + tool_name="write_file", + tool=None, + workspace=tmp_path, + params=params, + ) + + assert tracker is not None + target.write_text(params["content"], encoding="utf-8") + event = build_file_edit_end_event(tracker, params) + assert event.get("binary") is not True + assert event["added"] == 1 + assert event["deleted"] == 0 + + +def test_streaming_write_file_tracker_emits_live_line_counts(tmp_path: Path) -> None: + events: list[dict] = [] + + async def emit(batch: list[dict]) -> None: + events.extend(batch) + + async def run() -> None: + tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit) + await tracker.update({ + "index": 0, + "call_id": "call-live", + "name": "write_file", + "arguments_delta": '{"path":"notes.md","content":"', + }) + await tracker.update({ + "index": 0, + "arguments_delta": "line\\n" * 24, + }) + + asyncio.run(run()) + + assert events[0] == { + "version": 1, + "call_id": "call-live", + "tool": "write_file", + "path": "notes.md", + "absolute_path": (tmp_path / "notes.md").as_posix(), + "phase": "start", + "added": 0, + "deleted": 0, + "approximate": True, + "status": "editing", + } + assert events[-1]["path"] == "notes.md" + assert events[-1]["status"] == "editing" + assert events[-1]["approximate"] is True + assert events[-1]["added"] == 24 + assert events[-1]["deleted"] == 0 + + +def test_streaming_write_file_tracker_emits_pending_before_path(tmp_path: Path) -> None: + events: list[dict] = [] + + async def emit(batch: list[dict]) -> None: + events.extend(batch) + + async def run() -> None: + tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit) + await tracker.update({ + "index": 0, + "call_id": "call-live", + "name": "write_file", + "arguments_delta": '{"content":"line\\n', + }) + await tracker.update({ + "index": 0, + "arguments_delta": 'more\\n","path":"late.md"', + }) + + asyncio.run(run()) + + assert events[0] == { + "version": 1, + "call_id": "call-live", + "tool": "write_file", + "path": "", + "phase": "start", + "added": 1, + "deleted": 0, + "approximate": True, + "status": "editing", + "pending": True, + } + assert events[-1]["path"] == "late.md" + assert events[-1].get("pending") is not True + assert events[-1]["added"] == 2 + + +def test_streaming_write_file_tracker_flushes_small_pending_count(tmp_path: Path) -> None: + events: list[dict] = [] + + async def emit(batch: list[dict]) -> None: + events.extend(batch) + + async def run() -> None: + tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit) + await tracker.update({ + "index": 0, + "call_id": "call-live", + "name": "write_file", + "arguments_delta": '{"path":"small.md","content":"one\\n', + }) + await tracker.flush() + + asyncio.run(run()) + assert events + assert events[-1]["path"] == "small.md" + assert events[-1]["added"] == 1 + + +def test_streaming_write_file_tracker_normalizes_crlf_line_counts(tmp_path: Path) -> None: + events: list[dict] = [] + + async def emit(batch: list[dict]) -> None: + events.extend(batch) + + async def run() -> None: + tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit) + await tracker.update({ + "index": 0, + "call_id": "call-live", + "name": "write_file", + "arguments_delta": '{"path":"windows.txt","content":"one\\r\\ntwo\\r\\n', + }) + await tracker.flush() + + asyncio.run(run()) + assert events[-1]["path"] == "windows.txt" + assert events[-1]["added"] == 2 + + +def test_streaming_write_file_tracker_counts_unicode_escaped_newlines(tmp_path: Path) -> None: + events: list[dict] = [] + + async def emit(batch: list[dict]) -> None: + events.extend(batch) + + async def run() -> None: + tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit) + await tracker.update({ + "index": 0, + "call_id": "call-live", + "name": "write_file", + "arguments_delta": '{"path":"unicode.txt","content":"one\\u000atwo', + }) + await tracker.flush() + + asyncio.run(run()) + assert events[-1]["path"] == "unicode.txt" + assert events[-1]["added"] == 2 + + +def test_streaming_edit_file_tracker_emits_live_line_counts(tmp_path: Path) -> None: + target = tmp_path / "notes.md" + target.write_text("old\nkeep\n", encoding="utf-8") + events: list[dict] = [] + + async def emit(batch: list[dict]) -> None: + events.extend(batch) + + async def run() -> None: + tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit) + await tracker.update({ + "index": 0, + "call_id": "call-edit", + "name": "edit_file", + "arguments_delta": '{"path":"notes.md","old_text":"old\\nkeep","new_text":"', + }) + await tracker.update({ + "index": 0, + "arguments_delta": "new\\nkeep\\nextra\\n" * 8, + }) + + asyncio.run(run()) + + assert events[0] == { + "version": 1, + "call_id": "call-edit", + "tool": "edit_file", + "path": "notes.md", + "absolute_path": (tmp_path / "notes.md").as_posix(), + "phase": "start", + "added": 0, + "deleted": 2, + "approximate": True, + "status": "editing", + } + assert events[-1]["path"] == "notes.md" + assert events[-1]["status"] == "editing" + assert events[-1]["approximate"] is True + assert events[-1]["added"] == 24 + assert events[-1]["deleted"] == 2 + + +def test_streaming_tracker_applies_canonical_call_id_to_final_tool(tmp_path: Path) -> None: + events: list[dict] = [] + + async def emit(batch: list[dict]) -> None: + events.extend(batch) + + async def run() -> None: + tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit) + await tracker.update({ + "index": 0, + "name": "write_file", + "arguments_delta": '{"path":"matched.md","content":"one\\n', + }) + final = SimpleNamespace( + id="provider-final-id", + name="write_file", + arguments={"path": "matched.md", "content": "one\n"}, + ) + tracker.apply_final_call_ids([final]) + assert final.id == "idx:0" + + asyncio.run(run()) + + +def test_streaming_edit_file_tracker_flushes_small_pending_count(tmp_path: Path) -> None: + target = tmp_path / "small.py" + target.write_text("old\n", encoding="utf-8") + events: list[dict] = [] + + async def emit(batch: list[dict]) -> None: + events.extend(batch) + + async def run() -> None: + tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit) + await tracker.update({ + "index": 0, + "call_id": "call-edit", + "name": "edit_file", + "arguments_delta": '{"path":"small.py","old_text":"old\\n","new_text":"new\\nextra', + }) + await tracker.flush() + + asyncio.run(run()) + assert events + assert events[-1]["path"] == "small.py" + assert events[-1]["added"] == 2 + assert events[-1]["deleted"] == 1 + + +def test_streaming_write_file_tracker_errors_unmatched_live_edits(tmp_path: Path) -> None: + events: list[dict] = [] + + async def emit(batch: list[dict]) -> None: + events.extend(batch) + + async def run() -> None: + tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit) + await tracker.update({ + "index": 0, + "call_id": "call-live", + "name": "write_file", + "arguments_delta": '{"path":"aborted.md","content":"one\\n', + }) + await tracker.error_unmatched([], "Tool call did not complete.") + + asyncio.run(run()) + assert events[-1]["path"] == "aborted.md" + assert events[-1]["phase"] == "error" + assert events[-1]["status"] == "error" + + +def test_streaming_write_file_tracker_keeps_matched_final_tool_call(tmp_path: Path) -> None: + events: list[dict] = [] + + async def emit(batch: list[dict]) -> None: + events.extend(batch) + + async def run() -> None: + tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit) + await tracker.update({ + "index": 0, + "call_id": "idx-only", + "name": "write_file", + "arguments_delta": '{"path":"matched.md","content":"one\\n', + }) + await tracker.error_unmatched([ + SimpleNamespace( + id="final-call", + name="write_file", + arguments={"path": "matched.md", "content": "one\n"}, + ) + ], "Tool call did not complete.") + + asyncio.run(run()) + assert events + assert all(event["status"] == "editing" for event in events) + + def test_untracked_tools_do_not_prepare_file_edit_tracker(tmp_path: Path) -> None: assert prepare_file_edit_tracker( call_id="call-exec", diff --git a/tests/utils/test_webui_transcript.py b/tests/utils/test_webui_transcript.py index f13380f46..42736c9b1 100644 --- a/tests/utils/test_webui_transcript.py +++ b/tests/utils/test_webui_transcript.py @@ -98,6 +98,201 @@ def test_replay_file_edit_event_creates_file_activity(tmp_path, monkeypatch) -> assert msgs[2]["activitySegmentId"] != msgs[1]["activitySegmentId"] +def test_replay_file_edit_progress_merges_after_interleaved_activity(tmp_path, monkeypatch) -> None: + monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path) + key = "websocket:t-file-progress" + for ev in ( + {"event": "user", "chat_id": "t-file-progress", "text": "edit"}, + { + "event": "message", + "chat_id": "t-file-progress", + "text": 'write_file({"path":"foo.txt"})', + "kind": "tool_hint", + }, + { + "event": "file_edit", + "chat_id": "t-file-progress", + "edits": [ + { + "version": 1, + "call_id": "call-write", + "tool": "write_file", + "path": "foo.txt", + "phase": "start", + "added": 12, + "deleted": 0, + "approximate": True, + "status": "editing", + }, + ], + }, + { + "event": "message", + "chat_id": "t-file-progress", + "text": "still working", + "kind": "progress", + }, + { + "event": "file_edit", + "chat_id": "t-file-progress", + "edits": [ + { + "version": 1, + "call_id": "call-write", + "tool": "write_file", + "path": "foo.txt", + "phase": "end", + "added": 30, + "deleted": 0, + "approximate": False, + "status": "done", + }, + ], + }, + ): + append_transcript_object(key, ev) + + msgs = replay_transcript_to_ui_messages(read_transcript_lines(key)) + file_edit_messages = [msg for msg in msgs if msg.get("fileEdits")] + + assert len(file_edit_messages) == 1 + assert file_edit_messages[0]["fileEdits"] == [ + { + "version": 1, + "call_id": "call-write", + "tool": "write_file", + "path": "foo.txt", + "phase": "end", + "added": 30, + "deleted": 0, + "approximate": False, + "status": "done", + }, + ] + + +def test_replay_file_edit_pending_placeholder_upgrades_to_path(tmp_path, monkeypatch) -> None: + monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path) + key = "websocket:t-file-pending" + for ev in ( + {"event": "user", "chat_id": "t-file-pending", "text": "write"}, + { + "event": "file_edit", + "chat_id": "t-file-pending", + "edits": [ + { + "version": 1, + "call_id": "call-write", + "tool": "write_file", + "path": "", + "phase": "start", + "added": 1, + "deleted": 0, + "approximate": True, + "status": "editing", + "pending": True, + }, + ], + }, + { + "event": "file_edit", + "chat_id": "t-file-pending", + "edits": [ + { + "version": 1, + "call_id": "call-write", + "tool": "write_file", + "path": "foo.txt", + "phase": "start", + "added": 12, + "deleted": 0, + "approximate": True, + "status": "editing", + }, + ], + }, + ): + append_transcript_object(key, ev) + + msgs = replay_transcript_to_ui_messages(read_transcript_lines(key)) + file_edit_messages = [msg for msg in msgs if msg.get("fileEdits")] + + assert len(file_edit_messages) == 1 + assert file_edit_messages[0]["fileEdits"] == [ + { + "version": 1, + "call_id": "call-write", + "tool": "write_file", + "path": "foo.txt", + "phase": "start", + "added": 12, + "deleted": 0, + "approximate": True, + "status": "editing", + }, + ] + + +def test_replay_keeps_new_file_edit_after_reasoning_in_order(tmp_path, monkeypatch) -> None: + monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path) + key = "websocket:t-file-order" + for ev in ( + {"event": "user", "chat_id": "t-file-order", "text": "edit"}, + { + "event": "file_edit", + "chat_id": "t-file-order", + "edits": [ + { + "version": 1, + "call_id": "call-one", + "tool": "write_file", + "path": "one.txt", + "phase": "start", + "added": 10, + "deleted": 0, + "approximate": True, + "status": "editing", + }, + ], + }, + {"event": "reasoning_delta", "chat_id": "t-file-order", "text": "Check next."}, + {"event": "reasoning_end", "chat_id": "t-file-order"}, + { + "event": "file_edit", + "chat_id": "t-file-order", + "edits": [ + { + "version": 1, + "call_id": "call-two", + "tool": "write_file", + "path": "two.txt", + "phase": "start", + "added": 20, + "deleted": 0, + "approximate": True, + "status": "editing", + }, + ], + }, + ): + append_transcript_object(key, ev) + + msgs = replay_transcript_to_ui_messages(read_transcript_lines(key)) + + assert [msg.get("fileEdits", [{}])[0].get("path") if msg.get("fileEdits") else msg.get("reasoning") for msg in msgs[1:]] == [ + "one.txt", + "Check next.", + "two.txt", + ] + file_edit_segments = [ + msg.get("activitySegmentId") + for msg in msgs + if msg.get("fileEdits") + ] + assert len(file_edit_segments) == 2 + assert file_edit_segments[0] != file_edit_segments[1] + + def test_build_response_schema(monkeypatch, tmp_path) -> None: from nanobot.utils.webui_transcript import build_webui_thread_response