Merge remote-tracking branch 'origin/main' into pr-2646

This commit is contained in:
Xubin Ren 2026-04-03 17:53:52 +00:00
commit f409337fcf
45 changed files with 3677 additions and 1006 deletions

1
.gitignore vendored
View File

@ -2,6 +2,7 @@
.assets
.docs
.env
.web
*.pyc
dist/
build/

View File

@ -20,13 +20,20 @@
## 📢 News
> [!IMPORTANT]
> **Security note:** Due to `litellm` supply chain poisoning, **please check your Python environment ASAP** and refer to this [advisory](https://github.com/HKUDS/nanobot/discussions/2445) for details. We have fully removed the `litellm` since **v0.1.4.post6**.
- **2026-04-02** 🧱 **Long-running tasks** run more reliably — core runtime hardening.
- **2026-04-01** 🔑 GitHub Copilot auth restored; stricter workspace paths; OpenRouter Claude caching fix.
- **2026-03-31** 🛰️ WeChat multimodal alignment, Discord/Matrix polish, Python SDK facade, MCP and tool fixes.
- **2026-03-30** 🧩 OpenAI-compatible API tightened; composable agent lifecycle hooks.
- **2026-03-29** 💬 WeChat voice, typing, QR/media resilience; fixed-session OpenAI-compatible API.
- **2026-03-28** 📚 Provider docs refresh; skill template wording fix.
- **2026-03-27** 🚀 Released **v0.1.4.post6** — architecture decoupling, litellm removal, end-to-end streaming, WeChat channel, and a security fix. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post6) for details.
- **2026-03-26** 🏗️ Agent runner extracted and lifecycle hooks unified; stream delta coalescing at boundaries.
- **2026-03-25** 🌏 StepFun provider, configurable timezone, Gemini thought signatures.
- **2026-03-24** 🔧 WeChat compatibility, Feishu CardKit streaming, test suite restructured.
<details>
<summary>Earlier news</summary>
- **2026-03-23** 🔧 Command routing refactored for plugins, WhatsApp/WeChat media, unified channel login CLI.
- **2026-03-22** ⚡ End-to-end streaming, WeChat channel, Anthropic cache optimization, `/status` command.
- **2026-03-21** 🔒 Replace `litellm` with native `openai` + `anthropic` SDKs. Please see [commit](https://github.com/HKUDS/nanobot/commit/3dfdab7).
@ -34,10 +41,6 @@
- **2026-03-19** 💬 Telegram gets more resilient under load; Feishu now renders code blocks properly.
- **2026-03-18** 📷 Telegram can now send media via URL. Cron schedules show human-readable details.
- **2026-03-17** ✨ Feishu formatting glow-up, Slack reacts when done, custom endpoints support extra headers, and image handling is more reliable.
<details>
<summary>Earlier news</summary>
- **2026-03-16** 🚀 Released **v0.1.4.post5** — a refinement-focused release with stronger reliability and channel support, and a more dependable day-to-day experience. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post5) for details.
- **2026-03-15** 🧩 DingTalk rich media, smarter built-in skills, and cleaner model compatibility.
- **2026-03-14** 💬 Channel plugins, Feishu replies, and steadier MCP, QQ, and media handling.

View File

