diff --git a/README.md b/README.md index b376d0991..f593e26ec 100644 --- a/README.md +++ b/README.md @@ -1053,6 +1053,30 @@ Connects directly to any OpenAI-compatible endpoint — LM Studio, llama.cpp, To ``` > For local servers that don't require a key, set `apiKey` to any non-empty string (e.g. `"no-key"`). +> +> `custom` is the right choice for providers that expose an OpenAI-compatible **chat completions** API. It does **not** force third-party endpoints onto the OpenAI/Azure **Responses API**. +> +> If your proxy or gateway is specifically Responses-API-compatible, use the `azure_openai` provider shape instead and point `apiBase` at that endpoint: +> +> ```json +> { +> "providers": { +> "azure_openai": { +> "apiKey": "your-api-key", +> "apiBase": "https://api.your-provider.com", +> "defaultModel": "your-model-name" +> } +> }, +> "agents": { +> "defaults": { +> "provider": "azure_openai", +> "model": "your-model-name" +> } +> } +> } +> ``` +> +> In short: **chat-completions-compatible endpoint → `custom`**; **Responses-compatible endpoint → `azure_openai`**. @@ -1858,6 +1882,19 @@ By default, the API binds to `127.0.0.1:8900`. You can change this in `config.js - Single-message input: each request must contain exactly one `user` message - Fixed model: omit `model`, or pass the same model shown by `/v1/models` - No streaming: `stream=true` is not supported +- API requests run in the synthetic `api` channel, so the `message` tool does **not** automatically deliver to Telegram/Discord/etc. To proactively send to another chat, call `message` with an explicit `channel` and `chat_id` for an enabled channel. + +Example tool call for cross-channel delivery from an API session: + +```json +{ + "content": "Build finished successfully.", + "channel": "telegram", + "chat_id": "123456789" +} +``` + +If `channel` points to a channel that is not enabled in your config, nanobot will queue the outbound event but no platform delivery will occur. ### Endpoints diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 5c4fbfc49..39e1ce23a 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -129,6 +129,7 @@ class AgentLoop: """ _RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint" + _PENDING_USER_TURN_KEY = "pending_user_turn" def __init__( self, @@ -621,6 +622,8 @@ class AgentLoop: session = self.sessions.get_or_create(key) if self._restore_runtime_checkpoint(session): self.sessions.save(session) + if self._restore_pending_user_turn(session): + self.sessions.save(session) session, pending = self.auto_compact.prepare_session(session, key) @@ -656,6 +659,8 @@ class AgentLoop: session = self.sessions.get_or_create(key) if self._restore_runtime_checkpoint(session): self.sessions.save(session) + if self._restore_pending_user_turn(session): + self.sessions.save(session) session, pending = self.auto_compact.prepare_session(session, key) @@ -696,6 +701,19 @@ class AgentLoop: ) ) + # Persist the triggering user message immediately, before running the + # agent loop. If the process is killed mid-turn (OOM, SIGKILL, self- + # restart, etc.), the existing runtime_checkpoint preserves the + # in-flight assistant/tool state but NOT the user message itself, so + # the user's prompt is silently lost on recovery. Saving it up front + # makes recovery possible from the session log alone. + user_persisted_early = False + if isinstance(msg.content, str) and msg.content.strip(): + session.add_message("user", msg.content) + self._mark_pending_user_turn(session) + self.sessions.save(session) + user_persisted_early = True + final_content, _, all_msgs, stop_reason, had_injections = await self._run_agent_loop( initial_messages, on_progress=on_progress or _bus_progress, @@ -711,7 +729,10 @@ class AgentLoop: if final_content is None or not final_content.strip(): final_content = EMPTY_FINAL_RESPONSE_MESSAGE - self._save_turn(session, all_msgs, 1 + len(history)) + # Skip the already-persisted user message when saving the turn + save_skip = 1 + len(history) + (1 if user_persisted_early else 0) + self._save_turn(session, all_msgs, save_skip) + self._clear_pending_user_turn(session) self._clear_runtime_checkpoint(session) self.sessions.save(session) self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) @@ -829,6 +850,12 @@ class AgentLoop: session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload self.sessions.save(session) + def _mark_pending_user_turn(self, session: Session) -> None: + session.metadata[self._PENDING_USER_TURN_KEY] = True + + def _clear_pending_user_turn(self, session: Session) -> None: + session.metadata.pop(self._PENDING_USER_TURN_KEY, None) + def _clear_runtime_checkpoint(self, session: Session) -> None: if self._RUNTIME_CHECKPOINT_KEY in session.metadata: session.metadata.pop(self._RUNTIME_CHECKPOINT_KEY, None) @@ -895,9 +922,30 @@ class AgentLoop: break session.messages.extend(restored_messages[overlap:]) + self._clear_pending_user_turn(session) self._clear_runtime_checkpoint(session) return True + def _restore_pending_user_turn(self, session: Session) -> bool: + """Close a turn that only persisted the user message before crashing.""" + from datetime import datetime + + if not session.metadata.get(self._PENDING_USER_TURN_KEY): + return False + + if session.messages and session.messages[-1].get("role") == "user": + session.messages.append( + { + "role": "assistant", + "content": "Error: Task interrupted before a response was generated.", + "timestamp": datetime.now().isoformat(), + } + ) + session.updated_at = datetime.now() + + self._clear_pending_user_turn(session) + return True + async def process_direct( self, content: str, diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index e92d864f2..592af9de2 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -134,6 +134,50 @@ class AgentRunner: continue messages.append(injection) + async def _try_drain_injections( + self, + spec: AgentRunSpec, + messages: list[dict[str, Any]], + assistant_message: dict[str, Any] | None, + injection_cycles: int, + *, + phase: str = "after error", + iteration: int | None = None, + ) -> tuple[bool, int]: + """Drain pending injections. Returns (should_continue, updated_cycles). + + If injections are found and we haven't exceeded _MAX_INJECTION_CYCLES, + append them to *messages* (and emit a checkpoint if *assistant_message* + and *iteration* are both provided) and return (True, cycles+1) so the + caller continues the iteration loop. Otherwise return (False, cycles). + """ + if injection_cycles >= _MAX_INJECTION_CYCLES: + return False, injection_cycles + injections = await self._drain_injections(spec) + if not injections: + return False, injection_cycles + injection_cycles += 1 + if assistant_message is not None: + messages.append(assistant_message) + if iteration is not None: + await self._emit_checkpoint( + spec, + { + "phase": "final_response", + "iteration": iteration, + "model": spec.model, + "assistant_message": assistant_message, + "completed_tool_results": [], + "pending_tool_calls": [], + }, + ) + self._append_injected_messages(messages, injections) + logger.info( + "Injected {} follow-up message(s) {} ({}/{})", + len(injections), phase, injection_cycles, _MAX_INJECTION_CYCLES, + ) + return True, injection_cycles + async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]: """Drain pending user messages via the injection callback. @@ -287,6 +331,13 @@ class AgentRunner: context.error = error context.stop_reason = stop_reason await hook.after_iteration(context) + should_continue, injection_cycles = await self._try_drain_injections( + spec, messages, None, injection_cycles, + phase="after tool error", + ) + if should_continue: + had_injections = True + continue break await self._emit_checkpoint( spec, @@ -302,16 +353,12 @@ class AgentRunner: empty_content_retries = 0 length_recovery_count = 0 # Checkpoint 1: drain injections after tools, before next LLM call - if injection_cycles < _MAX_INJECTION_CYCLES: - injections = await self._drain_injections(spec) - if injections: - had_injections = True - injection_cycles += 1 - self._append_injected_messages(messages, injections) - logger.info( - "Injected {} follow-up message(s) after tool execution ({}/{})", - len(injections), injection_cycles, _MAX_INJECTION_CYCLES, - ) + _drained, injection_cycles = await self._try_drain_injections( + spec, messages, None, injection_cycles, + phase="after tool execution", + ) + if _drained: + had_injections = True await hook.after_iteration(context) continue @@ -379,36 +426,18 @@ class AgentRunner: # Check for mid-turn injections BEFORE signaling stream end. # If injections are found we keep the stream alive (resuming=True) # so streaming channels don't prematurely finalize the card. - _injected_after_final = False - if injection_cycles < _MAX_INJECTION_CYCLES: - injections = await self._drain_injections(spec) - if injections: - had_injections = True - injection_cycles += 1 - _injected_after_final = True - if assistant_message is not None: - messages.append(assistant_message) - await self._emit_checkpoint( - spec, - { - "phase": "final_response", - "iteration": iteration, - "model": spec.model, - "assistant_message": assistant_message, - "completed_tool_results": [], - "pending_tool_calls": [], - }, - ) - self._append_injected_messages(messages, injections) - logger.info( - "Injected {} follow-up message(s) after final response ({}/{})", - len(injections), injection_cycles, _MAX_INJECTION_CYCLES, - ) + should_continue, injection_cycles = await self._try_drain_injections( + spec, messages, assistant_message, injection_cycles, + phase="after final response", + iteration=iteration, + ) + if should_continue: + had_injections = True if hook.wants_streaming(): - await hook.on_stream_end(context, resuming=_injected_after_final) + await hook.on_stream_end(context, resuming=should_continue) - if _injected_after_final: + if should_continue: await hook.after_iteration(context) continue @@ -421,6 +450,13 @@ class AgentRunner: context.error = error context.stop_reason = stop_reason await hook.after_iteration(context) + should_continue, injection_cycles = await self._try_drain_injections( + spec, messages, None, injection_cycles, + phase="after LLM error", + ) + if should_continue: + had_injections = True + continue break if is_blank_text(clean): final_content = EMPTY_FINAL_RESPONSE_MESSAGE @@ -431,6 +467,13 @@ class AgentRunner: context.error = error context.stop_reason = stop_reason await hook.after_iteration(context) + should_continue, injection_cycles = await self._try_drain_injections( + spec, messages, None, injection_cycles, + phase="after empty response", + ) + if should_continue: + had_injections = True + continue break messages.append(assistant_message or build_assistant_message( @@ -467,6 +510,17 @@ class AgentRunner: max_iterations=spec.max_iterations, ) self._append_final_message(messages, final_content) + # Drain any remaining injections so they are appended to the + # conversation history instead of being re-published as + # independent inbound messages by _dispatch's finally block. + # We ignore should_continue here because the for-loop has already + # exhausted all iterations. + drained_after_max_iterations, injection_cycles = await self._try_drain_injections( + spec, messages, None, injection_cycles, + phase="after max_iterations", + ) + if drained_after_max_iterations: + had_injections = True return AgentRunResult( final_content=final_content, diff --git a/nanobot/agent/tools/mcp.py b/nanobot/agent/tools/mcp.py index 1b5a71322..2aea19279 100644 --- a/nanobot/agent/tools/mcp.py +++ b/nanobot/agent/tools/mcp.py @@ -454,7 +454,23 @@ async def connect_mcp_servers( return name, server_stack except Exception as e: - logger.error("MCP server '{}': failed to connect: {}", name, e) + hint = "" + text = str(e).lower() + if any( + marker in text + for marker in ( + "parse error", + "invalid json", + "unexpected token", + "jsonrpc", + "content-length", + ) + ): + hint = ( + " Hint: this looks like stdio protocol pollution. Make sure the MCP server writes " + "only JSON-RPC to stdout and sends logs/debug output to stderr instead." + ) + logger.error("MCP server '{}': failed to connect: {}{}", name, e, hint) try: await server_stack.aclose() except Exception: diff --git a/nanobot/agent/tools/registry.py b/nanobot/agent/tools/registry.py index 99d3ec63a..137038c0c 100644 --- a/nanobot/agent/tools/registry.py +++ b/nanobot/agent/tools/registry.py @@ -68,6 +68,13 @@ class ToolRegistry: params: dict[str, Any], ) -> tuple[Tool | None, dict[str, Any], str | None]: """Resolve, cast, and validate one tool call.""" + # Guard against invalid parameter types (e.g., list instead of dict) + if not isinstance(params, dict) and name in ('write_file', 'read_file'): + return None, params, ( + f"Error: Tool '{name}' parameters must be a JSON object, got {type(params).__name__}. " + "Use named parameters: tool_name(param1=\"value1\", param2=\"value2\")" + ) + tool = self._tools.get(name) if not tool: return None, params, ( diff --git a/nanobot/channels/dingtalk.py b/nanobot/channels/dingtalk.py index 39b5818bd..a863ba0df 100644 --- a/nanobot/channels/dingtalk.py +++ b/nanobot/channels/dingtalk.py @@ -337,6 +337,9 @@ class DingTalkChannel(BaseChannel): content_type = (resp.headers.get("content-type") or "").split(";")[0].strip() filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref)) return resp.content, filename, content_type or None + except httpx.TransportError as e: + logger.error("DingTalk media download network error ref={} err={}", media_ref, e) + raise except Exception as e: logger.error("DingTalk media download error ref={} err={}", media_ref, e) return None, None, None @@ -388,6 +391,9 @@ class DingTalkChannel(BaseChannel): logger.error("DingTalk media upload missing media_id body={}", text[:500]) return None return str(media_id) + except httpx.TransportError as e: + logger.error("DingTalk media upload network error type={} err={}", media_type, e) + raise except Exception as e: logger.error("DingTalk media upload error type={} err={}", media_type, e) return None @@ -437,6 +443,9 @@ class DingTalkChannel(BaseChannel): return False logger.debug("DingTalk message sent to {} with msgKey={}", chat_id, msg_key) return True + except httpx.TransportError as e: + logger.error("DingTalk network error sending message msgKey={} err={}", msg_key, e) + raise except Exception as e: logger.error("Error sending DingTalk message msgKey={} err={}", msg_key, e) return False diff --git a/nanobot/channels/discord.py b/nanobot/channels/discord.py index 6e8c673a3..336b6148d 100644 --- a/nanobot/channels/discord.py +++ b/nanobot/channels/discord.py @@ -366,6 +366,7 @@ class DiscordChannel(BaseChannel): await client.send_outbound(msg) except Exception as e: logger.error("Error sending Discord message: {}", e) + raise finally: if not is_progress: await self._stop_typing(msg.chat_id) diff --git a/nanobot/channels/qq.py b/nanobot/channels/qq.py index 484eed6e2..f109f6da6 100644 --- a/nanobot/channels/qq.py +++ b/nanobot/channels/qq.py @@ -280,6 +280,9 @@ class QQChannel(BaseChannel): msg_id=msg_id, content=msg.content.strip(), ) + except (aiohttp.ClientError, OSError): + # Network / transport errors — propagate so ChannelManager can retry + raise except Exception: logger.exception("Error sending QQ message to chat_id={}", msg.chat_id) @@ -362,7 +365,12 @@ class QQChannel(BaseChannel): logger.info("QQ media sent: {}", filename) return True + except (aiohttp.ClientError, OSError) as e: + # Network / transport errors — propagate for retry by caller + logger.warning("QQ send media network error filename={} err={}", filename, e) + raise except Exception as e: + # API-level or other non-network errors — return False so send() can fallback logger.error("QQ send media failed filename={} err={}", filename, e) return False diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py index 2dde232b1..f63704aa7 100644 --- a/nanobot/channels/telegram.py +++ b/nanobot/channels/telegram.py @@ -520,7 +520,10 @@ class TelegramChannel(BaseChannel): reply_parameters=reply_params, **(thread_kwargs or {}), ) - except Exception as e: + except BadRequest as e: + # Only fall back to plain text on actual HTML parse/format errors. + # Network errors (TimedOut, NetworkError) should propagate immediately + # to avoid doubling connection demand during pool exhaustion. logger.warning("HTML parse failed, falling back to plain text: {}", e) try: await self._call_with_retry( @@ -567,7 +570,10 @@ class TelegramChannel(BaseChannel): chat_id=int_chat_id, message_id=buf.message_id, text=html, parse_mode="HTML", ) - except Exception as e: + except BadRequest as e: + # Only fall back to plain text on actual HTML parse/format errors. + # Network errors (TimedOut, NetworkError) should propagate immediately + # to avoid doubling connection demand during pool exhaustion. if self._is_not_modified_error(e): logger.debug("Final stream edit already applied for {}", chat_id) self._stream_bufs.pop(chat_id, None) diff --git a/nanobot/channels/weixin.py b/nanobot/channels/weixin.py index 3f87e2203..fbe84bcf8 100644 --- a/nanobot/channels/weixin.py +++ b/nanobot/channels/weixin.py @@ -985,7 +985,43 @@ class WeixinChannel(BaseChannel): for media_path in (msg.media or []): try: await self._send_media_file(msg.chat_id, media_path, ctx_token) + except (httpx.TimeoutException, httpx.TransportError) as net_err: + # Network/transport errors: do NOT fall back to text — + # the text send would also likely fail, and the outer + # except will re-raise so ChannelManager retries properly. + logger.error( + "Network error sending WeChat media {}: {}", + media_path, + net_err, + ) + raise + except httpx.HTTPStatusError as http_err: + status_code = ( + http_err.response.status_code + if http_err.response is not None + else 0 + ) + if status_code >= 500: + # Server-side / retryable HTTP error — same as network. + logger.error( + "Server error ({} {}) sending WeChat media {}: {}", + status_code, + http_err.response.reason_phrase + if http_err.response is not None + else "", + media_path, + http_err, + ) + raise + # 4xx client errors are NOT retryable — fall back to text. + filename = Path(media_path).name + logger.error("Failed to send WeChat media {}: {}", media_path, http_err) + await self._send_text( + msg.chat_id, f"[Failed to send: {filename}]", ctx_token, + ) except Exception as e: + # Non-network errors (format, file-not-found, etc.): + # notify the user via text fallback. filename = Path(media_path).name logger.error("Failed to send WeChat media {}: {}", media_path, e) # Notify user about failure via text diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 101ee6c33..4dea2d5fc 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -799,7 +799,12 @@ class OpenAICompatProvider(LLMProvider): } @staticmethod - def _handle_error(e: Exception) -> LLMResponse: + def _handle_error( + e: Exception, + *, + spec: ProviderSpec | None = None, + api_base: str | None = None, + ) -> LLMResponse: body = ( getattr(e, "doc", None) or getattr(e, "body", None) @@ -807,6 +812,15 @@ class OpenAICompatProvider(LLMProvider): ) body_text = body if isinstance(body, str) else str(body) if body is not None else "" msg = f"Error: {body_text.strip()[:500]}" if body_text.strip() else f"Error calling LLM: {e}" + + text = f"{body_text} {e}".lower() + if spec and spec.is_local and ("502" in text or "connection" in text or "refused" in text): + msg += ( + "\nHint: this is a local model endpoint. Check that the local server is reachable at " + f"{api_base or spec.default_api_base}, and if you are using a proxy/tunnel, make sure it " + "can reach your local Ollama/vLLM service instead of routing localhost through the remote host." + ) + response = getattr(e, "response", None) retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None)) if retry_after is None: @@ -850,7 +864,7 @@ class OpenAICompatProvider(LLMProvider): ) return self._parse(await self._client.chat.completions.create(**kwargs)) except Exception as e: - return self._handle_error(e) + return self._handle_error(e, spec=self._spec, api_base=self.api_base) async def chat_stream( self, @@ -933,7 +947,7 @@ class OpenAICompatProvider(LLMProvider): error_kind="timeout", ) except Exception as e: - return self._handle_error(e) + return self._handle_error(e, spec=self._spec, api_base=self.api_base) def get_default_model(self) -> str: return self.default_model diff --git a/tests/agent/test_loop_save_turn.py b/tests/agent/test_loop_save_turn.py index 8a0b54b86..c965ccd8c 100644 --- a/tests/agent/test_loop_save_turn.py +++ b/tests/agent/test_loop_save_turn.py @@ -1,5 +1,12 @@ +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + from nanobot.agent.context import ContextBuilder from nanobot.agent.loop import AgentLoop +from nanobot.bus.events import InboundMessage +from nanobot.bus.queue import MessageBus from nanobot.session.manager import Session @@ -11,6 +18,12 @@ def _mk_loop() -> AgentLoop: return loop +def _make_full_loop(tmp_path: Path) -> AgentLoop: + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + return AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model") + + def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None: loop = _mk_loop() session = Session(key="test:runtime-only") @@ -200,3 +213,98 @@ def test_restore_runtime_checkpoint_dedupes_overlapping_tail() -> None: assert session.messages[0]["role"] == "assistant" assert session.messages[1]["tool_call_id"] == "call_done" assert session.messages[2]["tool_call_id"] == "call_pending" + + +@pytest.mark.asyncio +async def test_process_message_persists_user_message_before_turn_completes(tmp_path: Path) -> None: + loop = _make_full_loop(tmp_path) + loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] + loop._run_agent_loop = AsyncMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign] + + msg = InboundMessage(channel="feishu", sender_id="u1", chat_id="c1", content="persist me") + with pytest.raises(RuntimeError, match="boom"): + await loop._process_message(msg) + + loop.sessions.invalidate("feishu:c1") + persisted = loop.sessions.get_or_create("feishu:c1") + assert [m["role"] for m in persisted.messages] == ["user"] + assert persisted.messages[0]["content"] == "persist me" + assert persisted.metadata.get(AgentLoop._PENDING_USER_TURN_KEY) is True + assert persisted.updated_at >= persisted.created_at + + +@pytest.mark.asyncio +async def test_process_message_does_not_duplicate_early_persisted_user_message(tmp_path: Path) -> None: + loop = _make_full_loop(tmp_path) + loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] + loop._run_agent_loop = AsyncMock(return_value=( + "done", + None, + [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "done"}, + ], + "stop", + False, + )) # type: ignore[method-assign] + + result = await loop._process_message( + InboundMessage(channel="feishu", sender_id="u1", chat_id="c2", content="hello") + ) + + assert result is not None + assert result.content == "done" + session = loop.sessions.get_or_create("feishu:c2") + assert [ + {k: v for k, v in m.items() if k in {"role", "content"}} + for m in session.messages + ] == [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "done"}, + ] + assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata + + +@pytest.mark.asyncio +async def test_next_turn_after_crash_closes_pending_user_turn_before_new_input(tmp_path: Path) -> None: + loop = _make_full_loop(tmp_path) + loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] + loop.provider.chat_with_retry = AsyncMock(return_value=MagicMock()) # unused because _run_agent_loop is stubbed + + session = loop.sessions.get_or_create("feishu:c3") + session.add_message("user", "old question") + session.metadata[AgentLoop._PENDING_USER_TURN_KEY] = True + loop.sessions.save(session) + + loop._run_agent_loop = AsyncMock(return_value=( + "new answer", + None, + [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "old question"}, + {"role": "assistant", "content": "Error: Task interrupted before a response was generated."}, + {"role": "user", "content": "new question"}, + {"role": "assistant", "content": "new answer"}, + ], + "stop", + False, + )) # type: ignore[method-assign] + + result = await loop._process_message( + InboundMessage(channel="feishu", sender_id="u1", chat_id="c3", content="new question") + ) + + assert result is not None + assert result.content == "new answer" + session = loop.sessions.get_or_create("feishu:c3") + assert [ + {k: v for k, v in m.items() if k in {"role", "content"}} + for m in session.messages + ] == [ + {"role": "user", "content": "old question"}, + {"role": "assistant", "content": "Error: Task interrupted before a response was generated."}, + {"role": "user", "content": "new question"}, + {"role": "assistant", "content": "new answer"}, + ] + assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index a62457aa8..74025d779 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -18,6 +18,16 @@ from nanobot.providers.base import LLMResponse, ToolCallRequest _MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars +def _make_injection_callback(queue: asyncio.Queue): + """Return an async callback that drains *queue* into a list of dicts.""" + async def inject_cb(): + items = [] + while not queue.empty(): + items.append(await queue.get()) + return items + return inject_cb + + def _make_loop(tmp_path): from nanobot.agent.loop import AgentLoop from nanobot.bus.queue import MessageBus @@ -1888,12 +1898,7 @@ async def test_checkpoint1_injects_after_tool_execution(): tools.execute = AsyncMock(return_value="file content") injection_queue = asyncio.Queue() - - async def inject_cb(): - items = [] - while not injection_queue.empty(): - items.append(await injection_queue.get()) - return items + inject_cb = _make_injection_callback(injection_queue) # Put a follow-up message in the queue before the run starts await injection_queue.put( @@ -1951,12 +1956,7 @@ async def test_checkpoint2_injects_after_final_response_with_resuming_stream(): tools.get_definitions.return_value = [] injection_queue = asyncio.Queue() - - async def inject_cb(): - items = [] - while not injection_queue.empty(): - items.append(await injection_queue.get()) - return items + inject_cb = _make_injection_callback(injection_queue) # Inject a follow-up that arrives during the first response await injection_queue.put( @@ -2005,12 +2005,7 @@ async def test_checkpoint2_preserves_final_response_in_history_before_followup() tools.get_definitions.return_value = [] injection_queue = asyncio.Queue() - - async def inject_cb(): - items = [] - while not injection_queue.empty(): - items.append(await injection_queue.get()) - return items + inject_cb = _make_injection_callback(injection_queue) await injection_queue.put( InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question") @@ -2410,3 +2405,330 @@ async def test_dispatch_republishes_leftover_queue_messages(tmp_path): contents = [m.content for m in msgs] assert "leftover-1" in contents assert "leftover-2" in contents + + +@pytest.mark.asyncio +async def test_drain_injections_on_fatal_tool_error(): + """Pending injections should be drained even when a fatal tool error occurs.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="", + tool_calls=[ToolCallRequest(id="c1", name="exec", arguments={"cmd": "bad"})], + usage={}, + ) + # Second call: respond normally to the injected follow-up + return LLMResponse(content="reply to follow-up", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(side_effect=RuntimeError("tool exploded")) + + injection_queue = asyncio.Queue() + inject_cb = _make_injection_callback(injection_queue) + + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after error") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + fail_on_tool_error=True, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + assert result.final_content == "reply to follow-up" + # The injection should be in the messages history + injected = [ + m for m in result.messages + if m.get("role") == "user" and m.get("content") == "follow-up after error" + ] + assert len(injected) == 1 + + +@pytest.mark.asyncio +async def test_drain_injections_on_llm_error(): + """Pending injections should be drained when the LLM returns an error finish_reason.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content=None, + tool_calls=[], + finish_reason="error", + usage={}, + ) + # Second call: respond normally to the injected follow-up + return LLMResponse(content="recovered answer", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + injection_queue = asyncio.Queue() + inject_cb = _make_injection_callback(injection_queue) + + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after LLM error") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "previous response"}, + {"role": "user", "content": "trigger error"}, + ], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + assert result.final_content == "recovered answer" + injected = [ + m for m in result.messages + if m.get("role") == "user" and "follow-up after LLM error" in str(m.get("content", "")) + ] + assert len(injected) == 1 + + +@pytest.mark.asyncio +async def test_drain_injections_on_empty_final_response(): + """Pending injections should be drained when the runner exits due to empty response.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_EMPTY_RETRIES + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] <= _MAX_EMPTY_RETRIES + 1: + return LLMResponse(content="", tool_calls=[], usage={}) + # After retries exhausted + injection drain, respond normally + return LLMResponse(content="answer after empty", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + injection_queue = asyncio.Queue() + inject_cb = _make_injection_callback(injection_queue) + + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after empty") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "previous response"}, + {"role": "user", "content": "trigger empty"}, + ], + tools=tools, + model="test-model", + max_iterations=10, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + assert result.final_content == "answer after empty" + injected = [ + m for m in result.messages + if m.get("role") == "user" and "follow-up after empty" in str(m.get("content", "")) + ] + assert len(injected) == 1 + + +@pytest.mark.asyncio +async def test_drain_injections_on_max_iterations(): + """Pending injections should be drained when the runner hits max_iterations. + + Unlike other error paths, max_iterations cannot continue the loop, so + injections are appended to messages but not processed by the LLM. + The key point is they are consumed from the queue to prevent re-publish. + """ + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + return LLMResponse( + content="", + tool_calls=[ToolCallRequest(id=f"c{call_count['n']}", name="read_file", arguments={"path": "x"})], + usage={}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="file content") + + injection_queue = asyncio.Queue() + inject_cb = _make_injection_callback(injection_queue) + + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after max iters") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=2, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.stop_reason == "max_iterations" + assert result.had_injections is True + # The injection was consumed from the queue (preventing re-publish) + assert injection_queue.empty() + # The injection message is appended to conversation history + injected = [ + m for m in result.messages + if m.get("role") == "user" and m.get("content") == "follow-up after max iters" + ] + assert len(injected) == 1 + + +@pytest.mark.asyncio +async def test_drain_injections_set_flag_when_followup_arrives_after_last_iteration(): + """Late follow-ups drained in max_iterations should still flip had_injections.""" + from nanobot.agent.hook import AgentHook + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + return LLMResponse( + content="", + tool_calls=[ToolCallRequest(id=f"c{call_count['n']}", name="read_file", arguments={"path": "x"})], + usage={}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="file content") + + injection_queue = asyncio.Queue() + inject_cb = _make_injection_callback(injection_queue) + + class InjectOnLastAfterIterationHook(AgentHook): + def __init__(self) -> None: + self.after_iteration_calls = 0 + + async def after_iteration(self, context) -> None: + self.after_iteration_calls += 1 + if self.after_iteration_calls == 2: + await injection_queue.put( + InboundMessage( + channel="cli", + sender_id="u", + chat_id="c", + content="late follow-up after max iters", + ) + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=2, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + hook=InjectOnLastAfterIterationHook(), + )) + + assert result.stop_reason == "max_iterations" + assert result.had_injections is True + assert injection_queue.empty() + injected = [ + m for m in result.messages + if m.get("role") == "user" and m.get("content") == "late follow-up after max iters" + ] + assert len(injected) == 1 + + +@pytest.mark.asyncio +async def test_injection_cycle_cap_on_error_path(): + """Injection cycles should be capped even when every iteration hits an LLM error.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTION_CYCLES + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + return LLMResponse( + content=None, + tool_calls=[], + finish_reason="error", + usage={}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + drain_count = {"n": 0} + + async def inject_cb(): + drain_count["n"] += 1 + if drain_count["n"] <= _MAX_INJECTION_CYCLES: + return [InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg-{drain_count['n']}")] + return [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "previous"}, + {"role": "user", "content": "trigger error"}, + ], + tools=tools, + model="test-model", + max_iterations=20, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + # Should cap: _MAX_INJECTION_CYCLES drained rounds + 1 final round that breaks + assert call_count["n"] == _MAX_INJECTION_CYCLES + 1 + assert drain_count["n"] == _MAX_INJECTION_CYCLES diff --git a/tests/channels/test_dingtalk_channel.py b/tests/channels/test_dingtalk_channel.py index f743c4e62..86de99bb5 100644 --- a/tests/channels/test_dingtalk_channel.py +++ b/tests/channels/test_dingtalk_channel.py @@ -2,7 +2,9 @@ import asyncio import zipfile from io import BytesIO from types import SimpleNamespace +from unittest.mock import AsyncMock +import httpx import pytest # Check optional dingtalk dependencies before running tests @@ -52,6 +54,21 @@ class _FakeHttp: return self._next_response() +class _NetworkErrorHttp: + """HTTP client stub that raises httpx.TransportError on every request.""" + + def __init__(self) -> None: + self.calls: list[dict] = [] + + async def post(self, url: str, json=None, headers=None, **kwargs): + self.calls.append({"method": "POST", "url": url, "json": json, "headers": headers}) + raise httpx.ConnectError("Connection refused") + + async def get(self, url: str, **kwargs): + self.calls.append({"method": "GET", "url": url}) + raise httpx.ConnectError("Connection refused") + + @pytest.mark.asyncio async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None: config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]) @@ -298,3 +315,141 @@ async def test_send_media_ref_zips_html_before_upload(tmp_path, monkeypatch) -> archive = zipfile.ZipFile(BytesIO(captured["data"])) assert archive.namelist() == ["report.html"] + + +# ── Exception handling tests ────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_send_batch_message_propagates_transport_error() -> None: + """Network/transport errors must re-raise so callers can retry.""" + config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]) + channel = DingTalkChannel(config, MessageBus()) + channel._http = _NetworkErrorHttp() + + with pytest.raises(httpx.ConnectError, match="Connection refused"): + await channel._send_batch_message( + "token", + "user123", + "sampleMarkdown", + {"text": "hello", "title": "Nanobot Reply"}, + ) + + # The POST was attempted exactly once + assert len(channel._http.calls) == 1 + assert channel._http.calls[0]["method"] == "POST" + + +@pytest.mark.asyncio +async def test_send_batch_message_returns_false_on_api_error() -> None: + """DingTalk API-level errors (non-200 status, errcode != 0) should return False.""" + config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]) + channel = DingTalkChannel(config, MessageBus()) + + # Non-200 status code → API error → return False + channel._http = _FakeHttp(responses=[_FakeResponse(400, {"errcode": 400})]) + result = await channel._send_batch_message( + "token", "user123", "sampleMarkdown", {"text": "hello"} + ) + assert result is False + + # 200 with non-zero errcode → API error → return False + channel._http = _FakeHttp(responses=[_FakeResponse(200, {"errcode": 100})]) + result = await channel._send_batch_message( + "token", "user123", "sampleMarkdown", {"text": "hello"} + ) + assert result is False + + # 200 with errcode=0 → success → return True + channel._http = _FakeHttp(responses=[_FakeResponse(200, {"errcode": 0})]) + result = await channel._send_batch_message( + "token", "user123", "sampleMarkdown", {"text": "hello"} + ) + assert result is True + + +@pytest.mark.asyncio +async def test_send_media_ref_short_circuits_on_transport_error() -> None: + """When the first send fails with a transport error, _send_media_ref must + re-raise immediately instead of trying download+upload+fallback.""" + config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]) + channel = DingTalkChannel(config, MessageBus()) + channel._http = _NetworkErrorHttp() + + # An image URL triggers the sampleImageMsg path first + with pytest.raises(httpx.ConnectError, match="Connection refused"): + await channel._send_media_ref("token", "user123", "https://example.com/photo.jpg") + + # Only one POST should have been attempted — no download/upload/fallback + assert len(channel._http.calls) == 1 + assert channel._http.calls[0]["method"] == "POST" + + +@pytest.mark.asyncio +async def test_send_media_ref_short_circuits_on_download_transport_error() -> None: + """When the image URL send returns an API error (False) but the download + for the fallback hits a transport error, it must re-raise rather than + silently returning False.""" + config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]) + channel = DingTalkChannel(config, MessageBus()) + + # First POST (sampleImageMsg) returns API error → False, then GET (download) raises transport error + class _MixedHttp: + def __init__(self) -> None: + self.calls: list[dict] = [] + + async def post(self, url, json=None, headers=None, **kwargs): + self.calls.append({"method": "POST", "url": url}) + # API-level failure: 200 with errcode != 0 + return _FakeResponse(200, {"errcode": 100}) + + async def get(self, url, **kwargs): + self.calls.append({"method": "GET", "url": url}) + raise httpx.ConnectError("Connection refused") + + channel._http = _MixedHttp() + + with pytest.raises(httpx.ConnectError, match="Connection refused"): + await channel._send_media_ref("token", "user123", "https://example.com/photo.jpg") + + # Should have attempted POST (image URL) and GET (download), but NOT upload + assert len(channel._http.calls) == 2 + assert channel._http.calls[0]["method"] == "POST" + assert channel._http.calls[1]["method"] == "GET" + + +@pytest.mark.asyncio +async def test_send_media_ref_short_circuits_on_upload_transport_error() -> None: + """When download succeeds but upload hits a transport error, must re-raise.""" + config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]) + channel = DingTalkChannel(config, MessageBus()) + + image_bytes = b"\xff\xd8\xff\xe0" + b"\x00" * 100 # minimal JPEG-ish data + + class _UploadFailsHttp: + def __init__(self) -> None: + self.calls: list[dict] = [] + + async def post(self, url, json=None, headers=None, files=None, **kwargs): + self.calls.append({"method": "POST", "url": url}) + # If it's the upload endpoint, raise transport error + if "media/upload" in url: + raise httpx.ConnectError("Connection refused") + # Otherwise (sampleImageMsg), return API error to trigger fallback + return _FakeResponse(200, {"errcode": 100}) + + async def get(self, url, **kwargs): + self.calls.append({"method": "GET", "url": url}) + resp = _FakeResponse(200) + resp.content = image_bytes + resp.headers = {"content-type": "image/jpeg"} + return resp + + channel._http = _UploadFailsHttp() + + with pytest.raises(httpx.ConnectError, match="Connection refused"): + await channel._send_media_ref("token", "user123", "https://example.com/photo.jpg") + + # POST (image URL), GET (download), POST (upload) attempted — no further sends + methods = [c["method"] for c in channel._http.calls] + assert methods == ["POST", "GET", "POST"] diff --git a/tests/channels/test_discord_channel.py b/tests/channels/test_discord_channel.py index 3a31a5912..7a39bff2b 100644 --- a/tests/channels/test_discord_channel.py +++ b/tests/channels/test_discord_channel.py @@ -867,3 +867,100 @@ async def test_start_no_proxy_auth_when_only_password(monkeypatch) -> None: assert channel.is_running is False assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890" assert _FakeDiscordClient.instances[0].proxy_auth is None + + +# --------------------------------------------------------------------------- +# Tests for the send() exception propagation fix +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_send_re_raises_network_error() -> None: + """Network errors during send must propagate so ChannelManager can retry.""" + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = _FakeDiscordClient(channel, intents=None) + channel._client = client + channel._running = True + + async def _failing_send_outbound(msg: OutboundMessage) -> None: + raise ConnectionError("network unreachable") + + client.send_outbound = _failing_send_outbound # type: ignore[method-assign] + + with pytest.raises(ConnectionError, match="network unreachable"): + await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello")) + + +@pytest.mark.asyncio +async def test_send_re_raises_generic_exception() -> None: + """Any exception from send_outbound must propagate, not be swallowed.""" + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = _FakeDiscordClient(channel, intents=None) + channel._client = client + channel._running = True + + async def _failing_send_outbound(msg: OutboundMessage) -> None: + raise RuntimeError("discord API failure") + + client.send_outbound = _failing_send_outbound # type: ignore[method-assign] + + with pytest.raises(RuntimeError, match="discord API failure"): + await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello")) + + +@pytest.mark.asyncio +async def test_send_still_stops_typing_on_error() -> None: + """Typing cleanup must still run in the finally block even when send raises.""" + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = _FakeDiscordClient(channel, intents=None) + channel._client = client + channel._running = True + + # Start a typing task so we can verify it gets cleaned up + start = asyncio.Event() + release = asyncio.Event() + + async def slow_typing() -> None: + start.set() + await release.wait() + + typing_channel = _FakeChannel(channel_id=123) + typing_channel.typing_enter_hook = slow_typing + await channel._start_typing(typing_channel) + await asyncio.wait_for(start.wait(), timeout=1.0) + + async def _failing_send_outbound(msg: OutboundMessage) -> None: + raise ConnectionError("timeout") + + client.send_outbound = _failing_send_outbound # type: ignore[method-assign] + + with pytest.raises(ConnectionError, match="timeout"): + await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello")) + + release.set() + await asyncio.sleep(0) + + # Typing should have been cleaned up by the finally block + assert channel._typing_tasks == {} + + +@pytest.mark.asyncio +async def test_send_succeeds_normally() -> None: + """Successful sends should work without raising.""" + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = _FakeDiscordClient(channel, intents=None) + channel._client = client + channel._running = True + + sent_messages: list[OutboundMessage] = [] + + async def _capture_send_outbound(msg: OutboundMessage) -> None: + sent_messages.append(msg) + + client.send_outbound = _capture_send_outbound # type: ignore[method-assign] + + msg = OutboundMessage(channel="discord", chat_id="123", content="hello world") + await channel.send(msg) + + assert len(sent_messages) == 1 + assert sent_messages[0].content == "hello world" + assert sent_messages[0].chat_id == "123" diff --git a/tests/channels/test_qq_channel.py b/tests/channels/test_qq_channel.py index 729442a13..417648adf 100644 --- a/tests/channels/test_qq_channel.py +++ b/tests/channels/test_qq_channel.py @@ -1,6 +1,7 @@ import tempfile from pathlib import Path from types import SimpleNamespace +from unittest.mock import AsyncMock, patch import pytest @@ -14,6 +15,8 @@ except ImportError: if not QQ_AVAILABLE: pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True) +import aiohttp + from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.qq import QQChannel, QQConfig @@ -170,3 +173,221 @@ async def test_read_media_bytes_missing_file() -> None: data, filename = await channel._read_media_bytes("/nonexistent/path/image.png") assert data is None assert filename is None + + +# ------------------------------------------------------- +# Tests for _send_media exception handling +# ------------------------------------------------------- + +def _make_channel_with_local_file(suffix: str = ".png", content: bytes = b"\x89PNG\r\n"): + """Create a QQChannel with a fake client and a temp file for media.""" + channel = QQChannel( + QQConfig(app_id="app", secret="secret", allow_from=["*"]), + MessageBus(), + ) + channel._client = _FakeClient() + channel._chat_type_cache["user1"] = "c2c" + + tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) + tmp.write(content) + tmp.close() + return channel, tmp.name + + +@pytest.mark.asyncio +async def test_send_media_network_error_propagates() -> None: + """aiohttp.ClientError (network/transport) should re-raise, not return False.""" + channel, tmp_path = _make_channel_with_local_file() + + # Make the base64 upload raise a network error + channel._client.api._http = SimpleNamespace() + channel._client.api._http.request = AsyncMock( + side_effect=aiohttp.ServerDisconnectedError("connection lost"), + ) + + with pytest.raises(aiohttp.ServerDisconnectedError): + await channel._send_media( + chat_id="user1", + media_ref=tmp_path, + msg_id="msg1", + is_group=False, + ) + + +@pytest.mark.asyncio +async def test_send_media_client_connector_error_propagates() -> None: + """aiohttp.ClientConnectorError (DNS/connection refused) should re-raise.""" + channel, tmp_path = _make_channel_with_local_file() + + from aiohttp.client_reqrep import ConnectionKey + conn_key = ConnectionKey("api.qq.com", 443, True, None, None, None, None) + connector_error = aiohttp.ClientConnectorError( + connection_key=conn_key, + os_error=OSError("Connection refused"), + ) + + channel._client.api._http = SimpleNamespace() + channel._client.api._http.request = AsyncMock( + side_effect=connector_error, + ) + + with pytest.raises(aiohttp.ClientConnectorError): + await channel._send_media( + chat_id="user1", + media_ref=tmp_path, + msg_id="msg1", + is_group=False, + ) + + +@pytest.mark.asyncio +async def test_send_media_oserror_propagates() -> None: + """OSError (low-level I/O) should re-raise for retry.""" + channel, tmp_path = _make_channel_with_local_file() + + channel._client.api._http = SimpleNamespace() + channel._client.api._http.request = AsyncMock( + side_effect=OSError("Network is unreachable"), + ) + + with pytest.raises(OSError): + await channel._send_media( + chat_id="user1", + media_ref=tmp_path, + msg_id="msg1", + is_group=False, + ) + + +@pytest.mark.asyncio +async def test_send_media_api_error_returns_false() -> None: + """API-level errors (botpy RuntimeError subclasses) should return False, not raise.""" + channel, tmp_path = _make_channel_with_local_file() + + # Simulate a botpy API error (e.g. ServerError is a RuntimeError subclass) + from botpy.errors import ServerError + + channel._client.api._http = SimpleNamespace() + channel._client.api._http.request = AsyncMock( + side_effect=ServerError("internal server error"), + ) + + result = await channel._send_media( + chat_id="user1", + media_ref=tmp_path, + msg_id="msg1", + is_group=False, + ) + assert result is False + + +@pytest.mark.asyncio +async def test_send_media_generic_runtime_error_returns_false() -> None: + """Generic RuntimeError (not network) should return False.""" + channel, tmp_path = _make_channel_with_local_file() + + channel._client.api._http = SimpleNamespace() + channel._client.api._http.request = AsyncMock( + side_effect=RuntimeError("some API error"), + ) + + result = await channel._send_media( + chat_id="user1", + media_ref=tmp_path, + msg_id="msg1", + is_group=False, + ) + assert result is False + + +@pytest.mark.asyncio +async def test_send_media_value_error_returns_false() -> None: + """ValueError (bad API response data) should return False.""" + channel, tmp_path = _make_channel_with_local_file() + + channel._client.api._http = SimpleNamespace() + channel._client.api._http.request = AsyncMock( + side_effect=ValueError("bad response data"), + ) + + result = await channel._send_media( + chat_id="user1", + media_ref=tmp_path, + msg_id="msg1", + is_group=False, + ) + assert result is False + + +@pytest.mark.asyncio +async def test_send_media_timeout_error_propagates() -> None: + """asyncio.TimeoutError inherits from Exception but not ClientError/OSError. + However, aiohttp.ServerTimeoutError IS a ClientError subclass, so that propagates. + For a plain TimeoutError (which is also OSError in Python 3.11+), it should propagate.""" + channel, tmp_path = _make_channel_with_local_file() + + channel._client.api._http = SimpleNamespace() + channel._client.api._http.request = AsyncMock( + side_effect=aiohttp.ServerTimeoutError("request timed out"), + ) + + with pytest.raises(aiohttp.ServerTimeoutError): + await channel._send_media( + chat_id="user1", + media_ref=tmp_path, + msg_id="msg1", + is_group=False, + ) + + +@pytest.mark.asyncio +async def test_send_fallback_text_on_api_error() -> None: + """When _send_media returns False (API error), send() should emit fallback text.""" + channel, tmp_path = _make_channel_with_local_file() + + from botpy.errors import ServerError + + channel._client.api._http = SimpleNamespace() + channel._client.api._http.request = AsyncMock( + side_effect=ServerError("internal server error"), + ) + + await channel.send( + OutboundMessage( + channel="qq", + chat_id="user1", + content="", + media=[tmp_path], + metadata={"message_id": "msg1"}, + ) + ) + + # Should have sent a fallback text message + assert len(channel._client.api.c2c_calls) == 1 + fallback_content = channel._client.api.c2c_calls[0]["content"] + assert "Attachment send failed" in fallback_content + + +@pytest.mark.asyncio +async def test_send_propagates_network_error_no_fallback() -> None: + """When _send_media raises a network error, send() should NOT silently fallback.""" + channel, tmp_path = _make_channel_with_local_file() + + channel._client.api._http = SimpleNamespace() + channel._client.api._http.request = AsyncMock( + side_effect=aiohttp.ServerDisconnectedError("connection lost"), + ) + + with pytest.raises(aiohttp.ServerDisconnectedError): + await channel.send( + OutboundMessage( + channel="qq", + chat_id="user1", + content="hello", + media=[tmp_path], + metadata={"message_id": "msg1"}, + ) + ) + + # No fallback text should have been sent + assert len(channel._client.api.c2c_calls) == 0 diff --git a/tests/channels/test_telegram_channel.py b/tests/channels/test_telegram_channel.py index 5a1964127..8d9431ba6 100644 --- a/tests/channels/test_telegram_channel.py +++ b/tests/channels/test_telegram_channel.py @@ -387,6 +387,84 @@ async def test_send_delta_stream_end_treats_not_modified_as_success() -> None: assert "123" not in channel._stream_bufs +@pytest.mark.asyncio +async def test_send_delta_stream_end_does_not_fallback_on_network_timeout() -> None: + """TimedOut during HTML edit should propagate, never fall back to plain text.""" + from telegram.error import TimedOut + + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + # _call_with_retry retries TimedOut up to 3 times, so the mock will be called + # multiple times – but all calls must be with parse_mode="HTML" (no plain fallback). + channel._app.bot.edit_message_text = AsyncMock(side_effect=TimedOut("network timeout")) + channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0) + + with pytest.raises(TimedOut, match="network timeout"): + await channel.send_delta("123", "", {"_stream_end": True}) + + # Every call to edit_message_text must have used parse_mode="HTML" — + # no plain-text fallback call should have been made. + for call in channel._app.bot.edit_message_text.call_args_list: + assert call.kwargs.get("parse_mode") == "HTML" + # Buffer should still be present (not cleaned up on error) + assert "123" in channel._stream_bufs + + +@pytest.mark.asyncio +async def test_send_delta_stream_end_does_not_fallback_on_network_error() -> None: + """NetworkError during HTML edit should propagate, never fall back to plain text.""" + from telegram.error import NetworkError + + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + channel._app.bot.edit_message_text = AsyncMock(side_effect=NetworkError("connection reset")) + channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0) + + with pytest.raises(NetworkError, match="connection reset"): + await channel.send_delta("123", "", {"_stream_end": True}) + + # Every call to edit_message_text must have used parse_mode="HTML" — + # no plain-text fallback call should have been made. + for call in channel._app.bot.edit_message_text.call_args_list: + assert call.kwargs.get("parse_mode") == "HTML" + # Buffer should still be present (not cleaned up on error) + assert "123" in channel._stream_bufs + + +@pytest.mark.asyncio +async def test_send_delta_stream_end_falls_back_on_bad_request() -> None: + """BadRequest (HTML parse error) should still trigger plain-text fallback.""" + from telegram.error import BadRequest + + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + + # First call (HTML) raises BadRequest, second call (plain) succeeds + channel._app.bot.edit_message_text = AsyncMock( + side_effect=[BadRequest("Can't parse entities"), None] + ) + channel._stream_bufs["123"] = _StreamBuf(text="hello ", message_id=7, last_edit=0.0) + + await channel.send_delta("123", "", {"_stream_end": True}) + + # edit_message_text should have been called twice: once for HTML, once for plain fallback + assert channel._app.bot.edit_message_text.call_count == 2 + # Second call should not use parse_mode="HTML" + second_call_kwargs = channel._app.bot.edit_message_text.call_args_list[1].kwargs + assert "parse_mode" not in second_call_kwargs or second_call_kwargs.get("parse_mode") is None + # Buffer should be cleaned up on success + assert "123" not in channel._stream_bufs + + @pytest.mark.asyncio async def test_send_delta_stream_end_splits_oversized_reply() -> None: """Final streamed reply exceeding Telegram limit is split into chunks.""" @@ -1159,3 +1237,159 @@ async def test_on_message_location_with_text() -> None: assert len(handled) == 1 assert "meet me here" in handled[0]["content"] assert "[location: 51.5074, -0.1278]" in handled[0]["content"] + + +# --------------------------------------------------------------------------- +# Tests for retry amplification fix (issue #3050) +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_send_text_does_not_fallback_on_network_timeout() -> None: + """TimedOut should propagate immediately, NOT trigger plain-text fallback. + + Before the fix, _send_text caught ALL exceptions (including TimedOut) + and retried as plain text, doubling connection demand during pool + exhaustion — see issue #3050. + """ + from telegram.error import TimedOut + + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + + call_count = 0 + + async def always_timeout(**kwargs): + nonlocal call_count + call_count += 1 + raise TimedOut() + + channel._app.bot.send_message = always_timeout + + import nanobot.channels.telegram as tg_mod + orig_delay = tg_mod._SEND_RETRY_BASE_DELAY + tg_mod._SEND_RETRY_BASE_DELAY = 0.01 + try: + with pytest.raises(TimedOut): + await channel._send_text(123, "hello", None, {}) + finally: + tg_mod._SEND_RETRY_BASE_DELAY = orig_delay + + # With the fix: only _call_with_retry's 3 HTML attempts (no plain fallback). + # Before the fix: 3 HTML + 3 plain = 6 attempts. + assert call_count == 3, ( + f"Expected 3 calls (HTML retries only), got {call_count} " + "(plain-text fallback should not trigger on TimedOut)" + ) + + +@pytest.mark.asyncio +async def test_send_text_does_not_fallback_on_network_error() -> None: + """NetworkError should propagate immediately, NOT trigger plain-text fallback.""" + from telegram.error import NetworkError + + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + + call_count = 0 + + async def always_network_error(**kwargs): + nonlocal call_count + call_count += 1 + raise NetworkError("Connection reset") + + channel._app.bot.send_message = always_network_error + + import nanobot.channels.telegram as tg_mod + orig_delay = tg_mod._SEND_RETRY_BASE_DELAY + tg_mod._SEND_RETRY_BASE_DELAY = 0.01 + try: + with pytest.raises(NetworkError): + await channel._send_text(123, "hello", None, {}) + finally: + tg_mod._SEND_RETRY_BASE_DELAY = orig_delay + + # _call_with_retry does NOT retry NetworkError (only TimedOut/RetryAfter), + # so it raises after 1 attempt. The fix prevents plain-text fallback. + # Before the fix: 1 HTML + 1 plain = 2. After the fix: 1 HTML only. + assert call_count == 1, ( + f"Expected 1 call (HTML only, no plain fallback), got {call_count}" + ) + + +@pytest.mark.asyncio +async def test_send_text_falls_back_on_bad_request() -> None: + """BadRequest (HTML parse error) should still trigger plain-text fallback.""" + from telegram.error import BadRequest + + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + + original_send = channel._app.bot.send_message + html_call_count = 0 + + async def html_fails(**kwargs): + nonlocal html_call_count + if kwargs.get("parse_mode") == "HTML": + html_call_count += 1 + raise BadRequest("Can't parse entities") + return await original_send(**kwargs) + + channel._app.bot.send_message = html_fails + + import nanobot.channels.telegram as tg_mod + orig_delay = tg_mod._SEND_RETRY_BASE_DELAY + tg_mod._SEND_RETRY_BASE_DELAY = 0.01 + try: + await channel._send_text(123, "hello **world**", None, {}) + finally: + tg_mod._SEND_RETRY_BASE_DELAY = orig_delay + + # HTML attempt failed with BadRequest → fallback to plain text succeeds. + assert html_call_count == 1, f"Expected 1 HTML attempt, got {html_call_count}" + assert len(channel._app.bot.sent_messages) == 1 + # Plain text send should NOT have parse_mode + assert channel._app.bot.sent_messages[0].get("parse_mode") is None + + +@pytest.mark.asyncio +async def test_send_text_bad_request_plain_fallback_exhausted() -> None: + """When both HTML and plain-text fallback fail with BadRequest, the error propagates.""" + from telegram.error import BadRequest + + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + + call_count = 0 + + async def always_bad_request(**kwargs): + nonlocal call_count + call_count += 1 + raise BadRequest("Bad request") + + channel._app.bot.send_message = always_bad_request + + import nanobot.channels.telegram as tg_mod + orig_delay = tg_mod._SEND_RETRY_BASE_DELAY + tg_mod._SEND_RETRY_BASE_DELAY = 0.01 + try: + with pytest.raises(BadRequest): + await channel._send_text(123, "hello", None, {}) + finally: + tg_mod._SEND_RETRY_BASE_DELAY = orig_delay + + # _call_with_retry does NOT retry BadRequest (only TimedOut/RetryAfter), + # so HTML fails after 1 attempt → fallback to plain also fails after 1 attempt. + # Before the fix: 2 total. After the fix: still 2 (BadRequest SHOULD fallback). + assert call_count == 2, f"Expected 2 calls (1 HTML + 1 plain), got {call_count}" diff --git a/tests/channels/test_weixin_channel.py b/tests/channels/test_weixin_channel.py index 3a847411b..2b455fca6 100644 --- a/tests/channels/test_weixin_channel.py +++ b/tests/channels/test_weixin_channel.py @@ -1003,3 +1003,185 @@ async def test_download_media_item_non_image_requires_aes_key_even_with_full_url assert saved_path is None channel._client.get.assert_not_awaited() + + +# --------------------------------------------------------------------------- +# Tests for media-send error classification (network vs non-network errors) +# --------------------------------------------------------------------------- + + +def _make_outbound_msg(chat_id: str = "wx-user", content: str = "", media: list | None = None): + """Build a minimal OutboundMessage-like object for send() tests.""" + from nanobot.bus.events import OutboundMessage + + return OutboundMessage( + channel="weixin", + chat_id=chat_id, + content=content, + media=media or [], + metadata={}, + ) + + +@pytest.mark.asyncio +async def test_send_media_timeout_error_propagates_without_text_fallback() -> None: + """httpx.TimeoutException during media send must re-raise immediately, + NOT fall back to _send_text (which would also fail during network issues).""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-1" + channel._send_media_file = AsyncMock(side_effect=httpx.TimeoutException("timed out")) + channel._send_text = AsyncMock() + + msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"]) + + with pytest.raises(httpx.TimeoutException, match="timed out"): + await channel.send(msg) + + # _send_text must NOT have been called as a fallback + channel._send_text.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_send_media_transport_error_propagates_without_text_fallback() -> None: + """httpx.TransportError during media send must re-raise immediately.""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-1" + channel._send_media_file = AsyncMock( + side_effect=httpx.TransportError("connection reset") + ) + channel._send_text = AsyncMock() + + msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"]) + + with pytest.raises(httpx.TransportError, match="connection reset"): + await channel.send(msg) + + channel._send_text.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_send_media_5xx_http_status_error_propagates_without_text_fallback() -> None: + """httpx.HTTPStatusError with a 5xx status must re-raise immediately.""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-1" + + fake_response = httpx.Response( + status_code=503, + request=httpx.Request("POST", "https://example.test/upload"), + ) + channel._send_media_file = AsyncMock( + side_effect=httpx.HTTPStatusError( + "Service Unavailable", request=fake_response.request, response=fake_response + ) + ) + channel._send_text = AsyncMock() + + msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"]) + + with pytest.raises(httpx.HTTPStatusError, match="Service Unavailable"): + await channel.send(msg) + + channel._send_text.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_send_media_4xx_http_status_error_falls_back_to_text() -> None: + """httpx.HTTPStatusError with a 4xx status should fall back to text, not re-raise.""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-1" + + fake_response = httpx.Response( + status_code=400, + request=httpx.Request("POST", "https://example.test/upload"), + ) + channel._send_media_file = AsyncMock( + side_effect=httpx.HTTPStatusError( + "Bad Request", request=fake_response.request, response=fake_response + ) + ) + channel._send_text = AsyncMock() + + msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"]) + + # Should NOT raise — 4xx is a client error, non-retryable + await channel.send(msg) + + # _send_text should have been called with the fallback message + channel._send_text.assert_awaited_once_with( + "wx-user", "[Failed to send: photo.jpg]", "ctx-1" + ) + + +@pytest.mark.asyncio +async def test_send_media_file_not_found_falls_back_to_text() -> None: + """FileNotFoundError (a non-network error) should fall back to text.""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-1" + channel._send_media_file = AsyncMock( + side_effect=FileNotFoundError("Media file not found: /tmp/missing.jpg") + ) + channel._send_text = AsyncMock() + + msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/missing.jpg"]) + + # Should NOT raise + await channel.send(msg) + + channel._send_text.assert_awaited_once_with( + "wx-user", "[Failed to send: missing.jpg]", "ctx-1" + ) + + +@pytest.mark.asyncio +async def test_send_media_value_error_falls_back_to_text() -> None: + """ValueError (e.g. unsupported format) should fall back to text.""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-1" + channel._send_media_file = AsyncMock( + side_effect=ValueError("Unsupported media format") + ) + channel._send_text = AsyncMock() + + msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/file.xyz"]) + + # Should NOT raise + await channel.send(msg) + + channel._send_text.assert_awaited_once_with( + "wx-user", "[Failed to send: file.xyz]", "ctx-1" + ) + + +@pytest.mark.asyncio +async def test_send_media_network_error_does_not_double_api_calls() -> None: + """During network issues, media send should make exactly 1 API call attempt, + not 2 (media + text fallback). Verify total call count.""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-1" + channel._send_media_file = AsyncMock( + side_effect=httpx.ConnectError("connection refused") + ) + channel._send_text = AsyncMock() + + msg = _make_outbound_msg(chat_id="wx-user", content="hello", media=["/tmp/img.png"]) + + with pytest.raises(httpx.ConnectError): + await channel.send(msg) + + # _send_media_file called once, _send_text never called + channel._send_media_file.assert_awaited_once() + channel._send_text.assert_not_awaited() diff --git a/tests/providers/test_custom_provider.py b/tests/providers/test_custom_provider.py index d2a9f4247..85314dc79 100644 --- a/tests/providers/test_custom_provider.py +++ b/tests/providers/test_custom_provider.py @@ -4,6 +4,7 @@ from types import SimpleNamespace from unittest.mock import patch from nanobot.providers.openai_compat_provider import OpenAICompatProvider +from nanobot.providers.registry import find_by_name def test_custom_provider_parse_handles_empty_choices() -> None: @@ -53,3 +54,20 @@ def test_custom_provider_parse_chunks_accepts_plain_text_chunks() -> None: assert result.finish_reason == "stop" assert result.content == "hello world" + + +def test_local_provider_502_error_includes_reachability_hint() -> None: + spec = find_by_name("ollama") + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider(api_base="http://localhost:11434/v1", spec=spec) + + result = provider._handle_error( + Exception("Error code: 502"), + spec=spec, + api_base="http://localhost:11434/v1", + ) + + assert result.finish_reason == "error" + assert "local model endpoint" in result.content + assert "http://localhost:11434/v1" in result.content + assert "proxy/tunnel" in result.content diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index da90c4d0d..a133f53db 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -356,6 +356,33 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries( assert "Available wrapped names: mcp_test_demo" in warnings[-1] +@pytest.mark.asyncio +async def test_connect_mcp_servers_logs_stdio_pollution_hint( + monkeypatch: pytest.MonkeyPatch, +) -> None: + messages: list[str] = [] + + def _error(message: str, *args: object) -> None: + messages.append(message.format(*args)) + + @asynccontextmanager + async def _broken_stdio_client(_params: object): + raise RuntimeError("Parse error: Unexpected token 'INFO' before JSON-RPC headers") + yield # pragma: no cover + + monkeypatch.setattr(sys.modules["mcp.client.stdio"], "stdio_client", _broken_stdio_client) + monkeypatch.setattr("nanobot.agent.tools.mcp.logger.error", _error) + + registry = ToolRegistry() + stacks = await connect_mcp_servers({"gh": MCPServerConfig(command="github-mcp")}, registry) + + assert stacks == {} + assert messages + assert "stdio protocol pollution" in messages[-1] + assert "stdout" in messages[-1] + assert "stderr" in messages[-1] + + @pytest.mark.asyncio async def test_connect_mcp_servers_one_failure_does_not_block_others( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/tools/test_tool_registry.py b/tests/tools/test_tool_registry.py index 5b259119e..f9e8ce5e1 100644 --- a/tests/tools/test_tool_registry.py +++ b/tests/tools/test_tool_registry.py @@ -47,3 +47,27 @@ def test_get_definitions_orders_builtins_then_mcp_tools() -> None: "mcp_fs_list", "mcp_git_status", ] + + +def test_prepare_call_read_file_rejects_non_object_params_with_actionable_hint() -> None: + registry = ToolRegistry() + registry.register(_FakeTool("read_file")) + + tool, params, error = registry.prepare_call("read_file", ["foo.txt"]) + + assert tool is None + assert params == ["foo.txt"] + assert error is not None + assert "must be a JSON object" in error + assert "Use named parameters" in error + + +def test_prepare_call_other_tools_keep_generic_object_validation() -> None: + registry = ToolRegistry() + registry.register(_FakeTool("grep")) + + tool, params, error = registry.prepare_call("grep", ["TODO"]) + + assert tool is not None + assert params == ["TODO"] + assert error == "Error: Invalid parameters for tool 'grep': parameters must be an object, got list"