mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-02 07:45:54 +00:00
Merge remote-tracking branch 'origin/main' into pr-2646
This commit is contained in:
commit
f409337fcf
1
.gitignore
vendored
1
.gitignore
vendored
@ -2,6 +2,7 @@
|
|||||||
.assets
|
.assets
|
||||||
.docs
|
.docs
|
||||||
.env
|
.env
|
||||||
|
.web
|
||||||
*.pyc
|
*.pyc
|
||||||
dist/
|
dist/
|
||||||
build/
|
build/
|
||||||
|
|||||||
17
README.md
17
README.md
@ -20,13 +20,20 @@
|
|||||||
|
|
||||||
## 📢 News
|
## 📢 News
|
||||||
|
|
||||||
> [!IMPORTANT]
|
- **2026-04-02** 🧱 **Long-running tasks** run more reliably — core runtime hardening.
|
||||||
> **Security note:** Due to `litellm` supply chain poisoning, **please check your Python environment ASAP** and refer to this [advisory](https://github.com/HKUDS/nanobot/discussions/2445) for details. We have fully removed the `litellm` since **v0.1.4.post6**.
|
- **2026-04-01** 🔑 GitHub Copilot auth restored; stricter workspace paths; OpenRouter Claude caching fix.
|
||||||
|
- **2026-03-31** 🛰️ WeChat multimodal alignment, Discord/Matrix polish, Python SDK facade, MCP and tool fixes.
|
||||||
|
- **2026-03-30** 🧩 OpenAI-compatible API tightened; composable agent lifecycle hooks.
|
||||||
|
- **2026-03-29** 💬 WeChat voice, typing, QR/media resilience; fixed-session OpenAI-compatible API.
|
||||||
|
- **2026-03-28** 📚 Provider docs refresh; skill template wording fix.
|
||||||
- **2026-03-27** 🚀 Released **v0.1.4.post6** — architecture decoupling, litellm removal, end-to-end streaming, WeChat channel, and a security fix. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post6) for details.
|
- **2026-03-27** 🚀 Released **v0.1.4.post6** — architecture decoupling, litellm removal, end-to-end streaming, WeChat channel, and a security fix. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post6) for details.
|
||||||
- **2026-03-26** 🏗️ Agent runner extracted and lifecycle hooks unified; stream delta coalescing at boundaries.
|
- **2026-03-26** 🏗️ Agent runner extracted and lifecycle hooks unified; stream delta coalescing at boundaries.
|
||||||
- **2026-03-25** 🌏 StepFun provider, configurable timezone, Gemini thought signatures.
|
- **2026-03-25** 🌏 StepFun provider, configurable timezone, Gemini thought signatures.
|
||||||
- **2026-03-24** 🔧 WeChat compatibility, Feishu CardKit streaming, test suite restructured.
|
- **2026-03-24** 🔧 WeChat compatibility, Feishu CardKit streaming, test suite restructured.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Earlier news</summary>
|
||||||
|
|
||||||
- **2026-03-23** 🔧 Command routing refactored for plugins, WhatsApp/WeChat media, unified channel login CLI.
|
- **2026-03-23** 🔧 Command routing refactored for plugins, WhatsApp/WeChat media, unified channel login CLI.
|
||||||
- **2026-03-22** ⚡ End-to-end streaming, WeChat channel, Anthropic cache optimization, `/status` command.
|
- **2026-03-22** ⚡ End-to-end streaming, WeChat channel, Anthropic cache optimization, `/status` command.
|
||||||
- **2026-03-21** 🔒 Replace `litellm` with native `openai` + `anthropic` SDKs. Please see [commit](https://github.com/HKUDS/nanobot/commit/3dfdab7).
|
- **2026-03-21** 🔒 Replace `litellm` with native `openai` + `anthropic` SDKs. Please see [commit](https://github.com/HKUDS/nanobot/commit/3dfdab7).
|
||||||
@ -34,10 +41,6 @@
|
|||||||
- **2026-03-19** 💬 Telegram gets more resilient under load; Feishu now renders code blocks properly.
|
- **2026-03-19** 💬 Telegram gets more resilient under load; Feishu now renders code blocks properly.
|
||||||
- **2026-03-18** 📷 Telegram can now send media via URL. Cron schedules show human-readable details.
|
- **2026-03-18** 📷 Telegram can now send media via URL. Cron schedules show human-readable details.
|
||||||
- **2026-03-17** ✨ Feishu formatting glow-up, Slack reacts when done, custom endpoints support extra headers, and image handling is more reliable.
|
- **2026-03-17** ✨ Feishu formatting glow-up, Slack reacts when done, custom endpoints support extra headers, and image handling is more reliable.
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary>Earlier news</summary>
|
|
||||||
|
|
||||||
- **2026-03-16** 🚀 Released **v0.1.4.post5** — a refinement-focused release with stronger reliability and channel support, and a more dependable day-to-day experience. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post5) for details.
|
- **2026-03-16** 🚀 Released **v0.1.4.post5** — a refinement-focused release with stronger reliability and channel support, and a more dependable day-to-day experience. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post5) for details.
|
||||||
- **2026-03-15** 🧩 DingTalk rich media, smarter built-in skills, and cleaner model compatibility.
|
- **2026-03-15** 🧩 DingTalk rich media, smarter built-in skills, and cleaner model compatibility.
|
||||||
- **2026-03-14** 💬 Channel plugins, Feishu replies, and steadier MCP, QQ, and media handling.
|
- **2026-03-14** 💬 Channel plugins, Feishu replies, and steadier MCP, QQ, and media handling.
|
||||||
|
|||||||
@ -110,6 +110,20 @@ IMPORTANT: To send files (images, documents, audio, video) to the user, you MUST
|
|||||||
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
|
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
|
||||||
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
|
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
|
||||||
|
if isinstance(left, str) and isinstance(right, str):
|
||||||
|
return f"{left}\n\n{right}" if left else right
|
||||||
|
|
||||||
|
def _to_blocks(value: Any) -> list[dict[str, Any]]:
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [item if isinstance(item, dict) else {"type": "text", "text": str(item)} for item in value]
|
||||||
|
if value is None:
|
||||||
|
return []
|
||||||
|
return [{"type": "text", "text": str(value)}]
|
||||||
|
|
||||||
|
return _to_blocks(left) + _to_blocks(right)
|
||||||
|
|
||||||
def _load_bootstrap_files(self) -> str:
|
def _load_bootstrap_files(self) -> str:
|
||||||
"""Load all bootstrap files from workspace."""
|
"""Load all bootstrap files from workspace."""
|
||||||
parts = []
|
parts = []
|
||||||
@ -142,12 +156,17 @@ IMPORTANT: To send files (images, documents, audio, video) to the user, you MUST
|
|||||||
merged = f"{runtime_ctx}\n\n{user_content}"
|
merged = f"{runtime_ctx}\n\n{user_content}"
|
||||||
else:
|
else:
|
||||||
merged = [{"type": "text", "text": runtime_ctx}] + user_content
|
merged = [{"type": "text", "text": runtime_ctx}] + user_content
|
||||||
|
messages = [
|
||||||
return [
|
|
||||||
{"role": "system", "content": self.build_system_prompt(skill_names)},
|
{"role": "system", "content": self.build_system_prompt(skill_names)},
|
||||||
*history,
|
*history,
|
||||||
{"role": current_role, "content": merged},
|
|
||||||
]
|
]
|
||||||
|
if messages[-1].get("role") == current_role:
|
||||||
|
last = dict(messages[-1])
|
||||||
|
last["content"] = self._merge_message_content(last.get("content"), merged)
|
||||||
|
messages[-1] = last
|
||||||
|
return messages
|
||||||
|
messages.append({"role": current_role, "content": merged})
|
||||||
|
return messages
|
||||||
|
|
||||||
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
|
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
|
||||||
"""Build user message content with optional base64-encoded images."""
|
"""Build user message content with optional base64-encoded images."""
|
||||||
|
|||||||
@ -29,8 +29,11 @@ from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
|||||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||||
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
|
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
from nanobot.providers.base import LLMProvider
|
from nanobot.providers.base import LLMProvider
|
||||||
from nanobot.session.manager import Session, SessionManager
|
from nanobot.session.manager import Session, SessionManager
|
||||||
|
from nanobot.utils.helpers import image_placeholder_text, truncate_text
|
||||||
|
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from nanobot.config.schema import ChannelsConfig, ExecToolConfig, WebSearchConfig
|
from nanobot.config.schema import ChannelsConfig, ExecToolConfig, WebSearchConfig
|
||||||
@ -38,11 +41,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class _LoopHook(AgentHook):
|
class _LoopHook(AgentHook):
|
||||||
"""Core lifecycle hook for the main agent loop.
|
"""Core hook for the main loop."""
|
||||||
|
|
||||||
Handles streaming delta relay, progress reporting, tool-call logging,
|
|
||||||
and think-tag stripping for the built-in agent path.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -111,11 +110,7 @@ class _LoopHook(AgentHook):
|
|||||||
|
|
||||||
|
|
||||||
class _LoopHookChain(AgentHook):
|
class _LoopHookChain(AgentHook):
|
||||||
"""Run the core loop hook first, then best-effort extra hooks.
|
"""Run the core hook before extra hooks."""
|
||||||
|
|
||||||
This preserves the historical failure behavior of ``_LoopHook`` while still
|
|
||||||
letting user-supplied hooks opt into ``CompositeHook`` isolation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = ("_primary", "_extras")
|
__slots__ = ("_primary", "_extras")
|
||||||
|
|
||||||
@ -163,7 +158,7 @@ class AgentLoop:
|
|||||||
5. Sends responses back
|
5. Sends responses back
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_TOOL_RESULT_MAX_CHARS = 16_000
|
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -171,8 +166,11 @@ class AgentLoop:
|
|||||||
provider: LLMProvider,
|
provider: LLMProvider,
|
||||||
workspace: Path,
|
workspace: Path,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
max_iterations: int = 40,
|
max_iterations: int | None = None,
|
||||||
context_window_tokens: int = 65_536,
|
context_window_tokens: int | None = None,
|
||||||
|
context_block_limit: int | None = None,
|
||||||
|
max_tool_result_chars: int | None = None,
|
||||||
|
provider_retry_mode: str = "standard",
|
||||||
web_search_config: WebSearchConfig | None = None,
|
web_search_config: WebSearchConfig | None = None,
|
||||||
web_proxy: str | None = None,
|
web_proxy: str | None = None,
|
||||||
exec_config: ExecToolConfig | None = None,
|
exec_config: ExecToolConfig | None = None,
|
||||||
@ -186,13 +184,27 @@ class AgentLoop:
|
|||||||
):
|
):
|
||||||
from nanobot.config.schema import ExecToolConfig, WebSearchConfig
|
from nanobot.config.schema import ExecToolConfig, WebSearchConfig
|
||||||
|
|
||||||
|
defaults = AgentDefaults()
|
||||||
self.bus = bus
|
self.bus = bus
|
||||||
self.channels_config = channels_config
|
self.channels_config = channels_config
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.model = model or provider.get_default_model()
|
self.model = model or provider.get_default_model()
|
||||||
self.max_iterations = max_iterations
|
self.max_iterations = (
|
||||||
self.context_window_tokens = context_window_tokens
|
max_iterations if max_iterations is not None else defaults.max_tool_iterations
|
||||||
|
)
|
||||||
|
self.context_window_tokens = (
|
||||||
|
context_window_tokens
|
||||||
|
if context_window_tokens is not None
|
||||||
|
else defaults.context_window_tokens
|
||||||
|
)
|
||||||
|
self.context_block_limit = context_block_limit
|
||||||
|
self.max_tool_result_chars = (
|
||||||
|
max_tool_result_chars
|
||||||
|
if max_tool_result_chars is not None
|
||||||
|
else defaults.max_tool_result_chars
|
||||||
|
)
|
||||||
|
self.provider_retry_mode = provider_retry_mode
|
||||||
self.web_search_config = web_search_config or WebSearchConfig()
|
self.web_search_config = web_search_config or WebSearchConfig()
|
||||||
self.web_proxy = web_proxy
|
self.web_proxy = web_proxy
|
||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
@ -211,6 +223,7 @@ class AgentLoop:
|
|||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
bus=bus,
|
bus=bus,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
|
max_tool_result_chars=self.max_tool_result_chars,
|
||||||
web_search_config=self.web_search_config,
|
web_search_config=self.web_search_config,
|
||||||
web_proxy=web_proxy,
|
web_proxy=web_proxy,
|
||||||
exec_config=self.exec_config,
|
exec_config=self.exec_config,
|
||||||
@ -322,6 +335,7 @@ class AgentLoop:
|
|||||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||||
*,
|
*,
|
||||||
|
session: Session | None = None,
|
||||||
channel: str = "cli",
|
channel: str = "cli",
|
||||||
chat_id: str = "direct",
|
chat_id: str = "direct",
|
||||||
message_id: str | None = None,
|
message_id: str | None = None,
|
||||||
@ -348,14 +362,27 @@ class AgentLoop:
|
|||||||
else loop_hook
|
else loop_hook
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _checkpoint(payload: dict[str, Any]) -> None:
|
||||||
|
if session is None:
|
||||||
|
return
|
||||||
|
self._set_runtime_checkpoint(session, payload)
|
||||||
|
|
||||||
result = await self.runner.run(AgentRunSpec(
|
result = await self.runner.run(AgentRunSpec(
|
||||||
initial_messages=initial_messages,
|
initial_messages=initial_messages,
|
||||||
tools=self.tools,
|
tools=self.tools,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
max_iterations=self.max_iterations,
|
max_iterations=self.max_iterations,
|
||||||
|
max_tool_result_chars=self.max_tool_result_chars,
|
||||||
hook=hook,
|
hook=hook,
|
||||||
error_message="Sorry, I encountered an error calling the AI model.",
|
error_message="Sorry, I encountered an error calling the AI model.",
|
||||||
concurrent_tools=True,
|
concurrent_tools=True,
|
||||||
|
workspace=self.workspace,
|
||||||
|
session_key=session.key if session else None,
|
||||||
|
context_window_tokens=self.context_window_tokens,
|
||||||
|
context_block_limit=self.context_block_limit,
|
||||||
|
provider_retry_mode=self.provider_retry_mode,
|
||||||
|
progress_callback=on_progress,
|
||||||
|
checkpoint_callback=_checkpoint,
|
||||||
))
|
))
|
||||||
self._last_usage = result.usage
|
self._last_usage = result.usage
|
||||||
if result.stop_reason == "max_iterations":
|
if result.stop_reason == "max_iterations":
|
||||||
@ -493,6 +520,8 @@ class AgentLoop:
|
|||||||
logger.info("Processing system message from {}", msg.sender_id)
|
logger.info("Processing system message from {}", msg.sender_id)
|
||||||
key = f"{channel}:{chat_id}"
|
key = f"{channel}:{chat_id}"
|
||||||
session = self.sessions.get_or_create(key)
|
session = self.sessions.get_or_create(key)
|
||||||
|
if self._restore_runtime_checkpoint(session):
|
||||||
|
self.sessions.save(session)
|
||||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||||
history = session.get_history(max_messages=0)
|
history = session.get_history(max_messages=0)
|
||||||
@ -503,10 +532,11 @@ class AgentLoop:
|
|||||||
current_role=current_role,
|
current_role=current_role,
|
||||||
)
|
)
|
||||||
final_content, _, all_msgs = await self._run_agent_loop(
|
final_content, _, all_msgs = await self._run_agent_loop(
|
||||||
messages, channel=channel, chat_id=chat_id,
|
messages, session=session, channel=channel, chat_id=chat_id,
|
||||||
message_id=msg.metadata.get("message_id"),
|
message_id=msg.metadata.get("message_id"),
|
||||||
)
|
)
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
|
self._clear_runtime_checkpoint(session)
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
||||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||||
@ -517,6 +547,8 @@ class AgentLoop:
|
|||||||
|
|
||||||
key = session_key or msg.session_key
|
key = session_key or msg.session_key
|
||||||
session = self.sessions.get_or_create(key)
|
session = self.sessions.get_or_create(key)
|
||||||
|
if self._restore_runtime_checkpoint(session):
|
||||||
|
self.sessions.save(session)
|
||||||
|
|
||||||
# Slash commands
|
# Slash commands
|
||||||
raw = msg.content.strip()
|
raw = msg.content.strip()
|
||||||
@ -552,14 +584,16 @@ class AgentLoop:
|
|||||||
on_progress=on_progress or _bus_progress,
|
on_progress=on_progress or _bus_progress,
|
||||||
on_stream=on_stream,
|
on_stream=on_stream,
|
||||||
on_stream_end=on_stream_end,
|
on_stream_end=on_stream_end,
|
||||||
|
session=session,
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
message_id=msg.metadata.get("message_id"),
|
message_id=msg.metadata.get("message_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
if final_content is None:
|
if final_content is None or not final_content.strip():
|
||||||
final_content = "I've completed processing but have no response to give."
|
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
|
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
|
self._clear_runtime_checkpoint(session)
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
||||||
|
|
||||||
@ -577,12 +611,6 @@ class AgentLoop:
|
|||||||
metadata=meta,
|
metadata=meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _image_placeholder(block: dict[str, Any]) -> dict[str, str]:
|
|
||||||
"""Convert an inline image block into a compact text placeholder."""
|
|
||||||
path = (block.get("_meta") or {}).get("path", "")
|
|
||||||
return {"type": "text", "text": f"[image: {path}]" if path else "[image]"}
|
|
||||||
|
|
||||||
def _sanitize_persisted_blocks(
|
def _sanitize_persisted_blocks(
|
||||||
self,
|
self,
|
||||||
content: list[dict[str, Any]],
|
content: list[dict[str, Any]],
|
||||||
@ -609,13 +637,14 @@ class AgentLoop:
|
|||||||
block.get("type") == "image_url"
|
block.get("type") == "image_url"
|
||||||
and block.get("image_url", {}).get("url", "").startswith("data:image/")
|
and block.get("image_url", {}).get("url", "").startswith("data:image/")
|
||||||
):
|
):
|
||||||
filtered.append(self._image_placeholder(block))
|
path = (block.get("_meta") or {}).get("path", "")
|
||||||
|
filtered.append({"type": "text", "text": image_placeholder_text(path)})
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if block.get("type") == "text" and isinstance(block.get("text"), str):
|
if block.get("type") == "text" and isinstance(block.get("text"), str):
|
||||||
text = block["text"]
|
text = block["text"]
|
||||||
if truncate_text and len(text) > self._TOOL_RESULT_MAX_CHARS:
|
if truncate_text and len(text) > self.max_tool_result_chars:
|
||||||
text = text[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
text = truncate_text(text, self.max_tool_result_chars)
|
||||||
filtered.append({**block, "text": text})
|
filtered.append({**block, "text": text})
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -632,8 +661,8 @@ class AgentLoop:
|
|||||||
if role == "assistant" and not content and not entry.get("tool_calls"):
|
if role == "assistant" and not content and not entry.get("tool_calls"):
|
||||||
continue # skip empty assistant messages — they poison session context
|
continue # skip empty assistant messages — they poison session context
|
||||||
if role == "tool":
|
if role == "tool":
|
||||||
if isinstance(content, str) and len(content) > self._TOOL_RESULT_MAX_CHARS:
|
if isinstance(content, str) and len(content) > self.max_tool_result_chars:
|
||||||
entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
entry["content"] = truncate_text(content, self.max_tool_result_chars)
|
||||||
elif isinstance(content, list):
|
elif isinstance(content, list):
|
||||||
filtered = self._sanitize_persisted_blocks(content, truncate_text=True)
|
filtered = self._sanitize_persisted_blocks(content, truncate_text=True)
|
||||||
if not filtered:
|
if not filtered:
|
||||||
@ -656,6 +685,78 @@ class AgentLoop:
|
|||||||
session.messages.append(entry)
|
session.messages.append(entry)
|
||||||
session.updated_at = datetime.now()
|
session.updated_at = datetime.now()
|
||||||
|
|
||||||
|
def _set_runtime_checkpoint(self, session: Session, payload: dict[str, Any]) -> None:
|
||||||
|
"""Persist the latest in-flight turn state into session metadata."""
|
||||||
|
session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload
|
||||||
|
self.sessions.save(session)
|
||||||
|
|
||||||
|
def _clear_runtime_checkpoint(self, session: Session) -> None:
|
||||||
|
if self._RUNTIME_CHECKPOINT_KEY in session.metadata:
|
||||||
|
session.metadata.pop(self._RUNTIME_CHECKPOINT_KEY, None)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _checkpoint_message_key(message: dict[str, Any]) -> tuple[Any, ...]:
|
||||||
|
return (
|
||||||
|
message.get("role"),
|
||||||
|
message.get("content"),
|
||||||
|
message.get("tool_call_id"),
|
||||||
|
message.get("name"),
|
||||||
|
message.get("tool_calls"),
|
||||||
|
message.get("reasoning_content"),
|
||||||
|
message.get("thinking_blocks"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _restore_runtime_checkpoint(self, session: Session) -> bool:
|
||||||
|
"""Materialize an unfinished turn into session history before a new request."""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
checkpoint = session.metadata.get(self._RUNTIME_CHECKPOINT_KEY)
|
||||||
|
if not isinstance(checkpoint, dict):
|
||||||
|
return False
|
||||||
|
|
||||||
|
assistant_message = checkpoint.get("assistant_message")
|
||||||
|
completed_tool_results = checkpoint.get("completed_tool_results") or []
|
||||||
|
pending_tool_calls = checkpoint.get("pending_tool_calls") or []
|
||||||
|
|
||||||
|
restored_messages: list[dict[str, Any]] = []
|
||||||
|
if isinstance(assistant_message, dict):
|
||||||
|
restored = dict(assistant_message)
|
||||||
|
restored.setdefault("timestamp", datetime.now().isoformat())
|
||||||
|
restored_messages.append(restored)
|
||||||
|
for message in completed_tool_results:
|
||||||
|
if isinstance(message, dict):
|
||||||
|
restored = dict(message)
|
||||||
|
restored.setdefault("timestamp", datetime.now().isoformat())
|
||||||
|
restored_messages.append(restored)
|
||||||
|
for tool_call in pending_tool_calls:
|
||||||
|
if not isinstance(tool_call, dict):
|
||||||
|
continue
|
||||||
|
tool_id = tool_call.get("id")
|
||||||
|
name = ((tool_call.get("function") or {}).get("name")) or "tool"
|
||||||
|
restored_messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_id,
|
||||||
|
"name": name,
|
||||||
|
"content": "Error: Task interrupted before this tool finished.",
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
})
|
||||||
|
|
||||||
|
overlap = 0
|
||||||
|
max_overlap = min(len(session.messages), len(restored_messages))
|
||||||
|
for size in range(max_overlap, 0, -1):
|
||||||
|
existing = session.messages[-size:]
|
||||||
|
restored = restored_messages[:size]
|
||||||
|
if all(
|
||||||
|
self._checkpoint_message_key(left) == self._checkpoint_message_key(right)
|
||||||
|
for left, right in zip(existing, restored)
|
||||||
|
):
|
||||||
|
overlap = size
|
||||||
|
break
|
||||||
|
session.messages.extend(restored_messages[overlap:])
|
||||||
|
|
||||||
|
self._clear_runtime_checkpoint(session)
|
||||||
|
return True
|
||||||
|
|
||||||
async def process_direct(
|
async def process_direct(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
|
|||||||
@ -4,20 +4,36 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||||
from nanobot.agent.tools.registry import ToolRegistry
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
from nanobot.providers.base import LLMProvider, ToolCallRequest
|
from nanobot.providers.base import LLMProvider, ToolCallRequest
|
||||||
from nanobot.utils.helpers import build_assistant_message
|
from nanobot.utils.helpers import (
|
||||||
|
build_assistant_message,
|
||||||
|
estimate_message_tokens,
|
||||||
|
estimate_prompt_tokens_chain,
|
||||||
|
find_legal_message_start,
|
||||||
|
maybe_persist_tool_result,
|
||||||
|
truncate_text,
|
||||||
|
)
|
||||||
|
from nanobot.utils.runtime import (
|
||||||
|
EMPTY_FINAL_RESPONSE_MESSAGE,
|
||||||
|
build_finalization_retry_message,
|
||||||
|
ensure_nonempty_tool_result,
|
||||||
|
is_blank_text,
|
||||||
|
repeated_external_lookup_error,
|
||||||
|
)
|
||||||
|
|
||||||
_DEFAULT_MAX_ITERATIONS_MESSAGE = (
|
_DEFAULT_MAX_ITERATIONS_MESSAGE = (
|
||||||
"I reached the maximum number of tool call iterations ({max_iterations}) "
|
"I reached the maximum number of tool call iterations ({max_iterations}) "
|
||||||
"without completing the task. You can try breaking the task into smaller steps."
|
"without completing the task. You can try breaking the task into smaller steps."
|
||||||
)
|
)
|
||||||
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
|
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
|
||||||
|
_SNIP_SAFETY_BUFFER = 1024
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class AgentRunSpec:
|
class AgentRunSpec:
|
||||||
"""Configuration for a single agent execution."""
|
"""Configuration for a single agent execution."""
|
||||||
@ -26,6 +42,7 @@ class AgentRunSpec:
|
|||||||
tools: ToolRegistry
|
tools: ToolRegistry
|
||||||
model: str
|
model: str
|
||||||
max_iterations: int
|
max_iterations: int
|
||||||
|
max_tool_result_chars: int
|
||||||
temperature: float | None = None
|
temperature: float | None = None
|
||||||
max_tokens: int | None = None
|
max_tokens: int | None = None
|
||||||
reasoning_effort: str | None = None
|
reasoning_effort: str | None = None
|
||||||
@ -34,6 +51,13 @@ class AgentRunSpec:
|
|||||||
max_iterations_message: str | None = None
|
max_iterations_message: str | None = None
|
||||||
concurrent_tools: bool = False
|
concurrent_tools: bool = False
|
||||||
fail_on_tool_error: bool = False
|
fail_on_tool_error: bool = False
|
||||||
|
workspace: Path | None = None
|
||||||
|
session_key: str | None = None
|
||||||
|
context_window_tokens: int | None = None
|
||||||
|
context_block_limit: int | None = None
|
||||||
|
provider_retry_mode: str = "standard"
|
||||||
|
progress_callback: Any | None = None
|
||||||
|
checkpoint_callback: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
@ -60,91 +84,142 @@ class AgentRunner:
|
|||||||
messages = list(spec.initial_messages)
|
messages = list(spec.initial_messages)
|
||||||
final_content: str | None = None
|
final_content: str | None = None
|
||||||
tools_used: list[str] = []
|
tools_used: list[str] = []
|
||||||
usage: dict[str, int] = {}
|
usage: dict[str, int] = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
stop_reason = "completed"
|
stop_reason = "completed"
|
||||||
tool_events: list[dict[str, str]] = []
|
tool_events: list[dict[str, str]] = []
|
||||||
|
external_lookup_counts: dict[str, int] = {}
|
||||||
|
|
||||||
for iteration in range(spec.max_iterations):
|
for iteration in range(spec.max_iterations):
|
||||||
|
try:
|
||||||
|
messages = self._apply_tool_result_budget(spec, messages)
|
||||||
|
messages_for_model = self._snip_history(spec, messages)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Context governance failed on turn {} for {}: {}; using raw messages",
|
||||||
|
iteration,
|
||||||
|
spec.session_key or "default",
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
messages_for_model = messages
|
||||||
context = AgentHookContext(iteration=iteration, messages=messages)
|
context = AgentHookContext(iteration=iteration, messages=messages)
|
||||||
await hook.before_iteration(context)
|
await hook.before_iteration(context)
|
||||||
kwargs: dict[str, Any] = {
|
response = await self._request_model(spec, messages_for_model, hook, context)
|
||||||
"messages": messages,
|
raw_usage = self._usage_dict(response.usage)
|
||||||
"tools": spec.tools.get_definitions(),
|
|
||||||
"model": spec.model,
|
|
||||||
}
|
|
||||||
if spec.temperature is not None:
|
|
||||||
kwargs["temperature"] = spec.temperature
|
|
||||||
if spec.max_tokens is not None:
|
|
||||||
kwargs["max_tokens"] = spec.max_tokens
|
|
||||||
if spec.reasoning_effort is not None:
|
|
||||||
kwargs["reasoning_effort"] = spec.reasoning_effort
|
|
||||||
|
|
||||||
if hook.wants_streaming():
|
|
||||||
async def _stream(delta: str) -> None:
|
|
||||||
await hook.on_stream(context, delta)
|
|
||||||
|
|
||||||
response = await self.provider.chat_stream_with_retry(
|
|
||||||
**kwargs,
|
|
||||||
on_content_delta=_stream,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = await self.provider.chat_with_retry(**kwargs)
|
|
||||||
|
|
||||||
raw_usage = response.usage or {}
|
|
||||||
context.response = response
|
context.response = response
|
||||||
context.usage = raw_usage
|
context.usage = dict(raw_usage)
|
||||||
context.tool_calls = list(response.tool_calls)
|
context.tool_calls = list(response.tool_calls)
|
||||||
# Accumulate standard fields into result usage.
|
self._accumulate_usage(usage, raw_usage)
|
||||||
usage["prompt_tokens"] = usage.get("prompt_tokens", 0) + int(raw_usage.get("prompt_tokens", 0) or 0)
|
|
||||||
usage["completion_tokens"] = usage.get("completion_tokens", 0) + int(raw_usage.get("completion_tokens", 0) or 0)
|
|
||||||
cached = raw_usage.get("cached_tokens")
|
|
||||||
if cached:
|
|
||||||
usage["cached_tokens"] = usage.get("cached_tokens", 0) + int(cached)
|
|
||||||
|
|
||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
if hook.wants_streaming():
|
if hook.wants_streaming():
|
||||||
await hook.on_stream_end(context, resuming=True)
|
await hook.on_stream_end(context, resuming=True)
|
||||||
|
|
||||||
messages.append(build_assistant_message(
|
assistant_message = build_assistant_message(
|
||||||
response.content or "",
|
response.content or "",
|
||||||
tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls],
|
tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls],
|
||||||
reasoning_content=response.reasoning_content,
|
reasoning_content=response.reasoning_content,
|
||||||
thinking_blocks=response.thinking_blocks,
|
thinking_blocks=response.thinking_blocks,
|
||||||
))
|
)
|
||||||
|
messages.append(assistant_message)
|
||||||
tools_used.extend(tc.name for tc in response.tool_calls)
|
tools_used.extend(tc.name for tc in response.tool_calls)
|
||||||
|
await self._emit_checkpoint(
|
||||||
|
spec,
|
||||||
|
{
|
||||||
|
"phase": "awaiting_tools",
|
||||||
|
"iteration": iteration,
|
||||||
|
"model": spec.model,
|
||||||
|
"assistant_message": assistant_message,
|
||||||
|
"completed_tool_results": [],
|
||||||
|
"pending_tool_calls": [tc.to_openai_tool_call() for tc in response.tool_calls],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
await hook.before_execute_tools(context)
|
await hook.before_execute_tools(context)
|
||||||
|
|
||||||
results, new_events, fatal_error = await self._execute_tools(spec, response.tool_calls)
|
results, new_events, fatal_error = await self._execute_tools(
|
||||||
|
spec,
|
||||||
|
response.tool_calls,
|
||||||
|
external_lookup_counts,
|
||||||
|
)
|
||||||
tool_events.extend(new_events)
|
tool_events.extend(new_events)
|
||||||
context.tool_results = list(results)
|
context.tool_results = list(results)
|
||||||
context.tool_events = list(new_events)
|
context.tool_events = list(new_events)
|
||||||
if fatal_error is not None:
|
if fatal_error is not None:
|
||||||
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
|
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
|
||||||
|
final_content = error
|
||||||
stop_reason = "tool_error"
|
stop_reason = "tool_error"
|
||||||
|
self._append_final_message(messages, final_content)
|
||||||
|
context.final_content = final_content
|
||||||
context.error = error
|
context.error = error
|
||||||
context.stop_reason = stop_reason
|
context.stop_reason = stop_reason
|
||||||
await hook.after_iteration(context)
|
await hook.after_iteration(context)
|
||||||
break
|
break
|
||||||
|
completed_tool_results: list[dict[str, Any]] = []
|
||||||
for tool_call, result in zip(response.tool_calls, results):
|
for tool_call, result in zip(response.tool_calls, results):
|
||||||
messages.append({
|
tool_message = {
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"tool_call_id": tool_call.id,
|
"tool_call_id": tool_call.id,
|
||||||
"name": tool_call.name,
|
"name": tool_call.name,
|
||||||
"content": result,
|
"content": self._normalize_tool_result(
|
||||||
})
|
spec,
|
||||||
|
tool_call.id,
|
||||||
|
tool_call.name,
|
||||||
|
result,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
messages.append(tool_message)
|
||||||
|
completed_tool_results.append(tool_message)
|
||||||
|
await self._emit_checkpoint(
|
||||||
|
spec,
|
||||||
|
{
|
||||||
|
"phase": "tools_completed",
|
||||||
|
"iteration": iteration,
|
||||||
|
"model": spec.model,
|
||||||
|
"assistant_message": assistant_message,
|
||||||
|
"completed_tool_results": completed_tool_results,
|
||||||
|
"pending_tool_calls": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
await hook.after_iteration(context)
|
await hook.after_iteration(context)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
clean = hook.finalize_content(context, response.content)
|
||||||
|
if response.finish_reason != "error" and is_blank_text(clean):
|
||||||
|
logger.warning(
|
||||||
|
"Empty final response on turn {} for {}; retrying with explicit finalization prompt",
|
||||||
|
iteration,
|
||||||
|
spec.session_key or "default",
|
||||||
|
)
|
||||||
|
if hook.wants_streaming():
|
||||||
|
await hook.on_stream_end(context, resuming=False)
|
||||||
|
response = await self._request_finalization_retry(spec, messages_for_model)
|
||||||
|
retry_usage = self._usage_dict(response.usage)
|
||||||
|
self._accumulate_usage(usage, retry_usage)
|
||||||
|
raw_usage = self._merge_usage(raw_usage, retry_usage)
|
||||||
|
context.response = response
|
||||||
|
context.usage = dict(raw_usage)
|
||||||
|
context.tool_calls = list(response.tool_calls)
|
||||||
|
clean = hook.finalize_content(context, response.content)
|
||||||
|
|
||||||
if hook.wants_streaming():
|
if hook.wants_streaming():
|
||||||
await hook.on_stream_end(context, resuming=False)
|
await hook.on_stream_end(context, resuming=False)
|
||||||
|
|
||||||
clean = hook.finalize_content(context, response.content)
|
|
||||||
if response.finish_reason == "error":
|
if response.finish_reason == "error":
|
||||||
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
|
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
|
||||||
stop_reason = "error"
|
stop_reason = "error"
|
||||||
error = final_content
|
error = final_content
|
||||||
|
self._append_final_message(messages, final_content)
|
||||||
|
context.final_content = final_content
|
||||||
|
context.error = error
|
||||||
|
context.stop_reason = stop_reason
|
||||||
|
await hook.after_iteration(context)
|
||||||
|
break
|
||||||
|
if is_blank_text(clean):
|
||||||
|
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
|
stop_reason = "empty_final_response"
|
||||||
|
error = final_content
|
||||||
|
self._append_final_message(messages, final_content)
|
||||||
context.final_content = final_content
|
context.final_content = final_content
|
||||||
context.error = error
|
context.error = error
|
||||||
context.stop_reason = stop_reason
|
context.stop_reason = stop_reason
|
||||||
@ -156,6 +231,17 @@ class AgentRunner:
|
|||||||
reasoning_content=response.reasoning_content,
|
reasoning_content=response.reasoning_content,
|
||||||
thinking_blocks=response.thinking_blocks,
|
thinking_blocks=response.thinking_blocks,
|
||||||
))
|
))
|
||||||
|
await self._emit_checkpoint(
|
||||||
|
spec,
|
||||||
|
{
|
||||||
|
"phase": "final_response",
|
||||||
|
"iteration": iteration,
|
||||||
|
"model": spec.model,
|
||||||
|
"assistant_message": messages[-1],
|
||||||
|
"completed_tool_results": [],
|
||||||
|
"pending_tool_calls": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
final_content = clean
|
final_content = clean
|
||||||
context.final_content = final_content
|
context.final_content = final_content
|
||||||
context.stop_reason = stop_reason
|
context.stop_reason = stop_reason
|
||||||
@ -165,6 +251,7 @@ class AgentRunner:
|
|||||||
stop_reason = "max_iterations"
|
stop_reason = "max_iterations"
|
||||||
template = spec.max_iterations_message or _DEFAULT_MAX_ITERATIONS_MESSAGE
|
template = spec.max_iterations_message or _DEFAULT_MAX_ITERATIONS_MESSAGE
|
||||||
final_content = template.format(max_iterations=spec.max_iterations)
|
final_content = template.format(max_iterations=spec.max_iterations)
|
||||||
|
self._append_final_message(messages, final_content)
|
||||||
|
|
||||||
return AgentRunResult(
|
return AgentRunResult(
|
||||||
final_content=final_content,
|
final_content=final_content,
|
||||||
@ -176,21 +263,101 @@ class AgentRunner:
|
|||||||
tool_events=tool_events,
|
tool_events=tool_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _build_request_kwargs(
|
||||||
|
self,
|
||||||
|
spec: AgentRunSpec,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
*,
|
||||||
|
tools: list[dict[str, Any]] | None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools,
|
||||||
|
"model": spec.model,
|
||||||
|
"retry_mode": spec.provider_retry_mode,
|
||||||
|
"on_retry_wait": spec.progress_callback,
|
||||||
|
}
|
||||||
|
if spec.temperature is not None:
|
||||||
|
kwargs["temperature"] = spec.temperature
|
||||||
|
if spec.max_tokens is not None:
|
||||||
|
kwargs["max_tokens"] = spec.max_tokens
|
||||||
|
if spec.reasoning_effort is not None:
|
||||||
|
kwargs["reasoning_effort"] = spec.reasoning_effort
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
async def _request_model(
|
||||||
|
self,
|
||||||
|
spec: AgentRunSpec,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
hook: AgentHook,
|
||||||
|
context: AgentHookContext,
|
||||||
|
):
|
||||||
|
kwargs = self._build_request_kwargs(
|
||||||
|
spec,
|
||||||
|
messages,
|
||||||
|
tools=spec.tools.get_definitions(),
|
||||||
|
)
|
||||||
|
if hook.wants_streaming():
|
||||||
|
async def _stream(delta: str) -> None:
|
||||||
|
await hook.on_stream(context, delta)
|
||||||
|
|
||||||
|
return await self.provider.chat_stream_with_retry(
|
||||||
|
**kwargs,
|
||||||
|
on_content_delta=_stream,
|
||||||
|
)
|
||||||
|
return await self.provider.chat_with_retry(**kwargs)
|
||||||
|
|
||||||
|
async def _request_finalization_retry(
|
||||||
|
self,
|
||||||
|
spec: AgentRunSpec,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
):
|
||||||
|
retry_messages = list(messages)
|
||||||
|
retry_messages.append(build_finalization_retry_message())
|
||||||
|
kwargs = self._build_request_kwargs(spec, retry_messages, tools=None)
|
||||||
|
return await self.provider.chat_with_retry(**kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _usage_dict(usage: dict[str, Any] | None) -> dict[str, int]:
|
||||||
|
if not usage:
|
||||||
|
return {}
|
||||||
|
result: dict[str, int] = {}
|
||||||
|
for key, value in usage.items():
|
||||||
|
try:
|
||||||
|
result[key] = int(value or 0)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _accumulate_usage(target: dict[str, int], addition: dict[str, int]) -> None:
|
||||||
|
for key, value in addition.items():
|
||||||
|
target[key] = target.get(key, 0) + value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_usage(left: dict[str, int], right: dict[str, int]) -> dict[str, int]:
|
||||||
|
merged = dict(left)
|
||||||
|
for key, value in right.items():
|
||||||
|
merged[key] = merged.get(key, 0) + value
|
||||||
|
return merged
|
||||||
|
|
||||||
async def _execute_tools(
|
async def _execute_tools(
|
||||||
self,
|
self,
|
||||||
spec: AgentRunSpec,
|
spec: AgentRunSpec,
|
||||||
tool_calls: list[ToolCallRequest],
|
tool_calls: list[ToolCallRequest],
|
||||||
|
external_lookup_counts: dict[str, int],
|
||||||
) -> tuple[list[Any], list[dict[str, str]], BaseException | None]:
|
) -> tuple[list[Any], list[dict[str, str]], BaseException | None]:
|
||||||
if spec.concurrent_tools:
|
batches = self._partition_tool_batches(spec, tool_calls)
|
||||||
tool_results = await asyncio.gather(*(
|
tool_results: list[tuple[Any, dict[str, str], BaseException | None]] = []
|
||||||
self._run_tool(spec, tool_call)
|
for batch in batches:
|
||||||
for tool_call in tool_calls
|
if spec.concurrent_tools and len(batch) > 1:
|
||||||
))
|
tool_results.extend(await asyncio.gather(*(
|
||||||
|
self._run_tool(spec, tool_call, external_lookup_counts)
|
||||||
|
for tool_call in batch
|
||||||
|
)))
|
||||||
else:
|
else:
|
||||||
tool_results = [
|
for tool_call in batch:
|
||||||
await self._run_tool(spec, tool_call)
|
tool_results.append(await self._run_tool(spec, tool_call, external_lookup_counts))
|
||||||
for tool_call in tool_calls
|
|
||||||
]
|
|
||||||
|
|
||||||
results: list[Any] = []
|
results: list[Any] = []
|
||||||
events: list[dict[str, str]] = []
|
events: list[dict[str, str]] = []
|
||||||
@ -206,9 +373,44 @@ class AgentRunner:
|
|||||||
self,
|
self,
|
||||||
spec: AgentRunSpec,
|
spec: AgentRunSpec,
|
||||||
tool_call: ToolCallRequest,
|
tool_call: ToolCallRequest,
|
||||||
|
external_lookup_counts: dict[str, int],
|
||||||
) -> tuple[Any, dict[str, str], BaseException | None]:
|
) -> tuple[Any, dict[str, str], BaseException | None]:
|
||||||
|
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
||||||
|
lookup_error = repeated_external_lookup_error(
|
||||||
|
tool_call.name,
|
||||||
|
tool_call.arguments,
|
||||||
|
external_lookup_counts,
|
||||||
|
)
|
||||||
|
if lookup_error:
|
||||||
|
event = {
|
||||||
|
"name": tool_call.name,
|
||||||
|
"status": "error",
|
||||||
|
"detail": "repeated external lookup blocked",
|
||||||
|
}
|
||||||
|
if spec.fail_on_tool_error:
|
||||||
|
return lookup_error + _HINT, event, RuntimeError(lookup_error)
|
||||||
|
return lookup_error + _HINT, event, None
|
||||||
|
prepare_call = getattr(spec.tools, "prepare_call", None)
|
||||||
|
tool, params, prep_error = None, tool_call.arguments, None
|
||||||
|
if callable(prepare_call):
|
||||||
try:
|
try:
|
||||||
result = await spec.tools.execute(tool_call.name, tool_call.arguments)
|
prepared = prepare_call(tool_call.name, tool_call.arguments)
|
||||||
|
if isinstance(prepared, tuple) and len(prepared) == 3:
|
||||||
|
tool, params, prep_error = prepared
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if prep_error:
|
||||||
|
event = {
|
||||||
|
"name": tool_call.name,
|
||||||
|
"status": "error",
|
||||||
|
"detail": prep_error.split(": ", 1)[-1][:120],
|
||||||
|
}
|
||||||
|
return prep_error + _HINT, event, RuntimeError(prep_error) if spec.fail_on_tool_error else None
|
||||||
|
try:
|
||||||
|
if tool is not None:
|
||||||
|
result = await tool.execute(**params)
|
||||||
|
else:
|
||||||
|
result = await spec.tools.execute(tool_call.name, params)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
raise
|
raise
|
||||||
except BaseException as exc:
|
except BaseException as exc:
|
||||||
@ -221,14 +423,178 @@ class AgentRunner:
|
|||||||
return f"Error: {type(exc).__name__}: {exc}", event, exc
|
return f"Error: {type(exc).__name__}: {exc}", event, exc
|
||||||
return f"Error: {type(exc).__name__}: {exc}", event, None
|
return f"Error: {type(exc).__name__}: {exc}", event, None
|
||||||
|
|
||||||
|
if isinstance(result, str) and result.startswith("Error"):
|
||||||
|
event = {
|
||||||
|
"name": tool_call.name,
|
||||||
|
"status": "error",
|
||||||
|
"detail": result.replace("\n", " ").strip()[:120],
|
||||||
|
}
|
||||||
|
if spec.fail_on_tool_error:
|
||||||
|
return result + _HINT, event, RuntimeError(result)
|
||||||
|
return result + _HINT, event, None
|
||||||
|
|
||||||
detail = "" if result is None else str(result)
|
detail = "" if result is None else str(result)
|
||||||
detail = detail.replace("\n", " ").strip()
|
detail = detail.replace("\n", " ").strip()
|
||||||
if not detail:
|
if not detail:
|
||||||
detail = "(empty)"
|
detail = "(empty)"
|
||||||
elif len(detail) > 120:
|
elif len(detail) > 120:
|
||||||
detail = detail[:120] + "..."
|
detail = detail[:120] + "..."
|
||||||
return result, {
|
return result, {"name": tool_call.name, "status": "ok", "detail": detail}, None
|
||||||
"name": tool_call.name,
|
|
||||||
"status": "error" if isinstance(result, str) and result.startswith("Error") else "ok",
|
async def _emit_checkpoint(
|
||||||
"detail": detail,
|
self,
|
||||||
}, None
|
spec: AgentRunSpec,
|
||||||
|
payload: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
callback = spec.checkpoint_callback
|
||||||
|
if callback is not None:
|
||||||
|
await callback(payload)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _append_final_message(messages: list[dict[str, Any]], content: str | None) -> None:
|
||||||
|
if not content:
|
||||||
|
return
|
||||||
|
if (
|
||||||
|
messages
|
||||||
|
and messages[-1].get("role") == "assistant"
|
||||||
|
and not messages[-1].get("tool_calls")
|
||||||
|
):
|
||||||
|
if messages[-1].get("content") == content:
|
||||||
|
return
|
||||||
|
messages[-1] = build_assistant_message(content)
|
||||||
|
return
|
||||||
|
messages.append(build_assistant_message(content))
|
||||||
|
|
||||||
|
def _normalize_tool_result(
|
||||||
|
self,
|
||||||
|
spec: AgentRunSpec,
|
||||||
|
tool_call_id: str,
|
||||||
|
tool_name: str,
|
||||||
|
result: Any,
|
||||||
|
) -> Any:
|
||||||
|
result = ensure_nonempty_tool_result(tool_name, result)
|
||||||
|
try:
|
||||||
|
content = maybe_persist_tool_result(
|
||||||
|
spec.workspace,
|
||||||
|
spec.session_key,
|
||||||
|
tool_call_id,
|
||||||
|
result,
|
||||||
|
max_chars=spec.max_tool_result_chars,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Tool result persist failed for {} in {}: {}; using raw result",
|
||||||
|
tool_call_id,
|
||||||
|
spec.session_key or "default",
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
content = result
|
||||||
|
if isinstance(content, str) and len(content) > spec.max_tool_result_chars:
|
||||||
|
return truncate_text(content, spec.max_tool_result_chars)
|
||||||
|
return content
|
||||||
|
|
||||||
|
def _apply_tool_result_budget(
|
||||||
|
self,
|
||||||
|
spec: AgentRunSpec,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
updated = messages
|
||||||
|
for idx, message in enumerate(messages):
|
||||||
|
if message.get("role") != "tool":
|
||||||
|
continue
|
||||||
|
normalized = self._normalize_tool_result(
|
||||||
|
spec,
|
||||||
|
str(message.get("tool_call_id") or f"tool_{idx}"),
|
||||||
|
str(message.get("name") or "tool"),
|
||||||
|
message.get("content"),
|
||||||
|
)
|
||||||
|
if normalized != message.get("content"):
|
||||||
|
if updated is messages:
|
||||||
|
updated = [dict(m) for m in messages]
|
||||||
|
updated[idx]["content"] = normalized
|
||||||
|
return updated
|
||||||
|
|
||||||
|
def _snip_history(
|
||||||
|
self,
|
||||||
|
spec: AgentRunSpec,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
if not messages or not spec.context_window_tokens:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
provider_max_tokens = getattr(getattr(self.provider, "generation", None), "max_tokens", 4096)
|
||||||
|
max_output = spec.max_tokens if isinstance(spec.max_tokens, int) else (
|
||||||
|
provider_max_tokens if isinstance(provider_max_tokens, int) else 4096
|
||||||
|
)
|
||||||
|
budget = spec.context_block_limit or (
|
||||||
|
spec.context_window_tokens - max_output - _SNIP_SAFETY_BUFFER
|
||||||
|
)
|
||||||
|
if budget <= 0:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
estimate, _ = estimate_prompt_tokens_chain(
|
||||||
|
self.provider,
|
||||||
|
spec.model,
|
||||||
|
messages,
|
||||||
|
spec.tools.get_definitions(),
|
||||||
|
)
|
||||||
|
if estimate <= budget:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
system_messages = [dict(msg) for msg in messages if msg.get("role") == "system"]
|
||||||
|
non_system = [dict(msg) for msg in messages if msg.get("role") != "system"]
|
||||||
|
if not non_system:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
system_tokens = sum(estimate_message_tokens(msg) for msg in system_messages)
|
||||||
|
remaining_budget = max(128, budget - system_tokens)
|
||||||
|
kept: list[dict[str, Any]] = []
|
||||||
|
kept_tokens = 0
|
||||||
|
for message in reversed(non_system):
|
||||||
|
msg_tokens = estimate_message_tokens(message)
|
||||||
|
if kept and kept_tokens + msg_tokens > remaining_budget:
|
||||||
|
break
|
||||||
|
kept.append(message)
|
||||||
|
kept_tokens += msg_tokens
|
||||||
|
kept.reverse()
|
||||||
|
|
||||||
|
if kept:
|
||||||
|
for i, message in enumerate(kept):
|
||||||
|
if message.get("role") == "user":
|
||||||
|
kept = kept[i:]
|
||||||
|
break
|
||||||
|
start = find_legal_message_start(kept)
|
||||||
|
if start:
|
||||||
|
kept = kept[start:]
|
||||||
|
if not kept:
|
||||||
|
kept = non_system[-min(len(non_system), 4) :]
|
||||||
|
start = find_legal_message_start(kept)
|
||||||
|
if start:
|
||||||
|
kept = kept[start:]
|
||||||
|
return system_messages + kept
|
||||||
|
|
||||||
|
def _partition_tool_batches(
|
||||||
|
self,
|
||||||
|
spec: AgentRunSpec,
|
||||||
|
tool_calls: list[ToolCallRequest],
|
||||||
|
) -> list[list[ToolCallRequest]]:
|
||||||
|
if not spec.concurrent_tools:
|
||||||
|
return [[tool_call] for tool_call in tool_calls]
|
||||||
|
|
||||||
|
batches: list[list[ToolCallRequest]] = []
|
||||||
|
current: list[ToolCallRequest] = []
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
get_tool = getattr(spec.tools, "get", None)
|
||||||
|
tool = get_tool(tool_call.name) if callable(get_tool) else None
|
||||||
|
can_batch = bool(tool and tool.concurrency_safe)
|
||||||
|
if can_batch:
|
||||||
|
current.append(tool_call)
|
||||||
|
continue
|
||||||
|
if current:
|
||||||
|
batches.append(current)
|
||||||
|
current = []
|
||||||
|
batches.append([tool_call])
|
||||||
|
if current:
|
||||||
|
batches.append(current)
|
||||||
|
return batches
|
||||||
|
|
||||||
|
|||||||
@ -44,6 +44,7 @@ class SubagentManager:
|
|||||||
provider: LLMProvider,
|
provider: LLMProvider,
|
||||||
workspace: Path,
|
workspace: Path,
|
||||||
bus: MessageBus,
|
bus: MessageBus,
|
||||||
|
max_tool_result_chars: int,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
web_search_config: "WebSearchConfig | None" = None,
|
web_search_config: "WebSearchConfig | None" = None,
|
||||||
web_proxy: str | None = None,
|
web_proxy: str | None = None,
|
||||||
@ -56,6 +57,7 @@ class SubagentManager:
|
|||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.bus = bus
|
self.bus = bus
|
||||||
self.model = model or provider.get_default_model()
|
self.model = model or provider.get_default_model()
|
||||||
|
self.max_tool_result_chars = max_tool_result_chars
|
||||||
self.web_search_config = web_search_config or WebSearchConfig()
|
self.web_search_config = web_search_config or WebSearchConfig()
|
||||||
self.web_proxy = web_proxy
|
self.web_proxy = web_proxy
|
||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
@ -136,6 +138,7 @@ class SubagentManager:
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
max_iterations=15,
|
max_iterations=15,
|
||||||
|
max_tool_result_chars=self.max_tool_result_chars,
|
||||||
hook=_SubagentHook(task_id),
|
hook=_SubagentHook(task_id),
|
||||||
max_iterations_message="Task completed but no final response was generated.",
|
max_iterations_message="Task completed but no final response was generated.",
|
||||||
error_message=None,
|
error_message=None,
|
||||||
|
|||||||
@ -53,6 +53,21 @@ class Tool(ABC):
|
|||||||
"""JSON Schema for tool parameters."""
|
"""JSON Schema for tool parameters."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def read_only(self) -> bool:
|
||||||
|
"""Whether this tool is side-effect free and safe to parallelize."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def concurrency_safe(self) -> bool:
|
||||||
|
"""Whether this tool can run alongside other concurrency-safe tools."""
|
||||||
|
return self.read_only and not self.exclusive
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exclusive(self) -> bool:
|
||||||
|
"""Whether this tool should run alone even if concurrency is enabled."""
|
||||||
|
return False
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def execute(self, **kwargs: Any) -> Any:
|
async def execute(self, **kwargs: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -73,6 +73,10 @@ class ReadFileTool(_FsTool):
|
|||||||
"Use offset and limit to paginate through large files."
|
"Use offset and limit to paginate through large files."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def read_only(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
@ -344,6 +348,10 @@ class ListDirTool(_FsTool):
|
|||||||
"Common noise directories (.git, node_modules, __pycache__, etc.) are auto-ignored."
|
"Common noise directories (.git, node_modules, __pycache__, etc.) are auto-ignored."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def read_only(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -84,6 +84,9 @@ class MessageTool(Tool):
|
|||||||
media: list[str] | None = None,
|
media: list[str] | None = None,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> str:
|
) -> str:
|
||||||
|
from nanobot.utils.helpers import strip_think
|
||||||
|
content = strip_think(content)
|
||||||
|
|
||||||
channel = channel or self._default_channel
|
channel = channel or self._default_channel
|
||||||
chat_id = chat_id or self._default_chat_id
|
chat_id = chat_id or self._default_chat_id
|
||||||
# Only inherit default message_id when targeting the same channel+chat.
|
# Only inherit default message_id when targeting the same channel+chat.
|
||||||
|
|||||||
@ -35,22 +35,35 @@ class ToolRegistry:
|
|||||||
"""Get all tool definitions in OpenAI format."""
|
"""Get all tool definitions in OpenAI format."""
|
||||||
return [tool.to_schema() for tool in self._tools.values()]
|
return [tool.to_schema() for tool in self._tools.values()]
|
||||||
|
|
||||||
|
def prepare_call(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
params: dict[str, Any],
|
||||||
|
) -> tuple[Tool | None, dict[str, Any], str | None]:
|
||||||
|
"""Resolve, cast, and validate one tool call."""
|
||||||
|
tool = self._tools.get(name)
|
||||||
|
if not tool:
|
||||||
|
return None, params, (
|
||||||
|
f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
cast_params = tool.cast_params(params)
|
||||||
|
errors = tool.validate_params(cast_params)
|
||||||
|
if errors:
|
||||||
|
return tool, cast_params, (
|
||||||
|
f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors)
|
||||||
|
)
|
||||||
|
return tool, cast_params, None
|
||||||
|
|
||||||
async def execute(self, name: str, params: dict[str, Any]) -> Any:
|
async def execute(self, name: str, params: dict[str, Any]) -> Any:
|
||||||
"""Execute a tool by name with given parameters."""
|
"""Execute a tool by name with given parameters."""
|
||||||
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
||||||
|
tool, params, error = self.prepare_call(name, params)
|
||||||
tool = self._tools.get(name)
|
if error:
|
||||||
if not tool:
|
return error + _HINT
|
||||||
return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Attempt to cast parameters to match schema types
|
assert tool is not None # guarded by prepare_call()
|
||||||
params = tool.cast_params(params)
|
|
||||||
|
|
||||||
# Validate parameters
|
|
||||||
errors = tool.validate_params(params)
|
|
||||||
if errors:
|
|
||||||
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT
|
|
||||||
result = await tool.execute(**params)
|
result = await tool.execute(**params)
|
||||||
if isinstance(result, str) and result.startswith("Error"):
|
if isinstance(result, str) and result.startswith("Error"):
|
||||||
return result + _HINT
|
return result + _HINT
|
||||||
|
|||||||
@ -52,6 +52,10 @@ class ExecTool(Tool):
|
|||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return "Execute a shell command and return its output. Use with caution."
|
return "Execute a shell command and return its output. Use with caution."
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exclusive(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -92,6 +92,10 @@ class WebSearchTool(Tool):
|
|||||||
self.config = config if config is not None else WebSearchConfig()
|
self.config = config if config is not None else WebSearchConfig()
|
||||||
self.proxy = proxy
|
self.proxy = proxy
|
||||||
|
|
||||||
|
@property
|
||||||
|
def read_only(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
||||||
provider = self.config.provider.strip().lower() or "brave"
|
provider = self.config.provider.strip().lower() or "brave"
|
||||||
n = min(max(count or self.config.max_results, 1), 10)
|
n = min(max(count or self.config.max_results, 1), 10)
|
||||||
@ -234,6 +238,10 @@ class WebFetchTool(Tool):
|
|||||||
self.max_chars = max_chars
|
self.max_chars = max_chars
|
||||||
self.proxy = proxy
|
self.proxy = proxy
|
||||||
|
|
||||||
|
@property
|
||||||
|
def read_only(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> Any:
|
async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> Any:
|
||||||
max_chars = maxChars or self.max_chars
|
max_chars = maxChars or self.max_chars
|
||||||
is_valid, error_msg = _validate_url_safe(url)
|
is_valid, error_msg = _validate_url_safe(url)
|
||||||
|
|||||||
@ -14,6 +14,8 @@ from typing import Any
|
|||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
|
|
||||||
API_SESSION_KEY = "api:default"
|
API_SESSION_KEY = "api:default"
|
||||||
API_CHAT_ID = "default"
|
API_CHAT_ID = "default"
|
||||||
|
|
||||||
@ -98,7 +100,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
|||||||
|
|
||||||
logger.info("API request session_key={} content={}", session_key, user_content[:80])
|
logger.info("API request session_key={} content={}", session_key, user_content[:80])
|
||||||
|
|
||||||
_FALLBACK = "I've completed processing but have no response to give."
|
_FALLBACK = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with session_lock:
|
async with session_lock:
|
||||||
|
|||||||
@ -134,6 +134,7 @@ class QQConfig(Base):
|
|||||||
secret: str = ""
|
secret: str = ""
|
||||||
allow_from: list[str] = Field(default_factory=list)
|
allow_from: list[str] = Field(default_factory=list)
|
||||||
msg_format: Literal["plain", "markdown"] = "plain"
|
msg_format: Literal["plain", "markdown"] = "plain"
|
||||||
|
ack_message: str = "⏳ Processing..."
|
||||||
|
|
||||||
# Optional: directory to save inbound attachments. If empty, use nanobot get_media_dir("qq").
|
# Optional: directory to save inbound attachments. If empty, use nanobot get_media_dir("qq").
|
||||||
media_dir: str = ""
|
media_dir: str = ""
|
||||||
@ -484,6 +485,17 @@ class QQChannel(BaseChannel):
|
|||||||
if not content and not media_paths:
|
if not content and not media_paths:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if self.config.ack_message:
|
||||||
|
try:
|
||||||
|
await self._send_text_only(
|
||||||
|
chat_id=chat_id,
|
||||||
|
is_group=is_group,
|
||||||
|
msg_id=data.id,
|
||||||
|
content=self.config.ack_message,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("QQ ack message failed for chat_id={}", chat_id)
|
||||||
|
|
||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
sender_id=user_id,
|
sender_id=user_id,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
|
|||||||
@ -275,13 +275,10 @@ class TelegramChannel(BaseChannel):
|
|||||||
self._app = builder.build()
|
self._app = builder.build()
|
||||||
self._app.add_error_handler(self._on_error)
|
self._app.add_error_handler(self._on_error)
|
||||||
|
|
||||||
# Add command handlers
|
# Add command handlers (using Regex to support @username suffixes before bot initialization)
|
||||||
self._app.add_handler(CommandHandler("start", self._on_start))
|
self._app.add_handler(MessageHandler(filters.Regex(r"^/start(?:@\w+)?$"), self._on_start))
|
||||||
self._app.add_handler(CommandHandler("new", self._forward_command))
|
self._app.add_handler(MessageHandler(filters.Regex(r"^/(new|stop|restart|status)(?:@\w+)?$"), self._forward_command))
|
||||||
self._app.add_handler(CommandHandler("stop", self._forward_command))
|
self._app.add_handler(MessageHandler(filters.Regex(r"^/help(?:@\w+)?$"), self._on_help))
|
||||||
self._app.add_handler(CommandHandler("restart", self._forward_command))
|
|
||||||
self._app.add_handler(CommandHandler("status", self._forward_command))
|
|
||||||
self._app.add_handler(CommandHandler("help", self._on_help))
|
|
||||||
|
|
||||||
# Add message handler for text, photos, voice, documents
|
# Add message handler for text, photos, voice, documents
|
||||||
self._app.add_handler(
|
self._app.add_handler(
|
||||||
@ -313,7 +310,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
# Start polling (this runs until stopped)
|
# Start polling (this runs until stopped)
|
||||||
await self._app.updater.start_polling(
|
await self._app.updater.start_polling(
|
||||||
allowed_updates=["message"],
|
allowed_updates=["message"],
|
||||||
drop_pending_updates=True # Ignore old messages on startup
|
drop_pending_updates=False # Process pending messages on startup
|
||||||
)
|
)
|
||||||
|
|
||||||
# Keep running until stopped
|
# Keep running until stopped
|
||||||
@ -362,9 +359,14 @@ class TelegramChannel(BaseChannel):
|
|||||||
logger.warning("Telegram bot not running")
|
logger.warning("Telegram bot not running")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Only stop typing indicator for final responses
|
# Only stop typing indicator and remove reaction for final responses
|
||||||
if not msg.metadata.get("_progress", False):
|
if not msg.metadata.get("_progress", False):
|
||||||
self._stop_typing(msg.chat_id)
|
self._stop_typing(msg.chat_id)
|
||||||
|
if reply_to_message_id := msg.metadata.get("message_id"):
|
||||||
|
try:
|
||||||
|
await self._remove_reaction(msg.chat_id, int(reply_to_message_id))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
chat_id = int(msg.chat_id)
|
chat_id = int(msg.chat_id)
|
||||||
@ -435,7 +437,9 @@ class TelegramChannel(BaseChannel):
|
|||||||
await self._send_text(chat_id, chunk, reply_params, thread_kwargs)
|
await self._send_text(chat_id, chunk, reply_params, thread_kwargs)
|
||||||
|
|
||||||
async def _call_with_retry(self, fn, *args, **kwargs):
|
async def _call_with_retry(self, fn, *args, **kwargs):
|
||||||
"""Call an async Telegram API function with retry on pool/network timeout."""
|
"""Call an async Telegram API function with retry on pool/network timeout and RetryAfter."""
|
||||||
|
from telegram.error import RetryAfter
|
||||||
|
|
||||||
for attempt in range(1, _SEND_MAX_RETRIES + 1):
|
for attempt in range(1, _SEND_MAX_RETRIES + 1):
|
||||||
try:
|
try:
|
||||||
return await fn(*args, **kwargs)
|
return await fn(*args, **kwargs)
|
||||||
@ -448,6 +452,15 @@ class TelegramChannel(BaseChannel):
|
|||||||
attempt, _SEND_MAX_RETRIES, delay,
|
attempt, _SEND_MAX_RETRIES, delay,
|
||||||
)
|
)
|
||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
|
except RetryAfter as e:
|
||||||
|
if attempt == _SEND_MAX_RETRIES:
|
||||||
|
raise
|
||||||
|
delay = float(e.retry_after)
|
||||||
|
logger.warning(
|
||||||
|
"Telegram Flood Control (attempt {}/{}), retrying in {:.1f}s",
|
||||||
|
attempt, _SEND_MAX_RETRIES, delay,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
async def _send_text(
|
async def _send_text(
|
||||||
self,
|
self,
|
||||||
@ -498,6 +511,11 @@ class TelegramChannel(BaseChannel):
|
|||||||
if stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id:
|
if stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id:
|
||||||
return
|
return
|
||||||
self._stop_typing(chat_id)
|
self._stop_typing(chat_id)
|
||||||
|
if reply_to_message_id := meta.get("message_id"):
|
||||||
|
try:
|
||||||
|
await self._remove_reaction(chat_id, int(reply_to_message_id))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
try:
|
try:
|
||||||
html = _markdown_to_telegram_html(buf.text)
|
html = _markdown_to_telegram_html(buf.text)
|
||||||
await self._call_with_retry(
|
await self._call_with_retry(
|
||||||
@ -619,8 +637,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
"reply_to_message_id": getattr(reply_to, "message_id", None) if reply_to else None,
|
"reply_to_message_id": getattr(reply_to, "message_id", None) if reply_to else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
async def _extract_reply_context(self, message) -> str | None:
|
||||||
def _extract_reply_context(message) -> str | None:
|
|
||||||
"""Extract text from the message being replied to, if any."""
|
"""Extract text from the message being replied to, if any."""
|
||||||
reply = getattr(message, "reply_to_message", None)
|
reply = getattr(message, "reply_to_message", None)
|
||||||
if not reply:
|
if not reply:
|
||||||
@ -628,7 +645,21 @@ class TelegramChannel(BaseChannel):
|
|||||||
text = getattr(reply, "text", None) or getattr(reply, "caption", None) or ""
|
text = getattr(reply, "text", None) or getattr(reply, "caption", None) or ""
|
||||||
if len(text) > TELEGRAM_REPLY_CONTEXT_MAX_LEN:
|
if len(text) > TELEGRAM_REPLY_CONTEXT_MAX_LEN:
|
||||||
text = text[:TELEGRAM_REPLY_CONTEXT_MAX_LEN] + "..."
|
text = text[:TELEGRAM_REPLY_CONTEXT_MAX_LEN] + "..."
|
||||||
return f"[Reply to: {text}]" if text else None
|
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
|
||||||
|
bot_id, _ = await self._ensure_bot_identity()
|
||||||
|
reply_user = getattr(reply, "from_user", None)
|
||||||
|
|
||||||
|
if bot_id and reply_user and getattr(reply_user, "id", None) == bot_id:
|
||||||
|
return f"[Reply to bot: {text}]"
|
||||||
|
elif reply_user and getattr(reply_user, "username", None):
|
||||||
|
return f"[Reply to @{reply_user.username}: {text}]"
|
||||||
|
elif reply_user and getattr(reply_user, "first_name", None):
|
||||||
|
return f"[Reply to {reply_user.first_name}: {text}]"
|
||||||
|
else:
|
||||||
|
return f"[Reply to: {text}]"
|
||||||
|
|
||||||
async def _download_message_media(
|
async def _download_message_media(
|
||||||
self, msg, *, add_failure_content: bool = False
|
self, msg, *, add_failure_content: bool = False
|
||||||
@ -765,10 +796,18 @@ class TelegramChannel(BaseChannel):
|
|||||||
message = update.message
|
message = update.message
|
||||||
user = update.effective_user
|
user = update.effective_user
|
||||||
self._remember_thread_context(message)
|
self._remember_thread_context(message)
|
||||||
|
|
||||||
|
# Strip @bot_username suffix if present
|
||||||
|
content = message.text or ""
|
||||||
|
if content.startswith("/") and "@" in content:
|
||||||
|
cmd_part, *rest = content.split(" ", 1)
|
||||||
|
cmd_part = cmd_part.split("@")[0]
|
||||||
|
content = f"{cmd_part} {rest[0]}" if rest else cmd_part
|
||||||
|
|
||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
sender_id=self._sender_id(user),
|
sender_id=self._sender_id(user),
|
||||||
chat_id=str(message.chat_id),
|
chat_id=str(message.chat_id),
|
||||||
content=message.text or "",
|
content=content,
|
||||||
metadata=self._build_message_metadata(message, user),
|
metadata=self._build_message_metadata(message, user),
|
||||||
session_key=self._derive_topic_session_key(message),
|
session_key=self._derive_topic_session_key(message),
|
||||||
)
|
)
|
||||||
@ -812,7 +851,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
# Reply context: text and/or media from the replied-to message
|
# Reply context: text and/or media from the replied-to message
|
||||||
reply = getattr(message, "reply_to_message", None)
|
reply = getattr(message, "reply_to_message", None)
|
||||||
if reply is not None:
|
if reply is not None:
|
||||||
reply_ctx = self._extract_reply_context(message)
|
reply_ctx = await self._extract_reply_context(message)
|
||||||
reply_media, reply_media_parts = await self._download_message_media(reply)
|
reply_media, reply_media_parts = await self._download_message_media(reply)
|
||||||
if reply_media:
|
if reply_media:
|
||||||
media_paths = reply_media + media_paths
|
media_paths = reply_media + media_paths
|
||||||
@ -903,6 +942,19 @@ class TelegramChannel(BaseChannel):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Telegram reaction failed: {}", e)
|
logger.debug("Telegram reaction failed: {}", e)
|
||||||
|
|
||||||
|
async def _remove_reaction(self, chat_id: str, message_id: int) -> None:
|
||||||
|
"""Remove emoji reaction from a message (best-effort, non-blocking)."""
|
||||||
|
if not self._app:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await self._app.bot.set_message_reaction(
|
||||||
|
chat_id=int(chat_id),
|
||||||
|
message_id=message_id,
|
||||||
|
reaction=[],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Telegram reaction removal failed: {}", e)
|
||||||
|
|
||||||
async def _typing_loop(self, chat_id: str) -> None:
|
async def _typing_loop(self, chat_id: str) -> None:
|
||||||
"""Repeatedly send 'typing' action until cancelled."""
|
"""Repeatedly send 'typing' action until cancelled."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -542,6 +542,9 @@ def serve(
|
|||||||
model=runtime_config.agents.defaults.model,
|
model=runtime_config.agents.defaults.model,
|
||||||
max_iterations=runtime_config.agents.defaults.max_tool_iterations,
|
max_iterations=runtime_config.agents.defaults.max_tool_iterations,
|
||||||
context_window_tokens=runtime_config.agents.defaults.context_window_tokens,
|
context_window_tokens=runtime_config.agents.defaults.context_window_tokens,
|
||||||
|
context_block_limit=runtime_config.agents.defaults.context_block_limit,
|
||||||
|
max_tool_result_chars=runtime_config.agents.defaults.max_tool_result_chars,
|
||||||
|
provider_retry_mode=runtime_config.agents.defaults.provider_retry_mode,
|
||||||
web_search_config=runtime_config.tools.web.search,
|
web_search_config=runtime_config.tools.web.search,
|
||||||
web_proxy=runtime_config.tools.web.proxy or None,
|
web_proxy=runtime_config.tools.web.proxy or None,
|
||||||
exec_config=runtime_config.tools.exec,
|
exec_config=runtime_config.tools.exec,
|
||||||
@ -629,6 +632,9 @@ def gateway(
|
|||||||
model=config.agents.defaults.model,
|
model=config.agents.defaults.model,
|
||||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||||
|
context_block_limit=config.agents.defaults.context_block_limit,
|
||||||
|
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
|
||||||
|
provider_retry_mode=config.agents.defaults.provider_retry_mode,
|
||||||
web_search_config=config.tools.web.search,
|
web_search_config=config.tools.web.search,
|
||||||
web_proxy=config.tools.web.proxy or None,
|
web_proxy=config.tools.web.proxy or None,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
@ -835,6 +841,9 @@ def agent(
|
|||||||
model=config.agents.defaults.model,
|
model=config.agents.defaults.model,
|
||||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||||
|
context_block_limit=config.agents.defaults.context_block_limit,
|
||||||
|
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
|
||||||
|
provider_retry_mode=config.agents.defaults.provider_retry_mode,
|
||||||
web_search_config=config.tools.web.search,
|
web_search_config=config.tools.web.search,
|
||||||
web_proxy=config.tools.web.proxy or None,
|
web_proxy=config.tools.web.proxy or None,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
@ -1023,12 +1032,18 @@ app.add_typer(channels_app, name="channels")
|
|||||||
|
|
||||||
|
|
||||||
@channels_app.command("status")
|
@channels_app.command("status")
|
||||||
def channels_status():
|
def channels_status(
|
||||||
|
config_path: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
|
||||||
|
):
|
||||||
"""Show channel status."""
|
"""Show channel status."""
|
||||||
from nanobot.channels.registry import discover_all
|
from nanobot.channels.registry import discover_all
|
||||||
from nanobot.config.loader import load_config
|
from nanobot.config.loader import load_config, set_config_path
|
||||||
|
|
||||||
config = load_config()
|
resolved_config_path = Path(config_path).expanduser().resolve() if config_path else None
|
||||||
|
if resolved_config_path is not None:
|
||||||
|
set_config_path(resolved_config_path)
|
||||||
|
|
||||||
|
config = load_config(resolved_config_path)
|
||||||
|
|
||||||
table = Table(title="Channel Status")
|
table = Table(title="Channel Status")
|
||||||
table.add_column("Channel", style="cyan")
|
table.add_column("Channel", style="cyan")
|
||||||
@ -1115,12 +1130,17 @@ def _get_bridge_dir() -> Path:
|
|||||||
def channels_login(
|
def channels_login(
|
||||||
channel_name: str = typer.Argument(..., help="Channel name (e.g. weixin, whatsapp)"),
|
channel_name: str = typer.Argument(..., help="Channel name (e.g. weixin, whatsapp)"),
|
||||||
force: bool = typer.Option(False, "--force", "-f", help="Force re-authentication even if already logged in"),
|
force: bool = typer.Option(False, "--force", "-f", help="Force re-authentication even if already logged in"),
|
||||||
|
config_path: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
|
||||||
):
|
):
|
||||||
"""Authenticate with a channel via QR code or other interactive login."""
|
"""Authenticate with a channel via QR code or other interactive login."""
|
||||||
from nanobot.channels.registry import discover_all
|
from nanobot.channels.registry import discover_all
|
||||||
from nanobot.config.loader import load_config
|
from nanobot.config.loader import load_config, set_config_path
|
||||||
|
|
||||||
config = load_config()
|
resolved_config_path = Path(config_path).expanduser().resolve() if config_path else None
|
||||||
|
if resolved_config_path is not None:
|
||||||
|
set_config_path(resolved_config_path)
|
||||||
|
|
||||||
|
config = load_config(resolved_config_path)
|
||||||
channel_cfg = getattr(config.channels, channel_name, None) or {}
|
channel_cfg = getattr(config.channels, channel_name, None) or {}
|
||||||
|
|
||||||
# Validate channel exists
|
# Validate channel exists
|
||||||
|
|||||||
@ -26,7 +26,10 @@ async def cmd_stop(ctx: CommandContext) -> OutboundMessage:
|
|||||||
sub_cancelled = await loop.subagents.cancel_by_session(msg.session_key)
|
sub_cancelled = await loop.subagents.cancel_by_session(msg.session_key)
|
||||||
total = cancelled + sub_cancelled
|
total = cancelled + sub_cancelled
|
||||||
content = f"Stopped {total} task(s)." if total else "No active task to stop."
|
content = f"Stopped {total} task(s)." if total else "No active task to stop."
|
||||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, content=content)
|
return OutboundMessage(
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id, content=content,
|
||||||
|
metadata=dict(msg.metadata or {})
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def cmd_restart(ctx: CommandContext) -> OutboundMessage:
|
async def cmd_restart(ctx: CommandContext) -> OutboundMessage:
|
||||||
@ -38,7 +41,10 @@ async def cmd_restart(ctx: CommandContext) -> OutboundMessage:
|
|||||||
os.execv(sys.executable, [sys.executable, "-m", "nanobot"] + sys.argv[1:])
|
os.execv(sys.executable, [sys.executable, "-m", "nanobot"] + sys.argv[1:])
|
||||||
|
|
||||||
asyncio.create_task(_do_restart())
|
asyncio.create_task(_do_restart())
|
||||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, content="Restarting...")
|
return OutboundMessage(
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id, content="Restarting...",
|
||||||
|
metadata=dict(msg.metadata or {})
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def cmd_status(ctx: CommandContext) -> OutboundMessage:
|
async def cmd_status(ctx: CommandContext) -> OutboundMessage:
|
||||||
@ -62,7 +68,7 @@ async def cmd_status(ctx: CommandContext) -> OutboundMessage:
|
|||||||
session_msg_count=len(session.get_history(max_messages=0)),
|
session_msg_count=len(session.get_history(max_messages=0)),
|
||||||
context_tokens_estimate=ctx_est,
|
context_tokens_estimate=ctx_est,
|
||||||
),
|
),
|
||||||
metadata={"render_as": "text"},
|
metadata={**dict(ctx.msg.metadata or {}), "render_as": "text"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -79,6 +85,7 @@ async def cmd_new(ctx: CommandContext) -> OutboundMessage:
|
|||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=ctx.msg.channel, chat_id=ctx.msg.chat_id,
|
channel=ctx.msg.channel, chat_id=ctx.msg.chat_id,
|
||||||
content="New session started.",
|
content="New session started.",
|
||||||
|
metadata=dict(ctx.msg.metadata or {})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -88,7 +95,7 @@ async def cmd_help(ctx: CommandContext) -> OutboundMessage:
|
|||||||
channel=ctx.msg.channel,
|
channel=ctx.msg.channel,
|
||||||
chat_id=ctx.msg.chat_id,
|
chat_id=ctx.msg.chat_id,
|
||||||
content=build_help_text(),
|
content=build_help_text(),
|
||||||
metadata={"render_as": "text"},
|
metadata={**dict(ctx.msg.metadata or {}), "render_as": "text"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -38,8 +38,11 @@ class AgentDefaults(Base):
|
|||||||
)
|
)
|
||||||
max_tokens: int = 8192
|
max_tokens: int = 8192
|
||||||
context_window_tokens: int = 65_536
|
context_window_tokens: int = 65_536
|
||||||
|
context_block_limit: int | None = None
|
||||||
temperature: float = 0.1
|
temperature: float = 0.1
|
||||||
max_tool_iterations: int = 40
|
max_tool_iterations: int = 200
|
||||||
|
max_tool_result_chars: int = 16_000
|
||||||
|
provider_retry_mode: Literal["standard", "persistent"] = "standard"
|
||||||
reasoning_effort: str | None = None # low / medium / high - enables LLM thinking mode
|
reasoning_effort: str | None = None # low / medium / high - enables LLM thinking mode
|
||||||
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"
|
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"
|
||||||
|
|
||||||
|
|||||||
@ -73,6 +73,9 @@ class Nanobot:
|
|||||||
model=defaults.model,
|
model=defaults.model,
|
||||||
max_iterations=defaults.max_tool_iterations,
|
max_iterations=defaults.max_tool_iterations,
|
||||||
context_window_tokens=defaults.context_window_tokens,
|
context_window_tokens=defaults.context_window_tokens,
|
||||||
|
context_block_limit=defaults.context_block_limit,
|
||||||
|
max_tool_result_chars=defaults.max_tool_result_chars,
|
||||||
|
provider_retry_mode=defaults.provider_retry_mode,
|
||||||
web_search_config=config.tools.web.search,
|
web_search_config=config.tools.web.search,
|
||||||
web_proxy=config.tools.web.proxy or None,
|
web_proxy=config.tools.web.proxy or None,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
@ -434,13 +436,33 @@ class AnthropicProvider(LLMProvider):
|
|||||||
messages, tools, model, max_tokens, temperature,
|
messages, tools, model, max_tokens, temperature,
|
||||||
reasoning_effort, tool_choice,
|
reasoning_effort, tool_choice,
|
||||||
)
|
)
|
||||||
|
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
|
||||||
try:
|
try:
|
||||||
async with self._client.messages.stream(**kwargs) as stream:
|
async with self._client.messages.stream(**kwargs) as stream:
|
||||||
if on_content_delta:
|
if on_content_delta:
|
||||||
async for text in stream.text_stream:
|
stream_iter = stream.text_stream.__aiter__()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
text = await asyncio.wait_for(
|
||||||
|
stream_iter.__anext__(),
|
||||||
|
timeout=idle_timeout_s,
|
||||||
|
)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
await on_content_delta(text)
|
await on_content_delta(text)
|
||||||
response = await stream.get_final_message()
|
response = await asyncio.wait_for(
|
||||||
|
stream.get_final_message(),
|
||||||
|
timeout=idle_timeout_s,
|
||||||
|
)
|
||||||
return self._parse_response(response)
|
return self._parse_response(response)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return LLMResponse(
|
||||||
|
content=(
|
||||||
|
f"Error calling LLM: stream stalled for more than "
|
||||||
|
f"{idle_timeout_s} seconds"
|
||||||
|
),
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error")
|
return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error")
|
||||||
|
|
||||||
|
|||||||
@ -1,31 +1,36 @@
|
|||||||
"""Azure OpenAI provider implementation with API version 2024-10-21."""
|
"""Azure OpenAI provider using the OpenAI SDK Responses API.
|
||||||
|
|
||||||
|
Uses ``AsyncOpenAI`` pointed at ``https://{endpoint}/openai/v1/`` which
|
||||||
|
routes to the Responses API (``/responses``). Reuses shared conversion
|
||||||
|
helpers from :mod:`nanobot.providers.openai_responses`.
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from urllib.parse import urljoin
|
|
||||||
|
|
||||||
import httpx
|
from openai import AsyncOpenAI
|
||||||
import json_repair
|
|
||||||
|
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMProvider, LLMResponse
|
||||||
|
from nanobot.providers.openai_responses import (
|
||||||
_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
|
consume_sdk_stream,
|
||||||
|
convert_messages,
|
||||||
|
convert_tools,
|
||||||
|
parse_response_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIProvider(LLMProvider):
|
class AzureOpenAIProvider(LLMProvider):
|
||||||
"""
|
"""Azure OpenAI provider backed by the Responses API.
|
||||||
Azure OpenAI provider with API version 2024-10-21 compliance.
|
|
||||||
|
|
||||||
Features:
|
Features:
|
||||||
- Hardcoded API version 2024-10-21
|
- Uses the OpenAI Python SDK (``AsyncOpenAI``) with
|
||||||
- Uses model field as Azure deployment name in URL path
|
``base_url = {endpoint}/openai/v1/``
|
||||||
- Uses api-key header instead of Authorization Bearer
|
- Calls ``client.responses.create()`` (Responses API)
|
||||||
- Uses max_completion_tokens instead of max_tokens
|
- Reuses shared message/tool/SSE conversion from
|
||||||
- Direct HTTP calls, bypasses LiteLLM
|
``openai_responses``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -36,40 +41,28 @@ class AzureOpenAIProvider(LLMProvider):
|
|||||||
):
|
):
|
||||||
super().__init__(api_key, api_base)
|
super().__init__(api_key, api_base)
|
||||||
self.default_model = default_model
|
self.default_model = default_model
|
||||||
self.api_version = "2024-10-21"
|
|
||||||
|
|
||||||
# Validate required parameters
|
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise ValueError("Azure OpenAI api_key is required")
|
raise ValueError("Azure OpenAI api_key is required")
|
||||||
if not api_base:
|
if not api_base:
|
||||||
raise ValueError("Azure OpenAI api_base is required")
|
raise ValueError("Azure OpenAI api_base is required")
|
||||||
|
|
||||||
# Ensure api_base ends with /
|
# Normalise: ensure trailing slash
|
||||||
if not api_base.endswith('/'):
|
if not api_base.endswith("/"):
|
||||||
api_base += '/'
|
api_base += "/"
|
||||||
self.api_base = api_base
|
self.api_base = api_base
|
||||||
|
|
||||||
def _build_chat_url(self, deployment_name: str) -> str:
|
# SDK client targeting the Azure Responses API endpoint
|
||||||
"""Build the Azure OpenAI chat completions URL."""
|
base_url = f"{api_base.rstrip('/')}/openai/v1/"
|
||||||
# Azure OpenAI URL format:
|
self._client = AsyncOpenAI(
|
||||||
# https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}
|
api_key=api_key,
|
||||||
base_url = self.api_base
|
base_url=base_url,
|
||||||
if not base_url.endswith('/'):
|
default_headers={"x-session-affinity": uuid.uuid4().hex},
|
||||||
base_url += '/'
|
|
||||||
|
|
||||||
url = urljoin(
|
|
||||||
base_url,
|
|
||||||
f"openai/deployments/{deployment_name}/chat/completions"
|
|
||||||
)
|
)
|
||||||
return f"{url}?api-version={self.api_version}"
|
|
||||||
|
|
||||||
def _build_headers(self) -> dict[str, str]:
|
# ------------------------------------------------------------------
|
||||||
"""Build headers for Azure OpenAI API with api-key header."""
|
# Helpers
|
||||||
return {
|
# ------------------------------------------------------------------
|
||||||
"Content-Type": "application/json",
|
|
||||||
"api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization
|
|
||||||
"x-session-affinity": uuid.uuid4().hex, # For cache locality
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _supports_temperature(
|
def _supports_temperature(
|
||||||
@ -82,36 +75,51 @@ class AzureOpenAIProvider(LLMProvider):
|
|||||||
name = deployment_name.lower()
|
name = deployment_name.lower()
|
||||||
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
|
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
|
||||||
|
|
||||||
def _prepare_request_payload(
|
def _build_body(
|
||||||
self,
|
self,
|
||||||
deployment_name: str,
|
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None,
|
||||||
max_tokens: int = 4096,
|
model: str | None,
|
||||||
temperature: float = 0.7,
|
max_tokens: int,
|
||||||
reasoning_effort: str | None = None,
|
temperature: float,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
reasoning_effort: str | None,
|
||||||
|
tool_choice: str | dict[str, Any] | None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
|
"""Build the Responses API request body from Chat-Completions-style args."""
|
||||||
payload: dict[str, Any] = {
|
deployment = model or self.default_model
|
||||||
"messages": self._sanitize_request_messages(
|
instructions, input_items = convert_messages(self._sanitize_empty_content(messages))
|
||||||
self._sanitize_empty_content(messages),
|
|
||||||
_AZURE_MSG_KEYS,
|
body: dict[str, Any] = {
|
||||||
),
|
"model": deployment,
|
||||||
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
|
"instructions": instructions or None,
|
||||||
|
"input": input_items,
|
||||||
|
"max_output_tokens": max(1, max_tokens),
|
||||||
|
"store": False,
|
||||||
|
"stream": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self._supports_temperature(deployment_name, reasoning_effort):
|
if self._supports_temperature(deployment, reasoning_effort):
|
||||||
payload["temperature"] = temperature
|
body["temperature"] = temperature
|
||||||
|
|
||||||
if reasoning_effort:
|
if reasoning_effort:
|
||||||
payload["reasoning_effort"] = reasoning_effort
|
body["reasoning"] = {"effort": reasoning_effort}
|
||||||
|
body["include"] = ["reasoning.encrypted_content"]
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
payload["tools"] = tools
|
body["tools"] = convert_tools(tools)
|
||||||
payload["tool_choice"] = tool_choice or "auto"
|
body["tool_choice"] = tool_choice or "auto"
|
||||||
|
|
||||||
return payload
|
return body
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _handle_error(e: Exception) -> LLMResponse:
|
||||||
|
body = getattr(e, "body", None) or getattr(getattr(e, "response", None), "text", None)
|
||||||
|
msg = f"Error: {str(body).strip()[:500]}" if body else f"Error calling Azure OpenAI: {e}"
|
||||||
|
return LLMResponse(content=msg, finish_reason="error")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
@ -123,92 +131,15 @@ class AzureOpenAIProvider(LLMProvider):
|
|||||||
reasoning_effort: str | None = None,
|
reasoning_effort: str | None = None,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
body = self._build_body(
|
||||||
Send a chat completion request to Azure OpenAI.
|
messages, tools, model, max_tokens, temperature,
|
||||||
|
reasoning_effort, tool_choice,
|
||||||
Args:
|
|
||||||
messages: List of message dicts with 'role' and 'content'.
|
|
||||||
tools: Optional list of tool definitions in OpenAI format.
|
|
||||||
model: Model identifier (used as deployment name).
|
|
||||||
max_tokens: Maximum tokens in response (mapped to max_completion_tokens).
|
|
||||||
temperature: Sampling temperature.
|
|
||||||
reasoning_effort: Optional reasoning effort parameter.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
LLMResponse with content and/or tool calls.
|
|
||||||
"""
|
|
||||||
deployment_name = model or self.default_model
|
|
||||||
url = self._build_chat_url(deployment_name)
|
|
||||||
headers = self._build_headers()
|
|
||||||
payload = self._prepare_request_payload(
|
|
||||||
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
|
response = await self._client.responses.create(**body)
|
||||||
response = await client.post(url, headers=headers, json=payload)
|
return parse_response_output(response)
|
||||||
if response.status_code != 200:
|
|
||||||
return LLMResponse(
|
|
||||||
content=f"Azure OpenAI API Error {response.status_code}: {response.text}",
|
|
||||||
finish_reason="error",
|
|
||||||
)
|
|
||||||
|
|
||||||
response_data = response.json()
|
|
||||||
return self._parse_response(response_data)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return LLMResponse(
|
return self._handle_error(e)
|
||||||
content=f"Error calling Azure OpenAI: {repr(e)}",
|
|
||||||
finish_reason="error",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_response(self, response: dict[str, Any]) -> LLMResponse:
|
|
||||||
"""Parse Azure OpenAI response into our standard format."""
|
|
||||||
try:
|
|
||||||
choice = response["choices"][0]
|
|
||||||
message = choice["message"]
|
|
||||||
|
|
||||||
tool_calls = []
|
|
||||||
if message.get("tool_calls"):
|
|
||||||
for tc in message["tool_calls"]:
|
|
||||||
# Parse arguments from JSON string if needed
|
|
||||||
args = tc["function"]["arguments"]
|
|
||||||
if isinstance(args, str):
|
|
||||||
args = json_repair.loads(args)
|
|
||||||
|
|
||||||
tool_calls.append(
|
|
||||||
ToolCallRequest(
|
|
||||||
id=tc["id"],
|
|
||||||
name=tc["function"]["name"],
|
|
||||||
arguments=args,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
usage = {}
|
|
||||||
if response.get("usage"):
|
|
||||||
usage_data = response["usage"]
|
|
||||||
usage = {
|
|
||||||
"prompt_tokens": usage_data.get("prompt_tokens", 0),
|
|
||||||
"completion_tokens": usage_data.get("completion_tokens", 0),
|
|
||||||
"total_tokens": usage_data.get("total_tokens", 0),
|
|
||||||
}
|
|
||||||
|
|
||||||
reasoning_content = message.get("reasoning_content") or None
|
|
||||||
|
|
||||||
return LLMResponse(
|
|
||||||
content=message.get("content"),
|
|
||||||
tool_calls=tool_calls,
|
|
||||||
finish_reason=choice.get("finish_reason", "stop"),
|
|
||||||
usage=usage,
|
|
||||||
reasoning_content=reasoning_content,
|
|
||||||
)
|
|
||||||
|
|
||||||
except (KeyError, IndexError) as e:
|
|
||||||
return LLMResponse(
|
|
||||||
content=f"Error parsing Azure OpenAI response: {str(e)}",
|
|
||||||
finish_reason="error",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def chat_stream(
|
async def chat_stream(
|
||||||
self,
|
self,
|
||||||
@ -221,89 +152,26 @@ class AzureOpenAIProvider(LLMProvider):
|
|||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""Stream a chat completion via Azure OpenAI SSE."""
|
body = self._build_body(
|
||||||
deployment_name = model or self.default_model
|
messages, tools, model, max_tokens, temperature,
|
||||||
url = self._build_chat_url(deployment_name)
|
reasoning_effort, tool_choice,
|
||||||
headers = self._build_headers()
|
|
||||||
payload = self._prepare_request_payload(
|
|
||||||
deployment_name, messages, tools, max_tokens, temperature,
|
|
||||||
reasoning_effort, tool_choice=tool_choice,
|
|
||||||
)
|
)
|
||||||
payload["stream"] = True
|
body["stream"] = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
|
stream = await self._client.responses.create(**body)
|
||||||
async with client.stream("POST", url, headers=headers, json=payload) as response:
|
content, tool_calls, finish_reason, usage, reasoning_content = (
|
||||||
if response.status_code != 200:
|
await consume_sdk_stream(stream, on_content_delta)
|
||||||
text = await response.aread()
|
|
||||||
return LLMResponse(
|
|
||||||
content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}",
|
|
||||||
finish_reason="error",
|
|
||||||
)
|
)
|
||||||
return await self._consume_stream(response, on_content_delta)
|
|
||||||
except Exception as e:
|
|
||||||
return LLMResponse(content=f"Error calling Azure OpenAI: {repr(e)}", finish_reason="error")
|
|
||||||
|
|
||||||
async def _consume_stream(
|
|
||||||
self,
|
|
||||||
response: httpx.Response,
|
|
||||||
on_content_delta: Callable[[str], Awaitable[None]] | None,
|
|
||||||
) -> LLMResponse:
|
|
||||||
"""Parse Azure OpenAI SSE stream into an LLMResponse."""
|
|
||||||
content_parts: list[str] = []
|
|
||||||
tool_call_buffers: dict[int, dict[str, str]] = {}
|
|
||||||
finish_reason = "stop"
|
|
||||||
|
|
||||||
async for line in response.aiter_lines():
|
|
||||||
if not line.startswith("data: "):
|
|
||||||
continue
|
|
||||||
data = line[6:].strip()
|
|
||||||
if data == "[DONE]":
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
chunk = json.loads(data)
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
choices = chunk.get("choices") or []
|
|
||||||
if not choices:
|
|
||||||
continue
|
|
||||||
choice = choices[0]
|
|
||||||
if choice.get("finish_reason"):
|
|
||||||
finish_reason = choice["finish_reason"]
|
|
||||||
delta = choice.get("delta") or {}
|
|
||||||
|
|
||||||
text = delta.get("content")
|
|
||||||
if text:
|
|
||||||
content_parts.append(text)
|
|
||||||
if on_content_delta:
|
|
||||||
await on_content_delta(text)
|
|
||||||
|
|
||||||
for tc in delta.get("tool_calls") or []:
|
|
||||||
idx = tc.get("index", 0)
|
|
||||||
buf = tool_call_buffers.setdefault(idx, {"id": "", "name": "", "arguments": ""})
|
|
||||||
if tc.get("id"):
|
|
||||||
buf["id"] = tc["id"]
|
|
||||||
fn = tc.get("function") or {}
|
|
||||||
if fn.get("name"):
|
|
||||||
buf["name"] = fn["name"]
|
|
||||||
if fn.get("arguments"):
|
|
||||||
buf["arguments"] += fn["arguments"]
|
|
||||||
|
|
||||||
tool_calls = [
|
|
||||||
ToolCallRequest(
|
|
||||||
id=buf["id"], name=buf["name"],
|
|
||||||
arguments=json_repair.loads(buf["arguments"]) if buf["arguments"] else {},
|
|
||||||
)
|
|
||||||
for buf in tool_call_buffers.values()
|
|
||||||
]
|
|
||||||
|
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content="".join(content_parts) or None,
|
content=content or None,
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
|
usage=usage,
|
||||||
|
reasoning_content=reasoning_content,
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return self._handle_error(e)
|
||||||
|
|
||||||
def get_default_model(self) -> str:
|
def get_default_model(self) -> str:
|
||||||
"""Get the default model (also used as default deployment name)."""
|
|
||||||
return self.default_model
|
return self.default_model
|
||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
@ -9,6 +10,8 @@ from typing import Any
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from nanobot.utils.helpers import image_placeholder_text
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolCallRequest:
|
class ToolCallRequest:
|
||||||
@ -57,13 +60,7 @@ class LLMResponse:
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class GenerationSettings:
|
class GenerationSettings:
|
||||||
"""Default generation parameters for LLM calls.
|
"""Default generation settings."""
|
||||||
|
|
||||||
Stored on the provider so every call site inherits the same defaults
|
|
||||||
without having to pass temperature / max_tokens / reasoning_effort
|
|
||||||
through every layer. Individual call sites can still override by
|
|
||||||
passing explicit keyword arguments to chat() / chat_with_retry().
|
|
||||||
"""
|
|
||||||
|
|
||||||
temperature: float = 0.7
|
temperature: float = 0.7
|
||||||
max_tokens: int = 4096
|
max_tokens: int = 4096
|
||||||
@ -71,14 +68,12 @@ class GenerationSettings:
|
|||||||
|
|
||||||
|
|
||||||
class LLMProvider(ABC):
|
class LLMProvider(ABC):
|
||||||
"""
|
"""Base class for LLM providers."""
|
||||||
Abstract base class for LLM providers.
|
|
||||||
|
|
||||||
Implementations should handle the specifics of each provider's API
|
|
||||||
while maintaining a consistent interface.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_CHAT_RETRY_DELAYS = (1, 2, 4)
|
_CHAT_RETRY_DELAYS = (1, 2, 4)
|
||||||
|
_PERSISTENT_MAX_DELAY = 60
|
||||||
|
_PERSISTENT_IDENTICAL_ERROR_LIMIT = 10
|
||||||
|
_RETRY_HEARTBEAT_CHUNK = 30
|
||||||
_TRANSIENT_ERROR_MARKERS = (
|
_TRANSIENT_ERROR_MARKERS = (
|
||||||
"429",
|
"429",
|
||||||
"rate limit",
|
"rate limit",
|
||||||
@ -208,7 +203,7 @@ class LLMProvider(ABC):
|
|||||||
for b in content:
|
for b in content:
|
||||||
if isinstance(b, dict) and b.get("type") == "image_url":
|
if isinstance(b, dict) and b.get("type") == "image_url":
|
||||||
path = (b.get("_meta") or {}).get("path", "")
|
path = (b.get("_meta") or {}).get("path", "")
|
||||||
placeholder = f"[image: {path}]" if path else "[image omitted]"
|
placeholder = image_placeholder_text(path, empty="[image omitted]")
|
||||||
new_content.append({"type": "text", "text": placeholder})
|
new_content.append({"type": "text", "text": placeholder})
|
||||||
found = True
|
found = True
|
||||||
else:
|
else:
|
||||||
@ -273,6 +268,8 @@ class LLMProvider(ABC):
|
|||||||
reasoning_effort: object = _SENTINEL,
|
reasoning_effort: object = _SENTINEL,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
retry_mode: str = "standard",
|
||||||
|
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""Call chat_stream() with retry on transient provider failures."""
|
"""Call chat_stream() with retry on transient provider failures."""
|
||||||
if max_tokens is self._SENTINEL:
|
if max_tokens is self._SENTINEL:
|
||||||
@ -288,28 +285,13 @@ class LLMProvider(ABC):
|
|||||||
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
||||||
on_content_delta=on_content_delta,
|
on_content_delta=on_content_delta,
|
||||||
)
|
)
|
||||||
|
return await self._run_with_retry(
|
||||||
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
self._safe_chat_stream,
|
||||||
response = await self._safe_chat_stream(**kw)
|
kw,
|
||||||
|
messages,
|
||||||
if response.finish_reason != "error":
|
retry_mode=retry_mode,
|
||||||
return response
|
on_retry_wait=on_retry_wait,
|
||||||
|
|
||||||
if not self._is_transient_error(response.content):
|
|
||||||
stripped = self._strip_image_content(messages)
|
|
||||||
if stripped is not None:
|
|
||||||
logger.warning("Non-transient LLM error with image content, retrying without images")
|
|
||||||
return await self._safe_chat_stream(**{**kw, "messages": stripped})
|
|
||||||
return response
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
|
|
||||||
attempt, len(self._CHAT_RETRY_DELAYS), delay,
|
|
||||||
(response.content or "")[:120].lower(),
|
|
||||||
)
|
)
|
||||||
await asyncio.sleep(delay)
|
|
||||||
|
|
||||||
return await self._safe_chat_stream(**kw)
|
|
||||||
|
|
||||||
async def chat_with_retry(
|
async def chat_with_retry(
|
||||||
self,
|
self,
|
||||||
@ -320,6 +302,8 @@ class LLMProvider(ABC):
|
|||||||
temperature: object = _SENTINEL,
|
temperature: object = _SENTINEL,
|
||||||
reasoning_effort: object = _SENTINEL,
|
reasoning_effort: object = _SENTINEL,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
retry_mode: str = "standard",
|
||||||
|
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""Call chat() with retry on transient provider failures.
|
"""Call chat() with retry on transient provider failures.
|
||||||
|
|
||||||
@ -339,28 +323,118 @@ class LLMProvider(ABC):
|
|||||||
max_tokens=max_tokens, temperature=temperature,
|
max_tokens=max_tokens, temperature=temperature,
|
||||||
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
|
return await self._run_with_retry(
|
||||||
|
self._safe_chat,
|
||||||
|
kw,
|
||||||
|
messages,
|
||||||
|
retry_mode=retry_mode,
|
||||||
|
on_retry_wait=on_retry_wait,
|
||||||
|
)
|
||||||
|
|
||||||
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
@classmethod
|
||||||
response = await self._safe_chat(**kw)
|
def _extract_retry_after(cls, content: str | None) -> float | None:
|
||||||
|
text = (content or "").lower()
|
||||||
|
match = re.search(r"retry after\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)?", text)
|
||||||
|
if not match:
|
||||||
|
return None
|
||||||
|
value = float(match.group(1))
|
||||||
|
unit = (match.group(2) or "s").lower()
|
||||||
|
if unit in {"ms", "milliseconds"}:
|
||||||
|
return max(0.1, value / 1000.0)
|
||||||
|
if unit in {"m", "min", "minutes"}:
|
||||||
|
return value * 60.0
|
||||||
|
return value
|
||||||
|
|
||||||
|
async def _sleep_with_heartbeat(
|
||||||
|
self,
|
||||||
|
delay: float,
|
||||||
|
*,
|
||||||
|
attempt: int,
|
||||||
|
persistent: bool,
|
||||||
|
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> None:
|
||||||
|
remaining = max(0.0, delay)
|
||||||
|
while remaining > 0:
|
||||||
|
if on_retry_wait:
|
||||||
|
kind = "persistent retry" if persistent else "retry"
|
||||||
|
await on_retry_wait(
|
||||||
|
f"Model request failed, {kind} in {max(1, int(round(remaining)))}s "
|
||||||
|
f"(attempt {attempt})."
|
||||||
|
)
|
||||||
|
chunk = min(remaining, self._RETRY_HEARTBEAT_CHUNK)
|
||||||
|
await asyncio.sleep(chunk)
|
||||||
|
remaining -= chunk
|
||||||
|
|
||||||
|
async def _run_with_retry(
|
||||||
|
self,
|
||||||
|
call: Callable[..., Awaitable[LLMResponse]],
|
||||||
|
kw: dict[str, Any],
|
||||||
|
original_messages: list[dict[str, Any]],
|
||||||
|
*,
|
||||||
|
retry_mode: str,
|
||||||
|
on_retry_wait: Callable[[str], Awaitable[None]] | None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
attempt = 0
|
||||||
|
delays = list(self._CHAT_RETRY_DELAYS)
|
||||||
|
persistent = retry_mode == "persistent"
|
||||||
|
last_response: LLMResponse | None = None
|
||||||
|
last_error_key: str | None = None
|
||||||
|
identical_error_count = 0
|
||||||
|
while True:
|
||||||
|
attempt += 1
|
||||||
|
response = await call(**kw)
|
||||||
if response.finish_reason != "error":
|
if response.finish_reason != "error":
|
||||||
return response
|
return response
|
||||||
|
last_response = response
|
||||||
|
error_key = ((response.content or "").strip().lower() or None)
|
||||||
|
if error_key and error_key == last_error_key:
|
||||||
|
identical_error_count += 1
|
||||||
|
else:
|
||||||
|
last_error_key = error_key
|
||||||
|
identical_error_count = 1 if error_key else 0
|
||||||
|
|
||||||
if not self._is_transient_error(response.content):
|
if not self._is_transient_error(response.content):
|
||||||
stripped = self._strip_image_content(messages)
|
stripped = self._strip_image_content(original_messages)
|
||||||
if stripped is not None:
|
if stripped is not None and stripped != kw["messages"]:
|
||||||
logger.warning("Non-transient LLM error with image content, retrying without images")
|
logger.warning(
|
||||||
return await self._safe_chat(**{**kw, "messages": stripped})
|
"Non-transient LLM error with image content, retrying without images"
|
||||||
|
)
|
||||||
|
retry_kw = dict(kw)
|
||||||
|
retry_kw["messages"] = stripped
|
||||||
|
return await call(**retry_kw)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
if persistent and identical_error_count >= self._PERSISTENT_IDENTICAL_ERROR_LIMIT:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
|
"Stopping persistent retry after {} identical transient errors: {}",
|
||||||
attempt, len(self._CHAT_RETRY_DELAYS), delay,
|
identical_error_count,
|
||||||
(response.content or "")[:120].lower(),
|
(response.content or "")[:120].lower(),
|
||||||
)
|
)
|
||||||
await asyncio.sleep(delay)
|
return response
|
||||||
|
|
||||||
return await self._safe_chat(**kw)
|
if not persistent and attempt > len(delays):
|
||||||
|
break
|
||||||
|
|
||||||
|
base_delay = delays[min(attempt - 1, len(delays) - 1)]
|
||||||
|
delay = self._extract_retry_after(response.content) or base_delay
|
||||||
|
if persistent:
|
||||||
|
delay = min(delay, self._PERSISTENT_MAX_DELAY)
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"LLM transient error (attempt {}{}), retrying in {}s: {}",
|
||||||
|
attempt,
|
||||||
|
"+" if persistent and attempt > len(delays) else f"/{len(delays)}",
|
||||||
|
int(round(delay)),
|
||||||
|
(response.content or "")[:120].lower(),
|
||||||
|
)
|
||||||
|
await self._sleep_with_heartbeat(
|
||||||
|
delay,
|
||||||
|
attempt=attempt,
|
||||||
|
persistent=persistent,
|
||||||
|
on_retry_wait=on_retry_wait,
|
||||||
|
)
|
||||||
|
|
||||||
|
return last_response if last_response is not None else await call(**kw)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_default_model(self) -> str:
|
def get_default_model(self) -> str:
|
||||||
|
|||||||
@ -6,13 +6,18 @@ import asyncio
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Any, AsyncGenerator
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from oauth_cli_kit import get_token as get_codex_token
|
from oauth_cli_kit import get_token as get_codex_token
|
||||||
|
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
from nanobot.providers.openai_responses import (
|
||||||
|
consume_sse,
|
||||||
|
convert_messages,
|
||||||
|
convert_tools,
|
||||||
|
)
|
||||||
|
|
||||||
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
|
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
|
||||||
DEFAULT_ORIGINATOR = "nanobot"
|
DEFAULT_ORIGINATOR = "nanobot"
|
||||||
@ -36,7 +41,7 @@ class OpenAICodexProvider(LLMProvider):
|
|||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""Shared request logic for both chat() and chat_stream()."""
|
"""Shared request logic for both chat() and chat_stream()."""
|
||||||
model = model or self.default_model
|
model = model or self.default_model
|
||||||
system_prompt, input_items = _convert_messages(messages)
|
system_prompt, input_items = convert_messages(messages)
|
||||||
|
|
||||||
token = await asyncio.to_thread(get_codex_token)
|
token = await asyncio.to_thread(get_codex_token)
|
||||||
headers = _build_headers(token.account_id, token.access)
|
headers = _build_headers(token.account_id, token.access)
|
||||||
@ -56,7 +61,7 @@ class OpenAICodexProvider(LLMProvider):
|
|||||||
if reasoning_effort:
|
if reasoning_effort:
|
||||||
body["reasoning"] = {"effort": reasoning_effort}
|
body["reasoning"] = {"effort": reasoning_effort}
|
||||||
if tools:
|
if tools:
|
||||||
body["tools"] = _convert_tools(tools)
|
body["tools"] = convert_tools(tools)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
@ -127,96 +132,7 @@ async def _request_codex(
|
|||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
text = await response.aread()
|
text = await response.aread()
|
||||||
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
|
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
|
||||||
return await _consume_sse(response, on_content_delta)
|
return await consume_sse(response, on_content_delta)
|
||||||
|
|
||||||
|
|
||||||
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
||||||
"""Convert OpenAI function-calling schema to Codex flat format."""
|
|
||||||
converted: list[dict[str, Any]] = []
|
|
||||||
for tool in tools:
|
|
||||||
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
|
|
||||||
name = fn.get("name")
|
|
||||||
if not name:
|
|
||||||
continue
|
|
||||||
params = fn.get("parameters") or {}
|
|
||||||
converted.append({
|
|
||||||
"type": "function",
|
|
||||||
"name": name,
|
|
||||||
"description": fn.get("description") or "",
|
|
||||||
"parameters": params if isinstance(params, dict) else {},
|
|
||||||
})
|
|
||||||
return converted
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
|
||||||
system_prompt = ""
|
|
||||||
input_items: list[dict[str, Any]] = []
|
|
||||||
|
|
||||||
for idx, msg in enumerate(messages):
|
|
||||||
role = msg.get("role")
|
|
||||||
content = msg.get("content")
|
|
||||||
|
|
||||||
if role == "system":
|
|
||||||
system_prompt = content if isinstance(content, str) else ""
|
|
||||||
continue
|
|
||||||
|
|
||||||
if role == "user":
|
|
||||||
input_items.append(_convert_user_message(content))
|
|
||||||
continue
|
|
||||||
|
|
||||||
if role == "assistant":
|
|
||||||
if isinstance(content, str) and content:
|
|
||||||
input_items.append({
|
|
||||||
"type": "message", "role": "assistant",
|
|
||||||
"content": [{"type": "output_text", "text": content}],
|
|
||||||
"status": "completed", "id": f"msg_{idx}",
|
|
||||||
})
|
|
||||||
for tool_call in msg.get("tool_calls", []) or []:
|
|
||||||
fn = tool_call.get("function") or {}
|
|
||||||
call_id, item_id = _split_tool_call_id(tool_call.get("id"))
|
|
||||||
input_items.append({
|
|
||||||
"type": "function_call",
|
|
||||||
"id": item_id or f"fc_{idx}",
|
|
||||||
"call_id": call_id or f"call_{idx}",
|
|
||||||
"name": fn.get("name"),
|
|
||||||
"arguments": fn.get("arguments") or "{}",
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
|
|
||||||
if role == "tool":
|
|
||||||
call_id, _ = _split_tool_call_id(msg.get("tool_call_id"))
|
|
||||||
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
|
|
||||||
input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text})
|
|
||||||
|
|
||||||
return system_prompt, input_items
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_user_message(content: Any) -> dict[str, Any]:
|
|
||||||
if isinstance(content, str):
|
|
||||||
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
|
|
||||||
if isinstance(content, list):
|
|
||||||
converted: list[dict[str, Any]] = []
|
|
||||||
for item in content:
|
|
||||||
if not isinstance(item, dict):
|
|
||||||
continue
|
|
||||||
if item.get("type") == "text":
|
|
||||||
converted.append({"type": "input_text", "text": item.get("text", "")})
|
|
||||||
elif item.get("type") == "image_url":
|
|
||||||
url = (item.get("image_url") or {}).get("url")
|
|
||||||
if url:
|
|
||||||
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
|
|
||||||
if converted:
|
|
||||||
return {"role": "user", "content": converted}
|
|
||||||
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
|
|
||||||
|
|
||||||
|
|
||||||
def _split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
|
|
||||||
if isinstance(tool_call_id, str) and tool_call_id:
|
|
||||||
if "|" in tool_call_id:
|
|
||||||
call_id, item_id = tool_call_id.split("|", 1)
|
|
||||||
return call_id, item_id or None
|
|
||||||
return tool_call_id, None
|
|
||||||
return "call_0", None
|
|
||||||
|
|
||||||
|
|
||||||
def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
|
def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
|
||||||
@ -224,96 +140,6 @@ def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
|
|||||||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
|
|
||||||
buffer: list[str] = []
|
|
||||||
async for line in response.aiter_lines():
|
|
||||||
if line == "":
|
|
||||||
if buffer:
|
|
||||||
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
|
|
||||||
buffer = []
|
|
||||||
if not data_lines:
|
|
||||||
continue
|
|
||||||
data = "\n".join(data_lines).strip()
|
|
||||||
if not data or data == "[DONE]":
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
yield json.loads(data)
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
continue
|
|
||||||
buffer.append(line)
|
|
||||||
|
|
||||||
|
|
||||||
async def _consume_sse(
|
|
||||||
response: httpx.Response,
|
|
||||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
|
||||||
) -> tuple[str, list[ToolCallRequest], str]:
|
|
||||||
content = ""
|
|
||||||
tool_calls: list[ToolCallRequest] = []
|
|
||||||
tool_call_buffers: dict[str, dict[str, Any]] = {}
|
|
||||||
finish_reason = "stop"
|
|
||||||
|
|
||||||
async for event in _iter_sse(response):
|
|
||||||
event_type = event.get("type")
|
|
||||||
if event_type == "response.output_item.added":
|
|
||||||
item = event.get("item") or {}
|
|
||||||
if item.get("type") == "function_call":
|
|
||||||
call_id = item.get("call_id")
|
|
||||||
if not call_id:
|
|
||||||
continue
|
|
||||||
tool_call_buffers[call_id] = {
|
|
||||||
"id": item.get("id") or "fc_0",
|
|
||||||
"name": item.get("name"),
|
|
||||||
"arguments": item.get("arguments") or "",
|
|
||||||
}
|
|
||||||
elif event_type == "response.output_text.delta":
|
|
||||||
delta_text = event.get("delta") or ""
|
|
||||||
content += delta_text
|
|
||||||
if on_content_delta and delta_text:
|
|
||||||
await on_content_delta(delta_text)
|
|
||||||
elif event_type == "response.function_call_arguments.delta":
|
|
||||||
call_id = event.get("call_id")
|
|
||||||
if call_id and call_id in tool_call_buffers:
|
|
||||||
tool_call_buffers[call_id]["arguments"] += event.get("delta") or ""
|
|
||||||
elif event_type == "response.function_call_arguments.done":
|
|
||||||
call_id = event.get("call_id")
|
|
||||||
if call_id and call_id in tool_call_buffers:
|
|
||||||
tool_call_buffers[call_id]["arguments"] = event.get("arguments") or ""
|
|
||||||
elif event_type == "response.output_item.done":
|
|
||||||
item = event.get("item") or {}
|
|
||||||
if item.get("type") == "function_call":
|
|
||||||
call_id = item.get("call_id")
|
|
||||||
if not call_id:
|
|
||||||
continue
|
|
||||||
buf = tool_call_buffers.get(call_id) or {}
|
|
||||||
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
|
|
||||||
try:
|
|
||||||
args = json.loads(args_raw)
|
|
||||||
except Exception:
|
|
||||||
args = {"raw": args_raw}
|
|
||||||
tool_calls.append(
|
|
||||||
ToolCallRequest(
|
|
||||||
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
|
|
||||||
name=buf.get("name") or item.get("name"),
|
|
||||||
arguments=args,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif event_type == "response.completed":
|
|
||||||
status = (event.get("response") or {}).get("status")
|
|
||||||
finish_reason = _map_finish_reason(status)
|
|
||||||
elif event_type in {"error", "response.failed"}:
|
|
||||||
raise RuntimeError("Codex response failed")
|
|
||||||
|
|
||||||
return content, tool_calls, finish_reason
|
|
||||||
|
|
||||||
|
|
||||||
_FINISH_REASON_MAP = {"completed": "stop", "incomplete": "length", "failed": "error", "cancelled": "error"}
|
|
||||||
|
|
||||||
|
|
||||||
def _map_finish_reason(status: str | None) -> str:
|
|
||||||
return _FINISH_REASON_MAP.get(status or "completed", "stop")
|
|
||||||
|
|
||||||
|
|
||||||
def _friendly_error(status_code: int, raw: str) -> str:
|
def _friendly_error(status_code: int, raw: str) -> str:
|
||||||
if status_code == 429:
|
if status_code == 429:
|
||||||
return "ChatGPT usage quota exceeded or rate limit triggered. Please try again later."
|
return "ChatGPT usage quota exceeded or rate limit triggered. Please try again later."
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
@ -20,7 +21,6 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
_ALLOWED_MSG_KEYS = frozenset({
|
_ALLOWED_MSG_KEYS = frozenset({
|
||||||
"role", "content", "tool_calls", "tool_call_id", "name",
|
"role", "content", "tool_calls", "tool_call_id", "name",
|
||||||
"reasoning_content", "extra_content",
|
|
||||||
})
|
})
|
||||||
_ALNUM = string.ascii_letters + string.digits
|
_ALNUM = string.ascii_letters + string.digits
|
||||||
|
|
||||||
@ -615,16 +615,33 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
)
|
)
|
||||||
kwargs["stream"] = True
|
kwargs["stream"] = True
|
||||||
kwargs["stream_options"] = {"include_usage": True}
|
kwargs["stream_options"] = {"include_usage": True}
|
||||||
|
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
|
||||||
try:
|
try:
|
||||||
stream = await self._client.chat.completions.create(**kwargs)
|
stream = await self._client.chat.completions.create(**kwargs)
|
||||||
chunks: list[Any] = []
|
chunks: list[Any] = []
|
||||||
async for chunk in stream:
|
stream_iter = stream.__aiter__()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
chunk = await asyncio.wait_for(
|
||||||
|
stream_iter.__anext__(),
|
||||||
|
timeout=idle_timeout_s,
|
||||||
|
)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
if on_content_delta and chunk.choices:
|
if on_content_delta and chunk.choices:
|
||||||
text = getattr(chunk.choices[0].delta, "content", None)
|
text = getattr(chunk.choices[0].delta, "content", None)
|
||||||
if text:
|
if text:
|
||||||
await on_content_delta(text)
|
await on_content_delta(text)
|
||||||
return self._parse_chunks(chunks)
|
return self._parse_chunks(chunks)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return LLMResponse(
|
||||||
|
content=(
|
||||||
|
f"Error calling LLM: stream stalled for more than "
|
||||||
|
f"{idle_timeout_s} seconds"
|
||||||
|
),
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return self._handle_error(e)
|
return self._handle_error(e)
|
||||||
|
|
||||||
|
|||||||
29
nanobot/providers/openai_responses/__init__.py
Normal file
29
nanobot/providers/openai_responses/__init__.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
"""Shared helpers for OpenAI Responses API providers (Codex, Azure OpenAI)."""
|
||||||
|
|
||||||
|
from nanobot.providers.openai_responses.converters import (
|
||||||
|
convert_messages,
|
||||||
|
convert_tools,
|
||||||
|
convert_user_message,
|
||||||
|
split_tool_call_id,
|
||||||
|
)
|
||||||
|
from nanobot.providers.openai_responses.parsing import (
|
||||||
|
FINISH_REASON_MAP,
|
||||||
|
consume_sdk_stream,
|
||||||
|
consume_sse,
|
||||||
|
iter_sse,
|
||||||
|
map_finish_reason,
|
||||||
|
parse_response_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"convert_messages",
|
||||||
|
"convert_tools",
|
||||||
|
"convert_user_message",
|
||||||
|
"split_tool_call_id",
|
||||||
|
"iter_sse",
|
||||||
|
"consume_sse",
|
||||||
|
"consume_sdk_stream",
|
||||||
|
"map_finish_reason",
|
||||||
|
"parse_response_output",
|
||||||
|
"FINISH_REASON_MAP",
|
||||||
|
]
|
||||||
110
nanobot/providers/openai_responses/converters.py
Normal file
110
nanobot/providers/openai_responses/converters.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
"""Convert Chat Completions messages/tools to Responses API format."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
||||||
|
"""Convert Chat Completions messages to Responses API input items.
|
||||||
|
|
||||||
|
Returns ``(system_prompt, input_items)`` where *system_prompt* is extracted
|
||||||
|
from any ``system`` role message and *input_items* is the Responses API
|
||||||
|
``input`` array.
|
||||||
|
"""
|
||||||
|
system_prompt = ""
|
||||||
|
input_items: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for idx, msg in enumerate(messages):
|
||||||
|
role = msg.get("role")
|
||||||
|
content = msg.get("content")
|
||||||
|
|
||||||
|
if role == "system":
|
||||||
|
system_prompt = content if isinstance(content, str) else ""
|
||||||
|
continue
|
||||||
|
|
||||||
|
if role == "user":
|
||||||
|
input_items.append(convert_user_message(content))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if role == "assistant":
|
||||||
|
if isinstance(content, str) and content:
|
||||||
|
input_items.append({
|
||||||
|
"type": "message", "role": "assistant",
|
||||||
|
"content": [{"type": "output_text", "text": content}],
|
||||||
|
"status": "completed", "id": f"msg_{idx}",
|
||||||
|
})
|
||||||
|
for tool_call in msg.get("tool_calls", []) or []:
|
||||||
|
fn = tool_call.get("function") or {}
|
||||||
|
call_id, item_id = split_tool_call_id(tool_call.get("id"))
|
||||||
|
input_items.append({
|
||||||
|
"type": "function_call",
|
||||||
|
"id": item_id or f"fc_{idx}",
|
||||||
|
"call_id": call_id or f"call_{idx}",
|
||||||
|
"name": fn.get("name"),
|
||||||
|
"arguments": fn.get("arguments") or "{}",
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if role == "tool":
|
||||||
|
call_id, _ = split_tool_call_id(msg.get("tool_call_id"))
|
||||||
|
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
|
||||||
|
input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text})
|
||||||
|
|
||||||
|
return system_prompt, input_items
|
||||||
|
|
||||||
|
|
||||||
|
def convert_user_message(content: Any) -> dict[str, Any]:
|
||||||
|
"""Convert a user message's content to Responses API format.
|
||||||
|
|
||||||
|
Handles plain strings, ``text`` blocks -> ``input_text``, and
|
||||||
|
``image_url`` blocks -> ``input_image``.
|
||||||
|
"""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
|
||||||
|
if isinstance(content, list):
|
||||||
|
converted: list[dict[str, Any]] = []
|
||||||
|
for item in content:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
if item.get("type") == "text":
|
||||||
|
converted.append({"type": "input_text", "text": item.get("text", "")})
|
||||||
|
elif item.get("type") == "image_url":
|
||||||
|
url = (item.get("image_url") or {}).get("url")
|
||||||
|
if url:
|
||||||
|
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
|
||||||
|
if converted:
|
||||||
|
return {"role": "user", "content": converted}
|
||||||
|
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
|
||||||
|
|
||||||
|
|
||||||
|
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Convert OpenAI function-calling tool schema to Responses API flat format."""
|
||||||
|
converted: list[dict[str, Any]] = []
|
||||||
|
for tool in tools:
|
||||||
|
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
|
||||||
|
name = fn.get("name")
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
params = fn.get("parameters") or {}
|
||||||
|
converted.append({
|
||||||
|
"type": "function",
|
||||||
|
"name": name,
|
||||||
|
"description": fn.get("description") or "",
|
||||||
|
"parameters": params if isinstance(params, dict) else {},
|
||||||
|
})
|
||||||
|
return converted
|
||||||
|
|
||||||
|
|
||||||
|
def split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
|
||||||
|
"""Split a compound ``call_id|item_id`` string.
|
||||||
|
|
||||||
|
Returns ``(call_id, item_id)`` where *item_id* may be ``None``.
|
||||||
|
"""
|
||||||
|
if isinstance(tool_call_id, str) and tool_call_id:
|
||||||
|
if "|" in tool_call_id:
|
||||||
|
call_id, item_id = tool_call_id.split("|", 1)
|
||||||
|
return call_id, item_id or None
|
||||||
|
return tool_call_id, None
|
||||||
|
return "call_0", None
|
||||||
297
nanobot/providers/openai_responses/parsing.py
Normal file
297
nanobot/providers/openai_responses/parsing.py
Normal file
@ -0,0 +1,297 @@
|
|||||||
|
"""Parse Responses API SSE streams and SDK response objects."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import json_repair
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
FINISH_REASON_MAP = {
|
||||||
|
"completed": "stop",
|
||||||
|
"incomplete": "length",
|
||||||
|
"failed": "error",
|
||||||
|
"cancelled": "error",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def map_finish_reason(status: str | None) -> str:
|
||||||
|
"""Map a Responses API status string to a Chat-Completions-style finish_reason."""
|
||||||
|
return FINISH_REASON_MAP.get(status or "completed", "stop")
|
||||||
|
|
||||||
|
|
||||||
|
async def iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
|
||||||
|
"""Yield parsed JSON events from a Responses API SSE stream."""
|
||||||
|
buffer: list[str] = []
|
||||||
|
|
||||||
|
def _flush() -> dict[str, Any] | None:
|
||||||
|
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
|
||||||
|
buffer.clear()
|
||||||
|
if not data_lines:
|
||||||
|
return None
|
||||||
|
data = "\n".join(data_lines).strip()
|
||||||
|
if not data or data == "[DONE]":
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return json.loads(data)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to parse SSE event JSON: {}", data[:200])
|
||||||
|
return None
|
||||||
|
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line == "":
|
||||||
|
if buffer:
|
||||||
|
event = _flush()
|
||||||
|
if event is not None:
|
||||||
|
yield event
|
||||||
|
continue
|
||||||
|
buffer.append(line)
|
||||||
|
|
||||||
|
# Flush any remaining buffer at EOF (#10)
|
||||||
|
if buffer:
|
||||||
|
event = _flush()
|
||||||
|
if event is not None:
|
||||||
|
yield event
|
||||||
|
|
||||||
|
|
||||||
|
async def consume_sse(
|
||||||
|
response: httpx.Response,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> tuple[str, list[ToolCallRequest], str]:
|
||||||
|
"""Consume a Responses API SSE stream into ``(content, tool_calls, finish_reason)``."""
|
||||||
|
content = ""
|
||||||
|
tool_calls: list[ToolCallRequest] = []
|
||||||
|
tool_call_buffers: dict[str, dict[str, Any]] = {}
|
||||||
|
finish_reason = "stop"
|
||||||
|
|
||||||
|
async for event in iter_sse(response):
|
||||||
|
event_type = event.get("type")
|
||||||
|
if event_type == "response.output_item.added":
|
||||||
|
item = event.get("item") or {}
|
||||||
|
if item.get("type") == "function_call":
|
||||||
|
call_id = item.get("call_id")
|
||||||
|
if not call_id:
|
||||||
|
continue
|
||||||
|
tool_call_buffers[call_id] = {
|
||||||
|
"id": item.get("id") or "fc_0",
|
||||||
|
"name": item.get("name"),
|
||||||
|
"arguments": item.get("arguments") or "",
|
||||||
|
}
|
||||||
|
elif event_type == "response.output_text.delta":
|
||||||
|
delta_text = event.get("delta") or ""
|
||||||
|
content += delta_text
|
||||||
|
if on_content_delta and delta_text:
|
||||||
|
await on_content_delta(delta_text)
|
||||||
|
elif event_type == "response.function_call_arguments.delta":
|
||||||
|
call_id = event.get("call_id")
|
||||||
|
if call_id and call_id in tool_call_buffers:
|
||||||
|
tool_call_buffers[call_id]["arguments"] += event.get("delta") or ""
|
||||||
|
elif event_type == "response.function_call_arguments.done":
|
||||||
|
call_id = event.get("call_id")
|
||||||
|
if call_id and call_id in tool_call_buffers:
|
||||||
|
tool_call_buffers[call_id]["arguments"] = event.get("arguments") or ""
|
||||||
|
elif event_type == "response.output_item.done":
|
||||||
|
item = event.get("item") or {}
|
||||||
|
if item.get("type") == "function_call":
|
||||||
|
call_id = item.get("call_id")
|
||||||
|
if not call_id:
|
||||||
|
continue
|
||||||
|
buf = tool_call_buffers.get(call_id) or {}
|
||||||
|
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
|
||||||
|
try:
|
||||||
|
args = json.loads(args_raw)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to parse tool call arguments for '{}': {}",
|
||||||
|
buf.get("name") or item.get("name"),
|
||||||
|
args_raw[:200],
|
||||||
|
)
|
||||||
|
args = json_repair.loads(args_raw)
|
||||||
|
if not isinstance(args, dict):
|
||||||
|
args = {"raw": args_raw}
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCallRequest(
|
||||||
|
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
|
||||||
|
name=buf.get("name") or item.get("name") or "",
|
||||||
|
arguments=args,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif event_type == "response.completed":
|
||||||
|
status = (event.get("response") or {}).get("status")
|
||||||
|
finish_reason = map_finish_reason(status)
|
||||||
|
elif event_type in {"error", "response.failed"}:
|
||||||
|
detail = event.get("error") or event.get("message") or event
|
||||||
|
raise RuntimeError(f"Response failed: {str(detail)[:500]}")
|
||||||
|
|
||||||
|
return content, tool_calls, finish_reason
|
||||||
|
|
||||||
|
|
||||||
|
def parse_response_output(response: Any) -> LLMResponse:
|
||||||
|
"""Parse an SDK ``Response`` object into an ``LLMResponse``."""
|
||||||
|
if not isinstance(response, dict):
|
||||||
|
dump = getattr(response, "model_dump", None)
|
||||||
|
response = dump() if callable(dump) else vars(response)
|
||||||
|
|
||||||
|
output = response.get("output") or []
|
||||||
|
content_parts: list[str] = []
|
||||||
|
tool_calls: list[ToolCallRequest] = []
|
||||||
|
reasoning_content: str | None = None
|
||||||
|
|
||||||
|
for item in output:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
dump = getattr(item, "model_dump", None)
|
||||||
|
item = dump() if callable(dump) else vars(item)
|
||||||
|
|
||||||
|
item_type = item.get("type")
|
||||||
|
if item_type == "message":
|
||||||
|
for block in item.get("content") or []:
|
||||||
|
if not isinstance(block, dict):
|
||||||
|
dump = getattr(block, "model_dump", None)
|
||||||
|
block = dump() if callable(dump) else vars(block)
|
||||||
|
if block.get("type") == "output_text":
|
||||||
|
content_parts.append(block.get("text") or "")
|
||||||
|
elif item_type == "reasoning":
|
||||||
|
for s in item.get("summary") or []:
|
||||||
|
if not isinstance(s, dict):
|
||||||
|
dump = getattr(s, "model_dump", None)
|
||||||
|
s = dump() if callable(dump) else vars(s)
|
||||||
|
if s.get("type") == "summary_text" and s.get("text"):
|
||||||
|
reasoning_content = (reasoning_content or "") + s["text"]
|
||||||
|
elif item_type == "function_call":
|
||||||
|
call_id = item.get("call_id") or ""
|
||||||
|
item_id = item.get("id") or "fc_0"
|
||||||
|
args_raw = item.get("arguments") or "{}"
|
||||||
|
try:
|
||||||
|
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to parse tool call arguments for '{}': {}",
|
||||||
|
item.get("name"),
|
||||||
|
str(args_raw)[:200],
|
||||||
|
)
|
||||||
|
args = json_repair.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||||||
|
if not isinstance(args, dict):
|
||||||
|
args = {"raw": args_raw}
|
||||||
|
tool_calls.append(ToolCallRequest(
|
||||||
|
id=f"{call_id}|{item_id}",
|
||||||
|
name=item.get("name") or "",
|
||||||
|
arguments=args if isinstance(args, dict) else {},
|
||||||
|
))
|
||||||
|
|
||||||
|
usage_raw = response.get("usage") or {}
|
||||||
|
if not isinstance(usage_raw, dict):
|
||||||
|
dump = getattr(usage_raw, "model_dump", None)
|
||||||
|
usage_raw = dump() if callable(dump) else vars(usage_raw)
|
||||||
|
usage = {}
|
||||||
|
if usage_raw:
|
||||||
|
usage = {
|
||||||
|
"prompt_tokens": int(usage_raw.get("input_tokens") or 0),
|
||||||
|
"completion_tokens": int(usage_raw.get("output_tokens") or 0),
|
||||||
|
"total_tokens": int(usage_raw.get("total_tokens") or 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
status = response.get("status")
|
||||||
|
finish_reason = map_finish_reason(status)
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content="".join(content_parts) or None,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=usage,
|
||||||
|
reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def consume_sdk_stream(
|
||||||
|
stream: Any,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> tuple[str, list[ToolCallRequest], str, dict[str, int], str | None]:
|
||||||
|
"""Consume an SDK async stream from ``client.responses.create(stream=True)``."""
|
||||||
|
content = ""
|
||||||
|
tool_calls: list[ToolCallRequest] = []
|
||||||
|
tool_call_buffers: dict[str, dict[str, Any]] = {}
|
||||||
|
finish_reason = "stop"
|
||||||
|
usage: dict[str, int] = {}
|
||||||
|
reasoning_content: str | None = None
|
||||||
|
|
||||||
|
async for event in stream:
|
||||||
|
event_type = getattr(event, "type", None)
|
||||||
|
if event_type == "response.output_item.added":
|
||||||
|
item = getattr(event, "item", None)
|
||||||
|
if item and getattr(item, "type", None) == "function_call":
|
||||||
|
call_id = getattr(item, "call_id", None)
|
||||||
|
if not call_id:
|
||||||
|
continue
|
||||||
|
tool_call_buffers[call_id] = {
|
||||||
|
"id": getattr(item, "id", None) or "fc_0",
|
||||||
|
"name": getattr(item, "name", None),
|
||||||
|
"arguments": getattr(item, "arguments", None) or "",
|
||||||
|
}
|
||||||
|
elif event_type == "response.output_text.delta":
|
||||||
|
delta_text = getattr(event, "delta", "") or ""
|
||||||
|
content += delta_text
|
||||||
|
if on_content_delta and delta_text:
|
||||||
|
await on_content_delta(delta_text)
|
||||||
|
elif event_type == "response.function_call_arguments.delta":
|
||||||
|
call_id = getattr(event, "call_id", None)
|
||||||
|
if call_id and call_id in tool_call_buffers:
|
||||||
|
tool_call_buffers[call_id]["arguments"] += getattr(event, "delta", "") or ""
|
||||||
|
elif event_type == "response.function_call_arguments.done":
|
||||||
|
call_id = getattr(event, "call_id", None)
|
||||||
|
if call_id and call_id in tool_call_buffers:
|
||||||
|
tool_call_buffers[call_id]["arguments"] = getattr(event, "arguments", "") or ""
|
||||||
|
elif event_type == "response.output_item.done":
|
||||||
|
item = getattr(event, "item", None)
|
||||||
|
if item and getattr(item, "type", None) == "function_call":
|
||||||
|
call_id = getattr(item, "call_id", None)
|
||||||
|
if not call_id:
|
||||||
|
continue
|
||||||
|
buf = tool_call_buffers.get(call_id) or {}
|
||||||
|
args_raw = buf.get("arguments") or getattr(item, "arguments", None) or "{}"
|
||||||
|
try:
|
||||||
|
args = json.loads(args_raw)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to parse tool call arguments for '{}': {}",
|
||||||
|
buf.get("name") or getattr(item, "name", None),
|
||||||
|
str(args_raw)[:200],
|
||||||
|
)
|
||||||
|
args = json_repair.loads(args_raw)
|
||||||
|
if not isinstance(args, dict):
|
||||||
|
args = {"raw": args_raw}
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCallRequest(
|
||||||
|
id=f"{call_id}|{buf.get('id') or getattr(item, 'id', None) or 'fc_0'}",
|
||||||
|
name=buf.get("name") or getattr(item, "name", None) or "",
|
||||||
|
arguments=args,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif event_type == "response.completed":
|
||||||
|
resp = getattr(event, "response", None)
|
||||||
|
status = getattr(resp, "status", None) if resp else None
|
||||||
|
finish_reason = map_finish_reason(status)
|
||||||
|
if resp:
|
||||||
|
usage_obj = getattr(resp, "usage", None)
|
||||||
|
if usage_obj:
|
||||||
|
usage = {
|
||||||
|
"prompt_tokens": int(getattr(usage_obj, "input_tokens", 0) or 0),
|
||||||
|
"completion_tokens": int(getattr(usage_obj, "output_tokens", 0) or 0),
|
||||||
|
"total_tokens": int(getattr(usage_obj, "total_tokens", 0) or 0),
|
||||||
|
}
|
||||||
|
for out_item in getattr(resp, "output", None) or []:
|
||||||
|
if getattr(out_item, "type", None) == "reasoning":
|
||||||
|
for s in getattr(out_item, "summary", None) or []:
|
||||||
|
if getattr(s, "type", None) == "summary_text":
|
||||||
|
text = getattr(s, "text", None)
|
||||||
|
if text:
|
||||||
|
reasoning_content = (reasoning_content or "") + text
|
||||||
|
elif event_type in {"error", "response.failed"}:
|
||||||
|
detail = getattr(event, "error", None) or getattr(event, "message", None) or event
|
||||||
|
raise RuntimeError(f"Response failed: {str(detail)[:500]}")
|
||||||
|
|
||||||
|
return content, tool_calls, finish_reason, usage, reasoning_content
|
||||||
@ -10,20 +10,12 @@ from typing import Any
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.config.paths import get_legacy_sessions_dir
|
from nanobot.config.paths import get_legacy_sessions_dir
|
||||||
from nanobot.utils.helpers import ensure_dir, safe_filename
|
from nanobot.utils.helpers import ensure_dir, find_legal_message_start, safe_filename
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Session:
|
class Session:
|
||||||
"""
|
"""A conversation session."""
|
||||||
A conversation session.
|
|
||||||
|
|
||||||
Stores messages in JSONL format for easy reading and persistence.
|
|
||||||
|
|
||||||
Important: Messages are append-only for LLM cache efficiency.
|
|
||||||
The consolidation process writes summaries to MEMORY.md/HISTORY.md
|
|
||||||
but does NOT modify the messages list or get_history() output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
key: str # channel:chat_id
|
key: str # channel:chat_id
|
||||||
messages: list[dict[str, Any]] = field(default_factory=list)
|
messages: list[dict[str, Any]] = field(default_factory=list)
|
||||||
@ -43,43 +35,19 @@ class Session:
|
|||||||
self.messages.append(msg)
|
self.messages.append(msg)
|
||||||
self.updated_at = datetime.now()
|
self.updated_at = datetime.now()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _find_legal_start(messages: list[dict[str, Any]]) -> int:
|
|
||||||
"""Find first index where every tool result has a matching assistant tool_call."""
|
|
||||||
declared: set[str] = set()
|
|
||||||
start = 0
|
|
||||||
for i, msg in enumerate(messages):
|
|
||||||
role = msg.get("role")
|
|
||||||
if role == "assistant":
|
|
||||||
for tc in msg.get("tool_calls") or []:
|
|
||||||
if isinstance(tc, dict) and tc.get("id"):
|
|
||||||
declared.add(str(tc["id"]))
|
|
||||||
elif role == "tool":
|
|
||||||
tid = msg.get("tool_call_id")
|
|
||||||
if tid and str(tid) not in declared:
|
|
||||||
start = i + 1
|
|
||||||
declared.clear()
|
|
||||||
for prev in messages[start:i + 1]:
|
|
||||||
if prev.get("role") == "assistant":
|
|
||||||
for tc in prev.get("tool_calls") or []:
|
|
||||||
if isinstance(tc, dict) and tc.get("id"):
|
|
||||||
declared.add(str(tc["id"]))
|
|
||||||
return start
|
|
||||||
|
|
||||||
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
||||||
"""Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary."""
|
"""Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary."""
|
||||||
unconsolidated = self.messages[self.last_consolidated:]
|
unconsolidated = self.messages[self.last_consolidated:]
|
||||||
sliced = unconsolidated[-max_messages:]
|
sliced = unconsolidated[-max_messages:]
|
||||||
|
|
||||||
# Drop leading non-user messages to avoid starting mid-turn when possible.
|
# Avoid starting mid-turn when possible.
|
||||||
for i, message in enumerate(sliced):
|
for i, message in enumerate(sliced):
|
||||||
if message.get("role") == "user":
|
if message.get("role") == "user":
|
||||||
sliced = sliced[i:]
|
sliced = sliced[i:]
|
||||||
break
|
break
|
||||||
|
|
||||||
# Some providers reject orphan tool results if the matching assistant
|
# Drop orphan tool results at the front.
|
||||||
# tool_calls message fell outside the fixed-size history window.
|
start = find_legal_message_start(sliced)
|
||||||
start = self._find_legal_start(sliced)
|
|
||||||
if start:
|
if start:
|
||||||
sliced = sliced[start:]
|
sliced = sliced[start:]
|
||||||
|
|
||||||
@ -115,7 +83,7 @@ class Session:
|
|||||||
retained = self.messages[start_idx:]
|
retained = self.messages[start_idx:]
|
||||||
|
|
||||||
# Mirror get_history(): avoid persisting orphan tool results at the front.
|
# Mirror get_history(): avoid persisting orphan tool results at the front.
|
||||||
start = self._find_legal_start(retained)
|
start = find_legal_message_start(retained)
|
||||||
if start:
|
if start:
|
||||||
retained = retained[start:]
|
retained = retained[start:]
|
||||||
|
|
||||||
|
|||||||
@ -3,12 +3,15 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import shutil
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
def strip_think(text: str) -> str:
|
def strip_think(text: str) -> str:
|
||||||
@ -56,11 +59,7 @@ def timestamp() -> str:
|
|||||||
|
|
||||||
|
|
||||||
def current_time_str(timezone: str | None = None) -> str:
|
def current_time_str(timezone: str | None = None) -> str:
|
||||||
"""Human-readable current time with weekday and UTC offset.
|
"""Return the current time string."""
|
||||||
|
|
||||||
When *timezone* is a valid IANA name (e.g. ``"Asia/Shanghai"``), the time
|
|
||||||
is converted to that zone. Otherwise falls back to the host local time.
|
|
||||||
"""
|
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -76,12 +75,164 @@ def current_time_str(timezone: str | None = None) -> str:
|
|||||||
|
|
||||||
|
|
||||||
_UNSAFE_CHARS = re.compile(r'[<>:"/\\|?*]')
|
_UNSAFE_CHARS = re.compile(r'[<>:"/\\|?*]')
|
||||||
|
_TOOL_RESULT_PREVIEW_CHARS = 1200
|
||||||
|
_TOOL_RESULTS_DIR = ".nanobot/tool-results"
|
||||||
|
_TOOL_RESULT_RETENTION_SECS = 7 * 24 * 60 * 60
|
||||||
|
_TOOL_RESULT_MAX_BUCKETS = 32
|
||||||
|
|
||||||
def safe_filename(name: str) -> str:
|
def safe_filename(name: str) -> str:
|
||||||
"""Replace unsafe path characters with underscores."""
|
"""Replace unsafe path characters with underscores."""
|
||||||
return _UNSAFE_CHARS.sub("_", name).strip()
|
return _UNSAFE_CHARS.sub("_", name).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def image_placeholder_text(path: str | None, *, empty: str = "[image]") -> str:
|
||||||
|
"""Build an image placeholder string."""
|
||||||
|
return f"[image: {path}]" if path else empty
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_text(text: str, max_chars: int) -> str:
|
||||||
|
"""Truncate text with a stable suffix."""
|
||||||
|
if max_chars <= 0 or len(text) <= max_chars:
|
||||||
|
return text
|
||||||
|
return text[:max_chars] + "\n... (truncated)"
|
||||||
|
|
||||||
|
|
||||||
|
def find_legal_message_start(messages: list[dict[str, Any]]) -> int:
|
||||||
|
"""Find the first index whose tool results have matching assistant calls."""
|
||||||
|
declared: set[str] = set()
|
||||||
|
start = 0
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
role = msg.get("role")
|
||||||
|
if role == "assistant":
|
||||||
|
for tc in msg.get("tool_calls") or []:
|
||||||
|
if isinstance(tc, dict) and tc.get("id"):
|
||||||
|
declared.add(str(tc["id"]))
|
||||||
|
elif role == "tool":
|
||||||
|
tid = msg.get("tool_call_id")
|
||||||
|
if tid and str(tid) not in declared:
|
||||||
|
start = i + 1
|
||||||
|
declared.clear()
|
||||||
|
for prev in messages[start : i + 1]:
|
||||||
|
if prev.get("role") == "assistant":
|
||||||
|
for tc in prev.get("tool_calls") or []:
|
||||||
|
if isinstance(tc, dict) and tc.get("id"):
|
||||||
|
declared.add(str(tc["id"]))
|
||||||
|
return start
|
||||||
|
|
||||||
|
|
||||||
|
def stringify_text_blocks(content: list[dict[str, Any]]) -> str | None:
|
||||||
|
parts: list[str] = []
|
||||||
|
for block in content:
|
||||||
|
if not isinstance(block, dict):
|
||||||
|
return None
|
||||||
|
if block.get("type") != "text":
|
||||||
|
return None
|
||||||
|
text = block.get("text")
|
||||||
|
if not isinstance(text, str):
|
||||||
|
return None
|
||||||
|
parts.append(text)
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _render_tool_result_reference(
|
||||||
|
filepath: Path,
|
||||||
|
*,
|
||||||
|
original_size: int,
|
||||||
|
preview: str,
|
||||||
|
truncated_preview: bool,
|
||||||
|
) -> str:
|
||||||
|
result = (
|
||||||
|
f"[tool output persisted]\n"
|
||||||
|
f"Full output saved to: {filepath}\n"
|
||||||
|
f"Original size: {original_size} chars\n"
|
||||||
|
f"Preview:\n{preview}"
|
||||||
|
)
|
||||||
|
if truncated_preview:
|
||||||
|
result += "\n...\n(Read the saved file if you need the full output.)"
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _bucket_mtime(path: Path) -> float:
|
||||||
|
try:
|
||||||
|
return path.stat().st_mtime
|
||||||
|
except OSError:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup_tool_result_buckets(root: Path, current_bucket: Path) -> None:
|
||||||
|
siblings = [path for path in root.iterdir() if path.is_dir() and path != current_bucket]
|
||||||
|
cutoff = time.time() - _TOOL_RESULT_RETENTION_SECS
|
||||||
|
for path in siblings:
|
||||||
|
if _bucket_mtime(path) < cutoff:
|
||||||
|
shutil.rmtree(path, ignore_errors=True)
|
||||||
|
keep = max(_TOOL_RESULT_MAX_BUCKETS - 1, 0)
|
||||||
|
siblings = [path for path in siblings if path.exists()]
|
||||||
|
if len(siblings) <= keep:
|
||||||
|
return
|
||||||
|
siblings.sort(key=_bucket_mtime, reverse=True)
|
||||||
|
for path in siblings[keep:]:
|
||||||
|
shutil.rmtree(path, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _write_text_atomic(path: Path, content: str) -> None:
|
||||||
|
tmp = path.with_name(f".{path.name}.{uuid.uuid4().hex}.tmp")
|
||||||
|
try:
|
||||||
|
tmp.write_text(content, encoding="utf-8")
|
||||||
|
tmp.replace(path)
|
||||||
|
finally:
|
||||||
|
if tmp.exists():
|
||||||
|
tmp.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_persist_tool_result(
|
||||||
|
workspace: Path | None,
|
||||||
|
session_key: str | None,
|
||||||
|
tool_call_id: str,
|
||||||
|
content: Any,
|
||||||
|
*,
|
||||||
|
max_chars: int,
|
||||||
|
) -> Any:
|
||||||
|
"""Persist oversized tool output and replace it with a stable reference string."""
|
||||||
|
if workspace is None or max_chars <= 0:
|
||||||
|
return content
|
||||||
|
|
||||||
|
text_payload: str | None = None
|
||||||
|
suffix = "txt"
|
||||||
|
if isinstance(content, str):
|
||||||
|
text_payload = content
|
||||||
|
elif isinstance(content, list):
|
||||||
|
text_payload = stringify_text_blocks(content)
|
||||||
|
if text_payload is None:
|
||||||
|
return content
|
||||||
|
suffix = "json"
|
||||||
|
else:
|
||||||
|
return content
|
||||||
|
|
||||||
|
if len(text_payload) <= max_chars:
|
||||||
|
return content
|
||||||
|
|
||||||
|
root = ensure_dir(workspace / _TOOL_RESULTS_DIR)
|
||||||
|
bucket = ensure_dir(root / safe_filename(session_key or "default"))
|
||||||
|
try:
|
||||||
|
_cleanup_tool_result_buckets(root, bucket)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to clean stale tool result buckets in {}: {}", root, exc)
|
||||||
|
path = bucket / f"{safe_filename(tool_call_id)}.{suffix}"
|
||||||
|
if not path.exists():
|
||||||
|
if suffix == "json" and isinstance(content, list):
|
||||||
|
_write_text_atomic(path, json.dumps(content, ensure_ascii=False, indent=2))
|
||||||
|
else:
|
||||||
|
_write_text_atomic(path, text_payload)
|
||||||
|
|
||||||
|
preview = text_payload[:_TOOL_RESULT_PREVIEW_CHARS]
|
||||||
|
return _render_tool_result_reference(
|
||||||
|
path,
|
||||||
|
original_size=len(text_payload),
|
||||||
|
preview=preview,
|
||||||
|
truncated_preview=len(text_payload) > _TOOL_RESULT_PREVIEW_CHARS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def split_message(content: str, max_len: int = 2000) -> list[str]:
|
def split_message(content: str, max_len: int = 2000) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Split content into chunks within max_len, preferring line breaks.
|
Split content into chunks within max_len, preferring line breaks.
|
||||||
|
|||||||
88
nanobot/utils/runtime.py
Normal file
88
nanobot/utils/runtime.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
"""Runtime-specific helper functions and constants."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from nanobot.utils.helpers import stringify_text_blocks
|
||||||
|
|
||||||
|
_MAX_REPEAT_EXTERNAL_LOOKUPS = 2
|
||||||
|
|
||||||
|
EMPTY_FINAL_RESPONSE_MESSAGE = (
|
||||||
|
"I completed the tool steps but couldn't produce a final answer. "
|
||||||
|
"Please try again or narrow the task."
|
||||||
|
)
|
||||||
|
|
||||||
|
FINALIZATION_RETRY_PROMPT = (
|
||||||
|
"You have already finished the tool work. Do not call any more tools. "
|
||||||
|
"Using only the conversation and tool results above, provide the final answer for the user now."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def empty_tool_result_message(tool_name: str) -> str:
|
||||||
|
"""Short prompt-safe marker for tools that completed without visible output."""
|
||||||
|
return f"({tool_name} completed with no output)"
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_nonempty_tool_result(tool_name: str, content: Any) -> Any:
|
||||||
|
"""Replace semantically empty tool results with a short marker string."""
|
||||||
|
if content is None:
|
||||||
|
return empty_tool_result_message(tool_name)
|
||||||
|
if isinstance(content, str) and not content.strip():
|
||||||
|
return empty_tool_result_message(tool_name)
|
||||||
|
if isinstance(content, list):
|
||||||
|
if not content:
|
||||||
|
return empty_tool_result_message(tool_name)
|
||||||
|
text_payload = stringify_text_blocks(content)
|
||||||
|
if text_payload is not None and not text_payload.strip():
|
||||||
|
return empty_tool_result_message(tool_name)
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def is_blank_text(content: str | None) -> bool:
|
||||||
|
"""True when *content* is missing or only whitespace."""
|
||||||
|
return content is None or not content.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def build_finalization_retry_message() -> dict[str, str]:
|
||||||
|
"""A short no-tools-allowed prompt for final answer recovery."""
|
||||||
|
return {"role": "user", "content": FINALIZATION_RETRY_PROMPT}
|
||||||
|
|
||||||
|
|
||||||
|
def external_lookup_signature(tool_name: str, arguments: dict[str, Any]) -> str | None:
|
||||||
|
"""Stable signature for repeated external lookups we want to throttle."""
|
||||||
|
if tool_name == "web_fetch":
|
||||||
|
url = str(arguments.get("url") or "").strip()
|
||||||
|
if url:
|
||||||
|
return f"web_fetch:{url.lower()}"
|
||||||
|
if tool_name == "web_search":
|
||||||
|
query = str(arguments.get("query") or arguments.get("search_term") or "").strip()
|
||||||
|
if query:
|
||||||
|
return f"web_search:{query.lower()}"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def repeated_external_lookup_error(
|
||||||
|
tool_name: str,
|
||||||
|
arguments: dict[str, Any],
|
||||||
|
seen_counts: dict[str, int],
|
||||||
|
) -> str | None:
|
||||||
|
"""Block repeated external lookups after a small retry budget."""
|
||||||
|
signature = external_lookup_signature(tool_name, arguments)
|
||||||
|
if signature is None:
|
||||||
|
return None
|
||||||
|
count = seen_counts.get(signature, 0) + 1
|
||||||
|
seen_counts[signature] = count
|
||||||
|
if count <= _MAX_REPEAT_EXTERNAL_LOOKUPS:
|
||||||
|
return None
|
||||||
|
logger.warning(
|
||||||
|
"Blocking repeated external lookup {} on attempt {}",
|
||||||
|
signature[:160],
|
||||||
|
count,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
"Error: repeated external lookup blocked. "
|
||||||
|
"Use the results you already have to answer, or try a meaningfully different source."
|
||||||
|
)
|
||||||
@ -71,3 +71,19 @@ def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
|
|||||||
assert "Channel: cli" in user_content
|
assert "Channel: cli" in user_content
|
||||||
assert "Chat ID: direct" in user_content
|
assert "Chat ID: direct" in user_content
|
||||||
assert "Return exactly: OK" in user_content
|
assert "Return exactly: OK" in user_content
|
||||||
|
|
||||||
|
|
||||||
|
def test_subagent_result_does_not_create_consecutive_assistant_messages(tmp_path) -> None:
|
||||||
|
workspace = _make_workspace(tmp_path)
|
||||||
|
builder = ContextBuilder(workspace)
|
||||||
|
|
||||||
|
messages = builder.build_messages(
|
||||||
|
history=[{"role": "assistant", "content": "previous result"}],
|
||||||
|
current_message="subagent result",
|
||||||
|
channel="cli",
|
||||||
|
chat_id="direct",
|
||||||
|
current_role="assistant",
|
||||||
|
)
|
||||||
|
|
||||||
|
for left, right in zip(messages, messages[1:]):
|
||||||
|
assert not (left.get("role") == right.get("role") == "assistant")
|
||||||
|
|||||||
@ -5,7 +5,9 @@ from nanobot.session.manager import Session
|
|||||||
|
|
||||||
def _mk_loop() -> AgentLoop:
|
def _mk_loop() -> AgentLoop:
|
||||||
loop = AgentLoop.__new__(AgentLoop)
|
loop = AgentLoop.__new__(AgentLoop)
|
||||||
loop._TOOL_RESULT_MAX_CHARS = AgentLoop._TOOL_RESULT_MAX_CHARS
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
|
||||||
|
loop.max_tool_result_chars = AgentDefaults().max_tool_result_chars
|
||||||
return loop
|
return loop
|
||||||
|
|
||||||
|
|
||||||
@ -72,3 +74,129 @@ def test_save_turn_keeps_tool_results_under_16k() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert session.messages[0]["content"] == content
|
assert session.messages[0]["content"] == content
|
||||||
|
|
||||||
|
|
||||||
|
def test_restore_runtime_checkpoint_rehydrates_completed_and_pending_tools() -> None:
|
||||||
|
loop = _mk_loop()
|
||||||
|
session = Session(
|
||||||
|
key="test:checkpoint",
|
||||||
|
metadata={
|
||||||
|
AgentLoop._RUNTIME_CHECKPOINT_KEY: {
|
||||||
|
"assistant_message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "working",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_done",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "read_file", "arguments": "{}"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "call_pending",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "exec", "arguments": "{}"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"completed_tool_results": [
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_done",
|
||||||
|
"name": "read_file",
|
||||||
|
"content": "ok",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"pending_tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_pending",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "exec", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
restored = loop._restore_runtime_checkpoint(session)
|
||||||
|
|
||||||
|
assert restored is True
|
||||||
|
assert session.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is None
|
||||||
|
assert session.messages[0]["role"] == "assistant"
|
||||||
|
assert session.messages[1]["tool_call_id"] == "call_done"
|
||||||
|
assert session.messages[2]["tool_call_id"] == "call_pending"
|
||||||
|
assert "interrupted before this tool finished" in session.messages[2]["content"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_restore_runtime_checkpoint_dedupes_overlapping_tail() -> None:
|
||||||
|
loop = _mk_loop()
|
||||||
|
session = Session(
|
||||||
|
key="test:checkpoint-overlap",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "working",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_done",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "read_file", "arguments": "{}"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "call_pending",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "exec", "arguments": "{}"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_done",
|
||||||
|
"name": "read_file",
|
||||||
|
"content": "ok",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
metadata={
|
||||||
|
AgentLoop._RUNTIME_CHECKPOINT_KEY: {
|
||||||
|
"assistant_message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "working",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_done",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "read_file", "arguments": "{}"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "call_pending",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "exec", "arguments": "{}"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"completed_tool_results": [
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_done",
|
||||||
|
"name": "read_file",
|
||||||
|
"content": "ok",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"pending_tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_pending",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "exec", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
restored = loop._restore_runtime_checkpoint(session)
|
||||||
|
|
||||||
|
assert restored is True
|
||||||
|
assert session.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is None
|
||||||
|
assert len(session.messages) == 3
|
||||||
|
assert session.messages[0]["role"] == "assistant"
|
||||||
|
assert session.messages[1]["tool_call_id"] == "call_done"
|
||||||
|
assert session.messages[2]["tool_call_id"] == "call_pending"
|
||||||
|
|||||||
@ -2,12 +2,20 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import time
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
from nanobot.agent.tools.base import Tool
|
||||||
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||||
|
|
||||||
|
|
||||||
def _make_loop(tmp_path):
|
def _make_loop(tmp_path):
|
||||||
from nanobot.agent.loop import AgentLoop
|
from nanobot.agent.loop import AgentLoop
|
||||||
@ -60,6 +68,7 @@ async def test_runner_preserves_reasoning_fields_and_tool_results():
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
model="test-model",
|
model="test-model",
|
||||||
max_iterations=3,
|
max_iterations=3,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
))
|
))
|
||||||
|
|
||||||
assert result.final_content == "done"
|
assert result.final_content == "done"
|
||||||
@ -135,6 +144,7 @@ async def test_runner_calls_hooks_in_order():
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
model="test-model",
|
model="test-model",
|
||||||
max_iterations=3,
|
max_iterations=3,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
hook=RecordingHook(),
|
hook=RecordingHook(),
|
||||||
))
|
))
|
||||||
|
|
||||||
@ -191,6 +201,7 @@ async def test_runner_streaming_hook_receives_deltas_and_end_signal():
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
model="test-model",
|
model="test-model",
|
||||||
max_iterations=1,
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
hook=StreamingHook(),
|
hook=StreamingHook(),
|
||||||
))
|
))
|
||||||
|
|
||||||
@ -219,6 +230,7 @@ async def test_runner_returns_max_iterations_fallback():
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
model="test-model",
|
model="test-model",
|
||||||
max_iterations=2,
|
max_iterations=2,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
))
|
))
|
||||||
|
|
||||||
assert result.stop_reason == "max_iterations"
|
assert result.stop_reason == "max_iterations"
|
||||||
@ -226,7 +238,8 @@ async def test_runner_returns_max_iterations_fallback():
|
|||||||
"I reached the maximum number of tool call iterations (2) "
|
"I reached the maximum number of tool call iterations (2) "
|
||||||
"without completing the task. You can try breaking the task into smaller steps."
|
"without completing the task. You can try breaking the task into smaller steps."
|
||||||
)
|
)
|
||||||
|
assert result.messages[-1]["role"] == "assistant"
|
||||||
|
assert result.messages[-1]["content"] == result.final_content
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_runner_returns_structured_tool_error():
|
async def test_runner_returns_structured_tool_error():
|
||||||
@ -248,6 +261,7 @@ async def test_runner_returns_structured_tool_error():
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
model="test-model",
|
model="test-model",
|
||||||
max_iterations=2,
|
max_iterations=2,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
fail_on_tool_error=True,
|
fail_on_tool_error=True,
|
||||||
))
|
))
|
||||||
|
|
||||||
@ -258,6 +272,457 @@ async def test_runner_returns_structured_tool_error():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_persists_large_tool_results_for_follow_up_calls(tmp_path):
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
captured_second_call: list[dict] = []
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content="working",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_big", name="list_dir", arguments={"path": "."})],
|
||||||
|
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||||
|
)
|
||||||
|
captured_second_call[:] = messages
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(return_value="x" * 20_000)
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "do task"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=2,
|
||||||
|
workspace=tmp_path,
|
||||||
|
session_key="test:runner",
|
||||||
|
max_tool_result_chars=2048,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
|
||||||
|
assert "[tool output persisted]" in tool_message["content"]
|
||||||
|
assert "tool-results" in tool_message["content"]
|
||||||
|
assert (tmp_path / ".nanobot" / "tool-results" / "test_runner" / "call_big.txt").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_persist_tool_result_prunes_old_session_buckets(tmp_path):
|
||||||
|
from nanobot.utils.helpers import maybe_persist_tool_result
|
||||||
|
|
||||||
|
root = tmp_path / ".nanobot" / "tool-results"
|
||||||
|
old_bucket = root / "old_session"
|
||||||
|
recent_bucket = root / "recent_session"
|
||||||
|
old_bucket.mkdir(parents=True)
|
||||||
|
recent_bucket.mkdir(parents=True)
|
||||||
|
(old_bucket / "old.txt").write_text("old", encoding="utf-8")
|
||||||
|
(recent_bucket / "recent.txt").write_text("recent", encoding="utf-8")
|
||||||
|
|
||||||
|
stale = time.time() - (8 * 24 * 60 * 60)
|
||||||
|
os.utime(old_bucket, (stale, stale))
|
||||||
|
os.utime(old_bucket / "old.txt", (stale, stale))
|
||||||
|
|
||||||
|
persisted = maybe_persist_tool_result(
|
||||||
|
tmp_path,
|
||||||
|
"current:session",
|
||||||
|
"call_big",
|
||||||
|
"x" * 5000,
|
||||||
|
max_chars=64,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "[tool output persisted]" in persisted
|
||||||
|
assert not old_bucket.exists()
|
||||||
|
assert recent_bucket.exists()
|
||||||
|
assert (root / "current_session" / "call_big.txt").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_persist_tool_result_leaves_no_temp_files(tmp_path):
|
||||||
|
from nanobot.utils.helpers import maybe_persist_tool_result
|
||||||
|
|
||||||
|
root = tmp_path / ".nanobot" / "tool-results"
|
||||||
|
maybe_persist_tool_result(
|
||||||
|
tmp_path,
|
||||||
|
"current:session",
|
||||||
|
"call_big",
|
||||||
|
"x" * 5000,
|
||||||
|
max_chars=64,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (root / "current_session" / "call_big.txt").exists()
|
||||||
|
assert list((root / "current_session").glob("*.tmp")) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_persist_tool_result_logs_cleanup_failures(monkeypatch, tmp_path):
|
||||||
|
from nanobot.utils.helpers import maybe_persist_tool_result
|
||||||
|
|
||||||
|
warnings: list[str] = []
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.utils.helpers._cleanup_tool_result_buckets",
|
||||||
|
lambda *_args, **_kwargs: (_ for _ in ()).throw(OSError("busy")),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.utils.helpers.logger.warning",
|
||||||
|
lambda message, *args: warnings.append(message.format(*args)),
|
||||||
|
)
|
||||||
|
|
||||||
|
persisted = maybe_persist_tool_result(
|
||||||
|
tmp_path,
|
||||||
|
"current:session",
|
||||||
|
"call_big",
|
||||||
|
"x" * 5000,
|
||||||
|
max_chars=64,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "[tool output persisted]" in persisted
|
||||||
|
assert warnings and "Failed to clean stale tool result buckets" in warnings[0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_replaces_empty_tool_result_with_marker():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
captured_second_call: list[dict] = []
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content="working",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="noop", arguments={})],
|
||||||
|
usage={},
|
||||||
|
)
|
||||||
|
captured_second_call[:] = messages
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(return_value="")
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "do task"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=2,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
|
||||||
|
assert tool_message["content"] == "(noop completed with no output)"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_uses_raw_messages_when_context_governance_fails():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
captured_messages: list[dict] = []
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
captured_messages[:] = messages
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
initial_messages = [
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
]
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
runner._snip_history = MagicMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign]
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=initial_messages,
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
assert captured_messages == initial_messages
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_retries_empty_final_response_with_summary_prompt():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
calls: list[dict] = []
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, tools=None, **kwargs):
|
||||||
|
calls.append({"messages": messages, "tools": tools})
|
||||||
|
if len(calls) == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[],
|
||||||
|
usage={"prompt_tokens": 10, "completion_tokens": 1},
|
||||||
|
)
|
||||||
|
return LLMResponse(
|
||||||
|
content="final answer",
|
||||||
|
tool_calls=[],
|
||||||
|
usage={"prompt_tokens": 3, "completion_tokens": 7},
|
||||||
|
)
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "do task"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "final answer"
|
||||||
|
assert len(calls) == 2
|
||||||
|
assert calls[1]["tools"] is None
|
||||||
|
assert "Do not call any more tools" in calls[1]["messages"][-1]["content"]
|
||||||
|
assert result.usage["prompt_tokens"] == 13
|
||||||
|
assert result.usage["completion_tokens"] == 8
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_uses_specific_message_after_empty_finalization_retry():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
return LLMResponse(content=None, tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "do task"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
|
assert result.stop_reason == "empty_final_response"
|
||||||
|
|
||||||
|
|
||||||
|
def test_snip_history_drops_orphaned_tool_results_from_trimmed_slice(monkeypatch):
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "old user"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "tool call",
|
||||||
|
"tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "ls", "arguments": "{}"}}],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "call_1", "content": "tool output"},
|
||||||
|
{"role": "assistant", "content": "after tool"},
|
||||||
|
]
|
||||||
|
spec = AgentRunSpec(
|
||||||
|
initial_messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
context_window_tokens=2000,
|
||||||
|
context_block_limit=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", lambda *_args, **_kwargs: (500, None))
|
||||||
|
token_sizes = {
|
||||||
|
"old user": 120,
|
||||||
|
"tool call": 120,
|
||||||
|
"tool output": 40,
|
||||||
|
"after tool": 40,
|
||||||
|
"system": 0,
|
||||||
|
}
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.agent.runner.estimate_message_tokens",
|
||||||
|
lambda msg: token_sizes.get(str(msg.get("content")), 40),
|
||||||
|
)
|
||||||
|
|
||||||
|
trimmed = runner._snip_history(spec, messages)
|
||||||
|
|
||||||
|
assert trimmed == [
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "assistant", "content": "after tool"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_keeps_going_when_tool_result_persistence_fails():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
captured_second_call: list[dict] = []
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content="working",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||||
|
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||||
|
)
|
||||||
|
captured_second_call[:] = messages
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(return_value="tool result")
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
with patch("nanobot.agent.runner.maybe_persist_tool_result", side_effect=RuntimeError("disk full")):
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "do task"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=2,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
|
||||||
|
assert tool_message["content"] == "tool result"
|
||||||
|
|
||||||
|
|
||||||
|
class _DelayTool(Tool):
|
||||||
|
def __init__(self, name: str, *, delay: float, read_only: bool, shared_events: list[str]):
|
||||||
|
self._name = name
|
||||||
|
self._delay = delay
|
||||||
|
self._read_only = read_only
|
||||||
|
self._shared_events = shared_events
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict:
|
||||||
|
return {"type": "object", "properties": {}, "required": []}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def read_only(self) -> bool:
|
||||||
|
return self._read_only
|
||||||
|
|
||||||
|
async def execute(self, **kwargs):
|
||||||
|
self._shared_events.append(f"start:{self._name}")
|
||||||
|
await asyncio.sleep(self._delay)
|
||||||
|
self._shared_events.append(f"end:{self._name}")
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_batches_read_only_tools_before_exclusive_work():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
tools = ToolRegistry()
|
||||||
|
shared_events: list[str] = []
|
||||||
|
read_a = _DelayTool("read_a", delay=0.05, read_only=True, shared_events=shared_events)
|
||||||
|
read_b = _DelayTool("read_b", delay=0.05, read_only=True, shared_events=shared_events)
|
||||||
|
write_a = _DelayTool("write_a", delay=0.01, read_only=False, shared_events=shared_events)
|
||||||
|
tools.register(read_a)
|
||||||
|
tools.register(read_b)
|
||||||
|
tools.register(write_a)
|
||||||
|
|
||||||
|
runner = AgentRunner(MagicMock())
|
||||||
|
await runner._execute_tools(
|
||||||
|
AgentRunSpec(
|
||||||
|
initial_messages=[],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
concurrent_tools=True,
|
||||||
|
),
|
||||||
|
[
|
||||||
|
ToolCallRequest(id="ro1", name="read_a", arguments={}),
|
||||||
|
ToolCallRequest(id="ro2", name="read_b", arguments={}),
|
||||||
|
ToolCallRequest(id="rw1", name="write_a", arguments={}),
|
||||||
|
],
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert shared_events[0:2] == ["start:read_a", "start:read_b"]
|
||||||
|
assert "end:read_a" in shared_events and "end:read_b" in shared_events
|
||||||
|
assert shared_events.index("end:read_a") < shared_events.index("start:write_a")
|
||||||
|
assert shared_events.index("end:read_b") < shared_events.index("start:write_a")
|
||||||
|
assert shared_events[-2:] == ["start:write_a", "end:write_a"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_blocks_repeated_external_fetches():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
captured_final_call: list[dict] = []
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] <= 3:
|
||||||
|
return LLMResponse(
|
||||||
|
content="working",
|
||||||
|
tool_calls=[ToolCallRequest(id=f"call_{call_count['n']}", name="web_fetch", arguments={"url": "https://example.com"})],
|
||||||
|
usage={},
|
||||||
|
)
|
||||||
|
captured_final_call[:] = messages
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(return_value="page content")
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "research task"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=4,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
assert tools.execute.await_count == 2
|
||||||
|
blocked_tool_message = [
|
||||||
|
msg for msg in captured_final_call
|
||||||
|
if msg.get("role") == "tool" and msg.get("tool_call_id") == "call_3"
|
||||||
|
][0]
|
||||||
|
assert "repeated external lookup blocked" in blocked_tool_message["content"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_loop_max_iterations_message_stays_stable(tmp_path):
|
async def test_loop_max_iterations_message_stays_stable(tmp_path):
|
||||||
loop = _make_loop(tmp_path)
|
loop = _make_loop(tmp_path)
|
||||||
@ -307,6 +772,57 @@ async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp
|
|||||||
assert endings == [False]
|
assert endings == [False]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_loop_retries_think_only_final_response(tmp_path):
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(**kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(content="<think>hidden</think>", tool_calls=[], usage={})
|
||||||
|
return LLMResponse(content="Recovered answer", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
loop.provider.chat_with_retry = chat_with_retry
|
||||||
|
|
||||||
|
final_content, _, _ = await loop._run_agent_loop([])
|
||||||
|
|
||||||
|
assert final_content == "Recovered answer"
|
||||||
|
assert call_count["n"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_tool_error_sets_final_content():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
return LLMResponse(
|
||||||
|
content="working",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})],
|
||||||
|
usage={},
|
||||||
|
)
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(side_effect=RuntimeError("boom"))
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "do task"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
fail_on_tool_error=True,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "Error: RuntimeError: boom"
|
||||||
|
assert result.stop_reason == "tool_error"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, monkeypatch):
|
async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, monkeypatch):
|
||||||
from nanobot.agent.subagent import SubagentManager
|
from nanobot.agent.subagent import SubagentManager
|
||||||
@ -317,15 +833,20 @@ async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, mon
|
|||||||
provider.get_default_model.return_value = "test-model"
|
provider.get_default_model.return_value = "test-model"
|
||||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||||
content="working",
|
content="working",
|
||||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||||
))
|
))
|
||||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
mgr = SubagentManager(
|
||||||
|
provider=provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
bus=bus,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
)
|
||||||
mgr._announce_result = AsyncMock()
|
mgr._announce_result = AsyncMock()
|
||||||
|
|
||||||
async def fake_execute(self, name, arguments):
|
async def fake_execute(self, **kwargs):
|
||||||
return "tool result"
|
return "tool result"
|
||||||
|
|
||||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
|
||||||
|
|
||||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||||
|
|
||||||
@ -369,6 +890,7 @@ async def test_runner_accumulates_usage_and_preserves_cached_tokens():
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
model="test-model",
|
model="test-model",
|
||||||
max_iterations=3,
|
max_iterations=3,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
))
|
))
|
||||||
|
|
||||||
# Usage should be accumulated across iterations
|
# Usage should be accumulated across iterations
|
||||||
@ -407,6 +929,7 @@ async def test_runner_passes_cached_tokens_to_hook_context():
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
model="test-model",
|
model="test-model",
|
||||||
max_iterations=1,
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
hook=UsageHook(),
|
hook=UsageHook(),
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
|
||||||
|
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||||
|
|
||||||
|
|
||||||
def _make_loop(*, exec_config=None):
|
def _make_loop(*, exec_config=None):
|
||||||
"""Create a minimal AgentLoop with mocked dependencies."""
|
"""Create a minimal AgentLoop with mocked dependencies."""
|
||||||
@ -186,7 +190,12 @@ class TestSubagentCancellation:
|
|||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
provider.get_default_model.return_value = "test-model"
|
provider.get_default_model.return_value = "test-model"
|
||||||
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
mgr = SubagentManager(
|
||||||
|
provider=provider,
|
||||||
|
workspace=MagicMock(),
|
||||||
|
bus=bus,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
)
|
||||||
|
|
||||||
cancelled = asyncio.Event()
|
cancelled = asyncio.Event()
|
||||||
|
|
||||||
@ -214,7 +223,12 @@ class TestSubagentCancellation:
|
|||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
provider.get_default_model.return_value = "test-model"
|
provider.get_default_model.return_value = "test-model"
|
||||||
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
mgr = SubagentManager(
|
||||||
|
provider=provider,
|
||||||
|
workspace=MagicMock(),
|
||||||
|
bus=bus,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
)
|
||||||
assert await mgr.cancel_by_session("nonexistent") == 0
|
assert await mgr.cancel_by_session("nonexistent") == 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -236,19 +250,24 @@ class TestSubagentCancellation:
|
|||||||
if call_count["n"] == 1:
|
if call_count["n"] == 1:
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content="thinking",
|
content="thinking",
|
||||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||||
reasoning_content="hidden reasoning",
|
reasoning_content="hidden reasoning",
|
||||||
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
|
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
|
||||||
)
|
)
|
||||||
captured_second_call[:] = messages
|
captured_second_call[:] = messages
|
||||||
return LLMResponse(content="done", tool_calls=[])
|
return LLMResponse(content="done", tool_calls=[])
|
||||||
provider.chat_with_retry = scripted_chat_with_retry
|
provider.chat_with_retry = scripted_chat_with_retry
|
||||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
mgr = SubagentManager(
|
||||||
|
provider=provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
bus=bus,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
)
|
||||||
|
|
||||||
async def fake_execute(self, name, arguments):
|
async def fake_execute(self, **kwargs):
|
||||||
return "tool result"
|
return "tool result"
|
||||||
|
|
||||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
|
||||||
|
|
||||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||||
|
|
||||||
@ -273,6 +292,7 @@ class TestSubagentCancellation:
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
workspace=tmp_path,
|
workspace=tmp_path,
|
||||||
bus=bus,
|
bus=bus,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
exec_config=ExecToolConfig(enable=False),
|
exec_config=ExecToolConfig(enable=False),
|
||||||
)
|
)
|
||||||
mgr._announce_result = AsyncMock()
|
mgr._announce_result = AsyncMock()
|
||||||
@ -304,20 +324,25 @@ class TestSubagentCancellation:
|
|||||||
provider.get_default_model.return_value = "test-model"
|
provider.get_default_model.return_value = "test-model"
|
||||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||||
content="thinking",
|
content="thinking",
|
||||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||||
))
|
))
|
||||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
mgr = SubagentManager(
|
||||||
|
provider=provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
bus=bus,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
)
|
||||||
mgr._announce_result = AsyncMock()
|
mgr._announce_result = AsyncMock()
|
||||||
|
|
||||||
calls = {"n": 0}
|
calls = {"n": 0}
|
||||||
|
|
||||||
async def fake_execute(self, name, arguments):
|
async def fake_execute(self, **kwargs):
|
||||||
calls["n"] += 1
|
calls["n"] += 1
|
||||||
if calls["n"] == 1:
|
if calls["n"] == 1:
|
||||||
return "first result"
|
return "first result"
|
||||||
raise RuntimeError("boom")
|
raise RuntimeError("boom")
|
||||||
|
|
||||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
|
||||||
|
|
||||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||||
|
|
||||||
@ -340,15 +365,20 @@ class TestSubagentCancellation:
|
|||||||
provider.get_default_model.return_value = "test-model"
|
provider.get_default_model.return_value = "test-model"
|
||||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||||
content="thinking",
|
content="thinking",
|
||||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||||
))
|
))
|
||||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
mgr = SubagentManager(
|
||||||
|
provider=provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
bus=bus,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
)
|
||||||
mgr._announce_result = AsyncMock()
|
mgr._announce_result = AsyncMock()
|
||||||
|
|
||||||
started = asyncio.Event()
|
started = asyncio.Event()
|
||||||
cancelled = asyncio.Event()
|
cancelled = asyncio.Event()
|
||||||
|
|
||||||
async def fake_execute(self, name, arguments):
|
async def fake_execute(self, **kwargs):
|
||||||
started.set()
|
started.set()
|
||||||
try:
|
try:
|
||||||
await asyncio.sleep(60)
|
await asyncio.sleep(60)
|
||||||
@ -356,7 +386,7 @@ class TestSubagentCancellation:
|
|||||||
cancelled.set()
|
cancelled.set()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
|
||||||
|
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||||
@ -364,7 +394,7 @@ class TestSubagentCancellation:
|
|||||||
mgr._running_tasks["sub-1"] = task
|
mgr._running_tasks["sub-1"] = task
|
||||||
mgr._session_tasks["test:c1"] = {"sub-1"}
|
mgr._session_tasks["test:c1"] = {"sub-1"}
|
||||||
|
|
||||||
await started.wait()
|
await asyncio.wait_for(started.wait(), timeout=1.0)
|
||||||
|
|
||||||
count = await mgr.cancel_by_session("test:c1")
|
count = await mgr.cancel_by_session("test:c1")
|
||||||
|
|
||||||
|
|||||||
@ -208,7 +208,7 @@ def test_channels_login_uses_discovered_plugin_class(monkeypatch):
|
|||||||
seen["config"] = self.config
|
seen["config"] = self.config
|
||||||
return True
|
return True
|
||||||
|
|
||||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda: Config())
|
monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config())
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.channels.registry.discover_all",
|
"nanobot.channels.registry.discover_all",
|
||||||
lambda: {"fakeplugin": _LoginPlugin},
|
lambda: {"fakeplugin": _LoginPlugin},
|
||||||
@ -220,6 +220,57 @@ def test_channels_login_uses_discovered_plugin_class(monkeypatch):
|
|||||||
assert seen["force"] is True
|
assert seen["force"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_channels_login_sets_custom_config_path(monkeypatch, tmp_path):
|
||||||
|
from nanobot.cli.commands import app
|
||||||
|
from nanobot.config.schema import Config
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
seen: dict[str, object] = {}
|
||||||
|
config_path = tmp_path / "custom-config.json"
|
||||||
|
|
||||||
|
class _LoginPlugin(_FakePlugin):
|
||||||
|
async def login(self, force: bool = False) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.config.loader.set_config_path",
|
||||||
|
lambda path: seen.__setitem__("config_path", path),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.channels.registry.discover_all",
|
||||||
|
lambda: {"fakeplugin": _LoginPlugin},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = runner.invoke(app, ["channels", "login", "fakeplugin", "--config", str(config_path)])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert seen["config_path"] == config_path.resolve()
|
||||||
|
|
||||||
|
|
||||||
|
def test_channels_status_sets_custom_config_path(monkeypatch, tmp_path):
|
||||||
|
from nanobot.cli.commands import app
|
||||||
|
from nanobot.config.schema import Config
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
seen: dict[str, object] = {}
|
||||||
|
config_path = tmp_path / "custom-config.json"
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.config.loader.set_config_path",
|
||||||
|
lambda path: seen.__setitem__("config_path", path),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("nanobot.channels.registry.discover_all", lambda: {})
|
||||||
|
|
||||||
|
result = runner.invoke(app, ["channels", "status", "--config", str(config_path)])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert seen["config_path"] == config_path.resolve()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_manager_skips_disabled_plugin():
|
async def test_manager_skips_disabled_plugin():
|
||||||
fake_config = SimpleNamespace(
|
fake_config = SimpleNamespace(
|
||||||
|
|||||||
@ -594,7 +594,7 @@ async def test_send_stops_typing_after_send() -> None:
|
|||||||
typing_channel.typing_enter_hook = slow_typing
|
typing_channel.typing_enter_hook = slow_typing
|
||||||
|
|
||||||
await channel._start_typing(typing_channel)
|
await channel._start_typing(typing_channel)
|
||||||
await start.wait()
|
await asyncio.wait_for(start.wait(), timeout=1.0)
|
||||||
|
|
||||||
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||||
release.set()
|
release.set()
|
||||||
@ -614,7 +614,7 @@ async def test_send_stops_typing_after_send() -> None:
|
|||||||
typing_channel.typing_enter_hook = slow_typing_progress
|
typing_channel.typing_enter_hook = slow_typing_progress
|
||||||
|
|
||||||
await channel._start_typing(typing_channel)
|
await channel._start_typing(typing_channel)
|
||||||
await start.wait()
|
await asyncio.wait_for(start.wait(), timeout=1.0)
|
||||||
|
|
||||||
await channel.send(
|
await channel.send(
|
||||||
OutboundMessage(
|
OutboundMessage(
|
||||||
@ -665,7 +665,7 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() ->
|
|||||||
|
|
||||||
typing_channel = _NoTriggerChannel(channel_id=123)
|
typing_channel = _NoTriggerChannel(channel_id=123)
|
||||||
await channel._start_typing(typing_channel) # type: ignore[arg-type]
|
await channel._start_typing(typing_channel) # type: ignore[arg-type]
|
||||||
await entered.wait()
|
await asyncio.wait_for(entered.wait(), timeout=1.0)
|
||||||
|
|
||||||
assert "123" in channel._typing_tasks
|
assert "123" in channel._typing_tasks
|
||||||
|
|
||||||
|
|||||||
@ -3,16 +3,14 @@ from pathlib import Path
|
|||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
pytest.importorskip("nio")
|
||||||
|
pytest.importorskip("nh3")
|
||||||
|
pytest.importorskip("mistune")
|
||||||
from nio import RoomSendResponse
|
from nio import RoomSendResponse
|
||||||
|
|
||||||
from nanobot.channels.matrix import _build_matrix_text_content
|
from nanobot.channels.matrix import _build_matrix_text_content
|
||||||
|
|
||||||
# Check optional matrix dependencies before importing
|
|
||||||
try:
|
|
||||||
import nh3 # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
pytest.skip("Matrix dependencies not installed (nh3)", allow_module_level=True)
|
|
||||||
|
|
||||||
import nanobot.channels.matrix as matrix_module
|
import nanobot.channels.matrix as matrix_module
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|||||||
172
tests/channels/test_qq_ack_message.py
Normal file
172
tests/channels/test_qq_ack_message.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
"""Tests for QQ channel ack_message feature.
|
||||||
|
|
||||||
|
Covers the four verification points from the PR:
|
||||||
|
1. C2C message: ack appears instantly
|
||||||
|
2. Group message: ack appears instantly
|
||||||
|
3. ack_message set to "": no ack sent
|
||||||
|
4. Custom ack_message text: correct text delivered
|
||||||
|
Each test also verifies that normal message processing is not blocked.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
try:
|
||||||
|
from nanobot.channels import qq
|
||||||
|
|
||||||
|
QQ_AVAILABLE = getattr(qq, "QQ_AVAILABLE", False)
|
||||||
|
except ImportError:
|
||||||
|
QQ_AVAILABLE = False
|
||||||
|
|
||||||
|
if not QQ_AVAILABLE:
|
||||||
|
pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True)
|
||||||
|
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.channels.qq import QQChannel, QQConfig
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeApi:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.c2c_calls: list[dict] = []
|
||||||
|
self.group_calls: list[dict] = []
|
||||||
|
|
||||||
|
async def post_c2c_message(self, **kwargs) -> None:
|
||||||
|
self.c2c_calls.append(kwargs)
|
||||||
|
|
||||||
|
async def post_group_message(self, **kwargs) -> None:
|
||||||
|
self.group_calls.append(kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeClient:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.api = _FakeApi()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ack_sent_on_c2c_message() -> None:
|
||||||
|
"""Ack is sent immediately for C2C messages, then normal processing continues."""
|
||||||
|
channel = QQChannel(
|
||||||
|
QQConfig(
|
||||||
|
app_id="app",
|
||||||
|
secret="secret",
|
||||||
|
allow_from=["*"],
|
||||||
|
ack_message="⏳ Processing...",
|
||||||
|
),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._client = _FakeClient()
|
||||||
|
|
||||||
|
data = SimpleNamespace(
|
||||||
|
id="msg1",
|
||||||
|
content="hello",
|
||||||
|
author=SimpleNamespace(user_openid="user1"),
|
||||||
|
attachments=[],
|
||||||
|
)
|
||||||
|
await channel._on_message(data, is_group=False)
|
||||||
|
|
||||||
|
assert len(channel._client.api.c2c_calls) >= 1
|
||||||
|
ack_call = channel._client.api.c2c_calls[0]
|
||||||
|
assert ack_call["content"] == "⏳ Processing..."
|
||||||
|
assert ack_call["openid"] == "user1"
|
||||||
|
assert ack_call["msg_id"] == "msg1"
|
||||||
|
assert ack_call["msg_type"] == 0
|
||||||
|
|
||||||
|
msg = await channel.bus.consume_inbound()
|
||||||
|
assert msg.content == "hello"
|
||||||
|
assert msg.sender_id == "user1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ack_sent_on_group_message() -> None:
|
||||||
|
"""Ack is sent immediately for group messages, then normal processing continues."""
|
||||||
|
channel = QQChannel(
|
||||||
|
QQConfig(
|
||||||
|
app_id="app",
|
||||||
|
secret="secret",
|
||||||
|
allow_from=["*"],
|
||||||
|
ack_message="⏳ Processing...",
|
||||||
|
),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._client = _FakeClient()
|
||||||
|
|
||||||
|
data = SimpleNamespace(
|
||||||
|
id="msg2",
|
||||||
|
content="hello group",
|
||||||
|
group_openid="group123",
|
||||||
|
author=SimpleNamespace(member_openid="user1"),
|
||||||
|
attachments=[],
|
||||||
|
)
|
||||||
|
await channel._on_message(data, is_group=True)
|
||||||
|
|
||||||
|
assert len(channel._client.api.group_calls) >= 1
|
||||||
|
ack_call = channel._client.api.group_calls[0]
|
||||||
|
assert ack_call["content"] == "⏳ Processing..."
|
||||||
|
assert ack_call["group_openid"] == "group123"
|
||||||
|
assert ack_call["msg_id"] == "msg2"
|
||||||
|
assert ack_call["msg_type"] == 0
|
||||||
|
|
||||||
|
msg = await channel.bus.consume_inbound()
|
||||||
|
assert msg.content == "hello group"
|
||||||
|
assert msg.chat_id == "group123"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_ack_when_ack_message_empty() -> None:
|
||||||
|
"""Setting ack_message to empty string disables the ack entirely."""
|
||||||
|
channel = QQChannel(
|
||||||
|
QQConfig(
|
||||||
|
app_id="app",
|
||||||
|
secret="secret",
|
||||||
|
allow_from=["*"],
|
||||||
|
ack_message="",
|
||||||
|
),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._client = _FakeClient()
|
||||||
|
|
||||||
|
data = SimpleNamespace(
|
||||||
|
id="msg3",
|
||||||
|
content="hello",
|
||||||
|
author=SimpleNamespace(user_openid="user1"),
|
||||||
|
attachments=[],
|
||||||
|
)
|
||||||
|
await channel._on_message(data, is_group=False)
|
||||||
|
|
||||||
|
assert len(channel._client.api.c2c_calls) == 0
|
||||||
|
assert len(channel._client.api.group_calls) == 0
|
||||||
|
|
||||||
|
msg = await channel.bus.consume_inbound()
|
||||||
|
assert msg.content == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_custom_ack_message_text() -> None:
|
||||||
|
"""Custom Chinese ack_message text is delivered correctly."""
|
||||||
|
custom = "正在处理中,请稍候..."
|
||||||
|
channel = QQChannel(
|
||||||
|
QQConfig(
|
||||||
|
app_id="app",
|
||||||
|
secret="secret",
|
||||||
|
allow_from=["*"],
|
||||||
|
ack_message=custom,
|
||||||
|
),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._client = _FakeClient()
|
||||||
|
|
||||||
|
data = SimpleNamespace(
|
||||||
|
id="msg4",
|
||||||
|
content="test input",
|
||||||
|
author=SimpleNamespace(user_openid="user1"),
|
||||||
|
attachments=[],
|
||||||
|
)
|
||||||
|
await channel._on_message(data, is_group=False)
|
||||||
|
|
||||||
|
assert len(channel._client.api.c2c_calls) >= 1
|
||||||
|
ack_call = channel._client.api.c2c_calls[0]
|
||||||
|
assert ack_call["content"] == custom
|
||||||
|
|
||||||
|
msg = await channel.bus.consume_inbound()
|
||||||
|
assert msg.content == "test input"
|
||||||
@ -647,43 +647,56 @@ async def test_group_policy_open_accepts_plain_group_message() -> None:
|
|||||||
assert channel._app.bot.get_me_calls == 0
|
assert channel._app.bot.get_me_calls == 0
|
||||||
|
|
||||||
|
|
||||||
def test_extract_reply_context_no_reply() -> None:
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_reply_context_no_reply() -> None:
|
||||||
"""When there is no reply_to_message, _extract_reply_context returns None."""
|
"""When there is no reply_to_message, _extract_reply_context returns None."""
|
||||||
|
channel = TelegramChannel(TelegramConfig(enabled=True, token="123:abc"), MessageBus())
|
||||||
message = SimpleNamespace(reply_to_message=None)
|
message = SimpleNamespace(reply_to_message=None)
|
||||||
assert TelegramChannel._extract_reply_context(message) is None
|
assert await channel._extract_reply_context(message) is None
|
||||||
|
|
||||||
|
|
||||||
def test_extract_reply_context_with_text() -> None:
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_reply_context_with_text() -> None:
|
||||||
"""When reply has text, return prefixed string."""
|
"""When reply has text, return prefixed string."""
|
||||||
reply = SimpleNamespace(text="Hello world", caption=None)
|
channel = TelegramChannel(TelegramConfig(enabled=True, token="123:abc"), MessageBus())
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
reply = SimpleNamespace(text="Hello world", caption=None, from_user=SimpleNamespace(id=2, username="testuser", first_name="Test"))
|
||||||
message = SimpleNamespace(reply_to_message=reply)
|
message = SimpleNamespace(reply_to_message=reply)
|
||||||
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Hello world]"
|
assert await channel._extract_reply_context(message) == "[Reply to @testuser: Hello world]"
|
||||||
|
|
||||||
|
|
||||||
def test_extract_reply_context_with_caption_only() -> None:
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_reply_context_with_caption_only() -> None:
|
||||||
"""When reply has only caption (no text), caption is used."""
|
"""When reply has only caption (no text), caption is used."""
|
||||||
reply = SimpleNamespace(text=None, caption="Photo caption")
|
channel = TelegramChannel(TelegramConfig(enabled=True, token="123:abc"), MessageBus())
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
reply = SimpleNamespace(text=None, caption="Photo caption", from_user=SimpleNamespace(id=2, username=None, first_name="Test"))
|
||||||
message = SimpleNamespace(reply_to_message=reply)
|
message = SimpleNamespace(reply_to_message=reply)
|
||||||
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Photo caption]"
|
assert await channel._extract_reply_context(message) == "[Reply to Test: Photo caption]"
|
||||||
|
|
||||||
|
|
||||||
def test_extract_reply_context_truncation() -> None:
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_reply_context_truncation() -> None:
|
||||||
"""Reply text is truncated at TELEGRAM_REPLY_CONTEXT_MAX_LEN."""
|
"""Reply text is truncated at TELEGRAM_REPLY_CONTEXT_MAX_LEN."""
|
||||||
|
channel = TelegramChannel(TelegramConfig(enabled=True, token="123:abc"), MessageBus())
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
long_text = "x" * (TELEGRAM_REPLY_CONTEXT_MAX_LEN + 100)
|
long_text = "x" * (TELEGRAM_REPLY_CONTEXT_MAX_LEN + 100)
|
||||||
reply = SimpleNamespace(text=long_text, caption=None)
|
reply = SimpleNamespace(text=long_text, caption=None, from_user=SimpleNamespace(id=2, username=None, first_name=None))
|
||||||
message = SimpleNamespace(reply_to_message=reply)
|
message = SimpleNamespace(reply_to_message=reply)
|
||||||
result = TelegramChannel._extract_reply_context(message)
|
result = await channel._extract_reply_context(message)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.startswith("[Reply to: ")
|
assert result.startswith("[Reply to: ")
|
||||||
assert result.endswith("...]")
|
assert result.endswith("...]")
|
||||||
assert len(result) == len("[Reply to: ]") + TELEGRAM_REPLY_CONTEXT_MAX_LEN + len("...")
|
assert len(result) == len("[Reply to: ]") + TELEGRAM_REPLY_CONTEXT_MAX_LEN + len("...")
|
||||||
|
|
||||||
|
|
||||||
def test_extract_reply_context_no_text_returns_none() -> None:
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_reply_context_no_text_returns_none() -> None:
|
||||||
"""When reply has no text/caption, _extract_reply_context returns None (media handled separately)."""
|
"""When reply has no text/caption, _extract_reply_context returns None (media handled separately)."""
|
||||||
|
channel = TelegramChannel(TelegramConfig(enabled=True, token="123:abc"), MessageBus())
|
||||||
reply = SimpleNamespace(text=None, caption=None)
|
reply = SimpleNamespace(text=None, caption=None)
|
||||||
message = SimpleNamespace(reply_to_message=reply)
|
message = SimpleNamespace(reply_to_message=reply)
|
||||||
assert TelegramChannel._extract_reply_context(message) is None
|
assert await channel._extract_reply_context(message) is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
|
"""Test Azure OpenAI provider (Responses API via OpenAI SDK)."""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -8,392 +8,401 @@ from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
|||||||
from nanobot.providers.base import LLMResponse
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
|
||||||
def test_azure_openai_provider_init():
|
# ---------------------------------------------------------------------------
|
||||||
"""Test AzureOpenAIProvider initialization without deployment_name."""
|
# Init & validation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_creates_sdk_client():
|
||||||
|
"""Provider creates an AsyncOpenAI client with correct base_url."""
|
||||||
provider = AzureOpenAIProvider(
|
provider = AzureOpenAIProvider(
|
||||||
api_key="test-key",
|
api_key="test-key",
|
||||||
api_base="https://test-resource.openai.azure.com",
|
api_base="https://test-resource.openai.azure.com",
|
||||||
default_model="gpt-4o-deployment",
|
default_model="gpt-4o-deployment",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert provider.api_key == "test-key"
|
assert provider.api_key == "test-key"
|
||||||
assert provider.api_base == "https://test-resource.openai.azure.com/"
|
assert provider.api_base == "https://test-resource.openai.azure.com/"
|
||||||
assert provider.default_model == "gpt-4o-deployment"
|
assert provider.default_model == "gpt-4o-deployment"
|
||||||
assert provider.api_version == "2024-10-21"
|
# SDK client base_url ends with /openai/v1/
|
||||||
|
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
|
||||||
|
|
||||||
|
|
||||||
def test_azure_openai_provider_init_validation():
|
def test_init_base_url_no_trailing_slash():
|
||||||
"""Test AzureOpenAIProvider initialization validation."""
|
"""Trailing slashes are normalised before building base_url."""
|
||||||
# Missing api_key
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="k", api_base="https://res.openai.azure.com",
|
||||||
|
)
|
||||||
|
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_base_url_with_trailing_slash():
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="k", api_base="https://res.openai.azure.com/",
|
||||||
|
)
|
||||||
|
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_validation_missing_key():
|
||||||
with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
|
with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
|
||||||
AzureOpenAIProvider(api_key="", api_base="https://test.com")
|
AzureOpenAIProvider(api_key="", api_base="https://test.com")
|
||||||
|
|
||||||
# Missing api_base
|
|
||||||
|
def test_init_validation_missing_base():
|
||||||
with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
|
with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
|
||||||
AzureOpenAIProvider(api_key="test", api_base="")
|
AzureOpenAIProvider(api_key="test", api_base="")
|
||||||
|
|
||||||
|
|
||||||
def test_build_chat_url():
|
def test_no_api_version_in_base_url():
|
||||||
"""Test Azure OpenAI URL building with different deployment names."""
|
"""The /openai/v1/ path should NOT contain an api-version query param."""
|
||||||
|
provider = AzureOpenAIProvider(api_key="k", api_base="https://res.openai.azure.com")
|
||||||
|
base = str(provider._client.base_url)
|
||||||
|
assert "api-version" not in base
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _supports_temperature
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_supports_temperature_standard_model():
|
||||||
|
assert AzureOpenAIProvider._supports_temperature("gpt-4o") is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_supports_temperature_reasoning_model():
|
||||||
|
assert AzureOpenAIProvider._supports_temperature("o3-mini") is False
|
||||||
|
assert AzureOpenAIProvider._supports_temperature("gpt-5-chat") is False
|
||||||
|
assert AzureOpenAIProvider._supports_temperature("o4-mini") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_supports_temperature_with_reasoning_effort():
|
||||||
|
assert AzureOpenAIProvider._supports_temperature("gpt-4o", reasoning_effort="medium") is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _build_body — Responses API body construction
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_body_basic():
|
||||||
provider = AzureOpenAIProvider(
|
provider = AzureOpenAIProvider(
|
||||||
api_key="test-key",
|
api_key="k", api_base="https://res.openai.azure.com", default_model="gpt-4o",
|
||||||
api_base="https://test-resource.openai.azure.com",
|
)
|
||||||
default_model="gpt-4o",
|
messages = [{"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hi"}]
|
||||||
|
body = provider._build_body(messages, None, None, 4096, 0.7, None, None)
|
||||||
|
|
||||||
|
assert body["model"] == "gpt-4o"
|
||||||
|
assert body["instructions"] == "You are helpful."
|
||||||
|
assert body["temperature"] == 0.7
|
||||||
|
assert body["max_output_tokens"] == 4096
|
||||||
|
assert body["store"] is False
|
||||||
|
assert "reasoning" not in body
|
||||||
|
# input should contain the converted user message only (system extracted)
|
||||||
|
assert any(
|
||||||
|
item.get("role") == "user"
|
||||||
|
for item in body["input"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test various deployment names
|
|
||||||
test_cases = [
|
|
||||||
("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"),
|
|
||||||
("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"),
|
|
||||||
("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"),
|
|
||||||
]
|
|
||||||
|
|
||||||
for deployment_name, expected_url in test_cases:
|
def test_build_body_max_tokens_minimum():
|
||||||
url = provider._build_chat_url(deployment_name)
|
"""max_output_tokens should never be less than 1."""
|
||||||
assert url == expected_url
|
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
|
||||||
|
body = provider._build_body([{"role": "user", "content": "x"}], None, None, 0, 0.7, None, None)
|
||||||
|
assert body["max_output_tokens"] == 1
|
||||||
|
|
||||||
|
|
||||||
def test_build_chat_url_api_base_without_slash():
|
def test_build_body_with_tools():
|
||||||
"""Test URL building when api_base doesn't end with slash."""
|
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
|
||||||
provider = AzureOpenAIProvider(
|
|
||||||
api_key="test-key",
|
|
||||||
api_base="https://test-resource.openai.azure.com", # No trailing slash
|
|
||||||
default_model="gpt-4o",
|
|
||||||
)
|
|
||||||
|
|
||||||
url = provider._build_chat_url("test-deployment")
|
|
||||||
expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21"
|
|
||||||
assert url == expected
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_headers():
|
|
||||||
"""Test Azure OpenAI header building with api-key authentication."""
|
|
||||||
provider = AzureOpenAIProvider(
|
|
||||||
api_key="test-api-key-123",
|
|
||||||
api_base="https://test-resource.openai.azure.com",
|
|
||||||
default_model="gpt-4o",
|
|
||||||
)
|
|
||||||
|
|
||||||
headers = provider._build_headers()
|
|
||||||
assert headers["Content-Type"] == "application/json"
|
|
||||||
assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header
|
|
||||||
assert "x-session-affinity" in headers
|
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_request_payload():
|
|
||||||
"""Test request payload preparation with Azure OpenAI 2024-10-21 compliance."""
|
|
||||||
provider = AzureOpenAIProvider(
|
|
||||||
api_key="test-key",
|
|
||||||
api_base="https://test-resource.openai.azure.com",
|
|
||||||
default_model="gpt-4o",
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "Hello"}]
|
|
||||||
payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
|
|
||||||
|
|
||||||
assert payload["messages"] == messages
|
|
||||||
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
|
|
||||||
assert payload["temperature"] == 0.8
|
|
||||||
assert "tools" not in payload
|
|
||||||
|
|
||||||
# Test with tools
|
|
||||||
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||||
payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools)
|
body = provider._build_body(
|
||||||
assert payload_with_tools["tools"] == tools
|
[{"role": "user", "content": "weather?"}], tools, None, 4096, 0.7, None, None,
|
||||||
assert payload_with_tools["tool_choice"] == "auto"
|
|
||||||
|
|
||||||
# Test with reasoning_effort
|
|
||||||
payload_with_reasoning = provider._prepare_request_payload(
|
|
||||||
"gpt-5-chat", messages, reasoning_effort="medium"
|
|
||||||
)
|
)
|
||||||
assert payload_with_reasoning["reasoning_effort"] == "medium"
|
assert body["tools"] == [{"type": "function", "name": "get_weather", "description": "", "parameters": {}}]
|
||||||
assert "temperature" not in payload_with_reasoning
|
assert body["tool_choice"] == "auto"
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_request_payload_sanitizes_messages():
|
def test_build_body_with_reasoning():
|
||||||
"""Test Azure payload strips non-standard message keys before sending."""
|
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-5-chat")
|
||||||
provider = AzureOpenAIProvider(
|
body = provider._build_body(
|
||||||
api_key="test-key",
|
[{"role": "user", "content": "think"}], None, "gpt-5-chat", 4096, 0.7, "medium", None,
|
||||||
api_base="https://test-resource.openai.azure.com",
|
|
||||||
default_model="gpt-4o",
|
|
||||||
)
|
)
|
||||||
|
assert body["reasoning"] == {"effort": "medium"}
|
||||||
|
assert "reasoning.encrypted_content" in body.get("include", [])
|
||||||
|
# temperature omitted for reasoning models
|
||||||
|
assert "temperature" not in body
|
||||||
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
|
||||||
"reasoning_content": "hidden chain-of-thought",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": "call_123",
|
|
||||||
"name": "x",
|
|
||||||
"content": "ok",
|
|
||||||
"extra_field": "should be removed",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
payload = provider._prepare_request_payload("gpt-4o", messages)
|
def test_build_body_image_conversion():
|
||||||
|
"""image_url content blocks should be converted to input_image."""
|
||||||
|
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
|
||||||
|
messages = [{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What's in this image?"},
|
||||||
|
{"type": "image_url", "image_url": {"url": "https://example.com/img.png"}},
|
||||||
|
],
|
||||||
|
}]
|
||||||
|
body = provider._build_body(messages, None, None, 4096, 0.7, None, None)
|
||||||
|
user_item = body["input"][0]
|
||||||
|
content_types = [b["type"] for b in user_item["content"]]
|
||||||
|
assert "input_text" in content_types
|
||||||
|
assert "input_image" in content_types
|
||||||
|
image_block = next(b for b in user_item["content"] if b["type"] == "input_image")
|
||||||
|
assert image_block["image_url"] == "https://example.com/img.png"
|
||||||
|
|
||||||
assert payload["messages"] == [
|
|
||||||
{
|
def test_build_body_sanitizes_single_dict_content_block():
|
||||||
"role": "assistant",
|
"""Single content dicts should be preserved via shared message sanitization."""
|
||||||
"content": None,
|
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
|
||||||
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
messages = [{
|
||||||
|
"role": "user",
|
||||||
|
"content": {"type": "text", "text": "Hi from dict content"},
|
||||||
|
}]
|
||||||
|
|
||||||
|
body = provider._build_body(messages, None, None, 4096, 0.7, None, None)
|
||||||
|
|
||||||
|
assert body["input"][0]["content"] == [{"type": "input_text", "text": "Hi from dict content"}]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# chat() — non-streaming
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sdk_response(
|
||||||
|
content="Hello!", tool_calls=None, status="completed",
|
||||||
|
usage=None,
|
||||||
|
):
|
||||||
|
"""Build a mock that quacks like an openai Response object."""
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.model_dump = MagicMock(return_value={
|
||||||
|
"output": [
|
||||||
|
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": content}]},
|
||||||
|
*([{
|
||||||
|
"type": "function_call",
|
||||||
|
"call_id": tc["call_id"], "id": tc["id"],
|
||||||
|
"name": tc["name"], "arguments": tc["arguments"],
|
||||||
|
} for tc in (tool_calls or [])]),
|
||||||
|
],
|
||||||
|
"status": status,
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": (usage or {}).get("input_tokens", 10),
|
||||||
|
"output_tokens": (usage or {}).get("output_tokens", 5),
|
||||||
|
"total_tokens": (usage or {}).get("total_tokens", 15),
|
||||||
},
|
},
|
||||||
{
|
})
|
||||||
"role": "tool",
|
return resp
|
||||||
"tool_call_id": "call_123",
|
|
||||||
"name": "x",
|
|
||||||
"content": "ok",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_success():
|
async def test_chat_success():
|
||||||
"""Test successful chat request using model as deployment name."""
|
|
||||||
provider = AzureOpenAIProvider(
|
provider = AzureOpenAIProvider(
|
||||||
api_key="test-key",
|
api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||||
api_base="https://test-resource.openai.azure.com",
|
|
||||||
default_model="gpt-4o-deployment",
|
|
||||||
)
|
)
|
||||||
|
mock_resp = _make_sdk_response(content="Hello!")
|
||||||
|
provider._client.responses = MagicMock()
|
||||||
|
provider._client.responses.create = AsyncMock(return_value=mock_resp)
|
||||||
|
|
||||||
# Mock response data
|
result = await provider.chat([{"role": "user", "content": "Hi"}])
|
||||||
mock_response_data = {
|
|
||||||
"choices": [{
|
|
||||||
"message": {
|
|
||||||
"content": "Hello! How can I help you today?",
|
|
||||||
"role": "assistant"
|
|
||||||
},
|
|
||||||
"finish_reason": "stop"
|
|
||||||
}],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 12,
|
|
||||||
"completion_tokens": 18,
|
|
||||||
"total_tokens": 30
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch("httpx.AsyncClient") as mock_client:
|
|
||||||
mock_response = AsyncMock()
|
|
||||||
mock_response.status_code = 200
|
|
||||||
mock_response.json = Mock(return_value=mock_response_data)
|
|
||||||
|
|
||||||
mock_context = AsyncMock()
|
|
||||||
mock_context.post = AsyncMock(return_value=mock_response)
|
|
||||||
mock_client.return_value.__aenter__.return_value = mock_context
|
|
||||||
|
|
||||||
# Test with specific model (deployment name)
|
|
||||||
messages = [{"role": "user", "content": "Hello"}]
|
|
||||||
result = await provider.chat(messages, model="custom-deployment")
|
|
||||||
|
|
||||||
assert isinstance(result, LLMResponse)
|
assert isinstance(result, LLMResponse)
|
||||||
assert result.content == "Hello! How can I help you today?"
|
assert result.content == "Hello!"
|
||||||
assert result.finish_reason == "stop"
|
assert result.finish_reason == "stop"
|
||||||
assert result.usage["prompt_tokens"] == 12
|
assert result.usage["prompt_tokens"] == 10
|
||||||
assert result.usage["completion_tokens"] == 18
|
|
||||||
assert result.usage["total_tokens"] == 30
|
|
||||||
|
|
||||||
# Verify URL was built with the provided model as deployment name
|
|
||||||
call_args = mock_context.post.call_args
|
|
||||||
expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21"
|
|
||||||
assert call_args[0][0] == expected_url
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_uses_default_model_when_no_model_provided():
|
async def test_chat_uses_default_model():
|
||||||
"""Test that chat uses default_model when no model is specified."""
|
|
||||||
provider = AzureOpenAIProvider(
|
provider = AzureOpenAIProvider(
|
||||||
api_key="test-key",
|
api_key="k", api_base="https://test.openai.azure.com", default_model="my-deployment",
|
||||||
api_base="https://test-resource.openai.azure.com",
|
|
||||||
default_model="default-deployment",
|
|
||||||
)
|
)
|
||||||
|
mock_resp = _make_sdk_response(content="ok")
|
||||||
|
provider._client.responses = MagicMock()
|
||||||
|
provider._client.responses.create = AsyncMock(return_value=mock_resp)
|
||||||
|
|
||||||
mock_response_data = {
|
await provider.chat([{"role": "user", "content": "test"}])
|
||||||
"choices": [{
|
|
||||||
"message": {"content": "Response", "role": "assistant"},
|
|
||||||
"finish_reason": "stop"
|
|
||||||
}],
|
|
||||||
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch("httpx.AsyncClient") as mock_client:
|
call_kwargs = provider._client.responses.create.call_args[1]
|
||||||
mock_response = AsyncMock()
|
assert call_kwargs["model"] == "my-deployment"
|
||||||
mock_response.status_code = 200
|
|
||||||
mock_response.json = Mock(return_value=mock_response_data)
|
|
||||||
|
|
||||||
mock_context = AsyncMock()
|
|
||||||
mock_context.post = AsyncMock(return_value=mock_response)
|
|
||||||
mock_client.return_value.__aenter__.return_value = mock_context
|
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "Test"}]
|
@pytest.mark.asyncio
|
||||||
await provider.chat(messages) # No model specified
|
async def test_chat_custom_model():
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
mock_resp = _make_sdk_response(content="ok")
|
||||||
|
provider._client.responses = MagicMock()
|
||||||
|
provider._client.responses.create = AsyncMock(return_value=mock_resp)
|
||||||
|
|
||||||
# Verify URL was built with default model as deployment name
|
await provider.chat([{"role": "user", "content": "test"}], model="custom-deploy")
|
||||||
call_args = mock_context.post.call_args
|
|
||||||
expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21"
|
call_kwargs = provider._client.responses.create.call_args[1]
|
||||||
assert call_args[0][0] == expected_url
|
assert call_kwargs["model"] == "custom-deploy"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_with_tool_calls():
|
async def test_chat_with_tool_calls():
|
||||||
"""Test chat request with tool calls in response."""
|
|
||||||
provider = AzureOpenAIProvider(
|
provider = AzureOpenAIProvider(
|
||||||
api_key="test-key",
|
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||||
api_base="https://test-resource.openai.azure.com",
|
)
|
||||||
default_model="gpt-4o",
|
mock_resp = _make_sdk_response(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[{
|
||||||
|
"call_id": "call_123", "id": "fc_1",
|
||||||
|
"name": "get_weather", "arguments": '{"location": "SF"}',
|
||||||
|
}],
|
||||||
|
)
|
||||||
|
provider._client.responses = MagicMock()
|
||||||
|
provider._client.responses.create = AsyncMock(return_value=mock_resp)
|
||||||
|
|
||||||
|
result = await provider.chat(
|
||||||
|
[{"role": "user", "content": "Weather?"}],
|
||||||
|
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock response with tool calls
|
|
||||||
mock_response_data = {
|
|
||||||
"choices": [{
|
|
||||||
"message": {
|
|
||||||
"content": None,
|
|
||||||
"role": "assistant",
|
|
||||||
"tool_calls": [{
|
|
||||||
"id": "call_12345",
|
|
||||||
"function": {
|
|
||||||
"name": "get_weather",
|
|
||||||
"arguments": '{"location": "San Francisco"}'
|
|
||||||
}
|
|
||||||
}]
|
|
||||||
},
|
|
||||||
"finish_reason": "tool_calls"
|
|
||||||
}],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 20,
|
|
||||||
"completion_tokens": 15,
|
|
||||||
"total_tokens": 35
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch("httpx.AsyncClient") as mock_client:
|
|
||||||
mock_response = AsyncMock()
|
|
||||||
mock_response.status_code = 200
|
|
||||||
mock_response.json = Mock(return_value=mock_response_data)
|
|
||||||
|
|
||||||
mock_context = AsyncMock()
|
|
||||||
mock_context.post = AsyncMock(return_value=mock_response)
|
|
||||||
mock_client.return_value.__aenter__.return_value = mock_context
|
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "What's the weather?"}]
|
|
||||||
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
|
||||||
result = await provider.chat(messages, tools=tools, model="weather-model")
|
|
||||||
|
|
||||||
assert isinstance(result, LLMResponse)
|
|
||||||
assert result.content is None
|
|
||||||
assert result.finish_reason == "tool_calls"
|
|
||||||
assert len(result.tool_calls) == 1
|
assert len(result.tool_calls) == 1
|
||||||
assert result.tool_calls[0].name == "get_weather"
|
assert result.tool_calls[0].name == "get_weather"
|
||||||
assert result.tool_calls[0].arguments == {"location": "San Francisco"}
|
assert result.tool_calls[0].arguments == {"location": "SF"}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_api_error():
|
async def test_chat_error_handling():
|
||||||
"""Test chat request API error handling."""
|
|
||||||
provider = AzureOpenAIProvider(
|
provider = AzureOpenAIProvider(
|
||||||
api_key="test-key",
|
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||||
api_base="https://test-resource.openai.azure.com",
|
|
||||||
default_model="gpt-4o",
|
|
||||||
)
|
)
|
||||||
|
provider._client.responses = MagicMock()
|
||||||
|
provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed"))
|
||||||
|
|
||||||
with patch("httpx.AsyncClient") as mock_client:
|
result = await provider.chat([{"role": "user", "content": "Hi"}])
|
||||||
mock_response = AsyncMock()
|
|
||||||
mock_response.status_code = 401
|
|
||||||
mock_response.text = "Invalid authentication credentials"
|
|
||||||
|
|
||||||
mock_context = AsyncMock()
|
|
||||||
mock_context.post = AsyncMock(return_value=mock_response)
|
|
||||||
mock_client.return_value.__aenter__.return_value = mock_context
|
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "Hello"}]
|
|
||||||
result = await provider.chat(messages)
|
|
||||||
|
|
||||||
assert isinstance(result, LLMResponse)
|
assert isinstance(result, LLMResponse)
|
||||||
assert "Azure OpenAI API Error 401" in result.content
|
assert "Connection failed" in result.content
|
||||||
assert "Invalid authentication credentials" in result.content
|
|
||||||
assert result.finish_reason == "error"
|
assert result.finish_reason == "error"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_connection_error():
|
async def test_chat_reasoning_param_format():
|
||||||
"""Test chat request connection error handling."""
|
"""reasoning_effort should be sent as reasoning={effort: ...} not a flat string."""
|
||||||
provider = AzureOpenAIProvider(
|
provider = AzureOpenAIProvider(
|
||||||
api_key="test-key",
|
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-5-chat",
|
||||||
api_base="https://test-resource.openai.azure.com",
|
)
|
||||||
default_model="gpt-4o",
|
mock_resp = _make_sdk_response(content="thought")
|
||||||
|
provider._client.responses = MagicMock()
|
||||||
|
provider._client.responses.create = AsyncMock(return_value=mock_resp)
|
||||||
|
|
||||||
|
await provider.chat(
|
||||||
|
[{"role": "user", "content": "think"}], reasoning_effort="medium",
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch("httpx.AsyncClient") as mock_client:
|
call_kwargs = provider._client.responses.create.call_args[1]
|
||||||
mock_context = AsyncMock()
|
assert call_kwargs["reasoning"] == {"effort": "medium"}
|
||||||
mock_context.post = AsyncMock(side_effect=Exception("Connection failed"))
|
assert "reasoning_effort" not in call_kwargs
|
||||||
mock_client.return_value.__aenter__.return_value = mock_context
|
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "Hello"}]
|
|
||||||
result = await provider.chat(messages)
|
|
||||||
|
|
||||||
assert isinstance(result, LLMResponse)
|
# ---------------------------------------------------------------------------
|
||||||
assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content
|
# chat_stream()
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_stream_success():
|
||||||
|
"""Streaming should call on_content_delta and return combined response."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build mock SDK stream events
|
||||||
|
events = []
|
||||||
|
ev1 = MagicMock(type="response.output_text.delta", delta="Hello")
|
||||||
|
ev2 = MagicMock(type="response.output_text.delta", delta=" world")
|
||||||
|
resp_obj = MagicMock(status="completed")
|
||||||
|
ev3 = MagicMock(type="response.completed", response=resp_obj)
|
||||||
|
events = [ev1, ev2, ev3]
|
||||||
|
|
||||||
|
async def mock_stream():
|
||||||
|
for e in events:
|
||||||
|
yield e
|
||||||
|
|
||||||
|
provider._client.responses = MagicMock()
|
||||||
|
provider._client.responses.create = AsyncMock(return_value=mock_stream())
|
||||||
|
|
||||||
|
deltas: list[str] = []
|
||||||
|
|
||||||
|
async def on_delta(text: str) -> None:
|
||||||
|
deltas.append(text)
|
||||||
|
|
||||||
|
result = await provider.chat_stream(
|
||||||
|
[{"role": "user", "content": "Hi"}], on_content_delta=on_delta,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.content == "Hello world"
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
|
assert deltas == ["Hello", " world"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_stream_with_tool_calls():
|
||||||
|
"""Streaming tool calls should be accumulated correctly."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
item_added = MagicMock(type="function_call", call_id="call_1", id="fc_1", arguments="")
|
||||||
|
item_added.name = "get_weather"
|
||||||
|
ev_added = MagicMock(type="response.output_item.added", item=item_added)
|
||||||
|
ev_args_delta = MagicMock(type="response.function_call_arguments.delta", call_id="call_1", delta='{"loc')
|
||||||
|
ev_args_done = MagicMock(
|
||||||
|
type="response.function_call_arguments.done",
|
||||||
|
call_id="call_1", arguments='{"location":"SF"}',
|
||||||
|
)
|
||||||
|
item_done = MagicMock(
|
||||||
|
type="function_call", call_id="call_1", id="fc_1",
|
||||||
|
arguments='{"location":"SF"}',
|
||||||
|
)
|
||||||
|
item_done.name = "get_weather"
|
||||||
|
ev_item_done = MagicMock(type="response.output_item.done", item=item_done)
|
||||||
|
resp_obj = MagicMock(status="completed")
|
||||||
|
ev_completed = MagicMock(type="response.completed", response=resp_obj)
|
||||||
|
|
||||||
|
async def mock_stream():
|
||||||
|
for e in [ev_added, ev_args_delta, ev_args_done, ev_item_done, ev_completed]:
|
||||||
|
yield e
|
||||||
|
|
||||||
|
provider._client.responses = MagicMock()
|
||||||
|
provider._client.responses.create = AsyncMock(return_value=mock_stream())
|
||||||
|
|
||||||
|
result = await provider.chat_stream(
|
||||||
|
[{"role": "user", "content": "weather?"}],
|
||||||
|
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
assert result.tool_calls[0].name == "get_weather"
|
||||||
|
assert result.tool_calls[0].arguments == {"location": "SF"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_stream_error():
|
||||||
|
"""Streaming should return error when SDK raises."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
provider._client.responses = MagicMock()
|
||||||
|
provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed"))
|
||||||
|
|
||||||
|
result = await provider.chat_stream([{"role": "user", "content": "Hi"}])
|
||||||
|
|
||||||
|
assert "Connection failed" in result.content
|
||||||
assert result.finish_reason == "error"
|
assert result.finish_reason == "error"
|
||||||
|
|
||||||
|
|
||||||
def test_parse_response_malformed():
|
# ---------------------------------------------------------------------------
|
||||||
"""Test response parsing with malformed data."""
|
# get_default_model
|
||||||
provider = AzureOpenAIProvider(
|
# ---------------------------------------------------------------------------
|
||||||
api_key="test-key",
|
|
||||||
api_base="https://test-resource.openai.azure.com",
|
|
||||||
default_model="gpt-4o",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test with missing choices
|
|
||||||
malformed_response = {"usage": {"prompt_tokens": 10}}
|
|
||||||
result = provider._parse_response(malformed_response)
|
|
||||||
|
|
||||||
assert isinstance(result, LLMResponse)
|
|
||||||
assert "Error parsing Azure OpenAI response" in result.content
|
|
||||||
assert result.finish_reason == "error"
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_default_model():
|
def test_get_default_model():
|
||||||
"""Test get_default_model method."""
|
|
||||||
provider = AzureOpenAIProvider(
|
provider = AzureOpenAIProvider(
|
||||||
api_key="test-key",
|
api_key="k", api_base="https://r.com", default_model="my-deploy",
|
||||||
api_base="https://test-resource.openai.azure.com",
|
|
||||||
default_model="my-custom-deployment",
|
|
||||||
)
|
)
|
||||||
|
assert provider.get_default_model() == "my-deploy"
|
||||||
assert provider.get_default_model() == "my-custom-deployment"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Run basic tests
|
|
||||||
print("Running basic Azure OpenAI provider tests...")
|
|
||||||
|
|
||||||
# Test initialization
|
|
||||||
provider = AzureOpenAIProvider(
|
|
||||||
api_key="test-key",
|
|
||||||
api_base="https://test-resource.openai.azure.com",
|
|
||||||
default_model="gpt-4o-deployment",
|
|
||||||
)
|
|
||||||
print("✅ Provider initialization successful")
|
|
||||||
|
|
||||||
# Test URL building
|
|
||||||
url = provider._build_chat_url("my-deployment")
|
|
||||||
expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21"
|
|
||||||
assert url == expected
|
|
||||||
print("✅ URL building works correctly")
|
|
||||||
|
|
||||||
# Test headers
|
|
||||||
headers = provider._build_headers()
|
|
||||||
assert headers["api-key"] == "test-key"
|
|
||||||
assert headers["Content-Type"] == "application/json"
|
|
||||||
print("✅ Header building works correctly")
|
|
||||||
|
|
||||||
# Test payload preparation
|
|
||||||
messages = [{"role": "user", "content": "Test"}]
|
|
||||||
payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
|
|
||||||
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
|
|
||||||
print("✅ Payload preparation works correctly")
|
|
||||||
|
|
||||||
print("✅ All basic tests passed! Updated test file is working correctly.")
|
|
||||||
|
|||||||
@ -8,6 +8,7 @@ Validates that:
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
@ -53,6 +54,15 @@ def _fake_tool_call_response() -> SimpleNamespace:
|
|||||||
return SimpleNamespace(choices=[choice], usage=usage)
|
return SimpleNamespace(choices=[choice], usage=usage)
|
||||||
|
|
||||||
|
|
||||||
|
class _StalledStream:
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
await asyncio.sleep(3600)
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
|
||||||
def test_openrouter_spec_is_gateway() -> None:
|
def test_openrouter_spec_is_gateway() -> None:
|
||||||
spec = find_by_name("openrouter")
|
spec = find_by_name("openrouter")
|
||||||
assert spec is not None
|
assert spec is not None
|
||||||
@ -214,3 +224,54 @@ def test_openai_model_passthrough() -> None:
|
|||||||
spec=spec,
|
spec=spec,
|
||||||
)
|
)
|
||||||
assert provider.get_default_model() == "gpt-4o"
|
assert provider.get_default_model() == "gpt-4o"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compat_strips_message_level_reasoning_fields() -> None:
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
provider = OpenAICompatProvider()
|
||||||
|
|
||||||
|
sanitized = provider._sanitize_messages([
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "done",
|
||||||
|
"reasoning_content": "hidden",
|
||||||
|
"extra_content": {"debug": True},
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "fn", "arguments": "{}"},
|
||||||
|
"extra_content": {"google": {"thought_signature": "sig"}},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
])
|
||||||
|
|
||||||
|
assert "reasoning_content" not in sanitized[0]
|
||||||
|
assert "extra_content" not in sanitized[0]
|
||||||
|
assert sanitized[0]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_compat_stream_watchdog_returns_error_on_stall(monkeypatch) -> None:
|
||||||
|
monkeypatch.setenv("NANOBOT_STREAM_IDLE_TIMEOUT_S", "0")
|
||||||
|
mock_create = AsyncMock(return_value=_StalledStream())
|
||||||
|
spec = find_by_name("openai")
|
||||||
|
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||||
|
client_instance = MockClient.return_value
|
||||||
|
client_instance.chat.completions.create = mock_create
|
||||||
|
|
||||||
|
provider = OpenAICompatProvider(
|
||||||
|
api_key="sk-test-key",
|
||||||
|
default_model="gpt-4o",
|
||||||
|
spec=spec,
|
||||||
|
)
|
||||||
|
result = await provider.chat_stream(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.finish_reason == "error"
|
||||||
|
assert result.content is not None
|
||||||
|
assert "stream stalled" in result.content
|
||||||
|
|||||||
522
tests/providers/test_openai_responses.py
Normal file
522
tests/providers/test_openai_responses.py
Normal file
@ -0,0 +1,522 @@
|
|||||||
|
"""Tests for the shared openai_responses converters and parsers."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
|
from nanobot.providers.openai_responses.converters import (
|
||||||
|
convert_messages,
|
||||||
|
convert_tools,
|
||||||
|
convert_user_message,
|
||||||
|
split_tool_call_id,
|
||||||
|
)
|
||||||
|
from nanobot.providers.openai_responses.parsing import (
|
||||||
|
consume_sdk_stream,
|
||||||
|
map_finish_reason,
|
||||||
|
parse_response_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ======================================================================
|
||||||
|
# converters - split_tool_call_id
|
||||||
|
# ======================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestSplitToolCallId:
|
||||||
|
def test_plain_id(self):
|
||||||
|
assert split_tool_call_id("call_abc") == ("call_abc", None)
|
||||||
|
|
||||||
|
def test_compound_id(self):
|
||||||
|
assert split_tool_call_id("call_abc|fc_1") == ("call_abc", "fc_1")
|
||||||
|
|
||||||
|
def test_compound_empty_item_id(self):
|
||||||
|
assert split_tool_call_id("call_abc|") == ("call_abc", None)
|
||||||
|
|
||||||
|
def test_none(self):
|
||||||
|
assert split_tool_call_id(None) == ("call_0", None)
|
||||||
|
|
||||||
|
def test_empty_string(self):
|
||||||
|
assert split_tool_call_id("") == ("call_0", None)
|
||||||
|
|
||||||
|
def test_non_string(self):
|
||||||
|
assert split_tool_call_id(42) == ("call_0", None)
|
||||||
|
|
||||||
|
|
||||||
|
# ======================================================================
|
||||||
|
# converters - convert_user_message
|
||||||
|
# ======================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestConvertUserMessage:
|
||||||
|
def test_string_content(self):
|
||||||
|
result = convert_user_message("hello")
|
||||||
|
assert result == {"role": "user", "content": [{"type": "input_text", "text": "hello"}]}
|
||||||
|
|
||||||
|
def test_text_block(self):
|
||||||
|
result = convert_user_message([{"type": "text", "text": "hi"}])
|
||||||
|
assert result["content"] == [{"type": "input_text", "text": "hi"}]
|
||||||
|
|
||||||
|
def test_image_url_block(self):
|
||||||
|
result = convert_user_message([
|
||||||
|
{"type": "image_url", "image_url": {"url": "https://img.example/a.png"}},
|
||||||
|
])
|
||||||
|
assert result["content"] == [
|
||||||
|
{"type": "input_image", "image_url": "https://img.example/a.png", "detail": "auto"},
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_mixed_text_and_image(self):
|
||||||
|
result = convert_user_message([
|
||||||
|
{"type": "text", "text": "what's this?"},
|
||||||
|
{"type": "image_url", "image_url": {"url": "https://img.example/b.png"}},
|
||||||
|
])
|
||||||
|
assert len(result["content"]) == 2
|
||||||
|
assert result["content"][0]["type"] == "input_text"
|
||||||
|
assert result["content"][1]["type"] == "input_image"
|
||||||
|
|
||||||
|
def test_empty_list_falls_back(self):
|
||||||
|
result = convert_user_message([])
|
||||||
|
assert result["content"] == [{"type": "input_text", "text": ""}]
|
||||||
|
|
||||||
|
def test_none_falls_back(self):
|
||||||
|
result = convert_user_message(None)
|
||||||
|
assert result["content"] == [{"type": "input_text", "text": ""}]
|
||||||
|
|
||||||
|
def test_image_without_url_skipped(self):
|
||||||
|
result = convert_user_message([{"type": "image_url", "image_url": {}}])
|
||||||
|
assert result["content"] == [{"type": "input_text", "text": ""}]
|
||||||
|
|
||||||
|
def test_meta_fields_not_leaked(self):
|
||||||
|
"""_meta on content blocks must never appear in converted output."""
|
||||||
|
result = convert_user_message([
|
||||||
|
{"type": "text", "text": "hi", "_meta": {"path": "/tmp/x"}},
|
||||||
|
])
|
||||||
|
assert "_meta" not in result["content"][0]
|
||||||
|
|
||||||
|
def test_non_dict_items_skipped(self):
|
||||||
|
result = convert_user_message(["just a string", 42])
|
||||||
|
assert result["content"] == [{"type": "input_text", "text": ""}]
|
||||||
|
|
||||||
|
|
||||||
|
# ======================================================================
|
||||||
|
# converters - convert_messages
|
||||||
|
# ======================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestConvertMessages:
|
||||||
|
def test_system_extracted_as_instructions(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{"role": "user", "content": "Hi"},
|
||||||
|
]
|
||||||
|
instructions, items = convert_messages(msgs)
|
||||||
|
assert instructions == "You are helpful."
|
||||||
|
assert len(items) == 1
|
||||||
|
assert items[0]["role"] == "user"
|
||||||
|
|
||||||
|
def test_multiple_system_messages_last_wins(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "system", "content": "first"},
|
||||||
|
{"role": "system", "content": "second"},
|
||||||
|
{"role": "user", "content": "x"},
|
||||||
|
]
|
||||||
|
instructions, _ = convert_messages(msgs)
|
||||||
|
assert instructions == "second"
|
||||||
|
|
||||||
|
def test_user_message_converted(self):
|
||||||
|
_, items = convert_messages([{"role": "user", "content": "hello"}])
|
||||||
|
assert items[0]["role"] == "user"
|
||||||
|
assert items[0]["content"][0]["type"] == "input_text"
|
||||||
|
|
||||||
|
def test_assistant_text_message(self):
|
||||||
|
_, items = convert_messages([
|
||||||
|
{"role": "assistant", "content": "I'll help"},
|
||||||
|
])
|
||||||
|
assert items[0]["type"] == "message"
|
||||||
|
assert items[0]["role"] == "assistant"
|
||||||
|
assert items[0]["content"][0]["type"] == "output_text"
|
||||||
|
assert items[0]["content"][0]["text"] == "I'll help"
|
||||||
|
|
||||||
|
def test_assistant_empty_content_skipped(self):
|
||||||
|
_, items = convert_messages([{"role": "assistant", "content": ""}])
|
||||||
|
assert len(items) == 0
|
||||||
|
|
||||||
|
def test_assistant_with_tool_calls(self):
|
||||||
|
_, items = convert_messages([{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": "call_abc|fc_1",
|
||||||
|
"function": {"name": "get_weather", "arguments": '{"city":"SF"}'},
|
||||||
|
}],
|
||||||
|
}])
|
||||||
|
assert items[0]["type"] == "function_call"
|
||||||
|
assert items[0]["call_id"] == "call_abc"
|
||||||
|
assert items[0]["id"] == "fc_1"
|
||||||
|
assert items[0]["name"] == "get_weather"
|
||||||
|
|
||||||
|
def test_assistant_with_tool_calls_no_id(self):
|
||||||
|
"""Fallback IDs when tool_call.id is missing."""
|
||||||
|
_, items = convert_messages([{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{"function": {"name": "f1", "arguments": "{}"}}],
|
||||||
|
}])
|
||||||
|
assert items[0]["call_id"] == "call_0"
|
||||||
|
assert items[0]["id"].startswith("fc_")
|
||||||
|
|
||||||
|
def test_tool_message(self):
|
||||||
|
_, items = convert_messages([{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_abc",
|
||||||
|
"content": "result text",
|
||||||
|
}])
|
||||||
|
assert items[0]["type"] == "function_call_output"
|
||||||
|
assert items[0]["call_id"] == "call_abc"
|
||||||
|
assert items[0]["output"] == "result text"
|
||||||
|
|
||||||
|
def test_tool_message_dict_content(self):
|
||||||
|
_, items = convert_messages([{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_1",
|
||||||
|
"content": {"key": "value"},
|
||||||
|
}])
|
||||||
|
assert items[0]["output"] == '{"key": "value"}'
|
||||||
|
|
||||||
|
def test_non_standard_keys_not_leaked(self):
|
||||||
|
"""Extra keys on messages must not appear in converted items."""
|
||||||
|
_, items = convert_messages([{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hi",
|
||||||
|
"extra_field": "should vanish",
|
||||||
|
"_meta": {"path": "/tmp"},
|
||||||
|
}])
|
||||||
|
item = items[0]
|
||||||
|
assert "extra_field" not in str(item)
|
||||||
|
assert "_meta" not in str(item)
|
||||||
|
|
||||||
|
def test_full_conversation_roundtrip(self):
|
||||||
|
"""System + user + assistant(tool_call) + tool -> correct structure."""
|
||||||
|
msgs = [
|
||||||
|
{"role": "system", "content": "Be concise."},
|
||||||
|
{"role": "user", "content": "Weather in SF?"},
|
||||||
|
{
|
||||||
|
"role": "assistant", "content": None,
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": "c1|fc1",
|
||||||
|
"function": {"name": "get_weather", "arguments": '{"city":"SF"}'},
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "c1", "content": '{"temp":72}'},
|
||||||
|
]
|
||||||
|
instructions, items = convert_messages(msgs)
|
||||||
|
assert instructions == "Be concise."
|
||||||
|
assert len(items) == 3 # user, function_call, function_call_output
|
||||||
|
assert items[0]["role"] == "user"
|
||||||
|
assert items[1]["type"] == "function_call"
|
||||||
|
assert items[2]["type"] == "function_call_output"
|
||||||
|
|
||||||
|
|
||||||
|
# ======================================================================
|
||||||
|
# converters - convert_tools
|
||||||
|
# ======================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestConvertTools:
|
||||||
|
def test_standard_function_tool(self):
|
||||||
|
tools = [{"type": "function", "function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather",
|
||||||
|
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
|
||||||
|
}}]
|
||||||
|
result = convert_tools(tools)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["type"] == "function"
|
||||||
|
assert result[0]["name"] == "get_weather"
|
||||||
|
assert result[0]["description"] == "Get weather"
|
||||||
|
assert "properties" in result[0]["parameters"]
|
||||||
|
|
||||||
|
def test_tool_without_name_skipped(self):
|
||||||
|
tools = [{"type": "function", "function": {"parameters": {}}}]
|
||||||
|
assert convert_tools(tools) == []
|
||||||
|
|
||||||
|
def test_tool_without_function_wrapper(self):
|
||||||
|
"""Direct dict without type=function wrapper."""
|
||||||
|
tools = [{"name": "f1", "description": "d", "parameters": {}}]
|
||||||
|
result = convert_tools(tools)
|
||||||
|
assert result[0]["name"] == "f1"
|
||||||
|
|
||||||
|
def test_missing_optional_fields_default(self):
|
||||||
|
tools = [{"type": "function", "function": {"name": "f"}}]
|
||||||
|
result = convert_tools(tools)
|
||||||
|
assert result[0]["description"] == ""
|
||||||
|
assert result[0]["parameters"] == {}
|
||||||
|
|
||||||
|
def test_multiple_tools(self):
|
||||||
|
tools = [
|
||||||
|
{"type": "function", "function": {"name": "a", "parameters": {}}},
|
||||||
|
{"type": "function", "function": {"name": "b", "parameters": {}}},
|
||||||
|
]
|
||||||
|
assert len(convert_tools(tools)) == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ======================================================================
|
||||||
|
# parsing - map_finish_reason
|
||||||
|
# ======================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestMapFinishReason:
|
||||||
|
def test_completed(self):
|
||||||
|
assert map_finish_reason("completed") == "stop"
|
||||||
|
|
||||||
|
def test_incomplete(self):
|
||||||
|
assert map_finish_reason("incomplete") == "length"
|
||||||
|
|
||||||
|
def test_failed(self):
|
||||||
|
assert map_finish_reason("failed") == "error"
|
||||||
|
|
||||||
|
def test_cancelled(self):
|
||||||
|
assert map_finish_reason("cancelled") == "error"
|
||||||
|
|
||||||
|
def test_none_defaults_to_stop(self):
|
||||||
|
assert map_finish_reason(None) == "stop"
|
||||||
|
|
||||||
|
def test_unknown_defaults_to_stop(self):
|
||||||
|
assert map_finish_reason("some_new_status") == "stop"
|
||||||
|
|
||||||
|
|
||||||
|
# ======================================================================
|
||||||
|
# parsing - parse_response_output
|
||||||
|
# ======================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseResponseOutput:
|
||||||
|
def test_text_response(self):
|
||||||
|
resp = {
|
||||||
|
"output": [{"type": "message", "role": "assistant",
|
||||||
|
"content": [{"type": "output_text", "text": "Hello!"}]}],
|
||||||
|
"status": "completed",
|
||||||
|
"usage": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||||
|
}
|
||||||
|
result = parse_response_output(resp)
|
||||||
|
assert result.content == "Hello!"
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
|
assert result.usage == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
|
||||||
|
assert result.tool_calls == []
|
||||||
|
|
||||||
|
def test_tool_call_response(self):
|
||||||
|
resp = {
|
||||||
|
"output": [{
|
||||||
|
"type": "function_call",
|
||||||
|
"call_id": "call_1", "id": "fc_1",
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"city": "SF"}',
|
||||||
|
}],
|
||||||
|
"status": "completed",
|
||||||
|
"usage": {},
|
||||||
|
}
|
||||||
|
result = parse_response_output(resp)
|
||||||
|
assert result.content is None
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
assert result.tool_calls[0].name == "get_weather"
|
||||||
|
assert result.tool_calls[0].arguments == {"city": "SF"}
|
||||||
|
assert result.tool_calls[0].id == "call_1|fc_1"
|
||||||
|
|
||||||
|
def test_malformed_tool_arguments_logged(self):
|
||||||
|
"""Malformed JSON arguments should log a warning and fallback."""
|
||||||
|
resp = {
|
||||||
|
"output": [{
|
||||||
|
"type": "function_call",
|
||||||
|
"call_id": "c1", "id": "fc1",
|
||||||
|
"name": "f", "arguments": "{bad json",
|
||||||
|
}],
|
||||||
|
"status": "completed", "usage": {},
|
||||||
|
}
|
||||||
|
with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger:
|
||||||
|
result = parse_response_output(resp)
|
||||||
|
assert result.tool_calls[0].arguments == {"raw": "{bad json"}
|
||||||
|
mock_logger.warning.assert_called_once()
|
||||||
|
assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args)
|
||||||
|
|
||||||
|
def test_reasoning_content_extracted(self):
|
||||||
|
resp = {
|
||||||
|
"output": [
|
||||||
|
{"type": "reasoning", "summary": [
|
||||||
|
{"type": "summary_text", "text": "I think "},
|
||||||
|
{"type": "summary_text", "text": "therefore I am."},
|
||||||
|
]},
|
||||||
|
{"type": "message", "role": "assistant",
|
||||||
|
"content": [{"type": "output_text", "text": "42"}]},
|
||||||
|
],
|
||||||
|
"status": "completed", "usage": {},
|
||||||
|
}
|
||||||
|
result = parse_response_output(resp)
|
||||||
|
assert result.content == "42"
|
||||||
|
assert result.reasoning_content == "I think therefore I am."
|
||||||
|
|
||||||
|
def test_empty_output(self):
|
||||||
|
resp = {"output": [], "status": "completed", "usage": {}}
|
||||||
|
result = parse_response_output(resp)
|
||||||
|
assert result.content is None
|
||||||
|
assert result.tool_calls == []
|
||||||
|
|
||||||
|
def test_incomplete_status(self):
|
||||||
|
resp = {"output": [], "status": "incomplete", "usage": {}}
|
||||||
|
result = parse_response_output(resp)
|
||||||
|
assert result.finish_reason == "length"
|
||||||
|
|
||||||
|
def test_sdk_model_object(self):
|
||||||
|
"""parse_response_output should handle SDK objects with model_dump()."""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.model_dump.return_value = {
|
||||||
|
"output": [{"type": "message", "role": "assistant",
|
||||||
|
"content": [{"type": "output_text", "text": "sdk"}]}],
|
||||||
|
"status": "completed",
|
||||||
|
"usage": {"input_tokens": 1, "output_tokens": 2, "total_tokens": 3},
|
||||||
|
}
|
||||||
|
result = parse_response_output(mock)
|
||||||
|
assert result.content == "sdk"
|
||||||
|
assert result.usage["prompt_tokens"] == 1
|
||||||
|
|
||||||
|
def test_usage_maps_responses_api_keys(self):
|
||||||
|
"""Responses API uses input_tokens/output_tokens, not prompt_tokens/completion_tokens."""
|
||||||
|
resp = {
|
||||||
|
"output": [],
|
||||||
|
"status": "completed",
|
||||||
|
"usage": {"input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
|
||||||
|
}
|
||||||
|
result = parse_response_output(resp)
|
||||||
|
assert result.usage["prompt_tokens"] == 100
|
||||||
|
assert result.usage["completion_tokens"] == 50
|
||||||
|
assert result.usage["total_tokens"] == 150
|
||||||
|
|
||||||
|
|
||||||
|
# ======================================================================
|
||||||
|
# parsing - consume_sdk_stream
|
||||||
|
# ======================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestConsumeSdkStream:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_stream(self):
|
||||||
|
ev1 = MagicMock(type="response.output_text.delta", delta="Hello")
|
||||||
|
ev2 = MagicMock(type="response.output_text.delta", delta=" world")
|
||||||
|
resp_obj = MagicMock(status="completed", usage=None, output=[])
|
||||||
|
ev3 = MagicMock(type="response.completed", response=resp_obj)
|
||||||
|
|
||||||
|
async def stream():
|
||||||
|
for e in [ev1, ev2, ev3]:
|
||||||
|
yield e
|
||||||
|
|
||||||
|
content, tool_calls, finish_reason, usage, reasoning = await consume_sdk_stream(stream())
|
||||||
|
assert content == "Hello world"
|
||||||
|
assert tool_calls == []
|
||||||
|
assert finish_reason == "stop"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_content_delta_called(self):
|
||||||
|
ev1 = MagicMock(type="response.output_text.delta", delta="hi")
|
||||||
|
resp_obj = MagicMock(status="completed", usage=None, output=[])
|
||||||
|
ev2 = MagicMock(type="response.completed", response=resp_obj)
|
||||||
|
deltas = []
|
||||||
|
|
||||||
|
async def cb(text):
|
||||||
|
deltas.append(text)
|
||||||
|
|
||||||
|
async def stream():
|
||||||
|
for e in [ev1, ev2]:
|
||||||
|
yield e
|
||||||
|
|
||||||
|
await consume_sdk_stream(stream(), on_content_delta=cb)
|
||||||
|
assert deltas == ["hi"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_call_stream(self):
|
||||||
|
item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="")
|
||||||
|
item_added.name = "get_weather"
|
||||||
|
ev1 = MagicMock(type="response.output_item.added", item=item_added)
|
||||||
|
ev2 = MagicMock(type="response.function_call_arguments.delta", call_id="c1", delta='{"ci')
|
||||||
|
ev3 = MagicMock(type="response.function_call_arguments.done", call_id="c1", arguments='{"city":"SF"}')
|
||||||
|
item_done = MagicMock(type="function_call", call_id="c1", id="fc1", arguments='{"city":"SF"}')
|
||||||
|
item_done.name = "get_weather"
|
||||||
|
ev4 = MagicMock(type="response.output_item.done", item=item_done)
|
||||||
|
resp_obj = MagicMock(status="completed", usage=None, output=[])
|
||||||
|
ev5 = MagicMock(type="response.completed", response=resp_obj)
|
||||||
|
|
||||||
|
async def stream():
|
||||||
|
for e in [ev1, ev2, ev3, ev4, ev5]:
|
||||||
|
yield e
|
||||||
|
|
||||||
|
content, tool_calls, finish_reason, usage, reasoning = await consume_sdk_stream(stream())
|
||||||
|
assert content == ""
|
||||||
|
assert len(tool_calls) == 1
|
||||||
|
assert tool_calls[0].name == "get_weather"
|
||||||
|
assert tool_calls[0].arguments == {"city": "SF"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_usage_extracted(self):
|
||||||
|
usage_obj = MagicMock(input_tokens=10, output_tokens=5, total_tokens=15)
|
||||||
|
resp_obj = MagicMock(status="completed", usage=usage_obj, output=[])
|
||||||
|
ev = MagicMock(type="response.completed", response=resp_obj)
|
||||||
|
|
||||||
|
async def stream():
|
||||||
|
yield ev
|
||||||
|
|
||||||
|
_, _, _, usage, _ = await consume_sdk_stream(stream())
|
||||||
|
assert usage == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reasoning_extracted(self):
|
||||||
|
summary_item = MagicMock(type="summary_text", text="thinking...")
|
||||||
|
reasoning_item = MagicMock(type="reasoning", summary=[summary_item])
|
||||||
|
resp_obj = MagicMock(status="completed", usage=None, output=[reasoning_item])
|
||||||
|
ev = MagicMock(type="response.completed", response=resp_obj)
|
||||||
|
|
||||||
|
async def stream():
|
||||||
|
yield ev
|
||||||
|
|
||||||
|
_, _, _, _, reasoning = await consume_sdk_stream(stream())
|
||||||
|
assert reasoning == "thinking..."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_error_event_raises(self):
|
||||||
|
ev = MagicMock(type="error", error="rate_limit_exceeded")
|
||||||
|
|
||||||
|
async def stream():
|
||||||
|
yield ev
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="Response failed.*rate_limit_exceeded"):
|
||||||
|
await consume_sdk_stream(stream())
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_failed_event_raises(self):
|
||||||
|
ev = MagicMock(type="response.failed", error="server_error")
|
||||||
|
|
||||||
|
async def stream():
|
||||||
|
yield ev
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="Response failed.*server_error"):
|
||||||
|
await consume_sdk_stream(stream())
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_malformed_tool_args_logged(self):
|
||||||
|
"""Malformed JSON in streaming tool args should log a warning."""
|
||||||
|
item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="")
|
||||||
|
item_added.name = "f"
|
||||||
|
ev1 = MagicMock(type="response.output_item.added", item=item_added)
|
||||||
|
ev2 = MagicMock(type="response.function_call_arguments.done", call_id="c1", arguments="{bad")
|
||||||
|
item_done = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="{bad")
|
||||||
|
item_done.name = "f"
|
||||||
|
ev3 = MagicMock(type="response.output_item.done", item=item_done)
|
||||||
|
resp_obj = MagicMock(status="completed", usage=None, output=[])
|
||||||
|
ev4 = MagicMock(type="response.completed", response=resp_obj)
|
||||||
|
|
||||||
|
async def stream():
|
||||||
|
for e in [ev1, ev2, ev3, ev4]:
|
||||||
|
yield e
|
||||||
|
|
||||||
|
with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger:
|
||||||
|
_, tool_calls, _, _, _ = await consume_sdk_stream(stream())
|
||||||
|
assert tool_calls[0].arguments == {"raw": "{bad"}
|
||||||
|
mock_logger.warning.assert_called_once()
|
||||||
|
assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args)
|
||||||
@ -211,3 +211,56 @@ async def test_image_fallback_without_meta_uses_default_placeholder() -> None:
|
|||||||
content = msg.get("content")
|
content = msg.get("content")
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
assert any("[image omitted]" in (b.get("text") or "") for b in content)
|
assert any("[image omitted]" in (b.get("text") or "") for b in content)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_with_retry_uses_retry_after_and_emits_wait_progress(monkeypatch) -> None:
|
||||||
|
provider = ScriptedProvider([
|
||||||
|
LLMResponse(content="429 rate limit, retry after 7s", finish_reason="error"),
|
||||||
|
LLMResponse(content="ok"),
|
||||||
|
])
|
||||||
|
delays: list[float] = []
|
||||||
|
progress: list[str] = []
|
||||||
|
|
||||||
|
async def _fake_sleep(delay: float) -> None:
|
||||||
|
delays.append(delay)
|
||||||
|
|
||||||
|
async def _progress(msg: str) -> None:
|
||||||
|
progress.append(msg)
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||||
|
|
||||||
|
response = await provider.chat_with_retry(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
on_retry_wait=_progress,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.content == "ok"
|
||||||
|
assert delays == [7.0]
|
||||||
|
assert progress and "7s" in progress[0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monkeypatch) -> None:
|
||||||
|
provider = ScriptedProvider([
|
||||||
|
*[LLMResponse(content="429 rate limit", finish_reason="error") for _ in range(10)],
|
||||||
|
LLMResponse(content="ok"),
|
||||||
|
])
|
||||||
|
delays: list[float] = []
|
||||||
|
|
||||||
|
async def _fake_sleep(delay: float) -> None:
|
||||||
|
delays.append(delay)
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||||
|
|
||||||
|
response = await provider.chat_with_retry(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
retry_mode="persistent",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.finish_reason == "error"
|
||||||
|
assert response.content == "429 rate limit"
|
||||||
|
assert provider.calls == 10
|
||||||
|
assert delays == [1, 2, 4, 4, 4, 4, 4, 4, 4]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -347,6 +347,8 @@ async def test_empty_response_retry_then_success(aiohttp_client) -> None:
|
|||||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_empty_response_falls_back(aiohttp_client) -> None:
|
async def test_empty_response_falls_back(aiohttp_client) -> None:
|
||||||
|
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def always_empty(content, session_key="", channel="", chat_id=""):
|
async def always_empty(content, session_key="", channel="", chat_id=""):
|
||||||
@ -367,5 +369,5 @@ async def test_empty_response_falls_back(aiohttp_client) -> None:
|
|||||||
)
|
)
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
body = await resp.json()
|
body = await resp.json()
|
||||||
assert body["choices"][0]["message"]["content"] == "I've completed processing but have no response to give."
|
assert body["choices"][0]["message"]["content"] == EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
assert call_count == 2
|
assert call_count == 2
|
||||||
|
|||||||
@ -196,7 +196,7 @@ async def test_execute_re_raises_external_cancellation() -> None:
|
|||||||
|
|
||||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=10)
|
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=10)
|
||||||
task = asyncio.create_task(wrapper.execute())
|
task = asyncio.create_task(wrapper.execute())
|
||||||
await started.wait()
|
await asyncio.wait_for(started.wait(), timeout=1.0)
|
||||||
|
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user