@ -110,6 +110,20 @@ IMPORTANT: To send files (images, documents, audio, video) to the user, you MUST
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
@staticmethod
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
if isinstance(left, str) and isinstance(right, str):
return f"{left}\n\n{right}" if left else right
def _to_blocks(value: Any) -> list[dict[str, Any]]:
if isinstance(value, list):
return [item if isinstance(item, dict) else {"type": "text", "text": str(item)} for item in value]
if value is None:
return []
return [{"type": "text", "text": str(value)}]
return _to_blocks(left) + _to_blocks(right)
def _load_bootstrap_files(self) -> str:
"""Load all bootstrap files from workspace."""
parts = []
@ -142,12 +156,17 @@ IMPORTANT: To send files (images, documents, audio, video) to the user, you MUST
merged = f"{runtime_ctx}\n\n{user_content}"
else:
merged = [{"type": "text", "text": runtime_ctx}] + user_content
return [
messages = [
{"role": "system", "content": self.build_system_prompt(skill_names)},
*history,
{"role": current_role, "content": merged},
]
if messages[-1].get("role") == current_role:
last = dict(messages[-1])
last["content"] = self._merge_message_content(last.get("content"), merged)
messages[-1] = last
return messages
messages.append({"role": current_role, "content": merged})
return messages
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
"""Build user message content with optional base64-encoded images."""

View File

@ -29,8 +29,11 @@ from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
from nanobot.bus.events import InboundMessage, OutboundMessage
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
from nanobot.bus.queue import MessageBus
from nanobot.config.schema import AgentDefaults
from nanobot.providers.base import LLMProvider
from nanobot.session.manager import Session, SessionManager
from nanobot.utils.helpers import image_placeholder_text, truncate_text
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
if TYPE_CHECKING:
from nanobot.config.schema import ChannelsConfig, ExecToolConfig, WebSearchConfig
@ -38,11 +41,7 @@ if TYPE_CHECKING:
class _LoopHook(AgentHook):
"""Core lifecycle hook for the main agent loop.
Handles streaming delta relay, progress reporting, tool-call logging,
and think-tag stripping for the built-in agent path.
"""
"""Core hook for the main loop."""
def __init__(
self,
@ -111,11 +110,7 @@ class _LoopHook(AgentHook):
class _LoopHookChain(AgentHook):
"""Run the core loop hook first, then best-effort extra hooks.
This preserves the historical failure behavior of ``_LoopHook`` while still
letting user-supplied hooks opt into ``CompositeHook`` isolation.
"""
"""Run the core hook before extra hooks."""
__slots__ = ("_primary", "_extras")
@ -163,7 +158,7 @@ class AgentLoop:
5. Sends responses back
"""
_TOOL_RESULT_MAX_CHARS = 16_000
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
def __init__(
self,
@ -171,8 +166,11 @@ class AgentLoop:
provider: LLMProvider,
workspace: Path,
model: str | None = None,
max_iterations: int = 40,
context_window_tokens: int = 65_536,
max_iterations: int | None = None,
context_window_tokens: int | None = None,
context_block_limit: int | None = None,
max_tool_result_chars: int | None = None,
provider_retry_mode: str = "standard",
web_search_config: WebSearchConfig | None = None,
web_proxy: str | None = None,
exec_config: ExecToolConfig | None = None,
@ -186,13 +184,27 @@ class AgentLoop:
):
from nanobot.config.schema import ExecToolConfig, WebSearchConfig
defaults = AgentDefaults()
self.bus = bus
self.channels_config = channels_config
self.provider = provider
self.workspace = workspace
self.model = model or provider.get_default_model()
self.max_iterations = max_iterations
self.context_window_tokens = context_window_tokens
self.max_iterations = (
max_iterations if max_iterations is not None else defaults.max_tool_iterations
)
self.context_window_tokens = (
context_window_tokens
if context_window_tokens is not None
else defaults.context_window_tokens
)
self.context_block_limit = context_block_limit
self.max_tool_result_chars = (
max_tool_result_chars
if max_tool_result_chars is not None
else defaults.max_tool_result_chars
)
self.provider_retry_mode = provider_retry_mode
self.web_search_config = web_search_config or WebSearchConfig()
self.web_proxy = web_proxy
self.exec_config = exec_config or ExecToolConfig()
@ -211,6 +223,7 @@ class AgentLoop:
workspace=workspace,
bus=bus,
model=self.model,
max_tool_result_chars=self.max_tool_result_chars,
web_search_config=self.web_search_config,
web_proxy=web_proxy,
exec_config=self.exec_config,
@ -322,6 +335,7 @@ class AgentLoop:
on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None,
*,
session: Session | None = None,
channel: str = "cli",
chat_id: str = "direct",
message_id: str | None = None,
@ -348,14 +362,27 @@ class AgentLoop:
else loop_hook
)
async def _checkpoint(payload: dict[str, Any]) -> None:
if session is None:
return
self._set_runtime_checkpoint(session, payload)
result = await self.runner.run(AgentRunSpec(
initial_messages=initial_messages,
tools=self.tools,
model=self.model,
max_iterations=self.max_iterations,
max_tool_result_chars=self.max_tool_result_chars,
hook=hook,
error_message="Sorry, I encountered an error calling the AI model.",
concurrent_tools=True,
workspace=self.workspace,
session_key=session.key if session else None,
context_window_tokens=self.context_window_tokens,
context_block_limit=self.context_block_limit,
provider_retry_mode=self.provider_retry_mode,
progress_callback=on_progress,
checkpoint_callback=_checkpoint,
))
self._last_usage = result.usage
if result.stop_reason == "max_iterations":
@ -493,6 +520,8 @@ class AgentLoop:
logger.info("Processing system message from {}", msg.sender_id)
key = f"{channel}:{chat_id}"
session = self.sessions.get_or_create(key)
if self._restore_runtime_checkpoint(session):
self.sessions.save(session)
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
history = session.get_history(max_messages=0)
@ -503,10 +532,11 @@ class AgentLoop:
current_role=current_role,
)
final_content, _, all_msgs = await self._run_agent_loop(
messages, channel=channel, chat_id=chat_id,
messages, session=session, channel=channel, chat_id=chat_id,
message_id=msg.metadata.get("message_id"),
)
self._save_turn(session, all_msgs, 1 + len(history))
self._clear_runtime_checkpoint(session)
self.sessions.save(session)
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
return OutboundMessage(channel=channel, chat_id=chat_id,
@ -517,6 +547,8 @@ class AgentLoop:
key = session_key or msg.session_key
session = self.sessions.get_or_create(key)
if self._restore_runtime_checkpoint(session):
self.sessions.save(session)
# Slash commands
raw = msg.content.strip()
@ -552,14 +584,16 @@ class AgentLoop:
on_progress=on_progress or _bus_progress,
on_stream=on_stream,
on_stream_end=on_stream_end,
session=session,
channel=msg.channel, chat_id=msg.chat_id,
message_id=msg.metadata.get("message_id"),
)
if final_content is None:
final_content = "I've completed processing but have no response to give."
if final_content is None or not final_content.strip():
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
self._save_turn(session, all_msgs, 1 + len(history))
self._clear_runtime_checkpoint(session)
self.sessions.save(session)
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
@ -577,12 +611,6 @@ class AgentLoop:
metadata=meta,
)
@staticmethod
def _image_placeholder(block: dict[str, Any]) -> dict[str, str]:
"""Convert an inline image block into a compact text placeholder."""
path = (block.get("_meta") or {}).get("path", "")
return {"type": "text", "text": f"[image: {path}]" if path else "[image]"}
def _sanitize_persisted_blocks(
self,
content: list[dict[str, Any]],
@ -609,13 +637,14 @@ class AgentLoop:
block.get("type") == "image_url"
and block.get("image_url", {}).get("url", "").startswith("data:image/")
):
filtered.append(self._image_placeholder(block))
path = (block.get("_meta") or {}).get("path", "")
filtered.append({"type": "text", "text": image_placeholder_text(path)})
continue
if block.get("type") == "text" and isinstance(block.get("text"), str):
text = block["text"]
if truncate_text and len(text) > self._TOOL_RESULT_MAX_CHARS:
text = text[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
if truncate_text and len(text) > self.max_tool_result_chars:
text = truncate_text(text, self.max_tool_result_chars)
filtered.append({**block, "text": text})
continue
@ -632,8 +661,8 @@ class AgentLoop:
if role == "assistant" and not content and not entry.get("tool_calls"):
continue # skip empty assistant messages — they poison session context
if role == "tool":
if isinstance(content, str) and len(content) > self._TOOL_RESULT_MAX_CHARS:
entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
if isinstance(content, str) and len(content) > self.max_tool_result_chars:
entry["content"] = truncate_text(content, self.max_tool_result_chars)
elif isinstance(content, list):
filtered = self._sanitize_persisted_blocks(content, truncate_text=True)
if not filtered:
@ -656,6 +685,78 @@ class AgentLoop:
session.messages.append(entry)
session.updated_at = datetime.now()
def _set_runtime_checkpoint(self, session: Session, payload: dict[str, Any]) -> None:
"""Persist the latest in-flight turn state into session metadata."""
session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload
self.sessions.save(session)
def _clear_runtime_checkpoint(self, session: Session) -> None:
if self._RUNTIME_CHECKPOINT_KEY in session.metadata:
session.metadata.pop(self._RUNTIME_CHECKPOINT_KEY, None)
@staticmethod
def _checkpoint_message_key(message: dict[str, Any]) -> tuple[Any, ...]:
return (
message.get("role"),
message.get("content"),
message.get("tool_call_id"),
message.get("name"),
message.get("tool_calls"),
message.get("reasoning_content"),
message.get("thinking_blocks"),
)
def _restore_runtime_checkpoint(self, session: Session) -> bool:
"""Materialize an unfinished turn into session history before a new request."""
from datetime import datetime
checkpoint = session.metadata.get(self._RUNTIME_CHECKPOINT_KEY)
if not isinstance(checkpoint, dict):
return False
assistant_message = checkpoint.get("assistant_message")
completed_tool_results = checkpoint.get("completed_tool_results") or []
pending_tool_calls = checkpoint.get("pending_tool_calls") or []
restored_messages: list[dict[str, Any]] = []
if isinstance(assistant_message, dict):
restored = dict(assistant_message)
restored.setdefault("timestamp", datetime.now().isoformat())
restored_messages.append(restored)
for message in completed_tool_results:
if isinstance(message, dict):
restored = dict(message)
restored.setdefault("timestamp", datetime.now().isoformat())
restored_messages.append(restored)
for tool_call in pending_tool_calls:
if not isinstance(tool_call, dict):
continue
tool_id = tool_call.get("id")
name = ((tool_call.get("function") or {}).get("name")) or "tool"
restored_messages.append({
"role": "tool",
"tool_call_id": tool_id,
"name": name,
"content": "Error: Task interrupted before this tool finished.",
"timestamp": datetime.now().isoformat(),
})
overlap = 0
max_overlap = min(len(session.messages), len(restored_messages))
for size in range(max_overlap, 0, -1):
existing = session.messages[-size:]
restored = restored_messages[:size]
if all(
self._checkpoint_message_key(left) == self._checkpoint_message_key(right)
for left, right in zip(existing, restored)
):
overlap = size
break
session.messages.extend(restored_messages[overlap:])
self._clear_runtime_checkpoint(session)
return True
async def process_direct(
self,
content: str,

View File

@ -4,20 +4,36 @@ from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from loguru import logger
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.providers.base import LLMProvider, ToolCallRequest
from nanobot.utils.helpers import build_assistant_message
from nanobot.utils.helpers import (
build_assistant_message,
estimate_message_tokens,
estimate_prompt_tokens_chain,
find_legal_message_start,
maybe_persist_tool_result,
truncate_text,
)
from nanobot.utils.runtime import (
EMPTY_FINAL_RESPONSE_MESSAGE,
build_finalization_retry_message,
ensure_nonempty_tool_result,
is_blank_text,
repeated_external_lookup_error,
)
_DEFAULT_MAX_ITERATIONS_MESSAGE = (
"I reached the maximum number of tool call iterations ({max_iterations}) "
"without completing the task. You can try breaking the task into smaller steps."
)
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
_SNIP_SAFETY_BUFFER = 1024
@dataclass(slots=True)
class AgentRunSpec:
"""Configuration for a single agent execution."""
@ -26,6 +42,7 @@ class AgentRunSpec:
tools: ToolRegistry
model: str
max_iterations: int
max_tool_result_chars: int
temperature: float | None = None
max_tokens: int | None = None
reasoning_effort: str | None = None
@ -34,6 +51,13 @@ class AgentRunSpec:
max_iterations_message: str | None = None
concurrent_tools: bool = False
fail_on_tool_error: bool = False
workspace: Path | None = None
session_key: str | None = None
context_window_tokens: int | None = None
context_block_limit: int | None = None
provider_retry_mode: str = "standard"
progress_callback: Any | None = None
checkpoint_callback: Any | None = None
@dataclass(slots=True)
@ -60,91 +84,142 @@ class AgentRunner:
messages = list(spec.initial_messages)
final_content: str | None = None
tools_used: list[str] = []
usage: dict[str, int] = {}
usage: dict[str, int] = {"prompt_tokens": 0, "completion_tokens": 0}
error: str | None = None
stop_reason = "completed"
tool_events: list[dict[str, str]] = []
external_lookup_counts: dict[str, int] = {}
for iteration in range(spec.max_iterations):
try:
messages = self._apply_tool_result_budget(spec, messages)
messages_for_model = self._snip_history(spec, messages)
except Exception as exc:
logger.warning(
"Context governance failed on turn {} for {}: {}; using raw messages",
iteration,
spec.session_key or "default",
exc,
)
messages_for_model = messages
context = AgentHookContext(iteration=iteration, messages=messages)
await hook.before_iteration(context)
kwargs: dict[str, Any] = {
"messages": messages,
"tools": spec.tools.get_definitions(),
"model": spec.model,
}
if spec.temperature is not None:
kwargs["temperature"] = spec.temperature
if spec.max_tokens is not None:
kwargs["max_tokens"] = spec.max_tokens
if spec.reasoning_effort is not None:
kwargs["reasoning_effort"] = spec.reasoning_effort
if hook.wants_streaming():
async def _stream(delta: str) -> None:
await hook.on_stream(context, delta)
response = await self.provider.chat_stream_with_retry(
**kwargs,
on_content_delta=_stream,
)
else:
response = await self.provider.chat_with_retry(**kwargs)
raw_usage = response.usage or {}
response = await self._request_model(spec, messages_for_model, hook, context)
raw_usage = self._usage_dict(response.usage)
context.response = response
context.usage = raw_usage
context.usage = dict(raw_usage)
context.tool_calls = list(response.tool_calls)
# Accumulate standard fields into result usage.
usage["prompt_tokens"] = usage.get("prompt_tokens", 0) + int(raw_usage.get("prompt_tokens", 0) or 0)
usage["completion_tokens"] = usage.get("completion_tokens", 0) + int(raw_usage.get("completion_tokens", 0) or 0)
cached = raw_usage.get("cached_tokens")
if cached:
usage["cached_tokens"] = usage.get("cached_tokens", 0) + int(cached)
self._accumulate_usage(usage, raw_usage)
if response.has_tool_calls:
if hook.wants_streaming():
await hook.on_stream_end(context, resuming=True)
messages.append(build_assistant_message(
assistant_message = build_assistant_message(
response.content or "",
tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls],
reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks,
))
)
messages.append(assistant_message)
tools_used.extend(tc.name for tc in response.tool_calls)
await self._emit_checkpoint(
spec,
{
"phase": "awaiting_tools",
"iteration": iteration,
"model": spec.model,
"assistant_message": assistant_message,
"completed_tool_results": [],
"pending_tool_calls": [tc.to_openai_tool_call() for tc in response.tool_calls],
},
)
await hook.before_execute_tools(context)
results, new_events, fatal_error = await self._execute_tools(spec, response.tool_calls)
results, new_events, fatal_error = await self._execute_tools(
spec,
response.tool_calls,
external_lookup_counts,
)
tool_events.extend(new_events)
context.tool_results = list(results)
context.tool_events = list(new_events)
if fatal_error is not None:
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
final_content = error
stop_reason = "tool_error"
self._append_final_message(messages, final_content)
context.final_content = final_content
context.error = error
context.stop_reason = stop_reason
await hook.after_iteration(context)
break
completed_tool_results: list[dict[str, Any]] = []
for tool_call, result in zip(response.tool_calls, results):
messages.append({
tool_message = {
"role": "tool",
"tool_call_id": tool_call.id,
"name": tool_call.name,
"content": result,
})
"content": self._normalize_tool_result(
spec,
tool_call.id,
tool_call.name,
result,
),
}
messages.append(tool_message)
completed_tool_results.append(tool_message)
await self._emit_checkpoint(
spec,
{
"phase": "tools_completed",
"iteration": iteration,
"model": spec.model,
"assistant_message": assistant_message,
"completed_tool_results": completed_tool_results,
"pending_tool_calls": [],
},
)
await hook.after_iteration(context)
continue
clean = hook.finalize_content(context, response.content)
if response.finish_reason != "error" and is_blank_text(clean):
logger.warning(
"Empty final response on turn {} for {}; retrying with explicit finalization prompt",
iteration,
spec.session_key or "default",
)
if hook.wants_streaming():
await hook.on_stream_end(context, resuming=False)
response = await self._request_finalization_retry(spec, messages_for_model)
retry_usage = self._usage_dict(response.usage)
self._accumulate_usage(usage, retry_usage)
raw_usage = self._merge_usage(raw_usage, retry_usage)
context.response = response
context.usage = dict(raw_usage)
context.tool_calls = list(response.tool_calls)
clean = hook.finalize_content(context, response.content)
if hook.wants_streaming():
await hook.on_stream_end(context, resuming=False)
clean = hook.finalize_content(context, response.content)
if response.finish_reason == "error":
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
stop_reason = "error"
error = final_content
self._append_final_message(messages, final_content)
context.final_content = final_content
context.error = error
context.stop_reason = stop_reason
await hook.after_iteration(context)
break
if is_blank_text(clean):
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
stop_reason = "empty_final_response"
error = final_content
self._append_final_message(messages, final_content)
context.final_content = final_content
context.error = error
context.stop_reason = stop_reason
@ -156,6 +231,17 @@ class AgentRunner:
reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks,
))
await self._emit_checkpoint(
spec,
{
"phase": "final_response",
"iteration": iteration,
"model": spec.model,
"assistant_message": messages[-1],
"completed_tool_results": [],
"pending_tool_calls": [],
},
)
final_content = clean
context.final_content = final_content
context.stop_reason = stop_reason
@ -165,6 +251,7 @@ class AgentRunner:
stop_reason = "max_iterations"
template = spec.max_iterations_message or _DEFAULT_MAX_ITERATIONS_MESSAGE
final_content = template.format(max_iterations=spec.max_iterations)
self._append_final_message(messages, final_content)
return AgentRunResult(
final_content=final_content,
@ -176,21 +263,101 @@ class AgentRunner:
tool_events=tool_events,
)
def _build_request_kwargs(
self,
spec: AgentRunSpec,
messages: list[dict[str, Any]],
*,
tools: list[dict[str, Any]] | None,
) -> dict[str, Any]:
kwargs: dict[str, Any] = {
"messages": messages,
"tools": tools,
"model": spec.model,
"retry_mode": spec.provider_retry_mode,
"on_retry_wait": spec.progress_callback,
}
if spec.temperature is not None:
kwargs["temperature"] = spec.temperature
if spec.max_tokens is not None:
kwargs["max_tokens"] = spec.max_tokens
if spec.reasoning_effort is not None:
kwargs["reasoning_effort"] = spec.reasoning_effort
return kwargs
async def _request_model(
self,
spec: AgentRunSpec,
messages: list[dict[str, Any]],
hook: AgentHook,
context: AgentHookContext,
):
kwargs = self._build_request_kwargs(
spec,
messages,
tools=spec.tools.get_definitions(),
)
if hook.wants_streaming():
async def _stream(delta: str) -> None:
await hook.on_stream(context, delta)
return await self.provider.chat_stream_with_retry(
**kwargs,
on_content_delta=_stream,
)
return await self.provider.chat_with_retry(**kwargs)
async def _request_finalization_retry(
self,
spec: AgentRunSpec,
messages: list[dict[str, Any]],
):
retry_messages = list(messages)
retry_messages.append(build_finalization_retry_message())
kwargs = self._build_request_kwargs(spec, retry_messages, tools=None)
return await self.provider.chat_with_retry(**kwargs)
@staticmethod
def _usage_dict(usage: dict[str, Any] | None) -> dict[str, int]:
if not usage:
return {}
result: dict[str, int] = {}
for key, value in usage.items():
try:
result[key] = int(value or 0)
except (TypeError, ValueError):
continue
return result
@staticmethod
def _accumulate_usage(target: dict[str, int], addition: dict[str, int]) -> None:
for key, value in addition.items():
target[key] = target.get(key, 0) + value
@staticmethod
def _merge_usage(left: dict[str, int], right: dict[str, int]) -> dict[str, int]:
merged = dict(left)
for key, value in right.items():
merged[key] = merged.get(key, 0) + value
return merged
async def _execute_tools(
self,
spec: AgentRunSpec,
tool_calls: list[ToolCallRequest],
external_lookup_counts: dict[str, int],
) -> tuple[list[Any], list[dict[str, str]], BaseException | None]:
if spec.concurrent_tools:
tool_results = await asyncio.gather(*(
self._run_tool(spec, tool_call)
for tool_call in tool_calls
))
else:
tool_results = [
await self._run_tool(spec, tool_call)
for tool_call in tool_calls
]
batches = self._partition_tool_batches(spec, tool_calls)
tool_results: list[tuple[Any, dict[str, str], BaseException | None]] = []
for batch in batches:
if spec.concurrent_tools and len(batch) > 1:
tool_results.extend(await asyncio.gather(*(
self._run_tool(spec, tool_call, external_lookup_counts)
for tool_call in batch
)))
else:
for tool_call in batch:
tool_results.append(await self._run_tool(spec, tool_call, external_lookup_counts))
results: list[Any] = []
events: list[dict[str, str]] = []
@ -206,9 +373,44 @@ class AgentRunner:
self,
spec: AgentRunSpec,
tool_call: ToolCallRequest,
external_lookup_counts: dict[str, int],
) -> tuple[Any, dict[str, str], BaseException | None]:
_HINT = "\n\n[Analyze the error above and try a different approach.]"
lookup_error = repeated_external_lookup_error(
tool_call.name,
tool_call.arguments,
external_lookup_counts,
)
if lookup_error:
event = {
"name": tool_call.name,
"status": "error",
"detail": "repeated external lookup blocked",
}
if spec.fail_on_tool_error:
return lookup_error + _HINT, event, RuntimeError(lookup_error)
return lookup_error + _HINT, event, None
prepare_call = getattr(spec.tools, "prepare_call", None)
tool, params, prep_error = None, tool_call.arguments, None
if callable(prepare_call):
try:
prepared = prepare_call(tool_call.name, tool_call.arguments)
if isinstance(prepared, tuple) and len(prepared) == 3:
tool, params, prep_error = prepared
except Exception:
pass
if prep_error:
event = {
"name": tool_call.name,
"status": "error",
"detail": prep_error.split(": ", 1)[-1][:120],
}
return prep_error + _HINT, event, RuntimeError(prep_error) if spec.fail_on_tool_error else None
try:
result = await spec.tools.execute(tool_call.name, tool_call.arguments)
if tool is not None:
result = await tool.execute(**params)
else:
result = await spec.tools.execute(tool_call.name, params)
except asyncio.CancelledError:
raise
except BaseException as exc:
@ -221,14 +423,178 @@ class AgentRunner:
return f"Error: {type(exc).__name__}: {exc}", event, exc
return f"Error: {type(exc).__name__}: {exc}", event, None
if isinstance(result, str) and result.startswith("Error"):
event = {
"name": tool_call.name,
"status": "error",
"detail": result.replace("\n", " ").strip()[:120],
}
if spec.fail_on_tool_error:
return result + _HINT, event, RuntimeError(result)
return result + _HINT, event, None
detail = "" if result is None else str(result)
detail = detail.replace("\n", " ").strip()
if not detail:
detail = "(empty)"
elif len(detail) > 120:
detail = detail[:120] + "..."
return result, {
"name": tool_call.name,
"status": "error" if isinstance(result, str) and result.startswith("Error") else "ok",
"detail": detail,
}, None
return result, {"name": tool_call.name, "status": "ok", "detail": detail}, None
async def _emit_checkpoint(
self,
spec: AgentRunSpec,
payload: dict[str, Any],
) -> None:
callback = spec.checkpoint_callback
if callback is not None:
await callback(payload)
@staticmethod
def _append_final_message(messages: list[dict[str, Any]], content: str | None) -> None:
if not content:
return
if (
messages
and messages[-1].get("role") == "assistant"
and not messages[-1].get("tool_calls")
):
if messages[-1].get("content") == content:
return
messages[-1] = build_assistant_message(content)
return
messages.append(build_assistant_message(content))
def _normalize_tool_result(
self,
spec: AgentRunSpec,
tool_call_id: str,
tool_name: str,
result: Any,
) -> Any:
result = ensure_nonempty_tool_result(tool_name, result)
try:
content = maybe_persist_tool_result(
spec.workspace,
spec.session_key,
tool_call_id,
result,
max_chars=spec.max_tool_result_chars,
)
except Exception as exc:
logger.warning(
"Tool result persist failed for {} in {}: {}; using raw result",
tool_call_id,
spec.session_key or "default",
exc,
)
content = result
if isinstance(content, str) and len(content) > spec.max_tool_result_chars:
return truncate_text(content, spec.max_tool_result_chars)
return content
def _apply_tool_result_budget(
self,
spec: AgentRunSpec,
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
updated = messages
for idx, message in enumerate(messages):
if message.get("role") != "tool":
continue
normalized = self._normalize_tool_result(
spec,
str(message.get("tool_call_id") or f"tool_{idx}"),
str(message.get("name") or "tool"),
message.get("content"),
)
if normalized != message.get("content"):
if updated is messages:
updated = [dict(m) for m in messages]
updated[idx]["content"] = normalized
return updated
def _snip_history(
self,
spec: AgentRunSpec,
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
if not messages or not spec.context_window_tokens:
return messages
provider_max_tokens = getattr(getattr(self.provider, "generation", None), "max_tokens", 4096)
max_output = spec.max_tokens if isinstance(spec.max_tokens, int) else (
provider_max_tokens if isinstance(provider_max_tokens, int) else 4096
)
budget = spec.context_block_limit or (
spec.context_window_tokens - max_output - _SNIP_SAFETY_BUFFER
)
if budget <= 0:
return messages
estimate, _ = estimate_prompt_tokens_chain(
self.provider,
spec.model,
messages,
spec.tools.get_definitions(),
)
if estimate <= budget:
return messages
system_messages = [dict(msg) for msg in messages if msg.get("role") == "system"]
non_system = [dict(msg) for msg in messages if msg.get("role") != "system"]
if not non_system:
return messages
system_tokens = sum(estimate_message_tokens(msg) for msg in system_messages)
remaining_budget = max(128, budget - system_tokens)
kept: list[dict[str, Any]] = []
kept_tokens = 0
for message in reversed(non_system):
msg_tokens = estimate_message_tokens(message)
if kept and kept_tokens + msg_tokens > remaining_budget:
break
kept.append(message)
kept_tokens += msg_tokens
kept.reverse()
if kept:
for i, message in enumerate(kept):
if message.get("role") == "user":
kept = kept[i:]
break
start = find_legal_message_start(kept)
if start:
kept = kept[start:]
if not kept:
kept = non_system[-min(len(non_system), 4) :]
start = find_legal_message_start(kept)
if start:
kept = kept[start:]
return system_messages + kept
def _partition_tool_batches(
self,
spec: AgentRunSpec,
tool_calls: list[ToolCallRequest],
) -> list[list[ToolCallRequest]]:
if not spec.concurrent_tools:
return [[tool_call] for tool_call in tool_calls]
batches: list[list[ToolCallRequest]] = []
current: list[ToolCallRequest] = []
for tool_call in tool_calls:
get_tool = getattr(spec.tools, "get", None)
tool = get_tool(tool_call.name) if callable(get_tool) else None
can_batch = bool(tool and tool.concurrency_safe)
if can_batch:
current.append(tool_call)
continue
if current:
batches.append(current)
current = []
batches.append([tool_call])
if current:
batches.append(current)
return batches

View File

@ -44,6 +44,7 @@ class SubagentManager:
provider: LLMProvider,
workspace: Path,
bus: MessageBus,
max_tool_result_chars: int,
model: str | None = None,
web_search_config: "WebSearchConfig | None" = None,
web_proxy: str | None = None,
@ -56,6 +57,7 @@ class SubagentManager:
self.workspace = workspace
self.bus = bus
self.model = model or provider.get_default_model()
self.max_tool_result_chars = max_tool_result_chars
self.web_search_config = web_search_config or WebSearchConfig()
self.web_proxy = web_proxy
self.exec_config = exec_config or ExecToolConfig()
@ -136,6 +138,7 @@ class SubagentManager:
tools=tools,
model=self.model,
max_iterations=15,
max_tool_result_chars=self.max_tool_result_chars,
hook=_SubagentHook(task_id),
max_iterations_message="Task completed but no final response was generated.",
error_message=None,

View File

@ -53,6 +53,21 @@ class Tool(ABC):
"""JSON Schema for tool parameters."""
pass
@property
def read_only(self) -> bool:
"""Whether this tool is side-effect free and safe to parallelize."""
return False
@property
def concurrency_safe(self) -> bool:
"""Whether this tool can run alongside other concurrency-safe tools."""
return self.read_only and not self.exclusive
@property
def exclusive(self) -> bool:
"""Whether this tool should run alone even if concurrency is enabled."""
return False
@abstractmethod
async def execute(self, **kwargs: Any) -> Any:
"""

View File

@ -73,6 +73,10 @@ class ReadFileTool(_FsTool):
"Use offset and limit to paginate through large files."
)
@property
def read_only(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
@ -344,6 +348,10 @@ class ListDirTool(_FsTool):
"Common noise directories (.git, node_modules, __pycache__, etc.) are auto-ignored."
)
@property
def read_only(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {

View File

@ -84,6 +84,9 @@ class MessageTool(Tool):
media: list[str] | None = None,
**kwargs: Any
) -> str:
from nanobot.utils.helpers import strip_think
content = strip_think(content)
channel = channel or self._default_channel
chat_id = chat_id or self._default_chat_id
# Only inherit default message_id when targeting the same channel+chat.

View File

@ -35,22 +35,35 @@ class ToolRegistry:
"""Get all tool definitions in OpenAI format."""
return [tool.to_schema() for tool in self._tools.values()]
def prepare_call(
self,
name: str,
params: dict[str, Any],
) -> tuple[Tool | None, dict[str, Any], str | None]:
"""Resolve, cast, and validate one tool call."""
tool = self._tools.get(name)
if not tool:
return None, params, (
f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
)
cast_params = tool.cast_params(params)
errors = tool.validate_params(cast_params)
if errors:
return tool, cast_params, (
f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors)
)
return tool, cast_params, None
async def execute(self, name: str, params: dict[str, Any]) -> Any:
"""Execute a tool by name with given parameters."""
_HINT = "\n\n[Analyze the error above and try a different approach.]"
tool = self._tools.get(name)
if not tool:
return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
tool, params, error = self.prepare_call(name, params)
if error:
return error + _HINT
try:
# Attempt to cast parameters to match schema types
params = tool.cast_params(params)
# Validate parameters
errors = tool.validate_params(params)
if errors:
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT
assert tool is not None # guarded by prepare_call()
result = await tool.execute(**params)
if isinstance(result, str) and result.startswith("Error"):
return result + _HINT

View File

@ -52,6 +52,10 @@ class ExecTool(Tool):
def description(self) -> str:
return "Execute a shell command and return its output. Use with caution."
@property
def exclusive(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {

View File

@ -92,6 +92,10 @@ class WebSearchTool(Tool):
self.config = config if config is not None else WebSearchConfig()
self.proxy = proxy
@property
def read_only(self) -> bool:
return True
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
provider = self.config.provider.strip().lower() or "brave"
n = min(max(count or self.config.max_results, 1), 10)
@ -234,6 +238,10 @@ class WebFetchTool(Tool):
self.max_chars = max_chars
self.proxy = proxy
@property
def read_only(self) -> bool:
return True
async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> Any:
max_chars = maxChars or self.max_chars
is_valid, error_msg = _validate_url_safe(url)

View File

@ -14,6 +14,8 @@ from typing import Any
from aiohttp import web
from loguru import logger
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
API_SESSION_KEY = "api:default"
API_CHAT_ID = "default"
@ -98,7 +100,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
logger.info("API request session_key={} content={}", session_key, user_content[:80])
_FALLBACK = "I've completed processing but have no response to give."
_FALLBACK = EMPTY_FINAL_RESPONSE_MESSAGE
try:
async with session_lock:

View File

@ -134,6 +134,7 @@ class QQConfig(Base):
secret: str = ""
allow_from: list[str] = Field(default_factory=list)
msg_format: Literal["plain", "markdown"] = "plain"
ack_message: str = "⏳ Processing..."
# Optional: directory to save inbound attachments. If empty, use nanobot get_media_dir("qq").
media_dir: str = ""
@ -484,6 +485,17 @@ class QQChannel(BaseChannel):
if not content and not media_paths:
return
if self.config.ack_message:
try:
await self._send_text_only(
chat_id=chat_id,
is_group=is_group,
msg_id=data.id,
content=self.config.ack_message,
)
except Exception:
logger.debug("QQ ack message failed for chat_id={}", chat_id)
await self._handle_message(
sender_id=user_id,
chat_id=chat_id,

View File

@ -275,13 +275,10 @@ class TelegramChannel(BaseChannel):
self._app = builder.build()
self._app.add_error_handler(self._on_error)
# Add command handlers
self._app.add_handler(CommandHandler("start", self._on_start))
self._app.add_handler(CommandHandler("new", self._forward_command))
self._app.add_handler(CommandHandler("stop", self._forward_command))
self._app.add_handler(CommandHandler("restart", self._forward_command))
self._app.add_handler(CommandHandler("status", self._forward_command))
self._app.add_handler(CommandHandler("help", self._on_help))
# Add command handlers (using Regex to support @username suffixes before bot initialization)
self._app.add_handler(MessageHandler(filters.Regex(r"^/start(?:@\w+)?$"), self._on_start))
self._app.add_handler(MessageHandler(filters.Regex(r"^/(new|stop|restart|status)(?:@\w+)?$"), self._forward_command))
self._app.add_handler(MessageHandler(filters.Regex(r"^/help(?:@\w+)?$"), self._on_help))
# Add message handler for text, photos, voice, documents
self._app.add_handler(
@ -313,7 +310,7 @@ class TelegramChannel(BaseChannel):
# Start polling (this runs until stopped)
await self._app.updater.start_polling(
allowed_updates=["message"],
drop_pending_updates=True # Ignore old messages on startup
drop_pending_updates=False # Process pending messages on startup
)
# Keep running until stopped
@ -362,9 +359,14 @@ class TelegramChannel(BaseChannel):
logger.warning("Telegram bot not running")
return
# Only stop typing indicator for final responses
# Only stop typing indicator and remove reaction for final responses
if not msg.metadata.get("_progress", False):
self._stop_typing(msg.chat_id)
if reply_to_message_id := msg.metadata.get("message_id"):
try:
await self._remove_reaction(msg.chat_id, int(reply_to_message_id))
except ValueError:
pass
try:
chat_id = int(msg.chat_id)
@ -435,7 +437,9 @@ class TelegramChannel(BaseChannel):
await self._send_text(chat_id, chunk, reply_params, thread_kwargs)
async def _call_with_retry(self, fn, *args, **kwargs):
"""Call an async Telegram API function with retry on pool/network timeout."""
"""Call an async Telegram API function with retry on pool/network timeout and RetryAfter."""
from telegram.error import RetryAfter
for attempt in range(1, _SEND_MAX_RETRIES + 1):
try:
return await fn(*args, **kwargs)
@ -448,6 +452,15 @@ class TelegramChannel(BaseChannel):
attempt, _SEND_MAX_RETRIES, delay,
)
await asyncio.sleep(delay)
except RetryAfter as e:
if attempt == _SEND_MAX_RETRIES:
raise
delay = float(e.retry_after)
logger.warning(
"Telegram Flood Control (attempt {}/{}), retrying in {:.1f}s",
attempt, _SEND_MAX_RETRIES, delay,
)
await asyncio.sleep(delay)
async def _send_text(
self,
@ -498,6 +511,11 @@ class TelegramChannel(BaseChannel):
if stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id:
return
self._stop_typing(chat_id)
if reply_to_message_id := meta.get("message_id"):
try:
await self._remove_reaction(chat_id, int(reply_to_message_id))
except ValueError:
pass
try:
html = _markdown_to_telegram_html(buf.text)
await self._call_with_retry(
@ -619,8 +637,7 @@ class TelegramChannel(BaseChannel):
"reply_to_message_id": getattr(reply_to, "message_id", None) if reply_to else None,
}
@staticmethod
def _extract_reply_context(message) -> str | None:
async def _extract_reply_context(self, message) -> str | None:
"""Extract text from the message being replied to, if any."""
reply = getattr(message, "reply_to_message", None)
if not reply:
@ -628,7 +645,21 @@ class TelegramChannel(BaseChannel):
text = getattr(reply, "text", None) or getattr(reply, "caption", None) or ""
if len(text) > TELEGRAM_REPLY_CONTEXT_MAX_LEN:
text = text[:TELEGRAM_REPLY_CONTEXT_MAX_LEN] + "..."
return f"[Reply to: {text}]" if text else None
if not text:
return None
bot_id, _ = await self._ensure_bot_identity()
reply_user = getattr(reply, "from_user", None)
if bot_id and reply_user and getattr(reply_user, "id", None) == bot_id:
return f"[Reply to bot: {text}]"
elif reply_user and getattr(reply_user, "username", None):
return f"[Reply to @{reply_user.username}: {text}]"
elif reply_user and getattr(reply_user, "first_name", None):
return f"[Reply to {reply_user.first_name}: {text}]"
else:
return f"[Reply to: {text}]"
async def _download_message_media(
self, msg, *, add_failure_content: bool = False
@ -765,10 +796,18 @@ class TelegramChannel(BaseChannel):
message = update.message
user = update.effective_user
self._remember_thread_context(message)
# Strip @bot_username suffix if present
content = message.text or ""
if content.startswith("/") and "@" in content:
cmd_part, *rest = content.split(" ", 1)
cmd_part = cmd_part.split("@")[0]
content = f"{cmd_part} {rest[0]}" if rest else cmd_part
await self._handle_message(
sender_id=self._sender_id(user),
chat_id=str(message.chat_id),
content=message.text or "",
content=content,
metadata=self._build_message_metadata(message, user),
session_key=self._derive_topic_session_key(message),
)
@ -812,7 +851,7 @@ class TelegramChannel(BaseChannel):
# Reply context: text and/or media from the replied-to message
reply = getattr(message, "reply_to_message", None)
if reply is not None:
reply_ctx = self._extract_reply_context(message)
reply_ctx = await self._extract_reply_context(message)
reply_media, reply_media_parts = await self._download_message_media(reply)
if reply_media:
media_paths = reply_media + media_paths
@ -903,6 +942,19 @@ class TelegramChannel(BaseChannel):
except Exception as e:
logger.debug("Telegram reaction failed: {}", e)
async def _remove_reaction(self, chat_id: str, message_id: int) -> None:
"""Remove emoji reaction from a message (best-effort, non-blocking)."""
if not self._app:
return
try:
await self._app.bot.set_message_reaction(
chat_id=int(chat_id),
message_id=message_id,
reaction=[],
)
except Exception as e:
logger.debug("Telegram reaction removal failed: {}", e)
async def _typing_loop(self, chat_id: str) -> None:
"""Repeatedly send 'typing' action until cancelled."""
try:

View File

@ -542,6 +542,9 @@ def serve(
model=runtime_config.agents.defaults.model,
max_iterations=runtime_config.agents.defaults.max_tool_iterations,
context_window_tokens=runtime_config.agents.defaults.context_window_tokens,
context_block_limit=runtime_config.agents.defaults.context_block_limit,
max_tool_result_chars=runtime_config.agents.defaults.max_tool_result_chars,
provider_retry_mode=runtime_config.agents.defaults.provider_retry_mode,
web_search_config=runtime_config.tools.web.search,
web_proxy=runtime_config.tools.web.proxy or None,
exec_config=runtime_config.tools.exec,
@ -629,6 +632,9 @@ def gateway(
model=config.agents.defaults.model,
max_iterations=config.agents.defaults.max_tool_iterations,
context_window_tokens=config.agents.defaults.context_window_tokens,
context_block_limit=config.agents.defaults.context_block_limit,
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
provider_retry_mode=config.agents.defaults.provider_retry_mode,
web_search_config=config.tools.web.search,
web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec,
@ -835,6 +841,9 @@ def agent(
model=config.agents.defaults.model,
max_iterations=config.agents.defaults.max_tool_iterations,
context_window_tokens=config.agents.defaults.context_window_tokens,
context_block_limit=config.agents.defaults.context_block_limit,
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
provider_retry_mode=config.agents.defaults.provider_retry_mode,
web_search_config=config.tools.web.search,
web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec,
@ -1023,12 +1032,18 @@ app.add_typer(channels_app, name="channels")
@channels_app.command("status")
def channels_status():
def channels_status(
config_path: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
):
"""Show channel status."""
from nanobot.channels.registry import discover_all
from nanobot.config.loader import load_config
from nanobot.config.loader import load_config, set_config_path
config = load_config()
resolved_config_path = Path(config_path).expanduser().resolve() if config_path else None
if resolved_config_path is not None:
set_config_path(resolved_config_path)
config = load_config(resolved_config_path)
table = Table(title="Channel Status")
table.add_column("Channel", style="cyan")
@ -1115,12 +1130,17 @@ def _get_bridge_dir() -> Path:
def channels_login(
channel_name: str = typer.Argument(..., help="Channel name (e.g. weixin, whatsapp)"),
force: bool = typer.Option(False, "--force", "-f", help="Force re-authentication even if already logged in"),
config_path: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
):
"""Authenticate with a channel via QR code or other interactive login."""
from nanobot.channels.registry import discover_all
from nanobot.config.loader import load_config
from nanobot.config.loader import load_config, set_config_path
config = load_config()
resolved_config_path = Path(config_path).expanduser().resolve() if config_path else None
if resolved_config_path is not None:
set_config_path(resolved_config_path)
config = load_config(resolved_config_path)
channel_cfg = getattr(config.channels, channel_name, None) or {}
# Validate channel exists

View File

@ -26,7 +26,10 @@ async def cmd_stop(ctx: CommandContext) -> OutboundMessage:
sub_cancelled = await loop.subagents.cancel_by_session(msg.session_key)
total = cancelled + sub_cancelled
content = f"Stopped {total} task(s)." if total else "No active task to stop."
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, content=content)
return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content=content,
metadata=dict(msg.metadata or {})
)
async def cmd_restart(ctx: CommandContext) -> OutboundMessage:
@ -38,7 +41,10 @@ async def cmd_restart(ctx: CommandContext) -> OutboundMessage:
os.execv(sys.executable, [sys.executable, "-m", "nanobot"] + sys.argv[1:])
asyncio.create_task(_do_restart())
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, content="Restarting...")
return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content="Restarting...",
metadata=dict(msg.metadata or {})
)
async def cmd_status(ctx: CommandContext) -> OutboundMessage:
@ -62,7 +68,7 @@ async def cmd_status(ctx: CommandContext) -> OutboundMessage:
session_msg_count=len(session.get_history(max_messages=0)),
context_tokens_estimate=ctx_est,
),
metadata={"render_as": "text"},
metadata={**dict(ctx.msg.metadata or {}), "render_as": "text"},
)
@ -79,6 +85,7 @@ async def cmd_new(ctx: CommandContext) -> OutboundMessage:
return OutboundMessage(
channel=ctx.msg.channel, chat_id=ctx.msg.chat_id,
content="New session started.",
metadata=dict(ctx.msg.metadata or {})
)
@ -88,7 +95,7 @@ async def cmd_help(ctx: CommandContext) -> OutboundMessage:
channel=ctx.msg.channel,
chat_id=ctx.msg.chat_id,
content=build_help_text(),
metadata={"render_as": "text"},
metadata={**dict(ctx.msg.metadata or {}), "render_as": "text"},
)

View File

@ -38,8 +38,11 @@ class AgentDefaults(Base):
)
max_tokens: int = 8192
context_window_tokens: int = 65_536
context_block_limit: int | None = None
temperature: float = 0.1
max_tool_iterations: int = 40
max_tool_iterations: int = 200
max_tool_result_chars: int = 16_000
provider_retry_mode: Literal["standard", "persistent"] = "standard"
reasoning_effort: str | None = None # low / medium / high - enables LLM thinking mode
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"

View File

@ -73,6 +73,9 @@ class Nanobot:
model=defaults.model,
max_iterations=defaults.max_tool_iterations,
context_window_tokens=defaults.context_window_tokens,
context_block_limit=defaults.context_block_limit,
max_tool_result_chars=defaults.max_tool_result_chars,
provider_retry_mode=defaults.provider_retry_mode,
web_search_config=config.tools.web.search,
web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec,

View File

@ -2,6 +2,8 @@
from __future__ import annotations
import asyncio
import os
import re
import secrets
import string
@ -434,13 +436,33 @@ class AnthropicProvider(LLMProvider):
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
try:
async with self._client.messages.stream(**kwargs) as stream:
if on_content_delta:
async for text in stream.text_stream:
stream_iter = stream.text_stream.__aiter__()
while True:
try:
text = await asyncio.wait_for(
stream_iter.__anext__(),
timeout=idle_timeout_s,
)
except StopAsyncIteration:
break
await on_content_delta(text)
response = await stream.get_final_message()
response = await asyncio.wait_for(
stream.get_final_message(),
timeout=idle_timeout_s,
)
return self._parse_response(response)
except asyncio.TimeoutError:
return LLMResponse(
content=(
f"Error calling LLM: stream stalled for more than "
f"{idle_timeout_s} seconds"
),
finish_reason="error",
)
except Exception as e:
return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error")

View File

@ -1,31 +1,36 @@
"""Azure OpenAI provider implementation with API version 2024-10-21."""
"""Azure OpenAI provider using the OpenAI SDK Responses API.
Uses ``AsyncOpenAI`` pointed at ``https://{endpoint}/openai/v1/`` which
routes to the Responses API (``/responses``). Reuses shared conversion
helpers from :mod:`nanobot.providers.openai_responses`.
"""
from __future__ import annotations
import json
import uuid
from collections.abc import Awaitable, Callable
from typing import Any
from urllib.parse import urljoin
import httpx
import json_repair
from openai import AsyncOpenAI
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
from nanobot.providers.base import LLMProvider, LLMResponse
from nanobot.providers.openai_responses import (
consume_sdk_stream,
convert_messages,
convert_tools,
parse_response_output,
)
class AzureOpenAIProvider(LLMProvider):
"""
Azure OpenAI provider with API version 2024-10-21 compliance.
"""Azure OpenAI provider backed by the Responses API.
Features:
- Hardcoded API version 2024-10-21
- Uses model field as Azure deployment name in URL path
- Uses api-key header instead of Authorization Bearer
- Uses max_completion_tokens instead of max_tokens
- Direct HTTP calls, bypasses LiteLLM
- Uses the OpenAI Python SDK (``AsyncOpenAI``) with
``base_url = {endpoint}/openai/v1/``
- Calls ``client.responses.create()`` (Responses API)
- Reuses shared message/tool/SSE conversion from
``openai_responses``
"""
def __init__(
@ -36,40 +41,28 @@ class AzureOpenAIProvider(LLMProvider):
):
super().__init__(api_key, api_base)
self.default_model = default_model
self.api_version = "2024-10-21"
# Validate required parameters
if not api_key:
raise ValueError("Azure OpenAI api_key is required")
if not api_base:
raise ValueError("Azure OpenAI api_base is required")
# Ensure api_base ends with /
if not api_base.endswith('/'):
api_base += '/'
# Normalise: ensure trailing slash
if not api_base.endswith("/"):
api_base += "/"
self.api_base = api_base
def _build_chat_url(self, deployment_name: str) -> str:
"""Build the Azure OpenAI chat completions URL."""
# Azure OpenAI URL format:
# https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}
base_url = self.api_base
if not base_url.endswith('/'):
base_url += '/'
url = urljoin(
base_url,
f"openai/deployments/{deployment_name}/chat/completions"
# SDK client targeting the Azure Responses API endpoint
base_url = f"{api_base.rstrip('/')}/openai/v1/"
self._client = AsyncOpenAI(
api_key=api_key,
base_url=base_url,
default_headers={"x-session-affinity": uuid.uuid4().hex},
)
return f"{url}?api-version={self.api_version}"
def _build_headers(self) -> dict[str, str]:
"""Build headers for Azure OpenAI API with api-key header."""
return {
"Content-Type": "application/json",
"api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization
"x-session-affinity": uuid.uuid4().hex, # For cache locality
}
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@staticmethod
def _supports_temperature(
@ -82,36 +75,51 @@ class AzureOpenAIProvider(LLMProvider):
name = deployment_name.lower()
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
def _prepare_request_payload(
def _build_body(
self,
deployment_name: str,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None,
model: str | None,
max_tokens: int,
temperature: float,
reasoning_effort: str | None,
tool_choice: str | dict[str, Any] | None,
) -> dict[str, Any]:
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
payload: dict[str, Any] = {
"messages": self._sanitize_request_messages(
self._sanitize_empty_content(messages),
_AZURE_MSG_KEYS,
),
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
"""Build the Responses API request body from Chat-Completions-style args."""
deployment = model or self.default_model
instructions, input_items = convert_messages(self._sanitize_empty_content(messages))
body: dict[str, Any] = {
"model": deployment,
"instructions": instructions or None,
"input": input_items,
"max_output_tokens": max(1, max_tokens),
"store": False,
"stream": False,
}
if self._supports_temperature(deployment_name, reasoning_effort):
payload["temperature"] = temperature
if self._supports_temperature(deployment, reasoning_effort):
body["temperature"] = temperature
if reasoning_effort:
payload["reasoning_effort"] = reasoning_effort
body["reasoning"] = {"effort": reasoning_effort}
body["include"] = ["reasoning.encrypted_content"]
if tools:
payload["tools"] = tools
payload["tool_choice"] = tool_choice or "auto"
body["tools"] = convert_tools(tools)
body["tool_choice"] = tool_choice or "auto"
return payload
return body
@staticmethod
def _handle_error(e: Exception) -> LLMResponse:
body = getattr(e, "body", None) or getattr(getattr(e, "response", None), "text", None)
msg = f"Error: {str(body).strip()[:500]}" if body else f"Error calling Azure OpenAI: {e}"
return LLMResponse(content=msg, finish_reason="error")
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
async def chat(
self,
@ -123,92 +131,15 @@ class AzureOpenAIProvider(LLMProvider):
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
"""
Send a chat completion request to Azure OpenAI.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions in OpenAI format.
model: Model identifier (used as deployment name).
max_tokens: Maximum tokens in response (mapped to max_completion_tokens).
temperature: Sampling temperature.
reasoning_effort: Optional reasoning effort parameter.
Returns:
LLMResponse with content and/or tool calls.
"""
deployment_name = model or self.default_model
url = self._build_chat_url(deployment_name)
headers = self._build_headers()
payload = self._prepare_request_payload(
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort,
tool_choice=tool_choice,
body = self._build_body(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
try:
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
response = await client.post(url, headers=headers, json=payload)
if response.status_code != 200:
return LLMResponse(
content=f"Azure OpenAI API Error {response.status_code}: {response.text}",
finish_reason="error",
)
response_data = response.json()
return self._parse_response(response_data)
response = await self._client.responses.create(**body)
return parse_response_output(response)
except Exception as e:
return LLMResponse(
content=f"Error calling Azure OpenAI: {repr(e)}",
finish_reason="error",
)
def _parse_response(self, response: dict[str, Any]) -> LLMResponse:
"""Parse Azure OpenAI response into our standard format."""
try:
choice = response["choices"][0]
message = choice["message"]
tool_calls = []
if message.get("tool_calls"):
for tc in message["tool_calls"]:
# Parse arguments from JSON string if needed
args = tc["function"]["arguments"]
if isinstance(args, str):
args = json_repair.loads(args)
tool_calls.append(
ToolCallRequest(
id=tc["id"],
name=tc["function"]["name"],
arguments=args,
)
)
usage = {}
if response.get("usage"):
usage_data = response["usage"]
usage = {
"prompt_tokens": usage_data.get("prompt_tokens", 0),
"completion_tokens": usage_data.get("completion_tokens", 0),
"total_tokens": usage_data.get("total_tokens", 0),
}
reasoning_content = message.get("reasoning_content") or None
return LLMResponse(
content=message.get("content"),
tool_calls=tool_calls,
finish_reason=choice.get("finish_reason", "stop"),
usage=usage,
reasoning_content=reasoning_content,
)
except (KeyError, IndexError) as e:
return LLMResponse(
content=f"Error parsing Azure OpenAI response: {str(e)}",
finish_reason="error",
)
return self._handle_error(e)
async def chat_stream(
self,
@ -221,89 +152,26 @@ class AzureOpenAIProvider(LLMProvider):
tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
"""Stream a chat completion via Azure OpenAI SSE."""
deployment_name = model or self.default_model
url = self._build_chat_url(deployment_name)
headers = self._build_headers()
payload = self._prepare_request_payload(
deployment_name, messages, tools, max_tokens, temperature,
reasoning_effort, tool_choice=tool_choice,
body = self._build_body(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
payload["stream"] = True
body["stream"] = True
try:
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
async with client.stream("POST", url, headers=headers, json=payload) as response:
if response.status_code != 200:
text = await response.aread()
return LLMResponse(
content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}",
finish_reason="error",
)
return await self._consume_stream(response, on_content_delta)
except Exception as e:
return LLMResponse(content=f"Error calling Azure OpenAI: {repr(e)}", finish_reason="error")
async def _consume_stream(
self,
response: httpx.Response,
on_content_delta: Callable[[str], Awaitable[None]] | None,
) -> LLMResponse:
"""Parse Azure OpenAI SSE stream into an LLMResponse."""
content_parts: list[str] = []
tool_call_buffers: dict[int, dict[str, str]] = {}
finish_reason = "stop"
async for line in response.aiter_lines():
if not line.startswith("data: "):
continue
data = line[6:].strip()
if data == "[DONE]":
break
try:
chunk = json.loads(data)
except Exception:
continue
choices = chunk.get("choices") or []
if not choices:
continue
choice = choices[0]
if choice.get("finish_reason"):
finish_reason = choice["finish_reason"]
delta = choice.get("delta") or {}
text = delta.get("content")
if text:
content_parts.append(text)
if on_content_delta:
await on_content_delta(text)
for tc in delta.get("tool_calls") or []:
idx = tc.get("index", 0)
buf = tool_call_buffers.setdefault(idx, {"id": "", "name": "", "arguments": ""})
if tc.get("id"):
buf["id"] = tc["id"]
fn = tc.get("function") or {}
if fn.get("name"):
buf["name"] = fn["name"]
if fn.get("arguments"):
buf["arguments"] += fn["arguments"]
tool_calls = [
ToolCallRequest(
id=buf["id"], name=buf["name"],
arguments=json_repair.loads(buf["arguments"]) if buf["arguments"] else {},
stream = await self._client.responses.create(**body)
content, tool_calls, finish_reason, usage, reasoning_content = (
await consume_sdk_stream(stream, on_content_delta)
)
for buf in tool_call_buffers.values()
]
return LLMResponse(
content="".join(content_parts) or None,
tool_calls=tool_calls,
finish_reason=finish_reason,
)
return LLMResponse(
content=content or None,
tool_calls=tool_calls,
finish_reason=finish_reason,
usage=usage,
reasoning_content=reasoning_content,
)
except Exception as e:
return self._handle_error(e)
def get_default_model(self) -> str:
"""Get the default model (also used as default deployment name)."""
return self.default_model

View File

@ -2,6 +2,7 @@
import asyncio
import json
import re
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
@ -9,6 +10,8 @@ from typing import Any
from loguru import logger
from nanobot.utils.helpers import image_placeholder_text
@dataclass
class ToolCallRequest:
@ -57,13 +60,7 @@ class LLMResponse:
@dataclass(frozen=True)
class GenerationSettings:
"""Default generation parameters for LLM calls.
Stored on the provider so every call site inherits the same defaults
without having to pass temperature / max_tokens / reasoning_effort
through every layer. Individual call sites can still override by
passing explicit keyword arguments to chat() / chat_with_retry().
"""
"""Default generation settings."""
temperature: float = 0.7
max_tokens: int = 4096
@ -71,14 +68,12 @@ class GenerationSettings:
class LLMProvider(ABC):
"""
Abstract base class for LLM providers.
Implementations should handle the specifics of each provider's API
while maintaining a consistent interface.
"""
"""Base class for LLM providers."""
_CHAT_RETRY_DELAYS = (1, 2, 4)
_PERSISTENT_MAX_DELAY = 60
_PERSISTENT_IDENTICAL_ERROR_LIMIT = 10
_RETRY_HEARTBEAT_CHUNK = 30
_TRANSIENT_ERROR_MARKERS = (
"429",
"rate limit",
@ -208,7 +203,7 @@ class LLMProvider(ABC):
for b in content:
if isinstance(b, dict) and b.get("type") == "image_url":
path = (b.get("_meta") or {}).get("path", "")
placeholder = f"[image: {path}]" if path else "[image omitted]"
placeholder = image_placeholder_text(path, empty="[image omitted]")
new_content.append({"type": "text", "text": placeholder})
found = True
else:
@ -273,6 +268,8 @@ class LLMProvider(ABC):
reasoning_effort: object = _SENTINEL,
tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
retry_mode: str = "standard",
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
"""Call chat_stream() with retry on transient provider failures."""
if max_tokens is self._SENTINEL:
@ -288,28 +285,13 @@ class LLMProvider(ABC):
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
on_content_delta=on_content_delta,
)
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
response = await self._safe_chat_stream(**kw)
if response.finish_reason != "error":
return response
if not self._is_transient_error(response.content):
stripped = self._strip_image_content(messages)
if stripped is not None:
logger.warning("Non-transient LLM error with image content, retrying without images")
return await self._safe_chat_stream(**{**kw, "messages": stripped})
return response
logger.warning(
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
attempt, len(self._CHAT_RETRY_DELAYS), delay,
(response.content or "")[:120].lower(),
)
await asyncio.sleep(delay)
return await self._safe_chat_stream(**kw)
return await self._run_with_retry(
self._safe_chat_stream,
kw,
messages,
retry_mode=retry_mode,
on_retry_wait=on_retry_wait,
)
async def chat_with_retry(
self,
@ -320,6 +302,8 @@ class LLMProvider(ABC):
temperature: object = _SENTINEL,
reasoning_effort: object = _SENTINEL,
tool_choice: str | dict[str, Any] | None = None,
retry_mode: str = "standard",
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
"""Call chat() with retry on transient provider failures.
@ -339,28 +323,118 @@ class LLMProvider(ABC):
max_tokens=max_tokens, temperature=temperature,
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
)
return await self._run_with_retry(
self._safe_chat,
kw,
messages,
retry_mode=retry_mode,
on_retry_wait=on_retry_wait,
)
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
response = await self._safe_chat(**kw)
@classmethod
def _extract_retry_after(cls, content: str | None) -> float | None:
text = (content or "").lower()
match = re.search(r"retry after\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)?", text)
if not match:
return None
value = float(match.group(1))
unit = (match.group(2) or "s").lower()
if unit in {"ms", "milliseconds"}:
return max(0.1, value / 1000.0)
if unit in {"m", "min", "minutes"}:
return value * 60.0
return value
async def _sleep_with_heartbeat(
self,
delay: float,
*,
attempt: int,
persistent: bool,
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
) -> None:
remaining = max(0.0, delay)
while remaining > 0:
if on_retry_wait:
kind = "persistent retry" if persistent else "retry"
await on_retry_wait(
f"Model request failed, {kind} in {max(1, int(round(remaining)))}s "
f"(attempt {attempt})."
)
chunk = min(remaining, self._RETRY_HEARTBEAT_CHUNK)
await asyncio.sleep(chunk)
remaining -= chunk
async def _run_with_retry(
self,
call: Callable[..., Awaitable[LLMResponse]],
kw: dict[str, Any],
original_messages: list[dict[str, Any]],
*,
retry_mode: str,
on_retry_wait: Callable[[str], Awaitable[None]] | None,
) -> LLMResponse:
attempt = 0
delays = list(self._CHAT_RETRY_DELAYS)
persistent = retry_mode == "persistent"
last_response: LLMResponse | None = None
last_error_key: str | None = None
identical_error_count = 0
while True:
attempt += 1
response = await call(**kw)
if response.finish_reason != "error":
return response
last_response = response
error_key = ((response.content or "").strip().lower() or None)
if error_key and error_key == last_error_key:
identical_error_count += 1
else:
last_error_key = error_key
identical_error_count = 1 if error_key else 0
if not self._is_transient_error(response.content):
stripped = self._strip_image_content(messages)
if stripped is not None:
logger.warning("Non-transient LLM error with image content, retrying without images")
return await self._safe_chat(**{**kw, "messages": stripped})
stripped = self._strip_image_content(original_messages)
if stripped is not None and stripped != kw["messages"]:
logger.warning(
"Non-transient LLM error with image content, retrying without images"
)
retry_kw = dict(kw)
retry_kw["messages"] = stripped
return await call(**retry_kw)
return response
if persistent and identical_error_count >= self._PERSISTENT_IDENTICAL_ERROR_LIMIT:
logger.warning(
"Stopping persistent retry after {} identical transient errors: {}",
identical_error_count,
(response.content or "")[:120].lower(),
)
return response
if not persistent and attempt > len(delays):
break
base_delay = delays[min(attempt - 1, len(delays) - 1)]
delay = self._extract_retry_after(response.content) or base_delay
if persistent:
delay = min(delay, self._PERSISTENT_MAX_DELAY)
logger.warning(
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
attempt, len(self._CHAT_RETRY_DELAYS), delay,
"LLM transient error (attempt {}{}), retrying in {}s: {}",
attempt,
"+" if persistent and attempt > len(delays) else f"/{len(delays)}",
int(round(delay)),
(response.content or "")[:120].lower(),
)
await asyncio.sleep(delay)
await self._sleep_with_heartbeat(
delay,
attempt=attempt,
persistent=persistent,
on_retry_wait=on_retry_wait,
)
return await self._safe_chat(**kw)
return last_response if last_response is not None else await call(**kw)
@abstractmethod
def get_default_model(self) -> str:

