mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 17:32:39 +00:00
Merge branch 'main' into nightly
This commit is contained in:
commit
4741026538
76
README.md
76
README.md
@ -115,6 +115,7 @@
|
||||
- [Configuration](#️-configuration)
|
||||
- [Multiple Instances](#-multiple-instances)
|
||||
- [CLI Reference](#-cli-reference)
|
||||
- [OpenAI-Compatible API](#-openai-compatible-api)
|
||||
- [Docker](#-docker)
|
||||
- [Linux Service](#-linux-service)
|
||||
- [Project Structure](#-project-structure)
|
||||
@ -1618,6 +1619,7 @@ nanobot gateway --config ~/.nanobot-telegram/config.json --workspace /tmp/nanobo
|
||||
| `nanobot agent` | Interactive chat mode |
|
||||
| `nanobot agent --no-markdown` | Show plain-text replies |
|
||||
| `nanobot agent --logs` | Show runtime logs during chat |
|
||||
| `nanobot serve` | Start the OpenAI-compatible API |
|
||||
| `nanobot gateway` | Start the gateway |
|
||||
| `nanobot status` | Show status |
|
||||
| `nanobot provider login openai-codex` | OAuth login for providers |
|
||||
@ -1646,6 +1648,80 @@ The agent can also manage this file itself — ask it to "add a periodic task" a
|
||||
|
||||
</details>
|
||||
|
||||
## 🔌 OpenAI-Compatible API
|
||||
|
||||
nanobot can expose a minimal OpenAI-compatible endpoint for local integrations:
|
||||
|
||||
```bash
|
||||
pip install "nanobot-ai[api]"
|
||||
nanobot serve
|
||||
```
|
||||
|
||||
By default, the API binds to `127.0.0.1:8900`.
|
||||
|
||||
### Behavior
|
||||
|
||||
- Fixed session: all requests share the same nanobot session (`api:default`)
|
||||
- Single-message input: each request must contain exactly one `user` message
|
||||
- Fixed model: omit `model`, or pass the same model shown by `/v1/models`
|
||||
- No streaming: `stream=true` is not supported
|
||||
|
||||
### Endpoints
|
||||
|
||||
- `GET /health`
|
||||
- `GET /v1/models`
|
||||
- `POST /v1/chat/completions`
|
||||
|
||||
### curl
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8900/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
### Python (`requests`)
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
resp = requests.post(
|
||||
"http://127.0.0.1:8900/v1/chat/completions",
|
||||
json={
|
||||
"messages": [
|
||||
{"role": "user", "content": "hi"}
|
||||
]
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
print(resp.json()["choices"][0]["message"]["content"])
|
||||
```
|
||||
|
||||
### Python (`openai`)
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://127.0.0.1:8900/v1",
|
||||
api_key="dummy",
|
||||
)
|
||||
|
||||
resp = client.chat.completions.create(
|
||||
model="MiniMax-M2.7",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
print(resp.choices[0].message.content)
|
||||
```
|
||||
|
||||
## 🐳 Docker
|
||||
|
||||
> [!TIP]
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
#!/bin/bash
|
||||
# Count core agent lines (excluding channels/, cli/, providers/ adapters)
|
||||
# Count core agent lines (excluding channels/, cli/, api/, providers/ adapters)
|
||||
cd "$(dirname "$0")" || exit 1
|
||||
|
||||
echo "nanobot core agent line count"
|
||||
@ -15,7 +15,7 @@ root=$(cat nanobot/__init__.py nanobot/__main__.py | wc -l)
|
||||
printf " %-16s %5s lines\n" "(root)" "$root"
|
||||
|
||||
echo ""
|
||||
total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/command/*" ! -path "*/providers/*" ! -path "*/skills/*" | xargs cat | wc -l)
|
||||
total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/api/*" ! -path "*/command/*" ! -path "*/providers/*" ! -path "*/skills/*" | xargs cat | wc -l)
|
||||
echo " Core total: $total lines"
|
||||
echo ""
|
||||
echo " (excludes: channels/, cli/, command/, providers/, skills/)"
|
||||
echo " (excludes: channels/, cli/, api/, command/, providers/, skills/)"
|
||||
|
||||
@ -1,8 +1,19 @@
|
||||
"""Agent core module."""
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.agent.skills import SkillsLoader
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
|
||||
__all__ = ["AgentLoop", "ContextBuilder", "MemoryStore", "SkillsLoader"]
|
||||
__all__ = [
|
||||
"AgentHook",
|
||||
"AgentHookContext",
|
||||
"AgentLoop",
|
||||
"CompositeHook",
|
||||
"ContextBuilder",
|
||||
"MemoryStore",
|
||||
"SkillsLoader",
|
||||
"SubagentManager",
|
||||
]
|
||||
|
||||
@ -5,6 +5,8 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
@ -47,3 +49,60 @@ class AgentHook:
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return content
|
||||
|
||||
|
||||
class CompositeHook(AgentHook):
|
||||
"""Fan-out hook that delegates to an ordered list of hooks.
|
||||
|
||||
Error isolation: async methods catch and log per-hook exceptions
|
||||
so a faulty custom hook cannot crash the agent loop.
|
||||
``finalize_content`` is a pipeline (no isolation — bugs should surface).
|
||||
"""
|
||||
|
||||
__slots__ = ("_hooks",)
|
||||
|
||||
def __init__(self, hooks: list[AgentHook]) -> None:
|
||||
self._hooks = list(hooks)
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
return any(h.wants_streaming() for h in self._hooks)
|
||||
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.before_iteration(context)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.before_iteration error in {}", type(h).__name__)
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.on_stream(context, delta)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.on_stream error in {}", type(h).__name__)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.on_stream_end(context, resuming=resuming)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.on_stream_end error in {}", type(h).__name__)
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.before_execute_tools(context)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.before_execute_tools error in {}", type(h).__name__)
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.after_iteration(context)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.after_iteration error in {}", type(h).__name__)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
for h in self._hooks:
|
||||
content = h.finalize_content(context, content)
|
||||
return content
|
||||
|
||||
@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||
from nanobot.agent.memory import MemoryConsolidator
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
@ -37,6 +37,111 @@ if TYPE_CHECKING:
|
||||
from nanobot.cron.service import CronService
|
||||
|
||||
|
||||
class _LoopHook(AgentHook):
|
||||
"""Core lifecycle hook for the main agent loop.
|
||||
|
||||
Handles streaming delta relay, progress reporting, tool-call logging,
|
||||
and think-tag stripping for the built-in agent path.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_loop: AgentLoop,
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||
*,
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
self._loop = agent_loop
|
||||
self._on_progress = on_progress
|
||||
self._on_stream = on_stream
|
||||
self._on_stream_end = on_stream_end
|
||||
self._channel = channel
|
||||
self._chat_id = chat_id
|
||||
self._message_id = message_id
|
||||
self._stream_buf = ""
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
return self._on_stream is not None
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
from nanobot.utils.helpers import strip_think
|
||||
|
||||
prev_clean = strip_think(self._stream_buf)
|
||||
self._stream_buf += delta
|
||||
new_clean = strip_think(self._stream_buf)
|
||||
incremental = new_clean[len(prev_clean):]
|
||||
if incremental and self._on_stream:
|
||||
await self._on_stream(incremental)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
if self._on_stream_end:
|
||||
await self._on_stream_end(resuming=resuming)
|
||||
self._stream_buf = ""
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
if self._on_progress:
|
||||
if not self._on_stream:
|
||||
thought = self._loop._strip_think(
|
||||
context.response.content if context.response else None
|
||||
)
|
||||
if thought:
|
||||
await self._on_progress(thought)
|
||||
tool_hint = self._loop._strip_think(self._loop._tool_hint(context.tool_calls))
|
||||
await self._on_progress(tool_hint, tool_hint=True)
|
||||
for tc in context.tool_calls:
|
||||
args_str = json.dumps(tc.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
||||
self._loop._set_tool_context(self._channel, self._chat_id, self._message_id)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return self._loop._strip_think(content)
|
||||
|
||||
|
||||
class _LoopHookChain(AgentHook):
|
||||
"""Run the core loop hook first, then best-effort extra hooks.
|
||||
|
||||
This preserves the historical failure behavior of ``_LoopHook`` while still
|
||||
letting user-supplied hooks opt into ``CompositeHook`` isolation.
|
||||
"""
|
||||
|
||||
__slots__ = ("_primary", "_extras")
|
||||
|
||||
def __init__(self, primary: AgentHook, extra_hooks: list[AgentHook]) -> None:
|
||||
self._primary = primary
|
||||
self._extras = CompositeHook(extra_hooks)
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
return self._primary.wants_streaming() or self._extras.wants_streaming()
|
||||
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
await self._primary.before_iteration(context)
|
||||
await self._extras.before_iteration(context)
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
await self._primary.on_stream(context, delta)
|
||||
await self._extras.on_stream(context, delta)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
await self._primary.on_stream_end(context, resuming=resuming)
|
||||
await self._extras.on_stream_end(context, resuming=resuming)
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
await self._primary.before_execute_tools(context)
|
||||
await self._extras.before_execute_tools(context)
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
await self._primary.after_iteration(context)
|
||||
await self._extras.after_iteration(context)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
content = self._primary.finalize_content(context, content)
|
||||
return self._extras.finalize_content(context, content)
|
||||
|
||||
|
||||
class AgentLoop:
|
||||
"""
|
||||
The agent loop is the core processing engine.
|
||||
@ -68,6 +173,7 @@ class AgentLoop:
|
||||
mcp_servers: dict | None = None,
|
||||
channels_config: ChannelsConfig | None = None,
|
||||
timezone: str | None = None,
|
||||
hooks: list[AgentHook] | None = None,
|
||||
):
|
||||
from nanobot.config.schema import ExecToolConfig, WebSearchConfig
|
||||
|
||||
@ -85,6 +191,7 @@ class AgentLoop:
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
self._start_time = time.time()
|
||||
self._last_usage: dict[str, int] = {}
|
||||
self._extra_hooks: list[AgentHook] = hooks or []
|
||||
|
||||
self.context = ContextBuilder(workspace, timezone=timezone)
|
||||
self.sessions = session_manager or SessionManager(workspace)
|
||||
@ -217,52 +324,27 @@ class AgentLoop:
|
||||
``resuming=True`` means tool calls follow (spinner should restart);
|
||||
``resuming=False`` means this is the final response.
|
||||
"""
|
||||
loop_self = self
|
||||
|
||||
class _LoopHook(AgentHook):
|
||||
def __init__(self) -> None:
|
||||
self._stream_buf = ""
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
return on_stream is not None
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
from nanobot.utils.helpers import strip_think
|
||||
|
||||
prev_clean = strip_think(self._stream_buf)
|
||||
self._stream_buf += delta
|
||||
new_clean = strip_think(self._stream_buf)
|
||||
incremental = new_clean[len(prev_clean):]
|
||||
if incremental and on_stream:
|
||||
await on_stream(incremental)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
if on_stream_end:
|
||||
await on_stream_end(resuming=resuming)
|
||||
self._stream_buf = ""
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
if on_progress:
|
||||
if not on_stream:
|
||||
thought = loop_self._strip_think(context.response.content if context.response else None)
|
||||
if thought:
|
||||
await on_progress(thought)
|
||||
tool_hint = loop_self._strip_think(loop_self._tool_hint(context.tool_calls))
|
||||
await on_progress(tool_hint, tool_hint=True)
|
||||
for tc in context.tool_calls:
|
||||
args_str = json.dumps(tc.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
||||
loop_self._set_tool_context(channel, chat_id, message_id)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return loop_self._strip_think(content)
|
||||
loop_hook = _LoopHook(
|
||||
self,
|
||||
on_progress=on_progress,
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
hook: AgentHook = (
|
||||
_LoopHookChain(loop_hook, self._extra_hooks)
|
||||
if self._extra_hooks
|
||||
else loop_hook
|
||||
)
|
||||
|
||||
result = await self.runner.run(AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=self.tools,
|
||||
model=self.model,
|
||||
max_iterations=self.max_iterations,
|
||||
hook=_LoopHook(),
|
||||
hook=hook,
|
||||
error_message="Sorry, I encountered an error calling the AI model.",
|
||||
concurrent_tools=True,
|
||||
))
|
||||
|
||||
@ -21,6 +21,21 @@ from nanobot.config.schema import ExecToolConfig
|
||||
from nanobot.providers.base import LLMProvider
|
||||
|
||||
|
||||
class _SubagentHook(AgentHook):
|
||||
"""Logging-only hook for subagent execution."""
|
||||
|
||||
def __init__(self, task_id: str) -> None:
|
||||
self._task_id = task_id
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
for tool_call in context.tool_calls:
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.debug(
|
||||
"Subagent [{}] executing: {} with arguments: {}",
|
||||
self._task_id, tool_call.name, args_str,
|
||||
)
|
||||
|
||||
|
||||
class SubagentManager:
|
||||
"""Manages background subagent execution."""
|
||||
|
||||
@ -108,25 +123,19 @@ class SubagentManager:
|
||||
))
|
||||
tools.register(WebSearchTool(config=self.web_search_config, proxy=self.web_proxy))
|
||||
tools.register(WebFetchTool(proxy=self.web_proxy))
|
||||
|
||||
|
||||
system_prompt = self._build_subagent_prompt()
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": task},
|
||||
]
|
||||
|
||||
class _SubagentHook(AgentHook):
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
for tool_call in context.tool_calls:
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str)
|
||||
|
||||
result = await self.runner.run(AgentRunSpec(
|
||||
initial_messages=messages,
|
||||
tools=tools,
|
||||
model=self.model,
|
||||
max_iterations=15,
|
||||
hook=_SubagentHook(),
|
||||
hook=_SubagentHook(task_id),
|
||||
max_iterations_message="Task completed but no final response was generated.",
|
||||
error_message=None,
|
||||
fail_on_tool_error=True,
|
||||
@ -213,7 +222,7 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men
|
||||
lines.append("Failure:")
|
||||
lines.append(f"- {result.error}")
|
||||
return "\n".join(lines) or (result.error or "Error: subagent execution failed.")
|
||||
|
||||
|
||||
def _build_subagent_prompt(self) -> str:
|
||||
"""Build a focused system prompt for the subagent."""
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
|
||||
1
nanobot/api/__init__.py
Normal file
1
nanobot/api/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""OpenAI-compatible HTTP API for nanobot."""
|
||||
190
nanobot/api/server.py
Normal file
190
nanobot/api/server.py
Normal file
@ -0,0 +1,190 @@
|
||||
"""OpenAI-compatible HTTP API server for a fixed nanobot session.
|
||||
|
||||
Provides /v1/chat/completions and /v1/models endpoints.
|
||||
All requests route to a single persistent API session.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import web
|
||||
from loguru import logger
|
||||
|
||||
API_SESSION_KEY = "api:default"
|
||||
API_CHAT_ID = "default"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _error_json(status: int, message: str, err_type: str = "invalid_request_error") -> web.Response:
|
||||
return web.json_response(
|
||||
{"error": {"message": message, "type": err_type, "code": status}},
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
def _chat_completion_response(content: str, model: str) -> dict[str, Any]:
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": content},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
|
||||
def _response_text(value: Any) -> str:
|
||||
"""Normalize process_direct output to plain assistant text."""
|
||||
if value is None:
|
||||
return ""
|
||||
if hasattr(value, "content"):
|
||||
return str(getattr(value, "content") or "")
|
||||
return str(value)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Route handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def handle_chat_completions(request: web.Request) -> web.Response:
|
||||
"""POST /v1/chat/completions"""
|
||||
|
||||
# --- Parse body ---
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return _error_json(400, "Invalid JSON body")
|
||||
|
||||
messages = body.get("messages")
|
||||
if not isinstance(messages, list) or len(messages) != 1:
|
||||
return _error_json(400, "Only a single user message is supported")
|
||||
|
||||
# Stream not yet supported
|
||||
if body.get("stream", False):
|
||||
return _error_json(400, "stream=true is not supported yet. Set stream=false or omit it.")
|
||||
|
||||
message = messages[0]
|
||||
if not isinstance(message, dict) or message.get("role") != "user":
|
||||
return _error_json(400, "Only a single user message is supported")
|
||||
user_content = message.get("content", "")
|
||||
if isinstance(user_content, list):
|
||||
# Multi-modal content array — extract text parts
|
||||
user_content = " ".join(
|
||||
part.get("text", "") for part in user_content if part.get("type") == "text"
|
||||
)
|
||||
|
||||
agent_loop = request.app["agent_loop"]
|
||||
timeout_s: float = request.app.get("request_timeout", 120.0)
|
||||
model_name: str = request.app.get("model_name", "nanobot")
|
||||
if (requested_model := body.get("model")) and requested_model != model_name:
|
||||
return _error_json(400, f"Only configured model '{model_name}' is available")
|
||||
session_lock: asyncio.Lock = request.app["session_lock"]
|
||||
|
||||
logger.info("API request session_key={} content={}", API_SESSION_KEY, user_content[:80])
|
||||
|
||||
_FALLBACK = "I've completed processing but have no response to give."
|
||||
|
||||
try:
|
||||
async with session_lock:
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
agent_loop.process_direct(
|
||||
content=user_content,
|
||||
session_key=API_SESSION_KEY,
|
||||
channel="api",
|
||||
chat_id=API_CHAT_ID,
|
||||
),
|
||||
timeout=timeout_s,
|
||||
)
|
||||
response_text = _response_text(response)
|
||||
|
||||
if not response_text or not response_text.strip():
|
||||
logger.warning(
|
||||
"Empty response for session {}, retrying",
|
||||
API_SESSION_KEY,
|
||||
)
|
||||
retry_response = await asyncio.wait_for(
|
||||
agent_loop.process_direct(
|
||||
content=user_content,
|
||||
session_key=API_SESSION_KEY,
|
||||
channel="api",
|
||||
chat_id=API_CHAT_ID,
|
||||
),
|
||||
timeout=timeout_s,
|
||||
)
|
||||
response_text = _response_text(retry_response)
|
||||
if not response_text or not response_text.strip():
|
||||
logger.warning(
|
||||
"Empty response after retry for session {}, using fallback",
|
||||
API_SESSION_KEY,
|
||||
)
|
||||
response_text = _FALLBACK
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return _error_json(504, f"Request timed out after {timeout_s}s")
|
||||
except Exception:
|
||||
logger.exception("Error processing request for session {}", API_SESSION_KEY)
|
||||
return _error_json(500, "Internal server error", err_type="server_error")
|
||||
except Exception:
|
||||
logger.exception("Unexpected API lock error for session {}", API_SESSION_KEY)
|
||||
return _error_json(500, "Internal server error", err_type="server_error")
|
||||
|
||||
return web.json_response(_chat_completion_response(response_text, model_name))
|
||||
|
||||
|
||||
async def handle_models(request: web.Request) -> web.Response:
|
||||
"""GET /v1/models"""
|
||||
model_name = request.app.get("model_name", "nanobot")
|
||||
return web.json_response({
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": model_name,
|
||||
"object": "model",
|
||||
"created": 0,
|
||||
"owned_by": "nanobot",
|
||||
}
|
||||
],
|
||||
})
|
||||
|
||||
|
||||
async def handle_health(request: web.Request) -> web.Response:
|
||||
"""GET /health"""
|
||||
return web.json_response({"status": "ok"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# App factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_app(agent_loop, model_name: str = "nanobot", request_timeout: float = 120.0) -> web.Application:
|
||||
"""Create the aiohttp application.
|
||||
|
||||
Args:
|
||||
agent_loop: An initialized AgentLoop instance.
|
||||
model_name: Model name reported in responses.
|
||||
request_timeout: Per-request timeout in seconds.
|
||||
"""
|
||||
app = web.Application()
|
||||
app["agent_loop"] = agent_loop
|
||||
app["model_name"] = model_name
|
||||
app["request_timeout"] = request_timeout
|
||||
app["session_lock"] = asyncio.Lock()
|
||||
|
||||
app.router.add_post("/v1/chat/completions", handle_chat_completions)
|
||||
app.router.add_get("/v1/models", handle_models)
|
||||
app.router.add_get("/health", handle_health)
|
||||
return app
|
||||
@ -491,6 +491,91 @@ def _migrate_cron_store(config: "Config") -> None:
|
||||
shutil.move(str(legacy_path), str(new_path))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OpenAI-Compatible API Server
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@app.command()
|
||||
def serve(
|
||||
port: int | None = typer.Option(None, "--port", "-p", help="API server port"),
|
||||
host: str | None = typer.Option(None, "--host", "-H", help="Bind address"),
|
||||
timeout: float | None = typer.Option(None, "--timeout", "-t", help="Per-request timeout (seconds)"),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show nanobot runtime logs"),
|
||||
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
|
||||
config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
|
||||
):
|
||||
"""Start the OpenAI-compatible API server (/v1/chat/completions)."""
|
||||
try:
|
||||
from aiohttp import web # noqa: F401
|
||||
except ImportError:
|
||||
console.print("[red]aiohttp is required. Install with: pip install 'nanobot-ai[api]'[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
from loguru import logger
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.api.server import create_app
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.session.manager import SessionManager
|
||||
|
||||
if verbose:
|
||||
logger.enable("nanobot")
|
||||
else:
|
||||
logger.disable("nanobot")
|
||||
|
||||
runtime_config = _load_runtime_config(config, workspace)
|
||||
api_cfg = runtime_config.api
|
||||
host = host if host is not None else api_cfg.host
|
||||
port = port if port is not None else api_cfg.port
|
||||
timeout = timeout if timeout is not None else api_cfg.timeout
|
||||
sync_workspace_templates(runtime_config.workspace_path)
|
||||
bus = MessageBus()
|
||||
provider = _make_provider(runtime_config)
|
||||
session_manager = SessionManager(runtime_config.workspace_path)
|
||||
agent_loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=runtime_config.workspace_path,
|
||||
model=runtime_config.agents.defaults.model,
|
||||
max_iterations=runtime_config.agents.defaults.max_tool_iterations,
|
||||
context_window_tokens=runtime_config.agents.defaults.context_window_tokens,
|
||||
web_search_config=runtime_config.tools.web.search,
|
||||
web_proxy=runtime_config.tools.web.proxy or None,
|
||||
exec_config=runtime_config.tools.exec,
|
||||
restrict_to_workspace=runtime_config.tools.restrict_to_workspace,
|
||||
session_manager=session_manager,
|
||||
mcp_servers=runtime_config.tools.mcp_servers,
|
||||
channels_config=runtime_config.channels,
|
||||
timezone=runtime_config.agents.defaults.timezone,
|
||||
)
|
||||
|
||||
model_name = runtime_config.agents.defaults.model
|
||||
console.print(f"{__logo__} Starting OpenAI-compatible API server")
|
||||
console.print(f" [cyan]Endpoint[/cyan] : http://{host}:{port}/v1/chat/completions")
|
||||
console.print(f" [cyan]Model[/cyan] : {model_name}")
|
||||
console.print(" [cyan]Session[/cyan] : api:default")
|
||||
console.print(f" [cyan]Timeout[/cyan] : {timeout}s")
|
||||
if host in {"0.0.0.0", "::"}:
|
||||
console.print(
|
||||
"[yellow]Warning:[/yellow] API is bound to all interfaces. "
|
||||
"Only do this behind a trusted network boundary, firewall, or reverse proxy."
|
||||
)
|
||||
console.print()
|
||||
|
||||
api_app = create_app(agent_loop, model_name=model_name, request_timeout=timeout)
|
||||
|
||||
async def on_startup(_app):
|
||||
await agent_loop._connect_mcp()
|
||||
|
||||
async def on_cleanup(_app):
|
||||
await agent_loop.close_mcp()
|
||||
|
||||
api_app.on_startup.append(on_startup)
|
||||
api_app.on_cleanup.append(on_cleanup)
|
||||
|
||||
web.run_app(api_app, host=host, port=port, print=lambda msg: logger.info(msg))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Gateway / Server
|
||||
# ============================================================================
|
||||
|
||||
@ -96,6 +96,14 @@ class HeartbeatConfig(Base):
|
||||
keep_recent_messages: int = 8
|
||||
|
||||
|
||||
class ApiConfig(Base):
|
||||
"""OpenAI-compatible API server configuration."""
|
||||
|
||||
host: str = "127.0.0.1" # Safer default: local-only bind.
|
||||
port: int = 8900
|
||||
timeout: float = 120.0 # Per-request timeout in seconds.
|
||||
|
||||
|
||||
class GatewayConfig(Base):
|
||||
"""Gateway/server configuration."""
|
||||
|
||||
@ -156,6 +164,7 @@ class Config(BaseSettings):
|
||||
agents: AgentsConfig = Field(default_factory=AgentsConfig)
|
||||
channels: ChannelsConfig = Field(default_factory=ChannelsConfig)
|
||||
providers: ProvidersConfig = Field(default_factory=ProvidersConfig)
|
||||
api: ApiConfig = Field(default_factory=ApiConfig)
|
||||
gateway: GatewayConfig = Field(default_factory=GatewayConfig)
|
||||
tools: ToolsConfig = Field(default_factory=ToolsConfig)
|
||||
|
||||
|
||||
@ -51,6 +51,9 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
api = [
|
||||
"aiohttp>=3.9.0,<4.0.0",
|
||||
]
|
||||
wecom = [
|
||||
"wecom-aibot-sdk-python>=0.1.5",
|
||||
]
|
||||
@ -73,6 +76,7 @@ langsmith = [
|
||||
dev = [
|
||||
"pytest>=9.0.0,<10.0.0",
|
||||
"pytest-asyncio>=1.3.0,<2.0.0",
|
||||
"aiohttp>=3.9.0,<4.0.0",
|
||||
"pytest-cov>=6.0.0,<7.0.0",
|
||||
"ruff>=0.1.0",
|
||||
]
|
||||
|
||||
351
tests/agent/test_hook_composite.py
Normal file
351
tests/agent/test_hook_composite.py
Normal file
@ -0,0 +1,351 @@
|
||||
"""Tests for CompositeHook fan-out, error isolation, and integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||
|
||||
|
||||
def _ctx() -> AgentHookContext:
|
||||
return AgentHookContext(iteration=0, messages=[])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fan-out: every hook is called in order
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_fans_out_before_iteration():
|
||||
calls: list[str] = []
|
||||
|
||||
class H(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
calls.append(f"A:{context.iteration}")
|
||||
|
||||
class H2(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
calls.append(f"B:{context.iteration}")
|
||||
|
||||
hook = CompositeHook([H(), H2()])
|
||||
ctx = _ctx()
|
||||
await hook.before_iteration(ctx)
|
||||
assert calls == ["A:0", "B:0"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_fans_out_all_async_methods():
|
||||
"""Verify all async methods fan out to every hook."""
|
||||
events: list[str] = []
|
||||
|
||||
class RecordingHook(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
events.append("before_iteration")
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
events.append(f"on_stream:{delta}")
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
events.append(f"on_stream_end:{resuming}")
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
events.append("before_execute_tools")
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
events.append("after_iteration")
|
||||
|
||||
hook = CompositeHook([RecordingHook(), RecordingHook()])
|
||||
ctx = _ctx()
|
||||
|
||||
await hook.before_iteration(ctx)
|
||||
await hook.on_stream(ctx, "hi")
|
||||
await hook.on_stream_end(ctx, resuming=True)
|
||||
await hook.before_execute_tools(ctx)
|
||||
await hook.after_iteration(ctx)
|
||||
|
||||
assert events == [
|
||||
"before_iteration", "before_iteration",
|
||||
"on_stream:hi", "on_stream:hi",
|
||||
"on_stream_end:True", "on_stream_end:True",
|
||||
"before_execute_tools", "before_execute_tools",
|
||||
"after_iteration", "after_iteration",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error isolation: one hook raises, others still run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_error_isolation_before_iteration():
|
||||
calls: list[str] = []
|
||||
|
||||
class Bad(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
class Good(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
calls.append("good")
|
||||
|
||||
hook = CompositeHook([Bad(), Good()])
|
||||
await hook.before_iteration(_ctx())
|
||||
assert calls == ["good"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_error_isolation_on_stream():
|
||||
calls: list[str] = []
|
||||
|
||||
class Bad(AgentHook):
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
raise RuntimeError("stream-boom")
|
||||
|
||||
class Good(AgentHook):
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
calls.append(delta)
|
||||
|
||||
hook = CompositeHook([Bad(), Good()])
|
||||
await hook.on_stream(_ctx(), "delta")
|
||||
assert calls == ["delta"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_error_isolation_all_async():
|
||||
"""Error isolation for on_stream_end, before_execute_tools, after_iteration."""
|
||||
calls: list[str] = []
|
||||
|
||||
class Bad(AgentHook):
|
||||
async def on_stream_end(self, context, *, resuming):
|
||||
raise RuntimeError("err")
|
||||
async def before_execute_tools(self, context):
|
||||
raise RuntimeError("err")
|
||||
async def after_iteration(self, context):
|
||||
raise RuntimeError("err")
|
||||
|
||||
class Good(AgentHook):
|
||||
async def on_stream_end(self, context, *, resuming):
|
||||
calls.append("on_stream_end")
|
||||
async def before_execute_tools(self, context):
|
||||
calls.append("before_execute_tools")
|
||||
async def after_iteration(self, context):
|
||||
calls.append("after_iteration")
|
||||
|
||||
hook = CompositeHook([Bad(), Good()])
|
||||
ctx = _ctx()
|
||||
await hook.on_stream_end(ctx, resuming=False)
|
||||
await hook.before_execute_tools(ctx)
|
||||
await hook.after_iteration(ctx)
|
||||
assert calls == ["on_stream_end", "before_execute_tools", "after_iteration"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# finalize_content: pipeline semantics (no error isolation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_composite_finalize_content_pipeline():
|
||||
class Upper(AgentHook):
|
||||
def finalize_content(self, context, content):
|
||||
return content.upper() if content else content
|
||||
|
||||
class Suffix(AgentHook):
|
||||
def finalize_content(self, context, content):
|
||||
return (content + "!") if content else content
|
||||
|
||||
hook = CompositeHook([Upper(), Suffix()])
|
||||
result = hook.finalize_content(_ctx(), "hello")
|
||||
assert result == "HELLO!"
|
||||
|
||||
|
||||
def test_composite_finalize_content_none_passthrough():
|
||||
hook = CompositeHook([AgentHook()])
|
||||
assert hook.finalize_content(_ctx(), None) is None
|
||||
|
||||
|
||||
def test_composite_finalize_content_ordering():
|
||||
"""First hook transforms first, result feeds second hook."""
|
||||
steps: list[str] = []
|
||||
|
||||
class H1(AgentHook):
|
||||
def finalize_content(self, context, content):
|
||||
steps.append(f"H1:{content}")
|
||||
return content.upper()
|
||||
|
||||
class H2(AgentHook):
|
||||
def finalize_content(self, context, content):
|
||||
steps.append(f"H2:{content}")
|
||||
return content + "!"
|
||||
|
||||
hook = CompositeHook([H1(), H2()])
|
||||
result = hook.finalize_content(_ctx(), "hi")
|
||||
assert result == "HI!"
|
||||
assert steps == ["H1:hi", "H2:HI"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# wants_streaming: any-semantics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_composite_wants_streaming_any_true():
|
||||
class No(AgentHook):
|
||||
def wants_streaming(self):
|
||||
return False
|
||||
|
||||
class Yes(AgentHook):
|
||||
def wants_streaming(self):
|
||||
return True
|
||||
|
||||
hook = CompositeHook([No(), Yes(), No()])
|
||||
assert hook.wants_streaming() is True
|
||||
|
||||
|
||||
def test_composite_wants_streaming_all_false():
|
||||
hook = CompositeHook([AgentHook(), AgentHook()])
|
||||
assert hook.wants_streaming() is False
|
||||
|
||||
|
||||
def test_composite_wants_streaming_empty():
|
||||
hook = CompositeHook([])
|
||||
assert hook.wants_streaming() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Empty hooks list: behaves like no-op AgentHook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_empty_hooks_no_ops():
|
||||
hook = CompositeHook([])
|
||||
ctx = _ctx()
|
||||
await hook.before_iteration(ctx)
|
||||
await hook.on_stream(ctx, "delta")
|
||||
await hook.on_stream_end(ctx, resuming=False)
|
||||
await hook.before_execute_tools(ctx)
|
||||
await hook.after_iteration(ctx)
|
||||
assert hook.finalize_content(ctx, "test") == "test"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: AgentLoop with extra hooks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_loop(tmp_path, hooks=None):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.generation.max_tokens = 4096
|
||||
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager") as mock_sub_mgr, \
|
||||
patch("nanobot.agent.loop.MemoryConsolidator"):
|
||||
mock_sub_mgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, hooks=hooks,
|
||||
)
|
||||
return loop
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_loop_extra_hook_receives_calls(tmp_path):
|
||||
"""Extra hook passed to AgentLoop is called alongside core LoopHook."""
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
events: list[str] = []
|
||||
|
||||
class TrackingHook(AgentHook):
|
||||
async def before_iteration(self, context):
|
||||
events.append(f"before_iter:{context.iteration}")
|
||||
|
||||
async def after_iteration(self, context):
|
||||
events.append(f"after_iter:{context.iteration}")
|
||||
|
||||
loop = _make_loop(tmp_path, hooks=[TrackingHook()])
|
||||
loop.provider.chat_with_retry = AsyncMock(
|
||||
return_value=LLMResponse(content="done", tool_calls=[], usage={})
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
content, tools_used, messages = await loop._run_agent_loop(
|
||||
[{"role": "user", "content": "hi"}]
|
||||
)
|
||||
|
||||
assert content == "done"
|
||||
assert "before_iter:0" in events
|
||||
assert "after_iter:0" in events
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_loop_extra_hook_error_isolation(tmp_path):
|
||||
"""A faulty extra hook does not crash the agent loop."""
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
class BadHook(AgentHook):
|
||||
async def before_iteration(self, context):
|
||||
raise RuntimeError("I am broken")
|
||||
|
||||
loop = _make_loop(tmp_path, hooks=[BadHook()])
|
||||
loop.provider.chat_with_retry = AsyncMock(
|
||||
return_value=LLMResponse(content="still works", tool_calls=[], usage={})
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
content, _, _ = await loop._run_agent_loop(
|
||||
[{"role": "user", "content": "hi"}]
|
||||
)
|
||||
|
||||
assert content == "still works"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_loop_extra_hooks_do_not_swallow_loop_hook_errors(tmp_path):
|
||||
"""Extra hooks must not change the core LoopHook failure behavior."""
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
loop = _make_loop(tmp_path, hooks=[AgentHook()])
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="c1", name="list_dir", arguments={"path": "."})],
|
||||
usage={},
|
||||
))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
|
||||
async def bad_progress(*args, **kwargs):
|
||||
raise RuntimeError("progress failed")
|
||||
|
||||
with pytest.raises(RuntimeError, match="progress failed"):
|
||||
await loop._run_agent_loop([], on_progress=bad_progress)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_loop_no_hooks_backward_compat(tmp_path):
|
||||
"""Without hooks param, behavior is identical to before."""
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
loop = _make_loop(tmp_path)
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="c1", name="list_dir", arguments={"path": "."})],
|
||||
))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
loop.max_iterations = 2
|
||||
|
||||
content, tools_used, _ = await loop._run_agent_loop([])
|
||||
assert content == (
|
||||
"I reached the maximum number of tool call iterations (2) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
assert tools_used == ["list_dir", "list_dir"]
|
||||
@ -642,27 +642,105 @@ def test_heartbeat_retains_recent_messages_by_default():
|
||||
assert config.gateway.heartbeat.keep_recent_messages == 8
|
||||
|
||||
|
||||
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
|
||||
def _write_instance_config(tmp_path: Path) -> Path:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
return config_file
|
||||
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
def _stop_gateway_provider(_config) -> object:
|
||||
raise _StopGatewayError("stop")
|
||||
|
||||
|
||||
def _patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config: Config,
|
||||
*,
|
||||
set_config_path=None,
|
||||
sync_templates=None,
|
||||
make_provider=None,
|
||||
message_bus=None,
|
||||
session_manager=None,
|
||||
cron_service=None,
|
||||
get_cron_dir=None,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"nanobot.config.loader.set_config_path",
|
||||
lambda path: seen.__setitem__("config_path", path),
|
||||
set_config_path or (lambda _path: None),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands.sync_workspace_templates",
|
||||
lambda path: seen.__setitem__("workspace", path),
|
||||
sync_templates or (lambda _path: None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
|
||||
make_provider or (lambda _config: object()),
|
||||
)
|
||||
|
||||
if message_bus is not None:
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", message_bus)
|
||||
if session_manager is not None:
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", session_manager)
|
||||
if cron_service is not None:
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", cron_service)
|
||||
if get_cron_dir is not None:
|
||||
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", get_cron_dir)
|
||||
|
||||
|
||||
def _patch_serve_runtime(monkeypatch, config: Config, seen: dict[str, object]) -> None:
|
||||
pytest.importorskip("aiohttp")
|
||||
|
||||
class _FakeApiApp:
|
||||
def __init__(self) -> None:
|
||||
self.on_startup: list[object] = []
|
||||
self.on_cleanup: list[object] = []
|
||||
|
||||
class _FakeAgentLoop:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
seen["workspace"] = kwargs["workspace"]
|
||||
|
||||
async def _connect_mcp(self) -> None:
|
||||
return None
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
return None
|
||||
|
||||
def _fake_create_app(agent_loop, model_name: str, request_timeout: float):
|
||||
seen["agent_loop"] = agent_loop
|
||||
seen["model_name"] = model_name
|
||||
seen["request_timeout"] = request_timeout
|
||||
return _FakeApiApp()
|
||||
|
||||
def _fake_run_app(api_app, host: str, port: int, print):
|
||||
seen["api_app"] = api_app
|
||||
seen["host"] = host
|
||||
seen["port"] = port
|
||||
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
message_bus=lambda: object(),
|
||||
session_manager=lambda _workspace: object(),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.api.server.create_app", _fake_create_app)
|
||||
monkeypatch.setattr("aiohttp.web.run_app", _fake_run_app)
|
||||
|
||||
|
||||
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
set_config_path=lambda path: seen.__setitem__("config_path", path),
|
||||
sync_templates=lambda path: seen.__setitem__("workspace", path),
|
||||
make_provider=_stop_gateway_provider,
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
@ -673,24 +751,17 @@ def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Pa
|
||||
|
||||
|
||||
def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
override = tmp_path / "override-workspace"
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands.sync_workspace_templates",
|
||||
lambda path: seen.__setitem__("workspace", path),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
sync_templates=lambda path: seen.__setitem__("workspace", path),
|
||||
make_provider=_stop_gateway_provider,
|
||||
)
|
||||
|
||||
result = runner.invoke(
|
||||
@ -704,27 +775,23 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path)
|
||||
|
||||
|
||||
def test_gateway_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
||||
|
||||
class _StopCron:
|
||||
def __init__(self, store_path: Path) -> None:
|
||||
seen["cron_store"] = store_path
|
||||
raise _StopGatewayError("stop")
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
message_bus=lambda: object(),
|
||||
session_manager=lambda _workspace: object(),
|
||||
cron_service=_StopCron,
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
@ -735,10 +802,7 @@ def test_gateway_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path:
|
||||
def test_gateway_workspace_override_does_not_migrate_legacy_cron(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
legacy_dir = tmp_path / "global" / "cron"
|
||||
legacy_dir.mkdir(parents=True)
|
||||
legacy_file = legacy_dir / "jobs.json"
|
||||
@ -748,20 +812,19 @@ def test_gateway_workspace_override_does_not_migrate_legacy_cron(
|
||||
config = Config()
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
||||
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
|
||||
|
||||
class _StopCron:
|
||||
def __init__(self, store_path: Path) -> None:
|
||||
seen["cron_store"] = store_path
|
||||
raise _StopGatewayError("stop")
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
message_bus=lambda: object(),
|
||||
session_manager=lambda _workspace: object(),
|
||||
cron_service=_StopCron,
|
||||
get_cron_dir=lambda: legacy_dir,
|
||||
)
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
@ -777,10 +840,7 @@ def test_gateway_workspace_override_does_not_migrate_legacy_cron(
|
||||
def test_gateway_custom_config_workspace_does_not_migrate_legacy_cron(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
legacy_dir = tmp_path / "global" / "cron"
|
||||
legacy_dir.mkdir(parents=True)
|
||||
legacy_file = legacy_dir / "jobs.json"
|
||||
@ -791,20 +851,19 @@ def test_gateway_custom_config_workspace_does_not_migrate_legacy_cron(
|
||||
config.agents.defaults.workspace = str(custom_workspace)
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
||||
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
|
||||
|
||||
class _StopCron:
|
||||
def __init__(self, store_path: Path) -> None:
|
||||
seen["cron_store"] = store_path
|
||||
raise _StopGatewayError("stop")
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
message_bus=lambda: object(),
|
||||
session_manager=lambda _workspace: object(),
|
||||
cron_service=_StopCron,
|
||||
get_cron_dir=lambda: legacy_dir,
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
@ -856,19 +915,14 @@ def test_migrate_cron_store_skips_when_workspace_file_exists(tmp_path: Path) ->
|
||||
|
||||
|
||||
def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.gateway.port = 18791
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
make_provider=_stop_gateway_provider,
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
@ -878,19 +932,14 @@ def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_
|
||||
|
||||
|
||||
def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.gateway.port = 18791
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
make_provider=_stop_gateway_provider,
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"])
|
||||
@ -899,6 +948,63 @@ def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path)
|
||||
assert "port 18792" in result.stdout
|
||||
|
||||
|
||||
def test_serve_uses_api_config_defaults_and_workspace_override(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
config.api.host = "127.0.0.2"
|
||||
config.api.port = 18900
|
||||
config.api.timeout = 45.0
|
||||
override_workspace = tmp_path / "override-workspace"
|
||||
seen: dict[str, object] = {}
|
||||
|
||||
_patch_serve_runtime(monkeypatch, config, seen)
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
["serve", "--config", str(config_file), "--workspace", str(override_workspace)],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert seen["workspace"] == override_workspace
|
||||
assert seen["host"] == "127.0.0.2"
|
||||
assert seen["port"] == 18900
|
||||
assert seen["request_timeout"] == 45.0
|
||||
|
||||
|
||||
def test_serve_cli_options_override_api_config(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.api.host = "127.0.0.2"
|
||||
config.api.port = 18900
|
||||
config.api.timeout = 45.0
|
||||
seen: dict[str, object] = {}
|
||||
|
||||
_patch_serve_runtime(monkeypatch, config, seen)
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"serve",
|
||||
"--config",
|
||||
str(config_file),
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
"18901",
|
||||
"--timeout",
|
||||
"46",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert seen["host"] == "127.0.0.1"
|
||||
assert seen["port"] == 18901
|
||||
assert seen["request_timeout"] == 46.0
|
||||
|
||||
|
||||
def test_channels_login_requires_channel_name() -> None:
|
||||
result = runner.invoke(app, ["channels", "login"])
|
||||
|
||||
|
||||
372
tests/test_openai_api.py
Normal file
372
tests/test_openai_api.py
Normal file
@ -0,0 +1,372 @@
|
||||
"""Focused tests for the fixed-session OpenAI-compatible API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from nanobot.api.server import (
|
||||
API_CHAT_ID,
|
||||
API_SESSION_KEY,
|
||||
_chat_completion_response,
|
||||
_error_json,
|
||||
create_app,
|
||||
handle_chat_completions,
|
||||
)
|
||||
|
||||
try:
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
HAS_AIOHTTP = True
|
||||
except ImportError:
|
||||
HAS_AIOHTTP = False
|
||||
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
|
||||
def _make_mock_agent(response_text: str = "mock response") -> MagicMock:
|
||||
agent = MagicMock()
|
||||
agent.process_direct = AsyncMock(return_value=response_text)
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
return _make_mock_agent()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(mock_agent):
|
||||
return create_app(mock_agent, model_name="test-model", request_timeout=10.0)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def aiohttp_client():
|
||||
clients: list[TestClient] = []
|
||||
|
||||
async def _make_client(app):
|
||||
client = TestClient(TestServer(app))
|
||||
await client.start_server()
|
||||
clients.append(client)
|
||||
return client
|
||||
|
||||
try:
|
||||
yield _make_client
|
||||
finally:
|
||||
for client in clients:
|
||||
await client.close()
|
||||
|
||||
|
||||
def test_error_json() -> None:
|
||||
resp = _error_json(400, "bad request")
|
||||
assert resp.status == 400
|
||||
body = json.loads(resp.body)
|
||||
assert body["error"]["message"] == "bad request"
|
||||
assert body["error"]["code"] == 400
|
||||
|
||||
|
||||
def test_chat_completion_response() -> None:
|
||||
result = _chat_completion_response("hello world", "test-model")
|
||||
assert result["object"] == "chat.completion"
|
||||
assert result["model"] == "test-model"
|
||||
assert result["choices"][0]["message"]["content"] == "hello world"
|
||||
assert result["choices"][0]["finish_reason"] == "stop"
|
||||
assert result["id"].startswith("chatcmpl-")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_messages_returns_400(aiohttp_client, app) -> None:
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/v1/chat/completions", json={"model": "test"})
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_user_message_returns_400(aiohttp_client, app) -> None:
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "system", "content": "you are a bot"}]},
|
||||
)
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_true_returns_400(aiohttp_client, app) -> None:
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello"}], "stream": True},
|
||||
)
|
||||
assert resp.status == 400
|
||||
body = await resp.json()
|
||||
assert "stream" in body["error"]["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_mismatch_returns_400() -> None:
|
||||
request = MagicMock()
|
||||
request.json = AsyncMock(
|
||||
return_value={
|
||||
"model": "other-model",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
}
|
||||
)
|
||||
request.app = {
|
||||
"agent_loop": _make_mock_agent(),
|
||||
"model_name": "test-model",
|
||||
"request_timeout": 10.0,
|
||||
"session_lock": asyncio.Lock(),
|
||||
}
|
||||
|
||||
resp = await handle_chat_completions(request)
|
||||
assert resp.status == 400
|
||||
body = json.loads(resp.body)
|
||||
assert "test-model" in body["error"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_user_message_required() -> None:
|
||||
request = MagicMock()
|
||||
request.json = AsyncMock(
|
||||
return_value={
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "previous reply"},
|
||||
],
|
||||
}
|
||||
)
|
||||
request.app = {
|
||||
"agent_loop": _make_mock_agent(),
|
||||
"model_name": "test-model",
|
||||
"request_timeout": 10.0,
|
||||
"session_lock": asyncio.Lock(),
|
||||
}
|
||||
|
||||
resp = await handle_chat_completions(request)
|
||||
assert resp.status == 400
|
||||
body = json.loads(resp.body)
|
||||
assert "single user message" in body["error"]["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_user_message_must_have_user_role() -> None:
|
||||
request = MagicMock()
|
||||
request.json = AsyncMock(
|
||||
return_value={
|
||||
"messages": [{"role": "system", "content": "you are a bot"}],
|
||||
}
|
||||
)
|
||||
request.app = {
|
||||
"agent_loop": _make_mock_agent(),
|
||||
"model_name": "test-model",
|
||||
"request_timeout": 10.0,
|
||||
"session_lock": asyncio.Lock(),
|
||||
}
|
||||
|
||||
resp = await handle_chat_completions(request)
|
||||
assert resp.status == 400
|
||||
body = json.loads(resp.body)
|
||||
assert "single user message" in body["error"]["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_request_uses_fixed_api_session(aiohttp_client, mock_agent) -> None:
|
||||
app = create_app(mock_agent, model_name="test-model")
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["choices"][0]["message"]["content"] == "mock response"
|
||||
assert body["model"] == "test-model"
|
||||
mock_agent.process_direct.assert_called_once_with(
|
||||
content="hello",
|
||||
session_key=API_SESSION_KEY,
|
||||
channel="api",
|
||||
chat_id=API_CHAT_ID,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_followup_requests_share_same_session_key(aiohttp_client) -> None:
|
||||
call_log: list[str] = []
|
||||
|
||||
async def fake_process(content, session_key="", channel="", chat_id=""):
|
||||
call_log.append(session_key)
|
||||
return f"reply to {content}"
|
||||
|
||||
agent = MagicMock()
|
||||
agent.process_direct = fake_process
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
r1 = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "first"}]},
|
||||
)
|
||||
r2 = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "second"}]},
|
||||
)
|
||||
|
||||
assert r1.status == 200
|
||||
assert r2.status == 200
|
||||
assert call_log == [API_SESSION_KEY, API_SESSION_KEY]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_fixed_session_requests_are_serialized(aiohttp_client) -> None:
|
||||
order: list[str] = []
|
||||
barrier = asyncio.Event()
|
||||
|
||||
async def slow_process(content, session_key="", channel="", chat_id=""):
|
||||
order.append(f"start:{content}")
|
||||
if content == "first":
|
||||
barrier.set()
|
||||
await asyncio.sleep(0.1)
|
||||
else:
|
||||
await barrier.wait()
|
||||
order.append(f"end:{content}")
|
||||
return content
|
||||
|
||||
agent = MagicMock()
|
||||
agent.process_direct = slow_process
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
async def send(msg: str):
|
||||
return await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": msg}]},
|
||||
)
|
||||
|
||||
r1, r2 = await asyncio.gather(send("first"), send("second"))
|
||||
assert r1.status == 200
|
||||
assert r2.status == 200
|
||||
assert order.index("end:first") < order.index("start:second")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_models_endpoint(aiohttp_client, app) -> None:
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/v1/models")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["object"] == "list"
|
||||
assert body["data"][0]["id"] == "test-model"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_endpoint(aiohttp_client, app) -> None:
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/health")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["status"] == "ok"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_multimodal_content_extracts_text(aiohttp_client, mock_agent) -> None:
|
||||
app = create_app(mock_agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "describe this"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||
],
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
mock_agent.process_direct.assert_called_once_with(
|
||||
content="describe this",
|
||||
session_key=API_SESSION_KEY,
|
||||
channel="api",
|
||||
chat_id=API_CHAT_ID,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_response_retry_then_success(aiohttp_client) -> None:
|
||||
call_count = 0
|
||||
|
||||
async def sometimes_empty(content, session_key="", channel="", chat_id=""):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return ""
|
||||
return "recovered response"
|
||||
|
||||
agent = MagicMock()
|
||||
agent.process_direct = sometimes_empty
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["choices"][0]["message"]["content"] == "recovered response"
|
||||
assert call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_response_falls_back(aiohttp_client) -> None:
|
||||
call_count = 0
|
||||
|
||||
async def always_empty(content, session_key="", channel="", chat_id=""):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return ""
|
||||
|
||||
agent = MagicMock()
|
||||
agent.process_direct = always_empty
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["choices"][0]["message"]["content"] == "I've completed processing but have no response to give."
|
||||
assert call_count == 2
|
||||
Loading…
x
Reference in New Issue
Block a user