mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-14 23:19:55 +00:00
Merge remote-tracking branch 'origin/main' into nightly
This commit is contained in:
commit
4c684540c5
37
README.md
37
README.md
@ -1053,6 +1053,30 @@ Connects directly to any OpenAI-compatible endpoint — LM Studio, llama.cpp, To
|
||||
```
|
||||
|
||||
> For local servers that don't require a key, set `apiKey` to any non-empty string (e.g. `"no-key"`).
|
||||
>
|
||||
> `custom` is the right choice for providers that expose an OpenAI-compatible **chat completions** API. It does **not** force third-party endpoints onto the OpenAI/Azure **Responses API**.
|
||||
>
|
||||
> If your proxy or gateway is specifically Responses-API-compatible, use the `azure_openai` provider shape instead and point `apiBase` at that endpoint:
|
||||
>
|
||||
> ```json
|
||||
> {
|
||||
> "providers": {
|
||||
> "azure_openai": {
|
||||
> "apiKey": "your-api-key",
|
||||
> "apiBase": "https://api.your-provider.com",
|
||||
> "defaultModel": "your-model-name"
|
||||
> }
|
||||
> },
|
||||
> "agents": {
|
||||
> "defaults": {
|
||||
> "provider": "azure_openai",
|
||||
> "model": "your-model-name"
|
||||
> }
|
||||
> }
|
||||
> }
|
||||
> ```
|
||||
>
|
||||
> In short: **chat-completions-compatible endpoint → `custom`**; **Responses-compatible endpoint → `azure_openai`**.
|
||||
|
||||
</details>
|
||||
|
||||
@ -1858,6 +1882,19 @@ By default, the API binds to `127.0.0.1:8900`. You can change this in `config.js
|
||||
- Single-message input: each request must contain exactly one `user` message
|
||||
- Fixed model: omit `model`, or pass the same model shown by `/v1/models`
|
||||
- No streaming: `stream=true` is not supported
|
||||
- API requests run in the synthetic `api` channel, so the `message` tool does **not** automatically deliver to Telegram/Discord/etc. To proactively send to another chat, call `message` with an explicit `channel` and `chat_id` for an enabled channel.
|
||||
|
||||
Example tool call for cross-channel delivery from an API session:
|
||||
|
||||
```json
|
||||
{
|
||||
"content": "Build finished successfully.",
|
||||
"channel": "telegram",
|
||||
"chat_id": "123456789"
|
||||
}
|
||||
```
|
||||
|
||||
If `channel` points to a channel that is not enabled in your config, nanobot will queue the outbound event but no platform delivery will occur.
|
||||
|
||||
### Endpoints
|
||||
|
||||
|
||||
@ -129,6 +129,7 @@ class AgentLoop:
|
||||
"""
|
||||
|
||||
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
|
||||
_PENDING_USER_TURN_KEY = "pending_user_turn"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -621,6 +622,8 @@ class AgentLoop:
|
||||
session = self.sessions.get_or_create(key)
|
||||
if self._restore_runtime_checkpoint(session):
|
||||
self.sessions.save(session)
|
||||
if self._restore_pending_user_turn(session):
|
||||
self.sessions.save(session)
|
||||
|
||||
session, pending = self.auto_compact.prepare_session(session, key)
|
||||
|
||||
@ -656,6 +659,8 @@ class AgentLoop:
|
||||
session = self.sessions.get_or_create(key)
|
||||
if self._restore_runtime_checkpoint(session):
|
||||
self.sessions.save(session)
|
||||
if self._restore_pending_user_turn(session):
|
||||
self.sessions.save(session)
|
||||
|
||||
session, pending = self.auto_compact.prepare_session(session, key)
|
||||
|
||||
@ -696,6 +701,19 @@ class AgentLoop:
|
||||
)
|
||||
)
|
||||
|
||||
# Persist the triggering user message immediately, before running the
|
||||
# agent loop. If the process is killed mid-turn (OOM, SIGKILL, self-
|
||||
# restart, etc.), the existing runtime_checkpoint preserves the
|
||||
# in-flight assistant/tool state but NOT the user message itself, so
|
||||
# the user's prompt is silently lost on recovery. Saving it up front
|
||||
# makes recovery possible from the session log alone.
|
||||
user_persisted_early = False
|
||||
if isinstance(msg.content, str) and msg.content.strip():
|
||||
session.add_message("user", msg.content)
|
||||
self._mark_pending_user_turn(session)
|
||||
self.sessions.save(session)
|
||||
user_persisted_early = True
|
||||
|
||||
final_content, _, all_msgs, stop_reason, had_injections = await self._run_agent_loop(
|
||||
initial_messages,
|
||||
on_progress=on_progress or _bus_progress,
|
||||
@ -711,7 +729,10 @@ class AgentLoop:
|
||||
if final_content is None or not final_content.strip():
|
||||
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
# Skip the already-persisted user message when saving the turn
|
||||
save_skip = 1 + len(history) + (1 if user_persisted_early else 0)
|
||||
self._save_turn(session, all_msgs, save_skip)
|
||||
self._clear_pending_user_turn(session)
|
||||
self._clear_runtime_checkpoint(session)
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
||||
@ -829,6 +850,12 @@ class AgentLoop:
|
||||
session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload
|
||||
self.sessions.save(session)
|
||||
|
||||
def _mark_pending_user_turn(self, session: Session) -> None:
|
||||
session.metadata[self._PENDING_USER_TURN_KEY] = True
|
||||
|
||||
def _clear_pending_user_turn(self, session: Session) -> None:
|
||||
session.metadata.pop(self._PENDING_USER_TURN_KEY, None)
|
||||
|
||||
def _clear_runtime_checkpoint(self, session: Session) -> None:
|
||||
if self._RUNTIME_CHECKPOINT_KEY in session.metadata:
|
||||
session.metadata.pop(self._RUNTIME_CHECKPOINT_KEY, None)
|
||||
@ -895,9 +922,30 @@ class AgentLoop:
|
||||
break
|
||||
session.messages.extend(restored_messages[overlap:])
|
||||
|
||||
self._clear_pending_user_turn(session)
|
||||
self._clear_runtime_checkpoint(session)
|
||||
return True
|
||||
|
||||
def _restore_pending_user_turn(self, session: Session) -> bool:
|
||||
"""Close a turn that only persisted the user message before crashing."""
|
||||
from datetime import datetime
|
||||
|
||||
if not session.metadata.get(self._PENDING_USER_TURN_KEY):
|
||||
return False
|
||||
|
||||
if session.messages and session.messages[-1].get("role") == "user":
|
||||
session.messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Error: Task interrupted before a response was generated.",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
session.updated_at = datetime.now()
|
||||
|
||||
self._clear_pending_user_turn(session)
|
||||
return True
|
||||
|
||||
async def process_direct(
|
||||
self,
|
||||
content: str,
|
||||
|
||||
@ -134,6 +134,50 @@ class AgentRunner:
|
||||
continue
|
||||
messages.append(injection)
|
||||
|
||||
async def _try_drain_injections(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
messages: list[dict[str, Any]],
|
||||
assistant_message: dict[str, Any] | None,
|
||||
injection_cycles: int,
|
||||
*,
|
||||
phase: str = "after error",
|
||||
iteration: int | None = None,
|
||||
) -> tuple[bool, int]:
|
||||
"""Drain pending injections. Returns (should_continue, updated_cycles).
|
||||
|
||||
If injections are found and we haven't exceeded _MAX_INJECTION_CYCLES,
|
||||
append them to *messages* (and emit a checkpoint if *assistant_message*
|
||||
and *iteration* are both provided) and return (True, cycles+1) so the
|
||||
caller continues the iteration loop. Otherwise return (False, cycles).
|
||||
"""
|
||||
if injection_cycles >= _MAX_INJECTION_CYCLES:
|
||||
return False, injection_cycles
|
||||
injections = await self._drain_injections(spec)
|
||||
if not injections:
|
||||
return False, injection_cycles
|
||||
injection_cycles += 1
|
||||
if assistant_message is not None:
|
||||
messages.append(assistant_message)
|
||||
if iteration is not None:
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
"phase": "final_response",
|
||||
"iteration": iteration,
|
||||
"model": spec.model,
|
||||
"assistant_message": assistant_message,
|
||||
"completed_tool_results": [],
|
||||
"pending_tool_calls": [],
|
||||
},
|
||||
)
|
||||
self._append_injected_messages(messages, injections)
|
||||
logger.info(
|
||||
"Injected {} follow-up message(s) {} ({}/{})",
|
||||
len(injections), phase, injection_cycles, _MAX_INJECTION_CYCLES,
|
||||
)
|
||||
return True, injection_cycles
|
||||
|
||||
async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]:
|
||||
"""Drain pending user messages via the injection callback.
|
||||
|
||||
@ -287,6 +331,13 @@ class AgentRunner:
|
||||
context.error = error
|
||||
context.stop_reason = stop_reason
|
||||
await hook.after_iteration(context)
|
||||
should_continue, injection_cycles = await self._try_drain_injections(
|
||||
spec, messages, None, injection_cycles,
|
||||
phase="after tool error",
|
||||
)
|
||||
if should_continue:
|
||||
had_injections = True
|
||||
continue
|
||||
break
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
@ -302,16 +353,12 @@ class AgentRunner:
|
||||
empty_content_retries = 0
|
||||
length_recovery_count = 0
|
||||
# Checkpoint 1: drain injections after tools, before next LLM call
|
||||
if injection_cycles < _MAX_INJECTION_CYCLES:
|
||||
injections = await self._drain_injections(spec)
|
||||
if injections:
|
||||
had_injections = True
|
||||
injection_cycles += 1
|
||||
self._append_injected_messages(messages, injections)
|
||||
logger.info(
|
||||
"Injected {} follow-up message(s) after tool execution ({}/{})",
|
||||
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
|
||||
)
|
||||
_drained, injection_cycles = await self._try_drain_injections(
|
||||
spec, messages, None, injection_cycles,
|
||||
phase="after tool execution",
|
||||
)
|
||||
if _drained:
|
||||
had_injections = True
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
|
||||
@ -379,36 +426,18 @@ class AgentRunner:
|
||||
# Check for mid-turn injections BEFORE signaling stream end.
|
||||
# If injections are found we keep the stream alive (resuming=True)
|
||||
# so streaming channels don't prematurely finalize the card.
|
||||
_injected_after_final = False
|
||||
if injection_cycles < _MAX_INJECTION_CYCLES:
|
||||
injections = await self._drain_injections(spec)
|
||||
if injections:
|
||||
had_injections = True
|
||||
injection_cycles += 1
|
||||
_injected_after_final = True
|
||||
if assistant_message is not None:
|
||||
messages.append(assistant_message)
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
"phase": "final_response",
|
||||
"iteration": iteration,
|
||||
"model": spec.model,
|
||||
"assistant_message": assistant_message,
|
||||
"completed_tool_results": [],
|
||||
"pending_tool_calls": [],
|
||||
},
|
||||
)
|
||||
self._append_injected_messages(messages, injections)
|
||||
logger.info(
|
||||
"Injected {} follow-up message(s) after final response ({}/{})",
|
||||
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
|
||||
)
|
||||
should_continue, injection_cycles = await self._try_drain_injections(
|
||||
spec, messages, assistant_message, injection_cycles,
|
||||
phase="after final response",
|
||||
iteration=iteration,
|
||||
)
|
||||
if should_continue:
|
||||
had_injections = True
|
||||
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=_injected_after_final)
|
||||
await hook.on_stream_end(context, resuming=should_continue)
|
||||
|
||||
if _injected_after_final:
|
||||
if should_continue:
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
|
||||
@ -421,6 +450,13 @@ class AgentRunner:
|
||||
context.error = error
|
||||
context.stop_reason = stop_reason
|
||||
await hook.after_iteration(context)
|
||||
should_continue, injection_cycles = await self._try_drain_injections(
|
||||
spec, messages, None, injection_cycles,
|
||||
phase="after LLM error",
|
||||
)
|
||||
if should_continue:
|
||||
had_injections = True
|
||||
continue
|
||||
break
|
||||
if is_blank_text(clean):
|
||||
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
@ -431,6 +467,13 @@ class AgentRunner:
|
||||
context.error = error
|
||||
context.stop_reason = stop_reason
|
||||
await hook.after_iteration(context)
|
||||
should_continue, injection_cycles = await self._try_drain_injections(
|
||||
spec, messages, None, injection_cycles,
|
||||
phase="after empty response",
|
||||
)
|
||||
if should_continue:
|
||||
had_injections = True
|
||||
continue
|
||||
break
|
||||
|
||||
messages.append(assistant_message or build_assistant_message(
|
||||
@ -467,6 +510,17 @@ class AgentRunner:
|
||||
max_iterations=spec.max_iterations,
|
||||
)
|
||||
self._append_final_message(messages, final_content)
|
||||
# Drain any remaining injections so they are appended to the
|
||||
# conversation history instead of being re-published as
|
||||
# independent inbound messages by _dispatch's finally block.
|
||||
# We ignore should_continue here because the for-loop has already
|
||||
# exhausted all iterations.
|
||||
drained_after_max_iterations, injection_cycles = await self._try_drain_injections(
|
||||
spec, messages, None, injection_cycles,
|
||||
phase="after max_iterations",
|
||||
)
|
||||
if drained_after_max_iterations:
|
||||
had_injections = True
|
||||
|
||||
return AgentRunResult(
|
||||
final_content=final_content,
|
||||
|
||||
@ -454,7 +454,23 @@ async def connect_mcp_servers(
|
||||
return name, server_stack
|
||||
|
||||
except Exception as e:
|
||||
logger.error("MCP server '{}': failed to connect: {}", name, e)
|
||||
hint = ""
|
||||
text = str(e).lower()
|
||||
if any(
|
||||
marker in text
|
||||
for marker in (
|
||||
"parse error",
|
||||
"invalid json",
|
||||
"unexpected token",
|
||||
"jsonrpc",
|
||||
"content-length",
|
||||
)
|
||||
):
|
||||
hint = (
|
||||
" Hint: this looks like stdio protocol pollution. Make sure the MCP server writes "
|
||||
"only JSON-RPC to stdout and sends logs/debug output to stderr instead."
|
||||
)
|
||||
logger.error("MCP server '{}': failed to connect: {}{}", name, e, hint)
|
||||
try:
|
||||
await server_stack.aclose()
|
||||
except Exception:
|
||||
|
||||
@ -68,6 +68,13 @@ class ToolRegistry:
|
||||
params: dict[str, Any],
|
||||
) -> tuple[Tool | None, dict[str, Any], str | None]:
|
||||
"""Resolve, cast, and validate one tool call."""
|
||||
# Guard against invalid parameter types (e.g., list instead of dict)
|
||||
if not isinstance(params, dict) and name in ('write_file', 'read_file'):
|
||||
return None, params, (
|
||||
f"Error: Tool '{name}' parameters must be a JSON object, got {type(params).__name__}. "
|
||||
"Use named parameters: tool_name(param1=\"value1\", param2=\"value2\")"
|
||||
)
|
||||
|
||||
tool = self._tools.get(name)
|
||||
if not tool:
|
||||
return None, params, (
|
||||
|
||||
@ -337,6 +337,9 @@ class DingTalkChannel(BaseChannel):
|
||||
content_type = (resp.headers.get("content-type") or "").split(";")[0].strip()
|
||||
filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
|
||||
return resp.content, filename, content_type or None
|
||||
except httpx.TransportError as e:
|
||||
logger.error("DingTalk media download network error ref={} err={}", media_ref, e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("DingTalk media download error ref={} err={}", media_ref, e)
|
||||
return None, None, None
|
||||
@ -388,6 +391,9 @@ class DingTalkChannel(BaseChannel):
|
||||
logger.error("DingTalk media upload missing media_id body={}", text[:500])
|
||||
return None
|
||||
return str(media_id)
|
||||
except httpx.TransportError as e:
|
||||
logger.error("DingTalk media upload network error type={} err={}", media_type, e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("DingTalk media upload error type={} err={}", media_type, e)
|
||||
return None
|
||||
@ -437,6 +443,9 @@ class DingTalkChannel(BaseChannel):
|
||||
return False
|
||||
logger.debug("DingTalk message sent to {} with msgKey={}", chat_id, msg_key)
|
||||
return True
|
||||
except httpx.TransportError as e:
|
||||
logger.error("DingTalk network error sending message msgKey={} err={}", msg_key, e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error sending DingTalk message msgKey={} err={}", msg_key, e)
|
||||
return False
|
||||
|
||||
@ -366,6 +366,7 @@ class DiscordChannel(BaseChannel):
|
||||
await client.send_outbound(msg)
|
||||
except Exception as e:
|
||||
logger.error("Error sending Discord message: {}", e)
|
||||
raise
|
||||
finally:
|
||||
if not is_progress:
|
||||
await self._stop_typing(msg.chat_id)
|
||||
|
||||
@ -280,6 +280,9 @@ class QQChannel(BaseChannel):
|
||||
msg_id=msg_id,
|
||||
content=msg.content.strip(),
|
||||
)
|
||||
except (aiohttp.ClientError, OSError):
|
||||
# Network / transport errors — propagate so ChannelManager can retry
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Error sending QQ message to chat_id={}", msg.chat_id)
|
||||
|
||||
@ -362,7 +365,12 @@ class QQChannel(BaseChannel):
|
||||
|
||||
logger.info("QQ media sent: {}", filename)
|
||||
return True
|
||||
except (aiohttp.ClientError, OSError) as e:
|
||||
# Network / transport errors — propagate for retry by caller
|
||||
logger.warning("QQ send media network error filename={} err={}", filename, e)
|
||||
raise
|
||||
except Exception as e:
|
||||
# API-level or other non-network errors — return False so send() can fallback
|
||||
logger.error("QQ send media failed filename={} err={}", filename, e)
|
||||
return False
|
||||
|
||||
|
||||
@ -520,7 +520,10 @@ class TelegramChannel(BaseChannel):
|
||||
reply_parameters=reply_params,
|
||||
**(thread_kwargs or {}),
|
||||
)
|
||||
except Exception as e:
|
||||
except BadRequest as e:
|
||||
# Only fall back to plain text on actual HTML parse/format errors.
|
||||
# Network errors (TimedOut, NetworkError) should propagate immediately
|
||||
# to avoid doubling connection demand during pool exhaustion.
|
||||
logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
||||
try:
|
||||
await self._call_with_retry(
|
||||
@ -567,7 +570,10 @@ class TelegramChannel(BaseChannel):
|
||||
chat_id=int_chat_id, message_id=buf.message_id,
|
||||
text=html, parse_mode="HTML",
|
||||
)
|
||||
except Exception as e:
|
||||
except BadRequest as e:
|
||||
# Only fall back to plain text on actual HTML parse/format errors.
|
||||
# Network errors (TimedOut, NetworkError) should propagate immediately
|
||||
# to avoid doubling connection demand during pool exhaustion.
|
||||
if self._is_not_modified_error(e):
|
||||
logger.debug("Final stream edit already applied for {}", chat_id)
|
||||
self._stream_bufs.pop(chat_id, None)
|
||||
|
||||
@ -985,7 +985,43 @@ class WeixinChannel(BaseChannel):
|
||||
for media_path in (msg.media or []):
|
||||
try:
|
||||
await self._send_media_file(msg.chat_id, media_path, ctx_token)
|
||||
except (httpx.TimeoutException, httpx.TransportError) as net_err:
|
||||
# Network/transport errors: do NOT fall back to text —
|
||||
# the text send would also likely fail, and the outer
|
||||
# except will re-raise so ChannelManager retries properly.
|
||||
logger.error(
|
||||
"Network error sending WeChat media {}: {}",
|
||||
media_path,
|
||||
net_err,
|
||||
)
|
||||
raise
|
||||
except httpx.HTTPStatusError as http_err:
|
||||
status_code = (
|
||||
http_err.response.status_code
|
||||
if http_err.response is not None
|
||||
else 0
|
||||
)
|
||||
if status_code >= 500:
|
||||
# Server-side / retryable HTTP error — same as network.
|
||||
logger.error(
|
||||
"Server error ({} {}) sending WeChat media {}: {}",
|
||||
status_code,
|
||||
http_err.response.reason_phrase
|
||||
if http_err.response is not None
|
||||
else "",
|
||||
media_path,
|
||||
http_err,
|
||||
)
|
||||
raise
|
||||
# 4xx client errors are NOT retryable — fall back to text.
|
||||
filename = Path(media_path).name
|
||||
logger.error("Failed to send WeChat media {}: {}", media_path, http_err)
|
||||
await self._send_text(
|
||||
msg.chat_id, f"[Failed to send: {filename}]", ctx_token,
|
||||
)
|
||||
except Exception as e:
|
||||
# Non-network errors (format, file-not-found, etc.):
|
||||
# notify the user via text fallback.
|
||||
filename = Path(media_path).name
|
||||
logger.error("Failed to send WeChat media {}: {}", media_path, e)
|
||||
# Notify user about failure via text
|
||||
|
||||
@ -799,7 +799,12 @@ class OpenAICompatProvider(LLMProvider):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _handle_error(e: Exception) -> LLMResponse:
|
||||
def _handle_error(
|
||||
e: Exception,
|
||||
*,
|
||||
spec: ProviderSpec | None = None,
|
||||
api_base: str | None = None,
|
||||
) -> LLMResponse:
|
||||
body = (
|
||||
getattr(e, "doc", None)
|
||||
or getattr(e, "body", None)
|
||||
@ -807,6 +812,15 @@ class OpenAICompatProvider(LLMProvider):
|
||||
)
|
||||
body_text = body if isinstance(body, str) else str(body) if body is not None else ""
|
||||
msg = f"Error: {body_text.strip()[:500]}" if body_text.strip() else f"Error calling LLM: {e}"
|
||||
|
||||
text = f"{body_text} {e}".lower()
|
||||
if spec and spec.is_local and ("502" in text or "connection" in text or "refused" in text):
|
||||
msg += (
|
||||
"\nHint: this is a local model endpoint. Check that the local server is reachable at "
|
||||
f"{api_base or spec.default_api_base}, and if you are using a proxy/tunnel, make sure it "
|
||||
"can reach your local Ollama/vLLM service instead of routing localhost through the remote host."
|
||||
)
|
||||
|
||||
response = getattr(e, "response", None)
|
||||
retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None))
|
||||
if retry_after is None:
|
||||
@ -850,7 +864,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
)
|
||||
return self._parse(await self._client.chat.completions.create(**kwargs))
|
||||
except Exception as e:
|
||||
return self._handle_error(e)
|
||||
return self._handle_error(e, spec=self._spec, api_base=self.api_base)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
@ -933,7 +947,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
error_kind="timeout",
|
||||
)
|
||||
except Exception as e:
|
||||
return self._handle_error(e)
|
||||
return self._handle_error(e, spec=self._spec, api_base=self.api_base)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
|
||||
@ -1,5 +1,12 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.session.manager import Session
|
||||
|
||||
|
||||
@ -11,6 +18,12 @@ def _mk_loop() -> AgentLoop:
|
||||
return loop
|
||||
|
||||
|
||||
def _make_full_loop(tmp_path: Path) -> AgentLoop:
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
return AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model")
|
||||
|
||||
|
||||
def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None:
|
||||
loop = _mk_loop()
|
||||
session = Session(key="test:runtime-only")
|
||||
@ -200,3 +213,98 @@ def test_restore_runtime_checkpoint_dedupes_overlapping_tail() -> None:
|
||||
assert session.messages[0]["role"] == "assistant"
|
||||
assert session.messages[1]["tool_call_id"] == "call_done"
|
||||
assert session.messages[2]["tool_call_id"] == "call_pending"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_persists_user_message_before_turn_completes(tmp_path: Path) -> None:
|
||||
loop = _make_full_loop(tmp_path)
|
||||
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||
loop._run_agent_loop = AsyncMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="feishu", sender_id="u1", chat_id="c1", content="persist me")
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
await loop._process_message(msg)
|
||||
|
||||
loop.sessions.invalidate("feishu:c1")
|
||||
persisted = loop.sessions.get_or_create("feishu:c1")
|
||||
assert [m["role"] for m in persisted.messages] == ["user"]
|
||||
assert persisted.messages[0]["content"] == "persist me"
|
||||
assert persisted.metadata.get(AgentLoop._PENDING_USER_TURN_KEY) is True
|
||||
assert persisted.updated_at >= persisted.created_at
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_does_not_duplicate_early_persisted_user_message(tmp_path: Path) -> None:
|
||||
loop = _make_full_loop(tmp_path)
|
||||
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||
loop._run_agent_loop = AsyncMock(return_value=(
|
||||
"done",
|
||||
None,
|
||||
[
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "done"},
|
||||
],
|
||||
"stop",
|
||||
False,
|
||||
)) # type: ignore[method-assign]
|
||||
|
||||
result = await loop._process_message(
|
||||
InboundMessage(channel="feishu", sender_id="u1", chat_id="c2", content="hello")
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.content == "done"
|
||||
session = loop.sessions.get_or_create("feishu:c2")
|
||||
assert [
|
||||
{k: v for k, v in m.items() if k in {"role", "content"}}
|
||||
for m in session.messages
|
||||
] == [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "done"},
|
||||
]
|
||||
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_next_turn_after_crash_closes_pending_user_turn_before_new_input(tmp_path: Path) -> None:
|
||||
loop = _make_full_loop(tmp_path)
|
||||
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=MagicMock()) # unused because _run_agent_loop is stubbed
|
||||
|
||||
session = loop.sessions.get_or_create("feishu:c3")
|
||||
session.add_message("user", "old question")
|
||||
session.metadata[AgentLoop._PENDING_USER_TURN_KEY] = True
|
||||
loop.sessions.save(session)
|
||||
|
||||
loop._run_agent_loop = AsyncMock(return_value=(
|
||||
"new answer",
|
||||
None,
|
||||
[
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "old question"},
|
||||
{"role": "assistant", "content": "Error: Task interrupted before a response was generated."},
|
||||
{"role": "user", "content": "new question"},
|
||||
{"role": "assistant", "content": "new answer"},
|
||||
],
|
||||
"stop",
|
||||
False,
|
||||
)) # type: ignore[method-assign]
|
||||
|
||||
result = await loop._process_message(
|
||||
InboundMessage(channel="feishu", sender_id="u1", chat_id="c3", content="new question")
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.content == "new answer"
|
||||
session = loop.sessions.get_or_create("feishu:c3")
|
||||
assert [
|
||||
{k: v for k, v in m.items() if k in {"role", "content"}}
|
||||
for m in session.messages
|
||||
] == [
|
||||
{"role": "user", "content": "old question"},
|
||||
{"role": "assistant", "content": "Error: Task interrupted before a response was generated."},
|
||||
{"role": "user", "content": "new question"},
|
||||
{"role": "assistant", "content": "new answer"},
|
||||
]
|
||||
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
|
||||
|
||||
@ -18,6 +18,16 @@ from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
|
||||
def _make_injection_callback(queue: asyncio.Queue):
|
||||
"""Return an async callback that drains *queue* into a list of dicts."""
|
||||
async def inject_cb():
|
||||
items = []
|
||||
while not queue.empty():
|
||||
items.append(await queue.get())
|
||||
return items
|
||||
return inject_cb
|
||||
|
||||
|
||||
def _make_loop(tmp_path):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
@ -1888,12 +1898,7 @@ async def test_checkpoint1_injects_after_tool_execution():
|
||||
tools.execute = AsyncMock(return_value="file content")
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
|
||||
async def inject_cb():
|
||||
items = []
|
||||
while not injection_queue.empty():
|
||||
items.append(await injection_queue.get())
|
||||
return items
|
||||
inject_cb = _make_injection_callback(injection_queue)
|
||||
|
||||
# Put a follow-up message in the queue before the run starts
|
||||
await injection_queue.put(
|
||||
@ -1951,12 +1956,7 @@ async def test_checkpoint2_injects_after_final_response_with_resuming_stream():
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
|
||||
async def inject_cb():
|
||||
items = []
|
||||
while not injection_queue.empty():
|
||||
items.append(await injection_queue.get())
|
||||
return items
|
||||
inject_cb = _make_injection_callback(injection_queue)
|
||||
|
||||
# Inject a follow-up that arrives during the first response
|
||||
await injection_queue.put(
|
||||
@ -2005,12 +2005,7 @@ async def test_checkpoint2_preserves_final_response_in_history_before_followup()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
|
||||
async def inject_cb():
|
||||
items = []
|
||||
while not injection_queue.empty():
|
||||
items.append(await injection_queue.get())
|
||||
return items
|
||||
inject_cb = _make_injection_callback(injection_queue)
|
||||
|
||||
await injection_queue.put(
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question")
|
||||
@ -2410,3 +2405,330 @@ async def test_dispatch_republishes_leftover_queue_messages(tmp_path):
|
||||
contents = [m.content for m in msgs]
|
||||
assert "leftover-1" in contents
|
||||
assert "leftover-2" in contents
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_injections_on_fatal_tool_error():
|
||||
"""Pending injections should be drained even when a fatal tool error occurs."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="",
|
||||
tool_calls=[ToolCallRequest(id="c1", name="exec", arguments={"cmd": "bad"})],
|
||||
usage={},
|
||||
)
|
||||
# Second call: respond normally to the injected follow-up
|
||||
return LLMResponse(content="reply to follow-up", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(side_effect=RuntimeError("tool exploded"))
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
inject_cb = _make_injection_callback(injection_queue)
|
||||
|
||||
await injection_queue.put(
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after error")
|
||||
)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hello"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=5,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
fail_on_tool_error=True,
|
||||
injection_callback=inject_cb,
|
||||
))
|
||||
|
||||
assert result.had_injections is True
|
||||
assert result.final_content == "reply to follow-up"
|
||||
# The injection should be in the messages history
|
||||
injected = [
|
||||
m for m in result.messages
|
||||
if m.get("role") == "user" and m.get("content") == "follow-up after error"
|
||||
]
|
||||
assert len(injected) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_injections_on_llm_error():
|
||||
"""Pending injections should be drained when the LLM returns an error finish_reason."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[],
|
||||
finish_reason="error",
|
||||
usage={},
|
||||
)
|
||||
# Second call: respond normally to the injected follow-up
|
||||
return LLMResponse(content="recovered answer", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
inject_cb = _make_injection_callback(injection_queue)
|
||||
|
||||
await injection_queue.put(
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after LLM error")
|
||||
)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "previous response"},
|
||||
{"role": "user", "content": "trigger error"},
|
||||
],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=5,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
injection_callback=inject_cb,
|
||||
))
|
||||
|
||||
assert result.had_injections is True
|
||||
assert result.final_content == "recovered answer"
|
||||
injected = [
|
||||
m for m in result.messages
|
||||
if m.get("role") == "user" and "follow-up after LLM error" in str(m.get("content", ""))
|
||||
]
|
||||
assert len(injected) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_injections_on_empty_final_response():
|
||||
"""Pending injections should be drained when the runner exits due to empty response."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_EMPTY_RETRIES
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] <= _MAX_EMPTY_RETRIES + 1:
|
||||
return LLMResponse(content="", tool_calls=[], usage={})
|
||||
# After retries exhausted + injection drain, respond normally
|
||||
return LLMResponse(content="answer after empty", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
inject_cb = _make_injection_callback(injection_queue)
|
||||
|
||||
await injection_queue.put(
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after empty")
|
||||
)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "previous response"},
|
||||
{"role": "user", "content": "trigger empty"},
|
||||
],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=10,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
injection_callback=inject_cb,
|
||||
))
|
||||
|
||||
assert result.had_injections is True
|
||||
assert result.final_content == "answer after empty"
|
||||
injected = [
|
||||
m for m in result.messages
|
||||
if m.get("role") == "user" and "follow-up after empty" in str(m.get("content", ""))
|
||||
]
|
||||
assert len(injected) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_injections_on_max_iterations():
|
||||
"""Pending injections should be drained when the runner hits max_iterations.
|
||||
|
||||
Unlike other error paths, max_iterations cannot continue the loop, so
|
||||
injections are appended to messages but not processed by the LLM.
|
||||
The key point is they are consumed from the queue to prevent re-publish.
|
||||
"""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
return LLMResponse(
|
||||
content="",
|
||||
tool_calls=[ToolCallRequest(id=f"c{call_count['n']}", name="read_file", arguments={"path": "x"})],
|
||||
usage={},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="file content")
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
inject_cb = _make_injection_callback(injection_queue)
|
||||
|
||||
await injection_queue.put(
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after max iters")
|
||||
)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hello"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
injection_callback=inject_cb,
|
||||
))
|
||||
|
||||
assert result.stop_reason == "max_iterations"
|
||||
assert result.had_injections is True
|
||||
# The injection was consumed from the queue (preventing re-publish)
|
||||
assert injection_queue.empty()
|
||||
# The injection message is appended to conversation history
|
||||
injected = [
|
||||
m for m in result.messages
|
||||
if m.get("role") == "user" and m.get("content") == "follow-up after max iters"
|
||||
]
|
||||
assert len(injected) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_injections_set_flag_when_followup_arrives_after_last_iteration():
|
||||
"""Late follow-ups drained in max_iterations should still flip had_injections."""
|
||||
from nanobot.agent.hook import AgentHook
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
return LLMResponse(
|
||||
content="",
|
||||
tool_calls=[ToolCallRequest(id=f"c{call_count['n']}", name="read_file", arguments={"path": "x"})],
|
||||
usage={},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="file content")
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
inject_cb = _make_injection_callback(injection_queue)
|
||||
|
||||
class InjectOnLastAfterIterationHook(AgentHook):
|
||||
def __init__(self) -> None:
|
||||
self.after_iteration_calls = 0
|
||||
|
||||
async def after_iteration(self, context) -> None:
|
||||
self.after_iteration_calls += 1
|
||||
if self.after_iteration_calls == 2:
|
||||
await injection_queue.put(
|
||||
InboundMessage(
|
||||
channel="cli",
|
||||
sender_id="u",
|
||||
chat_id="c",
|
||||
content="late follow-up after max iters",
|
||||
)
|
||||
)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hello"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
injection_callback=inject_cb,
|
||||
hook=InjectOnLastAfterIterationHook(),
|
||||
))
|
||||
|
||||
assert result.stop_reason == "max_iterations"
|
||||
assert result.had_injections is True
|
||||
assert injection_queue.empty()
|
||||
injected = [
|
||||
m for m in result.messages
|
||||
if m.get("role") == "user" and m.get("content") == "late follow-up after max iters"
|
||||
]
|
||||
assert len(injected) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_injection_cycle_cap_on_error_path():
|
||||
"""Injection cycles should be capped even when every iteration hits an LLM error."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTION_CYCLES
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[],
|
||||
finish_reason="error",
|
||||
usage={},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
drain_count = {"n": 0}
|
||||
|
||||
async def inject_cb():
|
||||
drain_count["n"] += 1
|
||||
if drain_count["n"] <= _MAX_INJECTION_CYCLES:
|
||||
return [InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg-{drain_count['n']}")]
|
||||
return []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "previous"},
|
||||
{"role": "user", "content": "trigger error"},
|
||||
],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=20,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
injection_callback=inject_cb,
|
||||
))
|
||||
|
||||
assert result.had_injections is True
|
||||
# Should cap: _MAX_INJECTION_CYCLES drained rounds + 1 final round that breaks
|
||||
assert call_count["n"] == _MAX_INJECTION_CYCLES + 1
|
||||
assert drain_count["n"] == _MAX_INJECTION_CYCLES
|
||||
|
||||
@ -2,7 +2,9 @@ import asyncio
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
# Check optional dingtalk dependencies before running tests
|
||||
@ -52,6 +54,21 @@ class _FakeHttp:
|
||||
return self._next_response()
|
||||
|
||||
|
||||
class _NetworkErrorHttp:
|
||||
"""HTTP client stub that raises httpx.TransportError on every request."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict] = []
|
||||
|
||||
async def post(self, url: str, json=None, headers=None, **kwargs):
|
||||
self.calls.append({"method": "POST", "url": url, "json": json, "headers": headers})
|
||||
raise httpx.ConnectError("Connection refused")
|
||||
|
||||
async def get(self, url: str, **kwargs):
|
||||
self.calls.append({"method": "GET", "url": url})
|
||||
raise httpx.ConnectError("Connection refused")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None:
|
||||
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"])
|
||||
@ -298,3 +315,141 @@ async def test_send_media_ref_zips_html_before_upload(tmp_path, monkeypatch) ->
|
||||
|
||||
archive = zipfile.ZipFile(BytesIO(captured["data"]))
|
||||
assert archive.namelist() == ["report.html"]
|
||||
|
||||
|
||||
# ── Exception handling tests ──────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_batch_message_propagates_transport_error() -> None:
|
||||
"""Network/transport errors must re-raise so callers can retry."""
|
||||
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
|
||||
channel = DingTalkChannel(config, MessageBus())
|
||||
channel._http = _NetworkErrorHttp()
|
||||
|
||||
with pytest.raises(httpx.ConnectError, match="Connection refused"):
|
||||
await channel._send_batch_message(
|
||||
"token",
|
||||
"user123",
|
||||
"sampleMarkdown",
|
||||
{"text": "hello", "title": "Nanobot Reply"},
|
||||
)
|
||||
|
||||
# The POST was attempted exactly once
|
||||
assert len(channel._http.calls) == 1
|
||||
assert channel._http.calls[0]["method"] == "POST"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_batch_message_returns_false_on_api_error() -> None:
|
||||
"""DingTalk API-level errors (non-200 status, errcode != 0) should return False."""
|
||||
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
|
||||
channel = DingTalkChannel(config, MessageBus())
|
||||
|
||||
# Non-200 status code → API error → return False
|
||||
channel._http = _FakeHttp(responses=[_FakeResponse(400, {"errcode": 400})])
|
||||
result = await channel._send_batch_message(
|
||||
"token", "user123", "sampleMarkdown", {"text": "hello"}
|
||||
)
|
||||
assert result is False
|
||||
|
||||
# 200 with non-zero errcode → API error → return False
|
||||
channel._http = _FakeHttp(responses=[_FakeResponse(200, {"errcode": 100})])
|
||||
result = await channel._send_batch_message(
|
||||
"token", "user123", "sampleMarkdown", {"text": "hello"}
|
||||
)
|
||||
assert result is False
|
||||
|
||||
# 200 with errcode=0 → success → return True
|
||||
channel._http = _FakeHttp(responses=[_FakeResponse(200, {"errcode": 0})])
|
||||
result = await channel._send_batch_message(
|
||||
"token", "user123", "sampleMarkdown", {"text": "hello"}
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_ref_short_circuits_on_transport_error() -> None:
|
||||
"""When the first send fails with a transport error, _send_media_ref must
|
||||
re-raise immediately instead of trying download+upload+fallback."""
|
||||
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
|
||||
channel = DingTalkChannel(config, MessageBus())
|
||||
channel._http = _NetworkErrorHttp()
|
||||
|
||||
# An image URL triggers the sampleImageMsg path first
|
||||
with pytest.raises(httpx.ConnectError, match="Connection refused"):
|
||||
await channel._send_media_ref("token", "user123", "https://example.com/photo.jpg")
|
||||
|
||||
# Only one POST should have been attempted — no download/upload/fallback
|
||||
assert len(channel._http.calls) == 1
|
||||
assert channel._http.calls[0]["method"] == "POST"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_ref_short_circuits_on_download_transport_error() -> None:
|
||||
"""When the image URL send returns an API error (False) but the download
|
||||
for the fallback hits a transport error, it must re-raise rather than
|
||||
silently returning False."""
|
||||
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
|
||||
channel = DingTalkChannel(config, MessageBus())
|
||||
|
||||
# First POST (sampleImageMsg) returns API error → False, then GET (download) raises transport error
|
||||
class _MixedHttp:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict] = []
|
||||
|
||||
async def post(self, url, json=None, headers=None, **kwargs):
|
||||
self.calls.append({"method": "POST", "url": url})
|
||||
# API-level failure: 200 with errcode != 0
|
||||
return _FakeResponse(200, {"errcode": 100})
|
||||
|
||||
async def get(self, url, **kwargs):
|
||||
self.calls.append({"method": "GET", "url": url})
|
||||
raise httpx.ConnectError("Connection refused")
|
||||
|
||||
channel._http = _MixedHttp()
|
||||
|
||||
with pytest.raises(httpx.ConnectError, match="Connection refused"):
|
||||
await channel._send_media_ref("token", "user123", "https://example.com/photo.jpg")
|
||||
|
||||
# Should have attempted POST (image URL) and GET (download), but NOT upload
|
||||
assert len(channel._http.calls) == 2
|
||||
assert channel._http.calls[0]["method"] == "POST"
|
||||
assert channel._http.calls[1]["method"] == "GET"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_ref_short_circuits_on_upload_transport_error() -> None:
|
||||
"""When download succeeds but upload hits a transport error, must re-raise."""
|
||||
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
|
||||
channel = DingTalkChannel(config, MessageBus())
|
||||
|
||||
image_bytes = b"\xff\xd8\xff\xe0" + b"\x00" * 100 # minimal JPEG-ish data
|
||||
|
||||
class _UploadFailsHttp:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict] = []
|
||||
|
||||
async def post(self, url, json=None, headers=None, files=None, **kwargs):
|
||||
self.calls.append({"method": "POST", "url": url})
|
||||
# If it's the upload endpoint, raise transport error
|
||||
if "media/upload" in url:
|
||||
raise httpx.ConnectError("Connection refused")
|
||||
# Otherwise (sampleImageMsg), return API error to trigger fallback
|
||||
return _FakeResponse(200, {"errcode": 100})
|
||||
|
||||
async def get(self, url, **kwargs):
|
||||
self.calls.append({"method": "GET", "url": url})
|
||||
resp = _FakeResponse(200)
|
||||
resp.content = image_bytes
|
||||
resp.headers = {"content-type": "image/jpeg"}
|
||||
return resp
|
||||
|
||||
channel._http = _UploadFailsHttp()
|
||||
|
||||
with pytest.raises(httpx.ConnectError, match="Connection refused"):
|
||||
await channel._send_media_ref("token", "user123", "https://example.com/photo.jpg")
|
||||
|
||||
# POST (image URL), GET (download), POST (upload) attempted — no further sends
|
||||
methods = [c["method"] for c in channel._http.calls]
|
||||
assert methods == ["POST", "GET", "POST"]
|
||||
|
||||
@ -867,3 +867,100 @@ async def test_start_no_proxy_auth_when_only_password(monkeypatch) -> None:
|
||||
assert channel.is_running is False
|
||||
assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890"
|
||||
assert _FakeDiscordClient.instances[0].proxy_auth is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for the send() exception propagation fix
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_re_raises_network_error() -> None:
|
||||
"""Network errors during send must propagate so ChannelManager can retry."""
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = _FakeDiscordClient(channel, intents=None)
|
||||
channel._client = client
|
||||
channel._running = True
|
||||
|
||||
async def _failing_send_outbound(msg: OutboundMessage) -> None:
|
||||
raise ConnectionError("network unreachable")
|
||||
|
||||
client.send_outbound = _failing_send_outbound # type: ignore[method-assign]
|
||||
|
||||
with pytest.raises(ConnectionError, match="network unreachable"):
|
||||
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_re_raises_generic_exception() -> None:
|
||||
"""Any exception from send_outbound must propagate, not be swallowed."""
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = _FakeDiscordClient(channel, intents=None)
|
||||
channel._client = client
|
||||
channel._running = True
|
||||
|
||||
async def _failing_send_outbound(msg: OutboundMessage) -> None:
|
||||
raise RuntimeError("discord API failure")
|
||||
|
||||
client.send_outbound = _failing_send_outbound # type: ignore[method-assign]
|
||||
|
||||
with pytest.raises(RuntimeError, match="discord API failure"):
|
||||
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_still_stops_typing_on_error() -> None:
|
||||
"""Typing cleanup must still run in the finally block even when send raises."""
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = _FakeDiscordClient(channel, intents=None)
|
||||
channel._client = client
|
||||
channel._running = True
|
||||
|
||||
# Start a typing task so we can verify it gets cleaned up
|
||||
start = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def slow_typing() -> None:
|
||||
start.set()
|
||||
await release.wait()
|
||||
|
||||
typing_channel = _FakeChannel(channel_id=123)
|
||||
typing_channel.typing_enter_hook = slow_typing
|
||||
await channel._start_typing(typing_channel)
|
||||
await asyncio.wait_for(start.wait(), timeout=1.0)
|
||||
|
||||
async def _failing_send_outbound(msg: OutboundMessage) -> None:
|
||||
raise ConnectionError("timeout")
|
||||
|
||||
client.send_outbound = _failing_send_outbound # type: ignore[method-assign]
|
||||
|
||||
with pytest.raises(ConnectionError, match="timeout"):
|
||||
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||
|
||||
release.set()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Typing should have been cleaned up by the finally block
|
||||
assert channel._typing_tasks == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_succeeds_normally() -> None:
|
||||
"""Successful sends should work without raising."""
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = _FakeDiscordClient(channel, intents=None)
|
||||
channel._client = client
|
||||
channel._running = True
|
||||
|
||||
sent_messages: list[OutboundMessage] = []
|
||||
|
||||
async def _capture_send_outbound(msg: OutboundMessage) -> None:
|
||||
sent_messages.append(msg)
|
||||
|
||||
client.send_outbound = _capture_send_outbound # type: ignore[method-assign]
|
||||
|
||||
msg = OutboundMessage(channel="discord", chat_id="123", content="hello world")
|
||||
await channel.send(msg)
|
||||
|
||||
assert len(sent_messages) == 1
|
||||
assert sent_messages[0].content == "hello world"
|
||||
assert sent_messages[0].chat_id == "123"
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -14,6 +15,8 @@ except ImportError:
|
||||
if not QQ_AVAILABLE:
|
||||
pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True)
|
||||
|
||||
import aiohttp
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.qq import QQChannel, QQConfig
|
||||
@ -170,3 +173,221 @@ async def test_read_media_bytes_missing_file() -> None:
|
||||
data, filename = await channel._read_media_bytes("/nonexistent/path/image.png")
|
||||
assert data is None
|
||||
assert filename is None
|
||||
|
||||
|
||||
# -------------------------------------------------------
|
||||
# Tests for _send_media exception handling
|
||||
# -------------------------------------------------------
|
||||
|
||||
def _make_channel_with_local_file(suffix: str = ".png", content: bytes = b"\x89PNG\r\n"):
|
||||
"""Create a QQChannel with a fake client and a temp file for media."""
|
||||
channel = QQChannel(
|
||||
QQConfig(app_id="app", secret="secret", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._client = _FakeClient()
|
||||
channel._chat_type_cache["user1"] = "c2c"
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
||||
tmp.write(content)
|
||||
tmp.close()
|
||||
return channel, tmp.name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_network_error_propagates() -> None:
|
||||
"""aiohttp.ClientError (network/transport) should re-raise, not return False."""
|
||||
channel, tmp_path = _make_channel_with_local_file()
|
||||
|
||||
# Make the base64 upload raise a network error
|
||||
channel._client.api._http = SimpleNamespace()
|
||||
channel._client.api._http.request = AsyncMock(
|
||||
side_effect=aiohttp.ServerDisconnectedError("connection lost"),
|
||||
)
|
||||
|
||||
with pytest.raises(aiohttp.ServerDisconnectedError):
|
||||
await channel._send_media(
|
||||
chat_id="user1",
|
||||
media_ref=tmp_path,
|
||||
msg_id="msg1",
|
||||
is_group=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_client_connector_error_propagates() -> None:
|
||||
"""aiohttp.ClientConnectorError (DNS/connection refused) should re-raise."""
|
||||
channel, tmp_path = _make_channel_with_local_file()
|
||||
|
||||
from aiohttp.client_reqrep import ConnectionKey
|
||||
conn_key = ConnectionKey("api.qq.com", 443, True, None, None, None, None)
|
||||
connector_error = aiohttp.ClientConnectorError(
|
||||
connection_key=conn_key,
|
||||
os_error=OSError("Connection refused"),
|
||||
)
|
||||
|
||||
channel._client.api._http = SimpleNamespace()
|
||||
channel._client.api._http.request = AsyncMock(
|
||||
side_effect=connector_error,
|
||||
)
|
||||
|
||||
with pytest.raises(aiohttp.ClientConnectorError):
|
||||
await channel._send_media(
|
||||
chat_id="user1",
|
||||
media_ref=tmp_path,
|
||||
msg_id="msg1",
|
||||
is_group=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_oserror_propagates() -> None:
|
||||
"""OSError (low-level I/O) should re-raise for retry."""
|
||||
channel, tmp_path = _make_channel_with_local_file()
|
||||
|
||||
channel._client.api._http = SimpleNamespace()
|
||||
channel._client.api._http.request = AsyncMock(
|
||||
side_effect=OSError("Network is unreachable"),
|
||||
)
|
||||
|
||||
with pytest.raises(OSError):
|
||||
await channel._send_media(
|
||||
chat_id="user1",
|
||||
media_ref=tmp_path,
|
||||
msg_id="msg1",
|
||||
is_group=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_api_error_returns_false() -> None:
|
||||
"""API-level errors (botpy RuntimeError subclasses) should return False, not raise."""
|
||||
channel, tmp_path = _make_channel_with_local_file()
|
||||
|
||||
# Simulate a botpy API error (e.g. ServerError is a RuntimeError subclass)
|
||||
from botpy.errors import ServerError
|
||||
|
||||
channel._client.api._http = SimpleNamespace()
|
||||
channel._client.api._http.request = AsyncMock(
|
||||
side_effect=ServerError("internal server error"),
|
||||
)
|
||||
|
||||
result = await channel._send_media(
|
||||
chat_id="user1",
|
||||
media_ref=tmp_path,
|
||||
msg_id="msg1",
|
||||
is_group=False,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_generic_runtime_error_returns_false() -> None:
|
||||
"""Generic RuntimeError (not network) should return False."""
|
||||
channel, tmp_path = _make_channel_with_local_file()
|
||||
|
||||
channel._client.api._http = SimpleNamespace()
|
||||
channel._client.api._http.request = AsyncMock(
|
||||
side_effect=RuntimeError("some API error"),
|
||||
)
|
||||
|
||||
result = await channel._send_media(
|
||||
chat_id="user1",
|
||||
media_ref=tmp_path,
|
||||
msg_id="msg1",
|
||||
is_group=False,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_value_error_returns_false() -> None:
|
||||
"""ValueError (bad API response data) should return False."""
|
||||
channel, tmp_path = _make_channel_with_local_file()
|
||||
|
||||
channel._client.api._http = SimpleNamespace()
|
||||
channel._client.api._http.request = AsyncMock(
|
||||
side_effect=ValueError("bad response data"),
|
||||
)
|
||||
|
||||
result = await channel._send_media(
|
||||
chat_id="user1",
|
||||
media_ref=tmp_path,
|
||||
msg_id="msg1",
|
||||
is_group=False,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_timeout_error_propagates() -> None:
|
||||
"""asyncio.TimeoutError inherits from Exception but not ClientError/OSError.
|
||||
However, aiohttp.ServerTimeoutError IS a ClientError subclass, so that propagates.
|
||||
For a plain TimeoutError (which is also OSError in Python 3.11+), it should propagate."""
|
||||
channel, tmp_path = _make_channel_with_local_file()
|
||||
|
||||
channel._client.api._http = SimpleNamespace()
|
||||
channel._client.api._http.request = AsyncMock(
|
||||
side_effect=aiohttp.ServerTimeoutError("request timed out"),
|
||||
)
|
||||
|
||||
with pytest.raises(aiohttp.ServerTimeoutError):
|
||||
await channel._send_media(
|
||||
chat_id="user1",
|
||||
media_ref=tmp_path,
|
||||
msg_id="msg1",
|
||||
is_group=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_fallback_text_on_api_error() -> None:
|
||||
"""When _send_media returns False (API error), send() should emit fallback text."""
|
||||
channel, tmp_path = _make_channel_with_local_file()
|
||||
|
||||
from botpy.errors import ServerError
|
||||
|
||||
channel._client.api._http = SimpleNamespace()
|
||||
channel._client.api._http.request = AsyncMock(
|
||||
side_effect=ServerError("internal server error"),
|
||||
)
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="qq",
|
||||
chat_id="user1",
|
||||
content="",
|
||||
media=[tmp_path],
|
||||
metadata={"message_id": "msg1"},
|
||||
)
|
||||
)
|
||||
|
||||
# Should have sent a fallback text message
|
||||
assert len(channel._client.api.c2c_calls) == 1
|
||||
fallback_content = channel._client.api.c2c_calls[0]["content"]
|
||||
assert "Attachment send failed" in fallback_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_propagates_network_error_no_fallback() -> None:
|
||||
"""When _send_media raises a network error, send() should NOT silently fallback."""
|
||||
channel, tmp_path = _make_channel_with_local_file()
|
||||
|
||||
channel._client.api._http = SimpleNamespace()
|
||||
channel._client.api._http.request = AsyncMock(
|
||||
side_effect=aiohttp.ServerDisconnectedError("connection lost"),
|
||||
)
|
||||
|
||||
with pytest.raises(aiohttp.ServerDisconnectedError):
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="qq",
|
||||
chat_id="user1",
|
||||
content="hello",
|
||||
media=[tmp_path],
|
||||
metadata={"message_id": "msg1"},
|
||||
)
|
||||
)
|
||||
|
||||
# No fallback text should have been sent
|
||||
assert len(channel._client.api.c2c_calls) == 0
|
||||
|
||||
@ -387,6 +387,84 @@ async def test_send_delta_stream_end_treats_not_modified_as_success() -> None:
|
||||
assert "123" not in channel._stream_bufs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_stream_end_does_not_fallback_on_network_timeout() -> None:
|
||||
"""TimedOut during HTML edit should propagate, never fall back to plain text."""
|
||||
from telegram.error import TimedOut
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
# _call_with_retry retries TimedOut up to 3 times, so the mock will be called
|
||||
# multiple times – but all calls must be with parse_mode="HTML" (no plain fallback).
|
||||
channel._app.bot.edit_message_text = AsyncMock(side_effect=TimedOut("network timeout"))
|
||||
channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0)
|
||||
|
||||
with pytest.raises(TimedOut, match="network timeout"):
|
||||
await channel.send_delta("123", "", {"_stream_end": True})
|
||||
|
||||
# Every call to edit_message_text must have used parse_mode="HTML" —
|
||||
# no plain-text fallback call should have been made.
|
||||
for call in channel._app.bot.edit_message_text.call_args_list:
|
||||
assert call.kwargs.get("parse_mode") == "HTML"
|
||||
# Buffer should still be present (not cleaned up on error)
|
||||
assert "123" in channel._stream_bufs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_stream_end_does_not_fallback_on_network_error() -> None:
|
||||
"""NetworkError during HTML edit should propagate, never fall back to plain text."""
|
||||
from telegram.error import NetworkError
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._app.bot.edit_message_text = AsyncMock(side_effect=NetworkError("connection reset"))
|
||||
channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0)
|
||||
|
||||
with pytest.raises(NetworkError, match="connection reset"):
|
||||
await channel.send_delta("123", "", {"_stream_end": True})
|
||||
|
||||
# Every call to edit_message_text must have used parse_mode="HTML" —
|
||||
# no plain-text fallback call should have been made.
|
||||
for call in channel._app.bot.edit_message_text.call_args_list:
|
||||
assert call.kwargs.get("parse_mode") == "HTML"
|
||||
# Buffer should still be present (not cleaned up on error)
|
||||
assert "123" in channel._stream_bufs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_stream_end_falls_back_on_bad_request() -> None:
|
||||
"""BadRequest (HTML parse error) should still trigger plain-text fallback."""
|
||||
from telegram.error import BadRequest
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
# First call (HTML) raises BadRequest, second call (plain) succeeds
|
||||
channel._app.bot.edit_message_text = AsyncMock(
|
||||
side_effect=[BadRequest("Can't parse entities"), None]
|
||||
)
|
||||
channel._stream_bufs["123"] = _StreamBuf(text="hello <bad>", message_id=7, last_edit=0.0)
|
||||
|
||||
await channel.send_delta("123", "", {"_stream_end": True})
|
||||
|
||||
# edit_message_text should have been called twice: once for HTML, once for plain fallback
|
||||
assert channel._app.bot.edit_message_text.call_count == 2
|
||||
# Second call should not use parse_mode="HTML"
|
||||
second_call_kwargs = channel._app.bot.edit_message_text.call_args_list[1].kwargs
|
||||
assert "parse_mode" not in second_call_kwargs or second_call_kwargs.get("parse_mode") is None
|
||||
# Buffer should be cleaned up on success
|
||||
assert "123" not in channel._stream_bufs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_stream_end_splits_oversized_reply() -> None:
|
||||
"""Final streamed reply exceeding Telegram limit is split into chunks."""
|
||||
@ -1159,3 +1237,159 @@ async def test_on_message_location_with_text() -> None:
|
||||
assert len(handled) == 1
|
||||
assert "meet me here" in handled[0]["content"]
|
||||
assert "[location: 51.5074, -0.1278]" in handled[0]["content"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for retry amplification fix (issue #3050)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_text_does_not_fallback_on_network_timeout() -> None:
|
||||
"""TimedOut should propagate immediately, NOT trigger plain-text fallback.
|
||||
|
||||
Before the fix, _send_text caught ALL exceptions (including TimedOut)
|
||||
and retried as plain text, doubling connection demand during pool
|
||||
exhaustion — see issue #3050.
|
||||
"""
|
||||
from telegram.error import TimedOut
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def always_timeout(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise TimedOut()
|
||||
|
||||
channel._app.bot.send_message = always_timeout
|
||||
|
||||
import nanobot.channels.telegram as tg_mod
|
||||
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
|
||||
try:
|
||||
with pytest.raises(TimedOut):
|
||||
await channel._send_text(123, "hello", None, {})
|
||||
finally:
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
|
||||
|
||||
# With the fix: only _call_with_retry's 3 HTML attempts (no plain fallback).
|
||||
# Before the fix: 3 HTML + 3 plain = 6 attempts.
|
||||
assert call_count == 3, (
|
||||
f"Expected 3 calls (HTML retries only), got {call_count} "
|
||||
"(plain-text fallback should not trigger on TimedOut)"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_text_does_not_fallback_on_network_error() -> None:
|
||||
"""NetworkError should propagate immediately, NOT trigger plain-text fallback."""
|
||||
from telegram.error import NetworkError
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def always_network_error(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise NetworkError("Connection reset")
|
||||
|
||||
channel._app.bot.send_message = always_network_error
|
||||
|
||||
import nanobot.channels.telegram as tg_mod
|
||||
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
|
||||
try:
|
||||
with pytest.raises(NetworkError):
|
||||
await channel._send_text(123, "hello", None, {})
|
||||
finally:
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
|
||||
|
||||
# _call_with_retry does NOT retry NetworkError (only TimedOut/RetryAfter),
|
||||
# so it raises after 1 attempt. The fix prevents plain-text fallback.
|
||||
# Before the fix: 1 HTML + 1 plain = 2. After the fix: 1 HTML only.
|
||||
assert call_count == 1, (
|
||||
f"Expected 1 call (HTML only, no plain fallback), got {call_count}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_text_falls_back_on_bad_request() -> None:
|
||||
"""BadRequest (HTML parse error) should still trigger plain-text fallback."""
|
||||
from telegram.error import BadRequest
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
original_send = channel._app.bot.send_message
|
||||
html_call_count = 0
|
||||
|
||||
async def html_fails(**kwargs):
|
||||
nonlocal html_call_count
|
||||
if kwargs.get("parse_mode") == "HTML":
|
||||
html_call_count += 1
|
||||
raise BadRequest("Can't parse entities")
|
||||
return await original_send(**kwargs)
|
||||
|
||||
channel._app.bot.send_message = html_fails
|
||||
|
||||
import nanobot.channels.telegram as tg_mod
|
||||
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
|
||||
try:
|
||||
await channel._send_text(123, "hello **world**", None, {})
|
||||
finally:
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
|
||||
|
||||
# HTML attempt failed with BadRequest → fallback to plain text succeeds.
|
||||
assert html_call_count == 1, f"Expected 1 HTML attempt, got {html_call_count}"
|
||||
assert len(channel._app.bot.sent_messages) == 1
|
||||
# Plain text send should NOT have parse_mode
|
||||
assert channel._app.bot.sent_messages[0].get("parse_mode") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_text_bad_request_plain_fallback_exhausted() -> None:
|
||||
"""When both HTML and plain-text fallback fail with BadRequest, the error propagates."""
|
||||
from telegram.error import BadRequest
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def always_bad_request(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise BadRequest("Bad request")
|
||||
|
||||
channel._app.bot.send_message = always_bad_request
|
||||
|
||||
import nanobot.channels.telegram as tg_mod
|
||||
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
|
||||
try:
|
||||
with pytest.raises(BadRequest):
|
||||
await channel._send_text(123, "hello", None, {})
|
||||
finally:
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
|
||||
|
||||
# _call_with_retry does NOT retry BadRequest (only TimedOut/RetryAfter),
|
||||
# so HTML fails after 1 attempt → fallback to plain also fails after 1 attempt.
|
||||
# Before the fix: 2 total. After the fix: still 2 (BadRequest SHOULD fallback).
|
||||
assert call_count == 2, f"Expected 2 calls (1 HTML + 1 plain), got {call_count}"
|
||||
|
||||
@ -1003,3 +1003,185 @@ async def test_download_media_item_non_image_requires_aes_key_even_with_full_url
|
||||
|
||||
assert saved_path is None
|
||||
channel._client.get.assert_not_awaited()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for media-send error classification (network vs non-network errors)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_outbound_msg(chat_id: str = "wx-user", content: str = "", media: list | None = None):
|
||||
"""Build a minimal OutboundMessage-like object for send() tests."""
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
|
||||
return OutboundMessage(
|
||||
channel="weixin",
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
media=media or [],
|
||||
metadata={},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_timeout_error_propagates_without_text_fallback() -> None:
|
||||
"""httpx.TimeoutException during media send must re-raise immediately,
|
||||
NOT fall back to _send_text (which would also fail during network issues)."""
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-1"
|
||||
channel._send_media_file = AsyncMock(side_effect=httpx.TimeoutException("timed out"))
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"])
|
||||
|
||||
with pytest.raises(httpx.TimeoutException, match="timed out"):
|
||||
await channel.send(msg)
|
||||
|
||||
# _send_text must NOT have been called as a fallback
|
||||
channel._send_text.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_transport_error_propagates_without_text_fallback() -> None:
|
||||
"""httpx.TransportError during media send must re-raise immediately."""
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-1"
|
||||
channel._send_media_file = AsyncMock(
|
||||
side_effect=httpx.TransportError("connection reset")
|
||||
)
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"])
|
||||
|
||||
with pytest.raises(httpx.TransportError, match="connection reset"):
|
||||
await channel.send(msg)
|
||||
|
||||
channel._send_text.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_5xx_http_status_error_propagates_without_text_fallback() -> None:
|
||||
"""httpx.HTTPStatusError with a 5xx status must re-raise immediately."""
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-1"
|
||||
|
||||
fake_response = httpx.Response(
|
||||
status_code=503,
|
||||
request=httpx.Request("POST", "https://example.test/upload"),
|
||||
)
|
||||
channel._send_media_file = AsyncMock(
|
||||
side_effect=httpx.HTTPStatusError(
|
||||
"Service Unavailable", request=fake_response.request, response=fake_response
|
||||
)
|
||||
)
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"])
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError, match="Service Unavailable"):
|
||||
await channel.send(msg)
|
||||
|
||||
channel._send_text.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_4xx_http_status_error_falls_back_to_text() -> None:
|
||||
"""httpx.HTTPStatusError with a 4xx status should fall back to text, not re-raise."""
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-1"
|
||||
|
||||
fake_response = httpx.Response(
|
||||
status_code=400,
|
||||
request=httpx.Request("POST", "https://example.test/upload"),
|
||||
)
|
||||
channel._send_media_file = AsyncMock(
|
||||
side_effect=httpx.HTTPStatusError(
|
||||
"Bad Request", request=fake_response.request, response=fake_response
|
||||
)
|
||||
)
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"])
|
||||
|
||||
# Should NOT raise — 4xx is a client error, non-retryable
|
||||
await channel.send(msg)
|
||||
|
||||
# _send_text should have been called with the fallback message
|
||||
channel._send_text.assert_awaited_once_with(
|
||||
"wx-user", "[Failed to send: photo.jpg]", "ctx-1"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_file_not_found_falls_back_to_text() -> None:
|
||||
"""FileNotFoundError (a non-network error) should fall back to text."""
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-1"
|
||||
channel._send_media_file = AsyncMock(
|
||||
side_effect=FileNotFoundError("Media file not found: /tmp/missing.jpg")
|
||||
)
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/missing.jpg"])
|
||||
|
||||
# Should NOT raise
|
||||
await channel.send(msg)
|
||||
|
||||
channel._send_text.assert_awaited_once_with(
|
||||
"wx-user", "[Failed to send: missing.jpg]", "ctx-1"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_value_error_falls_back_to_text() -> None:
|
||||
"""ValueError (e.g. unsupported format) should fall back to text."""
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-1"
|
||||
channel._send_media_file = AsyncMock(
|
||||
side_effect=ValueError("Unsupported media format")
|
||||
)
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/file.xyz"])
|
||||
|
||||
# Should NOT raise
|
||||
await channel.send(msg)
|
||||
|
||||
channel._send_text.assert_awaited_once_with(
|
||||
"wx-user", "[Failed to send: file.xyz]", "ctx-1"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_network_error_does_not_double_api_calls() -> None:
|
||||
"""During network issues, media send should make exactly 1 API call attempt,
|
||||
not 2 (media + text fallback). Verify total call count."""
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-1"
|
||||
channel._send_media_file = AsyncMock(
|
||||
side_effect=httpx.ConnectError("connection refused")
|
||||
)
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
msg = _make_outbound_msg(chat_id="wx-user", content="hello", media=["/tmp/img.png"])
|
||||
|
||||
with pytest.raises(httpx.ConnectError):
|
||||
await channel.send(msg)
|
||||
|
||||
# _send_media_file called once, _send_text never called
|
||||
channel._send_media_file.assert_awaited_once()
|
||||
channel._send_text.assert_not_awaited()
|
||||
|
||||
@ -4,6 +4,7 @@ from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
|
||||
def test_custom_provider_parse_handles_empty_choices() -> None:
|
||||
@ -53,3 +54,20 @@ def test_custom_provider_parse_chunks_accepts_plain_text_chunks() -> None:
|
||||
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.content == "hello world"
|
||||
|
||||
|
||||
def test_local_provider_502_error_includes_reachability_hint() -> None:
|
||||
spec = find_by_name("ollama")
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider(api_base="http://localhost:11434/v1", spec=spec)
|
||||
|
||||
result = provider._handle_error(
|
||||
Exception("Error code: 502"),
|
||||
spec=spec,
|
||||
api_base="http://localhost:11434/v1",
|
||||
)
|
||||
|
||||
assert result.finish_reason == "error"
|
||||
assert "local model endpoint" in result.content
|
||||
assert "http://localhost:11434/v1" in result.content
|
||||
assert "proxy/tunnel" in result.content
|
||||
|
||||
@ -356,6 +356,33 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries(
|
||||
assert "Available wrapped names: mcp_test_demo" in warnings[-1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_mcp_servers_logs_stdio_pollution_hint(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
messages: list[str] = []
|
||||
|
||||
def _error(message: str, *args: object) -> None:
|
||||
messages.append(message.format(*args))
|
||||
|
||||
@asynccontextmanager
|
||||
async def _broken_stdio_client(_params: object):
|
||||
raise RuntimeError("Parse error: Unexpected token 'INFO' before JSON-RPC headers")
|
||||
yield # pragma: no cover
|
||||
|
||||
monkeypatch.setattr(sys.modules["mcp.client.stdio"], "stdio_client", _broken_stdio_client)
|
||||
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.error", _error)
|
||||
|
||||
registry = ToolRegistry()
|
||||
stacks = await connect_mcp_servers({"gh": MCPServerConfig(command="github-mcp")}, registry)
|
||||
|
||||
assert stacks == {}
|
||||
assert messages
|
||||
assert "stdio protocol pollution" in messages[-1]
|
||||
assert "stdout" in messages[-1]
|
||||
assert "stderr" in messages[-1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_mcp_servers_one_failure_does_not_block_others(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
@ -47,3 +47,27 @@ def test_get_definitions_orders_builtins_then_mcp_tools() -> None:
|
||||
"mcp_fs_list",
|
||||
"mcp_git_status",
|
||||
]
|
||||
|
||||
|
||||
def test_prepare_call_read_file_rejects_non_object_params_with_actionable_hint() -> None:
|
||||
registry = ToolRegistry()
|
||||
registry.register(_FakeTool("read_file"))
|
||||
|
||||
tool, params, error = registry.prepare_call("read_file", ["foo.txt"])
|
||||
|
||||
assert tool is None
|
||||
assert params == ["foo.txt"]
|
||||
assert error is not None
|
||||
assert "must be a JSON object" in error
|
||||
assert "Use named parameters" in error
|
||||
|
||||
|
||||
def test_prepare_call_other_tools_keep_generic_object_validation() -> None:
|
||||
registry = ToolRegistry()
|
||||
registry.register(_FakeTool("grep"))
|
||||
|
||||
tool, params, error = registry.prepare_call("grep", ["TODO"])
|
||||
|
||||
assert tool is not None
|
||||
assert params == ["TODO"]
|
||||
assert error == "Error: Invalid parameters for tool 'grep': parameters must be an object, got list"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user