View File

@ -6,13 +6,18 @@ import asyncio
import hashlib
import json
from collections.abc import Awaitable, Callable
from typing import Any, AsyncGenerator
from typing import Any
import httpx
from loguru import logger
from oauth_cli_kit import get_token as get_codex_token
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.openai_responses import (
consume_sse,
convert_messages,
convert_tools,
)
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
DEFAULT_ORIGINATOR = "nanobot"
@ -36,7 +41,7 @@ class OpenAICodexProvider(LLMProvider):
) -> LLMResponse:
"""Shared request logic for both chat() and chat_stream()."""
model = model or self.default_model
system_prompt, input_items = _convert_messages(messages)
system_prompt, input_items = convert_messages(messages)
token = await asyncio.to_thread(get_codex_token)
headers = _build_headers(token.account_id, token.access)
@ -56,7 +61,7 @@ class OpenAICodexProvider(LLMProvider):
if reasoning_effort:
body["reasoning"] = {"effort": reasoning_effort}
if tools:
body["tools"] = _convert_tools(tools)
body["tools"] = convert_tools(tools)
try:
try:
@ -127,96 +132,7 @@ async def _request_codex(
if response.status_code != 200:
text = await response.aread()
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
return await _consume_sse(response, on_content_delta)
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert OpenAI function-calling schema to Codex flat format."""
converted: list[dict[str, Any]] = []
for tool in tools:
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
name = fn.get("name")
if not name:
continue
params = fn.get("parameters") or {}
converted.append({
"type": "function",
"name": name,
"description": fn.get("description") or "",
"parameters": params if isinstance(params, dict) else {},
})
return converted
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
system_prompt = ""
input_items: list[dict[str, Any]] = []
for idx, msg in enumerate(messages):
role = msg.get("role")
content = msg.get("content")
if role == "system":
system_prompt = content if isinstance(content, str) else ""
continue
if role == "user":
input_items.append(_convert_user_message(content))
continue
if role == "assistant":
if isinstance(content, str) and content:
input_items.append({
"type": "message", "role": "assistant",
"content": [{"type": "output_text", "text": content}],
"status": "completed", "id": f"msg_{idx}",
})
for tool_call in msg.get("tool_calls", []) or []:
fn = tool_call.get("function") or {}
call_id, item_id = _split_tool_call_id(tool_call.get("id"))
input_items.append({
"type": "function_call",
"id": item_id or f"fc_{idx}",
"call_id": call_id or f"call_{idx}",
"name": fn.get("name"),
"arguments": fn.get("arguments") or "{}",
})
continue
if role == "tool":
call_id, _ = _split_tool_call_id(msg.get("tool_call_id"))
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text})
return system_prompt, input_items
def _convert_user_message(content: Any) -> dict[str, Any]:
if isinstance(content, str):
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
if isinstance(content, list):
converted: list[dict[str, Any]] = []
for item in content:
if not isinstance(item, dict):
continue
if item.get("type") == "text":
converted.append({"type": "input_text", "text": item.get("text", "")})
elif item.get("type") == "image_url":
url = (item.get("image_url") or {}).get("url")
if url:
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
if converted:
return {"role": "user", "content": converted}
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
def _split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
if isinstance(tool_call_id, str) and tool_call_id:
if "|" in tool_call_id:
call_id, item_id = tool_call_id.split("|", 1)
return call_id, item_id or None
return tool_call_id, None
return "call_0", None
return await consume_sse(response, on_content_delta)
def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
@ -224,96 +140,6 @@ def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
buffer: list[str] = []
async for line in response.aiter_lines():
if line == "":
if buffer:
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
buffer = []
if not data_lines:
continue
data = "\n".join(data_lines).strip()
if not data or data == "[DONE]":
continue
try:
yield json.loads(data)
except Exception:
continue
continue
buffer.append(line)
async def _consume_sse(
response: httpx.Response,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> tuple[str, list[ToolCallRequest], str]:
content = ""
tool_calls: list[ToolCallRequest] = []
tool_call_buffers: dict[str, dict[str, Any]] = {}
finish_reason = "stop"
async for event in _iter_sse(response):
event_type = event.get("type")
if event_type == "response.output_item.added":
item = event.get("item") or {}
if item.get("type") == "function_call":
call_id = item.get("call_id")
if not call_id:
continue
tool_call_buffers[call_id] = {
"id": item.get("id") or "fc_0",
"name": item.get("name"),
"arguments": item.get("arguments") or "",
}
elif event_type == "response.output_text.delta":
delta_text = event.get("delta") or ""
content += delta_text
if on_content_delta and delta_text:
await on_content_delta(delta_text)
elif event_type == "response.function_call_arguments.delta":
call_id = event.get("call_id")
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] += event.get("delta") or ""
elif event_type == "response.function_call_arguments.done":
call_id = event.get("call_id")
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] = event.get("arguments") or ""
elif event_type == "response.output_item.done":
item = event.get("item") or {}
if item.get("type") == "function_call":
call_id = item.get("call_id")
if not call_id:
continue
buf = tool_call_buffers.get(call_id) or {}
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
try:
args = json.loads(args_raw)
except Exception:
args = {"raw": args_raw}
tool_calls.append(
ToolCallRequest(
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
name=buf.get("name") or item.get("name"),
arguments=args,
)
)
elif event_type == "response.completed":
status = (event.get("response") or {}).get("status")
finish_reason = _map_finish_reason(status)
elif event_type in {"error", "response.failed"}:
raise RuntimeError("Codex response failed")
return content, tool_calls, finish_reason
_FINISH_REASON_MAP = {"completed": "stop", "incomplete": "length", "failed": "error", "cancelled": "error"}
def _map_finish_reason(status: str | None) -> str:
return _FINISH_REASON_MAP.get(status or "completed", "stop")
def _friendly_error(status_code: int, raw: str) -> str:
if status_code == 429:
return "ChatGPT usage quota exceeded or rate limit triggered. Please try again later."

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
import hashlib
import os
import secrets
@ -20,7 +21,6 @@ if TYPE_CHECKING:
_ALLOWED_MSG_KEYS = frozenset({
"role", "content", "tool_calls", "tool_call_id", "name",
"reasoning_content", "extra_content",
})
_ALNUM = string.ascii_letters + string.digits
@ -615,16 +615,33 @@ class OpenAICompatProvider(LLMProvider):
)
kwargs["stream"] = True
kwargs["stream_options"] = {"include_usage": True}
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
try:
stream = await self._client.chat.completions.create(**kwargs)
chunks: list[Any] = []
async for chunk in stream:
stream_iter = stream.__aiter__()
while True:
try:
chunk = await asyncio.wait_for(
stream_iter.__anext__(),
timeout=idle_timeout_s,
)
except StopAsyncIteration:
break
chunks.append(chunk)
if on_content_delta and chunk.choices:
text = getattr(chunk.choices[0].delta, "content", None)
if text:
await on_content_delta(text)
return self._parse_chunks(chunks)
except asyncio.TimeoutError:
return LLMResponse(
content=(
f"Error calling LLM: stream stalled for more than "
f"{idle_timeout_s} seconds"
),
finish_reason="error",
)
except Exception as e:
return self._handle_error(e)

View File

@ -0,0 +1,29 @@
"""Shared helpers for OpenAI Responses API providers (Codex, Azure OpenAI)."""
from nanobot.providers.openai_responses.converters import (
convert_messages,
convert_tools,
convert_user_message,
split_tool_call_id,
)
from nanobot.providers.openai_responses.parsing import (
FINISH_REASON_MAP,
consume_sdk_stream,
consume_sse,
iter_sse,
map_finish_reason,
parse_response_output,
)
__all__ = [
"convert_messages",
"convert_tools",
"convert_user_message",
"split_tool_call_id",
"iter_sse",
"consume_sse",
"consume_sdk_stream",
"map_finish_reason",
"parse_response_output",
"FINISH_REASON_MAP",
]

View File

@ -0,0 +1,110 @@
"""Convert Chat Completions messages/tools to Responses API format."""
from __future__ import annotations
import json
from typing import Any
def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
"""Convert Chat Completions messages to Responses API input items.
Returns ``(system_prompt, input_items)`` where *system_prompt* is extracted
from any ``system`` role message and *input_items* is the Responses API
``input`` array.
"""
system_prompt = ""
input_items: list[dict[str, Any]] = []
for idx, msg in enumerate(messages):
role = msg.get("role")
content = msg.get("content")
if role == "system":
system_prompt = content if isinstance(content, str) else ""
continue
if role == "user":
input_items.append(convert_user_message(content))
continue
if role == "assistant":
if isinstance(content, str) and content:
input_items.append({
"type": "message", "role": "assistant",
"content": [{"type": "output_text", "text": content}],
"status": "completed", "id": f"msg_{idx}",
})
for tool_call in msg.get("tool_calls", []) or []:
fn = tool_call.get("function") or {}
call_id, item_id = split_tool_call_id(tool_call.get("id"))
input_items.append({
"type": "function_call",
"id": item_id or f"fc_{idx}",
"call_id": call_id or f"call_{idx}",
"name": fn.get("name"),
"arguments": fn.get("arguments") or "{}",
})
continue
if role == "tool":
call_id, _ = split_tool_call_id(msg.get("tool_call_id"))
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text})
return system_prompt, input_items
def convert_user_message(content: Any) -> dict[str, Any]:
"""Convert a user message's content to Responses API format.
Handles plain strings, ``text`` blocks -> ``input_text``, and
``image_url`` blocks -> ``input_image``.
"""
if isinstance(content, str):
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
if isinstance(content, list):
converted: list[dict[str, Any]] = []
for item in content:
if not isinstance(item, dict):
continue
if item.get("type") == "text":
converted.append({"type": "input_text", "text": item.get("text", "")})
elif item.get("type") == "image_url":
url = (item.get("image_url") or {}).get("url")
if url:
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
if converted:
return {"role": "user", "content": converted}
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert OpenAI function-calling tool schema to Responses API flat format."""
converted: list[dict[str, Any]] = []
for tool in tools:
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
name = fn.get("name")
if not name:
continue
params = fn.get("parameters") or {}
converted.append({
"type": "function",
"name": name,
"description": fn.get("description") or "",
"parameters": params if isinstance(params, dict) else {},
})
return converted
def split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
"""Split a compound ``call_id|item_id`` string.
Returns ``(call_id, item_id)`` where *item_id* may be ``None``.
"""
if isinstance(tool_call_id, str) and tool_call_id:
if "|" in tool_call_id:
call_id, item_id = tool_call_id.split("|", 1)
return call_id, item_id or None
return tool_call_id, None
return "call_0", None

View File

@ -0,0 +1,297 @@
"""Parse Responses API SSE streams and SDK response objects."""
from __future__ import annotations
import json
from collections.abc import Awaitable, Callable
from typing import Any, AsyncGenerator
import httpx
import json_repair
from loguru import logger
from nanobot.providers.base import LLMResponse, ToolCallRequest
FINISH_REASON_MAP = {
"completed": "stop",
"incomplete": "length",
"failed": "error",
"cancelled": "error",
}
def map_finish_reason(status: str | None) -> str:
"""Map a Responses API status string to a Chat-Completions-style finish_reason."""
return FINISH_REASON_MAP.get(status or "completed", "stop")
async def iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
"""Yield parsed JSON events from a Responses API SSE stream."""
buffer: list[str] = []
def _flush() -> dict[str, Any] | None:
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
buffer.clear()
if not data_lines:
return None
data = "\n".join(data_lines).strip()
if not data or data == "[DONE]":
return None
try:
return json.loads(data)
except Exception:
logger.warning("Failed to parse SSE event JSON: {}", data[:200])
return None
async for line in response.aiter_lines():
if line == "":
if buffer:
event = _flush()
if event is not None:
yield event
continue
buffer.append(line)
# Flush any remaining buffer at EOF (#10)
if buffer:
event = _flush()
if event is not None:
yield event
async def consume_sse(
response: httpx.Response,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> tuple[str, list[ToolCallRequest], str]:
"""Consume a Responses API SSE stream into ``(content, tool_calls, finish_reason)``."""
content = ""
tool_calls: list[ToolCallRequest] = []
tool_call_buffers: dict[str, dict[str, Any]] = {}
finish_reason = "stop"
async for event in iter_sse(response):
event_type = event.get("type")
if event_type == "response.output_item.added":
item = event.get("item") or {}
if item.get("type") == "function_call":
call_id = item.get("call_id")
if not call_id:
continue
tool_call_buffers[call_id] = {
"id": item.get("id") or "fc_0",
"name": item.get("name"),
"arguments": item.get("arguments") or "",
}
elif event_type == "response.output_text.delta":
delta_text = event.get("delta") or ""
content += delta_text
if on_content_delta and delta_text:
await on_content_delta(delta_text)
elif event_type == "response.function_call_arguments.delta":
call_id = event.get("call_id")
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] += event.get("delta") or ""
elif event_type == "response.function_call_arguments.done":
call_id = event.get("call_id")
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] = event.get("arguments") or ""
elif event_type == "response.output_item.done":
item = event.get("item") or {}
if item.get("type") == "function_call":
call_id = item.get("call_id")
if not call_id:
continue
buf = tool_call_buffers.get(call_id) or {}
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
try:
args = json.loads(args_raw)
except Exception:
logger.warning(
"Failed to parse tool call arguments for '{}': {}",
buf.get("name") or item.get("name"),
args_raw[:200],
)
args = json_repair.loads(args_raw)
if not isinstance(args, dict):
args = {"raw": args_raw}
tool_calls.append(
ToolCallRequest(
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
name=buf.get("name") or item.get("name") or "",
arguments=args,
)
)
elif event_type == "response.completed":
status = (event.get("response") or {}).get("status")
finish_reason = map_finish_reason(status)
elif event_type in {"error", "response.failed"}:
detail = event.get("error") or event.get("message") or event
raise RuntimeError(f"Response failed: {str(detail)[:500]}")
return content, tool_calls, finish_reason
def parse_response_output(response: Any) -> LLMResponse:
"""Parse an SDK ``Response`` object into an ``LLMResponse``."""
if not isinstance(response, dict):
dump = getattr(response, "model_dump", None)
response = dump() if callable(dump) else vars(response)
output = response.get("output") or []
content_parts: list[str] = []
tool_calls: list[ToolCallRequest] = []
reasoning_content: str | None = None
for item in output:
if not isinstance(item, dict):
dump = getattr(item, "model_dump", None)
item = dump() if callable(dump) else vars(item)
item_type = item.get("type")
if item_type == "message":
for block in item.get("content") or []:
if not isinstance(block, dict):
dump = getattr(block, "model_dump", None)
block = dump() if callable(dump) else vars(block)
if block.get("type") == "output_text":
content_parts.append(block.get("text") or "")
elif item_type == "reasoning":
for s in item.get("summary") or []:
if not isinstance(s, dict):
dump = getattr(s, "model_dump", None)
s = dump() if callable(dump) else vars(s)
if s.get("type") == "summary_text" and s.get("text"):
reasoning_content = (reasoning_content or "") + s["text"]
elif item_type == "function_call":
call_id = item.get("call_id") or ""
item_id = item.get("id") or "fc_0"
args_raw = item.get("arguments") or "{}"
try:
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
except Exception:
logger.warning(
"Failed to parse tool call arguments for '{}': {}",
item.get("name"),
str(args_raw)[:200],
)
args = json_repair.loads(args_raw) if isinstance(args_raw, str) else args_raw
if not isinstance(args, dict):
args = {"raw": args_raw}
tool_calls.append(ToolCallRequest(
id=f"{call_id}|{item_id}",
name=item.get("name") or "",
arguments=args if isinstance(args, dict) else {},
))
usage_raw = response.get("usage") or {}
if not isinstance(usage_raw, dict):
dump = getattr(usage_raw, "model_dump", None)
usage_raw = dump() if callable(dump) else vars(usage_raw)
usage = {}
if usage_raw:
usage = {
"prompt_tokens": int(usage_raw.get("input_tokens") or 0),
"completion_tokens": int(usage_raw.get("output_tokens") or 0),
"total_tokens": int(usage_raw.get("total_tokens") or 0),
}
status = response.get("status")
finish_reason = map_finish_reason(status)
return LLMResponse(
content="".join(content_parts) or None,
tool_calls=tool_calls,
finish_reason=finish_reason,
usage=usage,
reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None,
)
async def consume_sdk_stream(
stream: Any,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> tuple[str, list[ToolCallRequest], str, dict[str, int], str | None]:
"""Consume an SDK async stream from ``client.responses.create(stream=True)``."""
content = ""
tool_calls: list[ToolCallRequest] = []
tool_call_buffers: dict[str, dict[str, Any]] = {}
finish_reason = "stop"
usage: dict[str, int] = {}
reasoning_content: str | None = None
async for event in stream:
event_type = getattr(event, "type", None)
if event_type == "response.output_item.added":
item = getattr(event, "item", None)
if item and getattr(item, "type", None) == "function_call":
call_id = getattr(item, "call_id", None)
if not call_id:
continue
tool_call_buffers[call_id] = {
"id": getattr(item, "id", None) or "fc_0",
"name": getattr(item, "name", None),
"arguments": getattr(item, "arguments", None) or "",
}
elif event_type == "response.output_text.delta":
delta_text = getattr(event, "delta", "") or ""
content += delta_text
if on_content_delta and delta_text:
await on_content_delta(delta_text)
elif event_type == "response.function_call_arguments.delta":
call_id = getattr(event, "call_id", None)
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] += getattr(event, "delta", "") or ""
elif event_type == "response.function_call_arguments.done":
call_id = getattr(event, "call_id", None)
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] = getattr(event, "arguments", "") or ""
elif event_type == "response.output_item.done":
item = getattr(event, "item", None)
if item and getattr(item, "type", None) == "function_call":
call_id = getattr(item, "call_id", None)
if not call_id:
continue
buf = tool_call_buffers.get(call_id) or {}
args_raw = buf.get("arguments") or getattr(item, "arguments", None) or "{}"
try:
args = json.loads(args_raw)
except Exception:
logger.warning(
"Failed to parse tool call arguments for '{}': {}",
buf.get("name") or getattr(item, "name", None),
str(args_raw)[:200],
)
args = json_repair.loads(args_raw)
if not isinstance(args, dict):
args = {"raw": args_raw}
tool_calls.append(
ToolCallRequest(
id=f"{call_id}|{buf.get('id') or getattr(item, 'id', None) or 'fc_0'}",
name=buf.get("name") or getattr(item, "name", None) or "",
arguments=args,
)
)
elif event_type == "response.completed":
resp = getattr(event, "response", None)
status = getattr(resp, "status", None) if resp else None
finish_reason = map_finish_reason(status)
if resp:
usage_obj = getattr(resp, "usage", None)
if usage_obj:
usage = {
"prompt_tokens": int(getattr(usage_obj, "input_tokens", 0) or 0),
"completion_tokens": int(getattr(usage_obj, "output_tokens", 0) or 0),
"total_tokens": int(getattr(usage_obj, "total_tokens", 0) or 0),
}
for out_item in getattr(resp, "output", None) or []:
if getattr(out_item, "type", None) == "reasoning":
for s in getattr(out_item, "summary", None) or []:
if getattr(s, "type", None) == "summary_text":
text = getattr(s, "text", None)
if text:
reasoning_content = (reasoning_content or "") + text
elif event_type in {"error", "response.failed"}:
detail = getattr(event, "error", None) or getattr(event, "message", None) or event
raise RuntimeError(f"Response failed: {str(detail)[:500]}")
return content, tool_calls, finish_reason, usage, reasoning_content

