mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-04 10:22:33 +00:00
Merge remote-tracking branch 'origin/main' into pr-2646
This commit is contained in:
commit
f409337fcf
1
.gitignore
vendored
1
.gitignore
vendored
@ -2,6 +2,7 @@
|
||||
.assets
|
||||
.docs
|
||||
.env
|
||||
.web
|
||||
*.pyc
|
||||
dist/
|
||||
build/
|
||||
|
||||
17
README.md
17
README.md
@ -20,13 +20,20 @@
|
||||
|
||||
## 📢 News
|
||||
|
||||
> [!IMPORTANT]
|
||||
> **Security note:** Due to `litellm` supply chain poisoning, **please check your Python environment ASAP** and refer to this [advisory](https://github.com/HKUDS/nanobot/discussions/2445) for details. We have fully removed the `litellm` since **v0.1.4.post6**.
|
||||
|
||||
- **2026-04-02** 🧱 **Long-running tasks** run more reliably — core runtime hardening.
|
||||
- **2026-04-01** 🔑 GitHub Copilot auth restored; stricter workspace paths; OpenRouter Claude caching fix.
|
||||
- **2026-03-31** 🛰️ WeChat multimodal alignment, Discord/Matrix polish, Python SDK facade, MCP and tool fixes.
|
||||
- **2026-03-30** 🧩 OpenAI-compatible API tightened; composable agent lifecycle hooks.
|
||||
- **2026-03-29** 💬 WeChat voice, typing, QR/media resilience; fixed-session OpenAI-compatible API.
|
||||
- **2026-03-28** 📚 Provider docs refresh; skill template wording fix.
|
||||
- **2026-03-27** 🚀 Released **v0.1.4.post6** — architecture decoupling, litellm removal, end-to-end streaming, WeChat channel, and a security fix. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post6) for details.
|
||||
- **2026-03-26** 🏗️ Agent runner extracted and lifecycle hooks unified; stream delta coalescing at boundaries.
|
||||
- **2026-03-25** 🌏 StepFun provider, configurable timezone, Gemini thought signatures.
|
||||
- **2026-03-24** 🔧 WeChat compatibility, Feishu CardKit streaming, test suite restructured.
|
||||
|
||||
<details>
|
||||
<summary>Earlier news</summary>
|
||||
|
||||
- **2026-03-23** 🔧 Command routing refactored for plugins, WhatsApp/WeChat media, unified channel login CLI.
|
||||
- **2026-03-22** ⚡ End-to-end streaming, WeChat channel, Anthropic cache optimization, `/status` command.
|
||||
- **2026-03-21** 🔒 Replace `litellm` with native `openai` + `anthropic` SDKs. Please see [commit](https://github.com/HKUDS/nanobot/commit/3dfdab7).
|
||||
@ -34,10 +41,6 @@
|
||||
- **2026-03-19** 💬 Telegram gets more resilient under load; Feishu now renders code blocks properly.
|
||||
- **2026-03-18** 📷 Telegram can now send media via URL. Cron schedules show human-readable details.
|
||||
- **2026-03-17** ✨ Feishu formatting glow-up, Slack reacts when done, custom endpoints support extra headers, and image handling is more reliable.
|
||||
|
||||
<details>
|
||||
<summary>Earlier news</summary>
|
||||
|
||||
- **2026-03-16** 🚀 Released **v0.1.4.post5** — a refinement-focused release with stronger reliability and channel support, and a more dependable day-to-day experience. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post5) for details.
|
||||
- **2026-03-15** 🧩 DingTalk rich media, smarter built-in skills, and cleaner model compatibility.
|
||||
- **2026-03-14** 💬 Channel plugins, Feishu replies, and steadier MCP, QQ, and media handling.
|
||||
|
||||
@ -110,6 +110,20 @@ IMPORTANT: To send files (images, documents, audio, video) to the user, you MUST
|
||||
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
|
||||
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
|
||||
if isinstance(left, str) and isinstance(right, str):
|
||||
return f"{left}\n\n{right}" if left else right
|
||||
|
||||
def _to_blocks(value: Any) -> list[dict[str, Any]]:
|
||||
if isinstance(value, list):
|
||||
return [item if isinstance(item, dict) else {"type": "text", "text": str(item)} for item in value]
|
||||
if value is None:
|
||||
return []
|
||||
return [{"type": "text", "text": str(value)}]
|
||||
|
||||
return _to_blocks(left) + _to_blocks(right)
|
||||
|
||||
def _load_bootstrap_files(self) -> str:
|
||||
"""Load all bootstrap files from workspace."""
|
||||
parts = []
|
||||
@ -142,12 +156,17 @@ IMPORTANT: To send files (images, documents, audio, video) to the user, you MUST
|
||||
merged = f"{runtime_ctx}\n\n{user_content}"
|
||||
else:
|
||||
merged = [{"type": "text", "text": runtime_ctx}] + user_content
|
||||
|
||||
return [
|
||||
messages = [
|
||||
{"role": "system", "content": self.build_system_prompt(skill_names)},
|
||||
*history,
|
||||
{"role": current_role, "content": merged},
|
||||
]
|
||||
if messages[-1].get("role") == current_role:
|
||||
last = dict(messages[-1])
|
||||
last["content"] = self._merge_message_content(last.get("content"), merged)
|
||||
messages[-1] = last
|
||||
return messages
|
||||
messages.append({"role": current_role, "content": merged})
|
||||
return messages
|
||||
|
||||
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
|
||||
"""Build user message content with optional base64-encoded images."""
|
||||
|
||||
@ -29,8 +29,11 @@ from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
from nanobot.utils.helpers import image_placeholder_text, truncate_text
|
||||
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.config.schema import ChannelsConfig, ExecToolConfig, WebSearchConfig
|
||||
@ -38,11 +41,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class _LoopHook(AgentHook):
|
||||
"""Core lifecycle hook for the main agent loop.
|
||||
|
||||
Handles streaming delta relay, progress reporting, tool-call logging,
|
||||
and think-tag stripping for the built-in agent path.
|
||||
"""
|
||||
"""Core hook for the main loop."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -111,11 +110,7 @@ class _LoopHook(AgentHook):
|
||||
|
||||
|
||||
class _LoopHookChain(AgentHook):
|
||||
"""Run the core loop hook first, then best-effort extra hooks.
|
||||
|
||||
This preserves the historical failure behavior of ``_LoopHook`` while still
|
||||
letting user-supplied hooks opt into ``CompositeHook`` isolation.
|
||||
"""
|
||||
"""Run the core hook before extra hooks."""
|
||||
|
||||
__slots__ = ("_primary", "_extras")
|
||||
|
||||
@ -163,7 +158,7 @@ class AgentLoop:
|
||||
5. Sends responses back
|
||||
"""
|
||||
|
||||
_TOOL_RESULT_MAX_CHARS = 16_000
|
||||
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -171,8 +166,11 @@ class AgentLoop:
|
||||
provider: LLMProvider,
|
||||
workspace: Path,
|
||||
model: str | None = None,
|
||||
max_iterations: int = 40,
|
||||
context_window_tokens: int = 65_536,
|
||||
max_iterations: int | None = None,
|
||||
context_window_tokens: int | None = None,
|
||||
context_block_limit: int | None = None,
|
||||
max_tool_result_chars: int | None = None,
|
||||
provider_retry_mode: str = "standard",
|
||||
web_search_config: WebSearchConfig | None = None,
|
||||
web_proxy: str | None = None,
|
||||
exec_config: ExecToolConfig | None = None,
|
||||
@ -186,13 +184,27 @@ class AgentLoop:
|
||||
):
|
||||
from nanobot.config.schema import ExecToolConfig, WebSearchConfig
|
||||
|
||||
defaults = AgentDefaults()
|
||||
self.bus = bus
|
||||
self.channels_config = channels_config
|
||||
self.provider = provider
|
||||
self.workspace = workspace
|
||||
self.model = model or provider.get_default_model()
|
||||
self.max_iterations = max_iterations
|
||||
self.context_window_tokens = context_window_tokens
|
||||
self.max_iterations = (
|
||||
max_iterations if max_iterations is not None else defaults.max_tool_iterations
|
||||
)
|
||||
self.context_window_tokens = (
|
||||
context_window_tokens
|
||||
if context_window_tokens is not None
|
||||
else defaults.context_window_tokens
|
||||
)
|
||||
self.context_block_limit = context_block_limit
|
||||
self.max_tool_result_chars = (
|
||||
max_tool_result_chars
|
||||
if max_tool_result_chars is not None
|
||||
else defaults.max_tool_result_chars
|
||||
)
|
||||
self.provider_retry_mode = provider_retry_mode
|
||||
self.web_search_config = web_search_config or WebSearchConfig()
|
||||
self.web_proxy = web_proxy
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
@ -211,6 +223,7 @@ class AgentLoop:
|
||||
workspace=workspace,
|
||||
bus=bus,
|
||||
model=self.model,
|
||||
max_tool_result_chars=self.max_tool_result_chars,
|
||||
web_search_config=self.web_search_config,
|
||||
web_proxy=web_proxy,
|
||||
exec_config=self.exec_config,
|
||||
@ -322,6 +335,7 @@ class AgentLoop:
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||
*,
|
||||
session: Session | None = None,
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
message_id: str | None = None,
|
||||
@ -348,14 +362,27 @@ class AgentLoop:
|
||||
else loop_hook
|
||||
)
|
||||
|
||||
async def _checkpoint(payload: dict[str, Any]) -> None:
|
||||
if session is None:
|
||||
return
|
||||
self._set_runtime_checkpoint(session, payload)
|
||||
|
||||
result = await self.runner.run(AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=self.tools,
|
||||
model=self.model,
|
||||
max_iterations=self.max_iterations,
|
||||
max_tool_result_chars=self.max_tool_result_chars,
|
||||
hook=hook,
|
||||
error_message="Sorry, I encountered an error calling the AI model.",
|
||||
concurrent_tools=True,
|
||||
workspace=self.workspace,
|
||||
session_key=session.key if session else None,
|
||||
context_window_tokens=self.context_window_tokens,
|
||||
context_block_limit=self.context_block_limit,
|
||||
provider_retry_mode=self.provider_retry_mode,
|
||||
progress_callback=on_progress,
|
||||
checkpoint_callback=_checkpoint,
|
||||
))
|
||||
self._last_usage = result.usage
|
||||
if result.stop_reason == "max_iterations":
|
||||
@ -493,6 +520,8 @@ class AgentLoop:
|
||||
logger.info("Processing system message from {}", msg.sender_id)
|
||||
key = f"{channel}:{chat_id}"
|
||||
session = self.sessions.get_or_create(key)
|
||||
if self._restore_runtime_checkpoint(session):
|
||||
self.sessions.save(session)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||
history = session.get_history(max_messages=0)
|
||||
@ -503,10 +532,11 @@ class AgentLoop:
|
||||
current_role=current_role,
|
||||
)
|
||||
final_content, _, all_msgs = await self._run_agent_loop(
|
||||
messages, channel=channel, chat_id=chat_id,
|
||||
messages, session=session, channel=channel, chat_id=chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self._clear_runtime_checkpoint(session)
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||
@ -517,6 +547,8 @@ class AgentLoop:
|
||||
|
||||
key = session_key or msg.session_key
|
||||
session = self.sessions.get_or_create(key)
|
||||
if self._restore_runtime_checkpoint(session):
|
||||
self.sessions.save(session)
|
||||
|
||||
# Slash commands
|
||||
raw = msg.content.strip()
|
||||
@ -552,14 +584,16 @@ class AgentLoop:
|
||||
on_progress=on_progress or _bus_progress,
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
session=session,
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
)
|
||||
|
||||
if final_content is None:
|
||||
final_content = "I've completed processing but have no response to give."
|
||||
if final_content is None or not final_content.strip():
|
||||
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self._clear_runtime_checkpoint(session)
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
||||
|
||||
@ -577,12 +611,6 @@ class AgentLoop:
|
||||
metadata=meta,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _image_placeholder(block: dict[str, Any]) -> dict[str, str]:
|
||||
"""Convert an inline image block into a compact text placeholder."""
|
||||
path = (block.get("_meta") or {}).get("path", "")
|
||||
return {"type": "text", "text": f"[image: {path}]" if path else "[image]"}
|
||||
|
||||
def _sanitize_persisted_blocks(
|
||||
self,
|
||||
content: list[dict[str, Any]],
|
||||
@ -609,13 +637,14 @@ class AgentLoop:
|
||||
block.get("type") == "image_url"
|
||||
and block.get("image_url", {}).get("url", "").startswith("data:image/")
|
||||
):
|
||||
filtered.append(self._image_placeholder(block))
|
||||
path = (block.get("_meta") or {}).get("path", "")
|
||||
filtered.append({"type": "text", "text": image_placeholder_text(path)})
|
||||
continue
|
||||
|
||||
if block.get("type") == "text" and isinstance(block.get("text"), str):
|
||||
text = block["text"]
|
||||
if truncate_text and len(text) > self._TOOL_RESULT_MAX_CHARS:
|
||||
text = text[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
||||
if truncate_text and len(text) > self.max_tool_result_chars:
|
||||
text = truncate_text(text, self.max_tool_result_chars)
|
||||
filtered.append({**block, "text": text})
|
||||
continue
|
||||
|
||||
@ -632,8 +661,8 @@ class AgentLoop:
|
||||
if role == "assistant" and not content and not entry.get("tool_calls"):
|
||||
continue # skip empty assistant messages — they poison session context
|
||||
if role == "tool":
|
||||
if isinstance(content, str) and len(content) > self._TOOL_RESULT_MAX_CHARS:
|
||||
entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
||||
if isinstance(content, str) and len(content) > self.max_tool_result_chars:
|
||||
entry["content"] = truncate_text(content, self.max_tool_result_chars)
|
||||
elif isinstance(content, list):
|
||||
filtered = self._sanitize_persisted_blocks(content, truncate_text=True)
|
||||
if not filtered:
|
||||
@ -656,6 +685,78 @@ class AgentLoop:
|
||||
session.messages.append(entry)
|
||||
session.updated_at = datetime.now()
|
||||
|
||||
def _set_runtime_checkpoint(self, session: Session, payload: dict[str, Any]) -> None:
|
||||
"""Persist the latest in-flight turn state into session metadata."""
|
||||
session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload
|
||||
self.sessions.save(session)
|
||||
|
||||
def _clear_runtime_checkpoint(self, session: Session) -> None:
|
||||
if self._RUNTIME_CHECKPOINT_KEY in session.metadata:
|
||||
session.metadata.pop(self._RUNTIME_CHECKPOINT_KEY, None)
|
||||
|
||||
@staticmethod
|
||||
def _checkpoint_message_key(message: dict[str, Any]) -> tuple[Any, ...]:
|
||||
return (
|
||||
message.get("role"),
|
||||
message.get("content"),
|
||||
message.get("tool_call_id"),
|
||||
message.get("name"),
|
||||
message.get("tool_calls"),
|
||||
message.get("reasoning_content"),
|
||||
message.get("thinking_blocks"),
|
||||
)
|
||||
|
||||
def _restore_runtime_checkpoint(self, session: Session) -> bool:
|
||||
"""Materialize an unfinished turn into session history before a new request."""
|
||||
from datetime import datetime
|
||||
|
||||
checkpoint = session.metadata.get(self._RUNTIME_CHECKPOINT_KEY)
|
||||
if not isinstance(checkpoint, dict):
|
||||
return False
|
||||
|
||||
assistant_message = checkpoint.get("assistant_message")
|
||||
completed_tool_results = checkpoint.get("completed_tool_results") or []
|
||||
pending_tool_calls = checkpoint.get("pending_tool_calls") or []
|
||||
|
||||
restored_messages: list[dict[str, Any]] = []
|
||||
if isinstance(assistant_message, dict):
|
||||
restored = dict(assistant_message)
|
||||
restored.setdefault("timestamp", datetime.now().isoformat())
|
||||
restored_messages.append(restored)
|
||||
for message in completed_tool_results:
|
||||
if isinstance(message, dict):
|
||||
restored = dict(message)
|
||||
restored.setdefault("timestamp", datetime.now().isoformat())
|
||||
restored_messages.append(restored)
|
||||
for tool_call in pending_tool_calls:
|
||||
if not isinstance(tool_call, dict):
|
||||
continue
|
||||
tool_id = tool_call.get("id")
|
||||
name = ((tool_call.get("function") or {}).get("name")) or "tool"
|
||||
restored_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_id,
|
||||
"name": name,
|
||||
"content": "Error: Task interrupted before this tool finished.",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
})
|
||||
|
||||
overlap = 0
|
||||
max_overlap = min(len(session.messages), len(restored_messages))
|
||||
for size in range(max_overlap, 0, -1):
|
||||
existing = session.messages[-size:]
|
||||
restored = restored_messages[:size]
|
||||
if all(
|
||||
self._checkpoint_message_key(left) == self._checkpoint_message_key(right)
|
||||
for left, right in zip(existing, restored)
|
||||
):
|
||||
overlap = size
|
||||
break
|
||||
session.messages.extend(restored_messages[overlap:])
|
||||
|
||||
self._clear_runtime_checkpoint(session)
|
||||
return True
|
||||
|
||||
async def process_direct(
|
||||
self,
|
||||
content: str,
|
||||
|
||||
@ -4,20 +4,36 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.providers.base import LLMProvider, ToolCallRequest
|
||||
from nanobot.utils.helpers import build_assistant_message
|
||||
from nanobot.utils.helpers import (
|
||||
build_assistant_message,
|
||||
estimate_message_tokens,
|
||||
estimate_prompt_tokens_chain,
|
||||
find_legal_message_start,
|
||||
maybe_persist_tool_result,
|
||||
truncate_text,
|
||||
)
|
||||
from nanobot.utils.runtime import (
|
||||
EMPTY_FINAL_RESPONSE_MESSAGE,
|
||||
build_finalization_retry_message,
|
||||
ensure_nonempty_tool_result,
|
||||
is_blank_text,
|
||||
repeated_external_lookup_error,
|
||||
)
|
||||
|
||||
_DEFAULT_MAX_ITERATIONS_MESSAGE = (
|
||||
"I reached the maximum number of tool call iterations ({max_iterations}) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
|
||||
|
||||
|
||||
_SNIP_SAFETY_BUFFER = 1024
|
||||
@dataclass(slots=True)
|
||||
class AgentRunSpec:
|
||||
"""Configuration for a single agent execution."""
|
||||
@ -26,6 +42,7 @@ class AgentRunSpec:
|
||||
tools: ToolRegistry
|
||||
model: str
|
||||
max_iterations: int
|
||||
max_tool_result_chars: int
|
||||
temperature: float | None = None
|
||||
max_tokens: int | None = None
|
||||
reasoning_effort: str | None = None
|
||||
@ -34,6 +51,13 @@ class AgentRunSpec:
|
||||
max_iterations_message: str | None = None
|
||||
concurrent_tools: bool = False
|
||||
fail_on_tool_error: bool = False
|
||||
workspace: Path | None = None
|
||||
session_key: str | None = None
|
||||
context_window_tokens: int | None = None
|
||||
context_block_limit: int | None = None
|
||||
provider_retry_mode: str = "standard"
|
||||
progress_callback: Any | None = None
|
||||
checkpoint_callback: Any | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -60,91 +84,142 @@ class AgentRunner:
|
||||
messages = list(spec.initial_messages)
|
||||
final_content: str | None = None
|
||||
tools_used: list[str] = []
|
||||
usage: dict[str, int] = {}
|
||||
usage: dict[str, int] = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
error: str | None = None
|
||||
stop_reason = "completed"
|
||||
tool_events: list[dict[str, str]] = []
|
||||
external_lookup_counts: dict[str, int] = {}
|
||||
|
||||
for iteration in range(spec.max_iterations):
|
||||
try:
|
||||
messages = self._apply_tool_result_budget(spec, messages)
|
||||
messages_for_model = self._snip_history(spec, messages)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Context governance failed on turn {} for {}: {}; using raw messages",
|
||||
iteration,
|
||||
spec.session_key or "default",
|
||||
exc,
|
||||
)
|
||||
messages_for_model = messages
|
||||
context = AgentHookContext(iteration=iteration, messages=messages)
|
||||
await hook.before_iteration(context)
|
||||
kwargs: dict[str, Any] = {
|
||||
"messages": messages,
|
||||
"tools": spec.tools.get_definitions(),
|
||||
"model": spec.model,
|
||||
}
|
||||
if spec.temperature is not None:
|
||||
kwargs["temperature"] = spec.temperature
|
||||
if spec.max_tokens is not None:
|
||||
kwargs["max_tokens"] = spec.max_tokens
|
||||
if spec.reasoning_effort is not None:
|
||||
kwargs["reasoning_effort"] = spec.reasoning_effort
|
||||
|
||||
if hook.wants_streaming():
|
||||
async def _stream(delta: str) -> None:
|
||||
await hook.on_stream(context, delta)
|
||||
|
||||
response = await self.provider.chat_stream_with_retry(
|
||||
**kwargs,
|
||||
on_content_delta=_stream,
|
||||
)
|
||||
else:
|
||||
response = await self.provider.chat_with_retry(**kwargs)
|
||||
|
||||
raw_usage = response.usage or {}
|
||||
response = await self._request_model(spec, messages_for_model, hook, context)
|
||||
raw_usage = self._usage_dict(response.usage)
|
||||
context.response = response
|
||||
context.usage = raw_usage
|
||||
context.usage = dict(raw_usage)
|
||||
context.tool_calls = list(response.tool_calls)
|
||||
# Accumulate standard fields into result usage.
|
||||
usage["prompt_tokens"] = usage.get("prompt_tokens", 0) + int(raw_usage.get("prompt_tokens", 0) or 0)
|
||||
usage["completion_tokens"] = usage.get("completion_tokens", 0) + int(raw_usage.get("completion_tokens", 0) or 0)
|
||||
cached = raw_usage.get("cached_tokens")
|
||||
if cached:
|
||||
usage["cached_tokens"] = usage.get("cached_tokens", 0) + int(cached)
|
||||
self._accumulate_usage(usage, raw_usage)
|
||||
|
||||
if response.has_tool_calls:
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=True)
|
||||
|
||||
messages.append(build_assistant_message(
|
||||
assistant_message = build_assistant_message(
|
||||
response.content or "",
|
||||
tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls],
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
))
|
||||
)
|
||||
messages.append(assistant_message)
|
||||
tools_used.extend(tc.name for tc in response.tool_calls)
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
"phase": "awaiting_tools",
|
||||
"iteration": iteration,
|
||||
"model": spec.model,
|
||||
"assistant_message": assistant_message,
|
||||
"completed_tool_results": [],
|
||||
"pending_tool_calls": [tc.to_openai_tool_call() for tc in response.tool_calls],
|
||||
},
|
||||
)
|
||||
|
||||
await hook.before_execute_tools(context)
|
||||
|
||||
results, new_events, fatal_error = await self._execute_tools(spec, response.tool_calls)
|
||||
results, new_events, fatal_error = await self._execute_tools(
|
||||
spec,
|
||||
response.tool_calls,
|
||||
external_lookup_counts,
|
||||
)
|
||||
tool_events.extend(new_events)
|
||||
context.tool_results = list(results)
|
||||
context.tool_events = list(new_events)
|
||||
if fatal_error is not None:
|
||||
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
|
||||
final_content = error
|
||||
stop_reason = "tool_error"
|
||||
self._append_final_message(messages, final_content)
|
||||
context.final_content = final_content
|
||||
context.error = error
|
||||
context.stop_reason = stop_reason
|
||||
await hook.after_iteration(context)
|
||||
break
|
||||
completed_tool_results: list[dict[str, Any]] = []
|
||||
for tool_call, result in zip(response.tool_calls, results):
|
||||
messages.append({
|
||||
tool_message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"name": tool_call.name,
|
||||
"content": result,
|
||||
})
|
||||
"content": self._normalize_tool_result(
|
||||
spec,
|
||||
tool_call.id,
|
||||
tool_call.name,
|
||||
result,
|
||||
),
|
||||
}
|
||||
messages.append(tool_message)
|
||||
completed_tool_results.append(tool_message)
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
"phase": "tools_completed",
|
||||
"iteration": iteration,
|
||||
"model": spec.model,
|
||||
"assistant_message": assistant_message,
|
||||
"completed_tool_results": completed_tool_results,
|
||||
"pending_tool_calls": [],
|
||||
},
|
||||
)
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
|
||||
clean = hook.finalize_content(context, response.content)
|
||||
if response.finish_reason != "error" and is_blank_text(clean):
|
||||
logger.warning(
|
||||
"Empty final response on turn {} for {}; retrying with explicit finalization prompt",
|
||||
iteration,
|
||||
spec.session_key or "default",
|
||||
)
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=False)
|
||||
response = await self._request_finalization_retry(spec, messages_for_model)
|
||||
retry_usage = self._usage_dict(response.usage)
|
||||
self._accumulate_usage(usage, retry_usage)
|
||||
raw_usage = self._merge_usage(raw_usage, retry_usage)
|
||||
context.response = response
|
||||
context.usage = dict(raw_usage)
|
||||
context.tool_calls = list(response.tool_calls)
|
||||
clean = hook.finalize_content(context, response.content)
|
||||
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=False)
|
||||
|
||||
clean = hook.finalize_content(context, response.content)
|
||||
if response.finish_reason == "error":
|
||||
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
|
||||
stop_reason = "error"
|
||||
error = final_content
|
||||
self._append_final_message(messages, final_content)
|
||||
context.final_content = final_content
|
||||
context.error = error
|
||||
context.stop_reason = stop_reason
|
||||
await hook.after_iteration(context)
|
||||
break
|
||||
if is_blank_text(clean):
|
||||
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
stop_reason = "empty_final_response"
|
||||
error = final_content
|
||||
self._append_final_message(messages, final_content)
|
||||
context.final_content = final_content
|
||||
context.error = error
|
||||
context.stop_reason = stop_reason
|
||||
@ -156,6 +231,17 @@ class AgentRunner:
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
))
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
"phase": "final_response",
|
||||
"iteration": iteration,
|
||||
"model": spec.model,
|
||||
"assistant_message": messages[-1],
|
||||
"completed_tool_results": [],
|
||||
"pending_tool_calls": [],
|
||||
},
|
||||
)
|
||||
final_content = clean
|
||||
context.final_content = final_content
|
||||
context.stop_reason = stop_reason
|
||||
@ -165,6 +251,7 @@ class AgentRunner:
|
||||
stop_reason = "max_iterations"
|
||||
template = spec.max_iterations_message or _DEFAULT_MAX_ITERATIONS_MESSAGE
|
||||
final_content = template.format(max_iterations=spec.max_iterations)
|
||||
self._append_final_message(messages, final_content)
|
||||
|
||||
return AgentRunResult(
|
||||
final_content=final_content,
|
||||
@ -176,21 +263,101 @@ class AgentRunner:
|
||||
tool_events=tool_events,
|
||||
)
|
||||
|
||||
def _build_request_kwargs(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
messages: list[dict[str, Any]],
|
||||
*,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
) -> dict[str, Any]:
|
||||
kwargs: dict[str, Any] = {
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"model": spec.model,
|
||||
"retry_mode": spec.provider_retry_mode,
|
||||
"on_retry_wait": spec.progress_callback,
|
||||
}
|
||||
if spec.temperature is not None:
|
||||
kwargs["temperature"] = spec.temperature
|
||||
if spec.max_tokens is not None:
|
||||
kwargs["max_tokens"] = spec.max_tokens
|
||||
if spec.reasoning_effort is not None:
|
||||
kwargs["reasoning_effort"] = spec.reasoning_effort
|
||||
return kwargs
|
||||
|
||||
async def _request_model(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
messages: list[dict[str, Any]],
|
||||
hook: AgentHook,
|
||||
context: AgentHookContext,
|
||||
):
|
||||
kwargs = self._build_request_kwargs(
|
||||
spec,
|
||||
messages,
|
||||
tools=spec.tools.get_definitions(),
|
||||
)
|
||||
if hook.wants_streaming():
|
||||
async def _stream(delta: str) -> None:
|
||||
await hook.on_stream(context, delta)
|
||||
|
||||
return await self.provider.chat_stream_with_retry(
|
||||
**kwargs,
|
||||
on_content_delta=_stream,
|
||||
)
|
||||
return await self.provider.chat_with_retry(**kwargs)
|
||||
|
||||
async def _request_finalization_retry(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
messages: list[dict[str, Any]],
|
||||
):
|
||||
retry_messages = list(messages)
|
||||
retry_messages.append(build_finalization_retry_message())
|
||||
kwargs = self._build_request_kwargs(spec, retry_messages, tools=None)
|
||||
return await self.provider.chat_with_retry(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _usage_dict(usage: dict[str, Any] | None) -> dict[str, int]:
|
||||
if not usage:
|
||||
return {}
|
||||
result: dict[str, int] = {}
|
||||
for key, value in usage.items():
|
||||
try:
|
||||
result[key] = int(value or 0)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _accumulate_usage(target: dict[str, int], addition: dict[str, int]) -> None:
|
||||
for key, value in addition.items():
|
||||
target[key] = target.get(key, 0) + value
|
||||
|
||||
@staticmethod
|
||||
def _merge_usage(left: dict[str, int], right: dict[str, int]) -> dict[str, int]:
|
||||
merged = dict(left)
|
||||
for key, value in right.items():
|
||||
merged[key] = merged.get(key, 0) + value
|
||||
return merged
|
||||
|
||||
async def _execute_tools(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
tool_calls: list[ToolCallRequest],
|
||||
external_lookup_counts: dict[str, int],
|
||||
) -> tuple[list[Any], list[dict[str, str]], BaseException | None]:
|
||||
if spec.concurrent_tools:
|
||||
tool_results = await asyncio.gather(*(
|
||||
self._run_tool(spec, tool_call)
|
||||
for tool_call in tool_calls
|
||||
))
|
||||
else:
|
||||
tool_results = [
|
||||
await self._run_tool(spec, tool_call)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
batches = self._partition_tool_batches(spec, tool_calls)
|
||||
tool_results: list[tuple[Any, dict[str, str], BaseException | None]] = []
|
||||
for batch in batches:
|
||||
if spec.concurrent_tools and len(batch) > 1:
|
||||
tool_results.extend(await asyncio.gather(*(
|
||||
self._run_tool(spec, tool_call, external_lookup_counts)
|
||||
for tool_call in batch
|
||||
)))
|
||||
else:
|
||||
for tool_call in batch:
|
||||
tool_results.append(await self._run_tool(spec, tool_call, external_lookup_counts))
|
||||
|
||||
results: list[Any] = []
|
||||
events: list[dict[str, str]] = []
|
||||
@ -206,9 +373,44 @@ class AgentRunner:
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
tool_call: ToolCallRequest,
|
||||
external_lookup_counts: dict[str, int],
|
||||
) -> tuple[Any, dict[str, str], BaseException | None]:
|
||||
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
||||
lookup_error = repeated_external_lookup_error(
|
||||
tool_call.name,
|
||||
tool_call.arguments,
|
||||
external_lookup_counts,
|
||||
)
|
||||
if lookup_error:
|
||||
event = {
|
||||
"name": tool_call.name,
|
||||
"status": "error",
|
||||
"detail": "repeated external lookup blocked",
|
||||
}
|
||||
if spec.fail_on_tool_error:
|
||||
return lookup_error + _HINT, event, RuntimeError(lookup_error)
|
||||
return lookup_error + _HINT, event, None
|
||||
prepare_call = getattr(spec.tools, "prepare_call", None)
|
||||
tool, params, prep_error = None, tool_call.arguments, None
|
||||
if callable(prepare_call):
|
||||
try:
|
||||
prepared = prepare_call(tool_call.name, tool_call.arguments)
|
||||
if isinstance(prepared, tuple) and len(prepared) == 3:
|
||||
tool, params, prep_error = prepared
|
||||
except Exception:
|
||||
pass
|
||||
if prep_error:
|
||||
event = {
|
||||
"name": tool_call.name,
|
||||
"status": "error",
|
||||
"detail": prep_error.split(": ", 1)[-1][:120],
|
||||
}
|
||||
return prep_error + _HINT, event, RuntimeError(prep_error) if spec.fail_on_tool_error else None
|
||||
try:
|
||||
result = await spec.tools.execute(tool_call.name, tool_call.arguments)
|
||||
if tool is not None:
|
||||
result = await tool.execute(**params)
|
||||
else:
|
||||
result = await spec.tools.execute(tool_call.name, params)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except BaseException as exc:
|
||||
@ -221,14 +423,178 @@ class AgentRunner:
|
||||
return f"Error: {type(exc).__name__}: {exc}", event, exc
|
||||
return f"Error: {type(exc).__name__}: {exc}", event, None
|
||||
|
||||
if isinstance(result, str) and result.startswith("Error"):
|
||||
event = {
|
||||
"name": tool_call.name,
|
||||
"status": "error",
|
||||
"detail": result.replace("\n", " ").strip()[:120],
|
||||
}
|
||||
if spec.fail_on_tool_error:
|
||||
return result + _HINT, event, RuntimeError(result)
|
||||
return result + _HINT, event, None
|
||||
|
||||
detail = "" if result is None else str(result)
|
||||
detail = detail.replace("\n", " ").strip()
|
||||
if not detail:
|
||||
detail = "(empty)"
|
||||
elif len(detail) > 120:
|
||||
detail = detail[:120] + "..."
|
||||
return result, {
|
||||
"name": tool_call.name,
|
||||
"status": "error" if isinstance(result, str) and result.startswith("Error") else "ok",
|
||||
"detail": detail,
|
||||
}, None
|
||||
return result, {"name": tool_call.name, "status": "ok", "detail": detail}, None
|
||||
|
||||
async def _emit_checkpoint(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
payload: dict[str, Any],
|
||||
) -> None:
|
||||
callback = spec.checkpoint_callback
|
||||
if callback is not None:
|
||||
await callback(payload)
|
||||
|
||||
@staticmethod
|
||||
def _append_final_message(messages: list[dict[str, Any]], content: str | None) -> None:
|
||||
if not content:
|
||||
return
|
||||
if (
|
||||
messages
|
||||
and messages[-1].get("role") == "assistant"
|
||||
and not messages[-1].get("tool_calls")
|
||||
):
|
||||
if messages[-1].get("content") == content:
|
||||
return
|
||||
messages[-1] = build_assistant_message(content)
|
||||
return
|
||||
messages.append(build_assistant_message(content))
|
||||
|
||||
def _normalize_tool_result(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
result: Any,
|
||||
) -> Any:
|
||||
result = ensure_nonempty_tool_result(tool_name, result)
|
||||
try:
|
||||
content = maybe_persist_tool_result(
|
||||
spec.workspace,
|
||||
spec.session_key,
|
||||
tool_call_id,
|
||||
result,
|
||||
max_chars=spec.max_tool_result_chars,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Tool result persist failed for {} in {}: {}; using raw result",
|
||||
tool_call_id,
|
||||
spec.session_key or "default",
|
||||
exc,
|
||||
)
|
||||
content = result
|
||||
if isinstance(content, str) and len(content) > spec.max_tool_result_chars:
|
||||
return truncate_text(content, spec.max_tool_result_chars)
|
||||
return content
|
||||
|
||||
def _apply_tool_result_budget(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
updated = messages
|
||||
for idx, message in enumerate(messages):
|
||||
if message.get("role") != "tool":
|
||||
continue
|
||||
normalized = self._normalize_tool_result(
|
||||
spec,
|
||||
str(message.get("tool_call_id") or f"tool_{idx}"),
|
||||
str(message.get("name") or "tool"),
|
||||
message.get("content"),
|
||||
)
|
||||
if normalized != message.get("content"):
|
||||
if updated is messages:
|
||||
updated = [dict(m) for m in messages]
|
||||
updated[idx]["content"] = normalized
|
||||
return updated
|
||||
|
||||
def _snip_history(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
if not messages or not spec.context_window_tokens:
|
||||
return messages
|
||||
|
||||
provider_max_tokens = getattr(getattr(self.provider, "generation", None), "max_tokens", 4096)
|
||||
max_output = spec.max_tokens if isinstance(spec.max_tokens, int) else (
|
||||
provider_max_tokens if isinstance(provider_max_tokens, int) else 4096
|
||||
)
|
||||
budget = spec.context_block_limit or (
|
||||
spec.context_window_tokens - max_output - _SNIP_SAFETY_BUFFER
|
||||
)
|
||||
if budget <= 0:
|
||||
return messages
|
||||
|
||||
estimate, _ = estimate_prompt_tokens_chain(
|
||||
self.provider,
|
||||
spec.model,
|
||||
messages,
|
||||
spec.tools.get_definitions(),
|
||||
)
|
||||
if estimate <= budget:
|
||||
return messages
|
||||
|
||||
system_messages = [dict(msg) for msg in messages if msg.get("role") == "system"]
|
||||
non_system = [dict(msg) for msg in messages if msg.get("role") != "system"]
|
||||
if not non_system:
|
||||
return messages
|
||||
|
||||
system_tokens = sum(estimate_message_tokens(msg) for msg in system_messages)
|
||||
remaining_budget = max(128, budget - system_tokens)
|
||||
kept: list[dict[str, Any]] = []
|
||||
kept_tokens = 0
|
||||
for message in reversed(non_system):
|
||||
msg_tokens = estimate_message_tokens(message)
|
||||
if kept and kept_tokens + msg_tokens > remaining_budget:
|
||||
break
|
||||
kept.append(message)
|
||||
kept_tokens += msg_tokens
|
||||
kept.reverse()
|
||||
|
||||
if kept:
|
||||
for i, message in enumerate(kept):
|
||||
if message.get("role") == "user":
|
||||
kept = kept[i:]
|
||||
break
|
||||
start = find_legal_message_start(kept)
|
||||
if start:
|
||||
kept = kept[start:]
|
||||
if not kept:
|
||||
kept = non_system[-min(len(non_system), 4) :]
|
||||
start = find_legal_message_start(kept)
|
||||
if start:
|
||||
kept = kept[start:]
|
||||
return system_messages + kept
|
||||
|
||||
def _partition_tool_batches(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
tool_calls: list[ToolCallRequest],
|
||||
) -> list[list[ToolCallRequest]]:
|
||||
if not spec.concurrent_tools:
|
||||
return [[tool_call] for tool_call in tool_calls]
|
||||
|
||||
batches: list[list[ToolCallRequest]] = []
|
||||
current: list[ToolCallRequest] = []
|
||||
for tool_call in tool_calls:
|
||||
get_tool = getattr(spec.tools, "get", None)
|
||||
tool = get_tool(tool_call.name) if callable(get_tool) else None
|
||||
can_batch = bool(tool and tool.concurrency_safe)
|
||||
if can_batch:
|
||||
current.append(tool_call)
|
||||
continue
|
||||
if current:
|
||||
batches.append(current)
|
||||
current = []
|
||||
batches.append([tool_call])
|
||||
if current:
|
||||
batches.append(current)
|
||||
return batches
|
||||
|
||||
|
||||
@ -44,6 +44,7 @@ class SubagentManager:
|
||||
provider: LLMProvider,
|
||||
workspace: Path,
|
||||
bus: MessageBus,
|
||||
max_tool_result_chars: int,
|
||||
model: str | None = None,
|
||||
web_search_config: "WebSearchConfig | None" = None,
|
||||
web_proxy: str | None = None,
|
||||
@ -56,6 +57,7 @@ class SubagentManager:
|
||||
self.workspace = workspace
|
||||
self.bus = bus
|
||||
self.model = model or provider.get_default_model()
|
||||
self.max_tool_result_chars = max_tool_result_chars
|
||||
self.web_search_config = web_search_config or WebSearchConfig()
|
||||
self.web_proxy = web_proxy
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
@ -136,6 +138,7 @@ class SubagentManager:
|
||||
tools=tools,
|
||||
model=self.model,
|
||||
max_iterations=15,
|
||||
max_tool_result_chars=self.max_tool_result_chars,
|
||||
hook=_SubagentHook(task_id),
|
||||
max_iterations_message="Task completed but no final response was generated.",
|
||||
error_message=None,
|
||||
|
||||
@ -53,6 +53,21 @@ class Tool(ABC):
|
||||
"""JSON Schema for tool parameters."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
"""Whether this tool is side-effect free and safe to parallelize."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def concurrency_safe(self) -> bool:
|
||||
"""Whether this tool can run alongside other concurrency-safe tools."""
|
||||
return self.read_only and not self.exclusive
|
||||
|
||||
@property
|
||||
def exclusive(self) -> bool:
|
||||
"""Whether this tool should run alone even if concurrency is enabled."""
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs: Any) -> Any:
|
||||
"""
|
||||
|
||||
@ -73,6 +73,10 @@ class ReadFileTool(_FsTool):
|
||||
"Use offset and limit to paginate through large files."
|
||||
)
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
@ -344,6 +348,10 @@ class ListDirTool(_FsTool):
|
||||
"Common noise directories (.git, node_modules, __pycache__, etc.) are auto-ignored."
|
||||
)
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
|
||||
@ -84,6 +84,9 @@ class MessageTool(Tool):
|
||||
media: list[str] | None = None,
|
||||
**kwargs: Any
|
||||
) -> str:
|
||||
from nanobot.utils.helpers import strip_think
|
||||
content = strip_think(content)
|
||||
|
||||
channel = channel or self._default_channel
|
||||
chat_id = chat_id or self._default_chat_id
|
||||
# Only inherit default message_id when targeting the same channel+chat.
|
||||
|
||||
@ -35,22 +35,35 @@ class ToolRegistry:
|
||||
"""Get all tool definitions in OpenAI format."""
|
||||
return [tool.to_schema() for tool in self._tools.values()]
|
||||
|
||||
def prepare_call(
|
||||
self,
|
||||
name: str,
|
||||
params: dict[str, Any],
|
||||
) -> tuple[Tool | None, dict[str, Any], str | None]:
|
||||
"""Resolve, cast, and validate one tool call."""
|
||||
tool = self._tools.get(name)
|
||||
if not tool:
|
||||
return None, params, (
|
||||
f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
||||
)
|
||||
|
||||
cast_params = tool.cast_params(params)
|
||||
errors = tool.validate_params(cast_params)
|
||||
if errors:
|
||||
return tool, cast_params, (
|
||||
f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors)
|
||||
)
|
||||
return tool, cast_params, None
|
||||
|
||||
async def execute(self, name: str, params: dict[str, Any]) -> Any:
|
||||
"""Execute a tool by name with given parameters."""
|
||||
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
||||
|
||||
tool = self._tools.get(name)
|
||||
if not tool:
|
||||
return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
||||
tool, params, error = self.prepare_call(name, params)
|
||||
if error:
|
||||
return error + _HINT
|
||||
|
||||
try:
|
||||
# Attempt to cast parameters to match schema types
|
||||
params = tool.cast_params(params)
|
||||
|
||||
# Validate parameters
|
||||
errors = tool.validate_params(params)
|
||||
if errors:
|
||||
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT
|
||||
assert tool is not None # guarded by prepare_call()
|
||||
result = await tool.execute(**params)
|
||||
if isinstance(result, str) and result.startswith("Error"):
|
||||
return result + _HINT
|
||||
|
||||
@ -52,6 +52,10 @@ class ExecTool(Tool):
|
||||
def description(self) -> str:
|
||||
return "Execute a shell command and return its output. Use with caution."
|
||||
|
||||
@property
|
||||
def exclusive(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
|
||||
@ -92,6 +92,10 @@ class WebSearchTool(Tool):
|
||||
self.config = config if config is not None else WebSearchConfig()
|
||||
self.proxy = proxy
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
return True
|
||||
|
||||
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
||||
provider = self.config.provider.strip().lower() or "brave"
|
||||
n = min(max(count or self.config.max_results, 1), 10)
|
||||
@ -234,6 +238,10 @@ class WebFetchTool(Tool):
|
||||
self.max_chars = max_chars
|
||||
self.proxy = proxy
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
return True
|
||||
|
||||
async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> Any:
|
||||
max_chars = maxChars or self.max_chars
|
||||
is_valid, error_msg = _validate_url_safe(url)
|
||||
|
||||
@ -14,6 +14,8 @@ from typing import Any
|
||||
from aiohttp import web
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
|
||||
API_SESSION_KEY = "api:default"
|
||||
API_CHAT_ID = "default"
|
||||
|
||||
@ -98,7 +100,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
||||
|
||||
logger.info("API request session_key={} content={}", session_key, user_content[:80])
|
||||
|
||||
_FALLBACK = "I've completed processing but have no response to give."
|
||||
_FALLBACK = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
|
||||
try:
|
||||
async with session_lock:
|
||||
|
||||
@ -134,6 +134,7 @@ class QQConfig(Base):
|
||||
secret: str = ""
|
||||
allow_from: list[str] = Field(default_factory=list)
|
||||
msg_format: Literal["plain", "markdown"] = "plain"
|
||||
ack_message: str = "⏳ Processing..."
|
||||
|
||||
# Optional: directory to save inbound attachments. If empty, use nanobot get_media_dir("qq").
|
||||
media_dir: str = ""
|
||||
@ -484,6 +485,17 @@ class QQChannel(BaseChannel):
|
||||
if not content and not media_paths:
|
||||
return
|
||||
|
||||
if self.config.ack_message:
|
||||
try:
|
||||
await self._send_text_only(
|
||||
chat_id=chat_id,
|
||||
is_group=is_group,
|
||||
msg_id=data.id,
|
||||
content=self.config.ack_message,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("QQ ack message failed for chat_id={}", chat_id)
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=user_id,
|
||||
chat_id=chat_id,
|
||||
|
||||
@ -275,13 +275,10 @@ class TelegramChannel(BaseChannel):
|
||||
self._app = builder.build()
|
||||
self._app.add_error_handler(self._on_error)
|
||||
|
||||
# Add command handlers
|
||||
self._app.add_handler(CommandHandler("start", self._on_start))
|
||||
self._app.add_handler(CommandHandler("new", self._forward_command))
|
||||
self._app.add_handler(CommandHandler("stop", self._forward_command))
|
||||
self._app.add_handler(CommandHandler("restart", self._forward_command))
|
||||
self._app.add_handler(CommandHandler("status", self._forward_command))
|
||||
self._app.add_handler(CommandHandler("help", self._on_help))
|
||||
# Add command handlers (using Regex to support @username suffixes before bot initialization)
|
||||
self._app.add_handler(MessageHandler(filters.Regex(r"^/start(?:@\w+)?$"), self._on_start))
|
||||
self._app.add_handler(MessageHandler(filters.Regex(r"^/(new|stop|restart|status)(?:@\w+)?$"), self._forward_command))
|
||||
self._app.add_handler(MessageHandler(filters.Regex(r"^/help(?:@\w+)?$"), self._on_help))
|
||||
|
||||
# Add message handler for text, photos, voice, documents
|
||||
self._app.add_handler(
|
||||
@ -313,7 +310,7 @@ class TelegramChannel(BaseChannel):
|
||||
# Start polling (this runs until stopped)
|
||||
await self._app.updater.start_polling(
|
||||
allowed_updates=["message"],
|
||||
drop_pending_updates=True # Ignore old messages on startup
|
||||
drop_pending_updates=False # Process pending messages on startup
|
||||
)
|
||||
|
||||
# Keep running until stopped
|
||||
@ -362,9 +359,14 @@ class TelegramChannel(BaseChannel):
|
||||
logger.warning("Telegram bot not running")
|
||||
return
|
||||
|
||||
# Only stop typing indicator for final responses
|
||||
# Only stop typing indicator and remove reaction for final responses
|
||||
if not msg.metadata.get("_progress", False):
|
||||
self._stop_typing(msg.chat_id)
|
||||
if reply_to_message_id := msg.metadata.get("message_id"):
|
||||
try:
|
||||
await self._remove_reaction(msg.chat_id, int(reply_to_message_id))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
chat_id = int(msg.chat_id)
|
||||
@ -435,7 +437,9 @@ class TelegramChannel(BaseChannel):
|
||||
await self._send_text(chat_id, chunk, reply_params, thread_kwargs)
|
||||
|
||||
async def _call_with_retry(self, fn, *args, **kwargs):
|
||||
"""Call an async Telegram API function with retry on pool/network timeout."""
|
||||
"""Call an async Telegram API function with retry on pool/network timeout and RetryAfter."""
|
||||
from telegram.error import RetryAfter
|
||||
|
||||
for attempt in range(1, _SEND_MAX_RETRIES + 1):
|
||||
try:
|
||||
return await fn(*args, **kwargs)
|
||||
@ -448,6 +452,15 @@ class TelegramChannel(BaseChannel):
|
||||
attempt, _SEND_MAX_RETRIES, delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
except RetryAfter as e:
|
||||
if attempt == _SEND_MAX_RETRIES:
|
||||
raise
|
||||
delay = float(e.retry_after)
|
||||
logger.warning(
|
||||
"Telegram Flood Control (attempt {}/{}), retrying in {:.1f}s",
|
||||
attempt, _SEND_MAX_RETRIES, delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
async def _send_text(
|
||||
self,
|
||||
@ -498,6 +511,11 @@ class TelegramChannel(BaseChannel):
|
||||
if stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id:
|
||||
return
|
||||
self._stop_typing(chat_id)
|
||||
if reply_to_message_id := meta.get("message_id"):
|
||||
try:
|
||||
await self._remove_reaction(chat_id, int(reply_to_message_id))
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
html = _markdown_to_telegram_html(buf.text)
|
||||
await self._call_with_retry(
|
||||
@ -619,8 +637,7 @@ class TelegramChannel(BaseChannel):
|
||||
"reply_to_message_id": getattr(reply_to, "message_id", None) if reply_to else None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _extract_reply_context(message) -> str | None:
|
||||
async def _extract_reply_context(self, message) -> str | None:
|
||||
"""Extract text from the message being replied to, if any."""
|
||||
reply = getattr(message, "reply_to_message", None)
|
||||
if not reply:
|
||||
@ -628,7 +645,21 @@ class TelegramChannel(BaseChannel):
|
||||
text = getattr(reply, "text", None) or getattr(reply, "caption", None) or ""
|
||||
if len(text) > TELEGRAM_REPLY_CONTEXT_MAX_LEN:
|
||||
text = text[:TELEGRAM_REPLY_CONTEXT_MAX_LEN] + "..."
|
||||
return f"[Reply to: {text}]" if text else None
|
||||
|
||||
if not text:
|
||||
return None
|
||||
|
||||
bot_id, _ = await self._ensure_bot_identity()
|
||||
reply_user = getattr(reply, "from_user", None)
|
||||
|
||||
if bot_id and reply_user and getattr(reply_user, "id", None) == bot_id:
|
||||
return f"[Reply to bot: {text}]"
|
||||
elif reply_user and getattr(reply_user, "username", None):
|
||||
return f"[Reply to @{reply_user.username}: {text}]"
|
||||
elif reply_user and getattr(reply_user, "first_name", None):
|
||||
return f"[Reply to {reply_user.first_name}: {text}]"
|
||||
else:
|
||||
return f"[Reply to: {text}]"
|
||||
|
||||
async def _download_message_media(
|
||||
self, msg, *, add_failure_content: bool = False
|
||||
@ -765,10 +796,18 @@ class TelegramChannel(BaseChannel):
|
||||
message = update.message
|
||||
user = update.effective_user
|
||||
self._remember_thread_context(message)
|
||||
|
||||
# Strip @bot_username suffix if present
|
||||
content = message.text or ""
|
||||
if content.startswith("/") and "@" in content:
|
||||
cmd_part, *rest = content.split(" ", 1)
|
||||
cmd_part = cmd_part.split("@")[0]
|
||||
content = f"{cmd_part} {rest[0]}" if rest else cmd_part
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=self._sender_id(user),
|
||||
chat_id=str(message.chat_id),
|
||||
content=message.text or "",
|
||||
content=content,
|
||||
metadata=self._build_message_metadata(message, user),
|
||||
session_key=self._derive_topic_session_key(message),
|
||||
)
|
||||
@ -812,7 +851,7 @@ class TelegramChannel(BaseChannel):
|
||||
# Reply context: text and/or media from the replied-to message
|
||||
reply = getattr(message, "reply_to_message", None)
|
||||
if reply is not None:
|
||||
reply_ctx = self._extract_reply_context(message)
|
||||
reply_ctx = await self._extract_reply_context(message)
|
||||
reply_media, reply_media_parts = await self._download_message_media(reply)
|
||||
if reply_media:
|
||||
media_paths = reply_media + media_paths
|
||||
@ -903,6 +942,19 @@ class TelegramChannel(BaseChannel):
|
||||
except Exception as e:
|
||||
logger.debug("Telegram reaction failed: {}", e)
|
||||
|
||||
async def _remove_reaction(self, chat_id: str, message_id: int) -> None:
|
||||
"""Remove emoji reaction from a message (best-effort, non-blocking)."""
|
||||
if not self._app:
|
||||
return
|
||||
try:
|
||||
await self._app.bot.set_message_reaction(
|
||||
chat_id=int(chat_id),
|
||||
message_id=message_id,
|
||||
reaction=[],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Telegram reaction removal failed: {}", e)
|
||||
|
||||
async def _typing_loop(self, chat_id: str) -> None:
|
||||
"""Repeatedly send 'typing' action until cancelled."""
|
||||
try:
|
||||
|
||||
@ -542,6 +542,9 @@ def serve(
|
||||
model=runtime_config.agents.defaults.model,
|
||||
max_iterations=runtime_config.agents.defaults.max_tool_iterations,
|
||||
context_window_tokens=runtime_config.agents.defaults.context_window_tokens,
|
||||
context_block_limit=runtime_config.agents.defaults.context_block_limit,
|
||||
max_tool_result_chars=runtime_config.agents.defaults.max_tool_result_chars,
|
||||
provider_retry_mode=runtime_config.agents.defaults.provider_retry_mode,
|
||||
web_search_config=runtime_config.tools.web.search,
|
||||
web_proxy=runtime_config.tools.web.proxy or None,
|
||||
exec_config=runtime_config.tools.exec,
|
||||
@ -629,6 +632,9 @@ def gateway(
|
||||
model=config.agents.defaults.model,
|
||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||
context_block_limit=config.agents.defaults.context_block_limit,
|
||||
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
|
||||
provider_retry_mode=config.agents.defaults.provider_retry_mode,
|
||||
web_search_config=config.tools.web.search,
|
||||
web_proxy=config.tools.web.proxy or None,
|
||||
exec_config=config.tools.exec,
|
||||
@ -835,6 +841,9 @@ def agent(
|
||||
model=config.agents.defaults.model,
|
||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||
context_block_limit=config.agents.defaults.context_block_limit,
|
||||
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
|
||||
provider_retry_mode=config.agents.defaults.provider_retry_mode,
|
||||
web_search_config=config.tools.web.search,
|
||||
web_proxy=config.tools.web.proxy or None,
|
||||
exec_config=config.tools.exec,
|
||||
@ -1023,12 +1032,18 @@ app.add_typer(channels_app, name="channels")
|
||||
|
||||
|
||||
@channels_app.command("status")
|
||||
def channels_status():
|
||||
def channels_status(
|
||||
config_path: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
|
||||
):
|
||||
"""Show channel status."""
|
||||
from nanobot.channels.registry import discover_all
|
||||
from nanobot.config.loader import load_config
|
||||
from nanobot.config.loader import load_config, set_config_path
|
||||
|
||||
config = load_config()
|
||||
resolved_config_path = Path(config_path).expanduser().resolve() if config_path else None
|
||||
if resolved_config_path is not None:
|
||||
set_config_path(resolved_config_path)
|
||||
|
||||
config = load_config(resolved_config_path)
|
||||
|
||||
table = Table(title="Channel Status")
|
||||
table.add_column("Channel", style="cyan")
|
||||
@ -1115,12 +1130,17 @@ def _get_bridge_dir() -> Path:
|
||||
def channels_login(
|
||||
channel_name: str = typer.Argument(..., help="Channel name (e.g. weixin, whatsapp)"),
|
||||
force: bool = typer.Option(False, "--force", "-f", help="Force re-authentication even if already logged in"),
|
||||
config_path: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
|
||||
):
|
||||
"""Authenticate with a channel via QR code or other interactive login."""
|
||||
from nanobot.channels.registry import discover_all
|
||||
from nanobot.config.loader import load_config
|
||||
from nanobot.config.loader import load_config, set_config_path
|
||||
|
||||
config = load_config()
|
||||
resolved_config_path = Path(config_path).expanduser().resolve() if config_path else None
|
||||
if resolved_config_path is not None:
|
||||
set_config_path(resolved_config_path)
|
||||
|
||||
config = load_config(resolved_config_path)
|
||||
channel_cfg = getattr(config.channels, channel_name, None) or {}
|
||||
|
||||
# Validate channel exists
|
||||
|
||||
@ -26,7 +26,10 @@ async def cmd_stop(ctx: CommandContext) -> OutboundMessage:
|
||||
sub_cancelled = await loop.subagents.cancel_by_session(msg.session_key)
|
||||
total = cancelled + sub_cancelled
|
||||
content = f"Stopped {total} task(s)." if total else "No active task to stop."
|
||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, content=content)
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=content,
|
||||
metadata=dict(msg.metadata or {})
|
||||
)
|
||||
|
||||
|
||||
async def cmd_restart(ctx: CommandContext) -> OutboundMessage:
|
||||
@ -38,7 +41,10 @@ async def cmd_restart(ctx: CommandContext) -> OutboundMessage:
|
||||
os.execv(sys.executable, [sys.executable, "-m", "nanobot"] + sys.argv[1:])
|
||||
|
||||
asyncio.create_task(_do_restart())
|
||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, content="Restarting...")
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content="Restarting...",
|
||||
metadata=dict(msg.metadata or {})
|
||||
)
|
||||
|
||||
|
||||
async def cmd_status(ctx: CommandContext) -> OutboundMessage:
|
||||
@ -62,7 +68,7 @@ async def cmd_status(ctx: CommandContext) -> OutboundMessage:
|
||||
session_msg_count=len(session.get_history(max_messages=0)),
|
||||
context_tokens_estimate=ctx_est,
|
||||
),
|
||||
metadata={"render_as": "text"},
|
||||
metadata={**dict(ctx.msg.metadata or {}), "render_as": "text"},
|
||||
)
|
||||
|
||||
|
||||
@ -79,6 +85,7 @@ async def cmd_new(ctx: CommandContext) -> OutboundMessage:
|
||||
return OutboundMessage(
|
||||
channel=ctx.msg.channel, chat_id=ctx.msg.chat_id,
|
||||
content="New session started.",
|
||||
metadata=dict(ctx.msg.metadata or {})
|
||||
)
|
||||
|
||||
|
||||
@ -88,7 +95,7 @@ async def cmd_help(ctx: CommandContext) -> OutboundMessage:
|
||||
channel=ctx.msg.channel,
|
||||
chat_id=ctx.msg.chat_id,
|
||||
content=build_help_text(),
|
||||
metadata={"render_as": "text"},
|
||||
metadata={**dict(ctx.msg.metadata or {}), "render_as": "text"},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -38,8 +38,11 @@ class AgentDefaults(Base):
|
||||
)
|
||||
max_tokens: int = 8192
|
||||
context_window_tokens: int = 65_536
|
||||
context_block_limit: int | None = None
|
||||
temperature: float = 0.1
|
||||
max_tool_iterations: int = 40
|
||||
max_tool_iterations: int = 200
|
||||
max_tool_result_chars: int = 16_000
|
||||
provider_retry_mode: Literal["standard", "persistent"] = "standard"
|
||||
reasoning_effort: str | None = None # low / medium / high - enables LLM thinking mode
|
||||
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"
|
||||
|
||||
|
||||
@ -73,6 +73,9 @@ class Nanobot:
|
||||
model=defaults.model,
|
||||
max_iterations=defaults.max_tool_iterations,
|
||||
context_window_tokens=defaults.context_window_tokens,
|
||||
context_block_limit=defaults.context_block_limit,
|
||||
max_tool_result_chars=defaults.max_tool_result_chars,
|
||||
provider_retry_mode=defaults.provider_retry_mode,
|
||||
web_search_config=config.tools.web.search,
|
||||
web_proxy=config.tools.web.proxy or None,
|
||||
exec_config=config.tools.exec,
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
import string
|
||||
@ -434,13 +436,33 @@ class AnthropicProvider(LLMProvider):
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
|
||||
try:
|
||||
async with self._client.messages.stream(**kwargs) as stream:
|
||||
if on_content_delta:
|
||||
async for text in stream.text_stream:
|
||||
stream_iter = stream.text_stream.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
text = await asyncio.wait_for(
|
||||
stream_iter.__anext__(),
|
||||
timeout=idle_timeout_s,
|
||||
)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
await on_content_delta(text)
|
||||
response = await stream.get_final_message()
|
||||
response = await asyncio.wait_for(
|
||||
stream.get_final_message(),
|
||||
timeout=idle_timeout_s,
|
||||
)
|
||||
return self._parse_response(response)
|
||||
except asyncio.TimeoutError:
|
||||
return LLMResponse(
|
||||
content=(
|
||||
f"Error calling LLM: stream stalled for more than "
|
||||
f"{idle_timeout_s} seconds"
|
||||
),
|
||||
finish_reason="error",
|
||||
)
|
||||
except Exception as e:
|
||||
return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error")
|
||||
|
||||
|
||||
@ -1,31 +1,36 @@
|
||||
"""Azure OpenAI provider implementation with API version 2024-10-21."""
|
||||
"""Azure OpenAI provider using the OpenAI SDK Responses API.
|
||||
|
||||
Uses ``AsyncOpenAI`` pointed at ``https://{endpoint}/openai/v1/`` which
|
||||
routes to the Responses API (``/responses``). Reuses shared conversion
|
||||
helpers from :mod:`nanobot.providers.openai_responses`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
import json_repair
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse
|
||||
from nanobot.providers.openai_responses import (
|
||||
consume_sdk_stream,
|
||||
convert_messages,
|
||||
convert_tools,
|
||||
parse_response_output,
|
||||
)
|
||||
|
||||
|
||||
class AzureOpenAIProvider(LLMProvider):
|
||||
"""
|
||||
Azure OpenAI provider with API version 2024-10-21 compliance.
|
||||
|
||||
"""Azure OpenAI provider backed by the Responses API.
|
||||
|
||||
Features:
|
||||
- Hardcoded API version 2024-10-21
|
||||
- Uses model field as Azure deployment name in URL path
|
||||
- Uses api-key header instead of Authorization Bearer
|
||||
- Uses max_completion_tokens instead of max_tokens
|
||||
- Direct HTTP calls, bypasses LiteLLM
|
||||
- Uses the OpenAI Python SDK (``AsyncOpenAI``) with
|
||||
``base_url = {endpoint}/openai/v1/``
|
||||
- Calls ``client.responses.create()`` (Responses API)
|
||||
- Reuses shared message/tool/SSE conversion from
|
||||
``openai_responses``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -36,40 +41,28 @@ class AzureOpenAIProvider(LLMProvider):
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = default_model
|
||||
self.api_version = "2024-10-21"
|
||||
|
||||
# Validate required parameters
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("Azure OpenAI api_key is required")
|
||||
if not api_base:
|
||||
raise ValueError("Azure OpenAI api_base is required")
|
||||
|
||||
# Ensure api_base ends with /
|
||||
if not api_base.endswith('/'):
|
||||
api_base += '/'
|
||||
|
||||
# Normalise: ensure trailing slash
|
||||
if not api_base.endswith("/"):
|
||||
api_base += "/"
|
||||
self.api_base = api_base
|
||||
|
||||
def _build_chat_url(self, deployment_name: str) -> str:
|
||||
"""Build the Azure OpenAI chat completions URL."""
|
||||
# Azure OpenAI URL format:
|
||||
# https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}
|
||||
base_url = self.api_base
|
||||
if not base_url.endswith('/'):
|
||||
base_url += '/'
|
||||
|
||||
url = urljoin(
|
||||
base_url,
|
||||
f"openai/deployments/{deployment_name}/chat/completions"
|
||||
# SDK client targeting the Azure Responses API endpoint
|
||||
base_url = f"{api_base.rstrip('/')}/openai/v1/"
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
default_headers={"x-session-affinity": uuid.uuid4().hex},
|
||||
)
|
||||
return f"{url}?api-version={self.api_version}"
|
||||
|
||||
def _build_headers(self) -> dict[str, str]:
|
||||
"""Build headers for Azure OpenAI API with api-key header."""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization
|
||||
"x-session-affinity": uuid.uuid4().hex, # For cache locality
|
||||
}
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _supports_temperature(
|
||||
@ -82,36 +75,51 @@ class AzureOpenAIProvider(LLMProvider):
|
||||
name = deployment_name.lower()
|
||||
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
|
||||
|
||||
def _prepare_request_payload(
|
||||
def _build_body(
|
||||
self,
|
||||
deployment_name: str,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
model: str | None,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
reasoning_effort: str | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
|
||||
payload: dict[str, Any] = {
|
||||
"messages": self._sanitize_request_messages(
|
||||
self._sanitize_empty_content(messages),
|
||||
_AZURE_MSG_KEYS,
|
||||
),
|
||||
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
|
||||
"""Build the Responses API request body from Chat-Completions-style args."""
|
||||
deployment = model or self.default_model
|
||||
instructions, input_items = convert_messages(self._sanitize_empty_content(messages))
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"model": deployment,
|
||||
"instructions": instructions or None,
|
||||
"input": input_items,
|
||||
"max_output_tokens": max(1, max_tokens),
|
||||
"store": False,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
if self._supports_temperature(deployment_name, reasoning_effort):
|
||||
payload["temperature"] = temperature
|
||||
if self._supports_temperature(deployment, reasoning_effort):
|
||||
body["temperature"] = temperature
|
||||
|
||||
if reasoning_effort:
|
||||
payload["reasoning_effort"] = reasoning_effort
|
||||
body["reasoning"] = {"effort": reasoning_effort}
|
||||
body["include"] = ["reasoning.encrypted_content"]
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
payload["tool_choice"] = tool_choice or "auto"
|
||||
body["tools"] = convert_tools(tools)
|
||||
body["tool_choice"] = tool_choice or "auto"
|
||||
|
||||
return payload
|
||||
return body
|
||||
|
||||
@staticmethod
|
||||
def _handle_error(e: Exception) -> LLMResponse:
|
||||
body = getattr(e, "body", None) or getattr(getattr(e, "response", None), "text", None)
|
||||
msg = f"Error: {str(body).strip()[:500]}" if body else f"Error calling Azure OpenAI: {e}"
|
||||
return LLMResponse(content=msg, finish_reason="error")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
@ -123,92 +131,15 @@ class AzureOpenAIProvider(LLMProvider):
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request to Azure OpenAI.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: Optional list of tool definitions in OpenAI format.
|
||||
model: Model identifier (used as deployment name).
|
||||
max_tokens: Maximum tokens in response (mapped to max_completion_tokens).
|
||||
temperature: Sampling temperature.
|
||||
reasoning_effort: Optional reasoning effort parameter.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
deployment_name = model or self.default_model
|
||||
url = self._build_chat_url(deployment_name)
|
||||
headers = self._build_headers()
|
||||
payload = self._prepare_request_payload(
|
||||
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort,
|
||||
tool_choice=tool_choice,
|
||||
body = self._build_body(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
|
||||
response = await client.post(url, headers=headers, json=payload)
|
||||
if response.status_code != 200:
|
||||
return LLMResponse(
|
||||
content=f"Azure OpenAI API Error {response.status_code}: {response.text}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
return self._parse_response(response_data)
|
||||
|
||||
response = await self._client.responses.create(**body)
|
||||
return parse_response_output(response)
|
||||
except Exception as e:
|
||||
return LLMResponse(
|
||||
content=f"Error calling Azure OpenAI: {repr(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
def _parse_response(self, response: dict[str, Any]) -> LLMResponse:
|
||||
"""Parse Azure OpenAI response into our standard format."""
|
||||
try:
|
||||
choice = response["choices"][0]
|
||||
message = choice["message"]
|
||||
|
||||
tool_calls = []
|
||||
if message.get("tool_calls"):
|
||||
for tc in message["tool_calls"]:
|
||||
# Parse arguments from JSON string if needed
|
||||
args = tc["function"]["arguments"]
|
||||
if isinstance(args, str):
|
||||
args = json_repair.loads(args)
|
||||
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=tc["id"],
|
||||
name=tc["function"]["name"],
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
|
||||
usage = {}
|
||||
if response.get("usage"):
|
||||
usage_data = response["usage"]
|
||||
usage = {
|
||||
"prompt_tokens": usage_data.get("prompt_tokens", 0),
|
||||
"completion_tokens": usage_data.get("completion_tokens", 0),
|
||||
"total_tokens": usage_data.get("total_tokens", 0),
|
||||
}
|
||||
|
||||
reasoning_content = message.get("reasoning_content") or None
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content"),
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=choice.get("finish_reason", "stop"),
|
||||
usage=usage,
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
|
||||
except (KeyError, IndexError) as e:
|
||||
return LLMResponse(
|
||||
content=f"Error parsing Azure OpenAI response: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
return self._handle_error(e)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
@ -221,89 +152,26 @@ class AzureOpenAIProvider(LLMProvider):
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Stream a chat completion via Azure OpenAI SSE."""
|
||||
deployment_name = model or self.default_model
|
||||
url = self._build_chat_url(deployment_name)
|
||||
headers = self._build_headers()
|
||||
payload = self._prepare_request_payload(
|
||||
deployment_name, messages, tools, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice=tool_choice,
|
||||
body = self._build_body(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
payload["stream"] = True
|
||||
body["stream"] = True
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
|
||||
async with client.stream("POST", url, headers=headers, json=payload) as response:
|
||||
if response.status_code != 200:
|
||||
text = await response.aread()
|
||||
return LLMResponse(
|
||||
content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}",
|
||||
finish_reason="error",
|
||||
)
|
||||
return await self._consume_stream(response, on_content_delta)
|
||||
except Exception as e:
|
||||
return LLMResponse(content=f"Error calling Azure OpenAI: {repr(e)}", finish_reason="error")
|
||||
|
||||
async def _consume_stream(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None,
|
||||
) -> LLMResponse:
|
||||
"""Parse Azure OpenAI SSE stream into an LLMResponse."""
|
||||
content_parts: list[str] = []
|
||||
tool_call_buffers: dict[int, dict[str, str]] = {}
|
||||
finish_reason = "stop"
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data = line[6:].strip()
|
||||
if data == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
choices = chunk.get("choices") or []
|
||||
if not choices:
|
||||
continue
|
||||
choice = choices[0]
|
||||
if choice.get("finish_reason"):
|
||||
finish_reason = choice["finish_reason"]
|
||||
delta = choice.get("delta") or {}
|
||||
|
||||
text = delta.get("content")
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
if on_content_delta:
|
||||
await on_content_delta(text)
|
||||
|
||||
for tc in delta.get("tool_calls") or []:
|
||||
idx = tc.get("index", 0)
|
||||
buf = tool_call_buffers.setdefault(idx, {"id": "", "name": "", "arguments": ""})
|
||||
if tc.get("id"):
|
||||
buf["id"] = tc["id"]
|
||||
fn = tc.get("function") or {}
|
||||
if fn.get("name"):
|
||||
buf["name"] = fn["name"]
|
||||
if fn.get("arguments"):
|
||||
buf["arguments"] += fn["arguments"]
|
||||
|
||||
tool_calls = [
|
||||
ToolCallRequest(
|
||||
id=buf["id"], name=buf["name"],
|
||||
arguments=json_repair.loads(buf["arguments"]) if buf["arguments"] else {},
|
||||
stream = await self._client.responses.create(**body)
|
||||
content, tool_calls, finish_reason, usage, reasoning_content = (
|
||||
await consume_sdk_stream(stream, on_content_delta)
|
||||
)
|
||||
for buf in tool_call_buffers.values()
|
||||
]
|
||||
|
||||
return LLMResponse(
|
||||
content="".join(content_parts) or None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
return LLMResponse(
|
||||
content=content or None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
except Exception as e:
|
||||
return self._handle_error(e)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model (also used as default deployment name)."""
|
||||
return self.default_model
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
@ -9,6 +10,8 @@ from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.utils.helpers import image_placeholder_text
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallRequest:
|
||||
@ -57,13 +60,7 @@ class LLMResponse:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenerationSettings:
|
||||
"""Default generation parameters for LLM calls.
|
||||
|
||||
Stored on the provider so every call site inherits the same defaults
|
||||
without having to pass temperature / max_tokens / reasoning_effort
|
||||
through every layer. Individual call sites can still override by
|
||||
passing explicit keyword arguments to chat() / chat_with_retry().
|
||||
"""
|
||||
"""Default generation settings."""
|
||||
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 4096
|
||||
@ -71,14 +68,12 @@ class GenerationSettings:
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""
|
||||
Abstract base class for LLM providers.
|
||||
|
||||
Implementations should handle the specifics of each provider's API
|
||||
while maintaining a consistent interface.
|
||||
"""
|
||||
"""Base class for LLM providers."""
|
||||
|
||||
_CHAT_RETRY_DELAYS = (1, 2, 4)
|
||||
_PERSISTENT_MAX_DELAY = 60
|
||||
_PERSISTENT_IDENTICAL_ERROR_LIMIT = 10
|
||||
_RETRY_HEARTBEAT_CHUNK = 30
|
||||
_TRANSIENT_ERROR_MARKERS = (
|
||||
"429",
|
||||
"rate limit",
|
||||
@ -208,7 +203,7 @@ class LLMProvider(ABC):
|
||||
for b in content:
|
||||
if isinstance(b, dict) and b.get("type") == "image_url":
|
||||
path = (b.get("_meta") or {}).get("path", "")
|
||||
placeholder = f"[image: {path}]" if path else "[image omitted]"
|
||||
placeholder = image_placeholder_text(path, empty="[image omitted]")
|
||||
new_content.append({"type": "text", "text": placeholder})
|
||||
found = True
|
||||
else:
|
||||
@ -273,6 +268,8 @@ class LLMProvider(ABC):
|
||||
reasoning_effort: object = _SENTINEL,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
retry_mode: str = "standard",
|
||||
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Call chat_stream() with retry on transient provider failures."""
|
||||
if max_tokens is self._SENTINEL:
|
||||
@ -288,28 +285,13 @@ class LLMProvider(ABC):
|
||||
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
||||
on_content_delta=on_content_delta,
|
||||
)
|
||||
|
||||
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
||||
response = await self._safe_chat_stream(**kw)
|
||||
|
||||
if response.finish_reason != "error":
|
||||
return response
|
||||
|
||||
if not self._is_transient_error(response.content):
|
||||
stripped = self._strip_image_content(messages)
|
||||
if stripped is not None:
|
||||
logger.warning("Non-transient LLM error with image content, retrying without images")
|
||||
return await self._safe_chat_stream(**{**kw, "messages": stripped})
|
||||
return response
|
||||
|
||||
logger.warning(
|
||||
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
|
||||
attempt, len(self._CHAT_RETRY_DELAYS), delay,
|
||||
(response.content or "")[:120].lower(),
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
return await self._safe_chat_stream(**kw)
|
||||
return await self._run_with_retry(
|
||||
self._safe_chat_stream,
|
||||
kw,
|
||||
messages,
|
||||
retry_mode=retry_mode,
|
||||
on_retry_wait=on_retry_wait,
|
||||
)
|
||||
|
||||
async def chat_with_retry(
|
||||
self,
|
||||
@ -320,6 +302,8 @@ class LLMProvider(ABC):
|
||||
temperature: object = _SENTINEL,
|
||||
reasoning_effort: object = _SENTINEL,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
retry_mode: str = "standard",
|
||||
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Call chat() with retry on transient provider failures.
|
||||
|
||||
@ -339,28 +323,118 @@ class LLMProvider(ABC):
|
||||
max_tokens=max_tokens, temperature=temperature,
|
||||
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
||||
)
|
||||
return await self._run_with_retry(
|
||||
self._safe_chat,
|
||||
kw,
|
||||
messages,
|
||||
retry_mode=retry_mode,
|
||||
on_retry_wait=on_retry_wait,
|
||||
)
|
||||
|
||||
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
||||
response = await self._safe_chat(**kw)
|
||||
@classmethod
|
||||
def _extract_retry_after(cls, content: str | None) -> float | None:
|
||||
text = (content or "").lower()
|
||||
match = re.search(r"retry after\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)?", text)
|
||||
if not match:
|
||||
return None
|
||||
value = float(match.group(1))
|
||||
unit = (match.group(2) or "s").lower()
|
||||
if unit in {"ms", "milliseconds"}:
|
||||
return max(0.1, value / 1000.0)
|
||||
if unit in {"m", "min", "minutes"}:
|
||||
return value * 60.0
|
||||
return value
|
||||
|
||||
async def _sleep_with_heartbeat(
|
||||
self,
|
||||
delay: float,
|
||||
*,
|
||||
attempt: int,
|
||||
persistent: bool,
|
||||
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> None:
|
||||
remaining = max(0.0, delay)
|
||||
while remaining > 0:
|
||||
if on_retry_wait:
|
||||
kind = "persistent retry" if persistent else "retry"
|
||||
await on_retry_wait(
|
||||
f"Model request failed, {kind} in {max(1, int(round(remaining)))}s "
|
||||
f"(attempt {attempt})."
|
||||
)
|
||||
chunk = min(remaining, self._RETRY_HEARTBEAT_CHUNK)
|
||||
await asyncio.sleep(chunk)
|
||||
remaining -= chunk
|
||||
|
||||
async def _run_with_retry(
|
||||
self,
|
||||
call: Callable[..., Awaitable[LLMResponse]],
|
||||
kw: dict[str, Any],
|
||||
original_messages: list[dict[str, Any]],
|
||||
*,
|
||||
retry_mode: str,
|
||||
on_retry_wait: Callable[[str], Awaitable[None]] | None,
|
||||
) -> LLMResponse:
|
||||
attempt = 0
|
||||
delays = list(self._CHAT_RETRY_DELAYS)
|
||||
persistent = retry_mode == "persistent"
|
||||
last_response: LLMResponse | None = None
|
||||
last_error_key: str | None = None
|
||||
identical_error_count = 0
|
||||
while True:
|
||||
attempt += 1
|
||||
response = await call(**kw)
|
||||
if response.finish_reason != "error":
|
||||
return response
|
||||
last_response = response
|
||||
error_key = ((response.content or "").strip().lower() or None)
|
||||
if error_key and error_key == last_error_key:
|
||||
identical_error_count += 1
|
||||
else:
|
||||
last_error_key = error_key
|
||||
identical_error_count = 1 if error_key else 0
|
||||
|
||||
if not self._is_transient_error(response.content):
|
||||
stripped = self._strip_image_content(messages)
|
||||
if stripped is not None:
|
||||
logger.warning("Non-transient LLM error with image content, retrying without images")
|
||||
return await self._safe_chat(**{**kw, "messages": stripped})
|
||||
stripped = self._strip_image_content(original_messages)
|
||||
if stripped is not None and stripped != kw["messages"]:
|
||||
logger.warning(
|
||||
"Non-transient LLM error with image content, retrying without images"
|
||||
)
|
||||
retry_kw = dict(kw)
|
||||
retry_kw["messages"] = stripped
|
||||
return await call(**retry_kw)
|
||||
return response
|
||||
|
||||
if persistent and identical_error_count >= self._PERSISTENT_IDENTICAL_ERROR_LIMIT:
|
||||
logger.warning(
|
||||
"Stopping persistent retry after {} identical transient errors: {}",
|
||||
identical_error_count,
|
||||
(response.content or "")[:120].lower(),
|
||||
)
|
||||
return response
|
||||
|
||||
if not persistent and attempt > len(delays):
|
||||
break
|
||||
|
||||
base_delay = delays[min(attempt - 1, len(delays) - 1)]
|
||||
delay = self._extract_retry_after(response.content) or base_delay
|
||||
if persistent:
|
||||
delay = min(delay, self._PERSISTENT_MAX_DELAY)
|
||||
|
||||
logger.warning(
|
||||
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
|
||||
attempt, len(self._CHAT_RETRY_DELAYS), delay,
|
||||
"LLM transient error (attempt {}{}), retrying in {}s: {}",
|
||||
attempt,
|
||||
"+" if persistent and attempt > len(delays) else f"/{len(delays)}",
|
||||
int(round(delay)),
|
||||
(response.content or "")[:120].lower(),
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
await self._sleep_with_heartbeat(
|
||||
delay,
|
||||
attempt=attempt,
|
||||
persistent=persistent,
|
||||
on_retry_wait=on_retry_wait,
|
||||
)
|
||||
|
||||
return await self._safe_chat(**kw)
|
||||
return last_response if last_response is not None else await call(**kw)
|
||||
|
||||
@abstractmethod
|
||||
def get_default_model(self) -> str:
|
||||
|
||||
@ -6,13 +6,18 @@ import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from oauth_cli_kit import get_token as get_codex_token
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.openai_responses import (
|
||||
consume_sse,
|
||||
convert_messages,
|
||||
convert_tools,
|
||||
)
|
||||
|
||||
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
DEFAULT_ORIGINATOR = "nanobot"
|
||||
@ -36,7 +41,7 @@ class OpenAICodexProvider(LLMProvider):
|
||||
) -> LLMResponse:
|
||||
"""Shared request logic for both chat() and chat_stream()."""
|
||||
model = model or self.default_model
|
||||
system_prompt, input_items = _convert_messages(messages)
|
||||
system_prompt, input_items = convert_messages(messages)
|
||||
|
||||
token = await asyncio.to_thread(get_codex_token)
|
||||
headers = _build_headers(token.account_id, token.access)
|
||||
@ -56,7 +61,7 @@ class OpenAICodexProvider(LLMProvider):
|
||||
if reasoning_effort:
|
||||
body["reasoning"] = {"effort": reasoning_effort}
|
||||
if tools:
|
||||
body["tools"] = _convert_tools(tools)
|
||||
body["tools"] = convert_tools(tools)
|
||||
|
||||
try:
|
||||
try:
|
||||
@ -127,96 +132,7 @@ async def _request_codex(
|
||||
if response.status_code != 200:
|
||||
text = await response.aread()
|
||||
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
|
||||
return await _consume_sse(response, on_content_delta)
|
||||
|
||||
|
||||
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert OpenAI function-calling schema to Codex flat format."""
|
||||
converted: list[dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
|
||||
name = fn.get("name")
|
||||
if not name:
|
||||
continue
|
||||
params = fn.get("parameters") or {}
|
||||
converted.append({
|
||||
"type": "function",
|
||||
"name": name,
|
||||
"description": fn.get("description") or "",
|
||||
"parameters": params if isinstance(params, dict) else {},
|
||||
})
|
||||
return converted
|
||||
|
||||
|
||||
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
||||
system_prompt = ""
|
||||
input_items: list[dict[str, Any]] = []
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
|
||||
if role == "system":
|
||||
system_prompt = content if isinstance(content, str) else ""
|
||||
continue
|
||||
|
||||
if role == "user":
|
||||
input_items.append(_convert_user_message(content))
|
||||
continue
|
||||
|
||||
if role == "assistant":
|
||||
if isinstance(content, str) and content:
|
||||
input_items.append({
|
||||
"type": "message", "role": "assistant",
|
||||
"content": [{"type": "output_text", "text": content}],
|
||||
"status": "completed", "id": f"msg_{idx}",
|
||||
})
|
||||
for tool_call in msg.get("tool_calls", []) or []:
|
||||
fn = tool_call.get("function") or {}
|
||||
call_id, item_id = _split_tool_call_id(tool_call.get("id"))
|
||||
input_items.append({
|
||||
"type": "function_call",
|
||||
"id": item_id or f"fc_{idx}",
|
||||
"call_id": call_id or f"call_{idx}",
|
||||
"name": fn.get("name"),
|
||||
"arguments": fn.get("arguments") or "{}",
|
||||
})
|
||||
continue
|
||||
|
||||
if role == "tool":
|
||||
call_id, _ = _split_tool_call_id(msg.get("tool_call_id"))
|
||||
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
|
||||
input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text})
|
||||
|
||||
return system_prompt, input_items
|
||||
|
||||
|
||||
def _convert_user_message(content: Any) -> dict[str, Any]:
|
||||
if isinstance(content, str):
|
||||
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
|
||||
if isinstance(content, list):
|
||||
converted: list[dict[str, Any]] = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("type") == "text":
|
||||
converted.append({"type": "input_text", "text": item.get("text", "")})
|
||||
elif item.get("type") == "image_url":
|
||||
url = (item.get("image_url") or {}).get("url")
|
||||
if url:
|
||||
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
|
||||
if converted:
|
||||
return {"role": "user", "content": converted}
|
||||
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
|
||||
|
||||
|
||||
def _split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
|
||||
if isinstance(tool_call_id, str) and tool_call_id:
|
||||
if "|" in tool_call_id:
|
||||
call_id, item_id = tool_call_id.split("|", 1)
|
||||
return call_id, item_id or None
|
||||
return tool_call_id, None
|
||||
return "call_0", None
|
||||
return await consume_sse(response, on_content_delta)
|
||||
|
||||
|
||||
def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
|
||||
@ -224,96 +140,6 @@ def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
|
||||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
|
||||
buffer: list[str] = []
|
||||
async for line in response.aiter_lines():
|
||||
if line == "":
|
||||
if buffer:
|
||||
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
|
||||
buffer = []
|
||||
if not data_lines:
|
||||
continue
|
||||
data = "\n".join(data_lines).strip()
|
||||
if not data or data == "[DONE]":
|
||||
continue
|
||||
try:
|
||||
yield json.loads(data)
|
||||
except Exception:
|
||||
continue
|
||||
continue
|
||||
buffer.append(line)
|
||||
|
||||
|
||||
async def _consume_sse(
|
||||
response: httpx.Response,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> tuple[str, list[ToolCallRequest], str]:
|
||||
content = ""
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
tool_call_buffers: dict[str, dict[str, Any]] = {}
|
||||
finish_reason = "stop"
|
||||
|
||||
async for event in _iter_sse(response):
|
||||
event_type = event.get("type")
|
||||
if event_type == "response.output_item.added":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
call_id = item.get("call_id")
|
||||
if not call_id:
|
||||
continue
|
||||
tool_call_buffers[call_id] = {
|
||||
"id": item.get("id") or "fc_0",
|
||||
"name": item.get("name"),
|
||||
"arguments": item.get("arguments") or "",
|
||||
}
|
||||
elif event_type == "response.output_text.delta":
|
||||
delta_text = event.get("delta") or ""
|
||||
content += delta_text
|
||||
if on_content_delta and delta_text:
|
||||
await on_content_delta(delta_text)
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] += event.get("delta") or ""
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] = event.get("arguments") or ""
|
||||
elif event_type == "response.output_item.done":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
call_id = item.get("call_id")
|
||||
if not call_id:
|
||||
continue
|
||||
buf = tool_call_buffers.get(call_id) or {}
|
||||
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
|
||||
try:
|
||||
args = json.loads(args_raw)
|
||||
except Exception:
|
||||
args = {"raw": args_raw}
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
|
||||
name=buf.get("name") or item.get("name"),
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
elif event_type == "response.completed":
|
||||
status = (event.get("response") or {}).get("status")
|
||||
finish_reason = _map_finish_reason(status)
|
||||
elif event_type in {"error", "response.failed"}:
|
||||
raise RuntimeError("Codex response failed")
|
||||
|
||||
return content, tool_calls, finish_reason
|
||||
|
||||
|
||||
_FINISH_REASON_MAP = {"completed": "stop", "incomplete": "length", "failed": "error", "cancelled": "error"}
|
||||
|
||||
|
||||
def _map_finish_reason(status: str | None) -> str:
|
||||
return _FINISH_REASON_MAP.get(status or "completed", "stop")
|
||||
|
||||
|
||||
def _friendly_error(status_code: int, raw: str) -> str:
|
||||
if status_code == 429:
|
||||
return "ChatGPT usage quota exceeded or rate limit triggered. Please try again later."
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
import secrets
|
||||
@ -20,7 +21,6 @@ if TYPE_CHECKING:
|
||||
|
||||
_ALLOWED_MSG_KEYS = frozenset({
|
||||
"role", "content", "tool_calls", "tool_call_id", "name",
|
||||
"reasoning_content", "extra_content",
|
||||
})
|
||||
_ALNUM = string.ascii_letters + string.digits
|
||||
|
||||
@ -615,16 +615,33 @@ class OpenAICompatProvider(LLMProvider):
|
||||
)
|
||||
kwargs["stream"] = True
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
|
||||
try:
|
||||
stream = await self._client.chat.completions.create(**kwargs)
|
||||
chunks: list[Any] = []
|
||||
async for chunk in stream:
|
||||
stream_iter = stream.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(
|
||||
stream_iter.__anext__(),
|
||||
timeout=idle_timeout_s,
|
||||
)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
chunks.append(chunk)
|
||||
if on_content_delta and chunk.choices:
|
||||
text = getattr(chunk.choices[0].delta, "content", None)
|
||||
if text:
|
||||
await on_content_delta(text)
|
||||
return self._parse_chunks(chunks)
|
||||
except asyncio.TimeoutError:
|
||||
return LLMResponse(
|
||||
content=(
|
||||
f"Error calling LLM: stream stalled for more than "
|
||||
f"{idle_timeout_s} seconds"
|
||||
),
|
||||
finish_reason="error",
|
||||
)
|
||||
except Exception as e:
|
||||
return self._handle_error(e)
|
||||
|
||||
|
||||
29
nanobot/providers/openai_responses/__init__.py
Normal file
29
nanobot/providers/openai_responses/__init__.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""Shared helpers for OpenAI Responses API providers (Codex, Azure OpenAI)."""
|
||||
|
||||
from nanobot.providers.openai_responses.converters import (
|
||||
convert_messages,
|
||||
convert_tools,
|
||||
convert_user_message,
|
||||
split_tool_call_id,
|
||||
)
|
||||
from nanobot.providers.openai_responses.parsing import (
|
||||
FINISH_REASON_MAP,
|
||||
consume_sdk_stream,
|
||||
consume_sse,
|
||||
iter_sse,
|
||||
map_finish_reason,
|
||||
parse_response_output,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"convert_messages",
|
||||
"convert_tools",
|
||||
"convert_user_message",
|
||||
"split_tool_call_id",
|
||||
"iter_sse",
|
||||
"consume_sse",
|
||||
"consume_sdk_stream",
|
||||
"map_finish_reason",
|
||||
"parse_response_output",
|
||||
"FINISH_REASON_MAP",
|
||||
]
|
||||
110
nanobot/providers/openai_responses/converters.py
Normal file
110
nanobot/providers/openai_responses/converters.py
Normal file
@ -0,0 +1,110 @@
|
||||
"""Convert Chat Completions messages/tools to Responses API format."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
||||
def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
||||
"""Convert Chat Completions messages to Responses API input items.
|
||||
|
||||
Returns ``(system_prompt, input_items)`` where *system_prompt* is extracted
|
||||
from any ``system`` role message and *input_items* is the Responses API
|
||||
``input`` array.
|
||||
"""
|
||||
system_prompt = ""
|
||||
input_items: list[dict[str, Any]] = []
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
|
||||
if role == "system":
|
||||
system_prompt = content if isinstance(content, str) else ""
|
||||
continue
|
||||
|
||||
if role == "user":
|
||||
input_items.append(convert_user_message(content))
|
||||
continue
|
||||
|
||||
if role == "assistant":
|
||||
if isinstance(content, str) and content:
|
||||
input_items.append({
|
||||
"type": "message", "role": "assistant",
|
||||
"content": [{"type": "output_text", "text": content}],
|
||||
"status": "completed", "id": f"msg_{idx}",
|
||||
})
|
||||
for tool_call in msg.get("tool_calls", []) or []:
|
||||
fn = tool_call.get("function") or {}
|
||||
call_id, item_id = split_tool_call_id(tool_call.get("id"))
|
||||
input_items.append({
|
||||
"type": "function_call",
|
||||
"id": item_id or f"fc_{idx}",
|
||||
"call_id": call_id or f"call_{idx}",
|
||||
"name": fn.get("name"),
|
||||
"arguments": fn.get("arguments") or "{}",
|
||||
})
|
||||
continue
|
||||
|
||||
if role == "tool":
|
||||
call_id, _ = split_tool_call_id(msg.get("tool_call_id"))
|
||||
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
|
||||
input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text})
|
||||
|
||||
return system_prompt, input_items
|
||||
|
||||
|
||||
def convert_user_message(content: Any) -> dict[str, Any]:
|
||||
"""Convert a user message's content to Responses API format.
|
||||
|
||||
Handles plain strings, ``text`` blocks -> ``input_text``, and
|
||||
``image_url`` blocks -> ``input_image``.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
|
||||
if isinstance(content, list):
|
||||
converted: list[dict[str, Any]] = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("type") == "text":
|
||||
converted.append({"type": "input_text", "text": item.get("text", "")})
|
||||
elif item.get("type") == "image_url":
|
||||
url = (item.get("image_url") or {}).get("url")
|
||||
if url:
|
||||
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
|
||||
if converted:
|
||||
return {"role": "user", "content": converted}
|
||||
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
|
||||
|
||||
|
||||
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert OpenAI function-calling tool schema to Responses API flat format."""
|
||||
converted: list[dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
|
||||
name = fn.get("name")
|
||||
if not name:
|
||||
continue
|
||||
params = fn.get("parameters") or {}
|
||||
converted.append({
|
||||
"type": "function",
|
||||
"name": name,
|
||||
"description": fn.get("description") or "",
|
||||
"parameters": params if isinstance(params, dict) else {},
|
||||
})
|
||||
return converted
|
||||
|
||||
|
||||
def split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
|
||||
"""Split a compound ``call_id|item_id`` string.
|
||||
|
||||
Returns ``(call_id, item_id)`` where *item_id* may be ``None``.
|
||||
"""
|
||||
if isinstance(tool_call_id, str) and tool_call_id:
|
||||
if "|" in tool_call_id:
|
||||
call_id, item_id = tool_call_id.split("|", 1)
|
||||
return call_id, item_id or None
|
||||
return tool_call_id, None
|
||||
return "call_0", None
|
||||
297
nanobot/providers/openai_responses/parsing.py
Normal file
297
nanobot/providers/openai_responses/parsing.py
Normal file
@ -0,0 +1,297 @@
|
||||
"""Parse Responses API SSE streams and SDK response objects."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import httpx
|
||||
import json_repair
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
FINISH_REASON_MAP = {
|
||||
"completed": "stop",
|
||||
"incomplete": "length",
|
||||
"failed": "error",
|
||||
"cancelled": "error",
|
||||
}
|
||||
|
||||
|
||||
def map_finish_reason(status: str | None) -> str:
|
||||
"""Map a Responses API status string to a Chat-Completions-style finish_reason."""
|
||||
return FINISH_REASON_MAP.get(status or "completed", "stop")
|
||||
|
||||
|
||||
async def iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
|
||||
"""Yield parsed JSON events from a Responses API SSE stream."""
|
||||
buffer: list[str] = []
|
||||
|
||||
def _flush() -> dict[str, Any] | None:
|
||||
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
|
||||
buffer.clear()
|
||||
if not data_lines:
|
||||
return None
|
||||
data = "\n".join(data_lines).strip()
|
||||
if not data or data == "[DONE]":
|
||||
return None
|
||||
try:
|
||||
return json.loads(data)
|
||||
except Exception:
|
||||
logger.warning("Failed to parse SSE event JSON: {}", data[:200])
|
||||
return None
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line == "":
|
||||
if buffer:
|
||||
event = _flush()
|
||||
if event is not None:
|
||||
yield event
|
||||
continue
|
||||
buffer.append(line)
|
||||
|
||||
# Flush any remaining buffer at EOF (#10)
|
||||
if buffer:
|
||||
event = _flush()
|
||||
if event is not None:
|
||||
yield event
|
||||
|
||||
|
||||
async def consume_sse(
|
||||
response: httpx.Response,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> tuple[str, list[ToolCallRequest], str]:
|
||||
"""Consume a Responses API SSE stream into ``(content, tool_calls, finish_reason)``."""
|
||||
content = ""
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
tool_call_buffers: dict[str, dict[str, Any]] = {}
|
||||
finish_reason = "stop"
|
||||
|
||||
async for event in iter_sse(response):
|
||||
event_type = event.get("type")
|
||||
if event_type == "response.output_item.added":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
call_id = item.get("call_id")
|
||||
if not call_id:
|
||||
continue
|
||||
tool_call_buffers[call_id] = {
|
||||
"id": item.get("id") or "fc_0",
|
||||
"name": item.get("name"),
|
||||
"arguments": item.get("arguments") or "",
|
||||
}
|
||||
elif event_type == "response.output_text.delta":
|
||||
delta_text = event.get("delta") or ""
|
||||
content += delta_text
|
||||
if on_content_delta and delta_text:
|
||||
await on_content_delta(delta_text)
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] += event.get("delta") or ""
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] = event.get("arguments") or ""
|
||||
elif event_type == "response.output_item.done":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
call_id = item.get("call_id")
|
||||
if not call_id:
|
||||
continue
|
||||
buf = tool_call_buffers.get(call_id) or {}
|
||||
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
|
||||
try:
|
||||
args = json.loads(args_raw)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to parse tool call arguments for '{}': {}",
|
||||
buf.get("name") or item.get("name"),
|
||||
args_raw[:200],
|
||||
)
|
||||
args = json_repair.loads(args_raw)
|
||||
if not isinstance(args, dict):
|
||||
args = {"raw": args_raw}
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
|
||||
name=buf.get("name") or item.get("name") or "",
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
elif event_type == "response.completed":
|
||||
status = (event.get("response") or {}).get("status")
|
||||
finish_reason = map_finish_reason(status)
|
||||
elif event_type in {"error", "response.failed"}:
|
||||
detail = event.get("error") or event.get("message") or event
|
||||
raise RuntimeError(f"Response failed: {str(detail)[:500]}")
|
||||
|
||||
return content, tool_calls, finish_reason
|
||||
|
||||
|
||||
def parse_response_output(response: Any) -> LLMResponse:
|
||||
"""Parse an SDK ``Response`` object into an ``LLMResponse``."""
|
||||
if not isinstance(response, dict):
|
||||
dump = getattr(response, "model_dump", None)
|
||||
response = dump() if callable(dump) else vars(response)
|
||||
|
||||
output = response.get("output") or []
|
||||
content_parts: list[str] = []
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
reasoning_content: str | None = None
|
||||
|
||||
for item in output:
|
||||
if not isinstance(item, dict):
|
||||
dump = getattr(item, "model_dump", None)
|
||||
item = dump() if callable(dump) else vars(item)
|
||||
|
||||
item_type = item.get("type")
|
||||
if item_type == "message":
|
||||
for block in item.get("content") or []:
|
||||
if not isinstance(block, dict):
|
||||
dump = getattr(block, "model_dump", None)
|
||||
block = dump() if callable(dump) else vars(block)
|
||||
if block.get("type") == "output_text":
|
||||
content_parts.append(block.get("text") or "")
|
||||
elif item_type == "reasoning":
|
||||
for s in item.get("summary") or []:
|
||||
if not isinstance(s, dict):
|
||||
dump = getattr(s, "model_dump", None)
|
||||
s = dump() if callable(dump) else vars(s)
|
||||
if s.get("type") == "summary_text" and s.get("text"):
|
||||
reasoning_content = (reasoning_content or "") + s["text"]
|
||||
elif item_type == "function_call":
|
||||
call_id = item.get("call_id") or ""
|
||||
item_id = item.get("id") or "fc_0"
|
||||
args_raw = item.get("arguments") or "{}"
|
||||
try:
|
||||
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to parse tool call arguments for '{}': {}",
|
||||
item.get("name"),
|
||||
str(args_raw)[:200],
|
||||
)
|
||||
args = json_repair.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||||
if not isinstance(args, dict):
|
||||
args = {"raw": args_raw}
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=f"{call_id}|{item_id}",
|
||||
name=item.get("name") or "",
|
||||
arguments=args if isinstance(args, dict) else {},
|
||||
))
|
||||
|
||||
usage_raw = response.get("usage") or {}
|
||||
if not isinstance(usage_raw, dict):
|
||||
dump = getattr(usage_raw, "model_dump", None)
|
||||
usage_raw = dump() if callable(dump) else vars(usage_raw)
|
||||
usage = {}
|
||||
if usage_raw:
|
||||
usage = {
|
||||
"prompt_tokens": int(usage_raw.get("input_tokens") or 0),
|
||||
"completion_tokens": int(usage_raw.get("output_tokens") or 0),
|
||||
"total_tokens": int(usage_raw.get("total_tokens") or 0),
|
||||
}
|
||||
|
||||
status = response.get("status")
|
||||
finish_reason = map_finish_reason(status)
|
||||
|
||||
return LLMResponse(
|
||||
content="".join(content_parts) or None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None,
|
||||
)
|
||||
|
||||
|
||||
async def consume_sdk_stream(
|
||||
stream: Any,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> tuple[str, list[ToolCallRequest], str, dict[str, int], str | None]:
|
||||
"""Consume an SDK async stream from ``client.responses.create(stream=True)``."""
|
||||
content = ""
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
tool_call_buffers: dict[str, dict[str, Any]] = {}
|
||||
finish_reason = "stop"
|
||||
usage: dict[str, int] = {}
|
||||
reasoning_content: str | None = None
|
||||
|
||||
async for event in stream:
|
||||
event_type = getattr(event, "type", None)
|
||||
if event_type == "response.output_item.added":
|
||||
item = getattr(event, "item", None)
|
||||
if item and getattr(item, "type", None) == "function_call":
|
||||
call_id = getattr(item, "call_id", None)
|
||||
if not call_id:
|
||||
continue
|
||||
tool_call_buffers[call_id] = {
|
||||
"id": getattr(item, "id", None) or "fc_0",
|
||||
"name": getattr(item, "name", None),
|
||||
"arguments": getattr(item, "arguments", None) or "",
|
||||
}
|
||||
elif event_type == "response.output_text.delta":
|
||||
delta_text = getattr(event, "delta", "") or ""
|
||||
content += delta_text
|
||||
if on_content_delta and delta_text:
|
||||
await on_content_delta(delta_text)
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
call_id = getattr(event, "call_id", None)
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] += getattr(event, "delta", "") or ""
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
call_id = getattr(event, "call_id", None)
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] = getattr(event, "arguments", "") or ""
|
||||
elif event_type == "response.output_item.done":
|
||||
item = getattr(event, "item", None)
|
||||
if item and getattr(item, "type", None) == "function_call":
|
||||
call_id = getattr(item, "call_id", None)
|
||||
if not call_id:
|
||||
continue
|
||||
buf = tool_call_buffers.get(call_id) or {}
|
||||
args_raw = buf.get("arguments") or getattr(item, "arguments", None) or "{}"
|
||||
try:
|
||||
args = json.loads(args_raw)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to parse tool call arguments for '{}': {}",
|
||||
buf.get("name") or getattr(item, "name", None),
|
||||
str(args_raw)[:200],
|
||||
)
|
||||
args = json_repair.loads(args_raw)
|
||||
if not isinstance(args, dict):
|
||||
args = {"raw": args_raw}
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=f"{call_id}|{buf.get('id') or getattr(item, 'id', None) or 'fc_0'}",
|
||||
name=buf.get("name") or getattr(item, "name", None) or "",
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
elif event_type == "response.completed":
|
||||
resp = getattr(event, "response", None)
|
||||
status = getattr(resp, "status", None) if resp else None
|
||||
finish_reason = map_finish_reason(status)
|
||||
if resp:
|
||||
usage_obj = getattr(resp, "usage", None)
|
||||
if usage_obj:
|
||||
usage = {
|
||||
"prompt_tokens": int(getattr(usage_obj, "input_tokens", 0) or 0),
|
||||
"completion_tokens": int(getattr(usage_obj, "output_tokens", 0) or 0),
|
||||
"total_tokens": int(getattr(usage_obj, "total_tokens", 0) or 0),
|
||||
}
|
||||
for out_item in getattr(resp, "output", None) or []:
|
||||
if getattr(out_item, "type", None) == "reasoning":
|
||||
for s in getattr(out_item, "summary", None) or []:
|
||||
if getattr(s, "type", None) == "summary_text":
|
||||
text = getattr(s, "text", None)
|
||||
if text:
|
||||
reasoning_content = (reasoning_content or "") + text
|
||||
elif event_type in {"error", "response.failed"}:
|
||||
detail = getattr(event, "error", None) or getattr(event, "message", None) or event
|
||||
raise RuntimeError(f"Response failed: {str(detail)[:500]}")
|
||||
|
||||
return content, tool_calls, finish_reason, usage, reasoning_content
|
||||
@ -10,20 +10,12 @@ from typing import Any
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.config.paths import get_legacy_sessions_dir
|
||||
from nanobot.utils.helpers import ensure_dir, safe_filename
|
||||
from nanobot.utils.helpers import ensure_dir, find_legal_message_start, safe_filename
|
||||
|
||||
|
||||
@dataclass
|
||||
class Session:
|
||||
"""
|
||||
A conversation session.
|
||||
|
||||
Stores messages in JSONL format for easy reading and persistence.
|
||||
|
||||
Important: Messages are append-only for LLM cache efficiency.
|
||||
The consolidation process writes summaries to MEMORY.md/HISTORY.md
|
||||
but does NOT modify the messages list or get_history() output.
|
||||
"""
|
||||
"""A conversation session."""
|
||||
|
||||
key: str # channel:chat_id
|
||||
messages: list[dict[str, Any]] = field(default_factory=list)
|
||||
@ -43,43 +35,19 @@ class Session:
|
||||
self.messages.append(msg)
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
@staticmethod
|
||||
def _find_legal_start(messages: list[dict[str, Any]]) -> int:
|
||||
"""Find first index where every tool result has a matching assistant tool_call."""
|
||||
declared: set[str] = set()
|
||||
start = 0
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role")
|
||||
if role == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
if isinstance(tc, dict) and tc.get("id"):
|
||||
declared.add(str(tc["id"]))
|
||||
elif role == "tool":
|
||||
tid = msg.get("tool_call_id")
|
||||
if tid and str(tid) not in declared:
|
||||
start = i + 1
|
||||
declared.clear()
|
||||
for prev in messages[start:i + 1]:
|
||||
if prev.get("role") == "assistant":
|
||||
for tc in prev.get("tool_calls") or []:
|
||||
if isinstance(tc, dict) and tc.get("id"):
|
||||
declared.add(str(tc["id"]))
|
||||
return start
|
||||
|
||||
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
||||
"""Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary."""
|
||||
unconsolidated = self.messages[self.last_consolidated:]
|
||||
sliced = unconsolidated[-max_messages:]
|
||||
|
||||
# Drop leading non-user messages to avoid starting mid-turn when possible.
|
||||
# Avoid starting mid-turn when possible.
|
||||
for i, message in enumerate(sliced):
|
||||
if message.get("role") == "user":
|
||||
sliced = sliced[i:]
|
||||
break
|
||||
|
||||
# Some providers reject orphan tool results if the matching assistant
|
||||
# tool_calls message fell outside the fixed-size history window.
|
||||
start = self._find_legal_start(sliced)
|
||||
# Drop orphan tool results at the front.
|
||||
start = find_legal_message_start(sliced)
|
||||
if start:
|
||||
sliced = sliced[start:]
|
||||
|
||||
@ -115,7 +83,7 @@ class Session:
|
||||
retained = self.messages[start_idx:]
|
||||
|
||||
# Mirror get_history(): avoid persisting orphan tool results at the front.
|
||||
start = self._find_legal_start(retained)
|
||||
start = find_legal_message_start(retained)
|
||||
if start:
|
||||
retained = retained[start:]
|
||||
|
||||
|
||||
@ -3,12 +3,15 @@
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def strip_think(text: str) -> str:
|
||||
@ -56,11 +59,7 @@ def timestamp() -> str:
|
||||
|
||||
|
||||
def current_time_str(timezone: str | None = None) -> str:
|
||||
"""Human-readable current time with weekday and UTC offset.
|
||||
|
||||
When *timezone* is a valid IANA name (e.g. ``"Asia/Shanghai"``), the time
|
||||
is converted to that zone. Otherwise falls back to the host local time.
|
||||
"""
|
||||
"""Return the current time string."""
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
try:
|
||||
@ -76,12 +75,164 @@ def current_time_str(timezone: str | None = None) -> str:
|
||||
|
||||
|
||||
_UNSAFE_CHARS = re.compile(r'[<>:"/\\|?*]')
|
||||
_TOOL_RESULT_PREVIEW_CHARS = 1200
|
||||
_TOOL_RESULTS_DIR = ".nanobot/tool-results"
|
||||
_TOOL_RESULT_RETENTION_SECS = 7 * 24 * 60 * 60
|
||||
_TOOL_RESULT_MAX_BUCKETS = 32
|
||||
|
||||
def safe_filename(name: str) -> str:
|
||||
"""Replace unsafe path characters with underscores."""
|
||||
return _UNSAFE_CHARS.sub("_", name).strip()
|
||||
|
||||
|
||||
def image_placeholder_text(path: str | None, *, empty: str = "[image]") -> str:
|
||||
"""Build an image placeholder string."""
|
||||
return f"[image: {path}]" if path else empty
|
||||
|
||||
|
||||
def truncate_text(text: str, max_chars: int) -> str:
|
||||
"""Truncate text with a stable suffix."""
|
||||
if max_chars <= 0 or len(text) <= max_chars:
|
||||
return text
|
||||
return text[:max_chars] + "\n... (truncated)"
|
||||
|
||||
|
||||
def find_legal_message_start(messages: list[dict[str, Any]]) -> int:
|
||||
"""Find the first index whose tool results have matching assistant calls."""
|
||||
declared: set[str] = set()
|
||||
start = 0
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role")
|
||||
if role == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
if isinstance(tc, dict) and tc.get("id"):
|
||||
declared.add(str(tc["id"]))
|
||||
elif role == "tool":
|
||||
tid = msg.get("tool_call_id")
|
||||
if tid and str(tid) not in declared:
|
||||
start = i + 1
|
||||
declared.clear()
|
||||
for prev in messages[start : i + 1]:
|
||||
if prev.get("role") == "assistant":
|
||||
for tc in prev.get("tool_calls") or []:
|
||||
if isinstance(tc, dict) and tc.get("id"):
|
||||
declared.add(str(tc["id"]))
|
||||
return start
|
||||
|
||||
|
||||
def stringify_text_blocks(content: list[dict[str, Any]]) -> str | None:
|
||||
parts: list[str] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
return None
|
||||
if block.get("type") != "text":
|
||||
return None
|
||||
text = block.get("text")
|
||||
if not isinstance(text, str):
|
||||
return None
|
||||
parts.append(text)
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _render_tool_result_reference(
|
||||
filepath: Path,
|
||||
*,
|
||||
original_size: int,
|
||||
preview: str,
|
||||
truncated_preview: bool,
|
||||
) -> str:
|
||||
result = (
|
||||
f"[tool output persisted]\n"
|
||||
f"Full output saved to: {filepath}\n"
|
||||
f"Original size: {original_size} chars\n"
|
||||
f"Preview:\n{preview}"
|
||||
)
|
||||
if truncated_preview:
|
||||
result += "\n...\n(Read the saved file if you need the full output.)"
|
||||
return result
|
||||
|
||||
|
||||
def _bucket_mtime(path: Path) -> float:
|
||||
try:
|
||||
return path.stat().st_mtime
|
||||
except OSError:
|
||||
return 0.0
|
||||
|
||||
|
||||
def _cleanup_tool_result_buckets(root: Path, current_bucket: Path) -> None:
|
||||
siblings = [path for path in root.iterdir() if path.is_dir() and path != current_bucket]
|
||||
cutoff = time.time() - _TOOL_RESULT_RETENTION_SECS
|
||||
for path in siblings:
|
||||
if _bucket_mtime(path) < cutoff:
|
||||
shutil.rmtree(path, ignore_errors=True)
|
||||
keep = max(_TOOL_RESULT_MAX_BUCKETS - 1, 0)
|
||||
siblings = [path for path in siblings if path.exists()]
|
||||
if len(siblings) <= keep:
|
||||
return
|
||||
siblings.sort(key=_bucket_mtime, reverse=True)
|
||||
for path in siblings[keep:]:
|
||||
shutil.rmtree(path, ignore_errors=True)
|
||||
|
||||
|
||||
def _write_text_atomic(path: Path, content: str) -> None:
|
||||
tmp = path.with_name(f".{path.name}.{uuid.uuid4().hex}.tmp")
|
||||
try:
|
||||
tmp.write_text(content, encoding="utf-8")
|
||||
tmp.replace(path)
|
||||
finally:
|
||||
if tmp.exists():
|
||||
tmp.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def maybe_persist_tool_result(
|
||||
workspace: Path | None,
|
||||
session_key: str | None,
|
||||
tool_call_id: str,
|
||||
content: Any,
|
||||
*,
|
||||
max_chars: int,
|
||||
) -> Any:
|
||||
"""Persist oversized tool output and replace it with a stable reference string."""
|
||||
if workspace is None or max_chars <= 0:
|
||||
return content
|
||||
|
||||
text_payload: str | None = None
|
||||
suffix = "txt"
|
||||
if isinstance(content, str):
|
||||
text_payload = content
|
||||
elif isinstance(content, list):
|
||||
text_payload = stringify_text_blocks(content)
|
||||
if text_payload is None:
|
||||
return content
|
||||
suffix = "json"
|
||||
else:
|
||||
return content
|
||||
|
||||
if len(text_payload) <= max_chars:
|
||||
return content
|
||||
|
||||
root = ensure_dir(workspace / _TOOL_RESULTS_DIR)
|
||||
bucket = ensure_dir(root / safe_filename(session_key or "default"))
|
||||
try:
|
||||
_cleanup_tool_result_buckets(root, bucket)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to clean stale tool result buckets in {}: {}", root, exc)
|
||||
path = bucket / f"{safe_filename(tool_call_id)}.{suffix}"
|
||||
if not path.exists():
|
||||
if suffix == "json" and isinstance(content, list):
|
||||
_write_text_atomic(path, json.dumps(content, ensure_ascii=False, indent=2))
|
||||
else:
|
||||
_write_text_atomic(path, text_payload)
|
||||
|
||||
preview = text_payload[:_TOOL_RESULT_PREVIEW_CHARS]
|
||||
return _render_tool_result_reference(
|
||||
path,
|
||||
original_size=len(text_payload),
|
||||
preview=preview,
|
||||
truncated_preview=len(text_payload) > _TOOL_RESULT_PREVIEW_CHARS,
|
||||
)
|
||||
|
||||
|
||||
def split_message(content: str, max_len: int = 2000) -> list[str]:
|
||||
"""
|
||||
Split content into chunks within max_len, preferring line breaks.
|
||||
|
||||
88
nanobot/utils/runtime.py
Normal file
88
nanobot/utils/runtime.py
Normal file
@ -0,0 +1,88 @@
|
||||
"""Runtime-specific helper functions and constants."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.utils.helpers import stringify_text_blocks
|
||||
|
||||
_MAX_REPEAT_EXTERNAL_LOOKUPS = 2
|
||||
|
||||
EMPTY_FINAL_RESPONSE_MESSAGE = (
|
||||
"I completed the tool steps but couldn't produce a final answer. "
|
||||
"Please try again or narrow the task."
|
||||
)
|
||||
|
||||
FINALIZATION_RETRY_PROMPT = (
|
||||
"You have already finished the tool work. Do not call any more tools. "
|
||||
"Using only the conversation and tool results above, provide the final answer for the user now."
|
||||
)
|
||||
|
||||
|
||||
def empty_tool_result_message(tool_name: str) -> str:
|
||||
"""Short prompt-safe marker for tools that completed without visible output."""
|
||||
return f"({tool_name} completed with no output)"
|
||||
|
||||
|
||||
def ensure_nonempty_tool_result(tool_name: str, content: Any) -> Any:
|
||||
"""Replace semantically empty tool results with a short marker string."""
|
||||
if content is None:
|
||||
return empty_tool_result_message(tool_name)
|
||||
if isinstance(content, str) and not content.strip():
|
||||
return empty_tool_result_message(tool_name)
|
||||
if isinstance(content, list):
|
||||
if not content:
|
||||
return empty_tool_result_message(tool_name)
|
||||
text_payload = stringify_text_blocks(content)
|
||||
if text_payload is not None and not text_payload.strip():
|
||||
return empty_tool_result_message(tool_name)
|
||||
return content
|
||||
|
||||
|
||||
def is_blank_text(content: str | None) -> bool:
|
||||
"""True when *content* is missing or only whitespace."""
|
||||
return content is None or not content.strip()
|
||||
|
||||
|
||||
def build_finalization_retry_message() -> dict[str, str]:
|
||||
"""A short no-tools-allowed prompt for final answer recovery."""
|
||||
return {"role": "user", "content": FINALIZATION_RETRY_PROMPT}
|
||||
|
||||
|
||||
def external_lookup_signature(tool_name: str, arguments: dict[str, Any]) -> str | None:
|
||||
"""Stable signature for repeated external lookups we want to throttle."""
|
||||
if tool_name == "web_fetch":
|
||||
url = str(arguments.get("url") or "").strip()
|
||||
if url:
|
||||
return f"web_fetch:{url.lower()}"
|
||||
if tool_name == "web_search":
|
||||
query = str(arguments.get("query") or arguments.get("search_term") or "").strip()
|
||||
if query:
|
||||
return f"web_search:{query.lower()}"
|
||||
return None
|
||||
|
||||
|
||||
def repeated_external_lookup_error(
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
seen_counts: dict[str, int],
|
||||
) -> str | None:
|
||||
"""Block repeated external lookups after a small retry budget."""
|
||||
signature = external_lookup_signature(tool_name, arguments)
|
||||
if signature is None:
|
||||
return None
|
||||
count = seen_counts.get(signature, 0) + 1
|
||||
seen_counts[signature] = count
|
||||
if count <= _MAX_REPEAT_EXTERNAL_LOOKUPS:
|
||||
return None
|
||||
logger.warning(
|
||||
"Blocking repeated external lookup {} on attempt {}",
|
||||
signature[:160],
|
||||
count,
|
||||
)
|
||||
return (
|
||||
"Error: repeated external lookup blocked. "
|
||||
"Use the results you already have to answer, or try a meaningfully different source."
|
||||
)
|
||||
@ -71,3 +71,19 @@ def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
|
||||
assert "Channel: cli" in user_content
|
||||
assert "Chat ID: direct" in user_content
|
||||
assert "Return exactly: OK" in user_content
|
||||
|
||||
|
||||
def test_subagent_result_does_not_create_consecutive_assistant_messages(tmp_path) -> None:
|
||||
workspace = _make_workspace(tmp_path)
|
||||
builder = ContextBuilder(workspace)
|
||||
|
||||
messages = builder.build_messages(
|
||||
history=[{"role": "assistant", "content": "previous result"}],
|
||||
current_message="subagent result",
|
||||
channel="cli",
|
||||
chat_id="direct",
|
||||
current_role="assistant",
|
||||
)
|
||||
|
||||
for left, right in zip(messages, messages[1:]):
|
||||
assert not (left.get("role") == right.get("role") == "assistant")
|
||||
|
||||
@ -5,7 +5,9 @@ from nanobot.session.manager import Session
|
||||
|
||||
def _mk_loop() -> AgentLoop:
|
||||
loop = AgentLoop.__new__(AgentLoop)
|
||||
loop._TOOL_RESULT_MAX_CHARS = AgentLoop._TOOL_RESULT_MAX_CHARS
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
|
||||
loop.max_tool_result_chars = AgentDefaults().max_tool_result_chars
|
||||
return loop
|
||||
|
||||
|
||||
@ -72,3 +74,129 @@ def test_save_turn_keeps_tool_results_under_16k() -> None:
|
||||
)
|
||||
|
||||
assert session.messages[0]["content"] == content
|
||||
|
||||
|
||||
def test_restore_runtime_checkpoint_rehydrates_completed_and_pending_tools() -> None:
|
||||
loop = _mk_loop()
|
||||
session = Session(
|
||||
key="test:checkpoint",
|
||||
metadata={
|
||||
AgentLoop._RUNTIME_CHECKPOINT_KEY: {
|
||||
"assistant_message": {
|
||||
"role": "assistant",
|
||||
"content": "working",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_done",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": "{}"},
|
||||
},
|
||||
{
|
||||
"id": "call_pending",
|
||||
"type": "function",
|
||||
"function": {"name": "exec", "arguments": "{}"},
|
||||
},
|
||||
],
|
||||
},
|
||||
"completed_tool_results": [
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_done",
|
||||
"name": "read_file",
|
||||
"content": "ok",
|
||||
}
|
||||
],
|
||||
"pending_tool_calls": [
|
||||
{
|
||||
"id": "call_pending",
|
||||
"type": "function",
|
||||
"function": {"name": "exec", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
restored = loop._restore_runtime_checkpoint(session)
|
||||
|
||||
assert restored is True
|
||||
assert session.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is None
|
||||
assert session.messages[0]["role"] == "assistant"
|
||||
assert session.messages[1]["tool_call_id"] == "call_done"
|
||||
assert session.messages[2]["tool_call_id"] == "call_pending"
|
||||
assert "interrupted before this tool finished" in session.messages[2]["content"].lower()
|
||||
|
||||
|
||||
def test_restore_runtime_checkpoint_dedupes_overlapping_tail() -> None:
|
||||
loop = _mk_loop()
|
||||
session = Session(
|
||||
key="test:checkpoint-overlap",
|
||||
messages=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "working",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_done",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": "{}"},
|
||||
},
|
||||
{
|
||||
"id": "call_pending",
|
||||
"type": "function",
|
||||
"function": {"name": "exec", "arguments": "{}"},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_done",
|
||||
"name": "read_file",
|
||||
"content": "ok",
|
||||
},
|
||||
],
|
||||
metadata={
|
||||
AgentLoop._RUNTIME_CHECKPOINT_KEY: {
|
||||
"assistant_message": {
|
||||
"role": "assistant",
|
||||
"content": "working",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_done",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": "{}"},
|
||||
},
|
||||
{
|
||||
"id": "call_pending",
|
||||
"type": "function",
|
||||
"function": {"name": "exec", "arguments": "{}"},
|
||||
},
|
||||
],
|
||||
},
|
||||
"completed_tool_results": [
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_done",
|
||||
"name": "read_file",
|
||||
"content": "ok",
|
||||
}
|
||||
],
|
||||
"pending_tool_calls": [
|
||||
{
|
||||
"id": "call_pending",
|
||||
"type": "function",
|
||||
"function": {"name": "exec", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
restored = loop._restore_runtime_checkpoint(session)
|
||||
|
||||
assert restored is True
|
||||
assert session.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is None
|
||||
assert len(session.messages) == 3
|
||||
assert session.messages[0]["role"] == "assistant"
|
||||
assert session.messages[1]["tool_call_id"] == "call_done"
|
||||
assert session.messages[2]["tool_call_id"] == "call_pending"
|
||||
|
||||
@ -2,12 +2,20 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
|
||||
def _make_loop(tmp_path):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
@ -60,6 +68,7 @@ async def test_runner_preserves_reasoning_fields_and_tool_results():
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
@ -135,6 +144,7 @@ async def test_runner_calls_hooks_in_order():
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=RecordingHook(),
|
||||
))
|
||||
|
||||
@ -191,6 +201,7 @@ async def test_runner_streaming_hook_receives_deltas_and_end_signal():
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=StreamingHook(),
|
||||
))
|
||||
|
||||
@ -219,6 +230,7 @@ async def test_runner_returns_max_iterations_fallback():
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.stop_reason == "max_iterations"
|
||||
@ -226,7 +238,8 @@ async def test_runner_returns_max_iterations_fallback():
|
||||
"I reached the maximum number of tool call iterations (2) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
|
||||
assert result.messages[-1]["role"] == "assistant"
|
||||
assert result.messages[-1]["content"] == result.final_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_returns_structured_tool_error():
|
||||
@ -248,6 +261,7 @@ async def test_runner_returns_structured_tool_error():
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
fail_on_tool_error=True,
|
||||
))
|
||||
|
||||
@ -258,6 +272,457 @@ async def test_runner_returns_structured_tool_error():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_persists_large_tool_results_for_follow_up_calls(tmp_path):
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_second_call: list[dict] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_big", name="list_dir", arguments={"path": "."})],
|
||||
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="x" * 20_000)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
workspace=tmp_path,
|
||||
session_key="test:runner",
|
||||
max_tool_result_chars=2048,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
|
||||
assert "[tool output persisted]" in tool_message["content"]
|
||||
assert "tool-results" in tool_message["content"]
|
||||
assert (tmp_path / ".nanobot" / "tool-results" / "test_runner" / "call_big.txt").exists()
|
||||
|
||||
|
||||
def test_persist_tool_result_prunes_old_session_buckets(tmp_path):
|
||||
from nanobot.utils.helpers import maybe_persist_tool_result
|
||||
|
||||
root = tmp_path / ".nanobot" / "tool-results"
|
||||
old_bucket = root / "old_session"
|
||||
recent_bucket = root / "recent_session"
|
||||
old_bucket.mkdir(parents=True)
|
||||
recent_bucket.mkdir(parents=True)
|
||||
(old_bucket / "old.txt").write_text("old", encoding="utf-8")
|
||||
(recent_bucket / "recent.txt").write_text("recent", encoding="utf-8")
|
||||
|
||||
stale = time.time() - (8 * 24 * 60 * 60)
|
||||
os.utime(old_bucket, (stale, stale))
|
||||
os.utime(old_bucket / "old.txt", (stale, stale))
|
||||
|
||||
persisted = maybe_persist_tool_result(
|
||||
tmp_path,
|
||||
"current:session",
|
||||
"call_big",
|
||||
"x" * 5000,
|
||||
max_chars=64,
|
||||
)
|
||||
|
||||
assert "[tool output persisted]" in persisted
|
||||
assert not old_bucket.exists()
|
||||
assert recent_bucket.exists()
|
||||
assert (root / "current_session" / "call_big.txt").exists()
|
||||
|
||||
|
||||
def test_persist_tool_result_leaves_no_temp_files(tmp_path):
|
||||
from nanobot.utils.helpers import maybe_persist_tool_result
|
||||
|
||||
root = tmp_path / ".nanobot" / "tool-results"
|
||||
maybe_persist_tool_result(
|
||||
tmp_path,
|
||||
"current:session",
|
||||
"call_big",
|
||||
"x" * 5000,
|
||||
max_chars=64,
|
||||
)
|
||||
|
||||
assert (root / "current_session" / "call_big.txt").exists()
|
||||
assert list((root / "current_session").glob("*.tmp")) == []
|
||||
|
||||
|
||||
def test_persist_tool_result_logs_cleanup_failures(monkeypatch, tmp_path):
|
||||
from nanobot.utils.helpers import maybe_persist_tool_result
|
||||
|
||||
warnings: list[str] = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.utils.helpers._cleanup_tool_result_buckets",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(OSError("busy")),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.utils.helpers.logger.warning",
|
||||
lambda message, *args: warnings.append(message.format(*args)),
|
||||
)
|
||||
|
||||
persisted = maybe_persist_tool_result(
|
||||
tmp_path,
|
||||
"current:session",
|
||||
"call_big",
|
||||
"x" * 5000,
|
||||
max_chars=64,
|
||||
)
|
||||
|
||||
assert "[tool output persisted]" in persisted
|
||||
assert warnings and "Failed to clean stale tool result buckets" in warnings[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_replaces_empty_tool_result_with_marker():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_second_call: list[dict] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="noop", arguments={})],
|
||||
usage={},
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
|
||||
assert tool_message["content"] == "(noop completed with no output)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_uses_raw_messages_when_context_governance_fails():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_messages: list[dict] = []
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
captured_messages[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
initial_messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "hello"},
|
||||
]
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
runner._snip_history = MagicMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign]
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
assert captured_messages == initial_messages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_retries_empty_final_response_with_summary_prompt():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
calls: list[dict] = []
|
||||
|
||||
async def chat_with_retry(*, messages, tools=None, **kwargs):
|
||||
calls.append({"messages": messages, "tools": tools})
|
||||
if len(calls) == 1:
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 1},
|
||||
)
|
||||
return LLMResponse(
|
||||
content="final answer",
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 3, "completion_tokens": 7},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "final answer"
|
||||
assert len(calls) == 2
|
||||
assert calls[1]["tools"] is None
|
||||
assert "Do not call any more tools" in calls[1]["messages"][-1]["content"]
|
||||
assert result.usage["prompt_tokens"] == 13
|
||||
assert result.usage["completion_tokens"] == 8
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_uses_specific_message_after_empty_finalization_retry():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
|
||||
provider = MagicMock()
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
return LLMResponse(content=None, tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
assert result.stop_reason == "empty_final_response"
|
||||
|
||||
|
||||
def test_snip_history_drops_orphaned_tool_results_from_trimmed_slice(monkeypatch):
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
runner = AgentRunner(provider)
|
||||
messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "old user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "tool call",
|
||||
"tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "ls", "arguments": "{}"}}],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "tool output"},
|
||||
{"role": "assistant", "content": "after tool"},
|
||||
]
|
||||
spec = AgentRunSpec(
|
||||
initial_messages=messages,
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
context_window_tokens=2000,
|
||||
context_block_limit=100,
|
||||
)
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", lambda *_args, **_kwargs: (500, None))
|
||||
token_sizes = {
|
||||
"old user": 120,
|
||||
"tool call": 120,
|
||||
"tool output": 40,
|
||||
"after tool": 40,
|
||||
"system": 0,
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
"nanobot.agent.runner.estimate_message_tokens",
|
||||
lambda msg: token_sizes.get(str(msg.get("content")), 40),
|
||||
)
|
||||
|
||||
trimmed = runner._snip_history(spec, messages)
|
||||
|
||||
assert trimmed == [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "assistant", "content": "after tool"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_keeps_going_when_tool_result_persistence_fails():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_second_call: list[dict] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="tool result")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
with patch("nanobot.agent.runner.maybe_persist_tool_result", side_effect=RuntimeError("disk full")):
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
|
||||
assert tool_message["content"] == "tool result"
|
||||
|
||||
|
||||
class _DelayTool(Tool):
|
||||
def __init__(self, name: str, *, delay: float, read_only: bool, shared_events: list[str]):
|
||||
self._name = name
|
||||
self._delay = delay
|
||||
self._read_only = read_only
|
||||
self._shared_events = shared_events
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict:
|
||||
return {"type": "object", "properties": {}, "required": []}
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
return self._read_only
|
||||
|
||||
async def execute(self, **kwargs):
|
||||
self._shared_events.append(f"start:{self._name}")
|
||||
await asyncio.sleep(self._delay)
|
||||
self._shared_events.append(f"end:{self._name}")
|
||||
return self._name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_batches_read_only_tools_before_exclusive_work():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
tools = ToolRegistry()
|
||||
shared_events: list[str] = []
|
||||
read_a = _DelayTool("read_a", delay=0.05, read_only=True, shared_events=shared_events)
|
||||
read_b = _DelayTool("read_b", delay=0.05, read_only=True, shared_events=shared_events)
|
||||
write_a = _DelayTool("write_a", delay=0.01, read_only=False, shared_events=shared_events)
|
||||
tools.register(read_a)
|
||||
tools.register(read_b)
|
||||
tools.register(write_a)
|
||||
|
||||
runner = AgentRunner(MagicMock())
|
||||
await runner._execute_tools(
|
||||
AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
concurrent_tools=True,
|
||||
),
|
||||
[
|
||||
ToolCallRequest(id="ro1", name="read_a", arguments={}),
|
||||
ToolCallRequest(id="ro2", name="read_b", arguments={}),
|
||||
ToolCallRequest(id="rw1", name="write_a", arguments={}),
|
||||
],
|
||||
{},
|
||||
)
|
||||
|
||||
assert shared_events[0:2] == ["start:read_a", "start:read_b"]
|
||||
assert "end:read_a" in shared_events and "end:read_b" in shared_events
|
||||
assert shared_events.index("end:read_a") < shared_events.index("start:write_a")
|
||||
assert shared_events.index("end:read_b") < shared_events.index("start:write_a")
|
||||
assert shared_events[-2:] == ["start:write_a", "end:write_a"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_blocks_repeated_external_fetches():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_final_call: list[dict] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] <= 3:
|
||||
return LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id=f"call_{call_count['n']}", name="web_fetch", arguments={"url": "https://example.com"})],
|
||||
usage={},
|
||||
)
|
||||
captured_final_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="page content")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "research task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=4,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
assert tools.execute.await_count == 2
|
||||
blocked_tool_message = [
|
||||
msg for msg in captured_final_call
|
||||
if msg.get("role") == "tool" and msg.get("tool_call_id") == "call_3"
|
||||
][0]
|
||||
assert "repeated external lookup blocked" in blocked_tool_message["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_max_iterations_message_stays_stable(tmp_path):
|
||||
loop = _make_loop(tmp_path)
|
||||
@ -307,6 +772,57 @@ async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp
|
||||
assert endings == [False]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_retries_think_only_final_response(tmp_path):
|
||||
loop = _make_loop(tmp_path)
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(content="<think>hidden</think>", tool_calls=[], usage={})
|
||||
return LLMResponse(content="Recovered answer", tool_calls=[], usage={})
|
||||
|
||||
loop.provider.chat_with_retry = chat_with_retry
|
||||
|
||||
final_content, _, _ = await loop._run_agent_loop([])
|
||||
|
||||
assert final_content == "Recovered answer"
|
||||
assert call_count["n"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_tool_error_sets_final_content():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
return LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})],
|
||||
usage={},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
fail_on_tool_error=True,
|
||||
))
|
||||
|
||||
assert result.final_content == "Error: RuntimeError: boom"
|
||||
assert result.stop_reason == "tool_error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, monkeypatch):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
@ -317,15 +833,20 @@ async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, mon
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
))
|
||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||
mgr = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
bus=bus,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
)
|
||||
mgr._announce_result = AsyncMock()
|
||||
|
||||
async def fake_execute(self, name, arguments):
|
||||
async def fake_execute(self, **kwargs):
|
||||
return "tool result"
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
|
||||
|
||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||
|
||||
@ -369,6 +890,7 @@ async def test_runner_accumulates_usage_and_preserves_cached_tokens():
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
# Usage should be accumulated across iterations
|
||||
@ -407,6 +929,7 @@ async def test_runner_passes_cached_tokens_to_hook_context():
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=UsageHook(),
|
||||
))
|
||||
|
||||
|
||||
@ -8,6 +8,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
|
||||
def _make_loop(*, exec_config=None):
|
||||
"""Create a minimal AgentLoop with mocked dependencies."""
|
||||
@ -186,7 +190,12 @@ class TestSubagentCancellation:
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||
mgr = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=MagicMock(),
|
||||
bus=bus,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
)
|
||||
|
||||
cancelled = asyncio.Event()
|
||||
|
||||
@ -214,7 +223,12 @@ class TestSubagentCancellation:
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||
mgr = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=MagicMock(),
|
||||
bus=bus,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
)
|
||||
assert await mgr.cancel_by_session("nonexistent") == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -236,19 +250,24 @@ class TestSubagentCancellation:
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
reasoning_content="hidden reasoning",
|
||||
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[])
|
||||
provider.chat_with_retry = scripted_chat_with_retry
|
||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||
mgr = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
bus=bus,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
)
|
||||
|
||||
async def fake_execute(self, name, arguments):
|
||||
async def fake_execute(self, **kwargs):
|
||||
return "tool result"
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
|
||||
|
||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||
|
||||
@ -273,6 +292,7 @@ class TestSubagentCancellation:
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
bus=bus,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
exec_config=ExecToolConfig(enable=False),
|
||||
)
|
||||
mgr._announce_result = AsyncMock()
|
||||
@ -304,20 +324,25 @@ class TestSubagentCancellation:
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
))
|
||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||
mgr = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
bus=bus,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
)
|
||||
mgr._announce_result = AsyncMock()
|
||||
|
||||
calls = {"n": 0}
|
||||
|
||||
async def fake_execute(self, name, arguments):
|
||||
async def fake_execute(self, **kwargs):
|
||||
calls["n"] += 1
|
||||
if calls["n"] == 1:
|
||||
return "first result"
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
|
||||
|
||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||
|
||||
@ -340,15 +365,20 @@ class TestSubagentCancellation:
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
))
|
||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||
mgr = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
bus=bus,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
)
|
||||
mgr._announce_result = AsyncMock()
|
||||
|
||||
started = asyncio.Event()
|
||||
cancelled = asyncio.Event()
|
||||
|
||||
async def fake_execute(self, name, arguments):
|
||||
async def fake_execute(self, **kwargs):
|
||||
started.set()
|
||||
try:
|
||||
await asyncio.sleep(60)
|
||||
@ -356,7 +386,7 @@ class TestSubagentCancellation:
|
||||
cancelled.set()
|
||||
raise
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
|
||||
|
||||
task = asyncio.create_task(
|
||||
mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||
@ -364,7 +394,7 @@ class TestSubagentCancellation:
|
||||
mgr._running_tasks["sub-1"] = task
|
||||
mgr._session_tasks["test:c1"] = {"sub-1"}
|
||||
|
||||
await started.wait()
|
||||
await asyncio.wait_for(started.wait(), timeout=1.0)
|
||||
|
||||
count = await mgr.cancel_by_session("test:c1")
|
||||
|
||||
|
||||
@ -208,7 +208,7 @@ def test_channels_login_uses_discovered_plugin_class(monkeypatch):
|
||||
seen["config"] = self.config
|
||||
return True
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda: Config())
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config())
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.registry.discover_all",
|
||||
lambda: {"fakeplugin": _LoginPlugin},
|
||||
@ -220,6 +220,57 @@ def test_channels_login_uses_discovered_plugin_class(monkeypatch):
|
||||
assert seen["force"] is True
|
||||
|
||||
|
||||
def test_channels_login_sets_custom_config_path(monkeypatch, tmp_path):
|
||||
from nanobot.cli.commands import app
|
||||
from nanobot.config.schema import Config
|
||||
from typer.testing import CliRunner
|
||||
|
||||
runner = CliRunner()
|
||||
seen: dict[str, object] = {}
|
||||
config_path = tmp_path / "custom-config.json"
|
||||
|
||||
class _LoginPlugin(_FakePlugin):
|
||||
async def login(self, force: bool = False) -> bool:
|
||||
return True
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config())
|
||||
monkeypatch.setattr(
|
||||
"nanobot.config.loader.set_config_path",
|
||||
lambda path: seen.__setitem__("config_path", path),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.registry.discover_all",
|
||||
lambda: {"fakeplugin": _LoginPlugin},
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["channels", "login", "fakeplugin", "--config", str(config_path)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert seen["config_path"] == config_path.resolve()
|
||||
|
||||
|
||||
def test_channels_status_sets_custom_config_path(monkeypatch, tmp_path):
|
||||
from nanobot.cli.commands import app
|
||||
from nanobot.config.schema import Config
|
||||
from typer.testing import CliRunner
|
||||
|
||||
runner = CliRunner()
|
||||
seen: dict[str, object] = {}
|
||||
config_path = tmp_path / "custom-config.json"
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config())
|
||||
monkeypatch.setattr(
|
||||
"nanobot.config.loader.set_config_path",
|
||||
lambda path: seen.__setitem__("config_path", path),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.channels.registry.discover_all", lambda: {})
|
||||
|
||||
result = runner.invoke(app, ["channels", "status", "--config", str(config_path)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert seen["config_path"] == config_path.resolve()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_skips_disabled_plugin():
|
||||
fake_config = SimpleNamespace(
|
||||
|
||||
@ -594,7 +594,7 @@ async def test_send_stops_typing_after_send() -> None:
|
||||
typing_channel.typing_enter_hook = slow_typing
|
||||
|
||||
await channel._start_typing(typing_channel)
|
||||
await start.wait()
|
||||
await asyncio.wait_for(start.wait(), timeout=1.0)
|
||||
|
||||
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||
release.set()
|
||||
@ -614,7 +614,7 @@ async def test_send_stops_typing_after_send() -> None:
|
||||
typing_channel.typing_enter_hook = slow_typing_progress
|
||||
|
||||
await channel._start_typing(typing_channel)
|
||||
await start.wait()
|
||||
await asyncio.wait_for(start.wait(), timeout=1.0)
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
@ -665,7 +665,7 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() ->
|
||||
|
||||
typing_channel = _NoTriggerChannel(channel_id=123)
|
||||
await channel._start_typing(typing_channel) # type: ignore[arg-type]
|
||||
await entered.wait()
|
||||
await asyncio.wait_for(entered.wait(), timeout=1.0)
|
||||
|
||||
assert "123" in channel._typing_tasks
|
||||
|
||||
|
||||
@ -3,16 +3,14 @@ from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("nio")
|
||||
pytest.importorskip("nh3")
|
||||
pytest.importorskip("mistune")
|
||||
from nio import RoomSendResponse
|
||||
|
||||
from nanobot.channels.matrix import _build_matrix_text_content
|
||||
|
||||
# Check optional matrix dependencies before importing
|
||||
try:
|
||||
import nh3 # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("Matrix dependencies not installed (nh3)", allow_module_level=True)
|
||||
|
||||
import nanobot.channels.matrix as matrix_module
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
172
tests/channels/test_qq_ack_message.py
Normal file
172
tests/channels/test_qq_ack_message.py
Normal file
@ -0,0 +1,172 @@
|
||||
"""Tests for QQ channel ack_message feature.
|
||||
|
||||
Covers the four verification points from the PR:
|
||||
1. C2C message: ack appears instantly
|
||||
2. Group message: ack appears instantly
|
||||
3. ack_message set to "": no ack sent
|
||||
4. Custom ack_message text: correct text delivered
|
||||
Each test also verifies that normal message processing is not blocked.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
from nanobot.channels import qq
|
||||
|
||||
QQ_AVAILABLE = getattr(qq, "QQ_AVAILABLE", False)
|
||||
except ImportError:
|
||||
QQ_AVAILABLE = False
|
||||
|
||||
if not QQ_AVAILABLE:
|
||||
pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True)
|
||||
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.qq import QQChannel, QQConfig
|
||||
|
||||
|
||||
class _FakeApi:
|
||||
def __init__(self) -> None:
|
||||
self.c2c_calls: list[dict] = []
|
||||
self.group_calls: list[dict] = []
|
||||
|
||||
async def post_c2c_message(self, **kwargs) -> None:
|
||||
self.c2c_calls.append(kwargs)
|
||||
|
||||
async def post_group_message(self, **kwargs) -> None:
|
||||
self.group_calls.append(kwargs)
|
||||
|
||||
|
||||
class _FakeClient:
|
||||
def __init__(self) -> None:
|
||||
self.api = _FakeApi()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ack_sent_on_c2c_message() -> None:
|
||||
"""Ack is sent immediately for C2C messages, then normal processing continues."""
|
||||
channel = QQChannel(
|
||||
QQConfig(
|
||||
app_id="app",
|
||||
secret="secret",
|
||||
allow_from=["*"],
|
||||
ack_message="⏳ Processing...",
|
||||
),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._client = _FakeClient()
|
||||
|
||||
data = SimpleNamespace(
|
||||
id="msg1",
|
||||
content="hello",
|
||||
author=SimpleNamespace(user_openid="user1"),
|
||||
attachments=[],
|
||||
)
|
||||
await channel._on_message(data, is_group=False)
|
||||
|
||||
assert len(channel._client.api.c2c_calls) >= 1
|
||||
ack_call = channel._client.api.c2c_calls[0]
|
||||
assert ack_call["content"] == "⏳ Processing..."
|
||||
assert ack_call["openid"] == "user1"
|
||||
assert ack_call["msg_id"] == "msg1"
|
||||
assert ack_call["msg_type"] == 0
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert msg.content == "hello"
|
||||
assert msg.sender_id == "user1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ack_sent_on_group_message() -> None:
|
||||
"""Ack is sent immediately for group messages, then normal processing continues."""
|
||||
channel = QQChannel(
|
||||
QQConfig(
|
||||
app_id="app",
|
||||
secret="secret",
|
||||
allow_from=["*"],
|
||||
ack_message="⏳ Processing...",
|
||||
),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._client = _FakeClient()
|
||||
|
||||
data = SimpleNamespace(
|
||||
id="msg2",
|
||||
content="hello group",
|
||||
group_openid="group123",
|
||||
author=SimpleNamespace(member_openid="user1"),
|
||||
attachments=[],
|
||||
)
|
||||
await channel._on_message(data, is_group=True)
|
||||
|
||||
assert len(channel._client.api.group_calls) >= 1
|
||||
ack_call = channel._client.api.group_calls[0]
|
||||
assert ack_call["content"] == "⏳ Processing..."
|
||||
assert ack_call["group_openid"] == "group123"
|
||||
assert ack_call["msg_id"] == "msg2"
|
||||
assert ack_call["msg_type"] == 0
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert msg.content == "hello group"
|
||||
assert msg.chat_id == "group123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_ack_when_ack_message_empty() -> None:
|
||||
"""Setting ack_message to empty string disables the ack entirely."""
|
||||
channel = QQChannel(
|
||||
QQConfig(
|
||||
app_id="app",
|
||||
secret="secret",
|
||||
allow_from=["*"],
|
||||
ack_message="",
|
||||
),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._client = _FakeClient()
|
||||
|
||||
data = SimpleNamespace(
|
||||
id="msg3",
|
||||
content="hello",
|
||||
author=SimpleNamespace(user_openid="user1"),
|
||||
attachments=[],
|
||||
)
|
||||
await channel._on_message(data, is_group=False)
|
||||
|
||||
assert len(channel._client.api.c2c_calls) == 0
|
||||
assert len(channel._client.api.group_calls) == 0
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert msg.content == "hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_ack_message_text() -> None:
|
||||
"""Custom Chinese ack_message text is delivered correctly."""
|
||||
custom = "正在处理中,请稍候..."
|
||||
channel = QQChannel(
|
||||
QQConfig(
|
||||
app_id="app",
|
||||
secret="secret",
|
||||
allow_from=["*"],
|
||||
ack_message=custom,
|
||||
),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._client = _FakeClient()
|
||||
|
||||
data = SimpleNamespace(
|
||||
id="msg4",
|
||||
content="test input",
|
||||
author=SimpleNamespace(user_openid="user1"),
|
||||
attachments=[],
|
||||
)
|
||||
await channel._on_message(data, is_group=False)
|
||||
|
||||
assert len(channel._client.api.c2c_calls) >= 1
|
||||
ack_call = channel._client.api.c2c_calls[0]
|
||||
assert ack_call["content"] == custom
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert msg.content == "test input"
|
||||
@ -647,43 +647,56 @@ async def test_group_policy_open_accepts_plain_group_message() -> None:
|
||||
assert channel._app.bot.get_me_calls == 0
|
||||
|
||||
|
||||
def test_extract_reply_context_no_reply() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_reply_context_no_reply() -> None:
|
||||
"""When there is no reply_to_message, _extract_reply_context returns None."""
|
||||
channel = TelegramChannel(TelegramConfig(enabled=True, token="123:abc"), MessageBus())
|
||||
message = SimpleNamespace(reply_to_message=None)
|
||||
assert TelegramChannel._extract_reply_context(message) is None
|
||||
assert await channel._extract_reply_context(message) is None
|
||||
|
||||
|
||||
def test_extract_reply_context_with_text() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_reply_context_with_text() -> None:
|
||||
"""When reply has text, return prefixed string."""
|
||||
reply = SimpleNamespace(text="Hello world", caption=None)
|
||||
channel = TelegramChannel(TelegramConfig(enabled=True, token="123:abc"), MessageBus())
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
reply = SimpleNamespace(text="Hello world", caption=None, from_user=SimpleNamespace(id=2, username="testuser", first_name="Test"))
|
||||
message = SimpleNamespace(reply_to_message=reply)
|
||||
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Hello world]"
|
||||
assert await channel._extract_reply_context(message) == "[Reply to @testuser: Hello world]"
|
||||
|
||||
|
||||
def test_extract_reply_context_with_caption_only() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_reply_context_with_caption_only() -> None:
|
||||
"""When reply has only caption (no text), caption is used."""
|
||||
reply = SimpleNamespace(text=None, caption="Photo caption")
|
||||
channel = TelegramChannel(TelegramConfig(enabled=True, token="123:abc"), MessageBus())
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
reply = SimpleNamespace(text=None, caption="Photo caption", from_user=SimpleNamespace(id=2, username=None, first_name="Test"))
|
||||
message = SimpleNamespace(reply_to_message=reply)
|
||||
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Photo caption]"
|
||||
assert await channel._extract_reply_context(message) == "[Reply to Test: Photo caption]"
|
||||
|
||||
|
||||
def test_extract_reply_context_truncation() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_reply_context_truncation() -> None:
|
||||
"""Reply text is truncated at TELEGRAM_REPLY_CONTEXT_MAX_LEN."""
|
||||
channel = TelegramChannel(TelegramConfig(enabled=True, token="123:abc"), MessageBus())
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
long_text = "x" * (TELEGRAM_REPLY_CONTEXT_MAX_LEN + 100)
|
||||
reply = SimpleNamespace(text=long_text, caption=None)
|
||||
reply = SimpleNamespace(text=long_text, caption=None, from_user=SimpleNamespace(id=2, username=None, first_name=None))
|
||||
message = SimpleNamespace(reply_to_message=reply)
|
||||
result = TelegramChannel._extract_reply_context(message)
|
||||
result = await channel._extract_reply_context(message)
|
||||
assert result is not None
|
||||
assert result.startswith("[Reply to: ")
|
||||
assert result.endswith("...]")
|
||||
assert len(result) == len("[Reply to: ]") + TELEGRAM_REPLY_CONTEXT_MAX_LEN + len("...")
|
||||
|
||||
|
||||
def test_extract_reply_context_no_text_returns_none() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_reply_context_no_text_returns_none() -> None:
|
||||
"""When reply has no text/caption, _extract_reply_context returns None (media handled separately)."""
|
||||
channel = TelegramChannel(TelegramConfig(enabled=True, token="123:abc"), MessageBus())
|
||||
reply = SimpleNamespace(text=None, caption=None)
|
||||
message = SimpleNamespace(reply_to_message=reply)
|
||||
assert TelegramChannel._extract_reply_context(message) is None
|
||||
assert await channel._extract_reply_context(message) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
|
||||
"""Test Azure OpenAI provider (Responses API via OpenAI SDK)."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
@ -8,392 +8,401 @@ from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
def test_azure_openai_provider_init():
|
||||
"""Test AzureOpenAIProvider initialization without deployment_name."""
|
||||
# ---------------------------------------------------------------------------
|
||||
# Init & validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_init_creates_sdk_client():
|
||||
"""Provider creates an AsyncOpenAI client with correct base_url."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
)
|
||||
|
||||
assert provider.api_key == "test-key"
|
||||
assert provider.api_base == "https://test-resource.openai.azure.com/"
|
||||
assert provider.default_model == "gpt-4o-deployment"
|
||||
assert provider.api_version == "2024-10-21"
|
||||
# SDK client base_url ends with /openai/v1/
|
||||
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
|
||||
|
||||
|
||||
def test_azure_openai_provider_init_validation():
|
||||
"""Test AzureOpenAIProvider initialization validation."""
|
||||
# Missing api_key
|
||||
def test_init_base_url_no_trailing_slash():
|
||||
"""Trailing slashes are normalised before building base_url."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="k", api_base="https://res.openai.azure.com",
|
||||
)
|
||||
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
|
||||
|
||||
|
||||
def test_init_base_url_with_trailing_slash():
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="k", api_base="https://res.openai.azure.com/",
|
||||
)
|
||||
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
|
||||
|
||||
|
||||
def test_init_validation_missing_key():
|
||||
with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
|
||||
AzureOpenAIProvider(api_key="", api_base="https://test.com")
|
||||
|
||||
# Missing api_base
|
||||
|
||||
|
||||
def test_init_validation_missing_base():
|
||||
with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
|
||||
AzureOpenAIProvider(api_key="test", api_base="")
|
||||
|
||||
|
||||
def test_build_chat_url():
|
||||
"""Test Azure OpenAI URL building with different deployment names."""
|
||||
def test_no_api_version_in_base_url():
|
||||
"""The /openai/v1/ path should NOT contain an api-version query param."""
|
||||
provider = AzureOpenAIProvider(api_key="k", api_base="https://res.openai.azure.com")
|
||||
base = str(provider._client.base_url)
|
||||
assert "api-version" not in base
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _supports_temperature
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_supports_temperature_standard_model():
|
||||
assert AzureOpenAIProvider._supports_temperature("gpt-4o") is True
|
||||
|
||||
|
||||
def test_supports_temperature_reasoning_model():
|
||||
assert AzureOpenAIProvider._supports_temperature("o3-mini") is False
|
||||
assert AzureOpenAIProvider._supports_temperature("gpt-5-chat") is False
|
||||
assert AzureOpenAIProvider._supports_temperature("o4-mini") is False
|
||||
|
||||
|
||||
def test_supports_temperature_with_reasoning_effort():
|
||||
assert AzureOpenAIProvider._supports_temperature("gpt-4o", reasoning_effort="medium") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_body — Responses API body construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_body_basic():
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
api_key="k", api_base="https://res.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Test various deployment names
|
||||
test_cases = [
|
||||
("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"),
|
||||
("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"),
|
||||
("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"),
|
||||
]
|
||||
|
||||
for deployment_name, expected_url in test_cases:
|
||||
url = provider._build_chat_url(deployment_name)
|
||||
assert url == expected_url
|
||||
messages = [{"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hi"}]
|
||||
body = provider._build_body(messages, None, None, 4096, 0.7, None, None)
|
||||
|
||||
|
||||
def test_build_chat_url_api_base_without_slash():
|
||||
"""Test URL building when api_base doesn't end with slash."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com", # No trailing slash
|
||||
default_model="gpt-4o",
|
||||
assert body["model"] == "gpt-4o"
|
||||
assert body["instructions"] == "You are helpful."
|
||||
assert body["temperature"] == 0.7
|
||||
assert body["max_output_tokens"] == 4096
|
||||
assert body["store"] is False
|
||||
assert "reasoning" not in body
|
||||
# input should contain the converted user message only (system extracted)
|
||||
assert any(
|
||||
item.get("role") == "user"
|
||||
for item in body["input"]
|
||||
)
|
||||
|
||||
url = provider._build_chat_url("test-deployment")
|
||||
expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert url == expected
|
||||
|
||||
|
||||
def test_build_headers():
|
||||
"""Test Azure OpenAI header building with api-key authentication."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-api-key-123",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
headers = provider._build_headers()
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header
|
||||
assert "x-session-affinity" in headers
|
||||
def test_build_body_max_tokens_minimum():
|
||||
"""max_output_tokens should never be less than 1."""
|
||||
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
|
||||
body = provider._build_body([{"role": "user", "content": "x"}], None, None, 0, 0.7, None, None)
|
||||
assert body["max_output_tokens"] == 1
|
||||
|
||||
|
||||
def test_prepare_request_payload():
|
||||
"""Test request payload preparation with Azure OpenAI 2024-10-21 compliance."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
|
||||
|
||||
assert payload["messages"] == messages
|
||||
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
|
||||
assert payload["temperature"] == 0.8
|
||||
assert "tools" not in payload
|
||||
|
||||
# Test with tools
|
||||
def test_build_body_with_tools():
|
||||
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
|
||||
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||
payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools)
|
||||
assert payload_with_tools["tools"] == tools
|
||||
assert payload_with_tools["tool_choice"] == "auto"
|
||||
|
||||
# Test with reasoning_effort
|
||||
payload_with_reasoning = provider._prepare_request_payload(
|
||||
"gpt-5-chat", messages, reasoning_effort="medium"
|
||||
body = provider._build_body(
|
||||
[{"role": "user", "content": "weather?"}], tools, None, 4096, 0.7, None, None,
|
||||
)
|
||||
assert payload_with_reasoning["reasoning_effort"] == "medium"
|
||||
assert "temperature" not in payload_with_reasoning
|
||||
assert body["tools"] == [{"type": "function", "name": "get_weather", "description": "", "parameters": {}}]
|
||||
assert body["tool_choice"] == "auto"
|
||||
|
||||
|
||||
def test_prepare_request_payload_sanitizes_messages():
|
||||
"""Test Azure payload strips non-standard message keys before sending."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
def test_build_body_with_reasoning():
|
||||
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-5-chat")
|
||||
body = provider._build_body(
|
||||
[{"role": "user", "content": "think"}], None, "gpt-5-chat", 4096, 0.7, "medium", None,
|
||||
)
|
||||
assert body["reasoning"] == {"effort": "medium"}
|
||||
assert "reasoning.encrypted_content" in body.get("include", [])
|
||||
# temperature omitted for reasoning models
|
||||
assert "temperature" not in body
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
||||
"reasoning_content": "hidden chain-of-thought",
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"name": "x",
|
||||
"content": "ok",
|
||||
"extra_field": "should be removed",
|
||||
},
|
||||
]
|
||||
|
||||
payload = provider._prepare_request_payload("gpt-4o", messages)
|
||||
def test_build_body_image_conversion():
|
||||
"""image_url content blocks should be converted to input_image."""
|
||||
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/img.png"}},
|
||||
],
|
||||
}]
|
||||
body = provider._build_body(messages, None, None, 4096, 0.7, None, None)
|
||||
user_item = body["input"][0]
|
||||
content_types = [b["type"] for b in user_item["content"]]
|
||||
assert "input_text" in content_types
|
||||
assert "input_image" in content_types
|
||||
image_block = next(b for b in user_item["content"] if b["type"] == "input_image")
|
||||
assert image_block["image_url"] == "https://example.com/img.png"
|
||||
|
||||
assert payload["messages"] == [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
||||
|
||||
def test_build_body_sanitizes_single_dict_content_block():
|
||||
"""Single content dicts should be preserved via shared message sanitization."""
|
||||
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": {"type": "text", "text": "Hi from dict content"},
|
||||
}]
|
||||
|
||||
body = provider._build_body(messages, None, None, 4096, 0.7, None, None)
|
||||
|
||||
assert body["input"][0]["content"] == [{"type": "input_text", "text": "Hi from dict content"}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# chat() — non-streaming
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_sdk_response(
|
||||
content="Hello!", tool_calls=None, status="completed",
|
||||
usage=None,
|
||||
):
|
||||
"""Build a mock that quacks like an openai Response object."""
|
||||
resp = MagicMock()
|
||||
resp.model_dump = MagicMock(return_value={
|
||||
"output": [
|
||||
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": content}]},
|
||||
*([{
|
||||
"type": "function_call",
|
||||
"call_id": tc["call_id"], "id": tc["id"],
|
||||
"name": tc["name"], "arguments": tc["arguments"],
|
||||
} for tc in (tool_calls or [])]),
|
||||
],
|
||||
"status": status,
|
||||
"usage": {
|
||||
"input_tokens": (usage or {}).get("input_tokens", 10),
|
||||
"output_tokens": (usage or {}).get("output_tokens", 5),
|
||||
"total_tokens": (usage or {}).get("total_tokens", 15),
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"name": "x",
|
||||
"content": "ok",
|
||||
},
|
||||
]
|
||||
})
|
||||
return resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_success():
|
||||
"""Test successful chat request using model as deployment name."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Mock response data
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": "Hello! How can I help you today?",
|
||||
"role": "assistant"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 12,
|
||||
"completion_tokens": 18,
|
||||
"total_tokens": 30
|
||||
}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
# Test with specific model (deployment name)
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages, model="custom-deployment")
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert result.content == "Hello! How can I help you today?"
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.usage["prompt_tokens"] == 12
|
||||
assert result.usage["completion_tokens"] == 18
|
||||
assert result.usage["total_tokens"] == 30
|
||||
|
||||
# Verify URL was built with the provided model as deployment name
|
||||
call_args = mock_context.post.call_args
|
||||
expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert call_args[0][0] == expected_url
|
||||
mock_resp = _make_sdk_response(content="Hello!")
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_resp)
|
||||
|
||||
result = await provider.chat([{"role": "user", "content": "Hi"}])
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert result.content == "Hello!"
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.usage["prompt_tokens"] == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_uses_default_model_when_no_model_provided():
|
||||
"""Test that chat uses default_model when no model is specified."""
|
||||
async def test_chat_uses_default_model():
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="default-deployment",
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="my-deployment",
|
||||
)
|
||||
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {"content": "Response", "role": "assistant"},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Test"}]
|
||||
await provider.chat(messages) # No model specified
|
||||
|
||||
# Verify URL was built with default model as deployment name
|
||||
call_args = mock_context.post.call_args
|
||||
expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert call_args[0][0] == expected_url
|
||||
mock_resp = _make_sdk_response(content="ok")
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_resp)
|
||||
|
||||
await provider.chat([{"role": "user", "content": "test"}])
|
||||
|
||||
call_kwargs = provider._client.responses.create.call_args[1]
|
||||
assert call_kwargs["model"] == "my-deployment"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_custom_model():
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
mock_resp = _make_sdk_response(content="ok")
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_resp)
|
||||
|
||||
await provider.chat([{"role": "user", "content": "test"}], model="custom-deploy")
|
||||
|
||||
call_kwargs = provider._client.responses.create.call_args[1]
|
||||
assert call_kwargs["model"] == "custom-deploy"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_tool_calls():
|
||||
"""Test chat request with tool calls in response."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Mock response with tool calls
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": None,
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_12345",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}'
|
||||
}
|
||||
}]
|
||||
},
|
||||
"finish_reason": "tool_calls"
|
||||
mock_resp = _make_sdk_response(
|
||||
content=None,
|
||||
tool_calls=[{
|
||||
"call_id": "call_123", "id": "fc_1",
|
||||
"name": "get_weather", "arguments": '{"location": "SF"}',
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 20,
|
||||
"completion_tokens": 15,
|
||||
"total_tokens": 35
|
||||
}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "What's the weather?"}]
|
||||
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||
result = await provider.chat(messages, tools=tools, model="weather-model")
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert result.content is None
|
||||
assert result.finish_reason == "tool_calls"
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "get_weather"
|
||||
assert result.tool_calls[0].arguments == {"location": "San Francisco"}
|
||||
)
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_resp)
|
||||
|
||||
result = await provider.chat(
|
||||
[{"role": "user", "content": "Weather?"}],
|
||||
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "get_weather"
|
||||
assert result.tool_calls[0].arguments == {"location": "SF"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_api_error():
|
||||
"""Test chat request API error handling."""
|
||||
async def test_chat_error_handling():
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Invalid authentication credentials"
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Azure OpenAI API Error 401" in result.content
|
||||
assert "Invalid authentication credentials" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
|
||||
result = await provider.chat([{"role": "user", "content": "Hi"}])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_connection_error():
|
||||
"""Test chat request connection error handling."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
def test_parse_response_malformed():
|
||||
"""Test response parsing with malformed data."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Test with missing choices
|
||||
malformed_response = {"usage": {"prompt_tokens": 10}}
|
||||
result = provider._parse_response(malformed_response)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Error parsing Azure OpenAI response" in result.content
|
||||
assert "Connection failed" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_reasoning_param_format():
|
||||
"""reasoning_effort should be sent as reasoning={effort: ...} not a flat string."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-5-chat",
|
||||
)
|
||||
mock_resp = _make_sdk_response(content="thought")
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_resp)
|
||||
|
||||
await provider.chat(
|
||||
[{"role": "user", "content": "think"}], reasoning_effort="medium",
|
||||
)
|
||||
|
||||
call_kwargs = provider._client.responses.create.call_args[1]
|
||||
assert call_kwargs["reasoning"] == {"effort": "medium"}
|
||||
assert "reasoning_effort" not in call_kwargs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# chat_stream()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_stream_success():
|
||||
"""Streaming should call on_content_delta and return combined response."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Build mock SDK stream events
|
||||
events = []
|
||||
ev1 = MagicMock(type="response.output_text.delta", delta="Hello")
|
||||
ev2 = MagicMock(type="response.output_text.delta", delta=" world")
|
||||
resp_obj = MagicMock(status="completed")
|
||||
ev3 = MagicMock(type="response.completed", response=resp_obj)
|
||||
events = [ev1, ev2, ev3]
|
||||
|
||||
async def mock_stream():
|
||||
for e in events:
|
||||
yield e
|
||||
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_stream())
|
||||
|
||||
deltas: list[str] = []
|
||||
|
||||
async def on_delta(text: str) -> None:
|
||||
deltas.append(text)
|
||||
|
||||
result = await provider.chat_stream(
|
||||
[{"role": "user", "content": "Hi"}], on_content_delta=on_delta,
|
||||
)
|
||||
|
||||
assert result.content == "Hello world"
|
||||
assert result.finish_reason == "stop"
|
||||
assert deltas == ["Hello", " world"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_stream_with_tool_calls():
|
||||
"""Streaming tool calls should be accumulated correctly."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
|
||||
item_added = MagicMock(type="function_call", call_id="call_1", id="fc_1", arguments="")
|
||||
item_added.name = "get_weather"
|
||||
ev_added = MagicMock(type="response.output_item.added", item=item_added)
|
||||
ev_args_delta = MagicMock(type="response.function_call_arguments.delta", call_id="call_1", delta='{"loc')
|
||||
ev_args_done = MagicMock(
|
||||
type="response.function_call_arguments.done",
|
||||
call_id="call_1", arguments='{"location":"SF"}',
|
||||
)
|
||||
item_done = MagicMock(
|
||||
type="function_call", call_id="call_1", id="fc_1",
|
||||
arguments='{"location":"SF"}',
|
||||
)
|
||||
item_done.name = "get_weather"
|
||||
ev_item_done = MagicMock(type="response.output_item.done", item=item_done)
|
||||
resp_obj = MagicMock(status="completed")
|
||||
ev_completed = MagicMock(type="response.completed", response=resp_obj)
|
||||
|
||||
async def mock_stream():
|
||||
for e in [ev_added, ev_args_delta, ev_args_done, ev_item_done, ev_completed]:
|
||||
yield e
|
||||
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_stream())
|
||||
|
||||
result = await provider.chat_stream(
|
||||
[{"role": "user", "content": "weather?"}],
|
||||
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "get_weather"
|
||||
assert result.tool_calls[0].arguments == {"location": "SF"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_stream_error():
|
||||
"""Streaming should return error when SDK raises."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
|
||||
result = await provider.chat_stream([{"role": "user", "content": "Hi"}])
|
||||
|
||||
assert "Connection failed" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_default_model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_default_model():
|
||||
"""Test get_default_model method."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="my-custom-deployment",
|
||||
api_key="k", api_base="https://r.com", default_model="my-deploy",
|
||||
)
|
||||
|
||||
assert provider.get_default_model() == "my-custom-deployment"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run basic tests
|
||||
print("Running basic Azure OpenAI provider tests...")
|
||||
|
||||
# Test initialization
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
)
|
||||
print("✅ Provider initialization successful")
|
||||
|
||||
# Test URL building
|
||||
url = provider._build_chat_url("my-deployment")
|
||||
expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert url == expected
|
||||
print("✅ URL building works correctly")
|
||||
|
||||
# Test headers
|
||||
headers = provider._build_headers()
|
||||
assert headers["api-key"] == "test-key"
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
print("✅ Header building works correctly")
|
||||
|
||||
# Test payload preparation
|
||||
messages = [{"role": "user", "content": "Test"}]
|
||||
payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
|
||||
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
|
||||
print("✅ Payload preparation works correctly")
|
||||
|
||||
print("✅ All basic tests passed! Updated test file is working correctly.")
|
||||
assert provider.get_default_model() == "my-deploy"
|
||||
|
||||
@ -8,6 +8,7 @@ Validates that:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
@ -53,6 +54,15 @@ def _fake_tool_call_response() -> SimpleNamespace:
|
||||
return SimpleNamespace(choices=[choice], usage=usage)
|
||||
|
||||
|
||||
class _StalledStream:
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
await asyncio.sleep(3600)
|
||||
raise StopAsyncIteration
|
||||
|
||||
|
||||
def test_openrouter_spec_is_gateway() -> None:
|
||||
spec = find_by_name("openrouter")
|
||||
assert spec is not None
|
||||
@ -214,3 +224,54 @@ def test_openai_model_passthrough() -> None:
|
||||
spec=spec,
|
||||
)
|
||||
assert provider.get_default_model() == "gpt-4o"
|
||||
|
||||
|
||||
def test_openai_compat_strips_message_level_reasoning_fields() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
sanitized = provider._sanitize_messages([
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "done",
|
||||
"reasoning_content": "hidden",
|
||||
"extra_content": {"debug": True},
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "fn", "arguments": "{}"},
|
||||
"extra_content": {"google": {"thought_signature": "sig"}},
|
||||
}
|
||||
],
|
||||
}
|
||||
])
|
||||
|
||||
assert "reasoning_content" not in sanitized[0]
|
||||
assert "extra_content" not in sanitized[0]
|
||||
assert sanitized[0]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_compat_stream_watchdog_returns_error_on_stall(monkeypatch) -> None:
|
||||
monkeypatch.setenv("NANOBOT_STREAM_IDLE_TIMEOUT_S", "0")
|
||||
mock_create = AsyncMock(return_value=_StalledStream())
|
||||
spec = find_by_name("openai")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_create
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-test-key",
|
||||
default_model="gpt-4o",
|
||||
spec=spec,
|
||||
)
|
||||
result = await provider.chat_stream(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
assert result.finish_reason == "error"
|
||||
assert result.content is not None
|
||||
assert "stream stalled" in result.content
|
||||
|
||||
522
tests/providers/test_openai_responses.py
Normal file
522
tests/providers/test_openai_responses.py
Normal file
@ -0,0 +1,522 @@
|
||||
"""Tests for the shared openai_responses converters and parsers."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.openai_responses.converters import (
|
||||
convert_messages,
|
||||
convert_tools,
|
||||
convert_user_message,
|
||||
split_tool_call_id,
|
||||
)
|
||||
from nanobot.providers.openai_responses.parsing import (
|
||||
consume_sdk_stream,
|
||||
map_finish_reason,
|
||||
parse_response_output,
|
||||
)
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# converters - split_tool_call_id
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestSplitToolCallId:
|
||||
def test_plain_id(self):
|
||||
assert split_tool_call_id("call_abc") == ("call_abc", None)
|
||||
|
||||
def test_compound_id(self):
|
||||
assert split_tool_call_id("call_abc|fc_1") == ("call_abc", "fc_1")
|
||||
|
||||
def test_compound_empty_item_id(self):
|
||||
assert split_tool_call_id("call_abc|") == ("call_abc", None)
|
||||
|
||||
def test_none(self):
|
||||
assert split_tool_call_id(None) == ("call_0", None)
|
||||
|
||||
def test_empty_string(self):
|
||||
assert split_tool_call_id("") == ("call_0", None)
|
||||
|
||||
def test_non_string(self):
|
||||
assert split_tool_call_id(42) == ("call_0", None)
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# converters - convert_user_message
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestConvertUserMessage:
|
||||
def test_string_content(self):
|
||||
result = convert_user_message("hello")
|
||||
assert result == {"role": "user", "content": [{"type": "input_text", "text": "hello"}]}
|
||||
|
||||
def test_text_block(self):
|
||||
result = convert_user_message([{"type": "text", "text": "hi"}])
|
||||
assert result["content"] == [{"type": "input_text", "text": "hi"}]
|
||||
|
||||
def test_image_url_block(self):
|
||||
result = convert_user_message([
|
||||
{"type": "image_url", "image_url": {"url": "https://img.example/a.png"}},
|
||||
])
|
||||
assert result["content"] == [
|
||||
{"type": "input_image", "image_url": "https://img.example/a.png", "detail": "auto"},
|
||||
]
|
||||
|
||||
def test_mixed_text_and_image(self):
|
||||
result = convert_user_message([
|
||||
{"type": "text", "text": "what's this?"},
|
||||
{"type": "image_url", "image_url": {"url": "https://img.example/b.png"}},
|
||||
])
|
||||
assert len(result["content"]) == 2
|
||||
assert result["content"][0]["type"] == "input_text"
|
||||
assert result["content"][1]["type"] == "input_image"
|
||||
|
||||
def test_empty_list_falls_back(self):
|
||||
result = convert_user_message([])
|
||||
assert result["content"] == [{"type": "input_text", "text": ""}]
|
||||
|
||||
def test_none_falls_back(self):
|
||||
result = convert_user_message(None)
|
||||
assert result["content"] == [{"type": "input_text", "text": ""}]
|
||||
|
||||
def test_image_without_url_skipped(self):
|
||||
result = convert_user_message([{"type": "image_url", "image_url": {}}])
|
||||
assert result["content"] == [{"type": "input_text", "text": ""}]
|
||||
|
||||
def test_meta_fields_not_leaked(self):
|
||||
"""_meta on content blocks must never appear in converted output."""
|
||||
result = convert_user_message([
|
||||
{"type": "text", "text": "hi", "_meta": {"path": "/tmp/x"}},
|
||||
])
|
||||
assert "_meta" not in result["content"][0]
|
||||
|
||||
def test_non_dict_items_skipped(self):
|
||||
result = convert_user_message(["just a string", 42])
|
||||
assert result["content"] == [{"type": "input_text", "text": ""}]
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# converters - convert_messages
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestConvertMessages:
|
||||
def test_system_extracted_as_instructions(self):
|
||||
msgs = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hi"},
|
||||
]
|
||||
instructions, items = convert_messages(msgs)
|
||||
assert instructions == "You are helpful."
|
||||
assert len(items) == 1
|
||||
assert items[0]["role"] == "user"
|
||||
|
||||
def test_multiple_system_messages_last_wins(self):
|
||||
msgs = [
|
||||
{"role": "system", "content": "first"},
|
||||
{"role": "system", "content": "second"},
|
||||
{"role": "user", "content": "x"},
|
||||
]
|
||||
instructions, _ = convert_messages(msgs)
|
||||
assert instructions == "second"
|
||||
|
||||
def test_user_message_converted(self):
|
||||
_, items = convert_messages([{"role": "user", "content": "hello"}])
|
||||
assert items[0]["role"] == "user"
|
||||
assert items[0]["content"][0]["type"] == "input_text"
|
||||
|
||||
def test_assistant_text_message(self):
|
||||
_, items = convert_messages([
|
||||
{"role": "assistant", "content": "I'll help"},
|
||||
])
|
||||
assert items[0]["type"] == "message"
|
||||
assert items[0]["role"] == "assistant"
|
||||
assert items[0]["content"][0]["type"] == "output_text"
|
||||
assert items[0]["content"][0]["text"] == "I'll help"
|
||||
|
||||
def test_assistant_empty_content_skipped(self):
|
||||
_, items = convert_messages([{"role": "assistant", "content": ""}])
|
||||
assert len(items) == 0
|
||||
|
||||
def test_assistant_with_tool_calls(self):
|
||||
_, items = convert_messages([{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_abc|fc_1",
|
||||
"function": {"name": "get_weather", "arguments": '{"city":"SF"}'},
|
||||
}],
|
||||
}])
|
||||
assert items[0]["type"] == "function_call"
|
||||
assert items[0]["call_id"] == "call_abc"
|
||||
assert items[0]["id"] == "fc_1"
|
||||
assert items[0]["name"] == "get_weather"
|
||||
|
||||
def test_assistant_with_tool_calls_no_id(self):
|
||||
"""Fallback IDs when tool_call.id is missing."""
|
||||
_, items = convert_messages([{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"function": {"name": "f1", "arguments": "{}"}}],
|
||||
}])
|
||||
assert items[0]["call_id"] == "call_0"
|
||||
assert items[0]["id"].startswith("fc_")
|
||||
|
||||
def test_tool_message(self):
|
||||
_, items = convert_messages([{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_abc",
|
||||
"content": "result text",
|
||||
}])
|
||||
assert items[0]["type"] == "function_call_output"
|
||||
assert items[0]["call_id"] == "call_abc"
|
||||
assert items[0]["output"] == "result text"
|
||||
|
||||
def test_tool_message_dict_content(self):
|
||||
_, items = convert_messages([{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": {"key": "value"},
|
||||
}])
|
||||
assert items[0]["output"] == '{"key": "value"}'
|
||||
|
||||
def test_non_standard_keys_not_leaked(self):
|
||||
"""Extra keys on messages must not appear in converted items."""
|
||||
_, items = convert_messages([{
|
||||
"role": "user",
|
||||
"content": "hi",
|
||||
"extra_field": "should vanish",
|
||||
"_meta": {"path": "/tmp"},
|
||||
}])
|
||||
item = items[0]
|
||||
assert "extra_field" not in str(item)
|
||||
assert "_meta" not in str(item)
|
||||
|
||||
def test_full_conversation_roundtrip(self):
|
||||
"""System + user + assistant(tool_call) + tool -> correct structure."""
|
||||
msgs = [
|
||||
{"role": "system", "content": "Be concise."},
|
||||
{"role": "user", "content": "Weather in SF?"},
|
||||
{
|
||||
"role": "assistant", "content": None,
|
||||
"tool_calls": [{
|
||||
"id": "c1|fc1",
|
||||
"function": {"name": "get_weather", "arguments": '{"city":"SF"}'},
|
||||
}],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "c1", "content": '{"temp":72}'},
|
||||
]
|
||||
instructions, items = convert_messages(msgs)
|
||||
assert instructions == "Be concise."
|
||||
assert len(items) == 3 # user, function_call, function_call_output
|
||||
assert items[0]["role"] == "user"
|
||||
assert items[1]["type"] == "function_call"
|
||||
assert items[2]["type"] == "function_call_output"
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# converters - convert_tools
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestConvertTools:
|
||||
def test_standard_function_tool(self):
|
||||
tools = [{"type": "function", "function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
|
||||
}}]
|
||||
result = convert_tools(tools)
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == "function"
|
||||
assert result[0]["name"] == "get_weather"
|
||||
assert result[0]["description"] == "Get weather"
|
||||
assert "properties" in result[0]["parameters"]
|
||||
|
||||
def test_tool_without_name_skipped(self):
|
||||
tools = [{"type": "function", "function": {"parameters": {}}}]
|
||||
assert convert_tools(tools) == []
|
||||
|
||||
def test_tool_without_function_wrapper(self):
|
||||
"""Direct dict without type=function wrapper."""
|
||||
tools = [{"name": "f1", "description": "d", "parameters": {}}]
|
||||
result = convert_tools(tools)
|
||||
assert result[0]["name"] == "f1"
|
||||
|
||||
def test_missing_optional_fields_default(self):
|
||||
tools = [{"type": "function", "function": {"name": "f"}}]
|
||||
result = convert_tools(tools)
|
||||
assert result[0]["description"] == ""
|
||||
assert result[0]["parameters"] == {}
|
||||
|
||||
def test_multiple_tools(self):
|
||||
tools = [
|
||||
{"type": "function", "function": {"name": "a", "parameters": {}}},
|
||||
{"type": "function", "function": {"name": "b", "parameters": {}}},
|
||||
]
|
||||
assert len(convert_tools(tools)) == 2
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# parsing - map_finish_reason
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestMapFinishReason:
|
||||
def test_completed(self):
|
||||
assert map_finish_reason("completed") == "stop"
|
||||
|
||||
def test_incomplete(self):
|
||||
assert map_finish_reason("incomplete") == "length"
|
||||
|
||||
def test_failed(self):
|
||||
assert map_finish_reason("failed") == "error"
|
||||
|
||||
def test_cancelled(self):
|
||||
assert map_finish_reason("cancelled") == "error"
|
||||
|
||||
def test_none_defaults_to_stop(self):
|
||||
assert map_finish_reason(None) == "stop"
|
||||
|
||||
def test_unknown_defaults_to_stop(self):
|
||||
assert map_finish_reason("some_new_status") == "stop"
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# parsing - parse_response_output
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestParseResponseOutput:
|
||||
def test_text_response(self):
|
||||
resp = {
|
||||
"output": [{"type": "message", "role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "Hello!"}]}],
|
||||
"status": "completed",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
}
|
||||
result = parse_response_output(resp)
|
||||
assert result.content == "Hello!"
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.usage == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
|
||||
assert result.tool_calls == []
|
||||
|
||||
def test_tool_call_response(self):
|
||||
resp = {
|
||||
"output": [{
|
||||
"type": "function_call",
|
||||
"call_id": "call_1", "id": "fc_1",
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "SF"}',
|
||||
}],
|
||||
"status": "completed",
|
||||
"usage": {},
|
||||
}
|
||||
result = parse_response_output(resp)
|
||||
assert result.content is None
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "get_weather"
|
||||
assert result.tool_calls[0].arguments == {"city": "SF"}
|
||||
assert result.tool_calls[0].id == "call_1|fc_1"
|
||||
|
||||
def test_malformed_tool_arguments_logged(self):
|
||||
"""Malformed JSON arguments should log a warning and fallback."""
|
||||
resp = {
|
||||
"output": [{
|
||||
"type": "function_call",
|
||||
"call_id": "c1", "id": "fc1",
|
||||
"name": "f", "arguments": "{bad json",
|
||||
}],
|
||||
"status": "completed", "usage": {},
|
||||
}
|
||||
with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger:
|
||||
result = parse_response_output(resp)
|
||||
assert result.tool_calls[0].arguments == {"raw": "{bad json"}
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args)
|
||||
|
||||
def test_reasoning_content_extracted(self):
|
||||
resp = {
|
||||
"output": [
|
||||
{"type": "reasoning", "summary": [
|
||||
{"type": "summary_text", "text": "I think "},
|
||||
{"type": "summary_text", "text": "therefore I am."},
|
||||
]},
|
||||
{"type": "message", "role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "42"}]},
|
||||
],
|
||||
"status": "completed", "usage": {},
|
||||
}
|
||||
result = parse_response_output(resp)
|
||||
assert result.content == "42"
|
||||
assert result.reasoning_content == "I think therefore I am."
|
||||
|
||||
def test_empty_output(self):
|
||||
resp = {"output": [], "status": "completed", "usage": {}}
|
||||
result = parse_response_output(resp)
|
||||
assert result.content is None
|
||||
assert result.tool_calls == []
|
||||
|
||||
def test_incomplete_status(self):
|
||||
resp = {"output": [], "status": "incomplete", "usage": {}}
|
||||
result = parse_response_output(resp)
|
||||
assert result.finish_reason == "length"
|
||||
|
||||
def test_sdk_model_object(self):
|
||||
"""parse_response_output should handle SDK objects with model_dump()."""
|
||||
mock = MagicMock()
|
||||
mock.model_dump.return_value = {
|
||||
"output": [{"type": "message", "role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "sdk"}]}],
|
||||
"status": "completed",
|
||||
"usage": {"input_tokens": 1, "output_tokens": 2, "total_tokens": 3},
|
||||
}
|
||||
result = parse_response_output(mock)
|
||||
assert result.content == "sdk"
|
||||
assert result.usage["prompt_tokens"] == 1
|
||||
|
||||
def test_usage_maps_responses_api_keys(self):
|
||||
"""Responses API uses input_tokens/output_tokens, not prompt_tokens/completion_tokens."""
|
||||
resp = {
|
||||
"output": [],
|
||||
"status": "completed",
|
||||
"usage": {"input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
|
||||
}
|
||||
result = parse_response_output(resp)
|
||||
assert result.usage["prompt_tokens"] == 100
|
||||
assert result.usage["completion_tokens"] == 50
|
||||
assert result.usage["total_tokens"] == 150
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# parsing - consume_sdk_stream
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestConsumeSdkStream:
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_stream(self):
|
||||
ev1 = MagicMock(type="response.output_text.delta", delta="Hello")
|
||||
ev2 = MagicMock(type="response.output_text.delta", delta=" world")
|
||||
resp_obj = MagicMock(status="completed", usage=None, output=[])
|
||||
ev3 = MagicMock(type="response.completed", response=resp_obj)
|
||||
|
||||
async def stream():
|
||||
for e in [ev1, ev2, ev3]:
|
||||
yield e
|
||||
|
||||
content, tool_calls, finish_reason, usage, reasoning = await consume_sdk_stream(stream())
|
||||
assert content == "Hello world"
|
||||
assert tool_calls == []
|
||||
assert finish_reason == "stop"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_content_delta_called(self):
|
||||
ev1 = MagicMock(type="response.output_text.delta", delta="hi")
|
||||
resp_obj = MagicMock(status="completed", usage=None, output=[])
|
||||
ev2 = MagicMock(type="response.completed", response=resp_obj)
|
||||
deltas = []
|
||||
|
||||
async def cb(text):
|
||||
deltas.append(text)
|
||||
|
||||
async def stream():
|
||||
for e in [ev1, ev2]:
|
||||
yield e
|
||||
|
||||
await consume_sdk_stream(stream(), on_content_delta=cb)
|
||||
assert deltas == ["hi"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_stream(self):
|
||||
item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="")
|
||||
item_added.name = "get_weather"
|
||||
ev1 = MagicMock(type="response.output_item.added", item=item_added)
|
||||
ev2 = MagicMock(type="response.function_call_arguments.delta", call_id="c1", delta='{"ci')
|
||||
ev3 = MagicMock(type="response.function_call_arguments.done", call_id="c1", arguments='{"city":"SF"}')
|
||||
item_done = MagicMock(type="function_call", call_id="c1", id="fc1", arguments='{"city":"SF"}')
|
||||
item_done.name = "get_weather"
|
||||
ev4 = MagicMock(type="response.output_item.done", item=item_done)
|
||||
resp_obj = MagicMock(status="completed", usage=None, output=[])
|
||||
ev5 = MagicMock(type="response.completed", response=resp_obj)
|
||||
|
||||
async def stream():
|
||||
for e in [ev1, ev2, ev3, ev4, ev5]:
|
||||
yield e
|
||||
|
||||
content, tool_calls, finish_reason, usage, reasoning = await consume_sdk_stream(stream())
|
||||
assert content == ""
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].name == "get_weather"
|
||||
assert tool_calls[0].arguments == {"city": "SF"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_extracted(self):
|
||||
usage_obj = MagicMock(input_tokens=10, output_tokens=5, total_tokens=15)
|
||||
resp_obj = MagicMock(status="completed", usage=usage_obj, output=[])
|
||||
ev = MagicMock(type="response.completed", response=resp_obj)
|
||||
|
||||
async def stream():
|
||||
yield ev
|
||||
|
||||
_, _, _, usage, _ = await consume_sdk_stream(stream())
|
||||
assert usage == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_extracted(self):
|
||||
summary_item = MagicMock(type="summary_text", text="thinking...")
|
||||
reasoning_item = MagicMock(type="reasoning", summary=[summary_item])
|
||||
resp_obj = MagicMock(status="completed", usage=None, output=[reasoning_item])
|
||||
ev = MagicMock(type="response.completed", response=resp_obj)
|
||||
|
||||
async def stream():
|
||||
yield ev
|
||||
|
||||
_, _, _, _, reasoning = await consume_sdk_stream(stream())
|
||||
assert reasoning == "thinking..."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_event_raises(self):
|
||||
ev = MagicMock(type="error", error="rate_limit_exceeded")
|
||||
|
||||
async def stream():
|
||||
yield ev
|
||||
|
||||
with pytest.raises(RuntimeError, match="Response failed.*rate_limit_exceeded"):
|
||||
await consume_sdk_stream(stream())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_event_raises(self):
|
||||
ev = MagicMock(type="response.failed", error="server_error")
|
||||
|
||||
async def stream():
|
||||
yield ev
|
||||
|
||||
with pytest.raises(RuntimeError, match="Response failed.*server_error"):
|
||||
await consume_sdk_stream(stream())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_malformed_tool_args_logged(self):
|
||||
"""Malformed JSON in streaming tool args should log a warning."""
|
||||
item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="")
|
||||
item_added.name = "f"
|
||||
ev1 = MagicMock(type="response.output_item.added", item=item_added)
|
||||
ev2 = MagicMock(type="response.function_call_arguments.done", call_id="c1", arguments="{bad")
|
||||
item_done = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="{bad")
|
||||
item_done.name = "f"
|
||||
ev3 = MagicMock(type="response.output_item.done", item=item_done)
|
||||
resp_obj = MagicMock(status="completed", usage=None, output=[])
|
||||
ev4 = MagicMock(type="response.completed", response=resp_obj)
|
||||
|
||||
async def stream():
|
||||
for e in [ev1, ev2, ev3, ev4]:
|
||||
yield e
|
||||
|
||||
with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger:
|
||||
_, tool_calls, _, _, _ = await consume_sdk_stream(stream())
|
||||
assert tool_calls[0].arguments == {"raw": "{bad"}
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args)
|
||||
@ -211,3 +211,56 @@ async def test_image_fallback_without_meta_uses_default_placeholder() -> None:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
assert any("[image omitted]" in (b.get("text") or "") for b in content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_uses_retry_after_and_emits_wait_progress(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="429 rate limit, retry after 7s", finish_reason="error"),
|
||||
LLMResponse(content="ok"),
|
||||
])
|
||||
delays: list[float] = []
|
||||
progress: list[str] = []
|
||||
|
||||
async def _fake_sleep(delay: float) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
async def _progress(msg: str) -> None:
|
||||
progress.append(msg)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
on_retry_wait=_progress,
|
||||
)
|
||||
|
||||
assert response.content == "ok"
|
||||
assert delays == [7.0]
|
||||
assert progress and "7s" in progress[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
*[LLMResponse(content="429 rate limit", finish_reason="error") for _ in range(10)],
|
||||
LLMResponse(content="ok"),
|
||||
])
|
||||
delays: list[float] = []
|
||||
|
||||
async def _fake_sleep(delay: float) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
retry_mode="persistent",
|
||||
)
|
||||
|
||||
assert response.finish_reason == "error"
|
||||
assert response.content == "429 rate limit"
|
||||
assert provider.calls == 10
|
||||
assert delays == [1, 2, 4, 4, 4, 4, 4, 4, 4]
|
||||
|
||||
|
||||
|
||||
@ -347,6 +347,8 @@ async def test_empty_response_retry_then_success(aiohttp_client) -> None:
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_response_falls_back(aiohttp_client) -> None:
|
||||
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def always_empty(content, session_key="", channel="", chat_id=""):
|
||||
@ -367,5 +369,5 @@ async def test_empty_response_falls_back(aiohttp_client) -> None:
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["choices"][0]["message"]["content"] == "I've completed processing but have no response to give."
|
||||
assert body["choices"][0]["message"]["content"] == EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
assert call_count == 2
|
||||
|
||||
@ -196,7 +196,7 @@ async def test_execute_re_raises_external_cancellation() -> None:
|
||||
|
||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=10)
|
||||
task = asyncio.create_task(wrapper.execute())
|
||||
await started.wait()
|
||||
await asyncio.wait_for(started.wait(), timeout=1.0)
|
||||
|
||||
task.cancel()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user