mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 08:02:30 +00:00
feat(webui): stream live file edit events
This commit is contained in:
parent
d4ade8f680
commit
7e2dbdef7d
@ -20,6 +20,7 @@ from nanobot.utils.file_edit_events import (
|
||||
build_file_edit_error_event,
|
||||
build_file_edit_start_event,
|
||||
prepare_file_edit_tracker,
|
||||
StreamingFileEditTracker,
|
||||
)
|
||||
from nanobot.utils.helpers import (
|
||||
IncrementalThinkExtractor,
|
||||
@ -629,6 +630,24 @@ class AgentRunner:
|
||||
)
|
||||
|
||||
progress_state: dict[str, bool] | None = None
|
||||
live_file_edits: StreamingFileEditTracker | None = None
|
||||
|
||||
if (
|
||||
spec.progress_callback is not None
|
||||
and on_progress_accepts_file_edit_events(spec.progress_callback)
|
||||
):
|
||||
async def _emit_live_file_edits(events: list[dict[str, Any]]) -> None:
|
||||
await invoke_file_edit_progress(spec.progress_callback, events)
|
||||
|
||||
live_file_edits = StreamingFileEditTracker(
|
||||
workspace=spec.workspace,
|
||||
tools=spec.tools,
|
||||
emit=_emit_live_file_edits,
|
||||
)
|
||||
|
||||
async def _tool_call_delta(delta: dict[str, Any]) -> None:
|
||||
if live_file_edits is not None:
|
||||
await live_file_edits.update(delta)
|
||||
|
||||
if wants_streaming:
|
||||
async def _stream(delta: str) -> None:
|
||||
@ -646,6 +665,7 @@ class AgentRunner:
|
||||
**kwargs,
|
||||
on_content_delta=_stream,
|
||||
on_thinking_delta=_thinking,
|
||||
on_tool_call_delta=_tool_call_delta if live_file_edits is not None else None,
|
||||
)
|
||||
elif wants_progress_streaming:
|
||||
stream_buf = ""
|
||||
@ -675,6 +695,7 @@ class AgentRunner:
|
||||
coro = self.provider.chat_stream_with_retry(
|
||||
**kwargs,
|
||||
on_content_delta=_stream_progress,
|
||||
on_tool_call_delta=_tool_call_delta if live_file_edits is not None else None,
|
||||
)
|
||||
else:
|
||||
coro = self.provider.chat_with_retry(**kwargs)
|
||||
@ -689,6 +710,14 @@ class AgentRunner:
|
||||
await coro if outer_timeout_s is None
|
||||
else await asyncio.wait_for(coro, timeout=outer_timeout_s)
|
||||
)
|
||||
if live_file_edits is not None:
|
||||
await live_file_edits.flush()
|
||||
if response.should_execute_tools:
|
||||
live_file_edits.apply_final_call_ids(response.tool_calls)
|
||||
await live_file_edits.error_unmatched(
|
||||
response.tool_calls if response.should_execute_tools else [],
|
||||
"Tool call did not complete.",
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
if outer_timeout_s is None:
|
||||
return LLMResponse(
|
||||
@ -907,7 +936,10 @@ class AgentRunner:
|
||||
if file_edit_tracker is not None and progress_callback is not None:
|
||||
await invoke_file_edit_progress(
|
||||
progress_callback,
|
||||
[build_file_edit_end_event(file_edit_tracker)],
|
||||
[build_file_edit_end_event(
|
||||
file_edit_tracker,
|
||||
params if isinstance(params, dict) else None,
|
||||
)],
|
||||
)
|
||||
|
||||
detail = "" if result is None else str(result)
|
||||
|
||||
@ -590,6 +590,7 @@ class AnthropicProvider(LLMProvider):
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_thinking_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
kwargs = self._build_kwargs(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
@ -598,11 +599,12 @@ class AnthropicProvider(LLMProvider):
|
||||
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
|
||||
try:
|
||||
async with self._client.messages.stream(**kwargs) as stream:
|
||||
if on_content_delta or on_thinking_delta:
|
||||
if on_content_delta or on_thinking_delta or on_tool_call_delta:
|
||||
# Idle timeout must track *any* SSE chunk (thinking_delta,
|
||||
# tool JSON deltas, etc.), not only text_stream tokens.
|
||||
# Otherwise extended thinking can stall text_stream for minutes
|
||||
# while the connection is healthy (e.g. MiniMax Anthropic).
|
||||
tool_blocks: dict[int, dict[str, str]] = {}
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(
|
||||
@ -611,7 +613,22 @@ class AnthropicProvider(LLMProvider):
|
||||
)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
if (
|
||||
if chunk.type == "content_block_start":
|
||||
block = getattr(chunk, "content_block", None)
|
||||
if getattr(block, "type", None) == "tool_use":
|
||||
index = int(getattr(chunk, "index", 0) or 0)
|
||||
state = {
|
||||
"call_id": str(getattr(block, "id", "") or ""),
|
||||
"name": str(getattr(block, "name", "") or ""),
|
||||
}
|
||||
tool_blocks[index] = state
|
||||
if on_tool_call_delta:
|
||||
await on_tool_call_delta({
|
||||
"index": index,
|
||||
**state,
|
||||
"arguments_delta": "",
|
||||
})
|
||||
elif (
|
||||
chunk.type == "content_block_delta"
|
||||
and getattr(chunk.delta, "type", None) == "thinking_delta"
|
||||
):
|
||||
@ -625,6 +642,20 @@ class AnthropicProvider(LLMProvider):
|
||||
text = getattr(chunk.delta, "text", None) or ""
|
||||
if text and on_content_delta:
|
||||
await on_content_delta(text)
|
||||
elif (
|
||||
chunk.type == "content_block_delta"
|
||||
and getattr(chunk.delta, "type", None) == "input_json_delta"
|
||||
):
|
||||
partial = getattr(chunk.delta, "partial_json", None) or ""
|
||||
if partial and on_tool_call_delta:
|
||||
index = int(getattr(chunk, "index", 0) or 0)
|
||||
state = tool_blocks.get(index, {})
|
||||
await on_tool_call_delta({
|
||||
"index": index,
|
||||
"call_id": state.get("call_id", ""),
|
||||
"name": state.get("name", ""),
|
||||
"arguments_delta": partial,
|
||||
})
|
||||
response = await asyncio.wait_for(
|
||||
stream.get_final_message(),
|
||||
timeout=idle_timeout_s,
|
||||
|
||||
@ -158,6 +158,7 @@ class AzureOpenAIProvider(LLMProvider):
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_thinking_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
_ = on_thinking_delta
|
||||
body = self._build_body(
|
||||
@ -169,7 +170,7 @@ class AzureOpenAIProvider(LLMProvider):
|
||||
try:
|
||||
stream = await self._client.responses.create(**body)
|
||||
content, tool_calls, finish_reason, usage, reasoning_content = (
|
||||
await consume_sdk_stream(stream, on_content_delta)
|
||||
await consume_sdk_stream(stream, on_content_delta, on_tool_call_delta)
|
||||
)
|
||||
return LLMResponse(
|
||||
content=content or None,
|
||||
|
||||
@ -70,11 +70,11 @@ class LLMResponse:
|
||||
|
||||
@property
|
||||
def should_execute_tools(self) -> bool:
|
||||
"""Tools execute only when has_tool_calls AND finish_reason is ``tool_calls`` / ``stop``.
|
||||
"""Tools execute only when has_tool_calls AND finish_reason is a tool-capable stop.
|
||||
Blocks gateway-injected calls under ``refusal`` / ``content_filter`` / ``error`` (#3220)."""
|
||||
if not self.has_tool_calls:
|
||||
return False
|
||||
return self.finish_reason in ("tool_calls", "stop")
|
||||
return self.finish_reason in ("tool_calls", "function_call", "stop")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -501,6 +501,7 @@ class LLMProvider(ABC):
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_thinking_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Stream a chat completion, calling *on_content_delta* for each text chunk.
|
||||
|
||||
@ -514,7 +515,7 @@ class LLMProvider(ABC):
|
||||
full content as a single delta. Providers that support native
|
||||
streaming should override this method.
|
||||
"""
|
||||
_ = on_thinking_delta
|
||||
_ = on_thinking_delta, on_tool_call_delta
|
||||
response = await self.chat(
|
||||
messages=messages, tools=tools, model=model,
|
||||
max_tokens=max_tokens, temperature=temperature,
|
||||
@ -544,6 +545,7 @@ class LLMProvider(ABC):
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_thinking_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None,
|
||||
retry_mode: str = "standard",
|
||||
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
@ -561,6 +563,7 @@ class LLMProvider(ABC):
|
||||
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
||||
on_content_delta=on_content_delta,
|
||||
on_thinking_delta=on_thinking_delta,
|
||||
on_tool_call_delta=on_tool_call_delta,
|
||||
)
|
||||
return await self._run_with_retry(
|
||||
self._safe_chat_stream,
|
||||
|
||||
@ -704,8 +704,9 @@ class BedrockProvider(LLMProvider):
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_thinking_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
_ = on_thinking_delta
|
||||
_ = on_thinking_delta, on_tool_call_delta
|
||||
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
|
||||
content_parts: list[str] = []
|
||||
reasoning_parts: list[str] = []
|
||||
|
||||
@ -243,6 +243,7 @@ class GitHubCopilotProvider(OpenAICompatProvider):
|
||||
tool_choice: str | dict[str, object] | None = None,
|
||||
on_content_delta: Callable[[str], None] | None = None,
|
||||
on_thinking_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_tool_call_delta: Callable[[dict[str, object]], Awaitable[None]] | None = None,
|
||||
):
|
||||
await self._refresh_client_api_key()
|
||||
return await super().chat_stream(
|
||||
@ -255,4 +256,5 @@ class GitHubCopilotProvider(OpenAICompatProvider):
|
||||
tool_choice=tool_choice,
|
||||
on_content_delta=on_content_delta,
|
||||
on_thinking_delta=on_thinking_delta,
|
||||
on_tool_call_delta=on_tool_call_delta,
|
||||
)
|
||||
|
||||
@ -40,6 +40,7 @@ class OpenAICodexProvider(LLMProvider):
|
||||
reasoning_effort: str | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Shared request logic for both chat() and chat_stream()."""
|
||||
model = model or self.default_model
|
||||
@ -70,6 +71,7 @@ class OpenAICodexProvider(LLMProvider):
|
||||
content, tool_calls, finish_reason = await _request_codex(
|
||||
DEFAULT_CODEX_URL, headers, body, verify=True,
|
||||
on_content_delta=on_content_delta,
|
||||
on_tool_call_delta=on_tool_call_delta,
|
||||
)
|
||||
except Exception as e:
|
||||
if "CERTIFICATE_VERIFY_FAILED" not in str(e):
|
||||
@ -78,6 +80,7 @@ class OpenAICodexProvider(LLMProvider):
|
||||
content, tool_calls, finish_reason = await _request_codex(
|
||||
DEFAULT_CODEX_URL, headers, body, verify=False,
|
||||
on_content_delta=on_content_delta,
|
||||
on_tool_call_delta=on_tool_call_delta,
|
||||
)
|
||||
return LLMResponse(content=content, tool_calls=tool_calls, finish_reason=finish_reason)
|
||||
except Exception as e:
|
||||
@ -100,9 +103,18 @@ class OpenAICodexProvider(LLMProvider):
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_thinking_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
_ = on_thinking_delta
|
||||
return await self._call_codex(messages, tools, model, reasoning_effort, tool_choice, on_content_delta)
|
||||
return await self._call_codex(
|
||||
messages,
|
||||
tools,
|
||||
model,
|
||||
reasoning_effort,
|
||||
tool_choice,
|
||||
on_content_delta,
|
||||
on_tool_call_delta,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
@ -138,6 +150,7 @@ async def _request_codex(
|
||||
body: dict[str, Any],
|
||||
verify: bool,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None,
|
||||
) -> tuple[str, list[ToolCallRequest], str]:
|
||||
async with httpx.AsyncClient(timeout=60.0, verify=verify) as client:
|
||||
async with client.stream("POST", url, headers=headers, json=body) as response:
|
||||
@ -148,7 +161,7 @@ async def _request_codex(
|
||||
_friendly_error(response.status_code, text.decode("utf-8", "ignore")),
|
||||
retry_after=retry_after,
|
||||
)
|
||||
return await consume_sse(response, on_content_delta)
|
||||
return await consume_sse(response, on_content_delta, on_tool_call_delta)
|
||||
|
||||
|
||||
def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
|
||||
|
||||
@ -999,6 +999,21 @@ class OpenAICompatProvider(LLMProvider):
|
||||
if fn_prov:
|
||||
buf["fn_prov"] = fn_prov
|
||||
|
||||
def _accum_legacy_function_call(function_call: Any) -> None:
|
||||
"""Accumulate legacy ``delta.function_call`` streaming chunks."""
|
||||
if not function_call:
|
||||
return
|
||||
buf = tc_bufs.setdefault(0, {
|
||||
"id": "", "name": "", "arguments": "",
|
||||
"extra_content": None, "prov": None, "fn_prov": None,
|
||||
})
|
||||
fn_name = _get(function_call, "name")
|
||||
if fn_name:
|
||||
buf["name"] = str(fn_name)
|
||||
fn_args = _get(function_call, "arguments")
|
||||
if fn_args:
|
||||
buf["arguments"] += str(fn_args)
|
||||
|
||||
for chunk in chunks:
|
||||
if isinstance(chunk, str):
|
||||
content_parts.append(chunk)
|
||||
@ -1029,6 +1044,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
reasoning_parts.append(text)
|
||||
for idx, tc in enumerate(delta.get("tool_calls") or []):
|
||||
_accum_tc(tc, idx)
|
||||
_accum_legacy_function_call(delta.get("function_call"))
|
||||
usage = cls._extract_usage(chunk_map) or usage
|
||||
continue
|
||||
|
||||
@ -1047,8 +1063,10 @@ class OpenAICompatProvider(LLMProvider):
|
||||
reasoning = getattr(delta, "reasoning", None)
|
||||
if reasoning:
|
||||
reasoning_parts.append(reasoning)
|
||||
for tc in (delta.tool_calls or []) if delta else []:
|
||||
for tc in (getattr(delta, "tool_calls", None) or []) if delta else []:
|
||||
_accum_tc(tc, getattr(tc, "index", 0))
|
||||
if delta:
|
||||
_accum_legacy_function_call(getattr(delta, "function_call", None))
|
||||
|
||||
return LLMResponse(
|
||||
content="".join(content_parts) or None,
|
||||
@ -1203,6 +1221,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_thinking_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
|
||||
try:
|
||||
@ -1226,9 +1245,16 @@ class OpenAICompatProvider(LLMProvider):
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
content, tool_calls, finish_reason, usage, reasoning_content = await consume_sdk_stream(
|
||||
(
|
||||
content,
|
||||
tool_calls,
|
||||
finish_reason,
|
||||
usage,
|
||||
reasoning_content,
|
||||
) = await consume_sdk_stream(
|
||||
_timed_stream(),
|
||||
on_content_delta,
|
||||
on_tool_call_delta=on_tool_call_delta,
|
||||
)
|
||||
self._record_responses_success(model, reasoning_effort)
|
||||
return LLMResponse(
|
||||
@ -1252,6 +1278,12 @@ class OpenAICompatProvider(LLMProvider):
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
if self._spec and self._spec.name == "zhipu" and tools and on_tool_call_delta:
|
||||
# Z.AI/GLM keeps streaming tool-call arguments behind an
|
||||
# explicit provider flag. Pass it through the OpenAI SDK's
|
||||
# extra_body escape hatch so the usual delta.tool_calls path
|
||||
# can surface live file-edit progress.
|
||||
kwargs.setdefault("extra_body", {})["tool_stream"] = True
|
||||
kwargs["stream"] = True
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
stream = await self._client.chat.completions.create(**kwargs)
|
||||
@ -1279,6 +1311,28 @@ class OpenAICompatProvider(LLMProvider):
|
||||
r_text = self._extract_text_content(reasoning)
|
||||
if r_text:
|
||||
await on_thinking_delta(r_text)
|
||||
if on_tool_call_delta:
|
||||
for idx, tool_delta in enumerate(
|
||||
getattr(delta_obj, "tool_calls", None) or []
|
||||
):
|
||||
fn = _get(tool_delta, "function")
|
||||
tool_index = _get(tool_delta, "index")
|
||||
await on_tool_call_delta({
|
||||
"index": tool_index if tool_index is not None else idx,
|
||||
"call_id": str(_get(tool_delta, "id") or ""),
|
||||
"name": str(_get(fn, "name") or "") if fn is not None else "",
|
||||
"arguments_delta": (
|
||||
str(_get(fn, "arguments") or "") if fn is not None else ""
|
||||
),
|
||||
})
|
||||
function_call = getattr(delta_obj, "function_call", None)
|
||||
if function_call:
|
||||
await on_tool_call_delta({
|
||||
"index": 0,
|
||||
"call_id": "",
|
||||
"name": str(_get(function_call, "name") or ""),
|
||||
"arguments_delta": str(_get(function_call, "arguments") or ""),
|
||||
})
|
||||
return self._parse_chunks(chunks)
|
||||
except asyncio.TimeoutError:
|
||||
return LLMResponse(
|
||||
|
||||
@ -62,6 +62,7 @@ async def iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], N
|
||||
async def consume_sse(
|
||||
response: httpx.Response,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None,
|
||||
) -> tuple[str, list[ToolCallRequest], str]:
|
||||
"""Consume a Responses API SSE stream into ``(content, tool_calls, finish_reason)``."""
|
||||
content = ""
|
||||
@ -82,6 +83,12 @@ async def consume_sse(
|
||||
"name": item.get("name"),
|
||||
"arguments": item.get("arguments") or "",
|
||||
}
|
||||
if on_tool_call_delta:
|
||||
await on_tool_call_delta({
|
||||
"call_id": str(call_id),
|
||||
"name": str(item.get("name") or ""),
|
||||
"arguments_delta": "",
|
||||
})
|
||||
elif event_type == "response.output_text.delta":
|
||||
delta_text = event.get("delta") or ""
|
||||
content += delta_text
|
||||
@ -90,7 +97,14 @@ async def consume_sse(
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] += event.get("delta") or ""
|
||||
delta = event.get("delta") or ""
|
||||
tool_call_buffers[call_id]["arguments"] += delta
|
||||
if on_tool_call_delta and delta:
|
||||
await on_tool_call_delta({
|
||||
"call_id": str(call_id),
|
||||
"name": str(tool_call_buffers[call_id].get("name") or ""),
|
||||
"arguments_delta": str(delta),
|
||||
})
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
@ -210,6 +224,7 @@ def parse_response_output(response: Any) -> LLMResponse:
|
||||
async def consume_sdk_stream(
|
||||
stream: Any,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None,
|
||||
) -> tuple[str, list[ToolCallRequest], str, dict[str, int], str | None]:
|
||||
"""Consume an SDK async stream from ``client.responses.create(stream=True)``."""
|
||||
content = ""
|
||||
@ -232,6 +247,12 @@ async def consume_sdk_stream(
|
||||
"name": getattr(item, "name", None),
|
||||
"arguments": getattr(item, "arguments", None) or "",
|
||||
}
|
||||
if on_tool_call_delta:
|
||||
await on_tool_call_delta({
|
||||
"call_id": str(call_id),
|
||||
"name": str(getattr(item, "name", None) or ""),
|
||||
"arguments_delta": "",
|
||||
})
|
||||
elif event_type == "response.output_text.delta":
|
||||
delta_text = getattr(event, "delta", "") or ""
|
||||
content += delta_text
|
||||
@ -240,7 +261,14 @@ async def consume_sdk_stream(
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
call_id = getattr(event, "call_id", None)
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] += getattr(event, "delta", "") or ""
|
||||
delta = getattr(event, "delta", "") or ""
|
||||
tool_call_buffers[call_id]["arguments"] += delta
|
||||
if on_tool_call_delta and delta:
|
||||
await on_tool_call_delta({
|
||||
"call_id": str(call_id),
|
||||
"name": str(tool_call_buffers[call_id].get("name") or ""),
|
||||
"arguments_delta": str(delta),
|
||||
})
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
call_id = getattr(event, "call_id", None)
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
|
||||
@ -4,13 +4,17 @@ from __future__ import annotations
|
||||
|
||||
import difflib
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
|
||||
TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file", "notebook_edit"})
|
||||
_MAX_SNAPSHOT_BYTES = 2 * 1024 * 1024
|
||||
_LIVE_EMIT_INTERVAL_S = 0.18
|
||||
_LIVE_EMIT_LINE_STEP = 24
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -103,6 +107,8 @@ def line_diff_stats(before: str | None, after: str | None) -> tuple[int, int]:
|
||||
"""Return ``(added, deleted)`` for a UTF-8 text line-level diff."""
|
||||
if before is None or after is None:
|
||||
return 0, 0
|
||||
if before == "":
|
||||
return _text_line_count(after), 0
|
||||
before_lines = before.replace("\r\n", "\n").splitlines()
|
||||
after_lines = after.replace("\r\n", "\n").splitlines()
|
||||
added = 0
|
||||
@ -118,6 +124,28 @@ def line_diff_stats(before: str | None, after: str | None) -> tuple[int, int]:
|
||||
return added, deleted
|
||||
|
||||
|
||||
def _text_line_count(text: str) -> int:
|
||||
if not text:
|
||||
return 0
|
||||
line_count = 0
|
||||
last_was_newline = False
|
||||
last_was_cr = False
|
||||
for ch in text:
|
||||
if ch == "\r":
|
||||
line_count += 1
|
||||
last_was_newline = True
|
||||
last_was_cr = True
|
||||
elif ch == "\n":
|
||||
if not last_was_cr:
|
||||
line_count += 1
|
||||
last_was_newline = True
|
||||
last_was_cr = False
|
||||
else:
|
||||
last_was_newline = False
|
||||
last_was_cr = False
|
||||
return line_count if last_was_newline else line_count + 1
|
||||
|
||||
|
||||
def prepare_file_edit_tracker(
|
||||
*,
|
||||
call_id: str,
|
||||
@ -160,12 +188,22 @@ def build_file_edit_start_event(
|
||||
)
|
||||
|
||||
|
||||
def build_file_edit_end_event(tracker: FileEditTracker) -> dict[str, Any]:
|
||||
def build_file_edit_end_event(
|
||||
tracker: FileEditTracker,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
after = read_file_snapshot(tracker.path)
|
||||
counted = False
|
||||
if tracker.before.countable and after.countable:
|
||||
added, deleted = line_diff_stats(tracker.before.text, after.text)
|
||||
counted = True
|
||||
else:
|
||||
added, deleted = 0, 0
|
||||
predicted_after = _predict_after_text(tracker.tool, params or {}, tracker.before)
|
||||
if tracker.before.countable and predicted_after is not None:
|
||||
added, deleted = line_diff_stats(tracker.before.text, predicted_after)
|
||||
counted = True
|
||||
else:
|
||||
added, deleted = 0, 0
|
||||
return _event_payload(
|
||||
tracker,
|
||||
phase="end",
|
||||
@ -173,11 +211,14 @@ def build_file_edit_end_event(tracker: FileEditTracker) -> dict[str, Any]:
|
||||
added=added,
|
||||
deleted=deleted,
|
||||
approximate=False,
|
||||
binary=after.binary or after.oversized or after.unreadable,
|
||||
binary=(after.binary or after.oversized or after.unreadable) and not counted,
|
||||
)
|
||||
|
||||
|
||||
def build_file_edit_error_event(tracker: FileEditTracker, error: str | None = None) -> dict[str, Any]:
|
||||
def build_file_edit_error_event(
|
||||
tracker: FileEditTracker,
|
||||
error: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
payload = _event_payload(
|
||||
tracker,
|
||||
phase="error",
|
||||
@ -191,6 +232,427 @@ def build_file_edit_error_event(tracker: FileEditTracker, error: str | None = No
|
||||
return payload
|
||||
|
||||
|
||||
def build_file_edit_live_event(
|
||||
tracker: FileEditTracker,
|
||||
*,
|
||||
added: int,
|
||||
deleted: int = 0,
|
||||
) -> dict[str, Any]:
|
||||
"""Build an approximate in-progress event while tool-call arguments stream."""
|
||||
return _event_payload(
|
||||
tracker,
|
||||
phase="start",
|
||||
status="editing",
|
||||
added=added,
|
||||
deleted=deleted,
|
||||
approximate=True,
|
||||
)
|
||||
|
||||
|
||||
def build_file_edit_pending_event(
|
||||
*,
|
||||
call_id: str,
|
||||
tool_name: str,
|
||||
added: int = 0,
|
||||
deleted: int = 0,
|
||||
) -> dict[str, Any]:
|
||||
"""Build an early placeholder before the streamed JSON path is available."""
|
||||
return {
|
||||
"version": 1,
|
||||
"call_id": str(call_id or ""),
|
||||
"tool": tool_name,
|
||||
"path": "",
|
||||
"phase": "start",
|
||||
"added": max(0, int(added)),
|
||||
"deleted": max(0, int(deleted)),
|
||||
"approximate": True,
|
||||
"status": "editing",
|
||||
"pending": True,
|
||||
}
|
||||
|
||||
|
||||
class StreamingFileEditTracker:
|
||||
"""Track file-edit tool arguments while the model is still streaming them.
|
||||
|
||||
Tool execution events only begin after the provider has completed the full
|
||||
function call. For large ``write_file`` calls, the long wait is usually the
|
||||
model producing the JSON ``content`` argument. Large ``edit_file`` calls
|
||||
can have the same wait while ``old_text`` / ``new_text`` stream in. This
|
||||
tracker converts those argument deltas into approximate WebUI file-edit
|
||||
events before the final exact diff is available.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
workspace: Path | None,
|
||||
tools: Any,
|
||||
emit: Callable[[list[dict[str, Any]]], Awaitable[None]],
|
||||
) -> None:
|
||||
self._workspace = workspace
|
||||
self._tools = tools
|
||||
self._emit = emit
|
||||
self._states: dict[str, _StreamingFileEditState] = {}
|
||||
|
||||
async def update(self, payload: dict[str, Any]) -> None:
|
||||
key = _stream_key(payload)
|
||||
if not key:
|
||||
return
|
||||
state = self._states.get(key)
|
||||
if state is None:
|
||||
state = _StreamingFileEditState(key=key)
|
||||
self._states[key] = state
|
||||
|
||||
state.apply_delta(payload)
|
||||
if state.name not in {"write_file", "edit_file"}:
|
||||
return
|
||||
if state.path is None:
|
||||
state.path = _extract_complete_json_string(state.arguments, "path")
|
||||
if state.path is None:
|
||||
added, deleted = state.live_diff_counts()
|
||||
now = time.monotonic()
|
||||
if state.should_emit_pending(added, deleted, now):
|
||||
state.mark_pending_emitted(added, deleted, now)
|
||||
await self._emit([build_file_edit_pending_event(
|
||||
call_id=state.call_id or state.key,
|
||||
tool_name=state.name,
|
||||
added=added,
|
||||
deleted=deleted,
|
||||
)])
|
||||
return
|
||||
if state.tracker is None:
|
||||
tool = self._tools.get(state.name) if hasattr(self._tools, "get") else None
|
||||
state.tracker = prepare_file_edit_tracker(
|
||||
call_id=state.call_id or state.key,
|
||||
tool_name=state.name,
|
||||
tool=tool,
|
||||
workspace=self._workspace,
|
||||
params={"path": state.path},
|
||||
)
|
||||
if state.tracker is None:
|
||||
return
|
||||
|
||||
added, deleted = state.live_diff_counts()
|
||||
now = time.monotonic()
|
||||
if not state.should_emit(added, deleted, now):
|
||||
return
|
||||
state.mark_emitted(added, deleted, now)
|
||||
await self._emit([build_file_edit_live_event(
|
||||
state.tracker,
|
||||
added=added,
|
||||
deleted=deleted,
|
||||
)])
|
||||
|
||||
async def flush(self) -> None:
|
||||
events: list[dict[str, Any]] = []
|
||||
now = time.monotonic()
|
||||
for state in self._states.values():
|
||||
if state.tracker is None:
|
||||
continue
|
||||
added, deleted = state.live_diff_counts()
|
||||
if (
|
||||
state.last_emitted_added == added
|
||||
and state.last_emitted_deleted == deleted
|
||||
and state.emitted_once
|
||||
):
|
||||
continue
|
||||
state.mark_emitted(added, deleted, now)
|
||||
events.append(build_file_edit_live_event(
|
||||
state.tracker,
|
||||
added=added,
|
||||
deleted=deleted,
|
||||
))
|
||||
if events:
|
||||
await self._emit(events)
|
||||
|
||||
def apply_final_call_ids(self, final_tool_calls: list[Any]) -> None:
|
||||
"""Keep final start/end events keyed to any earlier streamed placeholder."""
|
||||
for tool_call in final_tool_calls:
|
||||
canonical = self.canonical_call_id_for(tool_call)
|
||||
if canonical:
|
||||
try:
|
||||
tool_call.id = canonical
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def canonical_call_id_for(self, tool_call: Any) -> str | None:
|
||||
for state in self._states.values():
|
||||
if state.matches_final_tool_call(tool_call):
|
||||
return state.call_id or (state.tracker.call_id if state.tracker else None) or state.key
|
||||
return None
|
||||
|
||||
async def error_unmatched(
|
||||
self,
|
||||
final_tool_calls: list[Any],
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Mark streamed edits as failed when no final tool call will run."""
|
||||
events: list[dict[str, Any]] = []
|
||||
for state in self._states.values():
|
||||
if state.tracker is None:
|
||||
continue
|
||||
if any(state.matches_final_tool_call(tool_call) for tool_call in final_tool_calls):
|
||||
continue
|
||||
events.append(build_file_edit_error_event(state.tracker, error))
|
||||
if events:
|
||||
await self._emit(events)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _StreamingJsonStringField:
|
||||
key: str
|
||||
scan_pos: int | None = None
|
||||
closed: bool = False
|
||||
escape: bool = False
|
||||
unicode_remaining: int = 0
|
||||
unicode_buffer: str = ""
|
||||
newline_count: int = 0
|
||||
has_chars: bool = False
|
||||
last_char_newline: bool = False
|
||||
last_char_cr: bool = False
|
||||
|
||||
@property
|
||||
def line_count(self) -> int:
|
||||
if not self.has_chars:
|
||||
return 0
|
||||
return self.newline_count + (0 if self.last_char_newline else 1)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.scan_pos = None
|
||||
self.closed = False
|
||||
self.escape = False
|
||||
self.unicode_remaining = 0
|
||||
self.unicode_buffer = ""
|
||||
self.newline_count = 0
|
||||
self.has_chars = False
|
||||
self.last_char_newline = False
|
||||
self.last_char_cr = False
|
||||
|
||||
def scan(self, source: str) -> None:
|
||||
if self.closed:
|
||||
return
|
||||
if self.scan_pos is None:
|
||||
match = re.search(rf'"{re.escape(self.key)}"\s*:\s*"', source)
|
||||
if match is None:
|
||||
return
|
||||
self.scan_pos = match.end()
|
||||
i = self.scan_pos
|
||||
while i < len(source):
|
||||
ch = source[i]
|
||||
if self.unicode_remaining > 0:
|
||||
self.unicode_buffer += ch
|
||||
self.unicode_remaining -= 1
|
||||
if self.unicode_remaining == 0:
|
||||
try:
|
||||
decoded = chr(int(self.unicode_buffer, 16))
|
||||
except ValueError:
|
||||
decoded = "x"
|
||||
self.unicode_buffer = ""
|
||||
self._mark_char(decoded)
|
||||
i += 1
|
||||
continue
|
||||
if self.escape:
|
||||
self.escape = False
|
||||
if ch == "u":
|
||||
self.unicode_remaining = 4
|
||||
self.unicode_buffer = ""
|
||||
elif ch == "n":
|
||||
self._mark_char("\n")
|
||||
elif ch == "r":
|
||||
self._mark_char("\r")
|
||||
else:
|
||||
self._mark_char(ch)
|
||||
i += 1
|
||||
continue
|
||||
if ch == "\\":
|
||||
self.escape = True
|
||||
i += 1
|
||||
continue
|
||||
if ch == '"':
|
||||
self.closed = True
|
||||
i += 1
|
||||
break
|
||||
self._mark_char(ch)
|
||||
i += 1
|
||||
self.scan_pos = i
|
||||
|
||||
def _mark_char(self, ch: str) -> None:
|
||||
self.has_chars = True
|
||||
if ch == "\r":
|
||||
self.newline_count += 1
|
||||
self.last_char_newline = True
|
||||
self.last_char_cr = True
|
||||
elif ch == "\n":
|
||||
if not self.last_char_cr:
|
||||
self.newline_count += 1
|
||||
self.last_char_newline = True
|
||||
self.last_char_cr = False
|
||||
else:
|
||||
self.last_char_newline = False
|
||||
self.last_char_cr = False
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _StreamingFileEditState:
|
||||
key: str
|
||||
call_id: str = ""
|
||||
name: str = ""
|
||||
arguments: str = ""
|
||||
path: str | None = None
|
||||
tracker: FileEditTracker | None = None
|
||||
content: _StreamingJsonStringField = field(
|
||||
default_factory=lambda: _StreamingJsonStringField("content")
|
||||
)
|
||||
old_text: _StreamingJsonStringField = field(
|
||||
default_factory=lambda: _StreamingJsonStringField("old_text")
|
||||
)
|
||||
new_text: _StreamingJsonStringField = field(
|
||||
default_factory=lambda: _StreamingJsonStringField("new_text")
|
||||
)
|
||||
emitted_once: bool = False
|
||||
last_emitted_added: int = -1
|
||||
last_emitted_deleted: int = -1
|
||||
last_emit_at: float = 0.0
|
||||
pending_emitted: bool = False
|
||||
last_pending_added: int = -1
|
||||
last_pending_deleted: int = -1
|
||||
last_pending_at: float = 0.0
|
||||
|
||||
def apply_delta(self, payload: dict[str, Any]) -> None:
|
||||
call_id = payload.get("call_id")
|
||||
if isinstance(call_id, str) and call_id:
|
||||
self.call_id = call_id
|
||||
name = payload.get("name")
|
||||
if isinstance(name, str) and name:
|
||||
self.name = name
|
||||
args = payload.get("arguments")
|
||||
if isinstance(args, str):
|
||||
self.arguments = args
|
||||
self.content.reset()
|
||||
self.old_text.reset()
|
||||
self.new_text.reset()
|
||||
return
|
||||
delta = payload.get("arguments_delta")
|
||||
if isinstance(delta, str) and delta:
|
||||
self.arguments += delta
|
||||
|
||||
def live_diff_counts(self) -> tuple[int, int]:
|
||||
if self.name == "write_file":
|
||||
self.content.scan(self.arguments)
|
||||
return self.content.line_count, 0
|
||||
if self.name == "edit_file":
|
||||
self.old_text.scan(self.arguments)
|
||||
self.new_text.scan(self.arguments)
|
||||
return self.new_text.line_count, self.old_text.line_count
|
||||
return 0, 0
|
||||
|
||||
def should_emit(self, added: int, deleted: int, now: float) -> bool:
|
||||
if not self.emitted_once:
|
||||
return True
|
||||
if added == self.last_emitted_added and deleted == self.last_emitted_deleted:
|
||||
return False
|
||||
if max(
|
||||
abs(added - self.last_emitted_added),
|
||||
abs(deleted - self.last_emitted_deleted),
|
||||
) >= _LIVE_EMIT_LINE_STEP:
|
||||
return True
|
||||
return now - self.last_emit_at >= _LIVE_EMIT_INTERVAL_S
|
||||
|
||||
def mark_emitted(self, added: int, deleted: int, now: float) -> None:
|
||||
self.emitted_once = True
|
||||
self.last_emitted_added = added
|
||||
self.last_emitted_deleted = deleted
|
||||
self.last_emit_at = now
|
||||
|
||||
def should_emit_pending(self, added: int, deleted: int, now: float) -> bool:
|
||||
if not self.pending_emitted:
|
||||
return True
|
||||
if added == self.last_pending_added and deleted == self.last_pending_deleted:
|
||||
return False
|
||||
if max(
|
||||
abs(added - self.last_pending_added),
|
||||
abs(deleted - self.last_pending_deleted),
|
||||
) >= _LIVE_EMIT_LINE_STEP:
|
||||
return True
|
||||
return now - self.last_pending_at >= _LIVE_EMIT_INTERVAL_S
|
||||
|
||||
def mark_pending_emitted(self, added: int, deleted: int, now: float) -> None:
|
||||
self.pending_emitted = True
|
||||
self.last_pending_added = added
|
||||
self.last_pending_deleted = deleted
|
||||
self.last_pending_at = now
|
||||
|
||||
def matches_final_tool_call(self, tool_call: Any) -> bool:
|
||||
call_id = getattr(tool_call, "id", None)
|
||||
canonical = self.call_id or (self.tracker.call_id if self.tracker else "")
|
||||
if isinstance(call_id, str) and call_id and canonical and call_id == canonical:
|
||||
return True
|
||||
name = getattr(tool_call, "name", None)
|
||||
if name != self.name:
|
||||
return False
|
||||
arguments = getattr(tool_call, "arguments", None)
|
||||
if not isinstance(arguments, dict):
|
||||
return False
|
||||
path = arguments.get("path")
|
||||
if self.path is None and isinstance(path, str) and path:
|
||||
self.path = path
|
||||
return True
|
||||
return isinstance(path, str) and path == self.path
|
||||
|
||||
|
||||
def _stream_key(payload: dict[str, Any]) -> str:
|
||||
index = payload.get("index")
|
||||
if isinstance(index, int):
|
||||
return f"idx:{index}"
|
||||
if isinstance(index, str) and index:
|
||||
return f"idx:{index}"
|
||||
call_id = payload.get("call_id")
|
||||
if isinstance(call_id, str) and call_id:
|
||||
return f"id:{call_id}"
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_complete_json_string(source: str, key: str) -> str | None:
|
||||
match = re.search(rf'"{re.escape(key)}"\s*:\s*"', source)
|
||||
if match is None:
|
||||
return None
|
||||
out: list[str] = []
|
||||
i = match.end()
|
||||
escape = False
|
||||
while i < len(source):
|
||||
ch = source[i]
|
||||
if escape:
|
||||
escape = False
|
||||
if ch == "n":
|
||||
out.append("\n")
|
||||
elif ch == "r":
|
||||
out.append("\r")
|
||||
elif ch == "t":
|
||||
out.append("\t")
|
||||
elif ch == "u":
|
||||
digits = source[i + 1:i + 5]
|
||||
if len(digits) < 4:
|
||||
return None
|
||||
try:
|
||||
out.append(chr(int(digits, 16)))
|
||||
except ValueError:
|
||||
return None
|
||||
i += 4
|
||||
else:
|
||||
out.append(ch)
|
||||
i += 1
|
||||
continue
|
||||
if ch == "\\":
|
||||
escape = True
|
||||
i += 1
|
||||
continue
|
||||
if ch == '"':
|
||||
return "".join(out)
|
||||
out.append(ch)
|
||||
i += 1
|
||||
return None
|
||||
|
||||
|
||||
def _event_payload(
|
||||
tracker: FileEditTracker,
|
||||
*,
|
||||
@ -206,6 +668,7 @@ def _event_payload(
|
||||
"call_id": tracker.call_id,
|
||||
"tool": tracker.tool,
|
||||
"path": tracker.display_path,
|
||||
"absolute_path": tracker.path.as_posix(),
|
||||
"phase": phase,
|
||||
"added": max(0, int(added)),
|
||||
"deleted": max(0, int(deleted)),
|
||||
@ -260,8 +723,14 @@ def _predict_notebook_after_text(params: dict[str, Any], before_text: str) -> st
|
||||
return None
|
||||
new_source = params.get("new_source")
|
||||
source = new_source if isinstance(new_source, str) else ""
|
||||
cell_type = params.get("cell_type") if params.get("cell_type") in ("code", "markdown") else "code"
|
||||
mode = params.get("edit_mode") if params.get("edit_mode") in ("replace", "insert", "delete") else "replace"
|
||||
cell_type = (
|
||||
params.get("cell_type") if params.get("cell_type") in ("code", "markdown") else "code"
|
||||
)
|
||||
mode = (
|
||||
params.get("edit_mode")
|
||||
if params.get("edit_mode") in ("replace", "insert", "delete")
|
||||
else "replace"
|
||||
)
|
||||
if mode == "delete":
|
||||
if 0 <= cell_index < len(cells):
|
||||
cells.pop(cell_index)
|
||||
|
||||
@ -144,6 +144,17 @@ def replay_transcript_to_ui_messages(
|
||||
def _ensure_activity_segment() -> str:
|
||||
return active_activity_segment_id or _new_activity_segment()
|
||||
|
||||
def close_activity_for_answer() -> None:
|
||||
nonlocal active_activity_segment_id, active_file_edit_segment_id
|
||||
active_activity_segment_id = None
|
||||
active_file_edit_segment_id = None
|
||||
|
||||
def close_file_edit_phase_before_activity() -> None:
|
||||
nonlocal active_activity_segment_id, active_file_edit_segment_id
|
||||
if active_file_edit_segment_id:
|
||||
active_activity_segment_id = None
|
||||
active_file_edit_segment_id = None
|
||||
|
||||
def attach_reasoning_chunk(prev: list[dict[str, Any]], chunk: str, idx: int) -> None:
|
||||
for i in range(len(prev) - 1, -1, -1):
|
||||
candidate = prev[i]
|
||||
@ -243,7 +254,7 @@ def replay_transcript_to_ui_messages(
|
||||
return
|
||||
|
||||
def absorb_complete(extra: dict[str, Any], idx: int) -> None:
|
||||
nonlocal active_activity_segment_id
|
||||
nonlocal active_activity_segment_id, active_file_edit_segment_id
|
||||
last = messages[-1] if messages else None
|
||||
if last and is_reasoning_only_placeholder(last):
|
||||
messages[-1] = {
|
||||
@ -262,35 +273,50 @@ def replay_transcript_to_ui_messages(
|
||||
},
|
||||
)
|
||||
active_activity_segment_id = None
|
||||
active_file_edit_segment_id = None
|
||||
|
||||
def _file_edit_key(edit: dict[str, Any]) -> str:
|
||||
return "|".join(
|
||||
str(edit.get(k) or "")
|
||||
for k in ("call_id", "tool", "path")
|
||||
)
|
||||
call_id = str(edit.get("call_id") or "")
|
||||
tool = str(edit.get("tool") or "")
|
||||
if call_id:
|
||||
return f"{call_id}|{tool}"
|
||||
return f"{tool}|{edit.get('path') or ''}"
|
||||
|
||||
def find_file_edit_trace_index(
|
||||
segment: str | None,
|
||||
edits: list[dict[str, Any]],
|
||||
) -> int | None:
|
||||
incoming_keys = {_file_edit_key(edit) for edit in edits if isinstance(edit, dict)}
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
candidate = messages[i]
|
||||
if candidate.get("role") == "user":
|
||||
break
|
||||
if candidate.get("kind") != "trace" or not candidate.get("fileEdits"):
|
||||
continue
|
||||
if segment and candidate.get("activitySegmentId") == segment:
|
||||
return i
|
||||
existing_edits = candidate.get("fileEdits")
|
||||
if not isinstance(existing_edits, list):
|
||||
continue
|
||||
for existing in existing_edits:
|
||||
if isinstance(existing, dict) and _file_edit_key(existing) in incoming_keys:
|
||||
return i
|
||||
return None
|
||||
|
||||
def upsert_file_edits(edits: list[dict[str, Any]], idx: int) -> None:
|
||||
nonlocal active_file_edit_segment_id
|
||||
if not edits:
|
||||
return
|
||||
last = messages[-1] if messages else None
|
||||
if (
|
||||
active_file_edit_segment_id
|
||||
and last
|
||||
and last.get("kind") == "trace"
|
||||
and last.get("fileEdits")
|
||||
):
|
||||
segment = active_file_edit_segment_id
|
||||
else:
|
||||
segment = _new_activity_segment(activate=False)
|
||||
segment = active_file_edit_segment_id
|
||||
target_index = find_file_edit_trace_index(segment, edits)
|
||||
if target_index is not None:
|
||||
last = messages[target_index]
|
||||
segment = str(last.get("activitySegmentId") or segment or _new_activity_segment(activate=False))
|
||||
active_file_edit_segment_id = segment
|
||||
else:
|
||||
if not segment:
|
||||
segment = _new_activity_segment(activate=False)
|
||||
active_file_edit_segment_id = segment
|
||||
if not (
|
||||
last
|
||||
and last.get("kind") == "trace"
|
||||
and not last.get("isStreaming")
|
||||
and last.get("fileEdits")
|
||||
and last.get("activitySegmentId") == segment
|
||||
):
|
||||
messages.append(
|
||||
{
|
||||
"id": _new_id("tr", idx),
|
||||
@ -303,7 +329,11 @@ def replay_transcript_to_ui_messages(
|
||||
"createdAt": _ts_base + idx,
|
||||
},
|
||||
)
|
||||
last = messages[-1]
|
||||
target_index = len(messages) - 1
|
||||
last = messages[target_index]
|
||||
if not segment:
|
||||
segment = _new_activity_segment(activate=False)
|
||||
active_file_edit_segment_id = segment
|
||||
existing = list(last.get("fileEdits") or [])
|
||||
index_by_key = {
|
||||
_file_edit_key(edit): pos
|
||||
@ -316,11 +346,14 @@ def replay_transcript_to_ui_messages(
|
||||
key = _file_edit_key(edit)
|
||||
if key in index_by_key:
|
||||
pos = index_by_key[key]
|
||||
existing[pos] = {**existing[pos], **edit}
|
||||
merged = {**existing[pos], **edit}
|
||||
if edit.get("path") and not edit.get("pending"):
|
||||
merged.pop("pending", None)
|
||||
existing[pos] = merged
|
||||
else:
|
||||
index_by_key[key] = len(existing)
|
||||
existing.append(dict(edit))
|
||||
messages[-1] = {
|
||||
messages[target_index] = {
|
||||
**last,
|
||||
"fileEdits": existing,
|
||||
"activitySegmentId": last.get("activitySegmentId") or segment,
|
||||
@ -365,6 +398,7 @@ def replay_transcript_to_ui_messages(
|
||||
chunk = rec.get("text")
|
||||
if not isinstance(chunk, str):
|
||||
continue
|
||||
close_activity_for_answer()
|
||||
adopted = find_active_placeholder(messages) if buffer_message_id is None else None
|
||||
if buffer_message_id is None:
|
||||
if adopted:
|
||||
@ -403,6 +437,7 @@ def replay_transcript_to_ui_messages(
|
||||
chunk = rec.get("text")
|
||||
if not isinstance(chunk, str) or not chunk:
|
||||
continue
|
||||
close_file_edit_phase_before_activity()
|
||||
attach_reasoning_chunk(messages, chunk, idx)
|
||||
continue
|
||||
|
||||
@ -424,6 +459,7 @@ def replay_transcript_to_ui_messages(
|
||||
line = rec.get("text")
|
||||
if not isinstance(line, str) or not line:
|
||||
continue
|
||||
close_file_edit_phase_before_activity()
|
||||
attach_reasoning_chunk(messages, line, idx)
|
||||
close_reasoning(messages)
|
||||
continue
|
||||
|
||||
@ -309,6 +309,100 @@ class TestToolEventProgress:
|
||||
await invoke_file_edit_progress(telegram_progress, edit_events)
|
||||
assert bus.outbound_size == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_goal_turn_keeps_live_file_edit_progress_for_webui(self, tmp_path: Path) -> None:
|
||||
"""The /goal command rewrites the prompt but must not bypass WebUI file-edit progress."""
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.supports_progress_deltas = True
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
call_count = 0
|
||||
target = tmp_path / "goal.txt"
|
||||
|
||||
async def chat_stream_with_retry(*, on_tool_call_delta=None, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
assert on_tool_call_delta is not None
|
||||
await on_tool_call_delta({
|
||||
"index": 0,
|
||||
"call_id": "call-goal-write",
|
||||
"name": "write_file",
|
||||
"arguments_delta": '{"path":"goal.txt","content":"',
|
||||
})
|
||||
await on_tool_call_delta({
|
||||
"index": 0,
|
||||
"arguments_delta": "one\\ntwo\\nthree\\n",
|
||||
})
|
||||
await on_tool_call_delta({"index": 0, "arguments_delta": '"}'})
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call-goal-write",
|
||||
name="write_file",
|
||||
arguments={
|
||||
"path": "goal.txt",
|
||||
"content": "one\ntwo\nthree\n",
|
||||
},
|
||||
)
|
||||
],
|
||||
usage={},
|
||||
)
|
||||
return LLMResponse(content="Done", tool_calls=[], usage={})
|
||||
|
||||
async def execute(name: str, params: dict) -> str:
|
||||
assert name == "write_file"
|
||||
target.write_text(params["content"], encoding="utf-8")
|
||||
return "ok"
|
||||
|
||||
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
provider.chat_with_retry = AsyncMock()
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||
loop.tools.get_definitions = MagicMock(return_value=[
|
||||
{"type": "function", "function": {"name": "write_file"}},
|
||||
])
|
||||
loop.tools.prepare_call = MagicMock(
|
||||
return_value=(
|
||||
None,
|
||||
{"path": "goal.txt", "content": "one\ntwo\nthree\n"},
|
||||
None,
|
||||
),
|
||||
)
|
||||
loop.tools.execute = AsyncMock(side_effect=execute)
|
||||
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||
|
||||
await loop._dispatch(InboundMessage(
|
||||
channel="websocket",
|
||||
sender_id="u1",
|
||||
chat_id="chat1",
|
||||
content="/goal create goal file",
|
||||
metadata={"_wants_stream": True},
|
||||
))
|
||||
|
||||
outbound = []
|
||||
while bus.outbound_size > 0:
|
||||
outbound.append(await bus.consume_outbound())
|
||||
|
||||
edit_events = [
|
||||
event
|
||||
for msg in outbound
|
||||
for event in msg.metadata.get("_file_edit_events", [])
|
||||
]
|
||||
assert any(
|
||||
event["status"] == "editing"
|
||||
and event["approximate"]
|
||||
and event["added"] == 3
|
||||
for event in edit_events
|
||||
)
|
||||
assert any(
|
||||
event["status"] == "done"
|
||||
and not event["approximate"]
|
||||
and event["added"] == 3
|
||||
for event in edit_events
|
||||
)
|
||||
provider.chat_with_retry.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_channel_does_not_publish_codex_progress_deltas(
|
||||
self,
|
||||
|
||||
@ -6,7 +6,7 @@ import pytest
|
||||
|
||||
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMResponse
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
@ -77,3 +77,220 @@ async def test_runner_streams_provider_progress_deltas_by_default():
|
||||
assert result.final_content == "hello"
|
||||
assert [call.args[0] for call in progress_cb.await_args_list] == ["he", "llo"]
|
||||
provider.chat_with_retry.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_streams_live_write_file_activity_from_tool_argument_deltas(tmp_path):
|
||||
provider = MagicMock()
|
||||
provider.supports_progress_deltas = True
|
||||
call_count = 0
|
||||
progress_events: list[dict] = []
|
||||
|
||||
async def progress_cb(content, *, file_edit_events=None, **kwargs):
|
||||
if file_edit_events:
|
||||
progress_events.extend(file_edit_events)
|
||||
|
||||
class Tools:
|
||||
def get_definitions(self):
|
||||
return [{"type": "function", "function": {"name": "write_file"}}]
|
||||
|
||||
def get(self, name):
|
||||
return None
|
||||
|
||||
async def execute(self, name, params):
|
||||
assert name == "write_file"
|
||||
assert any(event["approximate"] and event["added"] == 24 for event in progress_events)
|
||||
target = tmp_path / params["path"]
|
||||
target.write_text(params["content"], encoding="utf-8")
|
||||
return "ok"
|
||||
|
||||
async def chat_stream_with_retry(*, on_tool_call_delta=None, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
assert on_tool_call_delta is not None
|
||||
await on_tool_call_delta({
|
||||
"index": 0,
|
||||
"call_id": "call-write",
|
||||
"name": "write_file",
|
||||
"arguments_delta": '{"path":"big.txt","content":"',
|
||||
})
|
||||
await on_tool_call_delta({"index": 0, "arguments_delta": "line\\n" * 24})
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call-write",
|
||||
name="write_file",
|
||||
arguments={"path": "big.txt", "content": "line\n" * 24},
|
||||
)
|
||||
],
|
||||
usage={},
|
||||
)
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
provider.chat_with_retry = AsyncMock()
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "write a large file"}],
|
||||
tools=Tools(),
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
progress_callback=progress_cb,
|
||||
workspace=tmp_path,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
assert any(event["approximate"] and event["added"] == 24 for event in progress_events)
|
||||
assert any(
|
||||
not event["approximate"] and event["phase"] == "end" and event["added"] == 24
|
||||
for event in progress_events
|
||||
)
|
||||
provider.chat_with_retry.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_streams_live_edit_file_activity_from_tool_argument_deltas(tmp_path):
|
||||
provider = MagicMock()
|
||||
provider.supports_progress_deltas = True
|
||||
call_count = 0
|
||||
progress_events: list[dict] = []
|
||||
target = tmp_path / "notes.txt"
|
||||
target.write_text("old\nkeep\n", encoding="utf-8")
|
||||
|
||||
async def progress_cb(content, *, file_edit_events=None, **kwargs):
|
||||
if file_edit_events:
|
||||
progress_events.extend(file_edit_events)
|
||||
|
||||
class Tools:
|
||||
def get_definitions(self):
|
||||
return [{"type": "function", "function": {"name": "edit_file"}}]
|
||||
|
||||
def get(self, name):
|
||||
return None
|
||||
|
||||
async def execute(self, name, params):
|
||||
assert name == "edit_file"
|
||||
assert any(
|
||||
event["tool"] == "edit_file"
|
||||
and event["approximate"]
|
||||
and event["added"] == 3
|
||||
and event["deleted"] == 2
|
||||
for event in progress_events
|
||||
)
|
||||
target.write_text(params["new_text"], encoding="utf-8")
|
||||
return "ok"
|
||||
|
||||
async def chat_stream_with_retry(*, on_tool_call_delta=None, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
assert on_tool_call_delta is not None
|
||||
await on_tool_call_delta({
|
||||
"index": 0,
|
||||
"call_id": "call-edit",
|
||||
"name": "edit_file",
|
||||
"arguments_delta": (
|
||||
'{"path":"notes.txt","old_text":"old\\nkeep\\n","new_text":"'
|
||||
),
|
||||
})
|
||||
await on_tool_call_delta({
|
||||
"index": 0,
|
||||
"arguments_delta": "new\\nkeep\\nextra\\n",
|
||||
})
|
||||
await on_tool_call_delta({"index": 0, "arguments_delta": '"}'})
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call-edit",
|
||||
name="edit_file",
|
||||
arguments={
|
||||
"path": "notes.txt",
|
||||
"old_text": "old\nkeep\n",
|
||||
"new_text": "new\nkeep\nextra\n",
|
||||
},
|
||||
)
|
||||
],
|
||||
usage={},
|
||||
)
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
provider.chat_with_retry = AsyncMock()
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "edit a file"}],
|
||||
tools=Tools(),
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
progress_callback=progress_cb,
|
||||
workspace=tmp_path,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
assert any(
|
||||
event["tool"] == "edit_file"
|
||||
and event["approximate"]
|
||||
and event["added"] == 3
|
||||
and event["deleted"] == 2
|
||||
for event in progress_events
|
||||
)
|
||||
assert any(
|
||||
event["tool"] == "edit_file"
|
||||
and not event["approximate"]
|
||||
and event["phase"] == "end"
|
||||
and event["added"] == 2
|
||||
and event["deleted"] == 1
|
||||
for event in progress_events
|
||||
)
|
||||
provider.chat_with_retry.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_marks_unfinished_live_write_file_activity_failed(tmp_path):
|
||||
provider = MagicMock()
|
||||
provider.supports_progress_deltas = True
|
||||
progress_events: list[dict] = []
|
||||
|
||||
async def progress_cb(content, *, file_edit_events=None, **kwargs):
|
||||
if file_edit_events:
|
||||
progress_events.extend(file_edit_events)
|
||||
|
||||
async def chat_stream_with_retry(*, on_tool_call_delta=None, **kwargs):
|
||||
assert on_tool_call_delta is not None
|
||||
await on_tool_call_delta({
|
||||
"index": 0,
|
||||
"call_id": "call-write",
|
||||
"name": "write_file",
|
||||
"arguments_delta": '{"path":"aborted.txt","content":"partial\\n',
|
||||
})
|
||||
return LLMResponse(content="stopped", tool_calls=[], finish_reason="stop", usage={})
|
||||
|
||||
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
provider.chat_with_retry = AsyncMock()
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = [{"type": "function", "function": {"name": "write_file"}}]
|
||||
tools.get.return_value = None
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "write a large file"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
progress_callback=progress_cb,
|
||||
workspace=tmp_path,
|
||||
))
|
||||
|
||||
assert result.final_content == "stopped"
|
||||
assert progress_events[-1]["path"] == "aborted.txt"
|
||||
assert progress_events[-1]["phase"] == "error"
|
||||
assert progress_events[-1]["status"] == "error"
|
||||
provider.chat_with_retry.assert_not_awaited()
|
||||
|
||||
@ -129,6 +129,74 @@ async def test_chat_stream_invokes_on_thinking_delta_for_thinking_delta() -> Non
|
||||
assert text_parts == ["X"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_stream_invokes_tool_call_delta_for_input_json_delta() -> None:
|
||||
provider = AnthropicProvider(api_key="sk-test")
|
||||
provider._client = MagicMock()
|
||||
|
||||
chunks = [
|
||||
SimpleNamespace(
|
||||
type="content_block_start",
|
||||
index=1,
|
||||
content_block=SimpleNamespace(
|
||||
type="tool_use",
|
||||
id="toolu_1",
|
||||
name="write_file",
|
||||
),
|
||||
),
|
||||
SimpleNamespace(
|
||||
type="content_block_delta",
|
||||
index=1,
|
||||
delta=SimpleNamespace(
|
||||
type="input_json_delta",
|
||||
partial_json='{"path":"notes.md","content":"',
|
||||
),
|
||||
),
|
||||
SimpleNamespace(
|
||||
type="content_block_delta",
|
||||
index=1,
|
||||
delta=SimpleNamespace(type="input_json_delta", partial_json="line\\n"),
|
||||
),
|
||||
]
|
||||
fake = _FakeAsyncStream(chunks)
|
||||
stream_cm = MagicMock()
|
||||
stream_cm.__aenter__ = AsyncMock(return_value=fake)
|
||||
stream_cm.__aexit__ = AsyncMock(return_value=None)
|
||||
provider._client.messages.stream = MagicMock(return_value=stream_cm)
|
||||
|
||||
deltas: list[dict] = []
|
||||
|
||||
async def on_tool_delta(delta: dict) -> None:
|
||||
deltas.append(delta)
|
||||
|
||||
await provider.chat_stream(
|
||||
messages=[{"role": "user", "content": "write"}],
|
||||
on_tool_call_delta=on_tool_delta,
|
||||
)
|
||||
|
||||
assert deltas == [
|
||||
{
|
||||
"index": 1,
|
||||
"call_id": "toolu_1",
|
||||
"name": "write_file",
|
||||
"arguments_delta": "",
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"call_id": "toolu_1",
|
||||
"name": "write_file",
|
||||
"arguments_delta": '{"path":"notes.md","content":"',
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"call_id": "toolu_1",
|
||||
"name": "write_file",
|
||||
"arguments_delta": "line\\n",
|
||||
},
|
||||
]
|
||||
fake.get_final_message.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_stream_without_callback_still_finalizes() -> None:
|
||||
provider = AnthropicProvider(api_key="sk-test")
|
||||
|
||||
@ -164,6 +164,130 @@ def _fake_chat_stream_reasoning_chunks():
|
||||
return _stream()
|
||||
|
||||
|
||||
def _fake_chat_stream_tool_call_chunks():
|
||||
"""Mimic OpenAI-compatible streaming tool-call argument deltas."""
|
||||
|
||||
async def _stream():
|
||||
yield SimpleNamespace(
|
||||
choices=[
|
||||
SimpleNamespace(
|
||||
finish_reason=None,
|
||||
delta=SimpleNamespace(
|
||||
content=None,
|
||||
reasoning_content=None,
|
||||
reasoning=None,
|
||||
tool_calls=[
|
||||
SimpleNamespace(
|
||||
index=0,
|
||||
id="call_write",
|
||||
function=SimpleNamespace(
|
||||
name="write_file",
|
||||
arguments='{"path":"notes.md","content":"',
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
],
|
||||
usage=None,
|
||||
)
|
||||
yield SimpleNamespace(
|
||||
choices=[
|
||||
SimpleNamespace(
|
||||
finish_reason=None,
|
||||
delta=SimpleNamespace(
|
||||
content=None,
|
||||
reasoning_content=None,
|
||||
reasoning=None,
|
||||
tool_calls=[
|
||||
SimpleNamespace(
|
||||
index=0,
|
||||
id=None,
|
||||
function=SimpleNamespace(name=None, arguments='line\\n"}'),
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
],
|
||||
usage=None,
|
||||
)
|
||||
yield SimpleNamespace(
|
||||
choices=[
|
||||
SimpleNamespace(
|
||||
finish_reason="tool_calls",
|
||||
delta=SimpleNamespace(
|
||||
content=None,
|
||||
reasoning_content=None,
|
||||
reasoning=None,
|
||||
tool_calls=None,
|
||||
),
|
||||
),
|
||||
],
|
||||
usage=SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
)
|
||||
|
||||
return _stream()
|
||||
|
||||
|
||||
def _fake_chat_stream_legacy_function_call_chunks():
|
||||
"""Mimic older OpenAI-compatible ``delta.function_call`` chunks."""
|
||||
|
||||
async def _stream():
|
||||
yield SimpleNamespace(
|
||||
choices=[
|
||||
SimpleNamespace(
|
||||
finish_reason=None,
|
||||
delta=SimpleNamespace(
|
||||
content=None,
|
||||
reasoning_content=None,
|
||||
reasoning=None,
|
||||
tool_calls=None,
|
||||
function_call=SimpleNamespace(
|
||||
name="write_file",
|
||||
arguments='{"path":"notes.md","content":"',
|
||||
),
|
||||
),
|
||||
),
|
||||
],
|
||||
usage=None,
|
||||
)
|
||||
yield SimpleNamespace(
|
||||
choices=[
|
||||
SimpleNamespace(
|
||||
finish_reason=None,
|
||||
delta=SimpleNamespace(
|
||||
content=None,
|
||||
reasoning_content=None,
|
||||
reasoning=None,
|
||||
tool_calls=None,
|
||||
function_call=SimpleNamespace(
|
||||
name=None,
|
||||
arguments='line\\n"}',
|
||||
),
|
||||
),
|
||||
),
|
||||
],
|
||||
usage=None,
|
||||
)
|
||||
yield SimpleNamespace(
|
||||
choices=[
|
||||
SimpleNamespace(
|
||||
finish_reason="function_call",
|
||||
delta=SimpleNamespace(
|
||||
content=None,
|
||||
reasoning_content=None,
|
||||
reasoning=None,
|
||||
tool_calls=None,
|
||||
function_call=None,
|
||||
),
|
||||
),
|
||||
],
|
||||
usage=SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
)
|
||||
|
||||
return _stream()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_compat_stream_forwards_reasoning_deltas_deepseek_style() -> None:
|
||||
"""Regression: DeepSeek-V4 / reasoner expose ``delta.reasoning_content`` during streaming."""
|
||||
@ -202,6 +326,98 @@ async def test_openai_compat_stream_forwards_reasoning_deltas_deepseek_style() -
|
||||
mock_chat.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("provider_name", "model"),
|
||||
[
|
||||
("openai", "gpt-4o"),
|
||||
("deepseek", "deepseek-chat"),
|
||||
("minimax", "MiniMax-M2.7"),
|
||||
("zhipu", "glm-4.6"),
|
||||
],
|
||||
)
|
||||
async def test_openai_compat_stream_forwards_tool_call_argument_deltas(
|
||||
provider_name: str,
|
||||
model: str,
|
||||
) -> None:
|
||||
mock_chat = AsyncMock(return_value=_fake_chat_stream_tool_call_chunks())
|
||||
spec = find_by_name(provider_name)
|
||||
deltas: list[dict] = []
|
||||
|
||||
async def on_tool_delta(delta: dict) -> None:
|
||||
deltas.append(delta)
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_openai:
|
||||
client_instance = mock_openai.return_value
|
||||
client_instance.chat.completions.create = mock_chat
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-test",
|
||||
default_model=model,
|
||||
spec=spec,
|
||||
)
|
||||
result = await provider.chat_stream(
|
||||
messages=[{"role": "user", "content": "write"}],
|
||||
tools=[{"type": "function", "function": {"name": "write_file"}}],
|
||||
model=model,
|
||||
on_tool_call_delta=on_tool_delta,
|
||||
)
|
||||
|
||||
assert deltas == [
|
||||
{
|
||||
"index": 0,
|
||||
"call_id": "call_write",
|
||||
"name": "write_file",
|
||||
"arguments_delta": '{"path":"notes.md","content":"',
|
||||
},
|
||||
{"index": 0, "call_id": "", "name": "", "arguments_delta": 'line\\n"}'},
|
||||
]
|
||||
assert result.tool_calls[0].name == "write_file"
|
||||
assert result.tool_calls[0].arguments == {"path": "notes.md", "content": "line\n"}
|
||||
kwargs = mock_chat.await_args.kwargs
|
||||
if provider_name == "zhipu":
|
||||
assert kwargs["extra_body"]["tool_stream"] is True
|
||||
else:
|
||||
assert kwargs.get("extra_body", {}).get("tool_stream") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_compat_stream_forwards_legacy_function_call_argument_deltas() -> None:
|
||||
mock_chat = AsyncMock(return_value=_fake_chat_stream_legacy_function_call_chunks())
|
||||
deltas: list[dict] = []
|
||||
|
||||
async def on_tool_delta(delta: dict) -> None:
|
||||
deltas.append(delta)
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_openai:
|
||||
client_instance = mock_openai.return_value
|
||||
client_instance.chat.completions.create = mock_chat
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-test",
|
||||
default_model="deepseek-chat",
|
||||
spec=find_by_name("deepseek"),
|
||||
)
|
||||
result = await provider.chat_stream(
|
||||
messages=[{"role": "user", "content": "write"}],
|
||||
tools=[{"type": "function", "function": {"name": "write_file"}}],
|
||||
model="deepseek-chat",
|
||||
on_tool_call_delta=on_tool_delta,
|
||||
)
|
||||
|
||||
assert deltas == [
|
||||
{
|
||||
"index": 0,
|
||||
"call_id": "",
|
||||
"name": "write_file",
|
||||
"arguments_delta": '{"path":"notes.md","content":"',
|
||||
},
|
||||
{"index": 0, "call_id": "", "name": "", "arguments_delta": 'line\\n"}'},
|
||||
]
|
||||
assert result.tool_calls[0].name == "write_file"
|
||||
assert result.tool_calls[0].arguments == {"path": "notes.md", "content": "line\n"}
|
||||
|
||||
|
||||
class _FakeResponsesError(Exception):
|
||||
def __init__(self, status_code: int, text: str):
|
||||
super().__init__(text)
|
||||
|
||||
@ -44,9 +44,15 @@ class TestShouldExecuteTools:
|
||||
resp = _response("stop")
|
||||
assert resp.should_execute_tools is True
|
||||
|
||||
def test_legacy_function_call_reason_executes(self) -> None:
|
||||
# Older OpenAI-compatible streaming APIs can still use the singular
|
||||
# function_call finish reason while carrying a tool-call-shaped payload.
|
||||
resp = _response("function_call")
|
||||
assert resp.should_execute_tools is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"anomalous_reason",
|
||||
["refusal", "content_filter", "error", "length", "function_call", ""],
|
||||
["refusal", "content_filter", "error", "length", ""],
|
||||
)
|
||||
def test_tool_calls_under_anomalous_reason_blocked(self, anomalous_reason: str) -> None:
|
||||
# This is the #3220 bug: gateways injecting tool_calls under any of these
|
||||
|
||||
@ -453,6 +453,56 @@ class TestConsumeSdkStream:
|
||||
assert tool_calls[0].name == "get_weather"
|
||||
assert tool_calls[0].arguments == {"city": "SF"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_argument_delta_callback(self):
|
||||
item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="")
|
||||
item_added.name = "write_file"
|
||||
ev1 = MagicMock(type="response.output_item.added", item=item_added)
|
||||
ev2 = MagicMock(
|
||||
type="response.function_call_arguments.delta",
|
||||
call_id="c1",
|
||||
delta='{"path":"a.txt","content":"',
|
||||
)
|
||||
ev3 = MagicMock(
|
||||
type="response.function_call_arguments.delta",
|
||||
call_id="c1",
|
||||
delta='hello\\n',
|
||||
)
|
||||
ev4 = MagicMock(
|
||||
type="response.function_call_arguments.done",
|
||||
call_id="c1",
|
||||
arguments='{"path":"a.txt","content":"hello\\n"}',
|
||||
)
|
||||
item_done = MagicMock(
|
||||
type="function_call",
|
||||
call_id="c1",
|
||||
id="fc1",
|
||||
arguments='{"path":"a.txt","content":"hello\\n"}',
|
||||
)
|
||||
item_done.name = "write_file"
|
||||
ev5 = MagicMock(type="response.output_item.done", item=item_done)
|
||||
resp_obj = MagicMock(status="completed", usage=None, output=[])
|
||||
ev6 = MagicMock(type="response.completed", response=resp_obj)
|
||||
deltas: list[dict] = []
|
||||
|
||||
async def cb(delta: dict) -> None:
|
||||
deltas.append(delta)
|
||||
|
||||
async def stream():
|
||||
for e in [ev1, ev2, ev3, ev4, ev5, ev6]:
|
||||
yield e
|
||||
|
||||
await consume_sdk_stream(stream(), on_tool_call_delta=cb)
|
||||
assert deltas == [
|
||||
{"call_id": "c1", "name": "write_file", "arguments_delta": ""},
|
||||
{
|
||||
"call_id": "c1",
|
||||
"name": "write_file",
|
||||
"arguments_delta": '{"path":"a.txt","content":"',
|
||||
},
|
||||
{"call_id": "c1", "name": "write_file", "arguments_delta": "hello\\n"},
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_extracted(self):
|
||||
usage_obj = MagicMock(input_tokens=10, output_tokens=5, total_tokens=15)
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
from nanobot.utils.file_edit_events import (
|
||||
build_file_edit_end_event,
|
||||
@ -8,6 +10,7 @@ from nanobot.utils.file_edit_events import (
|
||||
line_diff_stats,
|
||||
prepare_file_edit_tracker,
|
||||
read_file_snapshot,
|
||||
StreamingFileEditTracker,
|
||||
)
|
||||
|
||||
|
||||
@ -20,6 +23,10 @@ def test_line_diff_stats_normalizes_crlf() -> None:
|
||||
assert line_diff_stats("a\r\nb\r\n", "a\nb\nc\n") == (1, 0)
|
||||
|
||||
|
||||
def test_line_diff_stats_counts_new_file_crlf_lines_once() -> None:
|
||||
assert line_diff_stats("", "a\r\nb\r\n") == (2, 0)
|
||||
|
||||
|
||||
def test_write_file_start_predicts_and_end_calibrates_exact_diff(tmp_path: Path) -> None:
|
||||
target = tmp_path / "notes.txt"
|
||||
target.write_text("old\nkeep\n", encoding="utf-8")
|
||||
@ -39,6 +46,7 @@ def test_write_file_start_predicts_and_end_calibrates_exact_diff(tmp_path: Path)
|
||||
"call_id": "call-write",
|
||||
"tool": "write_file",
|
||||
"path": "notes.txt",
|
||||
"absolute_path": (tmp_path / "notes.txt").as_posix(),
|
||||
"phase": "start",
|
||||
"added": 2,
|
||||
"deleted": 1,
|
||||
@ -73,6 +81,307 @@ def test_binary_file_is_reported_but_not_counted(tmp_path: Path) -> None:
|
||||
assert (event["added"], event["deleted"]) == (0, 0)
|
||||
|
||||
|
||||
def test_oversized_write_file_end_uses_known_content_for_exact_count(tmp_path: Path) -> None:
|
||||
target = tmp_path / "large.txt"
|
||||
params = {"path": "large.txt", "content": "x" * (2 * 1024 * 1024 + 1)}
|
||||
tracker = prepare_file_edit_tracker(
|
||||
call_id="call-large",
|
||||
tool_name="write_file",
|
||||
tool=None,
|
||||
workspace=tmp_path,
|
||||
params=params,
|
||||
)
|
||||
|
||||
assert tracker is not None
|
||||
target.write_text(params["content"], encoding="utf-8")
|
||||
event = build_file_edit_end_event(tracker, params)
|
||||
assert event.get("binary") is not True
|
||||
assert event["added"] == 1
|
||||
assert event["deleted"] == 0
|
||||
|
||||
|
||||
def test_streaming_write_file_tracker_emits_live_line_counts(tmp_path: Path) -> None:
|
||||
events: list[dict] = []
|
||||
|
||||
async def emit(batch: list[dict]) -> None:
|
||||
events.extend(batch)
|
||||
|
||||
async def run() -> None:
|
||||
tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
|
||||
await tracker.update({
|
||||
"index": 0,
|
||||
"call_id": "call-live",
|
||||
"name": "write_file",
|
||||
"arguments_delta": '{"path":"notes.md","content":"',
|
||||
})
|
||||
await tracker.update({
|
||||
"index": 0,
|
||||
"arguments_delta": "line\\n" * 24,
|
||||
})
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
assert events[0] == {
|
||||
"version": 1,
|
||||
"call_id": "call-live",
|
||||
"tool": "write_file",
|
||||
"path": "notes.md",
|
||||
"absolute_path": (tmp_path / "notes.md").as_posix(),
|
||||
"phase": "start",
|
||||
"added": 0,
|
||||
"deleted": 0,
|
||||
"approximate": True,
|
||||
"status": "editing",
|
||||
}
|
||||
assert events[-1]["path"] == "notes.md"
|
||||
assert events[-1]["status"] == "editing"
|
||||
assert events[-1]["approximate"] is True
|
||||
assert events[-1]["added"] == 24
|
||||
assert events[-1]["deleted"] == 0
|
||||
|
||||
|
||||
def test_streaming_write_file_tracker_emits_pending_before_path(tmp_path: Path) -> None:
|
||||
events: list[dict] = []
|
||||
|
||||
async def emit(batch: list[dict]) -> None:
|
||||
events.extend(batch)
|
||||
|
||||
async def run() -> None:
|
||||
tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
|
||||
await tracker.update({
|
||||
"index": 0,
|
||||
"call_id": "call-live",
|
||||
"name": "write_file",
|
||||
"arguments_delta": '{"content":"line\\n',
|
||||
})
|
||||
await tracker.update({
|
||||
"index": 0,
|
||||
"arguments_delta": 'more\\n","path":"late.md"',
|
||||
})
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
assert events[0] == {
|
||||
"version": 1,
|
||||
"call_id": "call-live",
|
||||
"tool": "write_file",
|
||||
"path": "",
|
||||
"phase": "start",
|
||||
"added": 1,
|
||||
"deleted": 0,
|
||||
"approximate": True,
|
||||
"status": "editing",
|
||||
"pending": True,
|
||||
}
|
||||
assert events[-1]["path"] == "late.md"
|
||||
assert events[-1].get("pending") is not True
|
||||
assert events[-1]["added"] == 2
|
||||
|
||||
|
||||
def test_streaming_write_file_tracker_flushes_small_pending_count(tmp_path: Path) -> None:
|
||||
events: list[dict] = []
|
||||
|
||||
async def emit(batch: list[dict]) -> None:
|
||||
events.extend(batch)
|
||||
|
||||
async def run() -> None:
|
||||
tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
|
||||
await tracker.update({
|
||||
"index": 0,
|
||||
"call_id": "call-live",
|
||||
"name": "write_file",
|
||||
"arguments_delta": '{"path":"small.md","content":"one\\n',
|
||||
})
|
||||
await tracker.flush()
|
||||
|
||||
asyncio.run(run())
|
||||
assert events
|
||||
assert events[-1]["path"] == "small.md"
|
||||
assert events[-1]["added"] == 1
|
||||
|
||||
|
||||
def test_streaming_write_file_tracker_normalizes_crlf_line_counts(tmp_path: Path) -> None:
|
||||
events: list[dict] = []
|
||||
|
||||
async def emit(batch: list[dict]) -> None:
|
||||
events.extend(batch)
|
||||
|
||||
async def run() -> None:
|
||||
tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
|
||||
await tracker.update({
|
||||
"index": 0,
|
||||
"call_id": "call-live",
|
||||
"name": "write_file",
|
||||
"arguments_delta": '{"path":"windows.txt","content":"one\\r\\ntwo\\r\\n',
|
||||
})
|
||||
await tracker.flush()
|
||||
|
||||
asyncio.run(run())
|
||||
assert events[-1]["path"] == "windows.txt"
|
||||
assert events[-1]["added"] == 2
|
||||
|
||||
|
||||
def test_streaming_write_file_tracker_counts_unicode_escaped_newlines(tmp_path: Path) -> None:
|
||||
events: list[dict] = []
|
||||
|
||||
async def emit(batch: list[dict]) -> None:
|
||||
events.extend(batch)
|
||||
|
||||
async def run() -> None:
|
||||
tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
|
||||
await tracker.update({
|
||||
"index": 0,
|
||||
"call_id": "call-live",
|
||||
"name": "write_file",
|
||||
"arguments_delta": '{"path":"unicode.txt","content":"one\\u000atwo',
|
||||
})
|
||||
await tracker.flush()
|
||||
|
||||
asyncio.run(run())
|
||||
assert events[-1]["path"] == "unicode.txt"
|
||||
assert events[-1]["added"] == 2
|
||||
|
||||
|
||||
def test_streaming_edit_file_tracker_emits_live_line_counts(tmp_path: Path) -> None:
|
||||
target = tmp_path / "notes.md"
|
||||
target.write_text("old\nkeep\n", encoding="utf-8")
|
||||
events: list[dict] = []
|
||||
|
||||
async def emit(batch: list[dict]) -> None:
|
||||
events.extend(batch)
|
||||
|
||||
async def run() -> None:
|
||||
tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
|
||||
await tracker.update({
|
||||
"index": 0,
|
||||
"call_id": "call-edit",
|
||||
"name": "edit_file",
|
||||
"arguments_delta": '{"path":"notes.md","old_text":"old\\nkeep","new_text":"',
|
||||
})
|
||||
await tracker.update({
|
||||
"index": 0,
|
||||
"arguments_delta": "new\\nkeep\\nextra\\n" * 8,
|
||||
})
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
assert events[0] == {
|
||||
"version": 1,
|
||||
"call_id": "call-edit",
|
||||
"tool": "edit_file",
|
||||
"path": "notes.md",
|
||||
"absolute_path": (tmp_path / "notes.md").as_posix(),
|
||||
"phase": "start",
|
||||
"added": 0,
|
||||
"deleted": 2,
|
||||
"approximate": True,
|
||||
"status": "editing",
|
||||
}
|
||||
assert events[-1]["path"] == "notes.md"
|
||||
assert events[-1]["status"] == "editing"
|
||||
assert events[-1]["approximate"] is True
|
||||
assert events[-1]["added"] == 24
|
||||
assert events[-1]["deleted"] == 2
|
||||
|
||||
|
||||
def test_streaming_tracker_applies_canonical_call_id_to_final_tool(tmp_path: Path) -> None:
|
||||
events: list[dict] = []
|
||||
|
||||
async def emit(batch: list[dict]) -> None:
|
||||
events.extend(batch)
|
||||
|
||||
async def run() -> None:
|
||||
tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
|
||||
await tracker.update({
|
||||
"index": 0,
|
||||
"name": "write_file",
|
||||
"arguments_delta": '{"path":"matched.md","content":"one\\n',
|
||||
})
|
||||
final = SimpleNamespace(
|
||||
id="provider-final-id",
|
||||
name="write_file",
|
||||
arguments={"path": "matched.md", "content": "one\n"},
|
||||
)
|
||||
tracker.apply_final_call_ids([final])
|
||||
assert final.id == "idx:0"
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_streaming_edit_file_tracker_flushes_small_pending_count(tmp_path: Path) -> None:
|
||||
target = tmp_path / "small.py"
|
||||
target.write_text("old\n", encoding="utf-8")
|
||||
events: list[dict] = []
|
||||
|
||||
async def emit(batch: list[dict]) -> None:
|
||||
events.extend(batch)
|
||||
|
||||
async def run() -> None:
|
||||
tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
|
||||
await tracker.update({
|
||||
"index": 0,
|
||||
"call_id": "call-edit",
|
||||
"name": "edit_file",
|
||||
"arguments_delta": '{"path":"small.py","old_text":"old\\n","new_text":"new\\nextra',
|
||||
})
|
||||
await tracker.flush()
|
||||
|
||||
asyncio.run(run())
|
||||
assert events
|
||||
assert events[-1]["path"] == "small.py"
|
||||
assert events[-1]["added"] == 2
|
||||
assert events[-1]["deleted"] == 1
|
||||
|
||||
|
||||
def test_streaming_write_file_tracker_errors_unmatched_live_edits(tmp_path: Path) -> None:
|
||||
events: list[dict] = []
|
||||
|
||||
async def emit(batch: list[dict]) -> None:
|
||||
events.extend(batch)
|
||||
|
||||
async def run() -> None:
|
||||
tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
|
||||
await tracker.update({
|
||||
"index": 0,
|
||||
"call_id": "call-live",
|
||||
"name": "write_file",
|
||||
"arguments_delta": '{"path":"aborted.md","content":"one\\n',
|
||||
})
|
||||
await tracker.error_unmatched([], "Tool call did not complete.")
|
||||
|
||||
asyncio.run(run())
|
||||
assert events[-1]["path"] == "aborted.md"
|
||||
assert events[-1]["phase"] == "error"
|
||||
assert events[-1]["status"] == "error"
|
||||
|
||||
|
||||
def test_streaming_write_file_tracker_keeps_matched_final_tool_call(tmp_path: Path) -> None:
|
||||
events: list[dict] = []
|
||||
|
||||
async def emit(batch: list[dict]) -> None:
|
||||
events.extend(batch)
|
||||
|
||||
async def run() -> None:
|
||||
tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
|
||||
await tracker.update({
|
||||
"index": 0,
|
||||
"call_id": "idx-only",
|
||||
"name": "write_file",
|
||||
"arguments_delta": '{"path":"matched.md","content":"one\\n',
|
||||
})
|
||||
await tracker.error_unmatched([
|
||||
SimpleNamespace(
|
||||
id="final-call",
|
||||
name="write_file",
|
||||
arguments={"path": "matched.md", "content": "one\n"},
|
||||
)
|
||||
], "Tool call did not complete.")
|
||||
|
||||
asyncio.run(run())
|
||||
assert events
|
||||
assert all(event["status"] == "editing" for event in events)
|
||||
|
||||
|
||||
def test_untracked_tools_do_not_prepare_file_edit_tracker(tmp_path: Path) -> None:
|
||||
assert prepare_file_edit_tracker(
|
||||
call_id="call-exec",
|
||||
|
||||
@ -98,6 +98,201 @@ def test_replay_file_edit_event_creates_file_activity(tmp_path, monkeypatch) ->
|
||||
assert msgs[2]["activitySegmentId"] != msgs[1]["activitySegmentId"]
|
||||
|
||||
|
||||
def test_replay_file_edit_progress_merges_after_interleaved_activity(tmp_path, monkeypatch) -> None:
|
||||
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
|
||||
key = "websocket:t-file-progress"
|
||||
for ev in (
|
||||
{"event": "user", "chat_id": "t-file-progress", "text": "edit"},
|
||||
{
|
||||
"event": "message",
|
||||
"chat_id": "t-file-progress",
|
||||
"text": 'write_file({"path":"foo.txt"})',
|
||||
"kind": "tool_hint",
|
||||
},
|
||||
{
|
||||
"event": "file_edit",
|
||||
"chat_id": "t-file-progress",
|
||||
"edits": [
|
||||
{
|
||||
"version": 1,
|
||||
"call_id": "call-write",
|
||||
"tool": "write_file",
|
||||
"path": "foo.txt",
|
||||
"phase": "start",
|
||||
"added": 12,
|
||||
"deleted": 0,
|
||||
"approximate": True,
|
||||
"status": "editing",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"event": "message",
|
||||
"chat_id": "t-file-progress",
|
||||
"text": "still working",
|
||||
"kind": "progress",
|
||||
},
|
||||
{
|
||||
"event": "file_edit",
|
||||
"chat_id": "t-file-progress",
|
||||
"edits": [
|
||||
{
|
||||
"version": 1,
|
||||
"call_id": "call-write",
|
||||
"tool": "write_file",
|
||||
"path": "foo.txt",
|
||||
"phase": "end",
|
||||
"added": 30,
|
||||
"deleted": 0,
|
||||
"approximate": False,
|
||||
"status": "done",
|
||||
},
|
||||
],
|
||||
},
|
||||
):
|
||||
append_transcript_object(key, ev)
|
||||
|
||||
msgs = replay_transcript_to_ui_messages(read_transcript_lines(key))
|
||||
file_edit_messages = [msg for msg in msgs if msg.get("fileEdits")]
|
||||
|
||||
assert len(file_edit_messages) == 1
|
||||
assert file_edit_messages[0]["fileEdits"] == [
|
||||
{
|
||||
"version": 1,
|
||||
"call_id": "call-write",
|
||||
"tool": "write_file",
|
||||
"path": "foo.txt",
|
||||
"phase": "end",
|
||||
"added": 30,
|
||||
"deleted": 0,
|
||||
"approximate": False,
|
||||
"status": "done",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_replay_file_edit_pending_placeholder_upgrades_to_path(tmp_path, monkeypatch) -> None:
|
||||
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
|
||||
key = "websocket:t-file-pending"
|
||||
for ev in (
|
||||
{"event": "user", "chat_id": "t-file-pending", "text": "write"},
|
||||
{
|
||||
"event": "file_edit",
|
||||
"chat_id": "t-file-pending",
|
||||
"edits": [
|
||||
{
|
||||
"version": 1,
|
||||
"call_id": "call-write",
|
||||
"tool": "write_file",
|
||||
"path": "",
|
||||
"phase": "start",
|
||||
"added": 1,
|
||||
"deleted": 0,
|
||||
"approximate": True,
|
||||
"status": "editing",
|
||||
"pending": True,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"event": "file_edit",
|
||||
"chat_id": "t-file-pending",
|
||||
"edits": [
|
||||
{
|
||||
"version": 1,
|
||||
"call_id": "call-write",
|
||||
"tool": "write_file",
|
||||
"path": "foo.txt",
|
||||
"phase": "start",
|
||||
"added": 12,
|
||||
"deleted": 0,
|
||||
"approximate": True,
|
||||
"status": "editing",
|
||||
},
|
||||
],
|
||||
},
|
||||
):
|
||||
append_transcript_object(key, ev)
|
||||
|
||||
msgs = replay_transcript_to_ui_messages(read_transcript_lines(key))
|
||||
file_edit_messages = [msg for msg in msgs if msg.get("fileEdits")]
|
||||
|
||||
assert len(file_edit_messages) == 1
|
||||
assert file_edit_messages[0]["fileEdits"] == [
|
||||
{
|
||||
"version": 1,
|
||||
"call_id": "call-write",
|
||||
"tool": "write_file",
|
||||
"path": "foo.txt",
|
||||
"phase": "start",
|
||||
"added": 12,
|
||||
"deleted": 0,
|
||||
"approximate": True,
|
||||
"status": "editing",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_replay_keeps_new_file_edit_after_reasoning_in_order(tmp_path, monkeypatch) -> None:
|
||||
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
|
||||
key = "websocket:t-file-order"
|
||||
for ev in (
|
||||
{"event": "user", "chat_id": "t-file-order", "text": "edit"},
|
||||
{
|
||||
"event": "file_edit",
|
||||
"chat_id": "t-file-order",
|
||||
"edits": [
|
||||
{
|
||||
"version": 1,
|
||||
"call_id": "call-one",
|
||||
"tool": "write_file",
|
||||
"path": "one.txt",
|
||||
"phase": "start",
|
||||
"added": 10,
|
||||
"deleted": 0,
|
||||
"approximate": True,
|
||||
"status": "editing",
|
||||
},
|
||||
],
|
||||
},
|
||||
{"event": "reasoning_delta", "chat_id": "t-file-order", "text": "Check next."},
|
||||
{"event": "reasoning_end", "chat_id": "t-file-order"},
|
||||
{
|
||||
"event": "file_edit",
|
||||
"chat_id": "t-file-order",
|
||||
"edits": [
|
||||
{
|
||||
"version": 1,
|
||||
"call_id": "call-two",
|
||||
"tool": "write_file",
|
||||
"path": "two.txt",
|
||||
"phase": "start",
|
||||
"added": 20,
|
||||
"deleted": 0,
|
||||
"approximate": True,
|
||||
"status": "editing",
|
||||
},
|
||||
],
|
||||
},
|
||||
):
|
||||
append_transcript_object(key, ev)
|
||||
|
||||
msgs = replay_transcript_to_ui_messages(read_transcript_lines(key))
|
||||
|
||||
assert [msg.get("fileEdits", [{}])[0].get("path") if msg.get("fileEdits") else msg.get("reasoning") for msg in msgs[1:]] == [
|
||||
"one.txt",
|
||||
"Check next.",
|
||||
"two.txt",
|
||||
]
|
||||
file_edit_segments = [
|
||||
msg.get("activitySegmentId")
|
||||
for msg in msgs
|
||||
if msg.get("fileEdits")
|
||||
]
|
||||
assert len(file_edit_segments) == 2
|
||||
assert file_edit_segments[0] != file_edit_segments[1]
|
||||
|
||||
|
||||
def test_build_response_schema(monkeypatch, tmp_path) -> None:
|
||||
from nanobot.utils.webui_transcript import build_webui_thread_response
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user