View File

@ -10,20 +10,12 @@ from typing import Any
from loguru import logger
from nanobot.config.paths import get_legacy_sessions_dir
from nanobot.utils.helpers import ensure_dir, safe_filename
from nanobot.utils.helpers import ensure_dir, find_legal_message_start, safe_filename
@dataclass
class Session:
"""
A conversation session.
Stores messages in JSONL format for easy reading and persistence.
Important: Messages are append-only for LLM cache efficiency.
The consolidation process writes summaries to MEMORY.md/HISTORY.md
but does NOT modify the messages list or get_history() output.
"""
"""A conversation session."""
key: str # channel:chat_id
messages: list[dict[str, Any]] = field(default_factory=list)
@ -43,43 +35,19 @@ class Session:
self.messages.append(msg)
self.updated_at = datetime.now()
@staticmethod
def _find_legal_start(messages: list[dict[str, Any]]) -> int:
"""Find first index where every tool result has a matching assistant tool_call."""
declared: set[str] = set()
start = 0
for i, msg in enumerate(messages):
role = msg.get("role")
if role == "assistant":
for tc in msg.get("tool_calls") or []:
if isinstance(tc, dict) and tc.get("id"):
declared.add(str(tc["id"]))
elif role == "tool":
tid = msg.get("tool_call_id")
if tid and str(tid) not in declared:
start = i + 1
declared.clear()
for prev in messages[start:i + 1]:
if prev.get("role") == "assistant":
for tc in prev.get("tool_calls") or []:
if isinstance(tc, dict) and tc.get("id"):
declared.add(str(tc["id"]))
return start
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
"""Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary."""
unconsolidated = self.messages[self.last_consolidated:]
sliced = unconsolidated[-max_messages:]
# Drop leading non-user messages to avoid starting mid-turn when possible.
# Avoid starting mid-turn when possible.
for i, message in enumerate(sliced):
if message.get("role") == "user":
sliced = sliced[i:]
break
# Some providers reject orphan tool results if the matching assistant
# tool_calls message fell outside the fixed-size history window.
start = self._find_legal_start(sliced)
# Drop orphan tool results at the front.
start = find_legal_message_start(sliced)
if start:
sliced = sliced[start:]
@ -115,7 +83,7 @@ class Session:
retained = self.messages[start_idx:]
# Mirror get_history(): avoid persisting orphan tool results at the front.
start = self._find_legal_start(retained)
start = find_legal_message_start(retained)
if start:
retained = retained[start:]

