diff --git a/README.md b/README.md index 828b56477..8a8c864d0 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,8 @@ - [Configuration](#️-configuration) - [Multiple Instances](#-multiple-instances) - [CLI Reference](#-cli-reference) +- [Python SDK](#-python-sdk) +- [OpenAI-Compatible API](#-openai-compatible-api) - [Docker](#-docker) - [Linux Service](#-linux-service) - [Project Structure](#-project-structure) @@ -1541,6 +1543,7 @@ nanobot gateway --config ~/.nanobot-telegram/config.json --workspace /tmp/nanobo | `nanobot agent` | Interactive chat mode | | `nanobot agent --no-markdown` | Show plain-text replies | | `nanobot agent --logs` | Show runtime logs during chat | +| `nanobot serve` | Start the OpenAI-compatible API | | `nanobot gateway` | Start the gateway | | `nanobot status` | Show status | | `nanobot provider login openai-codex` | OAuth login for providers | @@ -1569,6 +1572,110 @@ The agent can also manage this file itself — ask it to "add a periodic task" a +## 🐍 Python SDK + +Use nanobot as a library — no CLI, no gateway, just Python: + +```python +from nanobot import Nanobot + +bot = Nanobot.from_config() +result = await bot.run("Summarize the README") +print(result.content) +``` + +Each call carries a `session_key` for conversation isolation — different keys get independent history: + +```python +await bot.run("hi", session_key="user-alice") +await bot.run("hi", session_key="task-42") +``` + +Add lifecycle hooks to observe or customize the agent: + +```python +from nanobot.agent import AgentHook, AgentHookContext + +class AuditHook(AgentHook): + async def before_execute_tools(self, ctx: AgentHookContext) -> None: + for tc in ctx.tool_calls: + print(f"[tool] {tc.name}") + +result = await bot.run("Hello", hooks=[AuditHook()]) +``` + +See [docs/PYTHON_SDK.md](docs/PYTHON_SDK.md) for the full SDK reference. + +## 🔌 OpenAI-Compatible API + +nanobot can expose a minimal OpenAI-compatible endpoint for local integrations: + +```bash +pip install "nanobot-ai[api]" +nanobot serve +``` + +By default, the API binds to `127.0.0.1:8900`. You can change this in `config.json`. + +### Behavior + +- Session isolation: pass `"session_id"` in the request body to isolate conversations; omit for a shared default session (`api:default`) +- Single-message input: each request must contain exactly one `user` message +- Fixed model: omit `model`, or pass the same model shown by `/v1/models` +- No streaming: `stream=true` is not supported + +### Endpoints + +- `GET /health` +- `GET /v1/models` +- `POST /v1/chat/completions` + +### curl + +```bash +curl http://127.0.0.1:8900/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [{"role": "user", "content": "hi"}], + "session_id": "my-session" + }' +``` + +### Python (`requests`) + +```python +import requests + +resp = requests.post( + "http://127.0.0.1:8900/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "hi"}], + "session_id": "my-session", # optional: isolate conversation + }, + timeout=120, +) +resp.raise_for_status() +print(resp.json()["choices"][0]["message"]["content"]) +``` + +### Python (`openai`) + +```python +from openai import OpenAI + +client = OpenAI( + base_url="http://127.0.0.1:8900/v1", + api_key="dummy", +) + +resp = client.chat.completions.create( + model="MiniMax-M2.7", + messages=[{"role": "user", "content": "hi"}], + extra_body={"session_id": "my-session"}, # optional: isolate conversation +) +print(resp.choices[0].message.content) +``` + ## 🐳 Docker > [!TIP] diff --git a/core_agent_lines.sh b/core_agent_lines.sh index d35207cb4..0891347d5 100755 --- a/core_agent_lines.sh +++ b/core_agent_lines.sh @@ -1,5 +1,6 @@ #!/bin/bash -# Count core agent lines (excluding channels/, cli/, providers/ adapters) +# Count core agent lines (excluding channels/, cli/, api/, providers/ adapters, +# and the high-level Python SDK facade) cd "$(dirname "$0")" || exit 1 echo "nanobot core agent line count" @@ -15,7 +16,7 @@ root=$(cat nanobot/__init__.py nanobot/__main__.py | wc -l) printf " %-16s %5s lines\n" "(root)" "$root" echo "" -total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/command/*" ! -path "*/providers/*" ! -path "*/skills/*" | xargs cat | wc -l) +total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/api/*" ! -path "*/command/*" ! -path "*/providers/*" ! -path "*/skills/*" ! -path "nanobot/nanobot.py" | xargs cat | wc -l) echo " Core total: $total lines" echo "" -echo " (excludes: channels/, cli/, command/, providers/, skills/)" +echo " (excludes: channels/, cli/, api/, command/, providers/, skills/, nanobot.py)" diff --git a/docs/PYTHON_SDK.md b/docs/PYTHON_SDK.md new file mode 100644 index 000000000..357722e5e --- /dev/null +++ b/docs/PYTHON_SDK.md @@ -0,0 +1,136 @@ +# Python SDK + +Use nanobot programmatically — load config, run the agent, get results. + +## Quick Start + +```python +import asyncio +from nanobot import Nanobot + +async def main(): + bot = Nanobot.from_config() + result = await bot.run("What time is it in Tokyo?") + print(result.content) + +asyncio.run(main()) +``` + +## API + +### `Nanobot.from_config(config_path?, *, workspace?)` + +Create a `Nanobot` from a config file. + +| Param | Type | Default | Description | +|-------|------|---------|-------------| +| `config_path` | `str \| Path \| None` | `None` | Path to `config.json`. Defaults to `~/.nanobot/config.json`. | +| `workspace` | `str \| Path \| None` | `None` | Override workspace directory from config. | + +Raises `FileNotFoundError` if an explicit path doesn't exist. + +### `await bot.run(message, *, session_key?, hooks?)` + +Run the agent once. Returns a `RunResult`. + +| Param | Type | Default | Description | +|-------|------|---------|-------------| +| `message` | `str` | *(required)* | The user message to process. | +| `session_key` | `str` | `"sdk:default"` | Session identifier for conversation isolation. Different keys get independent history. | +| `hooks` | `list[AgentHook] \| None` | `None` | Lifecycle hooks for this run only. | + +```python +# Isolated sessions — each user gets independent conversation history +await bot.run("hi", session_key="user-alice") +await bot.run("hi", session_key="user-bob") +``` + +### `RunResult` + +| Field | Type | Description | +|-------|------|-------------| +| `content` | `str` | The agent's final text response. | +| `tools_used` | `list[str]` | Tool names invoked during the run. | +| `messages` | `list[dict]` | Raw message history (for debugging). | + +## Hooks + +Hooks let you observe or modify the agent loop without touching internals. + +Subclass `AgentHook` and override any method: + +| Method | When | +|--------|------| +| `before_iteration(ctx)` | Before each LLM call | +| `on_stream(ctx, delta)` | On each streamed token | +| `on_stream_end(ctx)` | When streaming finishes | +| `before_execute_tools(ctx)` | Before tool execution (inspect `ctx.tool_calls`) | +| `after_iteration(ctx, response)` | After each LLM response | +| `finalize_content(ctx, content)` | Transform final output text | + +### Example: Audit Hook + +```python +from nanobot.agent import AgentHook, AgentHookContext + +class AuditHook(AgentHook): + def __init__(self): + self.calls = [] + + async def before_execute_tools(self, ctx: AgentHookContext) -> None: + for tc in ctx.tool_calls: + self.calls.append(tc.name) + print(f"[audit] {tc.name}({tc.arguments})") + +hook = AuditHook() +result = await bot.run("List files in /tmp", hooks=[hook]) +print(f"Tools used: {hook.calls}") +``` + +### Composing Hooks + +Pass multiple hooks — they run in order, errors in one don't block others: + +```python +result = await bot.run("hi", hooks=[AuditHook(), MetricsHook()]) +``` + +Under the hood this uses `CompositeHook` for fan-out with error isolation. + +### `finalize_content` Pipeline + +Unlike the async methods (fan-out), `finalize_content` is a pipeline — each hook's output feeds the next: + +```python +class Censor(AgentHook): + def finalize_content(self, ctx, content): + return content.replace("secret", "***") if content else content +``` + +## Full Example + +```python +import asyncio +from nanobot import Nanobot +from nanobot.agent import AgentHook, AgentHookContext + +class TimingHook(AgentHook): + async def before_iteration(self, ctx: AgentHookContext) -> None: + import time + ctx.metadata["_t0"] = time.time() + + async def after_iteration(self, ctx, response) -> None: + import time + elapsed = time.time() - ctx.metadata.get("_t0", 0) + print(f"[timing] iteration took {elapsed:.2f}s") + +async def main(): + bot = Nanobot.from_config(workspace="/my/project") + result = await bot.run( + "Explain the main function", + hooks=[TimingHook()], + ) + print(result.content) + +asyncio.run(main()) +``` diff --git a/nanobot/__init__.py b/nanobot/__init__.py index 07efd09cf..11833c696 100644 --- a/nanobot/__init__.py +++ b/nanobot/__init__.py @@ -4,3 +4,7 @@ nanobot - A lightweight AI agent framework __version__ = "0.1.4.post6" __logo__ = "🐈" + +from nanobot.nanobot import Nanobot, RunResult + +__all__ = ["Nanobot", "RunResult"] diff --git a/nanobot/agent/__init__.py b/nanobot/agent/__init__.py index f9ba8b87a..7d3ab2af4 100644 --- a/nanobot/agent/__init__.py +++ b/nanobot/agent/__init__.py @@ -1,8 +1,19 @@ """Agent core module.""" from nanobot.agent.context import ContextBuilder +from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook from nanobot.agent.loop import AgentLoop from nanobot.agent.memory import MemoryStore from nanobot.agent.skills import SkillsLoader +from nanobot.agent.subagent import SubagentManager -__all__ = ["AgentLoop", "ContextBuilder", "MemoryStore", "SkillsLoader"] +__all__ = [ + "AgentHook", + "AgentHookContext", + "AgentLoop", + "CompositeHook", + "ContextBuilder", + "MemoryStore", + "SkillsLoader", + "SubagentManager", +] diff --git a/nanobot/agent/hook.py b/nanobot/agent/hook.py index 368c46aa2..97ec7a07d 100644 --- a/nanobot/agent/hook.py +++ b/nanobot/agent/hook.py @@ -5,6 +5,8 @@ from __future__ import annotations from dataclasses import dataclass, field from typing import Any +from loguru import logger + from nanobot.providers.base import LLMResponse, ToolCallRequest @@ -47,3 +49,60 @@ class AgentHook: def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: return content + + +class CompositeHook(AgentHook): + """Fan-out hook that delegates to an ordered list of hooks. + + Error isolation: async methods catch and log per-hook exceptions + so a faulty custom hook cannot crash the agent loop. + ``finalize_content`` is a pipeline (no isolation — bugs should surface). + """ + + __slots__ = ("_hooks",) + + def __init__(self, hooks: list[AgentHook]) -> None: + self._hooks = list(hooks) + + def wants_streaming(self) -> bool: + return any(h.wants_streaming() for h in self._hooks) + + async def before_iteration(self, context: AgentHookContext) -> None: + for h in self._hooks: + try: + await h.before_iteration(context) + except Exception: + logger.exception("AgentHook.before_iteration error in {}", type(h).__name__) + + async def on_stream(self, context: AgentHookContext, delta: str) -> None: + for h in self._hooks: + try: + await h.on_stream(context, delta) + except Exception: + logger.exception("AgentHook.on_stream error in {}", type(h).__name__) + + async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: + for h in self._hooks: + try: + await h.on_stream_end(context, resuming=resuming) + except Exception: + logger.exception("AgentHook.on_stream_end error in {}", type(h).__name__) + + async def before_execute_tools(self, context: AgentHookContext) -> None: + for h in self._hooks: + try: + await h.before_execute_tools(context) + except Exception: + logger.exception("AgentHook.before_execute_tools error in {}", type(h).__name__) + + async def after_iteration(self, context: AgentHookContext) -> None: + for h in self._hooks: + try: + await h.after_iteration(context) + except Exception: + logger.exception("AgentHook.after_iteration error in {}", type(h).__name__) + + def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: + for h in self._hooks: + content = h.finalize_content(context, content) + return content diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 63ee92ca5..50fef58fd 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable from loguru import logger from nanobot.agent.context import ContextBuilder -from nanobot.agent.hook import AgentHook, AgentHookContext +from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook from nanobot.agent.memory import MemoryConsolidator from nanobot.agent.runner import AgentRunSpec, AgentRunner from nanobot.agent.subagent import SubagentManager @@ -37,6 +37,120 @@ if TYPE_CHECKING: from nanobot.cron.service import CronService +class _LoopHook(AgentHook): + """Core lifecycle hook for the main agent loop. + + Handles streaming delta relay, progress reporting, tool-call logging, + and think-tag stripping for the built-in agent path. + """ + + def __init__( + self, + agent_loop: AgentLoop, + on_progress: Callable[..., Awaitable[None]] | None = None, + on_stream: Callable[[str], Awaitable[None]] | None = None, + on_stream_end: Callable[..., Awaitable[None]] | None = None, + *, + channel: str = "cli", + chat_id: str = "direct", + message_id: str | None = None, + ) -> None: + self._loop = agent_loop + self._on_progress = on_progress + self._on_stream = on_stream + self._on_stream_end = on_stream_end + self._channel = channel + self._chat_id = chat_id + self._message_id = message_id + self._stream_buf = "" + + def wants_streaming(self) -> bool: + return self._on_stream is not None + + async def on_stream(self, context: AgentHookContext, delta: str) -> None: + from nanobot.utils.helpers import strip_think + + prev_clean = strip_think(self._stream_buf) + self._stream_buf += delta + new_clean = strip_think(self._stream_buf) + incremental = new_clean[len(prev_clean):] + if incremental and self._on_stream: + await self._on_stream(incremental) + + async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: + if self._on_stream_end: + await self._on_stream_end(resuming=resuming) + self._stream_buf = "" + + async def before_execute_tools(self, context: AgentHookContext) -> None: + if self._on_progress: + if not self._on_stream: + thought = self._loop._strip_think( + context.response.content if context.response else None + ) + if thought: + await self._on_progress(thought) + tool_hint = self._loop._strip_think(self._loop._tool_hint(context.tool_calls)) + await self._on_progress(tool_hint, tool_hint=True) + for tc in context.tool_calls: + args_str = json.dumps(tc.arguments, ensure_ascii=False) + logger.info("Tool call: {}({})", tc.name, args_str[:200]) + self._loop._set_tool_context(self._channel, self._chat_id, self._message_id) + + async def after_iteration(self, context: AgentHookContext) -> None: + u = context.usage or {} + logger.debug( + "LLM usage: prompt={} completion={} cached={}", + u.get("prompt_tokens", 0), + u.get("completion_tokens", 0), + u.get("cached_tokens", 0), + ) + + def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: + return self._loop._strip_think(content) + + +class _LoopHookChain(AgentHook): + """Run the core loop hook first, then best-effort extra hooks. + + This preserves the historical failure behavior of ``_LoopHook`` while still + letting user-supplied hooks opt into ``CompositeHook`` isolation. + """ + + __slots__ = ("_primary", "_extras") + + def __init__(self, primary: AgentHook, extra_hooks: list[AgentHook]) -> None: + self._primary = primary + self._extras = CompositeHook(extra_hooks) + + def wants_streaming(self) -> bool: + return self._primary.wants_streaming() or self._extras.wants_streaming() + + async def before_iteration(self, context: AgentHookContext) -> None: + await self._primary.before_iteration(context) + await self._extras.before_iteration(context) + + async def on_stream(self, context: AgentHookContext, delta: str) -> None: + await self._primary.on_stream(context, delta) + await self._extras.on_stream(context, delta) + + async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: + await self._primary.on_stream_end(context, resuming=resuming) + await self._extras.on_stream_end(context, resuming=resuming) + + async def before_execute_tools(self, context: AgentHookContext) -> None: + await self._primary.before_execute_tools(context) + await self._extras.before_execute_tools(context) + + async def after_iteration(self, context: AgentHookContext) -> None: + await self._primary.after_iteration(context) + await self._extras.after_iteration(context) + + def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: + content = self._primary.finalize_content(context, content) + return self._extras.finalize_content(context, content) + + class AgentLoop: """ The agent loop is the core processing engine. @@ -68,6 +182,7 @@ class AgentLoop: mcp_servers: dict | None = None, channels_config: ChannelsConfig | None = None, timezone: str | None = None, + hooks: list[AgentHook] | None = None, ): from nanobot.config.schema import ExecToolConfig, WebSearchConfig @@ -85,6 +200,7 @@ class AgentLoop: self.restrict_to_workspace = restrict_to_workspace self._start_time = time.time() self._last_usage: dict[str, int] = {} + self._extra_hooks: list[AgentHook] = hooks or [] self.context = ContextBuilder(workspace, timezone=timezone) self.sessions = session_manager or SessionManager(workspace) @@ -217,52 +333,27 @@ class AgentLoop: ``resuming=True`` means tool calls follow (spinner should restart); ``resuming=False`` means this is the final response. """ - loop_self = self - - class _LoopHook(AgentHook): - def __init__(self) -> None: - self._stream_buf = "" - - def wants_streaming(self) -> bool: - return on_stream is not None - - async def on_stream(self, context: AgentHookContext, delta: str) -> None: - from nanobot.utils.helpers import strip_think - - prev_clean = strip_think(self._stream_buf) - self._stream_buf += delta - new_clean = strip_think(self._stream_buf) - incremental = new_clean[len(prev_clean):] - if incremental and on_stream: - await on_stream(incremental) - - async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: - if on_stream_end: - await on_stream_end(resuming=resuming) - self._stream_buf = "" - - async def before_execute_tools(self, context: AgentHookContext) -> None: - if on_progress: - if not on_stream: - thought = loop_self._strip_think(context.response.content if context.response else None) - if thought: - await on_progress(thought) - tool_hint = loop_self._strip_think(loop_self._tool_hint(context.tool_calls)) - await on_progress(tool_hint, tool_hint=True) - for tc in context.tool_calls: - args_str = json.dumps(tc.arguments, ensure_ascii=False) - logger.info("Tool call: {}({})", tc.name, args_str[:200]) - loop_self._set_tool_context(channel, chat_id, message_id) - - def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: - return loop_self._strip_think(content) + loop_hook = _LoopHook( + self, + on_progress=on_progress, + on_stream=on_stream, + on_stream_end=on_stream_end, + channel=channel, + chat_id=chat_id, + message_id=message_id, + ) + hook: AgentHook = ( + _LoopHookChain(loop_hook, self._extra_hooks) + if self._extra_hooks + else loop_hook + ) result = await self.runner.run(AgentRunSpec( initial_messages=initial_messages, tools=self.tools, model=self.model, max_iterations=self.max_iterations, - hook=_LoopHook(), + hook=hook, error_message="Sorry, I encountered an error calling the AI model.", concurrent_tools=True, )) @@ -321,25 +412,25 @@ class AgentLoop: return f"{stream_base_id}:{stream_segment}" async def on_stream(delta: str) -> None: + meta = dict(msg.metadata or {}) + meta["_stream_delta"] = True + meta["_stream_id"] = _current_stream_id() await self.bus.publish_outbound(OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, content=delta, - metadata={ - "_stream_delta": True, - "_stream_id": _current_stream_id(), - }, + metadata=meta, )) async def on_stream_end(*, resuming: bool = False) -> None: nonlocal stream_segment + meta = dict(msg.metadata or {}) + meta["_stream_end"] = True + meta["_resuming"] = resuming + meta["_stream_id"] = _current_stream_id() await self.bus.publish_outbound(OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, content="", - metadata={ - "_stream_end": True, - "_resuming": resuming, - "_stream_id": _current_stream_id(), - }, + metadata=meta, )) stream_segment += 1 diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index d6242a6b4..4fec539dd 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -60,7 +60,7 @@ class AgentRunner: messages = list(spec.initial_messages) final_content: str | None = None tools_used: list[str] = [] - usage = {"prompt_tokens": 0, "completion_tokens": 0} + usage: dict[str, int] = {} error: str | None = None stop_reason = "completed" tool_events: list[dict[str, str]] = [] @@ -92,13 +92,15 @@ class AgentRunner: response = await self.provider.chat_with_retry(**kwargs) raw_usage = response.usage or {} - usage = { - "prompt_tokens": int(raw_usage.get("prompt_tokens", 0) or 0), - "completion_tokens": int(raw_usage.get("completion_tokens", 0) or 0), - } context.response = response - context.usage = usage + context.usage = raw_usage context.tool_calls = list(response.tool_calls) + # Accumulate standard fields into result usage. + usage["prompt_tokens"] = usage.get("prompt_tokens", 0) + int(raw_usage.get("prompt_tokens", 0) or 0) + usage["completion_tokens"] = usage.get("completion_tokens", 0) + int(raw_usage.get("completion_tokens", 0) or 0) + cached = raw_usage.get("cached_tokens") + if cached: + usage["cached_tokens"] = usage.get("cached_tokens", 0) + int(cached) if response.has_tool_calls: if hook.wants_streaming(): diff --git a/nanobot/agent/subagent.py b/nanobot/agent/subagent.py index 5266fc8b1..9d936f034 100644 --- a/nanobot/agent/subagent.py +++ b/nanobot/agent/subagent.py @@ -21,6 +21,21 @@ from nanobot.config.schema import ExecToolConfig from nanobot.providers.base import LLMProvider +class _SubagentHook(AgentHook): + """Logging-only hook for subagent execution.""" + + def __init__(self, task_id: str) -> None: + self._task_id = task_id + + async def before_execute_tools(self, context: AgentHookContext) -> None: + for tool_call in context.tool_calls: + args_str = json.dumps(tool_call.arguments, ensure_ascii=False) + logger.debug( + "Subagent [{}] executing: {} with arguments: {}", + self._task_id, tool_call.name, args_str, + ) + + class SubagentManager: """Manages background subagent execution.""" @@ -100,33 +115,28 @@ class SubagentManager: tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir)) tools.register(EditFileTool(workspace=self.workspace, allowed_dir=allowed_dir)) tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir)) - tools.register(ExecTool( - working_dir=str(self.workspace), - timeout=self.exec_config.timeout, - restrict_to_workspace=self.restrict_to_workspace, - path_append=self.exec_config.path_append, - )) + if self.exec_config.enable: + tools.register(ExecTool( + working_dir=str(self.workspace), + timeout=self.exec_config.timeout, + restrict_to_workspace=self.restrict_to_workspace, + path_append=self.exec_config.path_append, + )) tools.register(WebSearchTool(config=self.web_search_config, proxy=self.web_proxy)) tools.register(WebFetchTool(proxy=self.web_proxy)) - + system_prompt = self._build_subagent_prompt() messages: list[dict[str, Any]] = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": task}, ] - class _SubagentHook(AgentHook): - async def before_execute_tools(self, context: AgentHookContext) -> None: - for tool_call in context.tool_calls: - args_str = json.dumps(tool_call.arguments, ensure_ascii=False) - logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str) - result = await self.runner.run(AgentRunSpec( initial_messages=messages, tools=tools, model=self.model, max_iterations=15, - hook=_SubagentHook(), + hook=_SubagentHook(task_id), max_iterations_message="Task completed but no final response was generated.", error_message=None, fail_on_tool_error=True, @@ -213,7 +223,7 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men lines.append("Failure:") lines.append(f"- {result.error}") return "\n".join(lines) or (result.error or "Error: subagent execution failed.") - + def _build_subagent_prompt(self) -> str: """Build a focused system prompt for the subagent.""" from nanobot.agent.context import ContextBuilder diff --git a/nanobot/agent/tools/cron.py b/nanobot/agent/tools/cron.py index 9989af55f..f2aba0b97 100644 --- a/nanobot/agent/tools/cron.py +++ b/nanobot/agent/tools/cron.py @@ -74,7 +74,7 @@ class CronTool(Tool): "enum": ["add", "list", "remove"], "description": "Action to perform", }, - "message": {"type": "string", "description": "Reminder message (for add)"}, + "message": {"type": "string", "description": "Instruction for the agent to execute when the job triggers (e.g., 'Send a reminder to WeChat: xxx' or 'Check system status and report')"}, "every_seconds": { "type": "integer", "description": "Interval in seconds (for recurring tasks)", @@ -97,6 +97,11 @@ class CronTool(Tool): f"(e.g. '2026-02-12T10:30:00'). Naive values default to {self._default_timezone}." ), }, + "deliver": { + "type": "boolean", + "description": "Whether to deliver the execution result to the user channel (default true)", + "default": True + }, "job_id": {"type": "string", "description": "Job ID (for remove)"}, }, "required": ["action"], @@ -111,12 +116,13 @@ class CronTool(Tool): tz: str | None = None, at: str | None = None, job_id: str | None = None, + deliver: bool = True, **kwargs: Any, ) -> str: if action == "add": if self._in_cron_context.get(): return "Error: cannot schedule new jobs from within a cron job execution" - return self._add_job(message, every_seconds, cron_expr, tz, at) + return self._add_job(message, every_seconds, cron_expr, tz, at, deliver) elif action == "list": return self._list_jobs() elif action == "remove": @@ -130,6 +136,7 @@ class CronTool(Tool): cron_expr: str | None, tz: str | None, at: str | None, + deliver: bool = True, ) -> str: if not message: return "Error: message is required for add" @@ -171,7 +178,7 @@ class CronTool(Tool): name=message[:30], schedule=schedule, message=message, - deliver=True, + deliver=deliver, channel=self._channel, to=self._chat_id, delete_after_run=delete_after, diff --git a/nanobot/agent/tools/mcp.py b/nanobot/agent/tools/mcp.py index c1c3e79a2..51533333e 100644 --- a/nanobot/agent/tools/mcp.py +++ b/nanobot/agent/tools/mcp.py @@ -170,7 +170,11 @@ async def connect_mcp_servers( timeout: httpx.Timeout | None = None, auth: httpx.Auth | None = None, ) -> httpx.AsyncClient: - merged_headers = {**(cfg.headers or {}), **(headers or {})} + merged_headers = { + "Accept": "application/json, text/event-stream", + **(cfg.headers or {}), + **(headers or {}), + } return httpx.AsyncClient( headers=merged_headers or None, follow_redirects=True, diff --git a/nanobot/agent/tools/message.py b/nanobot/agent/tools/message.py index c8d50cf1e..3ac813248 100644 --- a/nanobot/agent/tools/message.py +++ b/nanobot/agent/tools/message.py @@ -86,7 +86,15 @@ class MessageTool(Tool): ) -> str: channel = channel or self._default_channel chat_id = chat_id or self._default_chat_id - message_id = message_id or self._default_message_id + # Only inherit default message_id when targeting the same channel+chat. + # Cross-chat sends must not carry the original message_id, because + # some channels (e.g. Feishu) use it to determine the target + # conversation via their Reply API, which would route the message + # to the wrong chat entirely. + if channel == self._default_channel and chat_id == self._default_chat_id: + message_id = message_id or self._default_message_id + else: + message_id = None if not channel or not chat_id: return "Error: No target channel/chat specified" @@ -101,7 +109,7 @@ class MessageTool(Tool): media=media or [], metadata={ "message_id": message_id, - }, + } if message_id else {}, ) try: diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index ed552b33e..b051edffc 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -186,7 +186,9 @@ class ExecTool(Tool): @staticmethod def _extract_absolute_paths(command: str) -> list[str]: - win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]+", command) # Windows: C:\... + # Windows: match drive-root paths like `C:\` as well as `C:\path\to\file` + # NOTE: `*` is required so `C:\` (nothing after the slash) is still extracted. + win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]*", command) posix_paths = re.findall(r"(?:^|[\s|>'\"])(/[^\s\"'>;|<]+)", command) # POSIX: /absolute only home_paths = re.findall(r"(?:^|[\s|>'\"])(~[^\s\"'>;|<]*)", command) # POSIX/Windows home shortcut: ~ return win_paths + posix_paths + home_paths diff --git a/nanobot/api/__init__.py b/nanobot/api/__init__.py new file mode 100644 index 000000000..f0c504cc1 --- /dev/null +++ b/nanobot/api/__init__.py @@ -0,0 +1 @@ +"""OpenAI-compatible HTTP API for nanobot.""" diff --git a/nanobot/api/server.py b/nanobot/api/server.py new file mode 100644 index 000000000..9494b6e31 --- /dev/null +++ b/nanobot/api/server.py @@ -0,0 +1,193 @@ +"""OpenAI-compatible HTTP API server for a fixed nanobot session. + +Provides /v1/chat/completions and /v1/models endpoints. +All requests route to a single persistent API session. +""" + +from __future__ import annotations + +import asyncio +import time +import uuid +from typing import Any + +from aiohttp import web +from loguru import logger + +API_SESSION_KEY = "api:default" +API_CHAT_ID = "default" + + +# --------------------------------------------------------------------------- +# Response helpers +# --------------------------------------------------------------------------- + +def _error_json(status: int, message: str, err_type: str = "invalid_request_error") -> web.Response: + return web.json_response( + {"error": {"message": message, "type": err_type, "code": status}}, + status=status, + ) + + +def _chat_completion_response(content: str, model: str) -> dict[str, Any]: + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:12]}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + } + + +def _response_text(value: Any) -> str: + """Normalize process_direct output to plain assistant text.""" + if value is None: + return "" + if hasattr(value, "content"): + return str(getattr(value, "content") or "") + return str(value) + + +# --------------------------------------------------------------------------- +# Route handlers +# --------------------------------------------------------------------------- + +async def handle_chat_completions(request: web.Request) -> web.Response: + """POST /v1/chat/completions""" + + # --- Parse body --- + try: + body = await request.json() + except Exception: + return _error_json(400, "Invalid JSON body") + + messages = body.get("messages") + if not isinstance(messages, list) or len(messages) != 1: + return _error_json(400, "Only a single user message is supported") + + # Stream not yet supported + if body.get("stream", False): + return _error_json(400, "stream=true is not supported yet. Set stream=false or omit it.") + + message = messages[0] + if not isinstance(message, dict) or message.get("role") != "user": + return _error_json(400, "Only a single user message is supported") + user_content = message.get("content", "") + if isinstance(user_content, list): + # Multi-modal content array — extract text parts + user_content = " ".join( + part.get("text", "") for part in user_content if part.get("type") == "text" + ) + + agent_loop = request.app["agent_loop"] + timeout_s: float = request.app.get("request_timeout", 120.0) + model_name: str = request.app.get("model_name", "nanobot") + if (requested_model := body.get("model")) and requested_model != model_name: + return _error_json(400, f"Only configured model '{model_name}' is available") + + session_key = f"api:{body['session_id']}" if body.get("session_id") else API_SESSION_KEY + session_locks: dict[str, asyncio.Lock] = request.app["session_locks"] + session_lock = session_locks.setdefault(session_key, asyncio.Lock()) + + logger.info("API request session_key={} content={}", session_key, user_content[:80]) + + _FALLBACK = "I've completed processing but have no response to give." + + try: + async with session_lock: + try: + response = await asyncio.wait_for( + agent_loop.process_direct( + content=user_content, + session_key=session_key, + channel="api", + chat_id=API_CHAT_ID, + ), + timeout=timeout_s, + ) + response_text = _response_text(response) + + if not response_text or not response_text.strip(): + logger.warning( + "Empty response for session {}, retrying", + session_key, + ) + retry_response = await asyncio.wait_for( + agent_loop.process_direct( + content=user_content, + session_key=session_key, + channel="api", + chat_id=API_CHAT_ID, + ), + timeout=timeout_s, + ) + response_text = _response_text(retry_response) + if not response_text or not response_text.strip(): + logger.warning( + "Empty response after retry for session {}, using fallback", + session_key, + ) + response_text = _FALLBACK + + except asyncio.TimeoutError: + return _error_json(504, f"Request timed out after {timeout_s}s") + except Exception: + logger.exception("Error processing request for session {}", session_key) + return _error_json(500, "Internal server error", err_type="server_error") + except Exception: + logger.exception("Unexpected API lock error for session {}", session_key) + return _error_json(500, "Internal server error", err_type="server_error") + + return web.json_response(_chat_completion_response(response_text, model_name)) + + +async def handle_models(request: web.Request) -> web.Response: + """GET /v1/models""" + model_name = request.app.get("model_name", "nanobot") + return web.json_response({ + "object": "list", + "data": [ + { + "id": model_name, + "object": "model", + "created": 0, + "owned_by": "nanobot", + } + ], + }) + + +async def handle_health(request: web.Request) -> web.Response: + """GET /health""" + return web.json_response({"status": "ok"}) + + +# --------------------------------------------------------------------------- +# App factory +# --------------------------------------------------------------------------- + +def create_app(agent_loop, model_name: str = "nanobot", request_timeout: float = 120.0) -> web.Application: + """Create the aiohttp application. + + Args: + agent_loop: An initialized AgentLoop instance. + model_name: Model name reported in responses. + request_timeout: Per-request timeout in seconds. + """ + app = web.Application() + app["agent_loop"] = agent_loop + app["model_name"] = model_name + app["request_timeout"] = request_timeout + app["session_locks"] = {} # per-user locks, keyed by session_key + + app.router.add_post("/v1/chat/completions", handle_chat_completions) + app.router.add_get("/v1/models", handle_models) + app.router.add_get("/health", handle_health) + return app diff --git a/nanobot/channels/discord.py b/nanobot/channels/discord.py index 82eafcc00..9bf4d919c 100644 --- a/nanobot/channels/discord.py +++ b/nanobot/channels/discord.py @@ -1,25 +1,37 @@ -"""Discord channel implementation using Discord Gateway websocket.""" +"""Discord channel implementation using discord.py.""" + +from __future__ import annotations import asyncio -import json +import importlib.util from pathlib import Path -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal -import httpx -from pydantic import Field -import websockets from loguru import logger +from pydantic import Field from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel +from nanobot.command.builtin import build_help_text from nanobot.config.paths import get_media_dir from nanobot.config.schema import Base -from nanobot.utils.helpers import split_message +from nanobot.utils.helpers import safe_filename, split_message + +DISCORD_AVAILABLE = importlib.util.find_spec("discord") is not None +if TYPE_CHECKING: + import discord + from discord import app_commands + from discord.abc import Messageable + +if DISCORD_AVAILABLE: + import discord + from discord import app_commands + from discord.abc import Messageable -DISCORD_API_BASE = "https://discord.com/api/v10" MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB MAX_MESSAGE_LEN = 2000 # Discord message character limit +TYPING_INTERVAL_S = 8 class DiscordConfig(Base): @@ -28,13 +40,205 @@ class DiscordConfig(Base): enabled: bool = False token: str = "" allow_from: list[str] = Field(default_factory=list) - gateway_url: str = "wss://gateway.discord.gg/?v=10&encoding=json" intents: int = 37377 group_policy: Literal["mention", "open"] = "mention" + read_receipt_emoji: str = "👀" + working_emoji: str = "🔧" + working_emoji_delay: float = 2.0 + + +if DISCORD_AVAILABLE: + + class DiscordBotClient(discord.Client): + """discord.py client that forwards events to the channel.""" + + def __init__(self, channel: DiscordChannel, *, intents: discord.Intents) -> None: + super().__init__(intents=intents) + self._channel = channel + self.tree = app_commands.CommandTree(self) + self._register_app_commands() + + async def on_ready(self) -> None: + self._channel._bot_user_id = str(self.user.id) if self.user else None + logger.info("Discord bot connected as user {}", self._channel._bot_user_id) + try: + synced = await self.tree.sync() + logger.info("Discord app commands synced: {}", len(synced)) + except Exception as e: + logger.warning("Discord app command sync failed: {}", e) + + async def on_message(self, message: discord.Message) -> None: + await self._channel._handle_discord_message(message) + + async def _reply_ephemeral(self, interaction: discord.Interaction, text: str) -> bool: + """Send an ephemeral interaction response and report success.""" + try: + await interaction.response.send_message(text, ephemeral=True) + return True + except Exception as e: + logger.warning("Discord interaction response failed: {}", e) + return False + + async def _forward_slash_command( + self, + interaction: discord.Interaction, + command_text: str, + ) -> None: + sender_id = str(interaction.user.id) + channel_id = interaction.channel_id + + if channel_id is None: + logger.warning("Discord slash command missing channel_id: {}", command_text) + return + + if not self._channel.is_allowed(sender_id): + await self._reply_ephemeral(interaction, "You are not allowed to use this bot.") + return + + await self._reply_ephemeral(interaction, f"Processing {command_text}...") + + await self._channel._handle_message( + sender_id=sender_id, + chat_id=str(channel_id), + content=command_text, + metadata={ + "interaction_id": str(interaction.id), + "guild_id": str(interaction.guild_id) if interaction.guild_id else None, + "is_slash_command": True, + }, + ) + + def _register_app_commands(self) -> None: + commands = ( + ("new", "Start a new conversation", "/new"), + ("stop", "Stop the current task", "/stop"), + ("restart", "Restart the bot", "/restart"), + ("status", "Show bot status", "/status"), + ) + + for name, description, command_text in commands: + @self.tree.command(name=name, description=description) + async def command_handler( + interaction: discord.Interaction, + _command_text: str = command_text, + ) -> None: + await self._forward_slash_command(interaction, _command_text) + + @self.tree.command(name="help", description="Show available commands") + async def help_command(interaction: discord.Interaction) -> None: + sender_id = str(interaction.user.id) + if not self._channel.is_allowed(sender_id): + await self._reply_ephemeral(interaction, "You are not allowed to use this bot.") + return + await self._reply_ephemeral(interaction, build_help_text()) + + @self.tree.error + async def on_app_command_error( + interaction: discord.Interaction, + error: app_commands.AppCommandError, + ) -> None: + command_name = interaction.command.qualified_name if interaction.command else "?" + logger.warning( + "Discord app command failed user={} channel={} cmd={} error={}", + interaction.user.id, + interaction.channel_id, + command_name, + error, + ) + + async def send_outbound(self, msg: OutboundMessage) -> None: + """Send a nanobot outbound message using Discord transport rules.""" + channel_id = int(msg.chat_id) + + channel = self.get_channel(channel_id) + if channel is None: + try: + channel = await self.fetch_channel(channel_id) + except Exception as e: + logger.warning("Discord channel {} unavailable: {}", msg.chat_id, e) + return + + reference, mention_settings = self._build_reply_context(channel, msg.reply_to) + sent_media = False + failed_media: list[str] = [] + + for index, media_path in enumerate(msg.media or []): + if await self._send_file( + channel, + media_path, + reference=reference if index == 0 else None, + mention_settings=mention_settings, + ): + sent_media = True + else: + failed_media.append(Path(media_path).name) + + for index, chunk in enumerate(self._build_chunks(msg.content or "", failed_media, sent_media)): + kwargs: dict[str, Any] = {"content": chunk} + if index == 0 and reference is not None and not sent_media: + kwargs["reference"] = reference + kwargs["allowed_mentions"] = mention_settings + await channel.send(**kwargs) + + async def _send_file( + self, + channel: Messageable, + file_path: str, + *, + reference: discord.PartialMessage | None, + mention_settings: discord.AllowedMentions, + ) -> bool: + """Send a file attachment via discord.py.""" + path = Path(file_path) + if not path.is_file(): + logger.warning("Discord file not found, skipping: {}", file_path) + return False + + if path.stat().st_size > MAX_ATTACHMENT_BYTES: + logger.warning("Discord file too large (>20MB), skipping: {}", path.name) + return False + + try: + kwargs: dict[str, Any] = {"file": discord.File(path)} + if reference is not None: + kwargs["reference"] = reference + kwargs["allowed_mentions"] = mention_settings + await channel.send(**kwargs) + logger.info("Discord file sent: {}", path.name) + return True + except Exception as e: + logger.error("Error sending Discord file {}: {}", path.name, e) + return False + + @staticmethod + def _build_chunks(content: str, failed_media: list[str], sent_media: bool) -> list[str]: + """Build outbound text chunks, including attachment-failure fallback text.""" + chunks = split_message(content, MAX_MESSAGE_LEN) + if chunks or not failed_media or sent_media: + return chunks + fallback = "\n".join(f"[attachment: {name} - send failed]" for name in failed_media) + return split_message(fallback, MAX_MESSAGE_LEN) + + @staticmethod + def _build_reply_context( + channel: Messageable, + reply_to: str | None, + ) -> tuple[discord.PartialMessage | None, discord.AllowedMentions]: + """Build reply context for outbound messages.""" + mention_settings = discord.AllowedMentions(replied_user=False) + if not reply_to: + return None, mention_settings + try: + message_id = int(reply_to) + except ValueError: + logger.warning("Invalid Discord reply target: {}", reply_to) + return None, mention_settings + + return channel.get_partial_message(message_id), mention_settings class DiscordChannel(BaseChannel): - """Discord channel using Gateway websocket.""" + """Discord channel using discord.py.""" name = "discord" display_name = "Discord" @@ -43,353 +247,270 @@ class DiscordChannel(BaseChannel): def default_config(cls) -> dict[str, Any]: return DiscordConfig().model_dump(by_alias=True) + @staticmethod + def _channel_key(channel_or_id: Any) -> str: + """Normalize channel-like objects and ids to a stable string key.""" + channel_id = getattr(channel_or_id, "id", channel_or_id) + return str(channel_id) + def __init__(self, config: Any, bus: MessageBus): if isinstance(config, dict): config = DiscordConfig.model_validate(config) super().__init__(config, bus) self.config: DiscordConfig = config - self._ws: websockets.WebSocketClientProtocol | None = None - self._seq: int | None = None - self._heartbeat_task: asyncio.Task | None = None - self._typing_tasks: dict[str, asyncio.Task] = {} - self._http: httpx.AsyncClient | None = None + self._client: DiscordBotClient | None = None + self._typing_tasks: dict[str, asyncio.Task[None]] = {} self._bot_user_id: str | None = None + self._pending_reactions: dict[str, Any] = {} # chat_id -> message object + self._working_emoji_tasks: dict[str, asyncio.Task[None]] = {} async def start(self) -> None: - """Start the Discord gateway connection.""" + """Start the Discord client.""" + if not DISCORD_AVAILABLE: + logger.error("discord.py not installed. Run: pip install nanobot-ai[discord]") + return + if not self.config.token: logger.error("Discord bot token not configured") return - self._running = True - self._http = httpx.AsyncClient(timeout=30.0) + try: + intents = discord.Intents.none() + intents.value = self.config.intents + self._client = DiscordBotClient(self, intents=intents) + except Exception as e: + logger.error("Failed to initialize Discord client: {}", e) + self._client = None + self._running = False + return - while self._running: - try: - logger.info("Connecting to Discord gateway...") - async with websockets.connect(self.config.gateway_url) as ws: - self._ws = ws - await self._gateway_loop() - except asyncio.CancelledError: - break - except Exception as e: - logger.warning("Discord gateway error: {}", e) - if self._running: - logger.info("Reconnecting to Discord gateway in 5 seconds...") - await asyncio.sleep(5) + self._running = True + logger.info("Starting Discord client via discord.py...") + + try: + await self._client.start(self.config.token) + except asyncio.CancelledError: + raise + except Exception as e: + logger.error("Discord client startup failed: {}", e) + finally: + self._running = False + await self._reset_runtime_state(close_client=True) async def stop(self) -> None: """Stop the Discord channel.""" self._running = False - if self._heartbeat_task: - self._heartbeat_task.cancel() - self._heartbeat_task = None - for task in self._typing_tasks.values(): - task.cancel() - self._typing_tasks.clear() - if self._ws: - await self._ws.close() - self._ws = None - if self._http: - await self._http.aclose() - self._http = None + await self._reset_runtime_state(close_client=True) async def send(self, msg: OutboundMessage) -> None: - """Send a message through Discord REST API, including file attachments.""" - if not self._http: - logger.warning("Discord HTTP client not initialized") + """Send a message through Discord using discord.py.""" + client = self._client + if client is None or not client.is_ready(): + logger.warning("Discord client not ready; dropping outbound message") return - url = f"{DISCORD_API_BASE}/channels/{msg.chat_id}/messages" - headers = {"Authorization": f"Bot {self.config.token}"} + is_progress = bool((msg.metadata or {}).get("_progress")) try: - sent_media = False - failed_media: list[str] = [] - - # Send file attachments first - for media_path in msg.media or []: - if await self._send_file(url, headers, media_path, reply_to=msg.reply_to): - sent_media = True - else: - failed_media.append(Path(media_path).name) - - # Send text content - chunks = split_message(msg.content or "", MAX_MESSAGE_LEN) - if not chunks and failed_media and not sent_media: - chunks = split_message( - "\n".join(f"[attachment: {name} - send failed]" for name in failed_media), - MAX_MESSAGE_LEN, - ) - if not chunks: - return - - for i, chunk in enumerate(chunks): - payload: dict[str, Any] = {"content": chunk} - - # Let the first successful attachment carry the reply if present. - if i == 0 and msg.reply_to and not sent_media: - payload["message_reference"] = {"message_id": msg.reply_to} - payload["allowed_mentions"] = {"replied_user": False} - - if not await self._send_payload(url, headers, payload): - break # Abort remaining chunks on failure + await client.send_outbound(msg) + except Exception as e: + logger.error("Error sending Discord message: {}", e) finally: - await self._stop_typing(msg.chat_id) + if not is_progress: + await self._stop_typing(msg.chat_id) + await self._clear_reactions(msg.chat_id) - async def _send_payload( - self, url: str, headers: dict[str, str], payload: dict[str, Any] - ) -> bool: - """Send a single Discord API payload with retry on rate-limit. Returns True on success.""" - for attempt in range(3): + async def _handle_discord_message(self, message: discord.Message) -> None: + """Handle incoming Discord messages from discord.py.""" + if message.author.bot: + return + + sender_id = str(message.author.id) + channel_id = self._channel_key(message.channel) + content = message.content or "" + + if not self._should_accept_inbound(message, sender_id, content): + return + + media_paths, attachment_markers = await self._download_attachments(message.attachments) + full_content = self._compose_inbound_content(content, attachment_markers) + metadata = self._build_inbound_metadata(message) + + await self._start_typing(message.channel) + + # Add read receipt reaction immediately, working emoji after delay + channel_id = self._channel_key(message.channel) + try: + await message.add_reaction(self.config.read_receipt_emoji) + self._pending_reactions[channel_id] = message + except Exception as e: + logger.debug("Failed to add read receipt reaction: {}", e) + + # Delayed working indicator (cosmetic — not tied to subagent lifecycle) + async def _delayed_working_emoji() -> None: + await asyncio.sleep(self.config.working_emoji_delay) try: - response = await self._http.post(url, headers=headers, json=payload) - if response.status_code == 429: - data = response.json() - retry_after = float(data.get("retry_after", 1.0)) - logger.warning("Discord rate limited, retrying in {}s", retry_after) - await asyncio.sleep(retry_after) - continue - response.raise_for_status() - return True - except Exception as e: - if attempt == 2: - logger.error("Error sending Discord message: {}", e) - else: - await asyncio.sleep(1) - return False + await message.add_reaction(self.config.working_emoji) + except Exception: + pass - async def _send_file( + self._working_emoji_tasks[channel_id] = asyncio.create_task(_delayed_working_emoji()) + + try: + await self._handle_message( + sender_id=sender_id, + chat_id=channel_id, + content=full_content, + media=media_paths, + metadata=metadata, + ) + except Exception: + await self._clear_reactions(channel_id) + await self._stop_typing(channel_id) + raise + + async def _on_message(self, message: discord.Message) -> None: + """Backward-compatible alias for legacy tests/callers.""" + await self._handle_discord_message(message) + + def _should_accept_inbound( self, - url: str, - headers: dict[str, str], - file_path: str, - reply_to: str | None = None, + message: discord.Message, + sender_id: str, + content: str, ) -> bool: - """Send a file attachment via Discord REST API using multipart/form-data.""" - path = Path(file_path) - if not path.is_file(): - logger.warning("Discord file not found, skipping: {}", file_path) - return False - - if path.stat().st_size > MAX_ATTACHMENT_BYTES: - logger.warning("Discord file too large (>20MB), skipping: {}", path.name) - return False - - payload_json: dict[str, Any] = {} - if reply_to: - payload_json["message_reference"] = {"message_id": reply_to} - payload_json["allowed_mentions"] = {"replied_user": False} - - for attempt in range(3): - try: - with open(path, "rb") as f: - files = {"files[0]": (path.name, f, "application/octet-stream")} - data: dict[str, Any] = {} - if payload_json: - data["payload_json"] = json.dumps(payload_json) - response = await self._http.post( - url, headers=headers, files=files, data=data - ) - if response.status_code == 429: - resp_data = response.json() - retry_after = float(resp_data.get("retry_after", 1.0)) - logger.warning("Discord rate limited, retrying in {}s", retry_after) - await asyncio.sleep(retry_after) - continue - response.raise_for_status() - logger.info("Discord file sent: {}", path.name) - return True - except Exception as e: - if attempt == 2: - logger.error("Error sending Discord file {}: {}", path.name, e) - else: - await asyncio.sleep(1) - return False - - async def _gateway_loop(self) -> None: - """Main gateway loop: identify, heartbeat, dispatch events.""" - if not self._ws: - return - - async for raw in self._ws: - try: - data = json.loads(raw) - except json.JSONDecodeError: - logger.warning("Invalid JSON from Discord gateway: {}", raw[:100]) - continue - - op = data.get("op") - event_type = data.get("t") - seq = data.get("s") - payload = data.get("d") - - if seq is not None: - self._seq = seq - - if op == 10: - # HELLO: start heartbeat and identify - interval_ms = payload.get("heartbeat_interval", 45000) - await self._start_heartbeat(interval_ms / 1000) - await self._identify() - elif op == 0 and event_type == "READY": - logger.info("Discord gateway READY") - # Capture bot user ID for mention detection - user_data = payload.get("user") or {} - self._bot_user_id = user_data.get("id") - logger.info("Discord bot connected as user {}", self._bot_user_id) - elif op == 0 and event_type == "MESSAGE_CREATE": - await self._handle_message_create(payload) - elif op == 7: - # RECONNECT: exit loop to reconnect - logger.info("Discord gateway requested reconnect") - break - elif op == 9: - # INVALID_SESSION: reconnect - logger.warning("Discord gateway invalid session") - break - - async def _identify(self) -> None: - """Send IDENTIFY payload.""" - if not self._ws: - return - - identify = { - "op": 2, - "d": { - "token": self.config.token, - "intents": self.config.intents, - "properties": { - "os": "nanobot", - "browser": "nanobot", - "device": "nanobot", - }, - }, - } - await self._ws.send(json.dumps(identify)) - - async def _start_heartbeat(self, interval_s: float) -> None: - """Start or restart the heartbeat loop.""" - if self._heartbeat_task: - self._heartbeat_task.cancel() - - async def heartbeat_loop() -> None: - while self._running and self._ws: - payload = {"op": 1, "d": self._seq} - try: - await self._ws.send(json.dumps(payload)) - except Exception as e: - logger.warning("Discord heartbeat failed: {}", e) - break - await asyncio.sleep(interval_s) - - self._heartbeat_task = asyncio.create_task(heartbeat_loop()) - - async def _handle_message_create(self, payload: dict[str, Any]) -> None: - """Handle incoming Discord messages.""" - author = payload.get("author") or {} - if author.get("bot"): - return - - sender_id = str(author.get("id", "")) - channel_id = str(payload.get("channel_id", "")) - content = payload.get("content") or "" - guild_id = payload.get("guild_id") - - if not sender_id or not channel_id: - return - + """Check if inbound Discord message should be processed.""" if not self.is_allowed(sender_id): - return + return False + if message.guild is not None and not self._should_respond_in_group(message, content): + return False + return True - # Check group channel policy (DMs always respond if is_allowed passes) - if guild_id is not None: - if not self._should_respond_in_group(payload, content): - return - - content_parts = [content] if content else [] + async def _download_attachments( + self, + attachments: list[discord.Attachment], + ) -> tuple[list[str], list[str]]: + """Download supported attachments and return paths + display markers.""" media_paths: list[str] = [] + markers: list[str] = [] media_dir = get_media_dir("discord") - for attachment in payload.get("attachments") or []: - url = attachment.get("url") - filename = attachment.get("filename") or "attachment" - size = attachment.get("size") or 0 - if not url or not self._http: - continue - if size and size > MAX_ATTACHMENT_BYTES: - content_parts.append(f"[attachment: {filename} - too large]") + for attachment in attachments: + filename = attachment.filename or "attachment" + if attachment.size and attachment.size > MAX_ATTACHMENT_BYTES: + markers.append(f"[attachment: {filename} - too large]") continue try: media_dir.mkdir(parents=True, exist_ok=True) - file_path = media_dir / f"{attachment.get('id', 'file')}_{filename.replace('/', '_')}" - resp = await self._http.get(url) - resp.raise_for_status() - file_path.write_bytes(resp.content) + safe_name = safe_filename(filename) + file_path = media_dir / f"{attachment.id}_{safe_name}" + await attachment.save(file_path) media_paths.append(str(file_path)) - content_parts.append(f"[attachment: {file_path}]") + markers.append(f"[attachment: {file_path.name}]") except Exception as e: logger.warning("Failed to download Discord attachment: {}", e) - content_parts.append(f"[attachment: {filename} - download failed]") + markers.append(f"[attachment: {filename} - download failed]") - reply_to = (payload.get("referenced_message") or {}).get("id") + return media_paths, markers - await self._start_typing(channel_id) + @staticmethod + def _compose_inbound_content(content: str, attachment_markers: list[str]) -> str: + """Combine message text with attachment markers.""" + content_parts = [content] if content else [] + content_parts.extend(attachment_markers) + return "\n".join(part for part in content_parts if part) or "[empty message]" - await self._handle_message( - sender_id=sender_id, - chat_id=channel_id, - content="\n".join(p for p in content_parts if p) or "[empty message]", - media=media_paths, - metadata={ - "message_id": str(payload.get("id", "")), - "guild_id": guild_id, - "reply_to": reply_to, - }, - ) + @staticmethod + def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]: + """Build metadata for inbound Discord messages.""" + reply_to = str(message.reference.message_id) if message.reference and message.reference.message_id else None + return { + "message_id": str(message.id), + "guild_id": str(message.guild.id) if message.guild else None, + "reply_to": reply_to, + } - def _should_respond_in_group(self, payload: dict[str, Any], content: str) -> bool: - """Check if bot should respond in a group channel based on policy.""" + def _should_respond_in_group(self, message: discord.Message, content: str) -> bool: + """Check if the bot should respond in a guild channel based on policy.""" if self.config.group_policy == "open": return True if self.config.group_policy == "mention": - # Check if bot was mentioned in the message - if self._bot_user_id: - # Check mentions array - mentions = payload.get("mentions") or [] - for mention in mentions: - if str(mention.get("id")) == self._bot_user_id: - return True - # Also check content for mention format <@USER_ID> - if f"<@{self._bot_user_id}>" in content or f"<@!{self._bot_user_id}>" in content: - return True - logger.debug("Discord message in {} ignored (bot not mentioned)", payload.get("channel_id")) + bot_user_id = self._bot_user_id + if bot_user_id is None: + logger.debug("Discord message in {} ignored (bot identity unavailable)", message.channel.id) + return False + + if any(str(user.id) == bot_user_id for user in message.mentions): + return True + if f"<@{bot_user_id}>" in content or f"<@!{bot_user_id}>" in content: + return True + + logger.debug("Discord message in {} ignored (bot not mentioned)", message.channel.id) return False return True - async def _start_typing(self, channel_id: str) -> None: + async def _start_typing(self, channel: Messageable) -> None: """Start periodic typing indicator for a channel.""" + channel_id = self._channel_key(channel) await self._stop_typing(channel_id) async def typing_loop() -> None: - url = f"{DISCORD_API_BASE}/channels/{channel_id}/typing" - headers = {"Authorization": f"Bot {self.config.token}"} while self._running: try: - await self._http.post(url, headers=headers) + async with channel.typing(): + await asyncio.sleep(TYPING_INTERVAL_S) except asyncio.CancelledError: return except Exception as e: logger.debug("Discord typing indicator failed for {}: {}", channel_id, e) return - await asyncio.sleep(8) self._typing_tasks[channel_id] = asyncio.create_task(typing_loop()) async def _stop_typing(self, channel_id: str) -> None: """Stop typing indicator for a channel.""" - task = self._typing_tasks.pop(channel_id, None) - if task: + task = self._typing_tasks.pop(self._channel_key(channel_id), None) + if task is None: + return + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + + async def _clear_reactions(self, chat_id: str) -> None: + """Remove all pending reactions after bot replies.""" + # Cancel delayed working emoji if it hasn't fired yet + task = self._working_emoji_tasks.pop(chat_id, None) + if task and not task.done(): task.cancel() + + msg_obj = self._pending_reactions.pop(chat_id, None) + if msg_obj is None: + return + bot_user = self._client.user if self._client else None + for emoji in (self.config.read_receipt_emoji, self.config.working_emoji): + try: + await msg_obj.remove_reaction(emoji, bot_user) + except Exception: + pass + + async def _cancel_all_typing(self) -> None: + """Stop all typing tasks.""" + channel_ids = list(self._typing_tasks) + for channel_id in channel_ids: + await self._stop_typing(channel_id) + + async def _reset_runtime_state(self, close_client: bool) -> None: + """Reset client and typing state.""" + await self._cancel_all_typing() + if close_client and self._client is not None and not self._client.is_closed(): + try: + await self._client.close() + except Exception as e: + logger.warning("Discord client close failed: {}", e) + self._client = None + self._bot_user_id = None diff --git a/nanobot/channels/matrix.py b/nanobot/channels/matrix.py index 98926735e..bc6d9398a 100644 --- a/nanobot/channels/matrix.py +++ b/nanobot/channels/matrix.py @@ -3,6 +3,8 @@ import asyncio import logging import mimetypes +import time +from dataclasses import dataclass from pathlib import Path from typing import Any, Literal, TypeAlias @@ -28,8 +30,8 @@ try: RoomSendError, RoomTypingError, SyncError, - UploadError, - ) + UploadError, RoomSendResponse, +) from nio.crypto.attachments import decrypt_attachment from nio.exceptions import EncryptionError except ImportError as e: @@ -97,6 +99,22 @@ MATRIX_HTML_CLEANER = nh3.Cleaner( link_rel="noopener noreferrer", ) +@dataclass +class _StreamBuf: + """ + Represents a buffer for managing LLM response stream data. + + :ivar text: Stores the text content of the buffer. + :type text: str + :ivar event_id: Identifier for the associated event. None indicates no + specific event association. + :type event_id: str | None + :ivar last_edit: Timestamp of the most recent edit to the buffer. + :type last_edit: float + """ + text: str = "" + event_id: str | None = None + last_edit: float = 0.0 def _render_markdown_html(text: str) -> str | None: """Render markdown to sanitized HTML; returns None for plain text.""" @@ -114,12 +132,47 @@ def _render_markdown_html(text: str) -> str | None: return formatted -def _build_matrix_text_content(text: str) -> dict[str, object]: - """Build Matrix m.text payload with optional HTML formatted_body.""" +def _build_matrix_text_content( + text: str, + event_id: str | None = None, + thread_relates_to: dict[str, object] | None = None, +) -> dict[str, object]: + """ + Constructs and returns a dictionary representing the matrix text content with optional + HTML formatting and reference to an existing event for replacement. This function is + primarily used to create content payloads compatible with the Matrix messaging protocol. + + :param text: The plain text content to include in the message. + :type text: str + :param event_id: Optional ID of the event to replace. If provided, the function will + include information indicating that the message is a replacement of the specified + event. + :type event_id: str | None + :param thread_relates_to: Optional Matrix thread relation metadata. For edits this is + stored in ``m.new_content`` so the replacement remains in the same thread. + :type thread_relates_to: dict[str, object] | None + :return: A dictionary containing the matrix text content, potentially enriched with + HTML formatting and replacement metadata if applicable. + :rtype: dict[str, object] + """ content: dict[str, object] = {"msgtype": "m.text", "body": text, "m.mentions": {}} if html := _render_markdown_html(text): content["format"] = MATRIX_HTML_FORMAT content["formatted_body"] = html + if event_id: + content["m.new_content"] = { + "body": text, + "msgtype": "m.text", + } + content["m.relates_to"] = { + "rel_type": "m.replace", + "event_id": event_id, + } + if thread_relates_to: + content["m.new_content"]["m.relates_to"] = thread_relates_to + elif thread_relates_to: + content["m.relates_to"] = thread_relates_to + return content @@ -159,7 +212,8 @@ class MatrixConfig(Base): allow_from: list[str] = Field(default_factory=list) group_policy: Literal["open", "mention", "allowlist"] = "open" group_allow_from: list[str] = Field(default_factory=list) - allow_room_mentions: bool = False + allow_room_mentions: bool = False, + streaming: bool = False class MatrixChannel(BaseChannel): @@ -167,6 +221,8 @@ class MatrixChannel(BaseChannel): name = "matrix" display_name = "Matrix" + _STREAM_EDIT_INTERVAL = 2 # min seconds between edit_message_text calls + monotonic_time = time.monotonic @classmethod def default_config(cls) -> dict[str, Any]: @@ -192,6 +248,8 @@ class MatrixChannel(BaseChannel): ) self._server_upload_limit_bytes: int | None = None self._server_upload_limit_checked = False + self._stream_bufs: dict[str, _StreamBuf] = {} + async def start(self) -> None: """Start Matrix client and begin sync loop.""" @@ -297,14 +355,17 @@ class MatrixChannel(BaseChannel): room = getattr(self.client, "rooms", {}).get(room_id) return bool(getattr(room, "encrypted", False)) - async def _send_room_content(self, room_id: str, content: dict[str, Any]) -> None: + async def _send_room_content(self, room_id: str, + content: dict[str, Any]) -> None | RoomSendResponse | RoomSendError: """Send m.room.message with E2EE options.""" if not self.client: - return + return None kwargs: dict[str, Any] = {"room_id": room_id, "message_type": "m.room.message", "content": content} + if self.config.e2ee_enabled: kwargs["ignore_unverified_devices"] = True - await self.client.room_send(**kwargs) + response = await self.client.room_send(**kwargs) + return response async def _resolve_server_upload_limit_bytes(self) -> int | None: """Query homeserver upload limit once per channel lifecycle.""" @@ -414,6 +475,53 @@ class MatrixChannel(BaseChannel): if not is_progress: await self._stop_typing_keepalive(msg.chat_id, clear_typing=True) + async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None: + meta = metadata or {} + relates_to = self._build_thread_relates_to(metadata) + + if meta.get("_stream_end"): + buf = self._stream_bufs.pop(chat_id, None) + if not buf or not buf.event_id or not buf.text: + return + + await self._stop_typing_keepalive(chat_id, clear_typing=True) + + content = _build_matrix_text_content( + buf.text, + buf.event_id, + thread_relates_to=relates_to, + ) + await self._send_room_content(chat_id, content) + return + + buf = self._stream_bufs.get(chat_id) + if buf is None: + buf = _StreamBuf() + self._stream_bufs[chat_id] = buf + buf.text += delta + + if not buf.text.strip(): + return + + now = self.monotonic_time() + + if not buf.last_edit or (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL: + try: + content = _build_matrix_text_content( + buf.text, + buf.event_id, + thread_relates_to=relates_to, + ) + response = await self._send_room_content(chat_id, content) + buf.last_edit = now + if not buf.event_id: + # we are editing the same message all the time, so only the first time the event id needs to be set + buf.event_id = response.event_id + except Exception: + await self._stop_typing_keepalive(chat_id, clear_typing=True) + pass + + def _register_event_callbacks(self) -> None: self.client.add_event_callback(self._on_message, RoomMessageText) self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER) diff --git a/nanobot/channels/weixin.py b/nanobot/channels/weixin.py index 9e2caae3f..2266bc9f0 100644 --- a/nanobot/channels/weixin.py +++ b/nanobot/channels/weixin.py @@ -14,6 +14,7 @@ import base64 import hashlib import json import os +import random import re import time import uuid @@ -52,7 +53,26 @@ MESSAGE_TYPE_BOT = 2 MESSAGE_STATE_FINISH = 2 WEIXIN_MAX_MESSAGE_LEN = 4000 -WEIXIN_CHANNEL_VERSION = "1.0.3" +WEIXIN_CHANNEL_VERSION = "2.1.1" +ILINK_APP_ID = "bot" + + +def _build_client_version(version: str) -> int: + """Encode semantic version as 0x00MMNNPP (major/minor/patch in one uint32).""" + parts = version.split(".") + + def _as_int(idx: int) -> int: + try: + return int(parts[idx]) + except Exception: + return 0 + + major = _as_int(0) + minor = _as_int(1) + patch = _as_int(2) + return ((major & 0xFF) << 16) | ((minor & 0xFF) << 8) | (patch & 0xFF) + +ILINK_APP_CLIENT_VERSION = _build_client_version(WEIXIN_CHANNEL_VERSION) BASE_INFO: dict[str, str] = {"channel_version": WEIXIN_CHANNEL_VERSION} # Session-expired error code @@ -64,18 +84,32 @@ MAX_CONSECUTIVE_FAILURES = 3 BACKOFF_DELAY_S = 30 RETRY_DELAY_S = 2 MAX_QR_REFRESH_COUNT = 3 +TYPING_STATUS_TYPING = 1 +TYPING_STATUS_CANCEL = 2 +TYPING_TICKET_TTL_S = 24 * 60 * 60 +TYPING_KEEPALIVE_INTERVAL_S = 5 +CONFIG_CACHE_INITIAL_RETRY_S = 2 +CONFIG_CACHE_MAX_RETRY_S = 60 * 60 # Default long-poll timeout; overridden by server via longpolling_timeout_ms. DEFAULT_LONG_POLL_TIMEOUT_S = 35 -# Media-type codes for getuploadurl (1=image, 2=video, 3=file) +# Media-type codes for getuploadurl (1=image, 2=video, 3=file, 4=voice) UPLOAD_MEDIA_IMAGE = 1 UPLOAD_MEDIA_VIDEO = 2 UPLOAD_MEDIA_FILE = 3 +UPLOAD_MEDIA_VOICE = 4 # File extensions considered as images / videos for outbound media _IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".tiff", ".ico", ".svg"} _VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv"} +_VOICE_EXTS = {".mp3", ".wav", ".amr", ".silk", ".ogg", ".m4a", ".aac", ".flac"} + + +def _has_downloadable_media_locator(media: dict[str, Any] | None) -> bool: + if not isinstance(media, dict): + return False + return bool(str(media.get("encrypt_query_param", "") or "") or str(media.get("full_url", "") or "").strip()) class WeixinConfig(Base): @@ -124,7 +158,7 @@ class WeixinChannel(BaseChannel): self._next_poll_timeout_s: int = DEFAULT_LONG_POLL_TIMEOUT_S self._session_pause_until: float = 0.0 self._typing_tasks: dict[str, asyncio.Task] = {} - self._typing_tickets: dict[str, str] = {} + self._typing_tickets: dict[str, dict[str, Any]] = {} # ------------------------------------------------------------------ # State persistence @@ -162,9 +196,9 @@ class WeixinChannel(BaseChannel): typing_tickets = data.get("typing_tickets", {}) if isinstance(typing_tickets, dict): self._typing_tickets = { - str(user_id): str(ticket) + str(user_id): ticket for user_id, ticket in typing_tickets.items() - if str(user_id).strip() and str(ticket).strip() + if str(user_id).strip() and isinstance(ticket, dict) } else: self._typing_tickets = {} @@ -172,8 +206,7 @@ class WeixinChannel(BaseChannel): if base_url: self.config.base_url = base_url return bool(self._token) - except Exception as e: - logger.warning("Failed to load WeChat state: {}", e) + except Exception: return False def _save_state(self) -> None: @@ -187,8 +220,8 @@ class WeixinChannel(BaseChannel): "base_url": self.config.base_url, } state_file.write_text(json.dumps(data, ensure_ascii=False)) - except Exception as e: - logger.warning("Failed to save WeChat state: {}", e) + except Exception: + pass # ------------------------------------------------------------------ # HTTP helpers (matches api.ts buildHeaders / apiFetch) @@ -210,6 +243,8 @@ class WeixinChannel(BaseChannel): "X-WECHAT-UIN": self._random_wechat_uin(), "Content-Type": "application/json", "AuthorizationType": "ilink_bot_token", + "iLink-App-Id": ILINK_APP_ID, + "iLink-App-ClientVersion": str(ILINK_APP_CLIENT_VERSION), } if auth and self._token: headers["Authorization"] = f"Bearer {self._token}" @@ -217,6 +252,15 @@ class WeixinChannel(BaseChannel): headers["SKRouteTag"] = str(self.config.route_tag).strip() return headers + @staticmethod + def _is_retryable_media_download_error(err: Exception) -> bool: + if isinstance(err, httpx.TimeoutException | httpx.TransportError): + return True + if isinstance(err, httpx.HTTPStatusError): + status_code = err.response.status_code if err.response is not None else 0 + return status_code >= 500 + return False + async def _api_get( self, endpoint: str, @@ -234,6 +278,25 @@ class WeixinChannel(BaseChannel): resp.raise_for_status() return resp.json() + async def _api_get_with_base( + self, + *, + base_url: str, + endpoint: str, + params: dict | None = None, + auth: bool = True, + extra_headers: dict[str, str] | None = None, + ) -> dict: + """GET helper that allows overriding base_url for QR redirect polling.""" + assert self._client is not None + url = f"{base_url.rstrip('/')}/{endpoint}" + hdrs = self._make_headers(auth=auth) + if extra_headers: + hdrs.update(extra_headers) + resp = await self._client.get(url, params=params, headers=hdrs) + resp.raise_for_status() + return resp.json() + async def _api_post( self, endpoint: str, @@ -270,23 +333,27 @@ class WeixinChannel(BaseChannel): async def _qr_login(self) -> bool: """Perform QR code login flow. Returns True on success.""" try: - logger.info("Starting WeChat QR code login...") refresh_count = 0 qrcode_id, scan_url = await self._fetch_qr_code() self._print_qr_code(scan_url) + current_poll_base_url = self.config.base_url - logger.info("Waiting for QR code scan...") while self._running: try: - # Reference plugin sends iLink-App-ClientVersion header for - # QR status polling (login-qr.ts:81). - status_data = await self._api_get( - "ilink/bot/get_qrcode_status", + status_data = await self._api_get_with_base( + base_url=current_poll_base_url, + endpoint="ilink/bot/get_qrcode_status", params={"qrcode": qrcode_id}, auth=False, - extra_headers={"iLink-App-ClientVersion": "1"}, ) - except httpx.TimeoutException: + except Exception as e: + if self._is_retryable_qr_poll_error(e): + await asyncio.sleep(1) + continue + raise + + if not isinstance(status_data, dict): + await asyncio.sleep(1) continue status = status_data.get("status", "") @@ -309,8 +376,15 @@ class WeixinChannel(BaseChannel): else: logger.error("Login confirmed but no bot_token in response") return False - elif status == "scaned": - logger.info("QR code scanned, waiting for confirmation...") + elif status == "scaned_but_redirect": + redirect_host = str(status_data.get("redirect_host", "") or "").strip() + if redirect_host: + if redirect_host.startswith("http://") or redirect_host.startswith("https://"): + redirected_base = redirect_host + else: + redirected_base = f"https://{redirect_host}" + if redirected_base != current_poll_base_url: + current_poll_base_url = redirected_base elif status == "expired": refresh_count += 1 if refresh_count > MAX_QR_REFRESH_COUNT: @@ -320,14 +394,9 @@ class WeixinChannel(BaseChannel): MAX_QR_REFRESH_COUNT, ) return False - logger.warning( - "QR code expired, refreshing... ({}/{})", - refresh_count, - MAX_QR_REFRESH_COUNT, - ) qrcode_id, scan_url = await self._fetch_qr_code() + current_poll_base_url = self.config.base_url self._print_qr_code(scan_url) - logger.info("New QR code generated, waiting for scan...") continue # status == "wait" — keep polling @@ -338,6 +407,16 @@ class WeixinChannel(BaseChannel): return False + @staticmethod + def _is_retryable_qr_poll_error(err: Exception) -> bool: + if isinstance(err, httpx.TimeoutException | httpx.TransportError): + return True + if isinstance(err, httpx.HTTPStatusError): + status_code = err.response.status_code if err.response is not None else 0 + if status_code >= 500: + return True + return False + @staticmethod def _print_qr_code(url: str) -> None: try: @@ -348,7 +427,6 @@ class WeixinChannel(BaseChannel): qr.make(fit=True) qr.print_ascii(invert=True) except ImportError: - logger.info("QR code URL (install 'qrcode' for terminal display): {}", url) print(f"\nLogin URL: {url}\n") # ------------------------------------------------------------------ @@ -410,12 +488,6 @@ class WeixinChannel(BaseChannel): if not self._running: break consecutive_failures += 1 - logger.error( - "WeChat poll error ({}/{}): {}", - consecutive_failures, - MAX_CONSECUTIVE_FAILURES, - e, - ) if consecutive_failures >= MAX_CONSECUTIVE_FAILURES: consecutive_failures = 0 await asyncio.sleep(BACKOFF_DELAY_S) @@ -432,8 +504,6 @@ class WeixinChannel(BaseChannel): await self._client.aclose() self._client = None self._save_state() - logger.info("WeChat channel stopped") - # ------------------------------------------------------------------ # Polling (matches monitor.ts monitorWeixinProvider) # ------------------------------------------------------------------ @@ -459,10 +529,6 @@ class WeixinChannel(BaseChannel): async def _poll_once(self) -> None: remaining = self._session_pause_remaining_s() if remaining > 0: - logger.warning( - "WeChat session paused, waiting {} min before next poll.", - max((remaining + 59) // 60, 1), - ) await asyncio.sleep(remaining) return @@ -512,8 +578,8 @@ class WeixinChannel(BaseChannel): for msg in msgs: try: await self._process_message(msg) - except Exception as e: - logger.error("Error processing WeChat message: {}", e) + except Exception: + pass # ------------------------------------------------------------------ # Inbound message processing (matches inbound.ts + process-message.ts) @@ -549,6 +615,7 @@ class WeixinChannel(BaseChannel): item_list: list[dict] = msg.get("item_list") or [] content_parts: list[str] = [] media_paths: list[str] = [] + has_top_level_downloadable_media = False for item in item_list: item_type = item.get("type", 0) @@ -585,6 +652,8 @@ class WeixinChannel(BaseChannel): elif item_type == ITEM_IMAGE: image_item = item.get("image_item") or {} + if _has_downloadable_media_locator(image_item.get("media")): + has_top_level_downloadable_media = True file_path = await self._download_media_item(image_item, "image") if file_path: content_parts.append(f"[image]\n[Image: source: {file_path}]") @@ -599,6 +668,8 @@ class WeixinChannel(BaseChannel): if voice_text: content_parts.append(f"[voice] {voice_text}") else: + if _has_downloadable_media_locator(voice_item.get("media")): + has_top_level_downloadable_media = True file_path = await self._download_media_item(voice_item, "voice") if file_path: transcription = await self.transcribe_audio(file_path) @@ -612,6 +683,8 @@ class WeixinChannel(BaseChannel): elif item_type == ITEM_FILE: file_item = item.get("file_item") or {} + if _has_downloadable_media_locator(file_item.get("media")): + has_top_level_downloadable_media = True file_name = file_item.get("file_name", "unknown") file_path = await self._download_media_item( file_item, @@ -626,6 +699,8 @@ class WeixinChannel(BaseChannel): elif item_type == ITEM_VIDEO: video_item = item.get("video_item") or {} + if _has_downloadable_media_locator(video_item.get("media")): + has_top_level_downloadable_media = True file_path = await self._download_media_item(video_item, "video") if file_path: content_parts.append(f"[video]\n[Video: source: {file_path}]") @@ -633,6 +708,52 @@ class WeixinChannel(BaseChannel): else: content_parts.append("[video]") + # Fallback: when no top-level media was downloaded, try quoted/referenced media. + # This aligns with the reference plugin behavior that checks ref_msg.message_item + # when main item_list has no downloadable media. + if not media_paths and not has_top_level_downloadable_media: + ref_media_item: dict[str, Any] | None = None + for item in item_list: + if item.get("type", 0) != ITEM_TEXT: + continue + ref = item.get("ref_msg") or {} + candidate = ref.get("message_item") or {} + if candidate.get("type", 0) in (ITEM_IMAGE, ITEM_VOICE, ITEM_FILE, ITEM_VIDEO): + ref_media_item = candidate + break + + if ref_media_item: + ref_type = ref_media_item.get("type", 0) + if ref_type == ITEM_IMAGE: + image_item = ref_media_item.get("image_item") or {} + file_path = await self._download_media_item(image_item, "image") + if file_path: + content_parts.append(f"[image]\n[Image: source: {file_path}]") + media_paths.append(file_path) + elif ref_type == ITEM_VOICE: + voice_item = ref_media_item.get("voice_item") or {} + file_path = await self._download_media_item(voice_item, "voice") + if file_path: + transcription = await self.transcribe_audio(file_path) + if transcription: + content_parts.append(f"[voice] {transcription}") + else: + content_parts.append(f"[voice]\n[Audio: source: {file_path}]") + media_paths.append(file_path) + elif ref_type == ITEM_FILE: + file_item = ref_media_item.get("file_item") or {} + file_name = file_item.get("file_name", "unknown") + file_path = await self._download_media_item(file_item, "file", file_name) + if file_path: + content_parts.append(f"[file: {file_name}]\n[File: source: {file_path}]") + media_paths.append(file_path) + elif ref_type == ITEM_VIDEO: + video_item = ref_media_item.get("video_item") or {} + file_path = await self._download_media_item(video_item, "video") + if file_path: + content_parts.append(f"[video]\n[Video: source: {file_path}]") + media_paths.append(file_path) + content = "\n".join(content_parts) if not content: return @@ -667,9 +788,10 @@ class WeixinChannel(BaseChannel): """Download + AES-decrypt a media item. Returns local path or None.""" try: media = typed_item.get("media") or {} - encrypt_query_param = media.get("encrypt_query_param", "") + encrypt_query_param = str(media.get("encrypt_query_param", "") or "") + full_url = str(media.get("full_url", "") or "").strip() - if not encrypt_query_param: + if not encrypt_query_param and not full_url: return None # Resolve AES key (media-download.ts:43-45, pic-decrypt.ts:40-52) @@ -686,21 +808,50 @@ class WeixinChannel(BaseChannel): elif media_aes_key_b64: aes_key_b64 = media_aes_key_b64 - # Build CDN download URL with proper URL-encoding (cdn-url.ts:7) - cdn_url = ( - f"{self.config.cdn_base_url}/download" - f"?encrypted_query_param={quote(encrypt_query_param)}" - ) + # Reference protocol behavior: VOICE/FILE/VIDEO require aes_key; + # only IMAGE may be downloaded as plain bytes when key is missing. + if media_type != "image" and not aes_key_b64: + return None assert self._client is not None - resp = await self._client.get(cdn_url) - resp.raise_for_status() - data = resp.content + fallback_url = "" + if encrypt_query_param: + fallback_url = ( + f"{self.config.cdn_base_url}/download" + f"?encrypted_query_param={quote(encrypt_query_param)}" + ) + + download_candidates: list[tuple[str, str]] = [] + if full_url: + download_candidates.append(("full_url", full_url)) + if fallback_url and (not full_url or fallback_url != full_url): + download_candidates.append(("encrypt_query_param", fallback_url)) + + data = b"" + for idx, (download_source, cdn_url) in enumerate(download_candidates): + try: + resp = await self._client.get(cdn_url) + resp.raise_for_status() + data = resp.content + break + except Exception as e: + has_more_candidates = idx + 1 < len(download_candidates) + should_fallback = ( + download_source == "full_url" + and has_more_candidates + and self._is_retryable_media_download_error(e) + ) + if should_fallback: + logger.warning( + "WeChat media download failed via full_url, falling back to encrypt_query_param: type={} err={}", + media_type, + e, + ) + continue + raise if aes_key_b64 and data: data = _decrypt_aes_ecb(data, aes_key_b64) - elif not aes_key_b64: - logger.debug("No AES key for {} item, using raw bytes", media_type) if not data: return None @@ -709,12 +860,12 @@ class WeixinChannel(BaseChannel): ext = _ext_for_type(media_type) if not filename: ts = int(time.time()) - h = abs(hash(encrypt_query_param)) % 100000 + hash_seed = encrypt_query_param or full_url + h = abs(hash(hash_seed)) % 100000 filename = f"{media_type}_{ts}_{h}{ext}" safe_name = os.path.basename(filename) file_path = media_dir / safe_name file_path.write_bytes(data) - logger.debug("Downloaded WeChat {} to {}", media_type, file_path) return str(file_path) except Exception as e: @@ -725,14 +876,76 @@ class WeixinChannel(BaseChannel): # Outbound (matches send.ts buildTextMessageReq + sendMessageWeixin) # ------------------------------------------------------------------ + async def _get_typing_ticket(self, user_id: str, context_token: str = "") -> str: + """Get typing ticket with per-user refresh + failure backoff cache.""" + now = time.time() + entry = self._typing_tickets.get(user_id) + if entry and now < float(entry.get("next_fetch_at", 0)): + return str(entry.get("ticket", "") or "") + + body: dict[str, Any] = { + "ilink_user_id": user_id, + "context_token": context_token or None, + "base_info": BASE_INFO, + } + data = await self._api_post("ilink/bot/getconfig", body) + if data.get("ret", 0) == 0: + ticket = str(data.get("typing_ticket", "") or "") + self._typing_tickets[user_id] = { + "ticket": ticket, + "ever_succeeded": True, + "next_fetch_at": now + (random.random() * TYPING_TICKET_TTL_S), + "retry_delay_s": CONFIG_CACHE_INITIAL_RETRY_S, + } + return ticket + + prev_delay = float(entry.get("retry_delay_s", CONFIG_CACHE_INITIAL_RETRY_S)) if entry else CONFIG_CACHE_INITIAL_RETRY_S + next_delay = min(prev_delay * 2, CONFIG_CACHE_MAX_RETRY_S) + if entry: + entry["next_fetch_at"] = now + next_delay + entry["retry_delay_s"] = next_delay + return str(entry.get("ticket", "") or "") + + self._typing_tickets[user_id] = { + "ticket": "", + "ever_succeeded": False, + "next_fetch_at": now + CONFIG_CACHE_INITIAL_RETRY_S, + "retry_delay_s": CONFIG_CACHE_INITIAL_RETRY_S, + } + return "" + + async def _send_typing(self, user_id: str, typing_ticket: str, status: int) -> None: + """Best-effort sendtyping wrapper.""" + if not typing_ticket: + return + body: dict[str, Any] = { + "ilink_user_id": user_id, + "typing_ticket": typing_ticket, + "status": status, + "base_info": BASE_INFO, + } + await self._api_post("ilink/bot/sendtyping", body) + + async def _typing_keepalive_loop(self, user_id: str, typing_ticket: str, stop_event: asyncio.Event) -> None: + try: + while not stop_event.is_set(): + await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_S) + if stop_event.is_set(): + break + try: + await self._send_typing(user_id, typing_ticket, TYPING_STATUS_TYPING) + except Exception: + pass + finally: + pass + async def send(self, msg: OutboundMessage) -> None: if not self._client or not self._token: logger.warning("WeChat client not initialized or not authenticated") return try: self._assert_session_active() - except RuntimeError as e: - logger.warning("WeChat send blocked: {}", e) + except RuntimeError: return is_progress = bool((msg.metadata or {}).get("_progress", False)) @@ -748,94 +961,103 @@ class WeixinChannel(BaseChannel): ) return - # --- Send media files first (following Telegram channel pattern) --- - for media_path in (msg.media or []): - try: - await self._send_media_file(msg.chat_id, media_path, ctx_token) - except Exception as e: - filename = Path(media_path).name - logger.error("Failed to send WeChat media {}: {}", media_path, e) - # Notify user about failure via text - await self._send_text( - msg.chat_id, f"[Failed to send: {filename}]", ctx_token, - ) + typing_ticket = "" + try: + typing_ticket = await self._get_typing_ticket(msg.chat_id, ctx_token) + except Exception: + typing_ticket = "" - # --- Send text content --- - if not content: - return + if typing_ticket: + try: + await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_TYPING) + except Exception: + pass + + typing_keepalive_stop = asyncio.Event() + typing_keepalive_task: asyncio.Task | None = None + if typing_ticket: + typing_keepalive_task = asyncio.create_task( + self._typing_keepalive_loop(msg.chat_id, typing_ticket, typing_keepalive_stop) + ) try: + # --- Send media files first (following Telegram channel pattern) --- + for media_path in (msg.media or []): + try: + await self._send_media_file(msg.chat_id, media_path, ctx_token) + except Exception as e: + filename = Path(media_path).name + logger.error("Failed to send WeChat media {}: {}", media_path, e) + # Notify user about failure via text + await self._send_text( + msg.chat_id, f"[Failed to send: {filename}]", ctx_token, + ) + + # --- Send text content --- + if not content: + return + chunks = split_message(content, WEIXIN_MAX_MESSAGE_LEN) for chunk in chunks: await self._send_text(msg.chat_id, chunk, ctx_token) except Exception as e: logger.error("Error sending WeChat message: {}", e) raise + finally: + if typing_keepalive_task: + typing_keepalive_stop.set() + typing_keepalive_task.cancel() + try: + await typing_keepalive_task + except asyncio.CancelledError: + pass - async def _get_typing_ticket(self, user_id: str, context_token: str) -> str: - """Fetch and cache typing ticket for a user/context pair.""" - if not self._client or not self._token or not user_id or not context_token: - return "" - cached = self._typing_tickets.get(user_id, "") - if cached: - return cached - try: - data = await self._api_post( - "ilink/bot/getconfig", - { - "ilink_user_id": user_id, - "context_token": context_token, - }, - ) - except Exception as e: - logger.debug("WeChat getconfig failed for {}: {}", user_id, e) - return "" - ticket = str(data.get("typing_ticket") or "").strip() - if ticket: - self._typing_tickets[user_id] = ticket - self._save_state() - return ticket + if typing_ticket and not is_progress: + try: + await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_CANCEL) + except Exception: + pass - async def _send_typing_status(self, to_user_id: str, typing_ticket: str, status: int) -> None: - if not typing_ticket: - return - await self._api_post( - "ilink/bot/sendtyping", - { - "ilink_user_id": to_user_id, - "typing_ticket": typing_ticket, - "status": status, - }, - ) - - async def _start_typing(self, chat_id: str, context_token: str) -> None: - if not self._client or not self._token or not chat_id or not context_token: + async def _start_typing(self, chat_id: str, context_token: str = "") -> None: + """Start typing indicator immediately when a message is received.""" + if not self._client or not self._token or not chat_id: return await self._stop_typing(chat_id, clear_remote=False) - ticket = await self._get_typing_ticket(chat_id, context_token) - if not ticket: - return try: - await self._send_typing_status(chat_id, ticket, 1) + ticket = await self._get_typing_ticket(chat_id, context_token) + if not ticket: + return + await self._send_typing(chat_id, ticket, TYPING_STATUS_TYPING) except Exception as e: - logger.debug("WeChat typing indicator failed for {}: {}", chat_id, e) + logger.debug("WeChat typing indicator start failed for {}: {}", chat_id, e) return - async def typing_loop() -> None: - try: - while self._running: - await asyncio.sleep(5) - await self._send_typing_status(chat_id, ticket, 1) - except asyncio.CancelledError: - pass - except Exception as e: - logger.debug("WeChat typing keepalive stopped for {}: {}", chat_id, e) + stop_event = asyncio.Event() - self._typing_tasks[chat_id] = asyncio.create_task(typing_loop()) + async def keepalive() -> None: + try: + while not stop_event.is_set(): + await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_S) + if stop_event.is_set(): + break + try: + await self._send_typing(chat_id, ticket, TYPING_STATUS_TYPING) + except Exception: + pass + finally: + pass + + task = asyncio.create_task(keepalive()) + task._typing_stop_event = stop_event # type: ignore[attr-defined] + self._typing_tasks[chat_id] = task async def _stop_typing(self, chat_id: str, *, clear_remote: bool) -> None: + """Stop typing indicator for a chat.""" task = self._typing_tasks.pop(chat_id, None) if task and not task.done(): + stop_event = getattr(task, "_typing_stop_event", None) + if stop_event: + stop_event.set() task.cancel() try: await task @@ -843,11 +1065,12 @@ class WeixinChannel(BaseChannel): pass if not clear_remote: return - ticket = self._typing_tickets.get(chat_id, "") + entry = self._typing_tickets.get(chat_id) + ticket = str(entry.get("ticket", "") or "") if isinstance(entry, dict) else "" if not ticket: return try: - await self._send_typing_status(chat_id, ticket, 2) + await self._send_typing(chat_id, ticket, TYPING_STATUS_CANCEL) except Exception as e: logger.debug("WeChat typing clear failed for {}: {}", chat_id, e) @@ -923,6 +1146,10 @@ class WeixinChannel(BaseChannel): upload_type = UPLOAD_MEDIA_VIDEO item_type = ITEM_VIDEO item_key = "video_item" + elif ext in _VOICE_EXTS: + upload_type = UPLOAD_MEDIA_VOICE + item_type = ITEM_VOICE + item_key = "voice_item" else: upload_type = UPLOAD_MEDIA_FILE item_type = ITEM_FILE @@ -936,7 +1163,7 @@ class WeixinChannel(BaseChannel): # Matches aesEcbPaddedSize: Math.ceil((size + 1) / 16) * 16 padded_size = ((raw_size + 1 + 15) // 16) * 16 - # Step 1: Get upload URL (upload_param) from server + # Step 1: Get upload URL from server (prefer upload_full_url, fallback to upload_param) file_key = os.urandom(16).hex() upload_body: dict[str, Any] = { "filekey": file_key, @@ -951,22 +1178,27 @@ class WeixinChannel(BaseChannel): assert self._client is not None upload_resp = await self._api_post("ilink/bot/getuploadurl", upload_body) - logger.debug("WeChat getuploadurl response: {}", upload_resp) - upload_param = upload_resp.get("upload_param", "") - if not upload_param: - raise RuntimeError(f"getuploadurl returned no upload_param: {upload_resp}") + upload_full_url = str(upload_resp.get("upload_full_url", "") or "").strip() + upload_param = str(upload_resp.get("upload_param", "") or "") + if not upload_full_url and not upload_param: + raise RuntimeError( + "getuploadurl returned no upload URL " + f"(need upload_full_url or upload_param): {upload_resp}" + ) # Step 2: AES-128-ECB encrypt and POST to CDN aes_key_b64 = base64.b64encode(aes_key_raw).decode() encrypted_data = _encrypt_aes_ecb(raw_data, aes_key_b64) - cdn_upload_url = ( - f"{self.config.cdn_base_url}/upload" - f"?encrypted_query_param={quote(upload_param)}" - f"&filekey={quote(file_key)}" - ) - logger.debug("WeChat CDN POST url={} ciphertextSize={}", cdn_upload_url[:80], len(encrypted_data)) + if upload_full_url: + cdn_upload_url = upload_full_url + else: + cdn_upload_url = ( + f"{self.config.cdn_base_url}/upload" + f"?encrypted_query_param={quote(upload_param)}" + f"&filekey={quote(file_key)}" + ) cdn_resp = await self._client.post( cdn_upload_url, @@ -982,7 +1214,6 @@ class WeixinChannel(BaseChannel): "CDN upload response missing x-encrypted-param header; " f"status={cdn_resp.status_code} headers={dict(cdn_resp.headers)}" ) - logger.debug("WeChat CDN upload success for {}, got download_param", p.name) # Step 3: Send message with the media item # aes_key for CDNMedia is the hex key encoded as base64 @@ -1031,7 +1262,6 @@ class WeixinChannel(BaseChannel): raise RuntimeError( f"WeChat send media error (code {errcode}): {data.get('errmsg', '')}" ) - logger.info("WeChat media sent: {} (type={})", p.name, item_key) # --------------------------------------------------------------------------- @@ -1103,23 +1333,42 @@ def _decrypt_aes_ecb(data: bytes, aes_key_b64: str) -> bytes: logger.warning("Failed to parse AES key, returning raw data: {}", e) return data + decrypted: bytes | None = None + try: from Crypto.Cipher import AES cipher = AES.new(key, AES.MODE_ECB) - return cipher.decrypt(data) # pycryptodome auto-strips PKCS7 with unpad + decrypted = cipher.decrypt(data) except ImportError: pass - try: - from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + if decrypted is None: + try: + from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes - cipher_obj = Cipher(algorithms.AES(key), modes.ECB()) - decryptor = cipher_obj.decryptor() - return decryptor.update(data) + decryptor.finalize() - except ImportError: - logger.warning("Cannot decrypt media: install 'pycryptodome' or 'cryptography'") + cipher_obj = Cipher(algorithms.AES(key), modes.ECB()) + decryptor = cipher_obj.decryptor() + decrypted = decryptor.update(data) + decryptor.finalize() + except ImportError: + logger.warning("Cannot decrypt media: install 'pycryptodome' or 'cryptography'") + return data + + return _pkcs7_unpad_safe(decrypted) + + +def _pkcs7_unpad_safe(data: bytes, block_size: int = 16) -> bytes: + """Safely remove PKCS7 padding when valid; otherwise return original bytes.""" + if not data: return data + if len(data) % block_size != 0: + return data + pad_len = data[-1] + if pad_len < 1 or pad_len > block_size: + return data + if data[-pad_len:] != bytes([pad_len]) * pad_len: + return data + return data[:-pad_len] def _ext_for_type(media_type: str) -> str: diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index cacb61ae6..49521aa16 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -415,6 +415,9 @@ def _make_provider(config: Config): api_base=p.api_base, default_model=model, ) + elif backend == "github_copilot": + from nanobot.providers.github_copilot_provider import GitHubCopilotProvider + provider = GitHubCopilotProvider(default_model=model) elif backend == "anthropic": from nanobot.providers.anthropic_provider import AnthropicProvider provider = AnthropicProvider( @@ -491,6 +494,91 @@ def _migrate_cron_store(config: "Config") -> None: shutil.move(str(legacy_path), str(new_path)) +# ============================================================================ +# OpenAI-Compatible API Server +# ============================================================================ + + +@app.command() +def serve( + port: int | None = typer.Option(None, "--port", "-p", help="API server port"), + host: str | None = typer.Option(None, "--host", "-H", help="Bind address"), + timeout: float | None = typer.Option(None, "--timeout", "-t", help="Per-request timeout (seconds)"), + verbose: bool = typer.Option(False, "--verbose", "-v", help="Show nanobot runtime logs"), + workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"), + config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"), +): + """Start the OpenAI-compatible API server (/v1/chat/completions).""" + try: + from aiohttp import web # noqa: F401 + except ImportError: + console.print("[red]aiohttp is required. Install with: pip install 'nanobot-ai[api]'[/red]") + raise typer.Exit(1) + + from loguru import logger + from nanobot.agent.loop import AgentLoop + from nanobot.api.server import create_app + from nanobot.bus.queue import MessageBus + from nanobot.session.manager import SessionManager + + if verbose: + logger.enable("nanobot") + else: + logger.disable("nanobot") + + runtime_config = _load_runtime_config(config, workspace) + api_cfg = runtime_config.api + host = host if host is not None else api_cfg.host + port = port if port is not None else api_cfg.port + timeout = timeout if timeout is not None else api_cfg.timeout + sync_workspace_templates(runtime_config.workspace_path) + bus = MessageBus() + provider = _make_provider(runtime_config) + session_manager = SessionManager(runtime_config.workspace_path) + agent_loop = AgentLoop( + bus=bus, + provider=provider, + workspace=runtime_config.workspace_path, + model=runtime_config.agents.defaults.model, + max_iterations=runtime_config.agents.defaults.max_tool_iterations, + context_window_tokens=runtime_config.agents.defaults.context_window_tokens, + web_search_config=runtime_config.tools.web.search, + web_proxy=runtime_config.tools.web.proxy or None, + exec_config=runtime_config.tools.exec, + restrict_to_workspace=runtime_config.tools.restrict_to_workspace, + session_manager=session_manager, + mcp_servers=runtime_config.tools.mcp_servers, + channels_config=runtime_config.channels, + timezone=runtime_config.agents.defaults.timezone, + ) + + model_name = runtime_config.agents.defaults.model + console.print(f"{__logo__} Starting OpenAI-compatible API server") + console.print(f" [cyan]Endpoint[/cyan] : http://{host}:{port}/v1/chat/completions") + console.print(f" [cyan]Model[/cyan] : {model_name}") + console.print(" [cyan]Session[/cyan] : api:default") + console.print(f" [cyan]Timeout[/cyan] : {timeout}s") + if host in {"0.0.0.0", "::"}: + console.print( + "[yellow]Warning:[/yellow] API is bound to all interfaces. " + "Only do this behind a trusted network boundary, firewall, or reverse proxy." + ) + console.print() + + api_app = create_app(agent_loop, model_name=model_name, request_timeout=timeout) + + async def on_startup(_app): + await agent_loop._connect_mcp() + + async def on_cleanup(_app): + await agent_loop.close_mcp() + + api_app.on_startup.append(on_startup) + api_app.on_cleanup.append(on_cleanup) + + web.run_app(api_app, host=host, port=port, print=lambda msg: logger.info(msg)) + + # ============================================================================ # Gateway / Server # ============================================================================ @@ -1204,26 +1292,16 @@ def _login_openai_codex() -> None: @_register_login("github_copilot") def _login_github_copilot() -> None: - import asyncio - - from openai import AsyncOpenAI - - console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n") - - async def _trigger(): - client = AsyncOpenAI( - api_key="dummy", - base_url="https://api.githubcopilot.com", - ) - await client.chat.completions.create( - model="gpt-4o", - messages=[{"role": "user", "content": "hi"}], - max_tokens=1, - ) - try: - asyncio.run(_trigger()) - console.print("[green]✓ Authenticated with GitHub Copilot[/green]") + from nanobot.providers.github_copilot_provider import login_github_copilot + + console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n") + token = login_github_copilot( + print_fn=lambda s: console.print(s), + prompt_fn=lambda s: typer.prompt(s), + ) + account = token.account_id or "GitHub" + console.print(f"[green]✓ Authenticated with GitHub Copilot[/green] [dim]{account}[/dim]") except Exception as e: console.print(f"[red]Authentication error: {e}[/red]") raise typer.Exit(1) diff --git a/nanobot/command/builtin.py b/nanobot/command/builtin.py index 0a9af3cb9..643397057 100644 --- a/nanobot/command/builtin.py +++ b/nanobot/command/builtin.py @@ -84,6 +84,16 @@ async def cmd_new(ctx: CommandContext) -> OutboundMessage: async def cmd_help(ctx: CommandContext) -> OutboundMessage: """Return available slash commands.""" + return OutboundMessage( + channel=ctx.msg.channel, + chat_id=ctx.msg.chat_id, + content=build_help_text(), + metadata={"render_as": "text"}, + ) + + +def build_help_text() -> str: + """Build canonical help text shared across channels.""" lines = [ "🐈 nanobot commands:", "/new — Start a new conversation", @@ -92,12 +102,7 @@ async def cmd_help(ctx: CommandContext) -> OutboundMessage: "/status — Show bot status", "/help — Show available commands", ] - return OutboundMessage( - channel=ctx.msg.channel, - chat_id=ctx.msg.chat_id, - content="\n".join(lines), - metadata={"render_as": "text"}, - ) + return "\n".join(lines) def register_builtin_commands(router: CommandRouter) -> None: diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index c8b69b42e..c4c927afd 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -96,6 +96,14 @@ class HeartbeatConfig(Base): keep_recent_messages: int = 8 +class ApiConfig(Base): + """OpenAI-compatible API server configuration.""" + + host: str = "127.0.0.1" # Safer default: local-only bind. + port: int = 8900 + timeout: float = 120.0 # Per-request timeout in seconds. + + class GatewayConfig(Base): """Gateway/server configuration.""" @@ -156,6 +164,7 @@ class Config(BaseSettings): agents: AgentsConfig = Field(default_factory=AgentsConfig) channels: ChannelsConfig = Field(default_factory=ChannelsConfig) providers: ProvidersConfig = Field(default_factory=ProvidersConfig) + api: ApiConfig = Field(default_factory=ApiConfig) gateway: GatewayConfig = Field(default_factory=GatewayConfig) tools: ToolsConfig = Field(default_factory=ToolsConfig) diff --git a/nanobot/nanobot.py b/nanobot/nanobot.py new file mode 100644 index 000000000..84fb70934 --- /dev/null +++ b/nanobot/nanobot.py @@ -0,0 +1,174 @@ +"""High-level programmatic interface to nanobot.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from nanobot.agent.hook import AgentHook +from nanobot.agent.loop import AgentLoop +from nanobot.bus.queue import MessageBus + + +@dataclass(slots=True) +class RunResult: + """Result of a single agent run.""" + + content: str + tools_used: list[str] + messages: list[dict[str, Any]] + + +class Nanobot: + """Programmatic facade for running the nanobot agent. + + Usage:: + + bot = Nanobot.from_config() + result = await bot.run("Summarize this repo", hooks=[MyHook()]) + print(result.content) + """ + + def __init__(self, loop: AgentLoop) -> None: + self._loop = loop + + @classmethod + def from_config( + cls, + config_path: str | Path | None = None, + *, + workspace: str | Path | None = None, + ) -> Nanobot: + """Create a Nanobot instance from a config file. + + Args: + config_path: Path to ``config.json``. Defaults to + ``~/.nanobot/config.json``. + workspace: Override the workspace directory from config. + """ + from nanobot.config.loader import load_config + from nanobot.config.schema import Config + + resolved: Path | None = None + if config_path is not None: + resolved = Path(config_path).expanduser().resolve() + if not resolved.exists(): + raise FileNotFoundError(f"Config not found: {resolved}") + + config: Config = load_config(resolved) + if workspace is not None: + config.agents.defaults.workspace = str( + Path(workspace).expanduser().resolve() + ) + + provider = _make_provider(config) + bus = MessageBus() + defaults = config.agents.defaults + + loop = AgentLoop( + bus=bus, + provider=provider, + workspace=config.workspace_path, + model=defaults.model, + max_iterations=defaults.max_tool_iterations, + context_window_tokens=defaults.context_window_tokens, + web_search_config=config.tools.web.search, + web_proxy=config.tools.web.proxy or None, + exec_config=config.tools.exec, + restrict_to_workspace=config.tools.restrict_to_workspace, + mcp_servers=config.tools.mcp_servers, + timezone=defaults.timezone, + ) + return cls(loop) + + async def run( + self, + message: str, + *, + session_key: str = "sdk:default", + hooks: list[AgentHook] | None = None, + ) -> RunResult: + """Run the agent once and return the result. + + Args: + message: The user message to process. + session_key: Session identifier for conversation isolation. + Different keys get independent history. + hooks: Optional lifecycle hooks for this run. + """ + prev = self._loop._extra_hooks + if hooks is not None: + self._loop._extra_hooks = list(hooks) + try: + response = await self._loop.process_direct( + message, session_key=session_key, + ) + finally: + self._loop._extra_hooks = prev + + content = (response.content if response else None) or "" + return RunResult(content=content, tools_used=[], messages=[]) + + +def _make_provider(config: Any) -> Any: + """Create the LLM provider from config (extracted from CLI).""" + from nanobot.providers.base import GenerationSettings + from nanobot.providers.registry import find_by_name + + model = config.agents.defaults.model + provider_name = config.get_provider_name(model) + p = config.get_provider(model) + spec = find_by_name(provider_name) if provider_name else None + backend = spec.backend if spec else "openai_compat" + + if backend == "azure_openai": + if not p or not p.api_key or not p.api_base: + raise ValueError("Azure OpenAI requires api_key and api_base in config.") + elif backend == "openai_compat" and not model.startswith("bedrock/"): + needs_key = not (p and p.api_key) + exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct) + if needs_key and not exempt: + raise ValueError(f"No API key configured for provider '{provider_name}'.") + + if backend == "openai_codex": + from nanobot.providers.openai_codex_provider import OpenAICodexProvider + + provider = OpenAICodexProvider(default_model=model) + elif backend == "github_copilot": + from nanobot.providers.github_copilot_provider import GitHubCopilotProvider + + provider = GitHubCopilotProvider(default_model=model) + elif backend == "azure_openai": + from nanobot.providers.azure_openai_provider import AzureOpenAIProvider + + provider = AzureOpenAIProvider( + api_key=p.api_key, api_base=p.api_base, default_model=model + ) + elif backend == "anthropic": + from nanobot.providers.anthropic_provider import AnthropicProvider + + provider = AnthropicProvider( + api_key=p.api_key if p else None, + api_base=config.get_api_base(model), + default_model=model, + extra_headers=p.extra_headers if p else None, + ) + else: + from nanobot.providers.openai_compat_provider import OpenAICompatProvider + + provider = OpenAICompatProvider( + api_key=p.api_key if p else None, + api_base=config.get_api_base(model), + default_model=model, + extra_headers=p.extra_headers if p else None, + spec=spec, + ) + + defaults = config.agents.defaults + provider.generation = GenerationSettings( + temperature=defaults.temperature, + max_tokens=defaults.max_tokens, + reasoning_effort=defaults.reasoning_effort, + ) + return provider diff --git a/nanobot/providers/__init__.py b/nanobot/providers/__init__.py index 0e259e6f0..ce2378707 100644 --- a/nanobot/providers/__init__.py +++ b/nanobot/providers/__init__.py @@ -13,6 +13,7 @@ __all__ = [ "AnthropicProvider", "OpenAICompatProvider", "OpenAICodexProvider", + "GitHubCopilotProvider", "AzureOpenAIProvider", ] @@ -20,12 +21,14 @@ _LAZY_IMPORTS = { "AnthropicProvider": ".anthropic_provider", "OpenAICompatProvider": ".openai_compat_provider", "OpenAICodexProvider": ".openai_codex_provider", + "GitHubCopilotProvider": ".github_copilot_provider", "AzureOpenAIProvider": ".azure_openai_provider", } if TYPE_CHECKING: from nanobot.providers.anthropic_provider import AnthropicProvider from nanobot.providers.azure_openai_provider import AzureOpenAIProvider + from nanobot.providers.github_copilot_provider import GitHubCopilotProvider from nanobot.providers.openai_compat_provider import OpenAICompatProvider from nanobot.providers.openai_codex_provider import OpenAICodexProvider diff --git a/nanobot/providers/anthropic_provider.py b/nanobot/providers/anthropic_provider.py index 3c789e730..8e102d305 100644 --- a/nanobot/providers/anthropic_provider.py +++ b/nanobot/providers/anthropic_provider.py @@ -370,15 +370,22 @@ class AnthropicProvider(LLMProvider): usage: dict[str, int] = {} if response.usage: + input_tokens = response.usage.input_tokens + cache_creation = getattr(response.usage, "cache_creation_input_tokens", 0) or 0 + cache_read = getattr(response.usage, "cache_read_input_tokens", 0) or 0 + total_prompt_tokens = input_tokens + cache_creation + cache_read usage = { - "prompt_tokens": response.usage.input_tokens, + "prompt_tokens": total_prompt_tokens, "completion_tokens": response.usage.output_tokens, - "total_tokens": response.usage.input_tokens + response.usage.output_tokens, + "total_tokens": total_prompt_tokens + response.usage.output_tokens, } for attr in ("cache_creation_input_tokens", "cache_read_input_tokens"): val = getattr(response.usage, attr, 0) if val: usage[attr] = val + # Normalize to cached_tokens for downstream consistency. + if cache_read: + usage["cached_tokens"] = cache_read return LLMResponse( content="".join(content_parts) or None, diff --git a/nanobot/providers/github_copilot_provider.py b/nanobot/providers/github_copilot_provider.py new file mode 100644 index 000000000..8d50006a0 --- /dev/null +++ b/nanobot/providers/github_copilot_provider.py @@ -0,0 +1,257 @@ +"""GitHub Copilot OAuth-backed provider.""" + +from __future__ import annotations + +import time +import webbrowser +from collections.abc import Callable + +import httpx +from oauth_cli_kit.models import OAuthToken +from oauth_cli_kit.storage import FileTokenStorage + +from nanobot.providers.openai_compat_provider import OpenAICompatProvider + +DEFAULT_GITHUB_DEVICE_CODE_URL = "https://github.com/login/device/code" +DEFAULT_GITHUB_ACCESS_TOKEN_URL = "https://github.com/login/oauth/access_token" +DEFAULT_GITHUB_USER_URL = "https://api.github.com/user" +DEFAULT_COPILOT_TOKEN_URL = "https://api.github.com/copilot_internal/v2/token" +DEFAULT_COPILOT_BASE_URL = "https://api.githubcopilot.com" +GITHUB_COPILOT_CLIENT_ID = "Iv1.b507a08c87ecfe98" +GITHUB_COPILOT_SCOPE = "read:user" +TOKEN_FILENAME = "github-copilot.json" +TOKEN_APP_NAME = "nanobot" +USER_AGENT = "nanobot/0.1" +EDITOR_VERSION = "vscode/1.99.0" +EDITOR_PLUGIN_VERSION = "copilot-chat/0.26.0" +_EXPIRY_SKEW_SECONDS = 60 +_LONG_LIVED_TOKEN_SECONDS = 315360000 + + +def _storage() -> FileTokenStorage: + return FileTokenStorage( + token_filename=TOKEN_FILENAME, + app_name=TOKEN_APP_NAME, + import_codex_cli=False, + ) + + +def _copilot_headers(token: str) -> dict[str, str]: + return { + "Authorization": f"token {token}", + "Accept": "application/json", + "User-Agent": USER_AGENT, + "Editor-Version": EDITOR_VERSION, + "Editor-Plugin-Version": EDITOR_PLUGIN_VERSION, + } + + +def _load_github_token() -> OAuthToken | None: + token = _storage().load() + if not token or not token.access: + return None + return token + + +def get_github_copilot_login_status() -> OAuthToken | None: + """Return the persisted GitHub OAuth token if available.""" + return _load_github_token() + + +def login_github_copilot( + print_fn: Callable[[str], None] | None = None, + prompt_fn: Callable[[str], str] | None = None, +) -> OAuthToken: + """Run GitHub device flow and persist the GitHub OAuth token used for Copilot.""" + del prompt_fn + printer = print_fn or print + timeout = httpx.Timeout(20.0, connect=20.0) + + with httpx.Client(timeout=timeout, follow_redirects=True, trust_env=True) as client: + response = client.post( + DEFAULT_GITHUB_DEVICE_CODE_URL, + headers={"Accept": "application/json", "User-Agent": USER_AGENT}, + data={"client_id": GITHUB_COPILOT_CLIENT_ID, "scope": GITHUB_COPILOT_SCOPE}, + ) + response.raise_for_status() + payload = response.json() + + device_code = str(payload["device_code"]) + user_code = str(payload["user_code"]) + verify_url = str(payload.get("verification_uri") or payload.get("verification_uri_complete") or "") + verify_complete = str(payload.get("verification_uri_complete") or verify_url) + interval = max(1, int(payload.get("interval") or 5)) + expires_in = int(payload.get("expires_in") or 900) + + printer(f"Open: {verify_url}") + printer(f"Code: {user_code}") + if verify_complete: + try: + webbrowser.open(verify_complete) + except Exception: + pass + + deadline = time.time() + expires_in + current_interval = interval + access_token = None + token_expires_in = _LONG_LIVED_TOKEN_SECONDS + while time.time() < deadline: + poll = client.post( + DEFAULT_GITHUB_ACCESS_TOKEN_URL, + headers={"Accept": "application/json", "User-Agent": USER_AGENT}, + data={ + "client_id": GITHUB_COPILOT_CLIENT_ID, + "device_code": device_code, + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", + }, + ) + poll.raise_for_status() + poll_payload = poll.json() + + access_token = poll_payload.get("access_token") + if access_token: + token_expires_in = int(poll_payload.get("expires_in") or _LONG_LIVED_TOKEN_SECONDS) + break + + error = poll_payload.get("error") + if error == "authorization_pending": + time.sleep(current_interval) + continue + if error == "slow_down": + current_interval += 5 + time.sleep(current_interval) + continue + if error == "expired_token": + raise RuntimeError("GitHub device code expired. Please run login again.") + if error == "access_denied": + raise RuntimeError("GitHub device flow was denied.") + if error: + desc = poll_payload.get("error_description") or error + raise RuntimeError(str(desc)) + time.sleep(current_interval) + else: + raise RuntimeError("GitHub device flow timed out.") + + user = client.get( + DEFAULT_GITHUB_USER_URL, + headers={ + "Authorization": f"Bearer {access_token}", + "Accept": "application/vnd.github+json", + "User-Agent": USER_AGENT, + }, + ) + user.raise_for_status() + user_payload = user.json() + account_id = user_payload.get("login") or str(user_payload.get("id") or "") or None + + expires_ms = int((time.time() + token_expires_in) * 1000) + token = OAuthToken( + access=str(access_token), + refresh="", + expires=expires_ms, + account_id=str(account_id) if account_id else None, + ) + _storage().save(token) + return token + + +class GitHubCopilotProvider(OpenAICompatProvider): + """Provider that exchanges a stored GitHub OAuth token for Copilot access tokens.""" + + def __init__(self, default_model: str = "github-copilot/gpt-4.1"): + from nanobot.providers.registry import find_by_name + + self._copilot_access_token: str | None = None + self._copilot_expires_at: float = 0.0 + super().__init__( + api_key="no-key", + api_base=DEFAULT_COPILOT_BASE_URL, + default_model=default_model, + extra_headers={ + "Editor-Version": EDITOR_VERSION, + "Editor-Plugin-Version": EDITOR_PLUGIN_VERSION, + "User-Agent": USER_AGENT, + }, + spec=find_by_name("github_copilot"), + ) + + async def _get_copilot_access_token(self) -> str: + now = time.time() + if self._copilot_access_token and now < self._copilot_expires_at - _EXPIRY_SKEW_SECONDS: + return self._copilot_access_token + + github_token = _load_github_token() + if not github_token or not github_token.access: + raise RuntimeError("GitHub Copilot is not logged in. Run: nanobot provider login github-copilot") + + timeout = httpx.Timeout(20.0, connect=20.0) + async with httpx.AsyncClient(timeout=timeout, follow_redirects=True, trust_env=True) as client: + response = await client.get( + DEFAULT_COPILOT_TOKEN_URL, + headers=_copilot_headers(github_token.access), + ) + response.raise_for_status() + payload = response.json() + + token = payload.get("token") + if not token: + raise RuntimeError("GitHub Copilot token exchange returned no token.") + + expires_at = payload.get("expires_at") + if isinstance(expires_at, (int, float)): + self._copilot_expires_at = float(expires_at) + else: + refresh_in = payload.get("refresh_in") or 1500 + self._copilot_expires_at = time.time() + int(refresh_in) + self._copilot_access_token = str(token) + return self._copilot_access_token + + async def _refresh_client_api_key(self) -> str: + token = await self._get_copilot_access_token() + self.api_key = token + self._client.api_key = token + return token + + async def chat( + self, + messages: list[dict[str, object]], + tools: list[dict[str, object]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, object] | None = None, + ): + await self._refresh_client_api_key() + return await super().chat( + messages=messages, + tools=tools, + model=model, + max_tokens=max_tokens, + temperature=temperature, + reasoning_effort=reasoning_effort, + tool_choice=tool_choice, + ) + + async def chat_stream( + self, + messages: list[dict[str, object]], + tools: list[dict[str, object]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, object] | None = None, + on_content_delta: Callable[[str], None] | None = None, + ): + await self._refresh_client_api_key() + return await super().chat_stream( + messages=messages, + tools=tools, + model=model, + max_tokens=max_tokens, + temperature=temperature, + reasoning_effort=reasoning_effort, + tool_choice=tool_choice, + on_content_delta=on_content_delta, + ) diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 397b8e797..f89879c90 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -235,7 +235,9 @@ class OpenAICompatProvider(LLMProvider): spec = self._spec if spec and spec.supports_prompt_caching: - messages, tools = self._apply_cache_control(messages, tools) + model_name = model or self.default_model + if any(model_name.lower().startswith(k) for k in ("anthropic/", "claude")): + messages, tools = self._apply_cache_control(messages, tools) if spec and spec.strip_model_prefix: model_name = model_name.split("/")[-1] @@ -308,6 +310,13 @@ class OpenAICompatProvider(LLMProvider): @classmethod def _extract_usage(cls, response: Any) -> dict[str, int]: + """Extract token usage from an OpenAI-compatible response. + + Handles both dict-based (raw JSON) and object-based (SDK Pydantic) + responses. Provider-specific ``cached_tokens`` fields are normalised + under a single key; see the priority chain inside for details. + """ + # --- resolve usage object --- usage_obj = None response_map = cls._maybe_mapping(response) if response_map is not None: @@ -317,19 +326,53 @@ class OpenAICompatProvider(LLMProvider): usage_map = cls._maybe_mapping(usage_obj) if usage_map is not None: - return { + result = { "prompt_tokens": int(usage_map.get("prompt_tokens") or 0), "completion_tokens": int(usage_map.get("completion_tokens") or 0), "total_tokens": int(usage_map.get("total_tokens") or 0), } - - if usage_obj: - return { + elif usage_obj: + result = { "prompt_tokens": getattr(usage_obj, "prompt_tokens", 0) or 0, "completion_tokens": getattr(usage_obj, "completion_tokens", 0) or 0, "total_tokens": getattr(usage_obj, "total_tokens", 0) or 0, } - return {} + else: + return {} + + # --- cached_tokens (normalised across providers) --- + # Try nested paths first (dict), fall back to attribute (SDK object). + # Priority order ensures the most specific field wins. + for path in ( + ("prompt_tokens_details", "cached_tokens"), # OpenAI/Zhipu/MiniMax/Qwen/Mistral/xAI + ("cached_tokens",), # StepFun/Moonshot (top-level) + ("prompt_cache_hit_tokens",), # DeepSeek/SiliconFlow + ): + cached = cls._get_nested_int(usage_map, path) + if not cached and usage_obj: + cached = cls._get_nested_int(usage_obj, path) + if cached: + result["cached_tokens"] = cached + break + + return result + + @staticmethod + def _get_nested_int(obj: Any, path: tuple[str, ...]) -> int: + """Drill into *obj* by *path* segments and return an ``int`` value. + + Supports both dict-key access and attribute access so it works + uniformly with raw JSON dicts **and** SDK Pydantic models. + """ + current = obj + for segment in path: + if current is None: + return 0 + if isinstance(current, dict): + current = current.get(segment) + else: + current = getattr(current, segment, None) + return int(current or 0) if current is not None else 0 def _parse(self, response: Any) -> LLMResponse: if isinstance(response, str): @@ -586,4 +629,4 @@ class OpenAICompatProvider(LLMProvider): return self._handle_error(e) def get_default_model(self) -> str: - return self.default_model + return self.default_model \ No newline at end of file diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py index 5644fc51d..8435005e1 100644 --- a/nanobot/providers/registry.py +++ b/nanobot/providers/registry.py @@ -34,7 +34,7 @@ class ProviderSpec: display_name: str = "" # shown in `nanobot status` # which provider implementation to use - # "openai_compat" | "anthropic" | "azure_openai" | "openai_codex" + # "openai_compat" | "anthropic" | "azure_openai" | "openai_codex" | "github_copilot" backend: str = "openai_compat" # extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),) @@ -218,8 +218,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = ( keywords=("github_copilot", "copilot"), env_key="", display_name="Github Copilot", - backend="openai_compat", + backend="github_copilot", default_api_base="https://api.githubcopilot.com", + strip_model_prefix=True, is_oauth=True, ), # DeepSeek: OpenAI-compatible at api.deepseek.com diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index a10a4f18b..406a4dd45 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -124,8 +124,8 @@ def build_assistant_message( msg: dict[str, Any] = {"role": "assistant", "content": content} if tool_calls: msg["tool_calls"] = tool_calls - if reasoning_content is not None: - msg["reasoning_content"] = reasoning_content + if reasoning_content is not None or thinking_blocks: + msg["reasoning_content"] = reasoning_content if reasoning_content is not None else "" if thinking_blocks: msg["thinking_blocks"] = thinking_blocks return msg @@ -255,14 +255,18 @@ def build_status_content( ) last_in = last_usage.get("prompt_tokens", 0) last_out = last_usage.get("completion_tokens", 0) + cached = last_usage.get("cached_tokens", 0) ctx_total = max(context_window_tokens, 0) ctx_pct = int((context_tokens_estimate / ctx_total) * 100) if ctx_total > 0 else 0 ctx_used_str = f"{context_tokens_estimate // 1000}k" if context_tokens_estimate >= 1000 else str(context_tokens_estimate) ctx_total_str = f"{ctx_total // 1024}k" if ctx_total > 0 else "n/a" + token_line = f"\U0001f4ca Tokens: {last_in} in / {last_out} out" + if cached and last_in: + token_line += f" ({cached * 100 // last_in}% cached)" return "\n".join([ f"\U0001f408 nanobot v{version}", f"\U0001f9e0 Model: {model}", - f"\U0001f4ca Tokens: {last_in} in / {last_out} out", + token_line, f"\U0001f4da Context: {ctx_used_str}/{ctx_total_str} ({ctx_pct}%)", f"\U0001f4ac Session: {session_msg_count} messages", f"\u23f1 Uptime: {uptime}", diff --git a/pyproject.toml b/pyproject.toml index d2952b039..51d494668 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,9 @@ dependencies = [ ] [project.optional-dependencies] +api = [ + "aiohttp>=3.9.0,<4.0.0", +] wecom = [ "wecom-aibot-sdk-python>=0.1.5", ] @@ -64,12 +67,16 @@ matrix = [ "mistune>=3.0.0,<4.0.0", "nh3>=0.2.17,<1.0.0", ] +discord = [ + "discord.py>=2.5.2,<3.0.0", +] langsmith = [ "langsmith>=0.1.0", ] dev = [ "pytest>=9.0.0,<10.0.0", "pytest-asyncio>=1.3.0,<2.0.0", + "aiohttp>=3.9.0,<4.0.0", "pytest-cov>=6.0.0,<7.0.0", "ruff>=0.1.0", ] diff --git a/tests/agent/test_hook_composite.py b/tests/agent/test_hook_composite.py new file mode 100644 index 000000000..203c892fb --- /dev/null +++ b/tests/agent/test_hook_composite.py @@ -0,0 +1,351 @@ +"""Tests for CompositeHook fan-out, error isolation, and integration.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook + + +def _ctx() -> AgentHookContext: + return AgentHookContext(iteration=0, messages=[]) + + +# --------------------------------------------------------------------------- +# Fan-out: every hook is called in order +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_composite_fans_out_before_iteration(): + calls: list[str] = [] + + class H(AgentHook): + async def before_iteration(self, context: AgentHookContext) -> None: + calls.append(f"A:{context.iteration}") + + class H2(AgentHook): + async def before_iteration(self, context: AgentHookContext) -> None: + calls.append(f"B:{context.iteration}") + + hook = CompositeHook([H(), H2()]) + ctx = _ctx() + await hook.before_iteration(ctx) + assert calls == ["A:0", "B:0"] + + +@pytest.mark.asyncio +async def test_composite_fans_out_all_async_methods(): + """Verify all async methods fan out to every hook.""" + events: list[str] = [] + + class RecordingHook(AgentHook): + async def before_iteration(self, context: AgentHookContext) -> None: + events.append("before_iteration") + + async def on_stream(self, context: AgentHookContext, delta: str) -> None: + events.append(f"on_stream:{delta}") + + async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: + events.append(f"on_stream_end:{resuming}") + + async def before_execute_tools(self, context: AgentHookContext) -> None: + events.append("before_execute_tools") + + async def after_iteration(self, context: AgentHookContext) -> None: + events.append("after_iteration") + + hook = CompositeHook([RecordingHook(), RecordingHook()]) + ctx = _ctx() + + await hook.before_iteration(ctx) + await hook.on_stream(ctx, "hi") + await hook.on_stream_end(ctx, resuming=True) + await hook.before_execute_tools(ctx) + await hook.after_iteration(ctx) + + assert events == [ + "before_iteration", "before_iteration", + "on_stream:hi", "on_stream:hi", + "on_stream_end:True", "on_stream_end:True", + "before_execute_tools", "before_execute_tools", + "after_iteration", "after_iteration", + ] + + +# --------------------------------------------------------------------------- +# Error isolation: one hook raises, others still run +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_composite_error_isolation_before_iteration(): + calls: list[str] = [] + + class Bad(AgentHook): + async def before_iteration(self, context: AgentHookContext) -> None: + raise RuntimeError("boom") + + class Good(AgentHook): + async def before_iteration(self, context: AgentHookContext) -> None: + calls.append("good") + + hook = CompositeHook([Bad(), Good()]) + await hook.before_iteration(_ctx()) + assert calls == ["good"] + + +@pytest.mark.asyncio +async def test_composite_error_isolation_on_stream(): + calls: list[str] = [] + + class Bad(AgentHook): + async def on_stream(self, context: AgentHookContext, delta: str) -> None: + raise RuntimeError("stream-boom") + + class Good(AgentHook): + async def on_stream(self, context: AgentHookContext, delta: str) -> None: + calls.append(delta) + + hook = CompositeHook([Bad(), Good()]) + await hook.on_stream(_ctx(), "delta") + assert calls == ["delta"] + + +@pytest.mark.asyncio +async def test_composite_error_isolation_all_async(): + """Error isolation for on_stream_end, before_execute_tools, after_iteration.""" + calls: list[str] = [] + + class Bad(AgentHook): + async def on_stream_end(self, context, *, resuming): + raise RuntimeError("err") + async def before_execute_tools(self, context): + raise RuntimeError("err") + async def after_iteration(self, context): + raise RuntimeError("err") + + class Good(AgentHook): + async def on_stream_end(self, context, *, resuming): + calls.append("on_stream_end") + async def before_execute_tools(self, context): + calls.append("before_execute_tools") + async def after_iteration(self, context): + calls.append("after_iteration") + + hook = CompositeHook([Bad(), Good()]) + ctx = _ctx() + await hook.on_stream_end(ctx, resuming=False) + await hook.before_execute_tools(ctx) + await hook.after_iteration(ctx) + assert calls == ["on_stream_end", "before_execute_tools", "after_iteration"] + + +# --------------------------------------------------------------------------- +# finalize_content: pipeline semantics (no error isolation) +# --------------------------------------------------------------------------- + + +def test_composite_finalize_content_pipeline(): + class Upper(AgentHook): + def finalize_content(self, context, content): + return content.upper() if content else content + + class Suffix(AgentHook): + def finalize_content(self, context, content): + return (content + "!") if content else content + + hook = CompositeHook([Upper(), Suffix()]) + result = hook.finalize_content(_ctx(), "hello") + assert result == "HELLO!" + + +def test_composite_finalize_content_none_passthrough(): + hook = CompositeHook([AgentHook()]) + assert hook.finalize_content(_ctx(), None) is None + + +def test_composite_finalize_content_ordering(): + """First hook transforms first, result feeds second hook.""" + steps: list[str] = [] + + class H1(AgentHook): + def finalize_content(self, context, content): + steps.append(f"H1:{content}") + return content.upper() + + class H2(AgentHook): + def finalize_content(self, context, content): + steps.append(f"H2:{content}") + return content + "!" + + hook = CompositeHook([H1(), H2()]) + result = hook.finalize_content(_ctx(), "hi") + assert result == "HI!" + assert steps == ["H1:hi", "H2:HI"] + + +# --------------------------------------------------------------------------- +# wants_streaming: any-semantics +# --------------------------------------------------------------------------- + + +def test_composite_wants_streaming_any_true(): + class No(AgentHook): + def wants_streaming(self): + return False + + class Yes(AgentHook): + def wants_streaming(self): + return True + + hook = CompositeHook([No(), Yes(), No()]) + assert hook.wants_streaming() is True + + +def test_composite_wants_streaming_all_false(): + hook = CompositeHook([AgentHook(), AgentHook()]) + assert hook.wants_streaming() is False + + +def test_composite_wants_streaming_empty(): + hook = CompositeHook([]) + assert hook.wants_streaming() is False + + +# --------------------------------------------------------------------------- +# Empty hooks list: behaves like no-op AgentHook +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_composite_empty_hooks_no_ops(): + hook = CompositeHook([]) + ctx = _ctx() + await hook.before_iteration(ctx) + await hook.on_stream(ctx, "delta") + await hook.on_stream_end(ctx, resuming=False) + await hook.before_execute_tools(ctx) + await hook.after_iteration(ctx) + assert hook.finalize_content(ctx, "test") == "test" + + +# --------------------------------------------------------------------------- +# Integration: AgentLoop with extra hooks +# --------------------------------------------------------------------------- + + +def _make_loop(tmp_path, hooks=None): + from nanobot.agent.loop import AgentLoop + from nanobot.bus.queue import MessageBus + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + provider.generation.max_tokens = 4096 + + with patch("nanobot.agent.loop.ContextBuilder"), \ + patch("nanobot.agent.loop.SessionManager"), \ + patch("nanobot.agent.loop.SubagentManager") as mock_sub_mgr, \ + patch("nanobot.agent.loop.MemoryConsolidator"): + mock_sub_mgr.return_value.cancel_by_session = AsyncMock(return_value=0) + loop = AgentLoop( + bus=bus, provider=provider, workspace=tmp_path, hooks=hooks, + ) + return loop + + +@pytest.mark.asyncio +async def test_agent_loop_extra_hook_receives_calls(tmp_path): + """Extra hook passed to AgentLoop is called alongside core LoopHook.""" + from nanobot.providers.base import LLMResponse + + events: list[str] = [] + + class TrackingHook(AgentHook): + async def before_iteration(self, context): + events.append(f"before_iter:{context.iteration}") + + async def after_iteration(self, context): + events.append(f"after_iter:{context.iteration}") + + loop = _make_loop(tmp_path, hooks=[TrackingHook()]) + loop.provider.chat_with_retry = AsyncMock( + return_value=LLMResponse(content="done", tool_calls=[], usage={}) + ) + loop.tools.get_definitions = MagicMock(return_value=[]) + + content, tools_used, messages = await loop._run_agent_loop( + [{"role": "user", "content": "hi"}] + ) + + assert content == "done" + assert "before_iter:0" in events + assert "after_iter:0" in events + + +@pytest.mark.asyncio +async def test_agent_loop_extra_hook_error_isolation(tmp_path): + """A faulty extra hook does not crash the agent loop.""" + from nanobot.providers.base import LLMResponse + + class BadHook(AgentHook): + async def before_iteration(self, context): + raise RuntimeError("I am broken") + + loop = _make_loop(tmp_path, hooks=[BadHook()]) + loop.provider.chat_with_retry = AsyncMock( + return_value=LLMResponse(content="still works", tool_calls=[], usage={}) + ) + loop.tools.get_definitions = MagicMock(return_value=[]) + + content, _, _ = await loop._run_agent_loop( + [{"role": "user", "content": "hi"}] + ) + + assert content == "still works" + + +@pytest.mark.asyncio +async def test_agent_loop_extra_hooks_do_not_swallow_loop_hook_errors(tmp_path): + """Extra hooks must not change the core LoopHook failure behavior.""" + from nanobot.providers.base import LLMResponse, ToolCallRequest + + loop = _make_loop(tmp_path, hooks=[AgentHook()]) + loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse( + content="working", + tool_calls=[ToolCallRequest(id="c1", name="list_dir", arguments={"path": "."})], + usage={}, + )) + loop.tools.get_definitions = MagicMock(return_value=[]) + loop.tools.execute = AsyncMock(return_value="ok") + + async def bad_progress(*args, **kwargs): + raise RuntimeError("progress failed") + + with pytest.raises(RuntimeError, match="progress failed"): + await loop._run_agent_loop([], on_progress=bad_progress) + + +@pytest.mark.asyncio +async def test_agent_loop_no_hooks_backward_compat(tmp_path): + """Without hooks param, behavior is identical to before.""" + from nanobot.providers.base import LLMResponse, ToolCallRequest + + loop = _make_loop(tmp_path) + loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse( + content="working", + tool_calls=[ToolCallRequest(id="c1", name="list_dir", arguments={"path": "."})], + )) + loop.tools.get_definitions = MagicMock(return_value=[]) + loop.tools.execute = AsyncMock(return_value="ok") + loop.max_iterations = 2 + + content, tools_used, _ = await loop._run_agent_loop([]) + assert content == ( + "I reached the maximum number of tool call iterations (2) " + "without completing the task. You can try breaking the task into smaller steps." + ) + assert tools_used == ["list_dir", "list_dir"] diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index 86b0ba710..98f1d73ae 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -333,3 +333,82 @@ async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, mon args = mgr._announce_result.await_args.args assert args[3] == "Task completed but no final response was generated." assert args[5] == "ok" + + +@pytest.mark.asyncio +async def test_runner_accumulates_usage_and_preserves_cached_tokens(): + """Runner should accumulate prompt/completion tokens across iterations + and preserve cached_tokens from provider responses.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="thinking", + tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})], + usage={"prompt_tokens": 100, "completion_tokens": 10, "cached_tokens": 80}, + ) + return LLMResponse( + content="done", + tool_calls=[], + usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="file content") + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "do task"}], + tools=tools, + model="test-model", + max_iterations=3, + )) + + # Usage should be accumulated across iterations + assert result.usage["prompt_tokens"] == 300 # 100 + 200 + assert result.usage["completion_tokens"] == 30 # 10 + 20 + assert result.usage["cached_tokens"] == 230 # 80 + 150 + + +@pytest.mark.asyncio +async def test_runner_passes_cached_tokens_to_hook_context(): + """Hook context.usage should contain cached_tokens.""" + from nanobot.agent.hook import AgentHook, AgentHookContext + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + captured_usage: list[dict] = [] + + class UsageHook(AgentHook): + async def after_iteration(self, context: AgentHookContext) -> None: + captured_usage.append(dict(context.usage)) + + async def chat_with_retry(**kwargs): + return LLMResponse( + content="done", + tool_calls=[], + usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + runner = AgentRunner(provider) + await runner.run(AgentRunSpec( + initial_messages=[], + tools=tools, + model="test-model", + max_iterations=1, + hook=UsageHook(), + )) + + assert len(captured_usage) == 1 + assert captured_usage[0]["cached_tokens"] == 150 diff --git a/tests/agent/test_task_cancel.py b/tests/agent/test_task_cancel.py index 8894cd973..70f7621d1 100644 --- a/tests/agent/test_task_cancel.py +++ b/tests/agent/test_task_cancel.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -116,6 +117,43 @@ class TestDispatch: out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0) assert out.content == "hi" + @pytest.mark.asyncio + async def test_dispatch_streaming_preserves_message_metadata(self): + from nanobot.bus.events import InboundMessage + + loop, bus = _make_loop() + msg = InboundMessage( + channel="matrix", + sender_id="u1", + chat_id="!room:matrix.org", + content="hello", + metadata={ + "_wants_stream": True, + "thread_root_event_id": "$root1", + "thread_reply_to_event_id": "$reply1", + }, + ) + + async def fake_process(_msg, *, on_stream=None, on_stream_end=None, **kwargs): + assert on_stream is not None + assert on_stream_end is not None + await on_stream("hi") + await on_stream_end(resuming=False) + return None + + loop._process_message = fake_process + + await loop._dispatch(msg) + first = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0) + second = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0) + + assert first.metadata["thread_root_event_id"] == "$root1" + assert first.metadata["thread_reply_to_event_id"] == "$reply1" + assert first.metadata["_stream_delta"] is True + assert second.metadata["thread_root_event_id"] == "$root1" + assert second.metadata["thread_reply_to_event_id"] == "$reply1" + assert second.metadata["_stream_end"] is True + @pytest.mark.asyncio async def test_processing_lock_serializes(self): from nanobot.bus.events import InboundMessage, OutboundMessage @@ -222,6 +260,39 @@ class TestSubagentCancellation: assert assistant_messages[0]["reasoning_content"] == "hidden reasoning" assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}] + @pytest.mark.asyncio + async def test_subagent_exec_tool_not_registered_when_disabled(self, tmp_path): + from nanobot.agent.subagent import SubagentManager + from nanobot.bus.queue import MessageBus + from nanobot.config.schema import ExecToolConfig + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + mgr = SubagentManager( + provider=provider, + workspace=tmp_path, + bus=bus, + exec_config=ExecToolConfig(enable=False), + ) + mgr._announce_result = AsyncMock() + + async def fake_run(spec): + assert spec.tools.get("exec") is None + return SimpleNamespace( + stop_reason="done", + final_content="done", + error=None, + tool_events=[], + ) + + mgr.runner.run = AsyncMock(side_effect=fake_run) + + await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"}) + + mgr.runner.run.assert_awaited_once() + mgr._announce_result.assert_awaited_once() + @pytest.mark.asyncio async def test_subagent_announces_error_when_tool_execution_fails(self, monkeypatch, tmp_path): from nanobot.agent.subagent import SubagentManager diff --git a/tests/channels/test_discord_channel.py b/tests/channels/test_discord_channel.py new file mode 100644 index 000000000..d352c788c --- /dev/null +++ b/tests/channels/test_discord_channel.py @@ -0,0 +1,676 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path +from types import SimpleNamespace + +import pytest +discord = pytest.importorskip("discord") + +from nanobot.bus.events import OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.discord import DiscordBotClient, DiscordChannel, DiscordConfig +from nanobot.command.builtin import build_help_text + + +# Minimal Discord client test double used to control startup/readiness behavior. +class _FakeDiscordClient: + instances: list["_FakeDiscordClient"] = [] + start_error: Exception | None = None + + def __init__(self, owner, *, intents) -> None: + self.owner = owner + self.intents = intents + self.closed = False + self.ready = True + self.channels: dict[int, object] = {} + self.user = SimpleNamespace(id=999) + self.__class__.instances.append(self) + + async def start(self, token: str) -> None: + self.token = token + if self.__class__.start_error is not None: + raise self.__class__.start_error + + async def close(self) -> None: + self.closed = True + + def is_closed(self) -> bool: + return self.closed + + def is_ready(self) -> bool: + return self.ready + + def get_channel(self, channel_id: int): + return self.channels.get(channel_id) + + async def send_outbound(self, msg: OutboundMessage) -> None: + channel = self.get_channel(int(msg.chat_id)) + if channel is None: + return + await channel.send(content=msg.content) + + +class _FakeAttachment: + # Attachment double that can simulate successful or failing save() calls. + def __init__(self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False) -> None: + self.id = attachment_id + self.filename = filename + self.size = size + self._fail = fail + + async def save(self, path: str | Path) -> None: + if self._fail: + raise RuntimeError("save failed") + Path(path).write_bytes(b"attachment") + + +class _FakePartialMessage: + # Lightweight stand-in for Discord partial message references used in replies. + def __init__(self, message_id: int) -> None: + self.id = message_id + + +class _FakeChannel: + # Channel double that records outbound payloads and typing activity. + def __init__(self, channel_id: int = 123) -> None: + self.id = channel_id + self.sent_payloads: list[dict] = [] + self.trigger_typing_calls = 0 + self.typing_enter_hook = None + + async def send(self, **kwargs) -> None: + payload = dict(kwargs) + if "file" in payload: + payload["file_name"] = payload["file"].filename + del payload["file"] + self.sent_payloads.append(payload) + + def get_partial_message(self, message_id: int) -> _FakePartialMessage: + return _FakePartialMessage(message_id) + + def typing(self): + channel = self + + class _TypingContext: + async def __aenter__(self): + channel.trigger_typing_calls += 1 + if channel.typing_enter_hook is not None: + await channel.typing_enter_hook() + + async def __aexit__(self, exc_type, exc, tb): + return False + + return _TypingContext() + + +class _FakeInteractionResponse: + def __init__(self) -> None: + self.messages: list[dict] = [] + self._done = False + + async def send_message(self, content: str, *, ephemeral: bool = False) -> None: + self.messages.append({"content": content, "ephemeral": ephemeral}) + self._done = True + + def is_done(self) -> bool: + return self._done + + +def _make_interaction( + *, + user_id: int = 123, + channel_id: int | None = 456, + guild_id: int | None = None, + interaction_id: int = 999, +): + return SimpleNamespace( + user=SimpleNamespace(id=user_id), + channel_id=channel_id, + guild_id=guild_id, + id=interaction_id, + command=SimpleNamespace(qualified_name="new"), + response=_FakeInteractionResponse(), + ) + + +def _make_message( + *, + author_id: int = 123, + author_bot: bool = False, + channel_id: int = 456, + message_id: int = 789, + content: str = "hello", + guild_id: int | None = None, + mentions: list[object] | None = None, + attachments: list[object] | None = None, + reply_to: int | None = None, +): + # Factory for incoming Discord message objects with optional guild/reply/attachments. + guild = SimpleNamespace(id=guild_id) if guild_id is not None else None + reference = SimpleNamespace(message_id=reply_to) if reply_to is not None else None + return SimpleNamespace( + author=SimpleNamespace(id=author_id, bot=author_bot), + channel=_FakeChannel(channel_id), + content=content, + guild=guild, + mentions=mentions or [], + attachments=attachments or [], + reference=reference, + id=message_id, + ) + + +@pytest.mark.asyncio +async def test_start_returns_when_token_missing() -> None: + # If no token is configured, startup should no-op and leave channel stopped. + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + + await channel.start() + + assert channel.is_running is False + assert channel._client is None + + +@pytest.mark.asyncio +async def test_start_returns_when_discord_dependency_missing(monkeypatch) -> None: + channel = DiscordChannel( + DiscordConfig(enabled=True, token="token", allow_from=["*"]), + MessageBus(), + ) + monkeypatch.setattr("nanobot.channels.discord.DISCORD_AVAILABLE", False) + + await channel.start() + + assert channel.is_running is False + assert channel._client is None + + +@pytest.mark.asyncio +async def test_start_handles_client_construction_failure(monkeypatch) -> None: + # Construction errors from the Discord client should be swallowed and keep state clean. + channel = DiscordChannel( + DiscordConfig(enabled=True, token="token", allow_from=["*"]), + MessageBus(), + ) + + def _boom(owner, *, intents): + raise RuntimeError("bad client") + + monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom) + + await channel.start() + + assert channel.is_running is False + assert channel._client is None + + +@pytest.mark.asyncio +async def test_start_handles_client_start_failure(monkeypatch) -> None: + # If client.start fails, the partially created client should be closed and detached. + channel = DiscordChannel( + DiscordConfig(enabled=True, token="token", allow_from=["*"]), + MessageBus(), + ) + + _FakeDiscordClient.instances.clear() + _FakeDiscordClient.start_error = RuntimeError("connect failed") + monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient) + + await channel.start() + + assert channel.is_running is False + assert channel._client is None + assert _FakeDiscordClient.instances[0].intents.value == channel.config.intents + assert _FakeDiscordClient.instances[0].closed is True + + _FakeDiscordClient.start_error = None + + +@pytest.mark.asyncio +async def test_stop_is_safe_after_partial_start(monkeypatch) -> None: + # stop() should close/discard the client even when startup was only partially completed. + channel = DiscordChannel( + DiscordConfig(enabled=True, token="token", allow_from=["*"]), + MessageBus(), + ) + client = _FakeDiscordClient(channel, intents=None) + channel._client = client + channel._running = True + + await channel.stop() + + assert channel.is_running is False + assert client.closed is True + assert channel._client is None + + +@pytest.mark.asyncio +async def test_on_message_ignores_bot_messages() -> None: + # Incoming bot-authored messages must be ignored to prevent feedback loops. + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + handled: list[dict] = [] + channel._handle_message = lambda **kwargs: handled.append(kwargs) # type: ignore[method-assign] + + await channel._on_message(_make_message(author_bot=True)) + + assert handled == [] + + # If inbound handling raises, typing should be stopped for that channel. + async def fail_handle(**kwargs) -> None: + raise RuntimeError("boom") + + channel._handle_message = fail_handle # type: ignore[method-assign] + + with pytest.raises(RuntimeError, match="boom"): + await channel._on_message(_make_message(author_id=123, channel_id=456)) + + assert channel._typing_tasks == {} + + +@pytest.mark.asyncio +async def test_on_message_accepts_allowlisted_dm() -> None: + # Allowed direct messages should be forwarded with normalized metadata. + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus()) + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + + await channel._on_message(_make_message(author_id=123, channel_id=456, message_id=789)) + + assert len(handled) == 1 + assert handled[0]["chat_id"] == "456" + assert handled[0]["metadata"] == {"message_id": "789", "guild_id": None, "reply_to": None} + + +@pytest.mark.asyncio +async def test_on_message_ignores_unmentioned_guild_message() -> None: + # With mention-only group policy, guild messages without a bot mention are dropped. + channel = DiscordChannel( + DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"), + MessageBus(), + ) + channel._bot_user_id = "999" + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + + await channel._on_message(_make_message(guild_id=1, content="hello everyone")) + + assert handled == [] + + +@pytest.mark.asyncio +async def test_on_message_accepts_mentioned_guild_message() -> None: + # Mentioned guild messages should be accepted and preserve reply threading metadata. + channel = DiscordChannel( + DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"), + MessageBus(), + ) + channel._bot_user_id = "999" + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + + await channel._on_message( + _make_message( + guild_id=1, + content="<@999> hello", + mentions=[SimpleNamespace(id=999)], + reply_to=321, + ) + ) + + assert len(handled) == 1 + assert handled[0]["metadata"]["reply_to"] == "321" + + +@pytest.mark.asyncio +async def test_on_message_downloads_attachments(tmp_path, monkeypatch) -> None: + # Attachment downloads should be saved and referenced in forwarded content/media. + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path) + + await channel._on_message( + _make_message( + attachments=[_FakeAttachment(12, "photo.png")], + content="see file", + ) + ) + + assert len(handled) == 1 + assert handled[0]["media"] == [str(tmp_path / "12_photo.png")] + assert "[attachment:" in handled[0]["content"] + + +@pytest.mark.asyncio +async def test_on_message_marks_failed_attachment_download(tmp_path, monkeypatch) -> None: + # Failed attachment downloads should emit a readable placeholder and no media path. + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path) + + await channel._on_message( + _make_message( + attachments=[_FakeAttachment(12, "photo.png", fail=True)], + content="", + ) + ) + + assert len(handled) == 1 + assert handled[0]["media"] == [] + assert handled[0]["content"] == "[attachment: photo.png - download failed]" + + +@pytest.mark.asyncio +async def test_send_warns_when_client_not_ready() -> None: + # Sending without a running/ready client should be a safe no-op. + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + + await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello")) + + assert channel._typing_tasks == {} + + +@pytest.mark.asyncio +async def test_send_skips_when_channel_not_cached() -> None: + # Outbound sends should be skipped when the destination channel is not resolvable. + owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = DiscordBotClient(owner, intents=discord.Intents.none()) + fetch_calls: list[int] = [] + + async def fetch_channel(channel_id: int): + fetch_calls.append(channel_id) + raise RuntimeError("not found") + + client.fetch_channel = fetch_channel # type: ignore[method-assign] + + await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello")) + + assert client.get_channel(123) is None + assert fetch_calls == [123] + + +@pytest.mark.asyncio +async def test_send_fetches_channel_when_not_cached() -> None: + owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = DiscordBotClient(owner, intents=discord.Intents.none()) + target = _FakeChannel(channel_id=123) + + async def fetch_channel(channel_id: int): + return target if channel_id == 123 else None + + client.fetch_channel = fetch_channel # type: ignore[method-assign] + + await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello")) + + assert target.sent_payloads == [{"content": "hello"}] + + +@pytest.mark.asyncio +async def test_slash_new_forwards_when_user_is_allowlisted() -> None: + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus()) + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + client = DiscordBotClient(channel, intents=discord.Intents.none()) + interaction = _make_interaction(user_id=123, channel_id=456, interaction_id=321) + + new_cmd = client.tree.get_command("new") + assert new_cmd is not None + await new_cmd.callback(interaction) + + assert interaction.response.messages == [ + {"content": "Processing /new...", "ephemeral": True} + ] + assert len(handled) == 1 + assert handled[0]["content"] == "/new" + assert handled[0]["sender_id"] == "123" + assert handled[0]["chat_id"] == "456" + assert handled[0]["metadata"]["interaction_id"] == "321" + assert handled[0]["metadata"]["is_slash_command"] is True + + +@pytest.mark.asyncio +async def test_slash_new_is_blocked_for_disallowed_user() -> None: + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["999"]), MessageBus()) + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + client = DiscordBotClient(channel, intents=discord.Intents.none()) + interaction = _make_interaction(user_id=123, channel_id=456) + + new_cmd = client.tree.get_command("new") + assert new_cmd is not None + await new_cmd.callback(interaction) + + assert interaction.response.messages == [ + {"content": "You are not allowed to use this bot.", "ephemeral": True} + ] + assert handled == [] + + +@pytest.mark.parametrize("slash_name", ["stop", "restart", "status"]) +@pytest.mark.asyncio +async def test_slash_commands_forward_via_handle_message(slash_name: str) -> None: + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + client = DiscordBotClient(channel, intents=discord.Intents.none()) + interaction = _make_interaction() + interaction.command.qualified_name = slash_name + + cmd = client.tree.get_command(slash_name) + assert cmd is not None + await cmd.callback(interaction) + + assert interaction.response.messages == [ + {"content": f"Processing /{slash_name}...", "ephemeral": True} + ] + assert len(handled) == 1 + assert handled[0]["content"] == f"/{slash_name}" + assert handled[0]["metadata"]["is_slash_command"] is True + + +@pytest.mark.asyncio +async def test_slash_help_returns_ephemeral_help_text() -> None: + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + client = DiscordBotClient(channel, intents=discord.Intents.none()) + interaction = _make_interaction() + interaction.command.qualified_name = "help" + + help_cmd = client.tree.get_command("help") + assert help_cmd is not None + await help_cmd.callback(interaction) + + assert interaction.response.messages == [ + {"content": build_help_text(), "ephemeral": True} + ] + assert handled == [] + + +@pytest.mark.asyncio +async def test_client_send_outbound_chunks_text_replies_and_uploads_files(tmp_path) -> None: + # Outbound payloads should upload files, attach reply references, and chunk long text. + owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = DiscordBotClient(owner, intents=discord.Intents.none()) + target = _FakeChannel(channel_id=123) + client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign] + + file_path = tmp_path / "demo.txt" + file_path.write_text("hi") + + await client.send_outbound( + OutboundMessage( + channel="discord", + chat_id="123", + content="a" * 2100, + reply_to="55", + media=[str(file_path)], + ) + ) + + assert len(target.sent_payloads) == 3 + assert target.sent_payloads[0]["file_name"] == "demo.txt" + assert target.sent_payloads[0]["reference"].id == 55 + assert target.sent_payloads[1]["content"] == "a" * 2000 + assert target.sent_payloads[2]["content"] == "a" * 100 + + +@pytest.mark.asyncio +async def test_client_send_outbound_reports_failed_attachments_when_no_text(tmp_path) -> None: + # If all attachment sends fail and no text exists, emit a failure placeholder message. + owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = DiscordBotClient(owner, intents=discord.Intents.none()) + target = _FakeChannel(channel_id=123) + client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign] + + missing_file = tmp_path / "missing.txt" + + await client.send_outbound( + OutboundMessage( + channel="discord", + chat_id="123", + content="", + media=[str(missing_file)], + ) + ) + + assert target.sent_payloads == [{"content": "[attachment: missing.txt - send failed]"}] + + +@pytest.mark.asyncio +async def test_send_stops_typing_after_send() -> None: + # Active typing indicators should be cancelled/cleared after a successful send. + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = _FakeDiscordClient(channel, intents=None) + channel._client = client + channel._running = True + + start = asyncio.Event() + release = asyncio.Event() + + async def slow_typing() -> None: + start.set() + await release.wait() + + typing_channel = _FakeChannel(channel_id=123) + typing_channel.typing_enter_hook = slow_typing + + await channel._start_typing(typing_channel) + await start.wait() + + await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello")) + release.set() + await asyncio.sleep(0) + + assert channel._typing_tasks == {} + + # Progress messages should keep typing active until a final (non-progress) send. + start = asyncio.Event() + release = asyncio.Event() + + async def slow_typing_progress() -> None: + start.set() + await release.wait() + + typing_channel = _FakeChannel(channel_id=123) + typing_channel.typing_enter_hook = slow_typing_progress + + await channel._start_typing(typing_channel) + await start.wait() + + await channel.send( + OutboundMessage( + channel="discord", + chat_id="123", + content="progress", + metadata={"_progress": True}, + ) + ) + + assert "123" in channel._typing_tasks + + await channel.send(OutboundMessage(channel="discord", chat_id="123", content="final")) + release.set() + await asyncio.sleep(0) + + assert channel._typing_tasks == {} + + +@pytest.mark.asyncio +async def test_start_typing_uses_typing_context_when_trigger_typing_missing() -> None: + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + channel._running = True + + entered = asyncio.Event() + release = asyncio.Event() + + class _TypingCtx: + async def __aenter__(self): + entered.set() + + async def __aexit__(self, exc_type, exc, tb): + return False + + class _NoTriggerChannel: + def __init__(self, channel_id: int = 123) -> None: + self.id = channel_id + + def typing(self): + async def _waiter(): + await release.wait() + # Hold the loop so task remains active until explicitly stopped. + class _Ctx(_TypingCtx): + async def __aenter__(self): + await super().__aenter__() + await _waiter() + return _Ctx() + + typing_channel = _NoTriggerChannel(channel_id=123) + await channel._start_typing(typing_channel) # type: ignore[arg-type] + await entered.wait() + + assert "123" in channel._typing_tasks + + await channel._stop_typing("123") + release.set() + await asyncio.sleep(0) + + assert channel._typing_tasks == {} diff --git a/tests/channels/test_matrix_channel.py b/tests/channels/test_matrix_channel.py index dd5e97d90..18a8e1097 100644 --- a/tests/channels/test_matrix_channel.py +++ b/tests/channels/test_matrix_channel.py @@ -3,6 +3,9 @@ from pathlib import Path from types import SimpleNamespace import pytest +from nio import RoomSendResponse + +from nanobot.channels.matrix import _build_matrix_text_content # Check optional matrix dependencies before importing try: @@ -65,6 +68,7 @@ class _FakeAsyncClient: self.raise_on_send = False self.raise_on_typing = False self.raise_on_upload = False + self.room_send_response: RoomSendResponse | None = RoomSendResponse(event_id="", room_id="") def add_event_callback(self, callback, event_type) -> None: self.callbacks.append((callback, event_type)) @@ -87,7 +91,7 @@ class _FakeAsyncClient: message_type: str, content: dict[str, object], ignore_unverified_devices: object = _ROOM_SEND_UNSET, - ) -> None: + ) -> RoomSendResponse: call: dict[str, object] = { "room_id": room_id, "message_type": message_type, @@ -98,6 +102,7 @@ class _FakeAsyncClient: self.room_send_calls.append(call) if self.raise_on_send: raise RuntimeError("send failed") + return self.room_send_response async def room_typing( self, @@ -520,6 +525,7 @@ async def test_on_message_room_mention_requires_opt_in() -> None: source={"content": {"m.mentions": {"room": True}}}, ) + channel.config.allow_room_mentions = False await channel._on_message(room, room_mention_event) assert handled == [] assert client.typing_calls == [] @@ -1322,3 +1328,302 @@ async def test_send_keeps_plaintext_only_for_plain_text() -> None: "body": text, "m.mentions": {}, } + + +def test_build_matrix_text_content_basic_text() -> None: + """Test basic text content without HTML formatting.""" + result = _build_matrix_text_content("Hello, World!") + expected = { + "msgtype": "m.text", + "body": "Hello, World!", + "m.mentions": {} + } + assert expected == result + + +def test_build_matrix_text_content_with_markdown() -> None: + """Test text content with markdown that renders to HTML.""" + text = "*Hello* **World**" + result = _build_matrix_text_content(text) + assert "msgtype" in result + assert "body" in result + assert result["body"] == text + assert "format" in result + assert result["format"] == "org.matrix.custom.html" + assert "formatted_body" in result + assert isinstance(result["formatted_body"], str) + assert len(result["formatted_body"]) > 0 + + +def test_build_matrix_text_content_with_event_id() -> None: + """Test text content with event_id for message replacement.""" + event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo" + result = _build_matrix_text_content("Updated message", event_id) + assert "msgtype" in result + assert "body" in result + assert result["m.new_content"] + assert result["m.new_content"]["body"] == "Updated message" + assert result["m.relates_to"]["rel_type"] == "m.replace" + assert result["m.relates_to"]["event_id"] == event_id + + +def test_build_matrix_text_content_with_event_id_preserves_thread_relation() -> None: + """Thread relations for edits should stay inside m.new_content.""" + relates_to = { + "rel_type": "m.thread", + "event_id": "$root1", + "m.in_reply_to": {"event_id": "$reply1"}, + "is_falling_back": True, + } + result = _build_matrix_text_content("Updated message", "event-1", relates_to) + + assert result["m.relates_to"] == { + "rel_type": "m.replace", + "event_id": "event-1", + } + assert result["m.new_content"]["m.relates_to"] == relates_to + + +def test_build_matrix_text_content_no_event_id() -> None: + """Test that when event_id is not provided, no extra properties are added.""" + result = _build_matrix_text_content("Regular message") + + # Basic required properties should be present + assert "msgtype" in result + assert "body" in result + assert result["body"] == "Regular message" + + # Extra properties for replacement should NOT be present + assert "m.relates_to" not in result + assert "m.new_content" not in result + assert "format" not in result + assert "formatted_body" not in result + + +def test_build_matrix_text_content_plain_text_no_html() -> None: + """Test plain text that should not include HTML formatting.""" + result = _build_matrix_text_content("Simple plain text") + assert "msgtype" in result + assert "body" in result + assert "format" not in result + assert "formatted_body" not in result + + +@pytest.mark.asyncio +async def test_send_room_content_returns_room_send_response(): + """Test that _send_room_content returns the response from client.room_send.""" + client = _FakeAsyncClient("", "", "", None) + channel = MatrixChannel(_make_config(), MessageBus()) + channel.client = client + + room_id = "!test_room:matrix.org" + content = {"msgtype": "m.text", "body": "Hello World"} + + result = await channel._send_room_content(room_id, content) + + assert result is client.room_send_response + + +@pytest.mark.asyncio +async def test_send_delta_creates_stream_buffer_and_sends_initial_message() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo" + + await channel.send_delta("!room:matrix.org", "Hello") + + assert "!room:matrix.org" in channel._stream_bufs + buf = channel._stream_bufs["!room:matrix.org"] + assert buf.text == "Hello" + assert buf.event_id == "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo" + assert len(client.room_send_calls) == 1 + assert client.room_send_calls[0]["content"]["body"] == "Hello" + + +@pytest.mark.asyncio +async def test_send_delta_appends_without_sending_before_edit_interval(monkeypatch) -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo" + + now = 100.0 + monkeypatch.setattr(channel, "monotonic_time", lambda: now) + + await channel.send_delta("!room:matrix.org", "Hello") + assert len(client.room_send_calls) == 1 + + await channel.send_delta("!room:matrix.org", " world") + assert len(client.room_send_calls) == 1 + + buf = channel._stream_bufs["!room:matrix.org"] + assert buf.text == "Hello world" + assert buf.event_id == "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo" + + +@pytest.mark.asyncio +async def test_send_delta_edits_again_after_interval(monkeypatch) -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo" + + times = [100.0, 102.0, 104.0, 106.0, 108.0] + times.reverse() + monkeypatch.setattr(channel, "monotonic_time", lambda: times and times.pop()) + + await channel.send_delta("!room:matrix.org", "Hello") + await channel.send_delta("!room:matrix.org", " world") + + assert len(client.room_send_calls) == 2 + first_content = client.room_send_calls[0]["content"] + second_content = client.room_send_calls[1]["content"] + + assert "body" in first_content + assert first_content["body"] == "Hello" + assert "m.relates_to" not in first_content + + assert "body" in second_content + assert "m.relates_to" in second_content + assert second_content["body"] == "Hello world" + assert second_content["m.relates_to"] == { + "rel_type": "m.replace", + "event_id": "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo", + } + + +@pytest.mark.asyncio +async def test_send_delta_stream_end_replaces_existing_message() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + channel._stream_bufs["!room:matrix.org"] = matrix_module._StreamBuf( + text="Final text", + event_id="event-1", + last_edit=100.0, + ) + + await channel.send_delta("!room:matrix.org", "", {"_stream_end": True}) + + assert "!room:matrix.org" not in channel._stream_bufs + assert client.typing_calls[-1] == ("!room:matrix.org", False, TYPING_NOTICE_TIMEOUT_MS) + assert len(client.room_send_calls) == 1 + assert client.room_send_calls[0]["content"]["body"] == "Final text" + assert client.room_send_calls[0]["content"]["m.relates_to"] == { + "rel_type": "m.replace", + "event_id": "event-1", + } + + +@pytest.mark.asyncio +async def test_send_delta_starts_threaded_stream_inside_thread() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + client.room_send_response.event_id = "event-1" + + metadata = { + "thread_root_event_id": "$root1", + "thread_reply_to_event_id": "$reply1", + } + await channel.send_delta("!room:matrix.org", "Hello", metadata) + + assert client.room_send_calls[0]["content"]["m.relates_to"] == { + "rel_type": "m.thread", + "event_id": "$root1", + "m.in_reply_to": {"event_id": "$reply1"}, + "is_falling_back": True, + } + + +@pytest.mark.asyncio +async def test_send_delta_threaded_edit_keeps_replace_and_thread_relation(monkeypatch) -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + client.room_send_response.event_id = "event-1" + + times = [100.0, 102.0, 104.0] + times.reverse() + monkeypatch.setattr(channel, "monotonic_time", lambda: times and times.pop()) + + metadata = { + "thread_root_event_id": "$root1", + "thread_reply_to_event_id": "$reply1", + } + await channel.send_delta("!room:matrix.org", "Hello", metadata) + await channel.send_delta("!room:matrix.org", " world", metadata) + await channel.send_delta("!room:matrix.org", "", {"_stream_end": True, **metadata}) + + edit_content = client.room_send_calls[1]["content"] + final_content = client.room_send_calls[2]["content"] + + assert edit_content["m.relates_to"] == { + "rel_type": "m.replace", + "event_id": "event-1", + } + assert edit_content["m.new_content"]["m.relates_to"] == { + "rel_type": "m.thread", + "event_id": "$root1", + "m.in_reply_to": {"event_id": "$reply1"}, + "is_falling_back": True, + } + assert final_content["m.relates_to"] == { + "rel_type": "m.replace", + "event_id": "event-1", + } + assert final_content["m.new_content"]["m.relates_to"] == { + "rel_type": "m.thread", + "event_id": "$root1", + "m.in_reply_to": {"event_id": "$reply1"}, + "is_falling_back": True, + } + + +@pytest.mark.asyncio +async def test_send_delta_stream_end_noop_when_buffer_missing() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + await channel.send_delta("!room:matrix.org", "", {"_stream_end": True}) + + assert client.room_send_calls == [] + assert client.typing_calls == [] + + +@pytest.mark.asyncio +async def test_send_delta_on_error_stops_typing(monkeypatch) -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + client.raise_on_send = True + channel.client = client + + now = 100.0 + monkeypatch.setattr(channel, "monotonic_time", lambda: now) + + await channel.send_delta("!room:matrix.org", "Hello", {"room_id": "!room:matrix.org"}) + + assert "!room:matrix.org" in channel._stream_bufs + assert channel._stream_bufs["!room:matrix.org"].text == "Hello" + assert len(client.room_send_calls) == 1 + + assert len(client.typing_calls) == 1 + + +@pytest.mark.asyncio +async def test_send_delta_ignores_whitespace_only_delta(monkeypatch) -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + now = 100.0 + monkeypatch.setattr(channel, "monotonic_time", lambda: now) + + await channel.send_delta("!room:matrix.org", " ") + + assert "!room:matrix.org" in channel._stream_bufs + assert channel._stream_bufs["!room:matrix.org"].text == " " + assert client.room_send_calls == [] \ No newline at end of file diff --git a/tests/channels/test_weixin_channel.py b/tests/channels/test_weixin_channel.py index 35b01db8b..3a847411b 100644 --- a/tests/channels/test_weixin_channel.py +++ b/tests/channels/test_weixin_channel.py @@ -1,17 +1,22 @@ import asyncio import json import tempfile +from pathlib import Path from types import SimpleNamespace from unittest.mock import AsyncMock import pytest +import httpx +import nanobot.channels.weixin as weixin_mod from nanobot.bus.queue import MessageBus from nanobot.channels.weixin import ( ITEM_IMAGE, ITEM_TEXT, MESSAGE_TYPE_BOT, WEIXIN_CHANNEL_VERSION, + _decrypt_aes_ecb, + _encrypt_aes_ecb, WeixinChannel, WeixinConfig, ) @@ -42,10 +47,12 @@ def test_make_headers_includes_route_tag_when_configured() -> None: assert headers["Authorization"] == "Bearer token" assert headers["SKRouteTag"] == "123" + assert headers["iLink-App-Id"] == "bot" + assert headers["iLink-App-ClientVersion"] == str((2 << 16) | (1 << 8) | 1) def test_channel_version_matches_reference_plugin_version() -> None: - assert WEIXIN_CHANNEL_VERSION == "1.0.3" + assert WEIXIN_CHANNEL_VERSION == "2.1.1" def test_save_and_load_state_persists_context_tokens(tmp_path) -> None: @@ -169,6 +176,120 @@ async def test_process_message_extracts_media_and_preserves_paths() -> None: assert inbound.media == ["/tmp/test.jpg"] +@pytest.mark.asyncio +async def test_process_message_falls_back_to_referenced_media_when_no_top_level_media() -> None: + channel, bus = _make_channel() + channel._download_media_item = AsyncMock(return_value="/tmp/ref.jpg") + + await channel._process_message( + { + "message_type": 1, + "message_id": "m3-ref-fallback", + "from_user_id": "wx-user", + "context_token": "ctx-3-ref-fallback", + "item_list": [ + { + "type": ITEM_TEXT, + "text_item": {"text": "reply to image"}, + "ref_msg": { + "message_item": { + "type": ITEM_IMAGE, + "image_item": {"media": {"encrypt_query_param": "ref-enc"}}, + }, + }, + }, + ], + } + ) + + inbound = await asyncio.wait_for(bus.consume_inbound(), timeout=1.0) + + channel._download_media_item.assert_awaited_once_with( + {"media": {"encrypt_query_param": "ref-enc"}}, + "image", + ) + assert inbound.media == ["/tmp/ref.jpg"] + assert "reply to image" in inbound.content + assert "[image]" in inbound.content + + +@pytest.mark.asyncio +async def test_process_message_does_not_use_referenced_fallback_when_top_level_media_exists() -> None: + channel, bus = _make_channel() + channel._download_media_item = AsyncMock(side_effect=["/tmp/top.jpg", "/tmp/ref.jpg"]) + + await channel._process_message( + { + "message_type": 1, + "message_id": "m3-ref-no-fallback", + "from_user_id": "wx-user", + "context_token": "ctx-3-ref-no-fallback", + "item_list": [ + {"type": ITEM_IMAGE, "image_item": {"media": {"encrypt_query_param": "top-enc"}}}, + { + "type": ITEM_TEXT, + "text_item": {"text": "has top-level media"}, + "ref_msg": { + "message_item": { + "type": ITEM_IMAGE, + "image_item": {"media": {"encrypt_query_param": "ref-enc"}}, + }, + }, + }, + ], + } + ) + + inbound = await asyncio.wait_for(bus.consume_inbound(), timeout=1.0) + + channel._download_media_item.assert_awaited_once_with( + {"media": {"encrypt_query_param": "top-enc"}}, + "image", + ) + assert inbound.media == ["/tmp/top.jpg"] + assert "/tmp/ref.jpg" not in inbound.content + + +@pytest.mark.asyncio +async def test_process_message_does_not_fallback_when_top_level_media_exists_but_download_fails() -> None: + channel, bus = _make_channel() + # Top-level image download fails (None), referenced image would succeed if fallback were triggered. + channel._download_media_item = AsyncMock(side_effect=[None, "/tmp/ref.jpg"]) + + await channel._process_message( + { + "message_type": 1, + "message_id": "m3-ref-no-fallback-on-failure", + "from_user_id": "wx-user", + "context_token": "ctx-3-ref-no-fallback-on-failure", + "item_list": [ + {"type": ITEM_IMAGE, "image_item": {"media": {"encrypt_query_param": "top-enc"}}}, + { + "type": ITEM_TEXT, + "text_item": {"text": "quoted has media"}, + "ref_msg": { + "message_item": { + "type": ITEM_IMAGE, + "image_item": {"media": {"encrypt_query_param": "ref-enc"}}, + }, + }, + }, + ], + } + ) + + inbound = await asyncio.wait_for(bus.consume_inbound(), timeout=1.0) + + # Should only attempt top-level media item; reference fallback must not activate. + channel._download_media_item.assert_awaited_once_with( + {"media": {"encrypt_query_param": "top-enc"}}, + "image", + ) + assert inbound.media == [] + assert "[image]" in inbound.content + assert "/tmp/ref.jpg" not in inbound.content + + @pytest.mark.asyncio async def test_send_without_context_token_does_not_send_text() -> None: channel, _bus = _make_channel() @@ -199,6 +320,70 @@ async def test_send_does_not_send_when_session_is_paused() -> None: channel._send_text.assert_not_awaited() +@pytest.mark.asyncio +async def test_get_typing_ticket_fetches_and_caches_per_user() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._api_post = AsyncMock(return_value={"ret": 0, "typing_ticket": "ticket-1"}) + + first = await channel._get_typing_ticket("wx-user", "ctx-1") + second = await channel._get_typing_ticket("wx-user", "ctx-2") + + assert first == "ticket-1" + assert second == "ticket-1" + channel._api_post.assert_awaited_once_with( + "ilink/bot/getconfig", + {"ilink_user_id": "wx-user", "context_token": "ctx-1", "base_info": weixin_mod.BASE_INFO}, + ) + + +@pytest.mark.asyncio +async def test_send_uses_typing_start_and_cancel_when_ticket_available() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-typing" + channel._send_text = AsyncMock() + channel._api_post = AsyncMock( + side_effect=[ + {"ret": 0, "typing_ticket": "ticket-typing"}, + {"ret": 0}, + {"ret": 0}, + ] + ) + + await channel.send( + type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})() + ) + + channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-typing") + assert channel._api_post.await_count == 3 + assert channel._api_post.await_args_list[0].args[0] == "ilink/bot/getconfig" + assert channel._api_post.await_args_list[1].args[0] == "ilink/bot/sendtyping" + assert channel._api_post.await_args_list[1].args[1]["status"] == 1 + assert channel._api_post.await_args_list[2].args[0] == "ilink/bot/sendtyping" + assert channel._api_post.await_args_list[2].args[1]["status"] == 2 + + +@pytest.mark.asyncio +async def test_send_still_sends_text_when_typing_ticket_missing() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-no-ticket" + channel._send_text = AsyncMock() + channel._api_post = AsyncMock(return_value={"ret": 1, "errmsg": "no config"}) + + await channel.send( + type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})() + ) + + channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-no-ticket") + channel._api_post.assert_awaited_once() + assert channel._api_post.await_args_list[0].args[0] == "ilink/bot/getconfig" + + @pytest.mark.asyncio async def test_poll_once_pauses_session_on_expired_errcode() -> None: channel, _bus = _make_channel() @@ -220,8 +405,12 @@ async def test_qr_login_refreshes_expired_qr_and_then_succeeds() -> None: channel._api_get = AsyncMock( side_effect=[ {"qrcode": "qr-1", "qrcode_img_content": "url-1"}, - {"status": "expired"}, {"qrcode": "qr-2", "qrcode_img_content": "url-2"}, + ] + ) + channel._api_get_with_base = AsyncMock( + side_effect=[ + {"status": "expired"}, { "status": "confirmed", "bot_token": "token-2", @@ -247,12 +436,16 @@ async def test_qr_login_returns_false_after_too_many_expired_qr_codes() -> None: channel._api_get = AsyncMock( side_effect=[ {"qrcode": "qr-1", "qrcode_img_content": "url-1"}, - {"status": "expired"}, {"qrcode": "qr-2", "qrcode_img_content": "url-2"}, - {"status": "expired"}, {"qrcode": "qr-3", "qrcode_img_content": "url-3"}, - {"status": "expired"}, {"qrcode": "qr-4", "qrcode_img_content": "url-4"}, + ] + ) + channel._api_get_with_base = AsyncMock( + side_effect=[ + {"status": "expired"}, + {"status": "expired"}, + {"status": "expired"}, {"status": "expired"}, ] ) @@ -262,6 +455,105 @@ async def test_qr_login_returns_false_after_too_many_expired_qr_codes() -> None: assert ok is False +@pytest.mark.asyncio +async def test_qr_login_switches_polling_base_url_on_redirect_status() -> None: + channel, _bus = _make_channel() + channel._running = True + channel._save_state = lambda: None + channel._print_qr_code = lambda url: None + channel._fetch_qr_code = AsyncMock(return_value=("qr-1", "url-1")) + + status_side_effect = [ + {"status": "scaned_but_redirect", "redirect_host": "idc.redirect.test"}, + { + "status": "confirmed", + "bot_token": "token-3", + "ilink_bot_id": "bot-3", + "baseurl": "https://example.test", + "ilink_user_id": "wx-user", + }, + ] + channel._api_get = AsyncMock(side_effect=list(status_side_effect)) + channel._api_get_with_base = AsyncMock(side_effect=list(status_side_effect)) + + ok = await channel._qr_login() + + assert ok is True + assert channel._token == "token-3" + assert channel._api_get_with_base.await_count == 2 + first_call = channel._api_get_with_base.await_args_list[0] + second_call = channel._api_get_with_base.await_args_list[1] + assert first_call.kwargs["base_url"] == "https://ilinkai.weixin.qq.com" + assert second_call.kwargs["base_url"] == "https://idc.redirect.test" + + +@pytest.mark.asyncio +async def test_qr_login_redirect_without_host_keeps_current_polling_base_url() -> None: + channel, _bus = _make_channel() + channel._running = True + channel._save_state = lambda: None + channel._print_qr_code = lambda url: None + channel._fetch_qr_code = AsyncMock(return_value=("qr-1", "url-1")) + + status_side_effect = [ + {"status": "scaned_but_redirect"}, + { + "status": "confirmed", + "bot_token": "token-4", + "ilink_bot_id": "bot-4", + "baseurl": "https://example.test", + "ilink_user_id": "wx-user", + }, + ] + channel._api_get = AsyncMock(side_effect=list(status_side_effect)) + channel._api_get_with_base = AsyncMock(side_effect=list(status_side_effect)) + + ok = await channel._qr_login() + + assert ok is True + assert channel._token == "token-4" + assert channel._api_get_with_base.await_count == 2 + first_call = channel._api_get_with_base.await_args_list[0] + second_call = channel._api_get_with_base.await_args_list[1] + assert first_call.kwargs["base_url"] == "https://ilinkai.weixin.qq.com" + assert second_call.kwargs["base_url"] == "https://ilinkai.weixin.qq.com" + + +@pytest.mark.asyncio +async def test_qr_login_resets_redirect_base_url_after_qr_refresh() -> None: + channel, _bus = _make_channel() + channel._running = True + channel._save_state = lambda: None + channel._print_qr_code = lambda url: None + channel._fetch_qr_code = AsyncMock(side_effect=[("qr-1", "url-1"), ("qr-2", "url-2")]) + + channel._api_get_with_base = AsyncMock( + side_effect=[ + {"status": "scaned_but_redirect", "redirect_host": "idc.redirect.test"}, + {"status": "expired"}, + { + "status": "confirmed", + "bot_token": "token-5", + "ilink_bot_id": "bot-5", + "baseurl": "https://example.test", + "ilink_user_id": "wx-user", + }, + ] + ) + + ok = await channel._qr_login() + + assert ok is True + assert channel._token == "token-5" + assert channel._api_get_with_base.await_count == 3 + first_call = channel._api_get_with_base.await_args_list[0] + second_call = channel._api_get_with_base.await_args_list[1] + third_call = channel._api_get_with_base.await_args_list[2] + assert first_call.kwargs["base_url"] == "https://ilinkai.weixin.qq.com" + assert second_call.kwargs["base_url"] == "https://idc.redirect.test" + assert third_call.kwargs["base_url"] == "https://ilinkai.weixin.qq.com" + + @pytest.mark.asyncio async def test_process_message_skips_bot_messages() -> None: channel, bus = _make_channel() @@ -281,12 +573,13 @@ async def test_process_message_skips_bot_messages() -> None: @pytest.mark.asyncio -async def test_process_message_fetches_typing_ticket_and_starts_typing() -> None: +async def test_process_message_starts_typing_on_inbound() -> None: + """Typing indicator fires immediately when user message arrives.""" channel, _bus = _make_channel() channel._running = True channel._client = object() channel._token = "token" - channel._api_post = AsyncMock(return_value={"typing_ticket": "ticket-1"}) + channel._start_typing = AsyncMock() await channel._process_message( { @@ -300,42 +593,42 @@ async def test_process_message_fetches_typing_ticket_and_starts_typing() -> None } ) - assert channel._typing_tickets["wx-user"] == "ticket-1" - assert "wx-user" in channel._typing_tasks - await channel._stop_typing("wx-user", clear_remote=False) + channel._start_typing.assert_awaited_once_with("wx-user", "ctx-typing") @pytest.mark.asyncio async def test_send_final_message_clears_typing_indicator() -> None: + """Non-progress send should cancel typing status.""" channel, _bus = _make_channel() channel._client = object() channel._token = "token" channel._context_tokens["wx-user"] = "ctx-2" - channel._typing_tickets["wx-user"] = "ticket-2" + channel._typing_tickets["wx-user"] = {"ticket": "ticket-2", "next_fetch_at": 9999999999} channel._send_text = AsyncMock() - channel._api_post = AsyncMock(return_value={}) + channel._api_post = AsyncMock(return_value={"ret": 0}) await channel.send( type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})() ) channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-2") - channel._api_post.assert_awaited_once() - endpoint, body = channel._api_post.await_args.args - assert endpoint == "ilink/bot/sendtyping" - assert body["status"] == 2 - assert body["typing_ticket"] == "ticket-2" + typing_cancel_calls = [ + c for c in channel._api_post.await_args_list + if c.args[0] == "ilink/bot/sendtyping" and c.args[1]["status"] == 2 + ] + assert len(typing_cancel_calls) >= 1 @pytest.mark.asyncio async def test_send_progress_message_keeps_typing_indicator() -> None: + """Progress messages must not cancel typing status.""" channel, _bus = _make_channel() channel._client = object() channel._token = "token" channel._context_tokens["wx-user"] = "ctx-2" - channel._typing_tickets["wx-user"] = "ticket-2" + channel._typing_tickets["wx-user"] = {"ticket": "ticket-2", "next_fetch_at": 9999999999} channel._send_text = AsyncMock() - channel._api_post = AsyncMock(return_value={}) + channel._api_post = AsyncMock(return_value={"ret": 0}) await channel.send( type( @@ -351,4 +644,362 @@ async def test_send_progress_message_keeps_typing_indicator() -> None: ) channel._send_text.assert_awaited_once_with("wx-user", "thinking", "ctx-2") - channel._api_post.assert_not_awaited() + typing_cancel_calls = [ + c for c in channel._api_post.await_args_list + if c.args and c.args[0] == "ilink/bot/sendtyping" and c.args[1].get("status") == 2 + ] + assert len(typing_cancel_calls) == 0 + + +class _DummyHttpResponse: + def __init__(self, *, headers: dict[str, str] | None = None, status_code: int = 200) -> None: + self.headers = headers or {} + self.status_code = status_code + + def raise_for_status(self) -> None: + return None + + +@pytest.mark.asyncio +async def test_send_media_uses_upload_full_url_when_present(tmp_path) -> None: + channel, _bus = _make_channel() + + media_file = tmp_path / "photo.jpg" + media_file.write_bytes(b"hello-weixin") + + cdn_post = AsyncMock(return_value=_DummyHttpResponse(headers={"x-encrypted-param": "dl-param"})) + channel._client = SimpleNamespace(post=cdn_post) + channel._api_post = AsyncMock( + side_effect=[ + { + "upload_full_url": "https://upload-full.example.test/path?foo=bar", + "upload_param": "should-not-be-used", + }, + {"ret": 0}, + ] + ) + + await channel._send_media_file("wx-user", str(media_file), "ctx-1") + + # first POST call is CDN upload + cdn_url = cdn_post.await_args_list[0].args[0] + assert cdn_url == "https://upload-full.example.test/path?foo=bar" + + +@pytest.mark.asyncio +async def test_send_media_falls_back_to_upload_param_url(tmp_path) -> None: + channel, _bus = _make_channel() + + media_file = tmp_path / "photo.jpg" + media_file.write_bytes(b"hello-weixin") + + cdn_post = AsyncMock(return_value=_DummyHttpResponse(headers={"x-encrypted-param": "dl-param"})) + channel._client = SimpleNamespace(post=cdn_post) + channel._api_post = AsyncMock( + side_effect=[ + {"upload_param": "enc-need-fallback"}, + {"ret": 0}, + ] + ) + + await channel._send_media_file("wx-user", str(media_file), "ctx-1") + + cdn_url = cdn_post.await_args_list[0].args[0] + assert cdn_url.startswith(f"{channel.config.cdn_base_url}/upload?encrypted_query_param=enc-need-fallback") + assert "&filekey=" in cdn_url + + +@pytest.mark.asyncio +async def test_send_media_voice_file_uses_voice_item_and_voice_upload_type(tmp_path) -> None: + channel, _bus = _make_channel() + + media_file = tmp_path / "voice.mp3" + media_file.write_bytes(b"voice-bytes") + + cdn_post = AsyncMock(return_value=_DummyHttpResponse(headers={"x-encrypted-param": "voice-dl-param"})) + channel._client = SimpleNamespace(post=cdn_post) + channel._api_post = AsyncMock( + side_effect=[ + {"upload_full_url": "https://upload-full.example.test/voice?foo=bar"}, + {"ret": 0}, + ] + ) + + await channel._send_media_file("wx-user", str(media_file), "ctx-voice") + + getupload_body = channel._api_post.await_args_list[0].args[1] + assert getupload_body["media_type"] == 4 + + sendmessage_body = channel._api_post.await_args_list[1].args[1] + item = sendmessage_body["msg"]["item_list"][0] + assert item["type"] == 3 + assert "voice_item" in item + assert "file_item" not in item + assert item["voice_item"]["media"]["encrypt_query_param"] == "voice-dl-param" + + +@pytest.mark.asyncio +async def test_send_typing_uses_keepalive_until_send_finishes() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-typing-loop" + async def _api_post_side_effect(endpoint: str, _body: dict | None = None, *, auth: bool = True): + if endpoint == "ilink/bot/getconfig": + return {"ret": 0, "typing_ticket": "ticket-keepalive"} + return {"ret": 0} + + channel._api_post = AsyncMock(side_effect=_api_post_side_effect) + + async def _slow_send_text(*_args, **_kwargs) -> None: + await asyncio.sleep(0.03) + + channel._send_text = AsyncMock(side_effect=_slow_send_text) + + old_interval = weixin_mod.TYPING_KEEPALIVE_INTERVAL_S + weixin_mod.TYPING_KEEPALIVE_INTERVAL_S = 0.01 + try: + await channel.send( + type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})() + ) + finally: + weixin_mod.TYPING_KEEPALIVE_INTERVAL_S = old_interval + + status_calls = [ + c.args[1]["status"] + for c in channel._api_post.await_args_list + if c.args and c.args[0] == "ilink/bot/sendtyping" + ] + assert status_calls.count(1) >= 2 + assert status_calls[-1] == 2 + + +@pytest.mark.asyncio +async def test_get_typing_ticket_failure_uses_backoff_and_cached_ticket(monkeypatch) -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + + now = {"value": 1000.0} + monkeypatch.setattr(weixin_mod.time, "time", lambda: now["value"]) + monkeypatch.setattr(weixin_mod.random, "random", lambda: 0.5) + + channel._api_post = AsyncMock(return_value={"ret": 0, "typing_ticket": "ticket-ok"}) + first = await channel._get_typing_ticket("wx-user", "ctx-1") + assert first == "ticket-ok" + + # force refresh window reached + now["value"] = now["value"] + (12 * 60 * 60) + 1 + channel._api_post = AsyncMock(return_value={"ret": 1, "errmsg": "temporary failure"}) + + # On refresh failure, should still return cached ticket and apply backoff. + second = await channel._get_typing_ticket("wx-user", "ctx-2") + assert second == "ticket-ok" + assert channel._api_post.await_count == 1 + + # Before backoff expiry, no extra fetch should happen. + now["value"] += 1 + third = await channel._get_typing_ticket("wx-user", "ctx-3") + assert third == "ticket-ok" + assert channel._api_post.await_count == 1 + + +@pytest.mark.asyncio +async def test_qr_login_treats_temporary_connect_error_as_wait_and_recovers() -> None: + channel, _bus = _make_channel() + channel._running = True + channel._save_state = lambda: None + channel._print_qr_code = lambda url: None + channel._fetch_qr_code = AsyncMock(return_value=("qr-1", "url-1")) + + request = httpx.Request("GET", "https://ilinkai.weixin.qq.com/ilink/bot/get_qrcode_status") + channel._api_get_with_base = AsyncMock( + side_effect=[ + httpx.ConnectError("temporary network", request=request), + { + "status": "confirmed", + "bot_token": "token-net-ok", + "ilink_bot_id": "bot-id", + "baseurl": "https://example.test", + "ilink_user_id": "wx-user", + }, + ] + ) + + ok = await channel._qr_login() + + assert ok is True + assert channel._token == "token-net-ok" + + +@pytest.mark.asyncio +async def test_qr_login_treats_5xx_gateway_response_error_as_wait_and_recovers() -> None: + channel, _bus = _make_channel() + channel._running = True + channel._save_state = lambda: None + channel._print_qr_code = lambda url: None + channel._fetch_qr_code = AsyncMock(return_value=("qr-1", "url-1")) + + request = httpx.Request("GET", "https://ilinkai.weixin.qq.com/ilink/bot/get_qrcode_status") + response = httpx.Response(status_code=524, request=request) + channel._api_get_with_base = AsyncMock( + side_effect=[ + httpx.HTTPStatusError("gateway timeout", request=request, response=response), + { + "status": "confirmed", + "bot_token": "token-5xx-ok", + "ilink_bot_id": "bot-id", + "baseurl": "https://example.test", + "ilink_user_id": "wx-user", + }, + ] + ) + + ok = await channel._qr_login() + + assert ok is True + assert channel._token == "token-5xx-ok" + + +def test_decrypt_aes_ecb_strips_valid_pkcs7_padding() -> None: + key_b64 = "MDEyMzQ1Njc4OWFiY2RlZg==" # base64("0123456789abcdef") + plaintext = b"hello-weixin-padding" + + ciphertext = _encrypt_aes_ecb(plaintext, key_b64) + decrypted = _decrypt_aes_ecb(ciphertext, key_b64) + + assert decrypted == plaintext + + +class _DummyDownloadResponse: + def __init__(self, content: bytes, status_code: int = 200) -> None: + self.content = content + self.status_code = status_code + + def raise_for_status(self) -> None: + return None + + +class _DummyErrorDownloadResponse(_DummyDownloadResponse): + def __init__(self, url: str, status_code: int) -> None: + super().__init__(content=b"", status_code=status_code) + self._url = url + + def raise_for_status(self) -> None: + request = httpx.Request("GET", self._url) + response = httpx.Response(self.status_code, request=request) + raise httpx.HTTPStatusError( + f"download failed with status {self.status_code}", + request=request, + response=response, + ) + + +@pytest.mark.asyncio +async def test_download_media_item_uses_full_url_when_present(tmp_path) -> None: + channel, _bus = _make_channel() + weixin_mod.get_media_dir = lambda _name: tmp_path + + full_url = "https://cdn.example.test/download/full" + channel._client = SimpleNamespace( + get=AsyncMock(return_value=_DummyDownloadResponse(content=b"raw-image-bytes")) + ) + + item = { + "media": { + "full_url": full_url, + "encrypt_query_param": "enc-fallback-should-not-be-used", + }, + } + saved_path = await channel._download_media_item(item, "image") + + assert saved_path is not None + assert Path(saved_path).read_bytes() == b"raw-image-bytes" + channel._client.get.assert_awaited_once_with(full_url) + + +@pytest.mark.asyncio +async def test_download_media_item_falls_back_when_full_url_returns_retryable_error(tmp_path) -> None: + channel, _bus = _make_channel() + weixin_mod.get_media_dir = lambda _name: tmp_path + + full_url = "https://cdn.example.test/download/full?taskid=123" + channel._client = SimpleNamespace( + get=AsyncMock( + side_effect=[ + _DummyErrorDownloadResponse(full_url, 500), + _DummyDownloadResponse(content=b"fallback-bytes"), + ] + ) + ) + + item = { + "media": { + "full_url": full_url, + "encrypt_query_param": "enc-fallback", + }, + } + saved_path = await channel._download_media_item(item, "image") + + assert saved_path is not None + assert Path(saved_path).read_bytes() == b"fallback-bytes" + assert channel._client.get.await_count == 2 + assert channel._client.get.await_args_list[0].args[0] == full_url + fallback_url = channel._client.get.await_args_list[1].args[0] + assert fallback_url.startswith(f"{channel.config.cdn_base_url}/download?encrypted_query_param=enc-fallback") + + +@pytest.mark.asyncio +async def test_download_media_item_falls_back_to_encrypt_query_param(tmp_path) -> None: + channel, _bus = _make_channel() + weixin_mod.get_media_dir = lambda _name: tmp_path + + channel._client = SimpleNamespace( + get=AsyncMock(return_value=_DummyDownloadResponse(content=b"fallback-bytes")) + ) + + item = {"media": {"encrypt_query_param": "enc-fallback"}} + saved_path = await channel._download_media_item(item, "image") + + assert saved_path is not None + assert Path(saved_path).read_bytes() == b"fallback-bytes" + called_url = channel._client.get.await_args_list[0].args[0] + assert called_url.startswith(f"{channel.config.cdn_base_url}/download?encrypted_query_param=enc-fallback") + + +@pytest.mark.asyncio +async def test_download_media_item_does_not_retry_when_full_url_fails_without_fallback(tmp_path) -> None: + channel, _bus = _make_channel() + weixin_mod.get_media_dir = lambda _name: tmp_path + + full_url = "https://cdn.example.test/download/full" + channel._client = SimpleNamespace( + get=AsyncMock(return_value=_DummyErrorDownloadResponse(full_url, 500)) + ) + + item = {"media": {"full_url": full_url}} + saved_path = await channel._download_media_item(item, "image") + + assert saved_path is None + channel._client.get.assert_awaited_once_with(full_url) + + +@pytest.mark.asyncio +async def test_download_media_item_non_image_requires_aes_key_even_with_full_url(tmp_path) -> None: + channel, _bus = _make_channel() + weixin_mod.get_media_dir = lambda _name: tmp_path + + full_url = "https://cdn.example.test/download/voice" + channel._client = SimpleNamespace( + get=AsyncMock(return_value=_DummyDownloadResponse(content=b"ciphertext-or-unknown")) + ) + + item = { + "media": { + "full_url": full_url, + }, + } + saved_path = await channel._download_media_item(item, "voice") + + assert saved_path is None + channel._client.get.assert_not_awaited() diff --git a/tests/cli/test_commands.py b/tests/cli/test_commands.py index a8fcc4aa0..0f6ff8177 100644 --- a/tests/cli/test_commands.py +++ b/tests/cli/test_commands.py @@ -317,6 +317,75 @@ def test_openai_compat_provider_passes_model_through(): assert provider.get_default_model() == "github-copilot/gpt-5.3-codex" +def test_make_provider_uses_github_copilot_backend(): + from nanobot.cli.commands import _make_provider + from nanobot.config.schema import Config + + config = Config.model_validate( + { + "agents": { + "defaults": { + "provider": "github-copilot", + "model": "github-copilot/gpt-4.1", + } + } + } + ) + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = _make_provider(config) + + assert provider.__class__.__name__ == "GitHubCopilotProvider" + + +def test_github_copilot_provider_strips_prefixed_model_name(): + from nanobot.providers.github_copilot_provider import GitHubCopilotProvider + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = GitHubCopilotProvider(default_model="github-copilot/gpt-5.1") + + kwargs = provider._build_kwargs( + messages=[{"role": "user", "content": "hi"}], + tools=None, + model="github-copilot/gpt-5.1", + max_tokens=16, + temperature=0.1, + reasoning_effort=None, + tool_choice=None, + ) + + assert kwargs["model"] == "gpt-5.1" + + +@pytest.mark.asyncio +async def test_github_copilot_provider_refreshes_client_api_key_before_chat(): + from nanobot.providers.github_copilot_provider import GitHubCopilotProvider + + mock_client = MagicMock() + mock_client.api_key = "no-key" + mock_client.chat.completions.create = AsyncMock(return_value={ + "choices": [{"message": {"content": "ok"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + }) + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI", return_value=mock_client): + provider = GitHubCopilotProvider(default_model="github-copilot/gpt-5.1") + + provider._get_copilot_access_token = AsyncMock(return_value="copilot-access-token") + + response = await provider.chat( + messages=[{"role": "user", "content": "hi"}], + model="github-copilot/gpt-5.1", + max_tokens=16, + temperature=0.1, + ) + + assert response.content == "ok" + assert provider._client.api_key == "copilot-access-token" + provider._get_copilot_access_token.assert_awaited_once() + mock_client.chat.completions.create.assert_awaited_once() + + def test_openai_codex_strip_prefix_supports_hyphen_and_underscore(): assert _strip_model_prefix("openai-codex/gpt-5.1-codex") == "gpt-5.1-codex" assert _strip_model_prefix("openai_codex/gpt-5.1-codex") == "gpt-5.1-codex" @@ -642,27 +711,105 @@ def test_heartbeat_retains_recent_messages_by_default(): assert config.gateway.heartbeat.keep_recent_messages == 8 -def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None: +def _write_instance_config(tmp_path: Path) -> Path: config_file = tmp_path / "instance" / "config.json" config_file.parent.mkdir(parents=True) config_file.write_text("{}") + return config_file - config = Config() - config.agents.defaults.workspace = str(tmp_path / "config-workspace") - seen: dict[str, Path] = {} +def _stop_gateway_provider(_config) -> object: + raise _StopGatewayError("stop") + + +def _patch_cli_command_runtime( + monkeypatch, + config: Config, + *, + set_config_path=None, + sync_templates=None, + make_provider=None, + message_bus=None, + session_manager=None, + cron_service=None, + get_cron_dir=None, +) -> None: monkeypatch.setattr( "nanobot.config.loader.set_config_path", - lambda path: seen.__setitem__("config_path", path), + set_config_path or (lambda _path: None), ) monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) monkeypatch.setattr( "nanobot.cli.commands.sync_workspace_templates", - lambda path: seen.__setitem__("workspace", path), + sync_templates or (lambda _path: None), ) monkeypatch.setattr( "nanobot.cli.commands._make_provider", - lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")), + make_provider or (lambda _config: object()), + ) + + if message_bus is not None: + monkeypatch.setattr("nanobot.bus.queue.MessageBus", message_bus) + if session_manager is not None: + monkeypatch.setattr("nanobot.session.manager.SessionManager", session_manager) + if cron_service is not None: + monkeypatch.setattr("nanobot.cron.service.CronService", cron_service) + if get_cron_dir is not None: + monkeypatch.setattr("nanobot.config.paths.get_cron_dir", get_cron_dir) + + +def _patch_serve_runtime(monkeypatch, config: Config, seen: dict[str, object]) -> None: + pytest.importorskip("aiohttp") + + class _FakeApiApp: + def __init__(self) -> None: + self.on_startup: list[object] = [] + self.on_cleanup: list[object] = [] + + class _FakeAgentLoop: + def __init__(self, **kwargs) -> None: + seen["workspace"] = kwargs["workspace"] + + async def _connect_mcp(self) -> None: + return None + + async def close_mcp(self) -> None: + return None + + def _fake_create_app(agent_loop, model_name: str, request_timeout: float): + seen["agent_loop"] = agent_loop + seen["model_name"] = model_name + seen["request_timeout"] = request_timeout + return _FakeApiApp() + + def _fake_run_app(api_app, host: str, port: int, print): + seen["api_app"] = api_app + seen["host"] = host + seen["port"] = port + + _patch_cli_command_runtime( + monkeypatch, + config, + message_bus=lambda: object(), + session_manager=lambda _workspace: object(), + ) + monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop) + monkeypatch.setattr("nanobot.api.server.create_app", _fake_create_app) + monkeypatch.setattr("aiohttp.web.run_app", _fake_run_app) + + +def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None: + config_file = _write_instance_config(tmp_path) + config = Config() + config.agents.defaults.workspace = str(tmp_path / "config-workspace") + seen: dict[str, Path] = {} + + _patch_cli_command_runtime( + monkeypatch, + config, + set_config_path=lambda path: seen.__setitem__("config_path", path), + sync_templates=lambda path: seen.__setitem__("workspace", path), + make_provider=_stop_gateway_provider, ) result = runner.invoke(app, ["gateway", "--config", str(config_file)]) @@ -673,24 +820,17 @@ def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Pa def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) -> None: - config_file = tmp_path / "instance" / "config.json" - config_file.parent.mkdir(parents=True) - config_file.write_text("{}") - + config_file = _write_instance_config(tmp_path) config = Config() config.agents.defaults.workspace = str(tmp_path / "config-workspace") override = tmp_path / "override-workspace" seen: dict[str, Path] = {} - monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) - monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) - monkeypatch.setattr( - "nanobot.cli.commands.sync_workspace_templates", - lambda path: seen.__setitem__("workspace", path), - ) - monkeypatch.setattr( - "nanobot.cli.commands._make_provider", - lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")), + _patch_cli_command_runtime( + monkeypatch, + config, + sync_templates=lambda path: seen.__setitem__("workspace", path), + make_provider=_stop_gateway_provider, ) result = runner.invoke( @@ -704,27 +844,23 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) def test_gateway_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None: - config_file = tmp_path / "instance" / "config.json" - config_file.parent.mkdir(parents=True) - config_file.write_text("{}") - + config_file = _write_instance_config(tmp_path) config = Config() config.agents.defaults.workspace = str(tmp_path / "config-workspace") seen: dict[str, Path] = {} - monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) - monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) - monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) - monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object()) - monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object()) - monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object()) - class _StopCron: def __init__(self, store_path: Path) -> None: seen["cron_store"] = store_path raise _StopGatewayError("stop") - monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron) + _patch_cli_command_runtime( + monkeypatch, + config, + message_bus=lambda: object(), + session_manager=lambda _workspace: object(), + cron_service=_StopCron, + ) result = runner.invoke(app, ["gateway", "--config", str(config_file)]) @@ -735,10 +871,7 @@ def test_gateway_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: def test_gateway_workspace_override_does_not_migrate_legacy_cron( monkeypatch, tmp_path: Path ) -> None: - config_file = tmp_path / "instance" / "config.json" - config_file.parent.mkdir(parents=True) - config_file.write_text("{}") - + config_file = _write_instance_config(tmp_path) legacy_dir = tmp_path / "global" / "cron" legacy_dir.mkdir(parents=True) legacy_file = legacy_dir / "jobs.json" @@ -748,20 +881,19 @@ def test_gateway_workspace_override_does_not_migrate_legacy_cron( config = Config() seen: dict[str, Path] = {} - monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) - monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) - monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) - monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object()) - monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object()) - monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object()) - monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir) - class _StopCron: def __init__(self, store_path: Path) -> None: seen["cron_store"] = store_path raise _StopGatewayError("stop") - monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron) + _patch_cli_command_runtime( + monkeypatch, + config, + message_bus=lambda: object(), + session_manager=lambda _workspace: object(), + cron_service=_StopCron, + get_cron_dir=lambda: legacy_dir, + ) result = runner.invoke( app, @@ -777,10 +909,7 @@ def test_gateway_workspace_override_does_not_migrate_legacy_cron( def test_gateway_custom_config_workspace_does_not_migrate_legacy_cron( monkeypatch, tmp_path: Path ) -> None: - config_file = tmp_path / "instance" / "config.json" - config_file.parent.mkdir(parents=True) - config_file.write_text("{}") - + config_file = _write_instance_config(tmp_path) legacy_dir = tmp_path / "global" / "cron" legacy_dir.mkdir(parents=True) legacy_file = legacy_dir / "jobs.json" @@ -791,20 +920,19 @@ def test_gateway_custom_config_workspace_does_not_migrate_legacy_cron( config.agents.defaults.workspace = str(custom_workspace) seen: dict[str, Path] = {} - monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) - monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) - monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) - monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object()) - monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object()) - monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object()) - monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir) - class _StopCron: def __init__(self, store_path: Path) -> None: seen["cron_store"] = store_path raise _StopGatewayError("stop") - monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron) + _patch_cli_command_runtime( + monkeypatch, + config, + message_bus=lambda: object(), + session_manager=lambda _workspace: object(), + cron_service=_StopCron, + get_cron_dir=lambda: legacy_dir, + ) result = runner.invoke(app, ["gateway", "--config", str(config_file)]) @@ -856,19 +984,14 @@ def test_migrate_cron_store_skips_when_workspace_file_exists(tmp_path: Path) -> def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_path: Path) -> None: - config_file = tmp_path / "instance" / "config.json" - config_file.parent.mkdir(parents=True) - config_file.write_text("{}") - + config_file = _write_instance_config(tmp_path) config = Config() config.gateway.port = 18791 - monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) - monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) - monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) - monkeypatch.setattr( - "nanobot.cli.commands._make_provider", - lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")), + _patch_cli_command_runtime( + monkeypatch, + config, + make_provider=_stop_gateway_provider, ) result = runner.invoke(app, ["gateway", "--config", str(config_file)]) @@ -878,19 +1001,14 @@ def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_ def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) -> None: - config_file = tmp_path / "instance" / "config.json" - config_file.parent.mkdir(parents=True) - config_file.write_text("{}") - + config_file = _write_instance_config(tmp_path) config = Config() config.gateway.port = 18791 - monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) - monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) - monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) - monkeypatch.setattr( - "nanobot.cli.commands._make_provider", - lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")), + _patch_cli_command_runtime( + monkeypatch, + config, + make_provider=_stop_gateway_provider, ) result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"]) @@ -899,6 +1017,63 @@ def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) assert "port 18792" in result.stdout +def test_serve_uses_api_config_defaults_and_workspace_override( + monkeypatch, tmp_path: Path +) -> None: + config_file = _write_instance_config(tmp_path) + config = Config() + config.agents.defaults.workspace = str(tmp_path / "config-workspace") + config.api.host = "127.0.0.2" + config.api.port = 18900 + config.api.timeout = 45.0 + override_workspace = tmp_path / "override-workspace" + seen: dict[str, object] = {} + + _patch_serve_runtime(monkeypatch, config, seen) + + result = runner.invoke( + app, + ["serve", "--config", str(config_file), "--workspace", str(override_workspace)], + ) + + assert result.exit_code == 0 + assert seen["workspace"] == override_workspace + assert seen["host"] == "127.0.0.2" + assert seen["port"] == 18900 + assert seen["request_timeout"] == 45.0 + + +def test_serve_cli_options_override_api_config(monkeypatch, tmp_path: Path) -> None: + config_file = _write_instance_config(tmp_path) + config = Config() + config.api.host = "127.0.0.2" + config.api.port = 18900 + config.api.timeout = 45.0 + seen: dict[str, object] = {} + + _patch_serve_runtime(monkeypatch, config, seen) + + result = runner.invoke( + app, + [ + "serve", + "--config", + str(config_file), + "--host", + "127.0.0.1", + "--port", + "18901", + "--timeout", + "46", + ], + ) + + assert result.exit_code == 0 + assert seen["host"] == "127.0.0.1" + assert seen["port"] == 18901 + assert seen["request_timeout"] == 46.0 + + def test_channels_login_requires_channel_name() -> None: result = runner.invoke(app, ["channels", "login"]) diff --git a/tests/cli/test_restart_command.py b/tests/cli/test_restart_command.py index 3281afe2d..6efcdad0d 100644 --- a/tests/cli/test_restart_command.py +++ b/tests/cli/test_restart_command.py @@ -152,10 +152,12 @@ class TestRestartCommand: ]) await loop._run_agent_loop([]) - assert loop._last_usage == {"prompt_tokens": 9, "completion_tokens": 4} + assert loop._last_usage["prompt_tokens"] == 9 + assert loop._last_usage["completion_tokens"] == 4 await loop._run_agent_loop([]) - assert loop._last_usage == {"prompt_tokens": 0, "completion_tokens": 0} + assert loop._last_usage["prompt_tokens"] == 0 + assert loop._last_usage["completion_tokens"] == 0 @pytest.mark.asyncio async def test_status_falls_back_to_last_usage_when_context_estimate_missing(self): diff --git a/tests/cron/test_cron_tool_list.py b/tests/cron/test_cron_tool_list.py index 22a502fa4..42ad7d419 100644 --- a/tests/cron/test_cron_tool_list.py +++ b/tests/cron/test_cron_tool_list.py @@ -285,6 +285,28 @@ def test_add_at_job_uses_default_timezone_for_naive_datetime(tmp_path) -> None: assert job.schedule.at_ms == expected +def test_add_job_delivers_by_default(tmp_path) -> None: + tool = _make_tool(tmp_path) + tool.set_context("telegram", "chat-1") + + result = tool._add_job("Morning standup", 60, None, None, None) + + assert result.startswith("Created job") + job = tool._cron.list_jobs()[0] + assert job.payload.deliver is True + + +def test_add_job_can_disable_delivery(tmp_path) -> None: + tool = _make_tool(tmp_path) + tool.set_context("telegram", "chat-1") + + result = tool._add_job("Background refresh", 60, None, None, None, deliver=False) + + assert result.startswith("Created job") + job = tool._cron.list_jobs()[0] + assert job.payload.deliver is False + + def test_list_excludes_disabled_jobs(tmp_path) -> None: tool = _make_tool(tmp_path) job = tool._cron.add_job( diff --git a/tests/providers/test_cached_tokens.py b/tests/providers/test_cached_tokens.py new file mode 100644 index 000000000..1b01408a4 --- /dev/null +++ b/tests/providers/test_cached_tokens.py @@ -0,0 +1,233 @@ +"""Tests for cached token extraction from OpenAI-compatible providers.""" + +from __future__ import annotations + +from nanobot.providers.openai_compat_provider import OpenAICompatProvider + + +class FakeUsage: + """Mimics an OpenAI SDK usage object (has attributes, not dict keys).""" + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +class FakePromptDetails: + """Mimics prompt_tokens_details sub-object.""" + def __init__(self, cached_tokens=0): + self.cached_tokens = cached_tokens + + +class _FakeSpec: + supports_prompt_caching = False + model_id_prefix = None + strip_model_prefix = False + max_completion_tokens = False + reasoning_effort = None + + +def _provider(): + from unittest.mock import MagicMock + p = OpenAICompatProvider.__new__(OpenAICompatProvider) + p.client = MagicMock() + p.spec = _FakeSpec() + return p + + +# Minimal valid choice so _parse reaches _extract_usage. +_DICT_CHOICE = {"message": {"content": "Hello"}} + +class _FakeMessage: + content = "Hello" + tool_calls = None + + +class _FakeChoice: + message = _FakeMessage() + finish_reason = "stop" + + +# --- dict-based response (raw JSON / mapping) --- + +def test_extract_usage_openai_cached_tokens_dict(): + """prompt_tokens_details.cached_tokens from a dict response.""" + p = _provider() + response = { + "choices": [_DICT_CHOICE], + "usage": { + "prompt_tokens": 2000, + "completion_tokens": 300, + "total_tokens": 2300, + "prompt_tokens_details": {"cached_tokens": 1200}, + } + } + result = p._parse(response) + assert result.usage["cached_tokens"] == 1200 + assert result.usage["prompt_tokens"] == 2000 + + +def test_extract_usage_deepseek_cached_tokens_dict(): + """prompt_cache_hit_tokens from a DeepSeek dict response.""" + p = _provider() + response = { + "choices": [_DICT_CHOICE], + "usage": { + "prompt_tokens": 1500, + "completion_tokens": 200, + "total_tokens": 1700, + "prompt_cache_hit_tokens": 1200, + "prompt_cache_miss_tokens": 300, + } + } + result = p._parse(response) + assert result.usage["cached_tokens"] == 1200 + + +def test_extract_usage_no_cached_tokens_dict(): + """Response without any cache fields -> no cached_tokens key.""" + p = _provider() + response = { + "choices": [_DICT_CHOICE], + "usage": { + "prompt_tokens": 1000, + "completion_tokens": 200, + "total_tokens": 1200, + } + } + result = p._parse(response) + assert "cached_tokens" not in result.usage + + +def test_extract_usage_openai_cached_zero_dict(): + """cached_tokens=0 should NOT be included (same as existing fields).""" + p = _provider() + response = { + "choices": [_DICT_CHOICE], + "usage": { + "prompt_tokens": 2000, + "completion_tokens": 300, + "total_tokens": 2300, + "prompt_tokens_details": {"cached_tokens": 0}, + } + } + result = p._parse(response) + assert "cached_tokens" not in result.usage + + +# --- object-based response (OpenAI SDK Pydantic model) --- + +def test_extract_usage_openai_cached_tokens_obj(): + """prompt_tokens_details.cached_tokens from an SDK object response.""" + p = _provider() + usage_obj = FakeUsage( + prompt_tokens=2000, + completion_tokens=300, + total_tokens=2300, + prompt_tokens_details=FakePromptDetails(cached_tokens=1200), + ) + response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj) + result = p._parse(response) + assert result.usage["cached_tokens"] == 1200 + + +def test_extract_usage_deepseek_cached_tokens_obj(): + """prompt_cache_hit_tokens from a DeepSeek SDK object response.""" + p = _provider() + usage_obj = FakeUsage( + prompt_tokens=1500, + completion_tokens=200, + total_tokens=1700, + prompt_cache_hit_tokens=1200, + ) + response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj) + result = p._parse(response) + assert result.usage["cached_tokens"] == 1200 + + +def test_extract_usage_stepfun_top_level_cached_tokens_dict(): + """StepFun/Moonshot: usage.cached_tokens at top level (not nested).""" + p = _provider() + response = { + "choices": [_DICT_CHOICE], + "usage": { + "prompt_tokens": 591, + "completion_tokens": 120, + "total_tokens": 711, + "cached_tokens": 512, + } + } + result = p._parse(response) + assert result.usage["cached_tokens"] == 512 + + +def test_extract_usage_stepfun_top_level_cached_tokens_obj(): + """StepFun/Moonshot: usage.cached_tokens as SDK object attribute.""" + p = _provider() + usage_obj = FakeUsage( + prompt_tokens=591, + completion_tokens=120, + total_tokens=711, + cached_tokens=512, + ) + response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj) + result = p._parse(response) + assert result.usage["cached_tokens"] == 512 + + +def test_extract_usage_priority_nested_over_top_level_dict(): + """When both nested and top-level cached_tokens exist, nested wins.""" + p = _provider() + response = { + "choices": [_DICT_CHOICE], + "usage": { + "prompt_tokens": 2000, + "completion_tokens": 300, + "total_tokens": 2300, + "prompt_tokens_details": {"cached_tokens": 100}, + "cached_tokens": 500, + } + } + result = p._parse(response) + assert result.usage["cached_tokens"] == 100 + + +def test_anthropic_maps_cache_fields_to_cached_tokens(): + """Anthropic's cache_read_input_tokens should map to cached_tokens.""" + from nanobot.providers.anthropic_provider import AnthropicProvider + + usage_obj = FakeUsage( + input_tokens=800, + output_tokens=200, + cache_creation_input_tokens=300, + cache_read_input_tokens=1200, + ) + content_block = FakeUsage(type="text", text="hello") + response = FakeUsage( + id="msg_1", + type="message", + stop_reason="end_turn", + content=[content_block], + usage=usage_obj, + ) + result = AnthropicProvider._parse_response(response) + assert result.usage["cached_tokens"] == 1200 + assert result.usage["prompt_tokens"] == 2300 + assert result.usage["total_tokens"] == 2500 + assert result.usage["cache_creation_input_tokens"] == 300 + + +def test_anthropic_no_cache_fields(): + """Anthropic response without cache fields should not have cached_tokens.""" + from nanobot.providers.anthropic_provider import AnthropicProvider + + usage_obj = FakeUsage(input_tokens=800, output_tokens=200) + content_block = FakeUsage(type="text", text="hello") + response = FakeUsage( + id="msg_1", + type="message", + stop_reason="end_turn", + content=[content_block], + usage=usage_obj, + ) + result = AnthropicProvider._parse_response(response) + assert "cached_tokens" not in result.usage diff --git a/tests/providers/test_providers_init.py b/tests/providers/test_providers_init.py index 32cbab478..d6912b437 100644 --- a/tests/providers/test_providers_init.py +++ b/tests/providers/test_providers_init.py @@ -11,6 +11,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None: monkeypatch.delitem(sys.modules, "nanobot.providers.anthropic_provider", raising=False) monkeypatch.delitem(sys.modules, "nanobot.providers.openai_compat_provider", raising=False) monkeypatch.delitem(sys.modules, "nanobot.providers.openai_codex_provider", raising=False) + monkeypatch.delitem(sys.modules, "nanobot.providers.github_copilot_provider", raising=False) monkeypatch.delitem(sys.modules, "nanobot.providers.azure_openai_provider", raising=False) providers = importlib.import_module("nanobot.providers") @@ -18,6 +19,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None: assert "nanobot.providers.anthropic_provider" not in sys.modules assert "nanobot.providers.openai_compat_provider" not in sys.modules assert "nanobot.providers.openai_codex_provider" not in sys.modules + assert "nanobot.providers.github_copilot_provider" not in sys.modules assert "nanobot.providers.azure_openai_provider" not in sys.modules assert providers.__all__ == [ "LLMProvider", @@ -25,6 +27,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None: "AnthropicProvider", "OpenAICompatProvider", "OpenAICodexProvider", + "GitHubCopilotProvider", "AzureOpenAIProvider", ] diff --git a/tests/test_build_status.py b/tests/test_build_status.py new file mode 100644 index 000000000..d98301cf7 --- /dev/null +++ b/tests/test_build_status.py @@ -0,0 +1,59 @@ +"""Tests for build_status_content cache hit rate display.""" + +from nanobot.utils.helpers import build_status_content + + +def test_status_shows_cache_hit_rate(): + content = build_status_content( + version="0.1.0", + model="glm-4-plus", + start_time=1000000.0, + last_usage={"prompt_tokens": 2000, "completion_tokens": 300, "cached_tokens": 1200}, + context_window_tokens=128000, + session_msg_count=10, + context_tokens_estimate=5000, + ) + assert "60% cached" in content + assert "2000 in / 300 out" in content + + +def test_status_no_cache_info(): + """Without cached_tokens, display should not show cache percentage.""" + content = build_status_content( + version="0.1.0", + model="glm-4-plus", + start_time=1000000.0, + last_usage={"prompt_tokens": 2000, "completion_tokens": 300}, + context_window_tokens=128000, + session_msg_count=10, + context_tokens_estimate=5000, + ) + assert "cached" not in content.lower() + assert "2000 in / 300 out" in content + + +def test_status_zero_cached_tokens(): + """cached_tokens=0 should not show cache percentage.""" + content = build_status_content( + version="0.1.0", + model="glm-4-plus", + start_time=1000000.0, + last_usage={"prompt_tokens": 2000, "completion_tokens": 300, "cached_tokens": 0}, + context_window_tokens=128000, + session_msg_count=10, + context_tokens_estimate=5000, + ) + assert "cached" not in content.lower() + + +def test_status_100_percent_cached(): + content = build_status_content( + version="0.1.0", + model="glm-4-plus", + start_time=1000000.0, + last_usage={"prompt_tokens": 1000, "completion_tokens": 100, "cached_tokens": 1000}, + context_window_tokens=128000, + session_msg_count=5, + context_tokens_estimate=3000, + ) + assert "100% cached" in content diff --git a/tests/test_nanobot_facade.py b/tests/test_nanobot_facade.py new file mode 100644 index 000000000..9ad9c5db1 --- /dev/null +++ b/tests/test_nanobot_facade.py @@ -0,0 +1,168 @@ +"""Tests for the Nanobot programmatic facade.""" + +from __future__ import annotations + +import json +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from nanobot.nanobot import Nanobot, RunResult + + +def _write_config(tmp_path: Path, overrides: dict | None = None) -> Path: + data = { + "providers": {"openrouter": {"apiKey": "sk-test-key"}}, + "agents": {"defaults": {"model": "openai/gpt-4.1"}}, + } + if overrides: + data.update(overrides) + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(data)) + return config_path + + +def test_from_config_missing_file(): + with pytest.raises(FileNotFoundError): + Nanobot.from_config("/nonexistent/config.json") + + +def test_from_config_creates_instance(tmp_path): + config_path = _write_config(tmp_path) + bot = Nanobot.from_config(config_path, workspace=tmp_path) + assert bot._loop is not None + assert bot._loop.workspace == tmp_path + + +def test_from_config_default_path(): + from nanobot.config.schema import Config + + with patch("nanobot.config.loader.load_config") as mock_load, \ + patch("nanobot.nanobot._make_provider") as mock_prov: + mock_load.return_value = Config() + mock_prov.return_value = MagicMock() + mock_prov.return_value.get_default_model.return_value = "test" + mock_prov.return_value.generation.max_tokens = 4096 + Nanobot.from_config() + mock_load.assert_called_once_with(None) + + +@pytest.mark.asyncio +async def test_run_returns_result(tmp_path): + config_path = _write_config(tmp_path) + bot = Nanobot.from_config(config_path, workspace=tmp_path) + + from nanobot.bus.events import OutboundMessage + + mock_response = OutboundMessage( + channel="cli", chat_id="direct", content="Hello back!" + ) + bot._loop.process_direct = AsyncMock(return_value=mock_response) + + result = await bot.run("hi") + + assert isinstance(result, RunResult) + assert result.content == "Hello back!" + bot._loop.process_direct.assert_awaited_once_with("hi", session_key="sdk:default") + + +@pytest.mark.asyncio +async def test_run_with_hooks(tmp_path): + from nanobot.agent.hook import AgentHook, AgentHookContext + from nanobot.bus.events import OutboundMessage + + config_path = _write_config(tmp_path) + bot = Nanobot.from_config(config_path, workspace=tmp_path) + + class TestHook(AgentHook): + async def before_iteration(self, context: AgentHookContext) -> None: + pass + + mock_response = OutboundMessage( + channel="cli", chat_id="direct", content="done" + ) + bot._loop.process_direct = AsyncMock(return_value=mock_response) + + result = await bot.run("hi", hooks=[TestHook()]) + + assert result.content == "done" + assert bot._loop._extra_hooks == [] + + +@pytest.mark.asyncio +async def test_run_hooks_restored_on_error(tmp_path): + config_path = _write_config(tmp_path) + bot = Nanobot.from_config(config_path, workspace=tmp_path) + + from nanobot.agent.hook import AgentHook + + bot._loop.process_direct = AsyncMock(side_effect=RuntimeError("boom")) + original_hooks = bot._loop._extra_hooks + + with pytest.raises(RuntimeError): + await bot.run("hi", hooks=[AgentHook()]) + + assert bot._loop._extra_hooks is original_hooks + + +@pytest.mark.asyncio +async def test_run_none_response(tmp_path): + config_path = _write_config(tmp_path) + bot = Nanobot.from_config(config_path, workspace=tmp_path) + bot._loop.process_direct = AsyncMock(return_value=None) + + result = await bot.run("hi") + assert result.content == "" + + +def test_workspace_override(tmp_path): + config_path = _write_config(tmp_path) + custom_ws = tmp_path / "custom_workspace" + custom_ws.mkdir() + + bot = Nanobot.from_config(config_path, workspace=custom_ws) + assert bot._loop.workspace == custom_ws + + +def test_sdk_make_provider_uses_github_copilot_backend(): + from nanobot.config.schema import Config + from nanobot.nanobot import _make_provider + + config = Config.model_validate( + { + "agents": { + "defaults": { + "provider": "github-copilot", + "model": "github-copilot/gpt-4.1", + } + } + } + ) + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = _make_provider(config) + + assert provider.__class__.__name__ == "GitHubCopilotProvider" + + +@pytest.mark.asyncio +async def test_run_custom_session_key(tmp_path): + from nanobot.bus.events import OutboundMessage + + config_path = _write_config(tmp_path) + bot = Nanobot.from_config(config_path, workspace=tmp_path) + + mock_response = OutboundMessage( + channel="cli", chat_id="direct", content="ok" + ) + bot._loop.process_direct = AsyncMock(return_value=mock_response) + + await bot.run("hi", session_key="user-alice") + bot._loop.process_direct.assert_awaited_once_with("hi", session_key="user-alice") + + +def test_import_from_top_level(): + from nanobot import Nanobot as N, RunResult as R + assert N is Nanobot + assert R is RunResult diff --git a/tests/test_openai_api.py b/tests/test_openai_api.py new file mode 100644 index 000000000..42fec33ed --- /dev/null +++ b/tests/test_openai_api.py @@ -0,0 +1,371 @@ +"""Focused tests for the fixed-session OpenAI-compatible API.""" + +from __future__ import annotations + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest +import pytest_asyncio + +from nanobot.api.server import ( + API_CHAT_ID, + API_SESSION_KEY, + _chat_completion_response, + _error_json, + create_app, + handle_chat_completions, +) + +try: + from aiohttp.test_utils import TestClient, TestServer + + HAS_AIOHTTP = True +except ImportError: + HAS_AIOHTTP = False + +pytest_plugins = ("pytest_asyncio",) + + +def _make_mock_agent(response_text: str = "mock response") -> MagicMock: + agent = MagicMock() + agent.process_direct = AsyncMock(return_value=response_text) + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + return agent + + +@pytest.fixture +def mock_agent(): + return _make_mock_agent() + + +@pytest.fixture +def app(mock_agent): + return create_app(mock_agent, model_name="test-model", request_timeout=10.0) + + +@pytest_asyncio.fixture +async def aiohttp_client(): + clients: list[TestClient] = [] + + async def _make_client(app): + client = TestClient(TestServer(app)) + await client.start_server() + clients.append(client) + return client + + try: + yield _make_client + finally: + for client in clients: + await client.close() + + +def test_error_json() -> None: + resp = _error_json(400, "bad request") + assert resp.status == 400 + body = json.loads(resp.body) + assert body["error"]["message"] == "bad request" + assert body["error"]["code"] == 400 + + +def test_chat_completion_response() -> None: + result = _chat_completion_response("hello world", "test-model") + assert result["object"] == "chat.completion" + assert result["model"] == "test-model" + assert result["choices"][0]["message"]["content"] == "hello world" + assert result["choices"][0]["finish_reason"] == "stop" + assert result["id"].startswith("chatcmpl-") + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_missing_messages_returns_400(aiohttp_client, app) -> None: + client = await aiohttp_client(app) + resp = await client.post("/v1/chat/completions", json={"model": "test"}) + assert resp.status == 400 + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_no_user_message_returns_400(aiohttp_client, app) -> None: + client = await aiohttp_client(app) + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "system", "content": "you are a bot"}]}, + ) + assert resp.status == 400 + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_stream_true_returns_400(aiohttp_client, app) -> None: + client = await aiohttp_client(app) + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}], "stream": True}, + ) + assert resp.status == 400 + body = await resp.json() + assert "stream" in body["error"]["message"].lower() + + +@pytest.mark.asyncio +async def test_model_mismatch_returns_400() -> None: + request = MagicMock() + request.json = AsyncMock( + return_value={ + "model": "other-model", + "messages": [{"role": "user", "content": "hello"}], + } + ) + request.app = { + "agent_loop": _make_mock_agent(), + "model_name": "test-model", + "request_timeout": 10.0, + "session_lock": asyncio.Lock(), + } + + resp = await handle_chat_completions(request) + assert resp.status == 400 + body = json.loads(resp.body) + assert "test-model" in body["error"]["message"] + + +@pytest.mark.asyncio +async def test_single_user_message_required() -> None: + request = MagicMock() + request.json = AsyncMock( + return_value={ + "messages": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "previous reply"}, + ], + } + ) + request.app = { + "agent_loop": _make_mock_agent(), + "model_name": "test-model", + "request_timeout": 10.0, + "session_lock": asyncio.Lock(), + } + + resp = await handle_chat_completions(request) + assert resp.status == 400 + body = json.loads(resp.body) + assert "single user message" in body["error"]["message"].lower() + + +@pytest.mark.asyncio +async def test_single_user_message_must_have_user_role() -> None: + request = MagicMock() + request.json = AsyncMock( + return_value={ + "messages": [{"role": "system", "content": "you are a bot"}], + } + ) + request.app = { + "agent_loop": _make_mock_agent(), + "model_name": "test-model", + "request_timeout": 10.0, + "session_lock": asyncio.Lock(), + } + + resp = await handle_chat_completions(request) + assert resp.status == 400 + body = json.loads(resp.body) + assert "single user message" in body["error"]["message"].lower() + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_successful_request_uses_fixed_api_session(aiohttp_client, mock_agent) -> None: + app = create_app(mock_agent, model_name="test-model") + client = await aiohttp_client(app) + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + ) + assert resp.status == 200 + body = await resp.json() + assert body["choices"][0]["message"]["content"] == "mock response" + assert body["model"] == "test-model" + mock_agent.process_direct.assert_called_once_with( + content="hello", + session_key=API_SESSION_KEY, + channel="api", + chat_id=API_CHAT_ID, + ) + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_followup_requests_share_same_session_key(aiohttp_client) -> None: + call_log: list[str] = [] + + async def fake_process(content, session_key="", channel="", chat_id=""): + call_log.append(session_key) + return f"reply to {content}" + + agent = MagicMock() + agent.process_direct = fake_process + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + + r1 = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "first"}]}, + ) + r2 = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "second"}]}, + ) + + assert r1.status == 200 + assert r2.status == 200 + assert call_log == [API_SESSION_KEY, API_SESSION_KEY] + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_fixed_session_requests_are_serialized(aiohttp_client) -> None: + order: list[str] = [] + + async def slow_process(content, session_key="", channel="", chat_id=""): + order.append(f"start:{content}") + await asyncio.sleep(0.1) + order.append(f"end:{content}") + return content + + agent = MagicMock() + agent.process_direct = slow_process + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + + async def send(msg: str): + return await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": msg}]}, + ) + + r1, r2 = await asyncio.gather(send("first"), send("second")) + assert r1.status == 200 + assert r2.status == 200 + # Verify serialization: one process must fully finish before the other starts + if order[0] == "start:first": + assert order.index("end:first") < order.index("start:second") + else: + assert order.index("end:second") < order.index("start:first") + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_models_endpoint(aiohttp_client, app) -> None: + client = await aiohttp_client(app) + resp = await client.get("/v1/models") + assert resp.status == 200 + body = await resp.json() + assert body["object"] == "list" + assert body["data"][0]["id"] == "test-model" + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_health_endpoint(aiohttp_client, app) -> None: + client = await aiohttp_client(app) + resp = await client.get("/health") + assert resp.status == 200 + body = await resp.json() + assert body["status"] == "ok" + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_multimodal_content_extracts_text(aiohttp_client, mock_agent) -> None: + app = create_app(mock_agent, model_name="m") + client = await aiohttp_client(app) + resp = await client.post( + "/v1/chat/completions", + json={ + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "describe this"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + ], + } + ] + }, + ) + assert resp.status == 200 + mock_agent.process_direct.assert_called_once_with( + content="describe this", + session_key=API_SESSION_KEY, + channel="api", + chat_id=API_CHAT_ID, + ) + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_empty_response_retry_then_success(aiohttp_client) -> None: + call_count = 0 + + async def sometimes_empty(content, session_key="", channel="", chat_id=""): + nonlocal call_count + call_count += 1 + if call_count == 1: + return "" + return "recovered response" + + agent = MagicMock() + agent.process_direct = sometimes_empty + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + ) + assert resp.status == 200 + body = await resp.json() + assert body["choices"][0]["message"]["content"] == "recovered response" + assert call_count == 2 + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_empty_response_falls_back(aiohttp_client) -> None: + call_count = 0 + + async def always_empty(content, session_key="", channel="", chat_id=""): + nonlocal call_count + call_count += 1 + return "" + + agent = MagicMock() + agent.process_direct = always_empty + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + ) + assert resp.status == 200 + body = await resp.json() + assert body["choices"][0]["message"]["content"] == "I've completed processing but have no response to give." + assert call_count == 2 diff --git a/tests/tools/test_tool_validation.py b/tests/tools/test_tool_validation.py index a95418fe5..98a3dc903 100644 --- a/tests/tools/test_tool_validation.py +++ b/tests/tools/test_tool_validation.py @@ -95,6 +95,14 @@ def test_exec_extract_absolute_paths_keeps_full_windows_path() -> None: assert paths == [r"C:\user\workspace\txt"] +def test_exec_extract_absolute_paths_captures_windows_drive_root_path() -> None: + """Windows drive root paths like `E:\\` must be extracted for workspace guarding.""" + # Note: raw strings cannot end with a single backslash. + cmd = "dir E:\\" + paths = ExecTool._extract_absolute_paths(cmd) + assert paths == ["E:\\"] + + def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None: cmd = ".venv/bin/python script.py" paths = ExecTool._extract_absolute_paths(cmd) @@ -134,6 +142,45 @@ def test_exec_guard_blocks_quoted_home_path_outside_workspace(tmp_path) -> None: assert error == "Error: Command blocked by safety guard (path outside working dir)" +def test_exec_guard_blocks_windows_drive_root_outside_workspace(monkeypatch) -> None: + import nanobot.agent.tools.shell as shell_mod + + class FakeWindowsPath: + def __init__(self, raw: str) -> None: + self.raw = raw.rstrip("\\") + ("\\" if raw.endswith("\\") else "") + + def resolve(self) -> "FakeWindowsPath": + return self + + def expanduser(self) -> "FakeWindowsPath": + return self + + def is_absolute(self) -> bool: + return len(self.raw) >= 3 and self.raw[1:3] == ":\\" + + @property + def parents(self) -> list["FakeWindowsPath"]: + if not self.is_absolute(): + return [] + trimmed = self.raw.rstrip("\\") + if len(trimmed) <= 2: + return [] + idx = trimmed.rfind("\\") + if idx <= 2: + return [FakeWindowsPath(trimmed[:2] + "\\")] + parent = FakeWindowsPath(trimmed[:idx]) + return [parent, *parent.parents] + + def __eq__(self, other: object) -> bool: + return isinstance(other, FakeWindowsPath) and self.raw.lower() == other.raw.lower() + + monkeypatch.setattr(shell_mod, "Path", FakeWindowsPath) + + tool = ExecTool(restrict_to_workspace=True) + error = tool._guard_command("dir E:\\", "E:\\workspace") + assert error == "Error: Command blocked by safety guard (path outside working dir)" + + # --- cast_params tests ---