View File

@ -3,12 +3,15 @@
import base64
import json
import re
import shutil
import time
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any
import tiktoken
from loguru import logger
def strip_think(text: str) -> str:
@ -56,11 +59,7 @@ def timestamp() -> str:
def current_time_str(timezone: str | None = None) -> str:
"""Human-readable current time with weekday and UTC offset.
When *timezone* is a valid IANA name (e.g. ``"Asia/Shanghai"``), the time
is converted to that zone. Otherwise falls back to the host local time.
"""
"""Return the current time string."""
from zoneinfo import ZoneInfo
try:
@ -76,12 +75,164 @@ def current_time_str(timezone: str | None = None) -> str:
_UNSAFE_CHARS = re.compile(r'[<>:"/\\|?*]')
_TOOL_RESULT_PREVIEW_CHARS = 1200
_TOOL_RESULTS_DIR = ".nanobot/tool-results"
_TOOL_RESULT_RETENTION_SECS = 7 * 24 * 60 * 60
_TOOL_RESULT_MAX_BUCKETS = 32
def safe_filename(name: str) -> str:
"""Replace unsafe path characters with underscores."""
return _UNSAFE_CHARS.sub("_", name).strip()
def image_placeholder_text(path: str | None, *, empty: str = "[image]") -> str:
"""Build an image placeholder string."""
return f"[image: {path}]" if path else empty
def truncate_text(text: str, max_chars: int) -> str:
"""Truncate text with a stable suffix."""
if max_chars <= 0 or len(text) <= max_chars:
return text
return text[:max_chars] + "\n... (truncated)"
def find_legal_message_start(messages: list[dict[str, Any]]) -> int:
"""Find the first index whose tool results have matching assistant calls."""
declared: set[str] = set()
start = 0
for i, msg in enumerate(messages):
role = msg.get("role")
if role == "assistant":
for tc in msg.get("tool_calls") or []:
if isinstance(tc, dict) and tc.get("id"):
declared.add(str(tc["id"]))
elif role == "tool":
tid = msg.get("tool_call_id")
if tid and str(tid) not in declared:
start = i + 1
declared.clear()
for prev in messages[start : i + 1]:
if prev.get("role") == "assistant":
for tc in prev.get("tool_calls") or []:
if isinstance(tc, dict) and tc.get("id"):
declared.add(str(tc["id"]))
return start
def stringify_text_blocks(content: list[dict[str, Any]]) -> str | None:
parts: list[str] = []
for block in content:
if not isinstance(block, dict):
return None
if block.get("type") != "text":
return None
text = block.get("text")
if not isinstance(text, str):
return None
parts.append(text)
return "\n".join(parts)
def _render_tool_result_reference(
filepath: Path,
*,
original_size: int,
preview: str,
truncated_preview: bool,
) -> str:
result = (
f"[tool output persisted]\n"
f"Full output saved to: {filepath}\n"
f"Original size: {original_size} chars\n"
f"Preview:\n{preview}"
)
if truncated_preview:
result += "\n...\n(Read the saved file if you need the full output.)"
return result
def _bucket_mtime(path: Path) -> float:
try:
return path.stat().st_mtime
except OSError:
return 0.0
def _cleanup_tool_result_buckets(root: Path, current_bucket: Path) -> None:
siblings = [path for path in root.iterdir() if path.is_dir() and path != current_bucket]
cutoff = time.time() - _TOOL_RESULT_RETENTION_SECS
for path in siblings:
if _bucket_mtime(path) < cutoff:
shutil.rmtree(path, ignore_errors=True)
keep = max(_TOOL_RESULT_MAX_BUCKETS - 1, 0)
siblings = [path for path in siblings if path.exists()]
if len(siblings) <= keep:
return
siblings.sort(key=_bucket_mtime, reverse=True)
for path in siblings[keep:]:
shutil.rmtree(path, ignore_errors=True)
def _write_text_atomic(path: Path, content: str) -> None:
tmp = path.with_name(f".{path.name}.{uuid.uuid4().hex}.tmp")
try:
tmp.write_text(content, encoding="utf-8")
tmp.replace(path)
finally:
if tmp.exists():
tmp.unlink(missing_ok=True)
def maybe_persist_tool_result(
workspace: Path | None,
session_key: str | None,
tool_call_id: str,
content: Any,
*,
max_chars: int,
) -> Any:
"""Persist oversized tool output and replace it with a stable reference string."""
if workspace is None or max_chars <= 0:
return content
text_payload: str | None = None
suffix = "txt"
if isinstance(content, str):
text_payload = content
elif isinstance(content, list):
text_payload = stringify_text_blocks(content)
if text_payload is None:
return content
suffix = "json"
else:
return content
if len(text_payload) <= max_chars:
return content
root = ensure_dir(workspace / _TOOL_RESULTS_DIR)
bucket = ensure_dir(root / safe_filename(session_key or "default"))
try:
_cleanup_tool_result_buckets(root, bucket)
except Exception as exc:
logger.warning("Failed to clean stale tool result buckets in {}: {}", root, exc)
path = bucket / f"{safe_filename(tool_call_id)}.{suffix}"
if not path.exists():
if suffix == "json" and isinstance(content, list):
_write_text_atomic(path, json.dumps(content, ensure_ascii=False, indent=2))
else:
_write_text_atomic(path, text_payload)
preview = text_payload[:_TOOL_RESULT_PREVIEW_CHARS]
return _render_tool_result_reference(
path,
original_size=len(text_payload),
preview=preview,
truncated_preview=len(text_payload) > _TOOL_RESULT_PREVIEW_CHARS,
)
def split_message(content: str, max_len: int = 2000) -> list[str]:
"""
Split content into chunks within max_len, preferring line breaks.

88
nanobot/utils/runtime.py Normal file
View File

@ -0,0 +1,88 @@
"""Runtime-specific helper functions and constants."""
from __future__ import annotations
from typing import Any
from loguru import logger
from nanobot.utils.helpers import stringify_text_blocks
_MAX_REPEAT_EXTERNAL_LOOKUPS = 2
EMPTY_FINAL_RESPONSE_MESSAGE = (
"I completed the tool steps but couldn't produce a final answer. "
"Please try again or narrow the task."
)
FINALIZATION_RETRY_PROMPT = (
"You have already finished the tool work. Do not call any more tools. "
"Using only the conversation and tool results above, provide the final answer for the user now."
)
def empty_tool_result_message(tool_name: str) -> str:
"""Short prompt-safe marker for tools that completed without visible output."""
return f"({tool_name} completed with no output)"
def ensure_nonempty_tool_result(tool_name: str, content: Any) -> Any:
"""Replace semantically empty tool results with a short marker string."""
if content is None:
return empty_tool_result_message(tool_name)
if isinstance(content, str) and not content.strip():
return empty_tool_result_message(tool_name)
if isinstance(content, list):
if not content:
return empty_tool_result_message(tool_name)
text_payload = stringify_text_blocks(content)
if text_payload is not None and not text_payload.strip():
return empty_tool_result_message(tool_name)
return content
def is_blank_text(content: str | None) -> bool:
"""True when *content* is missing or only whitespace."""
return content is None or not content.strip()
def build_finalization_retry_message() -> dict[str, str]:
"""A short no-tools-allowed prompt for final answer recovery."""
return {"role": "user", "content": FINALIZATION_RETRY_PROMPT}
def external_lookup_signature(tool_name: str, arguments: dict[str, Any]) -> str | None:
"""Stable signature for repeated external lookups we want to throttle."""
if tool_name == "web_fetch":
url = str(arguments.get("url") or "").strip()
if url:
return f"web_fetch:{url.lower()}"
if tool_name == "web_search":
query = str(arguments.get("query") or arguments.get("search_term") or "").strip()
if query:
return f"web_search:{query.lower()}"
return None
def repeated_external_lookup_error(
tool_name: str,
arguments: dict[str, Any],
seen_counts: dict[str, int],
) -> str | None:
"""Block repeated external lookups after a small retry budget."""
signature = external_lookup_signature(tool_name, arguments)
if signature is None:
return None
count = seen_counts.get(signature, 0) + 1
seen_counts[signature] = count
if count <= _MAX_REPEAT_EXTERNAL_LOOKUPS:
return None
logger.warning(
"Blocking repeated external lookup {} on attempt {}",
signature[:160],
count,
)
return (
"Error: repeated external lookup blocked. "
"Use the results you already have to answer, or try a meaningfully different source."
)

View File

@ -71,3 +71,19 @@ def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
assert "Channel: cli" in user_content
assert "Chat ID: direct" in user_content
assert "Return exactly: OK" in user_content
def test_subagent_result_does_not_create_consecutive_assistant_messages(tmp_path) -> None:
workspace = _make_workspace(tmp_path)
builder = ContextBuilder(workspace)
messages = builder.build_messages(
history=[{"role": "assistant", "content": "previous result"}],
current_message="subagent result",
channel="cli",
chat_id="direct",
current_role="assistant",
)
for left, right in zip(messages, messages[1:]):
assert not (left.get("role") == right.get("role") == "assistant")

View File

@ -5,7 +5,9 @@ from nanobot.session.manager import Session
def _mk_loop() -> AgentLoop:
loop = AgentLoop.__new__(AgentLoop)
loop._TOOL_RESULT_MAX_CHARS = AgentLoop._TOOL_RESULT_MAX_CHARS
from nanobot.config.schema import AgentDefaults
loop.max_tool_result_chars = AgentDefaults().max_tool_result_chars
return loop
@ -72,3 +74,129 @@ def test_save_turn_keeps_tool_results_under_16k() -> None:
)
assert session.messages[0]["content"] == content
def test_restore_runtime_checkpoint_rehydrates_completed_and_pending_tools() -> None:
loop = _mk_loop()
session = Session(
key="test:checkpoint",
metadata={
AgentLoop._RUNTIME_CHECKPOINT_KEY: {
"assistant_message": {
"role": "assistant",
"content": "working",
"tool_calls": [
{
"id": "call_done",
"type": "function",
"function": {"name": "read_file", "arguments": "{}"},
},
{
"id": "call_pending",
"type": "function",
"function": {"name": "exec", "arguments": "{}"},
},
],
},
"completed_tool_results": [
{
"role": "tool",
"tool_call_id": "call_done",
"name": "read_file",
"content": "ok",
}
],
"pending_tool_calls": [
{
"id": "call_pending",
"type": "function",
"function": {"name": "exec", "arguments": "{}"},
}
],
}
},
)
restored = loop._restore_runtime_checkpoint(session)
assert restored is True
assert session.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is None
assert session.messages[0]["role"] == "assistant"
assert session.messages[1]["tool_call_id"] == "call_done"
assert session.messages[2]["tool_call_id"] == "call_pending"
assert "interrupted before this tool finished" in session.messages[2]["content"].lower()
def test_restore_runtime_checkpoint_dedupes_overlapping_tail() -> None:
loop = _mk_loop()
session = Session(
key="test:checkpoint-overlap",
messages=[
{
"role": "assistant",
"content": "working",
"tool_calls": [
{
"id": "call_done",
"type": "function",
"function": {"name": "read_file", "arguments": "{}"},
},
{
"id": "call_pending",
"type": "function",
"function": {"name": "exec", "arguments": "{}"},
},
],
},
{
"role": "tool",
"tool_call_id": "call_done",
"name": "read_file",
"content": "ok",
},
],
metadata={
AgentLoop._RUNTIME_CHECKPOINT_KEY: {
"assistant_message": {
"role": "assistant",
"content": "working",
"tool_calls": [
{
"id": "call_done",
"type": "function",
"function": {"name": "read_file", "arguments": "{}"},
},
{
"id": "call_pending",
"type": "function",
"function": {"name": "exec", "arguments": "{}"},
},
],
},
"completed_tool_results": [
{
"role": "tool",
"tool_call_id": "call_done",
"name": "read_file",
"content": "ok",
}
],
"pending_tool_calls": [
{
"id": "call_pending",
"type": "function",
"function": {"name": "exec", "arguments": "{}"},
}
],
}
},
)
restored = loop._restore_runtime_checkpoint(session)
assert restored is True
assert session.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is None
assert len(session.messages) == 3
assert session.messages[0]["role"] == "assistant"
assert session.messages[1]["tool_call_id"] == "call_done"
assert session.messages[2]["tool_call_id"] == "call_pending"

View File

@ -2,12 +2,20 @@
from __future__ import annotations
import asyncio
import os
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from nanobot.config.schema import AgentDefaults
from nanobot.agent.tools.base import Tool
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.providers.base import LLMResponse, ToolCallRequest
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
def _make_loop(tmp_path):
from nanobot.agent.loop import AgentLoop
@ -60,6 +68,7 @@ async def test_runner_preserves_reasoning_fields_and_tool_results():
tools=tools,
model="test-model",
max_iterations=3,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
assert result.final_content == "done"
@ -135,6 +144,7 @@ async def test_runner_calls_hooks_in_order():
tools=tools,
model="test-model",
max_iterations=3,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=RecordingHook(),
))
@ -191,6 +201,7 @@ async def test_runner_streaming_hook_receives_deltas_and_end_signal():
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=StreamingHook(),
))
@ -219,6 +230,7 @@ async def test_runner_returns_max_iterations_fallback():
tools=tools,
model="test-model",
max_iterations=2,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
assert result.stop_reason == "max_iterations"
@ -226,7 +238,8 @@ async def test_runner_returns_max_iterations_fallback():
"I reached the maximum number of tool call iterations (2) "
"without completing the task. You can try breaking the task into smaller steps."
)
assert result.messages[-1]["role"] == "assistant"
assert result.messages[-1]["content"] == result.final_content
@pytest.mark.asyncio
async def test_runner_returns_structured_tool_error():
@ -248,6 +261,7 @@ async def test_runner_returns_structured_tool_error():
tools=tools,
model="test-model",
max_iterations=2,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
fail_on_tool_error=True,
))
@ -258,6 +272,457 @@ async def test_runner_returns_structured_tool_error():
]
@pytest.mark.asyncio
async def test_runner_persists_large_tool_results_for_follow_up_calls(tmp_path):
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
captured_second_call: list[dict] = []
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content="working",
tool_calls=[ToolCallRequest(id="call_big", name="list_dir", arguments={"path": "."})],
usage={"prompt_tokens": 5, "completion_tokens": 3},
)
captured_second_call[:] = messages
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="x" * 20_000)
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "do task"}],
tools=tools,
model="test-model",
max_iterations=2,
workspace=tmp_path,
session_key="test:runner",
max_tool_result_chars=2048,
))
assert result.final_content == "done"
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
assert "[tool output persisted]" in tool_message["content"]
assert "tool-results" in tool_message["content"]
assert (tmp_path / ".nanobot" / "tool-results" / "test_runner" / "call_big.txt").exists()
def test_persist_tool_result_prunes_old_session_buckets(tmp_path):
from nanobot.utils.helpers import maybe_persist_tool_result
root = tmp_path / ".nanobot" / "tool-results"
old_bucket = root / "old_session"
recent_bucket = root / "recent_session"
old_bucket.mkdir(parents=True)
recent_bucket.mkdir(parents=True)
(old_bucket / "old.txt").write_text("old", encoding="utf-8")
(recent_bucket / "recent.txt").write_text("recent", encoding="utf-8")
stale = time.time() - (8 * 24 * 60 * 60)
os.utime(old_bucket, (stale, stale))
os.utime(old_bucket / "old.txt", (stale, stale))
persisted = maybe_persist_tool_result(
tmp_path,
"current:session",
"call_big",
"x" * 5000,
max_chars=64,
)
assert "[tool output persisted]" in persisted
assert not old_bucket.exists()
assert recent_bucket.exists()
assert (root / "current_session" / "call_big.txt").exists()
def test_persist_tool_result_leaves_no_temp_files(tmp_path):
from nanobot.utils.helpers import maybe_persist_tool_result
root = tmp_path / ".nanobot" / "tool-results"
maybe_persist_tool_result(
tmp_path,
"current:session",
"call_big",
"x" * 5000,
max_chars=64,
)
assert (root / "current_session" / "call_big.txt").exists()
assert list((root / "current_session").glob("*.tmp")) == []
def test_persist_tool_result_logs_cleanup_failures(monkeypatch, tmp_path):
from nanobot.utils.helpers import maybe_persist_tool_result
warnings: list[str] = []
monkeypatch.setattr(
"nanobot.utils.helpers._cleanup_tool_result_buckets",
lambda *_args, **_kwargs: (_ for _ in ()).throw(OSError("busy")),
)
monkeypatch.setattr(
"nanobot.utils.helpers.logger.warning",
lambda message, *args: warnings.append(message.format(*args)),
)
persisted = maybe_persist_tool_result(
tmp_path,
"current:session",
"call_big",
"x" * 5000,
max_chars=64,
)
assert "[tool output persisted]" in persisted
assert warnings and "Failed to clean stale tool result buckets" in warnings[0]
@pytest.mark.asyncio
async def test_runner_replaces_empty_tool_result_with_marker():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
captured_second_call: list[dict] = []
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content="working",
tool_calls=[ToolCallRequest(id="call_1", name="noop", arguments={})],
usage={},
)
captured_second_call[:] = messages
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="")
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "do task"}],
tools=tools,
model="test-model",
max_iterations=2,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
assert result.final_content == "done"
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
assert tool_message["content"] == "(noop completed with no output)"
@pytest.mark.asyncio
async def test_runner_uses_raw_messages_when_context_governance_fails():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
captured_messages: list[dict] = []
async def chat_with_retry(*, messages, **kwargs):
captured_messages[:] = messages
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
initial_messages = [
{"role": "system", "content": "system"},
{"role": "user", "content": "hello"},
]
runner = AgentRunner(provider)
runner._snip_history = MagicMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign]
result = await runner.run(AgentRunSpec(
initial_messages=initial_messages,
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
assert result.final_content == "done"
assert captured_messages == initial_messages
@pytest.mark.asyncio
async def test_runner_retries_empty_final_response_with_summary_prompt():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
calls: list[dict] = []
async def chat_with_retry(*, messages, tools=None, **kwargs):
calls.append({"messages": messages, "tools": tools})
if len(calls) == 1:
return LLMResponse(
content=None,
tool_calls=[],
usage={"prompt_tokens": 10, "completion_tokens": 1},
)
return LLMResponse(
content="final answer",
tool_calls=[],
usage={"prompt_tokens": 3, "completion_tokens": 7},
)
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "do task"}],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
assert result.final_content == "final answer"
assert len(calls) == 2
assert calls[1]["tools"] is None
assert "Do not call any more tools" in calls[1]["messages"][-1]["content"]
assert result.usage["prompt_tokens"] == 13
assert result.usage["completion_tokens"] == 8
@pytest.mark.asyncio
async def test_runner_uses_specific_message_after_empty_finalization_retry():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
provider = MagicMock()
async def chat_with_retry(*, messages, **kwargs):
return LLMResponse(content=None, tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "do task"}],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
assert result.final_content == EMPTY_FINAL_RESPONSE_MESSAGE
assert result.stop_reason == "empty_final_response"
def test_snip_history_drops_orphaned_tool_results_from_trimmed_slice(monkeypatch):
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
tools = MagicMock()
tools.get_definitions.return_value = []
runner = AgentRunner(provider)
messages = [
{"role": "system", "content": "system"},
{"role": "user", "content": "old user"},
{
"role": "assistant",
"content": "tool call",
"tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "ls", "arguments": "{}"}}],
},
{"role": "tool", "tool_call_id": "call_1", "content": "tool output"},
{"role": "assistant", "content": "after tool"},
]
spec = AgentRunSpec(
initial_messages=messages,
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
context_window_tokens=2000,
context_block_limit=100,
)
monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", lambda *_args, **_kwargs: (500, None))
token_sizes = {
"old user": 120,
"tool call": 120,
"tool output": 40,
"after tool": 40,
"system": 0,
}
monkeypatch.setattr(
"nanobot.agent.runner.estimate_message_tokens",
lambda msg: token_sizes.get(str(msg.get("content")), 40),
)
trimmed = runner._snip_history(spec, messages)
assert trimmed == [
{"role": "system", "content": "system"},
{"role": "assistant", "content": "after tool"},
]
@pytest.mark.asyncio
async def test_runner_keeps_going_when_tool_result_persistence_fails():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
captured_second_call: list[dict] = []
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content="working",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
usage={"prompt_tokens": 5, "completion_tokens": 3},
)
captured_second_call[:] = messages
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="tool result")
runner = AgentRunner(provider)
with patch("nanobot.agent.runner.maybe_persist_tool_result", side_effect=RuntimeError("disk full")):
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "do task"}],
tools=tools,
model="test-model",
max_iterations=2,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
assert result.final_content == "done"
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
assert tool_message["content"] == "tool result"
class _DelayTool(Tool):
def __init__(self, name: str, *, delay: float, read_only: bool, shared_events: list[str]):
self._name = name
self._delay = delay
self._read_only = read_only
self._shared_events = shared_events
@property
def name(self) -> str:
return self._name
@property
def description(self) -> str:
return self._name
@property
def parameters(self) -> dict:
return {"type": "object", "properties": {}, "required": []}
@property
def read_only(self) -> bool:
return self._read_only
async def execute(self, **kwargs):
self._shared_events.append(f"start:{self._name}")
await asyncio.sleep(self._delay)
self._shared_events.append(f"end:{self._name}")
return self._name
@pytest.mark.asyncio
async def test_runner_batches_read_only_tools_before_exclusive_work():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
tools = ToolRegistry()
shared_events: list[str] = []
read_a = _DelayTool("read_a", delay=0.05, read_only=True, shared_events=shared_events)
read_b = _DelayTool("read_b", delay=0.05, read_only=True, shared_events=shared_events)
write_a = _DelayTool("write_a", delay=0.01, read_only=False, shared_events=shared_events)
tools.register(read_a)
tools.register(read_b)
tools.register(write_a)
runner = AgentRunner(MagicMock())
await runner._execute_tools(
AgentRunSpec(
initial_messages=[],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
concurrent_tools=True,
),
[
ToolCallRequest(id="ro1", name="read_a", arguments={}),
ToolCallRequest(id="ro2", name="read_b", arguments={}),
ToolCallRequest(id="rw1", name="write_a", arguments={}),
],
{},
)
assert shared_events[0:2] == ["start:read_a", "start:read_b"]
assert "end:read_a" in shared_events and "end:read_b" in shared_events
assert shared_events.index("end:read_a") < shared_events.index("start:write_a")
assert shared_events.index("end:read_b") < shared_events.index("start:write_a")
assert shared_events[-2:] == ["start:write_a", "end:write_a"]
@pytest.mark.asyncio
async def test_runner_blocks_repeated_external_fetches():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
captured_final_call: list[dict] = []
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
if call_count["n"] <= 3:
return LLMResponse(
content="working",
tool_calls=[ToolCallRequest(id=f"call_{call_count['n']}", name="web_fetch", arguments={"url": "https://example.com"})],
usage={},
)
captured_final_call[:] = messages
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="page content")
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "research task"}],
tools=tools,
model="test-model",
max_iterations=4,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
assert result.final_content == "done"
assert tools.execute.await_count == 2
blocked_tool_message = [
msg for msg in captured_final_call
if msg.get("role") == "tool" and msg.get("tool_call_id") == "call_3"
][0]
assert "repeated external lookup blocked" in blocked_tool_message["content"]
@pytest.mark.asyncio
async def test_loop_max_iterations_message_stays_stable(tmp_path):
loop = _make_loop(tmp_path)
@ -307,6 +772,57 @@ async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp
assert endings == [False]
@pytest.mark.asyncio
async def test_loop_retries_think_only_final_response(tmp_path):
loop = _make_loop(tmp_path)
call_count = {"n": 0}
async def chat_with_retry(**kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(content="<think>hidden</think>", tool_calls=[], usage={})
return LLMResponse(content="Recovered answer", tool_calls=[], usage={})
loop.provider.chat_with_retry = chat_with_retry
final_content, _, _ = await loop._run_agent_loop([])
assert final_content == "Recovered answer"
assert call_count["n"] == 2
@pytest.mark.asyncio
async def test_runner_tool_error_sets_final_content():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
async def chat_with_retry(*, messages, **kwargs):
return LLMResponse(
content="working",
tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})],
usage={},
)
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(side_effect=RuntimeError("boom"))
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "do task"}],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
fail_on_tool_error=True,
))
assert result.final_content == "Error: RuntimeError: boom"
assert result.stop_reason == "tool_error"
@pytest.mark.asyncio
async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, monkeypatch):
from nanobot.agent.subagent import SubagentManager
@ -317,15 +833,20 @@ async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, mon
provider.get_default_model.return_value = "test-model"
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
content="working",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
))
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
mgr = SubagentManager(
provider=provider,
workspace=tmp_path,
bus=bus,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
)
mgr._announce_result = AsyncMock()
async def fake_execute(self, name, arguments):
async def fake_execute(self, **kwargs):
return "tool result"
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
@ -369,6 +890,7 @@ async def test_runner_accumulates_usage_and_preserves_cached_tokens():
tools=tools,
model="test-model",
max_iterations=3,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
# Usage should be accumulated across iterations
@ -407,6 +929,7 @@ async def test_runner_passes_cached_tokens_to_hook_context():
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=UsageHook(),
))

View File

@ -8,6 +8,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from nanobot.config.schema import AgentDefaults
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
def _make_loop(*, exec_config=None):
"""Create a minimal AgentLoop with mocked dependencies."""
@ -186,7 +190,12 @@ class TestSubagentCancellation:
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
mgr = SubagentManager(
provider=provider,
workspace=MagicMock(),
bus=bus,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
)
cancelled = asyncio.Event()
@ -214,7 +223,12 @@ class TestSubagentCancellation:
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
mgr = SubagentManager(
provider=provider,
workspace=MagicMock(),
bus=bus,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
)
assert await mgr.cancel_by_session("nonexistent") == 0
@pytest.mark.asyncio
@ -236,19 +250,24 @@ class TestSubagentCancellation:
if call_count["n"] == 1:
return LLMResponse(
content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
reasoning_content="hidden reasoning",
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
)
captured_second_call[:] = messages
return LLMResponse(content="done", tool_calls=[])
provider.chat_with_retry = scripted_chat_with_retry
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
mgr = SubagentManager(
provider=provider,
workspace=tmp_path,
bus=bus,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
)
async def fake_execute(self, name, arguments):
async def fake_execute(self, **kwargs):
return "tool result"
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
@ -273,6 +292,7 @@ class TestSubagentCancellation:
provider=provider,
workspace=tmp_path,
bus=bus,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
exec_config=ExecToolConfig(enable=False),
)
mgr._announce_result = AsyncMock()
@ -304,20 +324,25 @@ class TestSubagentCancellation:
provider.get_default_model.return_value = "test-model"
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
))
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
mgr = SubagentManager(
provider=provider,
workspace=tmp_path,
bus=bus,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
)
mgr._announce_result = AsyncMock()
calls = {"n": 0}
async def fake_execute(self, name, arguments):
async def fake_execute(self, **kwargs):
calls["n"] += 1
if calls["n"] == 1:
return "first result"
raise RuntimeError("boom")
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
@ -340,15 +365,20 @@ class TestSubagentCancellation:
provider.get_default_model.return_value = "test-model"
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
))
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
mgr = SubagentManager(
provider=provider,
workspace=tmp_path,
bus=bus,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
)
mgr._announce_result = AsyncMock()
started = asyncio.Event()
cancelled = asyncio.Event()
async def fake_execute(self, name, arguments):
async def fake_execute(self, **kwargs):
started.set()
try:
await asyncio.sleep(60)
@ -356,7 +386,7 @@ class TestSubagentCancellation:
cancelled.set()
raise
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
task = asyncio.create_task(
mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
@ -364,7 +394,7 @@ class TestSubagentCancellation:
mgr._running_tasks["sub-1"] = task
mgr._session_tasks["test:c1"] = {"sub-1"}
await started.wait()
await asyncio.wait_for(started.wait(), timeout=1.0)
count = await mgr.cancel_by_session("test:c1")

View File

@ -208,7 +208,7 @@ def test_channels_login_uses_discovered_plugin_class(monkeypatch):
seen["config"] = self.config
return True
monkeypatch.setattr("nanobot.config.loader.load_config", lambda: Config())
monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config())
monkeypatch.setattr(
"nanobot.channels.registry.discover_all",
lambda: {"fakeplugin": _LoginPlugin},
@ -220,6 +220,57 @@ def test_channels_login_uses_discovered_plugin_class(monkeypatch):
assert seen["force"] is True
def test_channels_login_sets_custom_config_path(monkeypatch, tmp_path):
from nanobot.cli.commands import app
from nanobot.config.schema import Config
from typer.testing import CliRunner
runner = CliRunner()
seen: dict[str, object] = {}
config_path = tmp_path / "custom-config.json"
class _LoginPlugin(_FakePlugin):
async def login(self, force: bool = False) -> bool:
return True
monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config())
monkeypatch.setattr(
"nanobot.config.loader.set_config_path",
lambda path: seen.__setitem__("config_path", path),
)
monkeypatch.setattr(
"nanobot.channels.registry.discover_all",
lambda: {"fakeplugin": _LoginPlugin},
)
result = runner.invoke(app, ["channels", "login", "fakeplugin", "--config", str(config_path)])
assert result.exit_code == 0
assert seen["config_path"] == config_path.resolve()
def test_channels_status_sets_custom_config_path(monkeypatch, tmp_path):
from nanobot.cli.commands import app
from nanobot.config.schema import Config
from typer.testing import CliRunner
runner = CliRunner()
seen: dict[str, object] = {}
config_path = tmp_path / "custom-config.json"
monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config())
monkeypatch.setattr(
"nanobot.config.loader.set_config_path",
lambda path: seen.__setitem__("config_path", path),
)
monkeypatch.setattr("nanobot.channels.registry.discover_all", lambda: {})
result = runner.invoke(app, ["channels", "status", "--config", str(config_path)])
assert result.exit_code == 0
assert seen["config_path"] == config_path.resolve()
@pytest.mark.asyncio
async def test_manager_skips_disabled_plugin():
fake_config = SimpleNamespace(

View File

@ -594,7 +594,7 @@ async def test_send_stops_typing_after_send() -> None:
typing_channel.typing_enter_hook = slow_typing
await channel._start_typing(typing_channel)
await start.wait()
await asyncio.wait_for(start.wait(), timeout=1.0)
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
release.set()
@ -614,7 +614,7 @@ async def test_send_stops_typing_after_send() -> None:
typing_channel.typing_enter_hook = slow_typing_progress
await channel._start_typing(typing_channel)
await start.wait()
await asyncio.wait_for(start.wait(), timeout=1.0)
await channel.send(
OutboundMessage(
@ -665,7 +665,7 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() ->
typing_channel = _NoTriggerChannel(channel_id=123)
await channel._start_typing(typing_channel) # type: ignore[arg-type]
await entered.wait()
await asyncio.wait_for(entered.wait(), timeout=1.0)
assert "123" in channel._typing_tasks

View File

@ -3,16 +3,14 @@ from pathlib import Path
from types import SimpleNamespace
import pytest
pytest.importorskip("nio")
pytest.importorskip("nh3")
pytest.importorskip("mistune")
from nio import RoomSendResponse
from nanobot.channels.matrix import _build_matrix_text_content
# Check optional matrix dependencies before importing
try:
import nh3 # noqa: F401
except ImportError:
pytest.skip("Matrix dependencies not installed (nh3)", allow_module_level=True)
import nanobot.channels.matrix as matrix_module
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus

View File

@ -0,0 +1,172 @@
"""Tests for QQ channel ack_message feature.
Covers the four verification points from the PR:
1. C2C message: ack appears instantly
2. Group message: ack appears instantly
3. ack_message set to "": no ack sent
4. Custom ack_message text: correct text delivered
Each test also verifies that normal message processing is not blocked.
"""
from types import SimpleNamespace
import pytest
try:
from nanobot.channels import qq
QQ_AVAILABLE = getattr(qq, "QQ_AVAILABLE", False)
except ImportError:
QQ_AVAILABLE = False
if not QQ_AVAILABLE:
pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True)
from nanobot.bus.queue import MessageBus
from nanobot.channels.qq import QQChannel, QQConfig
class _FakeApi:
def __init__(self) -> None:
self.c2c_calls: list[dict] = []
self.group_calls: list[dict] = []
async def post_c2c_message(self, **kwargs) -> None:
self.c2c_calls.append(kwargs)
async def post_group_message(self, **kwargs) -> None:
self.group_calls.append(kwargs)
class _FakeClient:
def __init__(self) -> None:
self.api = _FakeApi()
@pytest.mark.asyncio
async def test_ack_sent_on_c2c_message() -> None:
"""Ack is sent immediately for C2C messages, then normal processing continues."""
channel = QQChannel(
QQConfig(
app_id="app",
secret="secret",
allow_from=["*"],
ack_message="⏳ Processing...",
),
MessageBus(),
)
channel._client = _FakeClient()
data = SimpleNamespace(
id="msg1",
content="hello",
author=SimpleNamespace(user_openid="user1"),
attachments=[],
)
await channel._on_message(data, is_group=False)
assert len(channel._client.api.c2c_calls) >= 1
ack_call = channel._client.api.c2c_calls[0]
assert ack_call["content"] == "⏳ Processing..."
assert ack_call["openid"] == "user1"
assert ack_call["msg_id"] == "msg1"
assert ack_call["msg_type"] == 0
msg = await channel.bus.consume_inbound()
assert msg.content == "hello"
assert msg.sender_id == "user1"
@pytest.mark.asyncio
async def test_ack_sent_on_group_message() -> None:
"""Ack is sent immediately for group messages, then normal processing continues."""
channel = QQChannel(
QQConfig(
app_id="app",
secret="secret",
allow_from=["*"],
ack_message="⏳ Processing...",
),
MessageBus(),
)
channel._client = _FakeClient()
data = SimpleNamespace(
id="msg2",
content="hello group",
group_openid="group123",
author=SimpleNamespace(member_openid="user1"),
attachments=[],
)
await channel._on_message(data, is_group=True)
assert len(channel._client.api.group_calls) >= 1
ack_call = channel._client.api.group_calls[0]
assert ack_call["content"] == "⏳ Processing..."
assert ack_call["group_openid"] == "group123"
assert ack_call["msg_id"] == "msg2"
assert ack_call["msg_type"] == 0
msg = await channel.bus.consume_inbound()
assert msg.content == "hello group"
assert msg.chat_id == "group123"
@pytest.mark.asyncio
async def test_no_ack_when_ack_message_empty() -> None:
"""Setting ack_message to empty string disables the ack entirely."""
channel = QQChannel(
QQConfig(
app_id="app",
secret="secret",
allow_from=["*"],
ack_message="",
),
MessageBus(),
)
channel._client = _FakeClient()
data = SimpleNamespace(
id="msg3",
content="hello",
author=SimpleNamespace(user_openid="user1"),
attachments=[],
)
await channel._on_message(data, is_group=False)
assert len(channel._client.api.c2c_calls) == 0
assert len(channel._client.api.group_calls) == 0
msg = await channel.bus.consume_inbound()
assert msg.content == "hello"
@pytest.mark.asyncio
async def test_custom_ack_message_text() -> None:
"""Custom Chinese ack_message text is delivered correctly."""
custom = "正在处理中,请稍候..."
channel = QQChannel(
QQConfig(
app_id="app",
secret="secret",
allow_from=["*"],
ack_message=custom,
),
MessageBus(),
)
channel._client = _FakeClient()
data = SimpleNamespace(
id="msg4",
content="test input",
author=SimpleNamespace(user_openid="user1"),
attachments=[],
)
await channel._on_message(data, is_group=False)
assert len(channel._client.api.c2c_calls) >= 1
ack_call = channel._client.api.c2c_calls[0]
assert ack_call["content"] == custom
msg = await channel.bus.consume_inbound()
assert msg.content == "test input"

View File

@ -647,43 +647,56 @@ async def test_group_policy_open_accepts_plain_group_message() -> None:
assert channel._app.bot.get_me_calls == 0
def test_extract_reply_context_no_reply() -> None:
@pytest.mark.asyncio
async def test_extract_reply_context_no_reply() -> None:
"""When there is no reply_to_message, _extract_reply_context returns None."""
channel = TelegramChannel(TelegramConfig(enabled=True, token="123:abc"), MessageBus())
message = SimpleNamespace(reply_to_message=None)
assert TelegramChannel._extract_reply_context(message) is None
assert await channel._extract_reply_context(message) is None
def test_extract_reply_context_with_text() -> None:
@pytest.mark.asyncio
async def test_extract_reply_context_with_text() -> None:
"""When reply has text, return prefixed string."""
reply = SimpleNamespace(text="Hello world", caption=None)
channel = TelegramChannel(TelegramConfig(enabled=True, token="123:abc"), MessageBus())
channel._app = _FakeApp(lambda: None)
reply = SimpleNamespace(text="Hello world", caption=None, from_user=SimpleNamespace(id=2, username="testuser", first_name="Test"))
message = SimpleNamespace(reply_to_message=reply)
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Hello world]"
assert await channel._extract_reply_context(message) == "[Reply to @testuser: Hello world]"
def test_extract_reply_context_with_caption_only() -> None:
@pytest.mark.asyncio
async def test_extract_reply_context_with_caption_only() -> None:
"""When reply has only caption (no text), caption is used."""
reply = SimpleNamespace(text=None, caption="Photo caption")
channel = TelegramChannel(TelegramConfig(enabled=True, token="123:abc"), MessageBus())
channel._app = _FakeApp(lambda: None)
reply = SimpleNamespace(text=None, caption="Photo caption", from_user=SimpleNamespace(id=2, username=None, first_name="Test"))
message = SimpleNamespace(reply_to_message=reply)
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Photo caption]"
assert await channel._extract_reply_context(message) == "[Reply to Test: Photo caption]"
def test_extract_reply_context_truncation() -> None:
@pytest.mark.asyncio
async def test_extract_reply_context_truncation() -> None:
"""Reply text is truncated at TELEGRAM_REPLY_CONTEXT_MAX_LEN."""
channel = TelegramChannel(TelegramConfig(enabled=True, token="123:abc"), MessageBus())
channel._app = _FakeApp(lambda: None)
long_text = "x" * (TELEGRAM_REPLY_CONTEXT_MAX_LEN + 100)
reply = SimpleNamespace(text=long_text, caption=None)
reply = SimpleNamespace(text=long_text, caption=None, from_user=SimpleNamespace(id=2, username=None, first_name=None))
message = SimpleNamespace(reply_to_message=reply)
result = TelegramChannel._extract_reply_context(message)
result = await channel._extract_reply_context(message)
assert result is not None
assert result.startswith("[Reply to: ")
assert result.endswith("...]")
assert len(result) == len("[Reply to: ]") + TELEGRAM_REPLY_CONTEXT_MAX_LEN + len("...")
def test_extract_reply_context_no_text_returns_none() -> None:
@pytest.mark.asyncio
async def test_extract_reply_context_no_text_returns_none() -> None:
"""When reply has no text/caption, _extract_reply_context returns None (media handled separately)."""
channel = TelegramChannel(TelegramConfig(enabled=True, token="123:abc"), MessageBus())
reply = SimpleNamespace(text=None, caption=None)
message = SimpleNamespace(reply_to_message=reply)
assert TelegramChannel._extract_reply_context(message) is None
assert await channel._extract_reply_context(message) is None
@pytest.mark.asyncio

View File

@ -1,6 +1,6 @@
"""Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
"""Test Azure OpenAI provider (Responses API via OpenAI SDK)."""
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import AsyncMock, MagicMock
import pytest
@ -8,392 +8,401 @@ from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
from nanobot.providers.base import LLMResponse
def test_azure_openai_provider_init():
"""Test AzureOpenAIProvider initialization without deployment_name."""
# ---------------------------------------------------------------------------
# Init & validation
# ---------------------------------------------------------------------------
def test_init_creates_sdk_client():
"""Provider creates an AsyncOpenAI client with correct base_url."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
)
assert provider.api_key == "test-key"
assert provider.api_base == "https://test-resource.openai.azure.com/"
assert provider.default_model == "gpt-4o-deployment"
assert provider.api_version == "2024-10-21"
# SDK client base_url ends with /openai/v1/
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
def test_azure_openai_provider_init_validation():
"""Test AzureOpenAIProvider initialization validation."""
# Missing api_key
def test_init_base_url_no_trailing_slash():
"""Trailing slashes are normalised before building base_url."""
provider = AzureOpenAIProvider(
api_key="k", api_base="https://res.openai.azure.com",
)
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
def test_init_base_url_with_trailing_slash():
provider = AzureOpenAIProvider(
api_key="k", api_base="https://res.openai.azure.com/",
)
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
def test_init_validation_missing_key():
with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
AzureOpenAIProvider(api_key="", api_base="https://test.com")
# Missing api_base
def test_init_validation_missing_base():
with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
AzureOpenAIProvider(api_key="test", api_base="")
def test_build_chat_url():
"""Test Azure OpenAI URL building with different deployment names."""
def test_no_api_version_in_base_url():
"""The /openai/v1/ path should NOT contain an api-version query param."""
provider = AzureOpenAIProvider(api_key="k", api_base="https://res.openai.azure.com")
base = str(provider._client.base_url)
assert "api-version" not in base
# ---------------------------------------------------------------------------
# _supports_temperature
# ---------------------------------------------------------------------------
def test_supports_temperature_standard_model():
assert AzureOpenAIProvider._supports_temperature("gpt-4o") is True
def test_supports_temperature_reasoning_model():
assert AzureOpenAIProvider._supports_temperature("o3-mini") is False
assert AzureOpenAIProvider._supports_temperature("gpt-5-chat") is False
assert AzureOpenAIProvider._supports_temperature("o4-mini") is False
def test_supports_temperature_with_reasoning_effort():
assert AzureOpenAIProvider._supports_temperature("gpt-4o", reasoning_effort="medium") is False
# ---------------------------------------------------------------------------
# _build_body — Responses API body construction
# ---------------------------------------------------------------------------
def test_build_body_basic():
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
api_key="k", api_base="https://res.openai.azure.com", default_model="gpt-4o",
)
# Test various deployment names
test_cases = [
("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"),
("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"),
("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"),
]
for deployment_name, expected_url in test_cases:
url = provider._build_chat_url(deployment_name)
assert url == expected_url
messages = [{"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hi"}]
body = provider._build_body(messages, None, None, 4096, 0.7, None, None)
def test_build_chat_url_api_base_without_slash():
"""Test URL building when api_base doesn't end with slash."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com", # No trailing slash
default_model="gpt-4o",
assert body["model"] == "gpt-4o"
assert body["instructions"] == "You are helpful."
assert body["temperature"] == 0.7
assert body["max_output_tokens"] == 4096
assert body["store"] is False
assert "reasoning" not in body
# input should contain the converted user message only (system extracted)
assert any(
item.get("role") == "user"
for item in body["input"]
)
url = provider._build_chat_url("test-deployment")
expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21"
assert url == expected
def test_build_headers():
"""Test Azure OpenAI header building with api-key authentication."""
provider = AzureOpenAIProvider(
api_key="test-api-key-123",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
headers = provider._build_headers()
assert headers["Content-Type"] == "application/json"
assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header
assert "x-session-affinity" in headers
def test_build_body_max_tokens_minimum():
"""max_output_tokens should never be less than 1."""
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
body = provider._build_body([{"role": "user", "content": "x"}], None, None, 0, 0.7, None, None)
assert body["max_output_tokens"] == 1
def test_prepare_request_payload():
"""Test request payload preparation with Azure OpenAI 2024-10-21 compliance."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
messages = [{"role": "user", "content": "Hello"}]
payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
assert payload["messages"] == messages
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
assert payload["temperature"] == 0.8
assert "tools" not in payload
# Test with tools
def test_build_body_with_tools():
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools)
assert payload_with_tools["tools"] == tools
assert payload_with_tools["tool_choice"] == "auto"
# Test with reasoning_effort
payload_with_reasoning = provider._prepare_request_payload(
"gpt-5-chat", messages, reasoning_effort="medium"
body = provider._build_body(
[{"role": "user", "content": "weather?"}], tools, None, 4096, 0.7, None, None,
)
assert payload_with_reasoning["reasoning_effort"] == "medium"
assert "temperature" not in payload_with_reasoning
assert body["tools"] == [{"type": "function", "name": "get_weather", "description": "", "parameters": {}}]
assert body["tool_choice"] == "auto"
def test_prepare_request_payload_sanitizes_messages():
"""Test Azure payload strips non-standard message keys before sending."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
def test_build_body_with_reasoning():
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-5-chat")
body = provider._build_body(
[{"role": "user", "content": "think"}], None, "gpt-5-chat", 4096, 0.7, "medium", None,
)
assert body["reasoning"] == {"effort": "medium"}
assert "reasoning.encrypted_content" in body.get("include", [])
# temperature omitted for reasoning models
assert "temperature" not in body
messages = [
{
"role": "assistant",
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
"reasoning_content": "hidden chain-of-thought",
},
{
"role": "tool",
"tool_call_id": "call_123",
"name": "x",
"content": "ok",
"extra_field": "should be removed",
},
]
payload = provider._prepare_request_payload("gpt-4o", messages)
def test_build_body_image_conversion():
"""image_url content blocks should be converted to input_image."""
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
messages = [{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "https://example.com/img.png"}},
],
}]
body = provider._build_body(messages, None, None, 4096, 0.7, None, None)
user_item = body["input"][0]
content_types = [b["type"] for b in user_item["content"]]
assert "input_text" in content_types
assert "input_image" in content_types
image_block = next(b for b in user_item["content"] if b["type"] == "input_image")
assert image_block["image_url"] == "https://example.com/img.png"
assert payload["messages"] == [
{
"role": "assistant",
"content": None,
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
def test_build_body_sanitizes_single_dict_content_block():
"""Single content dicts should be preserved via shared message sanitization."""
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
messages = [{
"role": "user",
"content": {"type": "text", "text": "Hi from dict content"},
}]
body = provider._build_body(messages, None, None, 4096, 0.7, None, None)
assert body["input"][0]["content"] == [{"type": "input_text", "text": "Hi from dict content"}]
# ---------------------------------------------------------------------------
# chat() — non-streaming
# ---------------------------------------------------------------------------
def _make_sdk_response(
content="Hello!", tool_calls=None, status="completed",
usage=None,
):
"""Build a mock that quacks like an openai Response object."""
resp = MagicMock()
resp.model_dump = MagicMock(return_value={
"output": [
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": content}]},
*([{
"type": "function_call",
"call_id": tc["call_id"], "id": tc["id"],
"name": tc["name"], "arguments": tc["arguments"],
} for tc in (tool_calls or [])]),
],
"status": status,
"usage": {
"input_tokens": (usage or {}).get("input_tokens", 10),
"output_tokens": (usage or {}).get("output_tokens", 5),
"total_tokens": (usage or {}).get("total_tokens", 15),
},
{
"role": "tool",
"tool_call_id": "call_123",
"name": "x",
"content": "ok",
},
]
})
return resp
@pytest.mark.asyncio
async def test_chat_success():
"""Test successful chat request using model as deployment name."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
# Mock response data
mock_response_data = {
"choices": [{
"message": {
"content": "Hello! How can I help you today?",
"role": "assistant"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 12,
"completion_tokens": 18,
"total_tokens": 30
}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
# Test with specific model (deployment name)
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages, model="custom-deployment")
assert isinstance(result, LLMResponse)
assert result.content == "Hello! How can I help you today?"
assert result.finish_reason == "stop"
assert result.usage["prompt_tokens"] == 12
assert result.usage["completion_tokens"] == 18
assert result.usage["total_tokens"] == 30
# Verify URL was built with the provided model as deployment name
call_args = mock_context.post.call_args
expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21"
assert call_args[0][0] == expected_url
mock_resp = _make_sdk_response(content="Hello!")
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_resp)
result = await provider.chat([{"role": "user", "content": "Hi"}])
assert isinstance(result, LLMResponse)
assert result.content == "Hello!"
assert result.finish_reason == "stop"
assert result.usage["prompt_tokens"] == 10
@pytest.mark.asyncio
async def test_chat_uses_default_model_when_no_model_provided():
"""Test that chat uses default_model when no model is specified."""
async def test_chat_uses_default_model():
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="default-deployment",
api_key="k", api_base="https://test.openai.azure.com", default_model="my-deployment",
)
mock_response_data = {
"choices": [{
"message": {"content": "Response", "role": "assistant"},
"finish_reason": "stop"
}],
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Test"}]
await provider.chat(messages) # No model specified
# Verify URL was built with default model as deployment name
call_args = mock_context.post.call_args
expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21"
assert call_args[0][0] == expected_url
mock_resp = _make_sdk_response(content="ok")
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_resp)
await provider.chat([{"role": "user", "content": "test"}])
call_kwargs = provider._client.responses.create.call_args[1]
assert call_kwargs["model"] == "my-deployment"
@pytest.mark.asyncio
async def test_chat_custom_model():
provider = AzureOpenAIProvider(
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
mock_resp = _make_sdk_response(content="ok")
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_resp)
await provider.chat([{"role": "user", "content": "test"}], model="custom-deploy")
call_kwargs = provider._client.responses.create.call_args[1]
assert call_kwargs["model"] == "custom-deploy"
@pytest.mark.asyncio
async def test_chat_with_tool_calls():
"""Test chat request with tool calls in response."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
# Mock response with tool calls
mock_response_data = {
"choices": [{
"message": {
"content": None,
"role": "assistant",
"tool_calls": [{
"id": "call_12345",
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}'
}
}]
},
"finish_reason": "tool_calls"
mock_resp = _make_sdk_response(
content=None,
tool_calls=[{
"call_id": "call_123", "id": "fc_1",
"name": "get_weather", "arguments": '{"location": "SF"}',
}],
"usage": {
"prompt_tokens": 20,
"completion_tokens": 15,
"total_tokens": 35
}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "What's the weather?"}]
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
result = await provider.chat(messages, tools=tools, model="weather-model")
assert isinstance(result, LLMResponse)
assert result.content is None
assert result.finish_reason == "tool_calls"
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "get_weather"
assert result.tool_calls[0].arguments == {"location": "San Francisco"}
)
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_resp)
result = await provider.chat(
[{"role": "user", "content": "Weather?"}],
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
)
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "get_weather"
assert result.tool_calls[0].arguments == {"location": "SF"}
@pytest.mark.asyncio
async def test_chat_api_error():
"""Test chat request API error handling."""
async def test_chat_error_handling():
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 401
mock_response.text = "Invalid authentication credentials"
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages)
assert isinstance(result, LLMResponse)
assert "Azure OpenAI API Error 401" in result.content
assert "Invalid authentication credentials" in result.content
assert result.finish_reason == "error"
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed"))
result = await provider.chat([{"role": "user", "content": "Hi"}])
@pytest.mark.asyncio
async def test_chat_connection_error():
"""Test chat request connection error handling."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
with patch("httpx.AsyncClient") as mock_client:
mock_context = AsyncMock()
mock_context.post = AsyncMock(side_effect=Exception("Connection failed"))
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages)
assert isinstance(result, LLMResponse)
assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content
assert result.finish_reason == "error"
def test_parse_response_malformed():
"""Test response parsing with malformed data."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
# Test with missing choices
malformed_response = {"usage": {"prompt_tokens": 10}}
result = provider._parse_response(malformed_response)
assert isinstance(result, LLMResponse)
assert "Error parsing Azure OpenAI response" in result.content
assert "Connection failed" in result.content
assert result.finish_reason == "error"
@pytest.mark.asyncio
async def test_chat_reasoning_param_format():
"""reasoning_effort should be sent as reasoning={effort: ...} not a flat string."""
provider = AzureOpenAIProvider(
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-5-chat",
)
mock_resp = _make_sdk_response(content="thought")
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_resp)
await provider.chat(
[{"role": "user", "content": "think"}], reasoning_effort="medium",
)
call_kwargs = provider._client.responses.create.call_args[1]
assert call_kwargs["reasoning"] == {"effort": "medium"}
assert "reasoning_effort" not in call_kwargs
# ---------------------------------------------------------------------------
# chat_stream()
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_chat_stream_success():
"""Streaming should call on_content_delta and return combined response."""
provider = AzureOpenAIProvider(
api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
# Build mock SDK stream events
events = []
ev1 = MagicMock(type="response.output_text.delta", delta="Hello")
ev2 = MagicMock(type="response.output_text.delta", delta=" world")
resp_obj = MagicMock(status="completed")
ev3 = MagicMock(type="response.completed", response=resp_obj)
events = [ev1, ev2, ev3]
async def mock_stream():
for e in events:
yield e
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_stream())
deltas: list[str] = []
async def on_delta(text: str) -> None:
deltas.append(text)
result = await provider.chat_stream(
[{"role": "user", "content": "Hi"}], on_content_delta=on_delta,
)
assert result.content == "Hello world"
assert result.finish_reason == "stop"
assert deltas == ["Hello", " world"]
@pytest.mark.asyncio
async def test_chat_stream_with_tool_calls():
"""Streaming tool calls should be accumulated correctly."""
provider = AzureOpenAIProvider(
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
item_added = MagicMock(type="function_call", call_id="call_1", id="fc_1", arguments="")
item_added.name = "get_weather"
ev_added = MagicMock(type="response.output_item.added", item=item_added)
ev_args_delta = MagicMock(type="response.function_call_arguments.delta", call_id="call_1", delta='{"loc')
ev_args_done = MagicMock(
type="response.function_call_arguments.done",
call_id="call_1", arguments='{"location":"SF"}',
)
item_done = MagicMock(
type="function_call", call_id="call_1", id="fc_1",
arguments='{"location":"SF"}',
)
item_done.name = "get_weather"
ev_item_done = MagicMock(type="response.output_item.done", item=item_done)
resp_obj = MagicMock(status="completed")
ev_completed = MagicMock(type="response.completed", response=resp_obj)
async def mock_stream():
for e in [ev_added, ev_args_delta, ev_args_done, ev_item_done, ev_completed]:
yield e
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_stream())
result = await provider.chat_stream(
[{"role": "user", "content": "weather?"}],
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
)
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "get_weather"
assert result.tool_calls[0].arguments == {"location": "SF"}
@pytest.mark.asyncio
async def test_chat_stream_error():
"""Streaming should return error when SDK raises."""
provider = AzureOpenAIProvider(
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed"))
result = await provider.chat_stream([{"role": "user", "content": "Hi"}])
assert "Connection failed" in result.content
assert result.finish_reason == "error"
# ---------------------------------------------------------------------------
# get_default_model
# ---------------------------------------------------------------------------
def test_get_default_model():
"""Test get_default_model method."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="my-custom-deployment",
api_key="k", api_base="https://r.com", default_model="my-deploy",
)
assert provider.get_default_model() == "my-custom-deployment"
if __name__ == "__main__":
# Run basic tests
print("Running basic Azure OpenAI provider tests...")
# Test initialization
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
)
print("✅ Provider initialization successful")
# Test URL building
url = provider._build_chat_url("my-deployment")
expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21"
assert url == expected
print("✅ URL building works correctly")
# Test headers
headers = provider._build_headers()
assert headers["api-key"] == "test-key"
assert headers["Content-Type"] == "application/json"
print("✅ Header building works correctly")
# Test payload preparation
messages = [{"role": "user", "content": "Test"}]
payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
print("✅ Payload preparation works correctly")
print("✅ All basic tests passed! Updated test file is working correctly.")
assert provider.get_default_model() == "my-deploy"

View File

@ -8,6 +8,7 @@ Validates that:
from __future__ import annotations
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
@ -53,6 +54,15 @@ def _fake_tool_call_response() -> SimpleNamespace:
return SimpleNamespace(choices=[choice], usage=usage)
class _StalledStream:
def __aiter__(self):
return self
async def __anext__(self):
await asyncio.sleep(3600)
raise StopAsyncIteration
def test_openrouter_spec_is_gateway() -> None:
spec = find_by_name("openrouter")
assert spec is not None
@ -214,3 +224,54 @@ def test_openai_model_passthrough() -> None:
spec=spec,
)
assert provider.get_default_model() == "gpt-4o"
def test_openai_compat_strips_message_level_reasoning_fields() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
sanitized = provider._sanitize_messages([
{
"role": "assistant",
"content": "done",
"reasoning_content": "hidden",
"extra_content": {"debug": True},
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "fn", "arguments": "{}"},
"extra_content": {"google": {"thought_signature": "sig"}},
}
],
}
])
assert "reasoning_content" not in sanitized[0]
assert "extra_content" not in sanitized[0]
assert sanitized[0]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
@pytest.mark.asyncio
async def test_openai_compat_stream_watchdog_returns_error_on_stall(monkeypatch) -> None:
monkeypatch.setenv("NANOBOT_STREAM_IDLE_TIMEOUT_S", "0")
mock_create = AsyncMock(return_value=_StalledStream())
spec = find_by_name("openai")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_create
provider = OpenAICompatProvider(
api_key="sk-test-key",
default_model="gpt-4o",
spec=spec,
)
result = await provider.chat_stream(
messages=[{"role": "user", "content": "hello"}],
model="gpt-4o",
)
assert result.finish_reason == "error"
assert result.content is not None
assert "stream stalled" in result.content

View File

@ -0,0 +1,522 @@
"""Tests for the shared openai_responses converters and parsers."""
from unittest.mock import MagicMock, patch
import pytest
from nanobot.providers.base import LLMResponse, ToolCallRequest
from nanobot.providers.openai_responses.converters import (
convert_messages,
convert_tools,
convert_user_message,
split_tool_call_id,
)
from nanobot.providers.openai_responses.parsing import (
consume_sdk_stream,
map_finish_reason,
parse_response_output,
)
# ======================================================================
# converters - split_tool_call_id
# ======================================================================
class TestSplitToolCallId:
def test_plain_id(self):
assert split_tool_call_id("call_abc") == ("call_abc", None)
def test_compound_id(self):
assert split_tool_call_id("call_abc|fc_1") == ("call_abc", "fc_1")
def test_compound_empty_item_id(self):
assert split_tool_call_id("call_abc|") == ("call_abc", None)
def test_none(self):
assert split_tool_call_id(None) == ("call_0", None)
def test_empty_string(self):
assert split_tool_call_id("") == ("call_0", None)
def test_non_string(self):
assert split_tool_call_id(42) == ("call_0", None)
# ======================================================================
# converters - convert_user_message
# ======================================================================
class TestConvertUserMessage:
def test_string_content(self):
result = convert_user_message("hello")
assert result == {"role": "user", "content": [{"type": "input_text", "text": "hello"}]}
def test_text_block(self):
result = convert_user_message([{"type": "text", "text": "hi"}])
assert result["content"] == [{"type": "input_text", "text": "hi"}]
def test_image_url_block(self):
result = convert_user_message([
{"type": "image_url", "image_url": {"url": "https://img.example/a.png"}},
])
assert result["content"] == [
{"type": "input_image", "image_url": "https://img.example/a.png", "detail": "auto"},
]
def test_mixed_text_and_image(self):
result = convert_user_message([
{"type": "text", "text": "what's this?"},
{"type": "image_url", "image_url": {"url": "https://img.example/b.png"}},
])
assert len(result["content"]) == 2
assert result["content"][0]["type"] == "input_text"
assert result["content"][1]["type"] == "input_image"
def test_empty_list_falls_back(self):
result = convert_user_message([])
assert result["content"] == [{"type": "input_text", "text": ""}]
def test_none_falls_back(self):
result = convert_user_message(None)
assert result["content"] == [{"type": "input_text", "text": ""}]
def test_image_without_url_skipped(self):
result = convert_user_message([{"type": "image_url", "image_url": {}}])
assert result["content"] == [{"type": "input_text", "text": ""}]
def test_meta_fields_not_leaked(self):
"""_meta on content blocks must never appear in converted output."""
result = convert_user_message([
{"type": "text", "text": "hi", "_meta": {"path": "/tmp/x"}},
])
assert "_meta" not in result["content"][0]
def test_non_dict_items_skipped(self):
result = convert_user_message(["just a string", 42])
assert result["content"] == [{"type": "input_text", "text": ""}]
# ======================================================================
# converters - convert_messages
# ======================================================================
class TestConvertMessages:
def test_system_extracted_as_instructions(self):
msgs = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hi"},
]
instructions, items = convert_messages(msgs)
assert instructions == "You are helpful."
assert len(items) == 1
assert items[0]["role"] == "user"
def test_multiple_system_messages_last_wins(self):
msgs = [
{"role": "system", "content": "first"},
{"role": "system", "content": "second"},
{"role": "user", "content": "x"},
]
instructions, _ = convert_messages(msgs)
assert instructions == "second"
def test_user_message_converted(self):
_, items = convert_messages([{"role": "user", "content": "hello"}])
assert items[0]["role"] == "user"
assert items[0]["content"][0]["type"] == "input_text"
def test_assistant_text_message(self):
_, items = convert_messages([
{"role": "assistant", "content": "I'll help"},
])
assert items[0]["type"] == "message"
assert items[0]["role"] == "assistant"
assert items[0]["content"][0]["type"] == "output_text"
assert items[0]["content"][0]["text"] == "I'll help"
def test_assistant_empty_content_skipped(self):
_, items = convert_messages([{"role": "assistant", "content": ""}])
assert len(items) == 0
def test_assistant_with_tool_calls(self):
_, items = convert_messages([{
"role": "assistant",
"content": None,
"tool_calls": [{
"id": "call_abc|fc_1",
"function": {"name": "get_weather", "arguments": '{"city":"SF"}'},
}],
}])
assert items[0]["type"] == "function_call"
assert items[0]["call_id"] == "call_abc"
assert items[0]["id"] == "fc_1"
assert items[0]["name"] == "get_weather"
def test_assistant_with_tool_calls_no_id(self):
"""Fallback IDs when tool_call.id is missing."""
_, items = convert_messages([{
"role": "assistant",
"content": None,
"tool_calls": [{"function": {"name": "f1", "arguments": "{}"}}],
}])
assert items[0]["call_id"] == "call_0"
assert items[0]["id"].startswith("fc_")
def test_tool_message(self):
_, items = convert_messages([{
"role": "tool",
"tool_call_id": "call_abc",
"content": "result text",
}])
assert items[0]["type"] == "function_call_output"
assert items[0]["call_id"] == "call_abc"
assert items[0]["output"] == "result text"
def test_tool_message_dict_content(self):
_, items = convert_messages([{
"role": "tool",
"tool_call_id": "call_1",
"content": {"key": "value"},
}])
assert items[0]["output"] == '{"key": "value"}'
def test_non_standard_keys_not_leaked(self):
"""Extra keys on messages must not appear in converted items."""
_, items = convert_messages([{
"role": "user",
"content": "hi",
"extra_field": "should vanish",
"_meta": {"path": "/tmp"},
}])
item = items[0]
assert "extra_field" not in str(item)
assert "_meta" not in str(item)
def test_full_conversation_roundtrip(self):
"""System + user + assistant(tool_call) + tool -> correct structure."""
msgs = [
{"role": "system", "content": "Be concise."},
{"role": "user", "content": "Weather in SF?"},
{
"role": "assistant", "content": None,
"tool_calls": [{
"id": "c1|fc1",
"function": {"name": "get_weather", "arguments": '{"city":"SF"}'},
}],
},
{"role": "tool", "tool_call_id": "c1", "content": '{"temp":72}'},
]
instructions, items = convert_messages(msgs)
assert instructions == "Be concise."
assert len(items) == 3 # user, function_call, function_call_output
assert items[0]["role"] == "user"
assert items[1]["type"] == "function_call"
assert items[2]["type"] == "function_call_output"
# ======================================================================
# converters - convert_tools
# ======================================================================
class TestConvertTools:
def test_standard_function_tool(self):
tools = [{"type": "function", "function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
}}]
result = convert_tools(tools)
assert len(result) == 1
assert result[0]["type"] == "function"
assert result[0]["name"] == "get_weather"
assert result[0]["description"] == "Get weather"
assert "properties" in result[0]["parameters"]
def test_tool_without_name_skipped(self):
tools = [{"type": "function", "function": {"parameters": {}}}]
assert convert_tools(tools) == []
def test_tool_without_function_wrapper(self):
"""Direct dict without type=function wrapper."""
tools = [{"name": "f1", "description": "d", "parameters": {}}]
result = convert_tools(tools)
assert result[0]["name"] == "f1"
def test_missing_optional_fields_default(self):
tools = [{"type": "function", "function": {"name": "f"}}]
result = convert_tools(tools)
assert result[0]["description"] == ""
assert result[0]["parameters"] == {}
def test_multiple_tools(self):
tools = [
{"type": "function", "function": {"name": "a", "parameters": {}}},
{"type": "function", "function": {"name": "b", "parameters": {}}},
]
assert len(convert_tools(tools)) == 2
# ======================================================================
# parsing - map_finish_reason
# ======================================================================
class TestMapFinishReason:
def test_completed(self):
assert map_finish_reason("completed") == "stop"
def test_incomplete(self):
assert map_finish_reason("incomplete") == "length"
def test_failed(self):
assert map_finish_reason("failed") == "error"
def test_cancelled(self):
assert map_finish_reason("cancelled") == "error"
def test_none_defaults_to_stop(self):
assert map_finish_reason(None) == "stop"
def test_unknown_defaults_to_stop(self):
assert map_finish_reason("some_new_status") == "stop"
# ======================================================================
# parsing - parse_response_output
# ======================================================================
class TestParseResponseOutput:
def test_text_response(self):
resp = {
"output": [{"type": "message", "role": "assistant",
"content": [{"type": "output_text", "text": "Hello!"}]}],
"status": "completed",
"usage": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
}
result = parse_response_output(resp)
assert result.content == "Hello!"
assert result.finish_reason == "stop"
assert result.usage == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
assert result.tool_calls == []
def test_tool_call_response(self):
resp = {
"output": [{
"type": "function_call",
"call_id": "call_1", "id": "fc_1",
"name": "get_weather",
"arguments": '{"city": "SF"}',
}],
"status": "completed",
"usage": {},
}
result = parse_response_output(resp)
assert result.content is None
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "get_weather"
assert result.tool_calls[0].arguments == {"city": "SF"}
assert result.tool_calls[0].id == "call_1|fc_1"
def test_malformed_tool_arguments_logged(self):
"""Malformed JSON arguments should log a warning and fallback."""
resp = {
"output": [{
"type": "function_call",
"call_id": "c1", "id": "fc1",
"name": "f", "arguments": "{bad json",
}],
"status": "completed", "usage": {},
}
with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger:
result = parse_response_output(resp)
assert result.tool_calls[0].arguments == {"raw": "{bad json"}
mock_logger.warning.assert_called_once()
assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args)
def test_reasoning_content_extracted(self):
resp = {
"output": [
{"type": "reasoning", "summary": [
{"type": "summary_text", "text": "I think "},
{"type": "summary_text", "text": "therefore I am."},
]},
{"type": "message", "role": "assistant",
"content": [{"type": "output_text", "text": "42"}]},
],
"status": "completed", "usage": {},
}
result = parse_response_output(resp)
assert result.content == "42"
assert result.reasoning_content == "I think therefore I am."
def test_empty_output(self):
resp = {"output": [], "status": "completed", "usage": {}}
result = parse_response_output(resp)
assert result.content is None
assert result.tool_calls == []
def test_incomplete_status(self):
resp = {"output": [], "status": "incomplete", "usage": {}}
result = parse_response_output(resp)
assert result.finish_reason == "length"
def test_sdk_model_object(self):
"""parse_response_output should handle SDK objects with model_dump()."""
mock = MagicMock()
mock.model_dump.return_value = {
"output": [{"type": "message", "role": "assistant",
"content": [{"type": "output_text", "text": "sdk"}]}],
"status": "completed",
"usage": {"input_tokens": 1, "output_tokens": 2, "total_tokens": 3},
}
result = parse_response_output(mock)
assert result.content == "sdk"
assert result.usage["prompt_tokens"] == 1
def test_usage_maps_responses_api_keys(self):
"""Responses API uses input_tokens/output_tokens, not prompt_tokens/completion_tokens."""
resp = {
"output": [],
"status": "completed",
"usage": {"input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
}
result = parse_response_output(resp)
assert result.usage["prompt_tokens"] == 100
assert result.usage["completion_tokens"] == 50
assert result.usage["total_tokens"] == 150
# ======================================================================
# parsing - consume_sdk_stream
# ======================================================================
class TestConsumeSdkStream:
@pytest.mark.asyncio
async def test_text_stream(self):
ev1 = MagicMock(type="response.output_text.delta", delta="Hello")
ev2 = MagicMock(type="response.output_text.delta", delta=" world")
resp_obj = MagicMock(status="completed", usage=None, output=[])
ev3 = MagicMock(type="response.completed", response=resp_obj)
async def stream():
for e in [ev1, ev2, ev3]:
yield e
content, tool_calls, finish_reason, usage, reasoning = await consume_sdk_stream(stream())
assert content == "Hello world"
assert tool_calls == []
assert finish_reason == "stop"
@pytest.mark.asyncio
async def test_on_content_delta_called(self):
ev1 = MagicMock(type="response.output_text.delta", delta="hi")
resp_obj = MagicMock(status="completed", usage=None, output=[])
ev2 = MagicMock(type="response.completed", response=resp_obj)
deltas = []
async def cb(text):
deltas.append(text)
async def stream():
for e in [ev1, ev2]:
yield e
await consume_sdk_stream(stream(), on_content_delta=cb)
assert deltas == ["hi"]
@pytest.mark.asyncio
async def test_tool_call_stream(self):
item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="")
item_added.name = "get_weather"
ev1 = MagicMock(type="response.output_item.added", item=item_added)
ev2 = MagicMock(type="response.function_call_arguments.delta", call_id="c1", delta='{"ci')
ev3 = MagicMock(type="response.function_call_arguments.done", call_id="c1", arguments='{"city":"SF"}')
item_done = MagicMock(type="function_call", call_id="c1", id="fc1", arguments='{"city":"SF"}')
item_done.name = "get_weather"
ev4 = MagicMock(type="response.output_item.done", item=item_done)
resp_obj = MagicMock(status="completed", usage=None, output=[])
ev5 = MagicMock(type="response.completed", response=resp_obj)
async def stream():
for e in [ev1, ev2, ev3, ev4, ev5]:
yield e
content, tool_calls, finish_reason, usage, reasoning = await consume_sdk_stream(stream())
assert content == ""
assert len(tool_calls) == 1
assert tool_calls[0].name == "get_weather"
assert tool_calls[0].arguments == {"city": "SF"}
@pytest.mark.asyncio
async def test_usage_extracted(self):
usage_obj = MagicMock(input_tokens=10, output_tokens=5, total_tokens=15)
resp_obj = MagicMock(status="completed", usage=usage_obj, output=[])
ev = MagicMock(type="response.completed", response=resp_obj)
async def stream():
yield ev
_, _, _, usage, _ = await consume_sdk_stream(stream())
assert usage == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
@pytest.mark.asyncio
async def test_reasoning_extracted(self):
summary_item = MagicMock(type="summary_text", text="thinking...")
reasoning_item = MagicMock(type="reasoning", summary=[summary_item])
resp_obj = MagicMock(status="completed", usage=None, output=[reasoning_item])
ev = MagicMock(type="response.completed", response=resp_obj)
async def stream():
yield ev
_, _, _, _, reasoning = await consume_sdk_stream(stream())
assert reasoning == "thinking..."
@pytest.mark.asyncio
async def test_error_event_raises(self):
ev = MagicMock(type="error", error="rate_limit_exceeded")
async def stream():
yield ev
with pytest.raises(RuntimeError, match="Response failed.*rate_limit_exceeded"):
await consume_sdk_stream(stream())
@pytest.mark.asyncio
async def test_failed_event_raises(self):
ev = MagicMock(type="response.failed", error="server_error")
async def stream():
yield ev
with pytest.raises(RuntimeError, match="Response failed.*server_error"):
await consume_sdk_stream(stream())
@pytest.mark.asyncio
async def test_malformed_tool_args_logged(self):
"""Malformed JSON in streaming tool args should log a warning."""
item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="")
item_added.name = "f"
ev1 = MagicMock(type="response.output_item.added", item=item_added)
ev2 = MagicMock(type="response.function_call_arguments.done", call_id="c1", arguments="{bad")
item_done = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="{bad")
item_done.name = "f"
ev3 = MagicMock(type="response.output_item.done", item=item_done)
resp_obj = MagicMock(status="completed", usage=None, output=[])
ev4 = MagicMock(type="response.completed", response=resp_obj)
async def stream():
for e in [ev1, ev2, ev3, ev4]:
yield e
with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger:
_, tool_calls, _, _, _ = await consume_sdk_stream(stream())
assert tool_calls[0].arguments == {"raw": "{bad"}
mock_logger.warning.assert_called_once()
assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args)

View File

@ -211,3 +211,56 @@ async def test_image_fallback_without_meta_uses_default_placeholder() -> None:
content = msg.get("content")
if isinstance(content, list):
assert any("[image omitted]" in (b.get("text") or "") for b in content)
@pytest.mark.asyncio
async def test_chat_with_retry_uses_retry_after_and_emits_wait_progress(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="429 rate limit, retry after 7s", finish_reason="error"),
LLMResponse(content="ok"),
])
delays: list[float] = []
progress: list[str] = []
async def _fake_sleep(delay: float) -> None:
delays.append(delay)
async def _progress(msg: str) -> None:
progress.append(msg)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(
messages=[{"role": "user", "content": "hello"}],
on_retry_wait=_progress,
)
assert response.content == "ok"
assert delays == [7.0]
assert progress and "7s" in progress[0]
@pytest.mark.asyncio
async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monkeypatch) -> None:
provider = ScriptedProvider([
*[LLMResponse(content="429 rate limit", finish_reason="error") for _ in range(10)],
LLMResponse(content="ok"),
])
delays: list[float] = []
async def _fake_sleep(delay: float) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(
messages=[{"role": "user", "content": "hello"}],
retry_mode="persistent",
)
assert response.finish_reason == "error"
assert response.content == "429 rate limit"
assert provider.calls == 10
assert delays == [1, 2, 4, 4, 4, 4, 4, 4, 4]

View File

@ -347,6 +347,8 @@ async def test_empty_response_retry_then_success(aiohttp_client) -> None:
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_empty_response_falls_back(aiohttp_client) -> None:
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
call_count = 0
async def always_empty(content, session_key="", channel="", chat_id=""):
@ -367,5 +369,5 @@ async def test_empty_response_falls_back(aiohttp_client) -> None:
)
assert resp.status == 200
body = await resp.json()
assert body["choices"][0]["message"]["content"] == "I've completed processing but have no response to give."
assert body["choices"][0]["message"]["content"] == EMPTY_FINAL_RESPONSE_MESSAGE
assert call_count == 2

View File

@ -196,7 +196,7 @@ async def test_execute_re_raises_external_cancellation() -> None:
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=10)
task = asyncio.create_task(wrapper.execute())
await started.wait()
await asyncio.wait_for(started.wait(), timeout=1.0)
task.cancel()