Merge remote-tracking branch 'origin/main' into feat/runtime-hardening

This commit is contained in:
Xubin Ren 2026-04-02 10:40:49 +00:00
commit eefd7e60f2
31 changed files with 2385 additions and 789 deletions

View File

@ -95,6 +95,15 @@ class _LoopHook(AgentHook):
logger.info("Tool call: {}({})", tc.name, args_str[:200]) logger.info("Tool call: {}({})", tc.name, args_str[:200])
self._loop._set_tool_context(self._channel, self._chat_id, self._message_id) self._loop._set_tool_context(self._channel, self._chat_id, self._message_id)
async def after_iteration(self, context: AgentHookContext) -> None:
u = context.usage or {}
logger.debug(
"LLM usage: prompt={} completion={} cached={}",
u.get("prompt_tokens", 0),
u.get("completion_tokens", 0),
u.get("cached_tokens", 0),
)
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
return self._loop._strip_think(content) return self._loop._strip_think(content)

View File

@ -77,7 +77,7 @@ class AgentRunner:
messages = list(spec.initial_messages) messages = list(spec.initial_messages)
final_content: str | None = None final_content: str | None = None
tools_used: list[str] = [] tools_used: list[str] = []
usage = {"prompt_tokens": 0, "completion_tokens": 0} usage: dict[str, int] = {}
error: str | None = None error: str | None = None
stop_reason = "completed" stop_reason = "completed"
tool_events: list[dict[str, str]] = [] tool_events: list[dict[str, str]] = []
@ -122,13 +122,15 @@ class AgentRunner:
response = await self.provider.chat_with_retry(**kwargs) response = await self.provider.chat_with_retry(**kwargs)
raw_usage = response.usage or {} raw_usage = response.usage or {}
usage = {
"prompt_tokens": int(raw_usage.get("prompt_tokens", 0) or 0),
"completion_tokens": int(raw_usage.get("completion_tokens", 0) or 0),
}
context.response = response context.response = response
context.usage = usage context.usage = raw_usage
context.tool_calls = list(response.tool_calls) context.tool_calls = list(response.tool_calls)
# Accumulate standard fields into result usage.
usage["prompt_tokens"] = usage.get("prompt_tokens", 0) + int(raw_usage.get("prompt_tokens", 0) or 0)
usage["completion_tokens"] = usage.get("completion_tokens", 0) + int(raw_usage.get("completion_tokens", 0) or 0)
cached = raw_usage.get("cached_tokens")
if cached:
usage["cached_tokens"] = usage.get("cached_tokens", 0) + int(cached)
if response.has_tool_calls: if response.has_tool_calls:
if hook.wants_streaming(): if hook.wants_streaming():

View File

@ -97,6 +97,11 @@ class CronTool(Tool):
f"(e.g. '2026-02-12T10:30:00'). Naive values default to {self._default_timezone}." f"(e.g. '2026-02-12T10:30:00'). Naive values default to {self._default_timezone}."
), ),
}, },
"deliver": {
"type": "boolean",
"description": "Whether to deliver the execution result to the user channel (default true)",
"default": True
},
"job_id": {"type": "string", "description": "Job ID (for remove)"}, "job_id": {"type": "string", "description": "Job ID (for remove)"},
}, },
"required": ["action"], "required": ["action"],
@ -111,12 +116,13 @@ class CronTool(Tool):
tz: str | None = None, tz: str | None = None,
at: str | None = None, at: str | None = None,
job_id: str | None = None, job_id: str | None = None,
deliver: bool = True,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
if action == "add": if action == "add":
if self._in_cron_context.get(): if self._in_cron_context.get():
return "Error: cannot schedule new jobs from within a cron job execution" return "Error: cannot schedule new jobs from within a cron job execution"
return self._add_job(message, every_seconds, cron_expr, tz, at) return self._add_job(message, every_seconds, cron_expr, tz, at, deliver)
elif action == "list": elif action == "list":
return self._list_jobs() return self._list_jobs()
elif action == "remove": elif action == "remove":
@ -130,6 +136,7 @@ class CronTool(Tool):
cron_expr: str | None, cron_expr: str | None,
tz: str | None, tz: str | None,
at: str | None, at: str | None,
deliver: bool = True,
) -> str: ) -> str:
if not message: if not message:
return "Error: message is required for add" return "Error: message is required for add"
@ -171,7 +178,7 @@ class CronTool(Tool):
name=message[:30], name=message[:30],
schedule=schedule, schedule=schedule,
message=message, message=message,
deliver=True, deliver=deliver,
channel=self._channel, channel=self._channel,
to=self._chat_id, to=self._chat_id,
delete_after_run=delete_after, delete_after_run=delete_after,

View File

@ -86,7 +86,15 @@ class MessageTool(Tool):
) -> str: ) -> str:
channel = channel or self._default_channel channel = channel or self._default_channel
chat_id = chat_id or self._default_chat_id chat_id = chat_id or self._default_chat_id
message_id = message_id or self._default_message_id # Only inherit default message_id when targeting the same channel+chat.
# Cross-chat sends must not carry the original message_id, because
# some channels (e.g. Feishu) use it to determine the target
# conversation via their Reply API, which would route the message
# to the wrong chat entirely.
if channel == self._default_channel and chat_id == self._default_chat_id:
message_id = message_id or self._default_message_id
else:
message_id = None
if not channel or not chat_id: if not channel or not chat_id:
return "Error: No target channel/chat specified" return "Error: No target channel/chat specified"
@ -101,7 +109,7 @@ class MessageTool(Tool):
media=media or [], media=media or [],
metadata={ metadata={
"message_id": message_id, "message_id": message_id,
}, } if message_id else {},
) )
try: try:

View File

@ -190,7 +190,9 @@ class ExecTool(Tool):
@staticmethod @staticmethod
def _extract_absolute_paths(command: str) -> list[str]: def _extract_absolute_paths(command: str) -> list[str]:
win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]+", command) # Windows: C:\... # Windows: match drive-root paths like `C:\` as well as `C:\path\to\file`
# NOTE: `*` is required so `C:\` (nothing after the slash) is still extracted.
win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]*", command)
posix_paths = re.findall(r"(?:^|[\s|>'\"])(/[^\s\"'>;|<]+)", command) # POSIX: /absolute only posix_paths = re.findall(r"(?:^|[\s|>'\"])(/[^\s\"'>;|<]+)", command) # POSIX: /absolute only
home_paths = re.findall(r"(?:^|[\s|>'\"])(~[^\s\"'>;|<]*)", command) # POSIX/Windows home shortcut: ~ home_paths = re.findall(r"(?:^|[\s|>'\"])(~[^\s\"'>;|<]*)", command) # POSIX/Windows home shortcut: ~
return win_paths + posix_paths + home_paths return win_paths + posix_paths + home_paths

View File

@ -415,6 +415,9 @@ def _make_provider(config: Config):
api_base=p.api_base, api_base=p.api_base,
default_model=model, default_model=model,
) )
elif backend == "github_copilot":
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
provider = GitHubCopilotProvider(default_model=model)
elif backend == "anthropic": elif backend == "anthropic":
from nanobot.providers.anthropic_provider import AnthropicProvider from nanobot.providers.anthropic_provider import AnthropicProvider
provider = AnthropicProvider( provider = AnthropicProvider(
@ -1029,12 +1032,18 @@ app.add_typer(channels_app, name="channels")
@channels_app.command("status") @channels_app.command("status")
def channels_status(): def channels_status(
config_path: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
):
"""Show channel status.""" """Show channel status."""
from nanobot.channels.registry import discover_all from nanobot.channels.registry import discover_all
from nanobot.config.loader import load_config from nanobot.config.loader import load_config, set_config_path
config = load_config() resolved_config_path = Path(config_path).expanduser().resolve() if config_path else None
if resolved_config_path is not None:
set_config_path(resolved_config_path)
config = load_config(resolved_config_path)
table = Table(title="Channel Status") table = Table(title="Channel Status")
table.add_column("Channel", style="cyan") table.add_column("Channel", style="cyan")
@ -1121,12 +1130,17 @@ def _get_bridge_dir() -> Path:
def channels_login( def channels_login(
channel_name: str = typer.Argument(..., help="Channel name (e.g. weixin, whatsapp)"), channel_name: str = typer.Argument(..., help="Channel name (e.g. weixin, whatsapp)"),
force: bool = typer.Option(False, "--force", "-f", help="Force re-authentication even if already logged in"), force: bool = typer.Option(False, "--force", "-f", help="Force re-authentication even if already logged in"),
config_path: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
): ):
"""Authenticate with a channel via QR code or other interactive login.""" """Authenticate with a channel via QR code or other interactive login."""
from nanobot.channels.registry import discover_all from nanobot.channels.registry import discover_all
from nanobot.config.loader import load_config from nanobot.config.loader import load_config, set_config_path
config = load_config() resolved_config_path = Path(config_path).expanduser().resolve() if config_path else None
if resolved_config_path is not None:
set_config_path(resolved_config_path)
config = load_config(resolved_config_path)
channel_cfg = getattr(config.channels, channel_name, None) or {} channel_cfg = getattr(config.channels, channel_name, None) or {}
# Validate channel exists # Validate channel exists
@ -1298,26 +1312,16 @@ def _login_openai_codex() -> None:
@_register_login("github_copilot") @_register_login("github_copilot")
def _login_github_copilot() -> None: def _login_github_copilot() -> None:
import asyncio
from openai import AsyncOpenAI
console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n")
async def _trigger():
client = AsyncOpenAI(
api_key="dummy",
base_url="https://api.githubcopilot.com",
)
await client.chat.completions.create(
model="gpt-4o",
messages=[{"role": "user", "content": "hi"}],
max_tokens=1,
)
try: try:
asyncio.run(_trigger()) from nanobot.providers.github_copilot_provider import login_github_copilot
console.print("[green]✓ Authenticated with GitHub Copilot[/green]")
console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n")
token = login_github_copilot(
print_fn=lambda s: console.print(s),
prompt_fn=lambda s: typer.prompt(s),
)
account = token.account_id or "GitHub"
console.print(f"[green]✓ Authenticated with GitHub Copilot[/green] [dim]{account}[/dim]")
except Exception as e: except Exception as e:
console.print(f"[red]Authentication error: {e}[/red]") console.print(f"[red]Authentication error: {e}[/red]")
raise typer.Exit(1) raise typer.Exit(1)

View File

@ -138,6 +138,10 @@ def _make_provider(config: Any) -> Any:
from nanobot.providers.openai_codex_provider import OpenAICodexProvider from nanobot.providers.openai_codex_provider import OpenAICodexProvider
provider = OpenAICodexProvider(default_model=model) provider = OpenAICodexProvider(default_model=model)
elif backend == "github_copilot":
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
provider = GitHubCopilotProvider(default_model=model)
elif backend == "azure_openai": elif backend == "azure_openai":
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider from nanobot.providers.azure_openai_provider import AzureOpenAIProvider

View File

@ -13,6 +13,7 @@ __all__ = [
"AnthropicProvider", "AnthropicProvider",
"OpenAICompatProvider", "OpenAICompatProvider",
"OpenAICodexProvider", "OpenAICodexProvider",
"GitHubCopilotProvider",
"AzureOpenAIProvider", "AzureOpenAIProvider",
] ]
@ -20,12 +21,14 @@ _LAZY_IMPORTS = {
"AnthropicProvider": ".anthropic_provider", "AnthropicProvider": ".anthropic_provider",
"OpenAICompatProvider": ".openai_compat_provider", "OpenAICompatProvider": ".openai_compat_provider",
"OpenAICodexProvider": ".openai_codex_provider", "OpenAICodexProvider": ".openai_codex_provider",
"GitHubCopilotProvider": ".github_copilot_provider",
"AzureOpenAIProvider": ".azure_openai_provider", "AzureOpenAIProvider": ".azure_openai_provider",
} }
if TYPE_CHECKING: if TYPE_CHECKING:
from nanobot.providers.anthropic_provider import AnthropicProvider from nanobot.providers.anthropic_provider import AnthropicProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
from nanobot.providers.openai_compat_provider import OpenAICompatProvider from nanobot.providers.openai_compat_provider import OpenAICompatProvider
from nanobot.providers.openai_codex_provider import OpenAICodexProvider from nanobot.providers.openai_codex_provider import OpenAICodexProvider

View File

@ -372,15 +372,22 @@ class AnthropicProvider(LLMProvider):
usage: dict[str, int] = {} usage: dict[str, int] = {}
if response.usage: if response.usage:
input_tokens = response.usage.input_tokens
cache_creation = getattr(response.usage, "cache_creation_input_tokens", 0) or 0
cache_read = getattr(response.usage, "cache_read_input_tokens", 0) or 0
total_prompt_tokens = input_tokens + cache_creation + cache_read
usage = { usage = {
"prompt_tokens": response.usage.input_tokens, "prompt_tokens": total_prompt_tokens,
"completion_tokens": response.usage.output_tokens, "completion_tokens": response.usage.output_tokens,
"total_tokens": response.usage.input_tokens + response.usage.output_tokens, "total_tokens": total_prompt_tokens + response.usage.output_tokens,
} }
for attr in ("cache_creation_input_tokens", "cache_read_input_tokens"): for attr in ("cache_creation_input_tokens", "cache_read_input_tokens"):
val = getattr(response.usage, attr, 0) val = getattr(response.usage, attr, 0)
if val: if val:
usage[attr] = val usage[attr] = val
# Normalize to cached_tokens for downstream consistency.
if cache_read:
usage["cached_tokens"] = cache_read
return LLMResponse( return LLMResponse(
content="".join(content_parts) or None, content="".join(content_parts) or None,

View File

@ -1,31 +1,36 @@
"""Azure OpenAI provider implementation with API version 2024-10-21.""" """Azure OpenAI provider using the OpenAI SDK Responses API.
Uses ``AsyncOpenAI`` pointed at ``https://{endpoint}/openai/v1/`` which
routes to the Responses API (``/responses``). Reuses shared conversion
helpers from :mod:`nanobot.providers.openai_responses`.
"""
from __future__ import annotations from __future__ import annotations
import json
import uuid import uuid
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import Any from typing import Any
from urllib.parse import urljoin
import httpx from openai import AsyncOpenAI
import json_repair
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.providers.base import LLMProvider, LLMResponse
from nanobot.providers.openai_responses import (
_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"}) consume_sdk_stream,
convert_messages,
convert_tools,
parse_response_output,
)
class AzureOpenAIProvider(LLMProvider): class AzureOpenAIProvider(LLMProvider):
""" """Azure OpenAI provider backed by the Responses API.
Azure OpenAI provider with API version 2024-10-21 compliance.
Features: Features:
- Hardcoded API version 2024-10-21 - Uses the OpenAI Python SDK (``AsyncOpenAI``) with
- Uses model field as Azure deployment name in URL path ``base_url = {endpoint}/openai/v1/``
- Uses api-key header instead of Authorization Bearer - Calls ``client.responses.create()`` (Responses API)
- Uses max_completion_tokens instead of max_tokens - Reuses shared message/tool/SSE conversion from
- Direct HTTP calls, bypasses LiteLLM ``openai_responses``
""" """
def __init__( def __init__(
@ -36,40 +41,28 @@ class AzureOpenAIProvider(LLMProvider):
): ):
super().__init__(api_key, api_base) super().__init__(api_key, api_base)
self.default_model = default_model self.default_model = default_model
self.api_version = "2024-10-21"
# Validate required parameters
if not api_key: if not api_key:
raise ValueError("Azure OpenAI api_key is required") raise ValueError("Azure OpenAI api_key is required")
if not api_base: if not api_base:
raise ValueError("Azure OpenAI api_base is required") raise ValueError("Azure OpenAI api_base is required")
# Ensure api_base ends with / # Normalise: ensure trailing slash
if not api_base.endswith('/'): if not api_base.endswith("/"):
api_base += '/' api_base += "/"
self.api_base = api_base self.api_base = api_base
def _build_chat_url(self, deployment_name: str) -> str: # SDK client targeting the Azure Responses API endpoint
"""Build the Azure OpenAI chat completions URL.""" base_url = f"{api_base.rstrip('/')}/openai/v1/"
# Azure OpenAI URL format: self._client = AsyncOpenAI(
# https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version} api_key=api_key,
base_url = self.api_base base_url=base_url,
if not base_url.endswith('/'): default_headers={"x-session-affinity": uuid.uuid4().hex},
base_url += '/'
url = urljoin(
base_url,
f"openai/deployments/{deployment_name}/chat/completions"
) )
return f"{url}?api-version={self.api_version}"
def _build_headers(self) -> dict[str, str]: # ------------------------------------------------------------------
"""Build headers for Azure OpenAI API with api-key header.""" # Helpers
return { # ------------------------------------------------------------------
"Content-Type": "application/json",
"api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization
"x-session-affinity": uuid.uuid4().hex, # For cache locality
}
@staticmethod @staticmethod
def _supports_temperature( def _supports_temperature(
@ -82,36 +75,51 @@ class AzureOpenAIProvider(LLMProvider):
name = deployment_name.lower() name = deployment_name.lower()
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4")) return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
def _prepare_request_payload( def _build_body(
self, self,
deployment_name: str,
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None,
max_tokens: int = 4096, model: str | None,
temperature: float = 0.7, max_tokens: int,
reasoning_effort: str | None = None, temperature: float,
tool_choice: str | dict[str, Any] | None = None, reasoning_effort: str | None,
tool_choice: str | dict[str, Any] | None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance.""" """Build the Responses API request body from Chat-Completions-style args."""
payload: dict[str, Any] = { deployment = model or self.default_model
"messages": self._sanitize_request_messages( instructions, input_items = convert_messages(self._sanitize_empty_content(messages))
self._sanitize_empty_content(messages),
_AZURE_MSG_KEYS, body: dict[str, Any] = {
), "model": deployment,
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens "instructions": instructions or None,
"input": input_items,
"max_output_tokens": max(1, max_tokens),
"store": False,
"stream": False,
} }
if self._supports_temperature(deployment_name, reasoning_effort): if self._supports_temperature(deployment, reasoning_effort):
payload["temperature"] = temperature body["temperature"] = temperature
if reasoning_effort: if reasoning_effort:
payload["reasoning_effort"] = reasoning_effort body["reasoning"] = {"effort": reasoning_effort}
body["include"] = ["reasoning.encrypted_content"]
if tools: if tools:
payload["tools"] = tools body["tools"] = convert_tools(tools)
payload["tool_choice"] = tool_choice or "auto" body["tool_choice"] = tool_choice or "auto"
return payload return body
@staticmethod
def _handle_error(e: Exception) -> LLMResponse:
body = getattr(e, "body", None) or getattr(getattr(e, "response", None), "text", None)
msg = f"Error: {str(body).strip()[:500]}" if body else f"Error calling Azure OpenAI: {e}"
return LLMResponse(content=msg, finish_reason="error")
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
async def chat( async def chat(
self, self,
@ -123,92 +131,15 @@ class AzureOpenAIProvider(LLMProvider):
reasoning_effort: str | None = None, reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None, tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse: ) -> LLMResponse:
""" body = self._build_body(
Send a chat completion request to Azure OpenAI. messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
Args:
messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions in OpenAI format.
model: Model identifier (used as deployment name).
max_tokens: Maximum tokens in response (mapped to max_completion_tokens).
temperature: Sampling temperature.
reasoning_effort: Optional reasoning effort parameter.
Returns:
LLMResponse with content and/or tool calls.
"""
deployment_name = model or self.default_model
url = self._build_chat_url(deployment_name)
headers = self._build_headers()
payload = self._prepare_request_payload(
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort,
tool_choice=tool_choice,
) )
try: try:
async with httpx.AsyncClient(timeout=60.0, verify=True) as client: response = await self._client.responses.create(**body)
response = await client.post(url, headers=headers, json=payload) return parse_response_output(response)
if response.status_code != 200:
return LLMResponse(
content=f"Azure OpenAI API Error {response.status_code}: {response.text}",
finish_reason="error",
)
response_data = response.json()
return self._parse_response(response_data)
except Exception as e: except Exception as e:
return LLMResponse( return self._handle_error(e)
content=f"Error calling Azure OpenAI: {repr(e)}",
finish_reason="error",
)
def _parse_response(self, response: dict[str, Any]) -> LLMResponse:
"""Parse Azure OpenAI response into our standard format."""
try:
choice = response["choices"][0]
message = choice["message"]
tool_calls = []
if message.get("tool_calls"):
for tc in message["tool_calls"]:
# Parse arguments from JSON string if needed
args = tc["function"]["arguments"]
if isinstance(args, str):
args = json_repair.loads(args)
tool_calls.append(
ToolCallRequest(
id=tc["id"],
name=tc["function"]["name"],
arguments=args,
)
)
usage = {}
if response.get("usage"):
usage_data = response["usage"]
usage = {
"prompt_tokens": usage_data.get("prompt_tokens", 0),
"completion_tokens": usage_data.get("completion_tokens", 0),
"total_tokens": usage_data.get("total_tokens", 0),
}
reasoning_content = message.get("reasoning_content") or None
return LLMResponse(
content=message.get("content"),
tool_calls=tool_calls,
finish_reason=choice.get("finish_reason", "stop"),
usage=usage,
reasoning_content=reasoning_content,
)
except (KeyError, IndexError) as e:
return LLMResponse(
content=f"Error parsing Azure OpenAI response: {str(e)}",
finish_reason="error",
)
async def chat_stream( async def chat_stream(
self, self,
@ -221,89 +152,26 @@ class AzureOpenAIProvider(LLMProvider):
tool_choice: str | dict[str, Any] | None = None, tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None, on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse: ) -> LLMResponse:
"""Stream a chat completion via Azure OpenAI SSE.""" body = self._build_body(
deployment_name = model or self.default_model messages, tools, model, max_tokens, temperature,
url = self._build_chat_url(deployment_name) reasoning_effort, tool_choice,
headers = self._build_headers()
payload = self._prepare_request_payload(
deployment_name, messages, tools, max_tokens, temperature,
reasoning_effort, tool_choice=tool_choice,
) )
payload["stream"] = True body["stream"] = True
try: try:
async with httpx.AsyncClient(timeout=60.0, verify=True) as client: stream = await self._client.responses.create(**body)
async with client.stream("POST", url, headers=headers, json=payload) as response: content, tool_calls, finish_reason, usage, reasoning_content = (
if response.status_code != 200: await consume_sdk_stream(stream, on_content_delta)
text = await response.aread()
return LLMResponse(
content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}",
finish_reason="error",
)
return await self._consume_stream(response, on_content_delta)
except Exception as e:
return LLMResponse(content=f"Error calling Azure OpenAI: {repr(e)}", finish_reason="error")
async def _consume_stream(
self,
response: httpx.Response,
on_content_delta: Callable[[str], Awaitable[None]] | None,
) -> LLMResponse:
"""Parse Azure OpenAI SSE stream into an LLMResponse."""
content_parts: list[str] = []
tool_call_buffers: dict[int, dict[str, str]] = {}
finish_reason = "stop"
async for line in response.aiter_lines():
if not line.startswith("data: "):
continue
data = line[6:].strip()
if data == "[DONE]":
break
try:
chunk = json.loads(data)
except Exception:
continue
choices = chunk.get("choices") or []
if not choices:
continue
choice = choices[0]
if choice.get("finish_reason"):
finish_reason = choice["finish_reason"]
delta = choice.get("delta") or {}
text = delta.get("content")
if text:
content_parts.append(text)
if on_content_delta:
await on_content_delta(text)
for tc in delta.get("tool_calls") or []:
idx = tc.get("index", 0)
buf = tool_call_buffers.setdefault(idx, {"id": "", "name": "", "arguments": ""})
if tc.get("id"):
buf["id"] = tc["id"]
fn = tc.get("function") or {}
if fn.get("name"):
buf["name"] = fn["name"]
if fn.get("arguments"):
buf["arguments"] += fn["arguments"]
tool_calls = [
ToolCallRequest(
id=buf["id"], name=buf["name"],
arguments=json_repair.loads(buf["arguments"]) if buf["arguments"] else {},
) )
for buf in tool_call_buffers.values() return LLMResponse(
] content=content or None,
tool_calls=tool_calls,
return LLMResponse( finish_reason=finish_reason,
content="".join(content_parts) or None, usage=usage,
tool_calls=tool_calls, reasoning_content=reasoning_content,
finish_reason=finish_reason, )
) except Exception as e:
return self._handle_error(e)
def get_default_model(self) -> str: def get_default_model(self) -> str:
"""Get the default model (also used as default deployment name)."""
return self.default_model return self.default_model

View File

@ -0,0 +1,257 @@
"""GitHub Copilot OAuth-backed provider."""
from __future__ import annotations
import time
import webbrowser
from collections.abc import Callable
import httpx
from oauth_cli_kit.models import OAuthToken
from oauth_cli_kit.storage import FileTokenStorage
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
DEFAULT_GITHUB_DEVICE_CODE_URL = "https://github.com/login/device/code"
DEFAULT_GITHUB_ACCESS_TOKEN_URL = "https://github.com/login/oauth/access_token"
DEFAULT_GITHUB_USER_URL = "https://api.github.com/user"
DEFAULT_COPILOT_TOKEN_URL = "https://api.github.com/copilot_internal/v2/token"
DEFAULT_COPILOT_BASE_URL = "https://api.githubcopilot.com"
GITHUB_COPILOT_CLIENT_ID = "Iv1.b507a08c87ecfe98"
GITHUB_COPILOT_SCOPE = "read:user"
TOKEN_FILENAME = "github-copilot.json"
TOKEN_APP_NAME = "nanobot"
USER_AGENT = "nanobot/0.1"
EDITOR_VERSION = "vscode/1.99.0"
EDITOR_PLUGIN_VERSION = "copilot-chat/0.26.0"
_EXPIRY_SKEW_SECONDS = 60
_LONG_LIVED_TOKEN_SECONDS = 315360000
def _storage() -> FileTokenStorage:
return FileTokenStorage(
token_filename=TOKEN_FILENAME,
app_name=TOKEN_APP_NAME,
import_codex_cli=False,
)
def _copilot_headers(token: str) -> dict[str, str]:
return {
"Authorization": f"token {token}",
"Accept": "application/json",
"User-Agent": USER_AGENT,
"Editor-Version": EDITOR_VERSION,
"Editor-Plugin-Version": EDITOR_PLUGIN_VERSION,
}
def _load_github_token() -> OAuthToken | None:
token = _storage().load()
if not token or not token.access:
return None
return token
def get_github_copilot_login_status() -> OAuthToken | None:
"""Return the persisted GitHub OAuth token if available."""
return _load_github_token()
def login_github_copilot(
print_fn: Callable[[str], None] | None = None,
prompt_fn: Callable[[str], str] | None = None,
) -> OAuthToken:
"""Run GitHub device flow and persist the GitHub OAuth token used for Copilot."""
del prompt_fn
printer = print_fn or print
timeout = httpx.Timeout(20.0, connect=20.0)
with httpx.Client(timeout=timeout, follow_redirects=True, trust_env=True) as client:
response = client.post(
DEFAULT_GITHUB_DEVICE_CODE_URL,
headers={"Accept": "application/json", "User-Agent": USER_AGENT},
data={"client_id": GITHUB_COPILOT_CLIENT_ID, "scope": GITHUB_COPILOT_SCOPE},
)
response.raise_for_status()
payload = response.json()
device_code = str(payload["device_code"])
user_code = str(payload["user_code"])
verify_url = str(payload.get("verification_uri") or payload.get("verification_uri_complete") or "")
verify_complete = str(payload.get("verification_uri_complete") or verify_url)
interval = max(1, int(payload.get("interval") or 5))
expires_in = int(payload.get("expires_in") or 900)
printer(f"Open: {verify_url}")
printer(f"Code: {user_code}")
if verify_complete:
try:
webbrowser.open(verify_complete)
except Exception:
pass
deadline = time.time() + expires_in
current_interval = interval
access_token = None
token_expires_in = _LONG_LIVED_TOKEN_SECONDS
while time.time() < deadline:
poll = client.post(
DEFAULT_GITHUB_ACCESS_TOKEN_URL,
headers={"Accept": "application/json", "User-Agent": USER_AGENT},
data={
"client_id": GITHUB_COPILOT_CLIENT_ID,
"device_code": device_code,
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
},
)
poll.raise_for_status()
poll_payload = poll.json()
access_token = poll_payload.get("access_token")
if access_token:
token_expires_in = int(poll_payload.get("expires_in") or _LONG_LIVED_TOKEN_SECONDS)
break
error = poll_payload.get("error")
if error == "authorization_pending":
time.sleep(current_interval)
continue
if error == "slow_down":
current_interval += 5
time.sleep(current_interval)
continue
if error == "expired_token":
raise RuntimeError("GitHub device code expired. Please run login again.")
if error == "access_denied":
raise RuntimeError("GitHub device flow was denied.")
if error:
desc = poll_payload.get("error_description") or error
raise RuntimeError(str(desc))
time.sleep(current_interval)
else:
raise RuntimeError("GitHub device flow timed out.")
user = client.get(
DEFAULT_GITHUB_USER_URL,
headers={
"Authorization": f"Bearer {access_token}",
"Accept": "application/vnd.github+json",
"User-Agent": USER_AGENT,
},
)
user.raise_for_status()
user_payload = user.json()
account_id = user_payload.get("login") or str(user_payload.get("id") or "") or None
expires_ms = int((time.time() + token_expires_in) * 1000)
token = OAuthToken(
access=str(access_token),
refresh="",
expires=expires_ms,
account_id=str(account_id) if account_id else None,
)
_storage().save(token)
return token
class GitHubCopilotProvider(OpenAICompatProvider):
"""Provider that exchanges a stored GitHub OAuth token for Copilot access tokens."""
def __init__(self, default_model: str = "github-copilot/gpt-4.1"):
from nanobot.providers.registry import find_by_name
self._copilot_access_token: str | None = None
self._copilot_expires_at: float = 0.0
super().__init__(
api_key="no-key",
api_base=DEFAULT_COPILOT_BASE_URL,
default_model=default_model,
extra_headers={
"Editor-Version": EDITOR_VERSION,
"Editor-Plugin-Version": EDITOR_PLUGIN_VERSION,
"User-Agent": USER_AGENT,
},
spec=find_by_name("github_copilot"),
)
async def _get_copilot_access_token(self) -> str:
now = time.time()
if self._copilot_access_token and now < self._copilot_expires_at - _EXPIRY_SKEW_SECONDS:
return self._copilot_access_token
github_token = _load_github_token()
if not github_token or not github_token.access:
raise RuntimeError("GitHub Copilot is not logged in. Run: nanobot provider login github-copilot")
timeout = httpx.Timeout(20.0, connect=20.0)
async with httpx.AsyncClient(timeout=timeout, follow_redirects=True, trust_env=True) as client:
response = await client.get(
DEFAULT_COPILOT_TOKEN_URL,
headers=_copilot_headers(github_token.access),
)
response.raise_for_status()
payload = response.json()
token = payload.get("token")
if not token:
raise RuntimeError("GitHub Copilot token exchange returned no token.")
expires_at = payload.get("expires_at")
if isinstance(expires_at, (int, float)):
self._copilot_expires_at = float(expires_at)
else:
refresh_in = payload.get("refresh_in") or 1500
self._copilot_expires_at = time.time() + int(refresh_in)
self._copilot_access_token = str(token)
return self._copilot_access_token
async def _refresh_client_api_key(self) -> str:
token = await self._get_copilot_access_token()
self.api_key = token
self._client.api_key = token
return token
async def chat(
self,
messages: list[dict[str, object]],
tools: list[dict[str, object]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, object] | None = None,
):
await self._refresh_client_api_key()
return await super().chat(
messages=messages,
tools=tools,
model=model,
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
tool_choice=tool_choice,
)
async def chat_stream(
self,
messages: list[dict[str, object]],
tools: list[dict[str, object]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, object] | None = None,
on_content_delta: Callable[[str], None] | None = None,
):
await self._refresh_client_api_key()
return await super().chat_stream(
messages=messages,
tools=tools,
model=model,
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
tool_choice=tool_choice,
on_content_delta=on_content_delta,
)

View File

@ -6,13 +6,18 @@ import asyncio
import hashlib import hashlib
import json import json
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import Any, AsyncGenerator from typing import Any
import httpx import httpx
from loguru import logger from loguru import logger
from oauth_cli_kit import get_token as get_codex_token from oauth_cli_kit import get_token as get_codex_token
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.openai_responses import (
consume_sse,
convert_messages,
convert_tools,
)
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses" DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
DEFAULT_ORIGINATOR = "nanobot" DEFAULT_ORIGINATOR = "nanobot"
@ -36,7 +41,7 @@ class OpenAICodexProvider(LLMProvider):
) -> LLMResponse: ) -> LLMResponse:
"""Shared request logic for both chat() and chat_stream().""" """Shared request logic for both chat() and chat_stream()."""
model = model or self.default_model model = model or self.default_model
system_prompt, input_items = _convert_messages(messages) system_prompt, input_items = convert_messages(messages)
token = await asyncio.to_thread(get_codex_token) token = await asyncio.to_thread(get_codex_token)
headers = _build_headers(token.account_id, token.access) headers = _build_headers(token.account_id, token.access)
@ -56,7 +61,7 @@ class OpenAICodexProvider(LLMProvider):
if reasoning_effort: if reasoning_effort:
body["reasoning"] = {"effort": reasoning_effort} body["reasoning"] = {"effort": reasoning_effort}
if tools: if tools:
body["tools"] = _convert_tools(tools) body["tools"] = convert_tools(tools)
try: try:
try: try:
@ -127,96 +132,7 @@ async def _request_codex(
if response.status_code != 200: if response.status_code != 200:
text = await response.aread() text = await response.aread()
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore"))) raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
return await _consume_sse(response, on_content_delta) return await consume_sse(response, on_content_delta)
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert OpenAI function-calling schema to Codex flat format."""
converted: list[dict[str, Any]] = []
for tool in tools:
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
name = fn.get("name")
if not name:
continue
params = fn.get("parameters") or {}
converted.append({
"type": "function",
"name": name,
"description": fn.get("description") or "",
"parameters": params if isinstance(params, dict) else {},
})
return converted
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
system_prompt = ""
input_items: list[dict[str, Any]] = []
for idx, msg in enumerate(messages):
role = msg.get("role")
content = msg.get("content")
if role == "system":
system_prompt = content if isinstance(content, str) else ""
continue
if role == "user":
input_items.append(_convert_user_message(content))
continue
if role == "assistant":
if isinstance(content, str) and content:
input_items.append({
"type": "message", "role": "assistant",
"content": [{"type": "output_text", "text": content}],
"status": "completed", "id": f"msg_{idx}",
})
for tool_call in msg.get("tool_calls", []) or []:
fn = tool_call.get("function") or {}
call_id, item_id = _split_tool_call_id(tool_call.get("id"))
input_items.append({
"type": "function_call",
"id": item_id or f"fc_{idx}",
"call_id": call_id or f"call_{idx}",
"name": fn.get("name"),
"arguments": fn.get("arguments") or "{}",
})
continue
if role == "tool":
call_id, _ = _split_tool_call_id(msg.get("tool_call_id"))
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text})
return system_prompt, input_items
def _convert_user_message(content: Any) -> dict[str, Any]:
if isinstance(content, str):
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
if isinstance(content, list):
converted: list[dict[str, Any]] = []
for item in content:
if not isinstance(item, dict):
continue
if item.get("type") == "text":
converted.append({"type": "input_text", "text": item.get("text", "")})
elif item.get("type") == "image_url":
url = (item.get("image_url") or {}).get("url")
if url:
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
if converted:
return {"role": "user", "content": converted}
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
def _split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
if isinstance(tool_call_id, str) and tool_call_id:
if "|" in tool_call_id:
call_id, item_id = tool_call_id.split("|", 1)
return call_id, item_id or None
return tool_call_id, None
return "call_0", None
def _prompt_cache_key(messages: list[dict[str, Any]]) -> str: def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
@ -224,96 +140,6 @@ def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
return hashlib.sha256(raw.encode("utf-8")).hexdigest() return hashlib.sha256(raw.encode("utf-8")).hexdigest()
async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
buffer: list[str] = []
async for line in response.aiter_lines():
if line == "":
if buffer:
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
buffer = []
if not data_lines:
continue
data = "\n".join(data_lines).strip()
if not data or data == "[DONE]":
continue
try:
yield json.loads(data)
except Exception:
continue
continue
buffer.append(line)
async def _consume_sse(
response: httpx.Response,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> tuple[str, list[ToolCallRequest], str]:
content = ""
tool_calls: list[ToolCallRequest] = []
tool_call_buffers: dict[str, dict[str, Any]] = {}
finish_reason = "stop"
async for event in _iter_sse(response):
event_type = event.get("type")
if event_type == "response.output_item.added":
item = event.get("item") or {}
if item.get("type") == "function_call":
call_id = item.get("call_id")
if not call_id:
continue
tool_call_buffers[call_id] = {
"id": item.get("id") or "fc_0",
"name": item.get("name"),
"arguments": item.get("arguments") or "",
}
elif event_type == "response.output_text.delta":
delta_text = event.get("delta") or ""
content += delta_text
if on_content_delta and delta_text:
await on_content_delta(delta_text)
elif event_type == "response.function_call_arguments.delta":
call_id = event.get("call_id")
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] += event.get("delta") or ""
elif event_type == "response.function_call_arguments.done":
call_id = event.get("call_id")
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] = event.get("arguments") or ""
elif event_type == "response.output_item.done":
item = event.get("item") or {}
if item.get("type") == "function_call":
call_id = item.get("call_id")
if not call_id:
continue
buf = tool_call_buffers.get(call_id) or {}
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
try:
args = json.loads(args_raw)
except Exception:
args = {"raw": args_raw}
tool_calls.append(
ToolCallRequest(
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
name=buf.get("name") or item.get("name"),
arguments=args,
)
)
elif event_type == "response.completed":
status = (event.get("response") or {}).get("status")
finish_reason = _map_finish_reason(status)
elif event_type in {"error", "response.failed"}:
raise RuntimeError("Codex response failed")
return content, tool_calls, finish_reason
_FINISH_REASON_MAP = {"completed": "stop", "incomplete": "length", "failed": "error", "cancelled": "error"}
def _map_finish_reason(status: str | None) -> str:
return _FINISH_REASON_MAP.get(status or "completed", "stop")
def _friendly_error(status_code: int, raw: str) -> str: def _friendly_error(status_code: int, raw: str) -> str:
if status_code == 429: if status_code == 429:
return "ChatGPT usage quota exceeded or rate limit triggered. Please try again later." return "ChatGPT usage quota exceeded or rate limit triggered. Please try again later."

View File

@ -235,7 +235,9 @@ class OpenAICompatProvider(LLMProvider):
spec = self._spec spec = self._spec
if spec and spec.supports_prompt_caching: if spec and spec.supports_prompt_caching:
messages, tools = self._apply_cache_control(messages, tools) model_name = model or self.default_model
if any(model_name.lower().startswith(k) for k in ("anthropic/", "claude")):
messages, tools = self._apply_cache_control(messages, tools)
if spec and spec.strip_model_prefix: if spec and spec.strip_model_prefix:
model_name = model_name.split("/")[-1] model_name = model_name.split("/")[-1]
@ -308,6 +310,13 @@ class OpenAICompatProvider(LLMProvider):
@classmethod @classmethod
def _extract_usage(cls, response: Any) -> dict[str, int]: def _extract_usage(cls, response: Any) -> dict[str, int]:
"""Extract token usage from an OpenAI-compatible response.
Handles both dict-based (raw JSON) and object-based (SDK Pydantic)
responses. Provider-specific ``cached_tokens`` fields are normalised
under a single key; see the priority chain inside for details.
"""
# --- resolve usage object ---
usage_obj = None usage_obj = None
response_map = cls._maybe_mapping(response) response_map = cls._maybe_mapping(response)
if response_map is not None: if response_map is not None:
@ -317,19 +326,53 @@ class OpenAICompatProvider(LLMProvider):
usage_map = cls._maybe_mapping(usage_obj) usage_map = cls._maybe_mapping(usage_obj)
if usage_map is not None: if usage_map is not None:
return { result = {
"prompt_tokens": int(usage_map.get("prompt_tokens") or 0), "prompt_tokens": int(usage_map.get("prompt_tokens") or 0),
"completion_tokens": int(usage_map.get("completion_tokens") or 0), "completion_tokens": int(usage_map.get("completion_tokens") or 0),
"total_tokens": int(usage_map.get("total_tokens") or 0), "total_tokens": int(usage_map.get("total_tokens") or 0),
} }
elif usage_obj:
if usage_obj: result = {
return {
"prompt_tokens": getattr(usage_obj, "prompt_tokens", 0) or 0, "prompt_tokens": getattr(usage_obj, "prompt_tokens", 0) or 0,
"completion_tokens": getattr(usage_obj, "completion_tokens", 0) or 0, "completion_tokens": getattr(usage_obj, "completion_tokens", 0) or 0,
"total_tokens": getattr(usage_obj, "total_tokens", 0) or 0, "total_tokens": getattr(usage_obj, "total_tokens", 0) or 0,
} }
return {} else:
return {}
# --- cached_tokens (normalised across providers) ---
# Try nested paths first (dict), fall back to attribute (SDK object).
# Priority order ensures the most specific field wins.
for path in (
("prompt_tokens_details", "cached_tokens"), # OpenAI/Zhipu/MiniMax/Qwen/Mistral/xAI
("cached_tokens",), # StepFun/Moonshot (top-level)
("prompt_cache_hit_tokens",), # DeepSeek/SiliconFlow
):
cached = cls._get_nested_int(usage_map, path)
if not cached and usage_obj:
cached = cls._get_nested_int(usage_obj, path)
if cached:
result["cached_tokens"] = cached
break
return result
@staticmethod
def _get_nested_int(obj: Any, path: tuple[str, ...]) -> int:
"""Drill into *obj* by *path* segments and return an ``int`` value.
Supports both dict-key access and attribute access so it works
uniformly with raw JSON dicts **and** SDK Pydantic models.
"""
current = obj
for segment in path:
if current is None:
return 0
if isinstance(current, dict):
current = current.get(segment)
else:
current = getattr(current, segment, None)
return int(current or 0) if current is not None else 0
def _parse(self, response: Any) -> LLMResponse: def _parse(self, response: Any) -> LLMResponse:
if isinstance(response, str): if isinstance(response, str):
@ -603,4 +646,4 @@ class OpenAICompatProvider(LLMProvider):
return self._handle_error(e) return self._handle_error(e)
def get_default_model(self) -> str: def get_default_model(self) -> str:
return self.default_model return self.default_model

View File

@ -0,0 +1,29 @@
"""Shared helpers for OpenAI Responses API providers (Codex, Azure OpenAI)."""
from nanobot.providers.openai_responses.converters import (
convert_messages,
convert_tools,
convert_user_message,
split_tool_call_id,
)
from nanobot.providers.openai_responses.parsing import (
FINISH_REASON_MAP,
consume_sdk_stream,
consume_sse,
iter_sse,
map_finish_reason,
parse_response_output,
)
__all__ = [
"convert_messages",
"convert_tools",
"convert_user_message",
"split_tool_call_id",
"iter_sse",
"consume_sse",
"consume_sdk_stream",
"map_finish_reason",
"parse_response_output",
"FINISH_REASON_MAP",
]

View File

@ -0,0 +1,110 @@
"""Convert Chat Completions messages/tools to Responses API format."""
from __future__ import annotations
import json
from typing import Any
def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
"""Convert Chat Completions messages to Responses API input items.
Returns ``(system_prompt, input_items)`` where *system_prompt* is extracted
from any ``system`` role message and *input_items* is the Responses API
``input`` array.
"""
system_prompt = ""
input_items: list[dict[str, Any]] = []
for idx, msg in enumerate(messages):
role = msg.get("role")
content = msg.get("content")
if role == "system":
system_prompt = content if isinstance(content, str) else ""
continue
if role == "user":
input_items.append(convert_user_message(content))
continue
if role == "assistant":
if isinstance(content, str) and content:
input_items.append({
"type": "message", "role": "assistant",
"content": [{"type": "output_text", "text": content}],
"status": "completed", "id": f"msg_{idx}",
})
for tool_call in msg.get("tool_calls", []) or []:
fn = tool_call.get("function") or {}
call_id, item_id = split_tool_call_id(tool_call.get("id"))
input_items.append({
"type": "function_call",
"id": item_id or f"fc_{idx}",
"call_id": call_id or f"call_{idx}",
"name": fn.get("name"),
"arguments": fn.get("arguments") or "{}",
})
continue
if role == "tool":
call_id, _ = split_tool_call_id(msg.get("tool_call_id"))
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text})
return system_prompt, input_items
def convert_user_message(content: Any) -> dict[str, Any]:
"""Convert a user message's content to Responses API format.
Handles plain strings, ``text`` blocks -> ``input_text``, and
``image_url`` blocks -> ``input_image``.
"""
if isinstance(content, str):
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
if isinstance(content, list):
converted: list[dict[str, Any]] = []
for item in content:
if not isinstance(item, dict):
continue
if item.get("type") == "text":
converted.append({"type": "input_text", "text": item.get("text", "")})
elif item.get("type") == "image_url":
url = (item.get("image_url") or {}).get("url")
if url:
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
if converted:
return {"role": "user", "content": converted}
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert OpenAI function-calling tool schema to Responses API flat format."""
converted: list[dict[str, Any]] = []
for tool in tools:
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
name = fn.get("name")
if not name:
continue
params = fn.get("parameters") or {}
converted.append({
"type": "function",
"name": name,
"description": fn.get("description") or "",
"parameters": params if isinstance(params, dict) else {},
})
return converted
def split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
"""Split a compound ``call_id|item_id`` string.
Returns ``(call_id, item_id)`` where *item_id* may be ``None``.
"""
if isinstance(tool_call_id, str) and tool_call_id:
if "|" in tool_call_id:
call_id, item_id = tool_call_id.split("|", 1)
return call_id, item_id or None
return tool_call_id, None
return "call_0", None

View File

@ -0,0 +1,297 @@
"""Parse Responses API SSE streams and SDK response objects."""
from __future__ import annotations
import json
from collections.abc import Awaitable, Callable
from typing import Any, AsyncGenerator
import httpx
import json_repair
from loguru import logger
from nanobot.providers.base import LLMResponse, ToolCallRequest
FINISH_REASON_MAP = {
"completed": "stop",
"incomplete": "length",
"failed": "error",
"cancelled": "error",
}
def map_finish_reason(status: str | None) -> str:
"""Map a Responses API status string to a Chat-Completions-style finish_reason."""
return FINISH_REASON_MAP.get(status or "completed", "stop")
async def iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
"""Yield parsed JSON events from a Responses API SSE stream."""
buffer: list[str] = []
def _flush() -> dict[str, Any] | None:
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
buffer.clear()
if not data_lines:
return None
data = "\n".join(data_lines).strip()
if not data or data == "[DONE]":
return None
try:
return json.loads(data)
except Exception:
logger.warning("Failed to parse SSE event JSON: {}", data[:200])
return None
async for line in response.aiter_lines():
if line == "":
if buffer:
event = _flush()
if event is not None:
yield event
continue
buffer.append(line)
# Flush any remaining buffer at EOF (#10)
if buffer:
event = _flush()
if event is not None:
yield event
async def consume_sse(
response: httpx.Response,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> tuple[str, list[ToolCallRequest], str]:
"""Consume a Responses API SSE stream into ``(content, tool_calls, finish_reason)``."""
content = ""
tool_calls: list[ToolCallRequest] = []
tool_call_buffers: dict[str, dict[str, Any]] = {}
finish_reason = "stop"
async for event in iter_sse(response):
event_type = event.get("type")
if event_type == "response.output_item.added":
item = event.get("item") or {}
if item.get("type") == "function_call":
call_id = item.get("call_id")
if not call_id:
continue
tool_call_buffers[call_id] = {
"id": item.get("id") or "fc_0",
"name": item.get("name"),
"arguments": item.get("arguments") or "",
}
elif event_type == "response.output_text.delta":
delta_text = event.get("delta") or ""
content += delta_text
if on_content_delta and delta_text:
await on_content_delta(delta_text)
elif event_type == "response.function_call_arguments.delta":
call_id = event.get("call_id")
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] += event.get("delta") or ""
elif event_type == "response.function_call_arguments.done":
call_id = event.get("call_id")
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] = event.get("arguments") or ""
elif event_type == "response.output_item.done":
item = event.get("item") or {}
if item.get("type") == "function_call":
call_id = item.get("call_id")
if not call_id:
continue
buf = tool_call_buffers.get(call_id) or {}
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
try:
args = json.loads(args_raw)
except Exception:
logger.warning(
"Failed to parse tool call arguments for '{}': {}",
buf.get("name") or item.get("name"),
args_raw[:200],
)
args = json_repair.loads(args_raw)
if not isinstance(args, dict):
args = {"raw": args_raw}
tool_calls.append(
ToolCallRequest(
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
name=buf.get("name") or item.get("name") or "",
arguments=args,
)
)
elif event_type == "response.completed":
status = (event.get("response") or {}).get("status")
finish_reason = map_finish_reason(status)
elif event_type in {"error", "response.failed"}:
detail = event.get("error") or event.get("message") or event
raise RuntimeError(f"Response failed: {str(detail)[:500]}")
return content, tool_calls, finish_reason
def parse_response_output(response: Any) -> LLMResponse:
"""Parse an SDK ``Response`` object into an ``LLMResponse``."""
if not isinstance(response, dict):
dump = getattr(response, "model_dump", None)
response = dump() if callable(dump) else vars(response)
output = response.get("output") or []
content_parts: list[str] = []
tool_calls: list[ToolCallRequest] = []
reasoning_content: str | None = None
for item in output:
if not isinstance(item, dict):
dump = getattr(item, "model_dump", None)
item = dump() if callable(dump) else vars(item)
item_type = item.get("type")
if item_type == "message":
for block in item.get("content") or []:
if not isinstance(block, dict):
dump = getattr(block, "model_dump", None)
block = dump() if callable(dump) else vars(block)
if block.get("type") == "output_text":
content_parts.append(block.get("text") or "")
elif item_type == "reasoning":
for s in item.get("summary") or []:
if not isinstance(s, dict):
dump = getattr(s, "model_dump", None)
s = dump() if callable(dump) else vars(s)
if s.get("type") == "summary_text" and s.get("text"):
reasoning_content = (reasoning_content or "") + s["text"]
elif item_type == "function_call":
call_id = item.get("call_id") or ""
item_id = item.get("id") or "fc_0"
args_raw = item.get("arguments") or "{}"
try:
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
except Exception:
logger.warning(
"Failed to parse tool call arguments for '{}': {}",
item.get("name"),
str(args_raw)[:200],
)
args = json_repair.loads(args_raw) if isinstance(args_raw, str) else args_raw
if not isinstance(args, dict):
args = {"raw": args_raw}
tool_calls.append(ToolCallRequest(
id=f"{call_id}|{item_id}",
name=item.get("name") or "",
arguments=args if isinstance(args, dict) else {},
))
usage_raw = response.get("usage") or {}
if not isinstance(usage_raw, dict):
dump = getattr(usage_raw, "model_dump", None)
usage_raw = dump() if callable(dump) else vars(usage_raw)
usage = {}
if usage_raw:
usage = {
"prompt_tokens": int(usage_raw.get("input_tokens") or 0),
"completion_tokens": int(usage_raw.get("output_tokens") or 0),
"total_tokens": int(usage_raw.get("total_tokens") or 0),
}
status = response.get("status")
finish_reason = map_finish_reason(status)
return LLMResponse(
content="".join(content_parts) or None,
tool_calls=tool_calls,
finish_reason=finish_reason,
usage=usage,
reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None,
)
async def consume_sdk_stream(
stream: Any,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> tuple[str, list[ToolCallRequest], str, dict[str, int], str | None]:
"""Consume an SDK async stream from ``client.responses.create(stream=True)``."""
content = ""
tool_calls: list[ToolCallRequest] = []
tool_call_buffers: dict[str, dict[str, Any]] = {}
finish_reason = "stop"
usage: dict[str, int] = {}
reasoning_content: str | None = None
async for event in stream:
event_type = getattr(event, "type", None)
if event_type == "response.output_item.added":
item = getattr(event, "item", None)
if item and getattr(item, "type", None) == "function_call":
call_id = getattr(item, "call_id", None)
if not call_id:
continue
tool_call_buffers[call_id] = {
"id": getattr(item, "id", None) or "fc_0",
"name": getattr(item, "name", None),
"arguments": getattr(item, "arguments", None) or "",
}
elif event_type == "response.output_text.delta":
delta_text = getattr(event, "delta", "") or ""
content += delta_text
if on_content_delta and delta_text:
await on_content_delta(delta_text)
elif event_type == "response.function_call_arguments.delta":
call_id = getattr(event, "call_id", None)
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] += getattr(event, "delta", "") or ""
elif event_type == "response.function_call_arguments.done":
call_id = getattr(event, "call_id", None)
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] = getattr(event, "arguments", "") or ""
elif event_type == "response.output_item.done":
item = getattr(event, "item", None)
if item and getattr(item, "type", None) == "function_call":
call_id = getattr(item, "call_id", None)
if not call_id:
continue
buf = tool_call_buffers.get(call_id) or {}
args_raw = buf.get("arguments") or getattr(item, "arguments", None) or "{}"
try:
args = json.loads(args_raw)
except Exception:
logger.warning(
"Failed to parse tool call arguments for '{}': {}",
buf.get("name") or getattr(item, "name", None),
str(args_raw)[:200],
)
args = json_repair.loads(args_raw)
if not isinstance(args, dict):
args = {"raw": args_raw}
tool_calls.append(
ToolCallRequest(
id=f"{call_id}|{buf.get('id') or getattr(item, 'id', None) or 'fc_0'}",
name=buf.get("name") or getattr(item, "name", None) or "",
arguments=args,
)
)
elif event_type == "response.completed":
resp = getattr(event, "response", None)
status = getattr(resp, "status", None) if resp else None
finish_reason = map_finish_reason(status)
if resp:
usage_obj = getattr(resp, "usage", None)
if usage_obj:
usage = {
"prompt_tokens": int(getattr(usage_obj, "input_tokens", 0) or 0),
"completion_tokens": int(getattr(usage_obj, "output_tokens", 0) or 0),
"total_tokens": int(getattr(usage_obj, "total_tokens", 0) or 0),
}
for out_item in getattr(resp, "output", None) or []:
if getattr(out_item, "type", None) == "reasoning":
for s in getattr(out_item, "summary", None) or []:
if getattr(s, "type", None) == "summary_text":
text = getattr(s, "text", None)
if text:
reasoning_content = (reasoning_content or "") + text
elif event_type in {"error", "response.failed"}:
detail = getattr(event, "error", None) or getattr(event, "message", None) or event
raise RuntimeError(f"Response failed: {str(detail)[:500]}")
return content, tool_calls, finish_reason, usage, reasoning_content

View File

@ -34,7 +34,7 @@ class ProviderSpec:
display_name: str = "" # shown in `nanobot status` display_name: str = "" # shown in `nanobot status`
# which provider implementation to use # which provider implementation to use
# "openai_compat" | "anthropic" | "azure_openai" | "openai_codex" # "openai_compat" | "anthropic" | "azure_openai" | "openai_codex" | "github_copilot"
backend: str = "openai_compat" backend: str = "openai_compat"
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),) # extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
@ -218,8 +218,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("github_copilot", "copilot"), keywords=("github_copilot", "copilot"),
env_key="", env_key="",
display_name="Github Copilot", display_name="Github Copilot",
backend="openai_compat", backend="github_copilot",
default_api_base="https://api.githubcopilot.com", default_api_base="https://api.githubcopilot.com",
strip_model_prefix=True,
is_oauth=True, is_oauth=True,
), ),
# DeepSeek: OpenAI-compatible at api.deepseek.com # DeepSeek: OpenAI-compatible at api.deepseek.com

View File

@ -405,14 +405,18 @@ def build_status_content(
) )
last_in = last_usage.get("prompt_tokens", 0) last_in = last_usage.get("prompt_tokens", 0)
last_out = last_usage.get("completion_tokens", 0) last_out = last_usage.get("completion_tokens", 0)
cached = last_usage.get("cached_tokens", 0)
ctx_total = max(context_window_tokens, 0) ctx_total = max(context_window_tokens, 0)
ctx_pct = int((context_tokens_estimate / ctx_total) * 100) if ctx_total > 0 else 0 ctx_pct = int((context_tokens_estimate / ctx_total) * 100) if ctx_total > 0 else 0
ctx_used_str = f"{context_tokens_estimate // 1000}k" if context_tokens_estimate >= 1000 else str(context_tokens_estimate) ctx_used_str = f"{context_tokens_estimate // 1000}k" if context_tokens_estimate >= 1000 else str(context_tokens_estimate)
ctx_total_str = f"{ctx_total // 1024}k" if ctx_total > 0 else "n/a" ctx_total_str = f"{ctx_total // 1024}k" if ctx_total > 0 else "n/a"
token_line = f"\U0001f4ca Tokens: {last_in} in / {last_out} out"
if cached and last_in:
token_line += f" ({cached * 100 // last_in}% cached)"
return "\n".join([ return "\n".join([
f"\U0001f408 nanobot v{version}", f"\U0001f408 nanobot v{version}",
f"\U0001f9e0 Model: {model}", f"\U0001f9e0 Model: {model}",
f"\U0001f4ca Tokens: {last_in} in / {last_out} out", token_line,
f"\U0001f4da Context: {ctx_used_str}/{ctx_total_str} ({ctx_pct}%)", f"\U0001f4da Context: {ctx_used_str}/{ctx_total_str} ({ctx_pct}%)",
f"\U0001f4ac Session: {session_msg_count} messages", f"\U0001f4ac Session: {session_msg_count} messages",
f"\u23f1 Uptime: {uptime}", f"\u23f1 Uptime: {uptime}",

View File

@ -578,3 +578,82 @@ async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, mon
args = mgr._announce_result.await_args.args args = mgr._announce_result.await_args.args
assert args[3] == "Task completed but no final response was generated." assert args[3] == "Task completed but no final response was generated."
assert args[5] == "ok" assert args[5] == "ok"
@pytest.mark.asyncio
async def test_runner_accumulates_usage_and_preserves_cached_tokens():
"""Runner should accumulate prompt/completion tokens across iterations
and preserve cached_tokens from provider responses."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})],
usage={"prompt_tokens": 100, "completion_tokens": 10, "cached_tokens": 80},
)
return LLMResponse(
content="done",
tool_calls=[],
usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150},
)
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="file content")
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "do task"}],
tools=tools,
model="test-model",
max_iterations=3,
))
# Usage should be accumulated across iterations
assert result.usage["prompt_tokens"] == 300 # 100 + 200
assert result.usage["completion_tokens"] == 30 # 10 + 20
assert result.usage["cached_tokens"] == 230 # 80 + 150
@pytest.mark.asyncio
async def test_runner_passes_cached_tokens_to_hook_context():
"""Hook context.usage should contain cached_tokens."""
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
captured_usage: list[dict] = []
class UsageHook(AgentHook):
async def after_iteration(self, context: AgentHookContext) -> None:
captured_usage.append(dict(context.usage))
async def chat_with_retry(**kwargs):
return LLMResponse(
content="done",
tool_calls=[],
usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150},
)
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
runner = AgentRunner(provider)
await runner.run(AgentRunSpec(
initial_messages=[],
tools=tools,
model="test-model",
max_iterations=1,
hook=UsageHook(),
))
assert len(captured_usage) == 1
assert captured_usage[0]["cached_tokens"] == 150

View File

@ -208,7 +208,7 @@ def test_channels_login_uses_discovered_plugin_class(monkeypatch):
seen["config"] = self.config seen["config"] = self.config
return True return True
monkeypatch.setattr("nanobot.config.loader.load_config", lambda: Config()) monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config())
monkeypatch.setattr( monkeypatch.setattr(
"nanobot.channels.registry.discover_all", "nanobot.channels.registry.discover_all",
lambda: {"fakeplugin": _LoginPlugin}, lambda: {"fakeplugin": _LoginPlugin},
@ -220,6 +220,57 @@ def test_channels_login_uses_discovered_plugin_class(monkeypatch):
assert seen["force"] is True assert seen["force"] is True
def test_channels_login_sets_custom_config_path(monkeypatch, tmp_path):
from nanobot.cli.commands import app
from nanobot.config.schema import Config
from typer.testing import CliRunner
runner = CliRunner()
seen: dict[str, object] = {}
config_path = tmp_path / "custom-config.json"
class _LoginPlugin(_FakePlugin):
async def login(self, force: bool = False) -> bool:
return True
monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config())
monkeypatch.setattr(
"nanobot.config.loader.set_config_path",
lambda path: seen.__setitem__("config_path", path),
)
monkeypatch.setattr(
"nanobot.channels.registry.discover_all",
lambda: {"fakeplugin": _LoginPlugin},
)
result = runner.invoke(app, ["channels", "login", "fakeplugin", "--config", str(config_path)])
assert result.exit_code == 0
assert seen["config_path"] == config_path.resolve()
def test_channels_status_sets_custom_config_path(monkeypatch, tmp_path):
from nanobot.cli.commands import app
from nanobot.config.schema import Config
from typer.testing import CliRunner
runner = CliRunner()
seen: dict[str, object] = {}
config_path = tmp_path / "custom-config.json"
monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config())
monkeypatch.setattr(
"nanobot.config.loader.set_config_path",
lambda path: seen.__setitem__("config_path", path),
)
monkeypatch.setattr("nanobot.channels.registry.discover_all", lambda: {})
result = runner.invoke(app, ["channels", "status", "--config", str(config_path)])
assert result.exit_code == 0
assert seen["config_path"] == config_path.resolve()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_manager_skips_disabled_plugin(): async def test_manager_skips_disabled_plugin():
fake_config = SimpleNamespace( fake_config = SimpleNamespace(

View File

@ -3,16 +3,14 @@ from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
import pytest import pytest
pytest.importorskip("nio")
pytest.importorskip("nh3")
pytest.importorskip("mistune")
from nio import RoomSendResponse from nio import RoomSendResponse
from nanobot.channels.matrix import _build_matrix_text_content from nanobot.channels.matrix import _build_matrix_text_content
# Check optional matrix dependencies before importing
try:
import nh3 # noqa: F401
except ImportError:
pytest.skip("Matrix dependencies not installed (nh3)", allow_module_level=True)
import nanobot.channels.matrix as matrix_module import nanobot.channels.matrix as matrix_module
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus

View File

@ -317,6 +317,75 @@ def test_openai_compat_provider_passes_model_through():
assert provider.get_default_model() == "github-copilot/gpt-5.3-codex" assert provider.get_default_model() == "github-copilot/gpt-5.3-codex"
def test_make_provider_uses_github_copilot_backend():
from nanobot.cli.commands import _make_provider
from nanobot.config.schema import Config
config = Config.model_validate(
{
"agents": {
"defaults": {
"provider": "github-copilot",
"model": "github-copilot/gpt-4.1",
}
}
}
)
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = _make_provider(config)
assert provider.__class__.__name__ == "GitHubCopilotProvider"
def test_github_copilot_provider_strips_prefixed_model_name():
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = GitHubCopilotProvider(default_model="github-copilot/gpt-5.1")
kwargs = provider._build_kwargs(
messages=[{"role": "user", "content": "hi"}],
tools=None,
model="github-copilot/gpt-5.1",
max_tokens=16,
temperature=0.1,
reasoning_effort=None,
tool_choice=None,
)
assert kwargs["model"] == "gpt-5.1"
@pytest.mark.asyncio
async def test_github_copilot_provider_refreshes_client_api_key_before_chat():
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
mock_client = MagicMock()
mock_client.api_key = "no-key"
mock_client.chat.completions.create = AsyncMock(return_value={
"choices": [{"message": {"content": "ok"}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
})
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI", return_value=mock_client):
provider = GitHubCopilotProvider(default_model="github-copilot/gpt-5.1")
provider._get_copilot_access_token = AsyncMock(return_value="copilot-access-token")
response = await provider.chat(
messages=[{"role": "user", "content": "hi"}],
model="github-copilot/gpt-5.1",
max_tokens=16,
temperature=0.1,
)
assert response.content == "ok"
assert provider._client.api_key == "copilot-access-token"
provider._get_copilot_access_token.assert_awaited_once()
mock_client.chat.completions.create.assert_awaited_once()
def test_openai_codex_strip_prefix_supports_hyphen_and_underscore(): def test_openai_codex_strip_prefix_supports_hyphen_and_underscore():
assert _strip_model_prefix("openai-codex/gpt-5.1-codex") == "gpt-5.1-codex" assert _strip_model_prefix("openai-codex/gpt-5.1-codex") == "gpt-5.1-codex"
assert _strip_model_prefix("openai_codex/gpt-5.1-codex") == "gpt-5.1-codex" assert _strip_model_prefix("openai_codex/gpt-5.1-codex") == "gpt-5.1-codex"

View File

@ -152,10 +152,12 @@ class TestRestartCommand:
]) ])
await loop._run_agent_loop([]) await loop._run_agent_loop([])
assert loop._last_usage == {"prompt_tokens": 9, "completion_tokens": 4} assert loop._last_usage["prompt_tokens"] == 9
assert loop._last_usage["completion_tokens"] == 4
await loop._run_agent_loop([]) await loop._run_agent_loop([])
assert loop._last_usage == {"prompt_tokens": 0, "completion_tokens": 0} assert loop._last_usage["prompt_tokens"] == 0
assert loop._last_usage["completion_tokens"] == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_status_falls_back_to_last_usage_when_context_estimate_missing(self): async def test_status_falls_back_to_last_usage_when_context_estimate_missing(self):

View File

@ -285,6 +285,28 @@ def test_add_at_job_uses_default_timezone_for_naive_datetime(tmp_path) -> None:
assert job.schedule.at_ms == expected assert job.schedule.at_ms == expected
def test_add_job_delivers_by_default(tmp_path) -> None:
tool = _make_tool(tmp_path)
tool.set_context("telegram", "chat-1")
result = tool._add_job("Morning standup", 60, None, None, None)
assert result.startswith("Created job")
job = tool._cron.list_jobs()[0]
assert job.payload.deliver is True
def test_add_job_can_disable_delivery(tmp_path) -> None:
tool = _make_tool(tmp_path)
tool.set_context("telegram", "chat-1")
result = tool._add_job("Background refresh", 60, None, None, None, deliver=False)
assert result.startswith("Created job")
job = tool._cron.list_jobs()[0]
assert job.payload.deliver is False
def test_list_excludes_disabled_jobs(tmp_path) -> None: def test_list_excludes_disabled_jobs(tmp_path) -> None:
tool = _make_tool(tmp_path) tool = _make_tool(tmp_path)
job = tool._cron.add_job( job = tool._cron.add_job(

View File

@ -1,6 +1,6 @@
"""Test Azure OpenAI provider implementation (updated for model-based deployment names).""" """Test Azure OpenAI provider (Responses API via OpenAI SDK)."""
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
@ -8,392 +8,401 @@ from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
from nanobot.providers.base import LLMResponse from nanobot.providers.base import LLMResponse
def test_azure_openai_provider_init(): # ---------------------------------------------------------------------------
"""Test AzureOpenAIProvider initialization without deployment_name.""" # Init & validation
# ---------------------------------------------------------------------------
def test_init_creates_sdk_client():
"""Provider creates an AsyncOpenAI client with correct base_url."""
provider = AzureOpenAIProvider( provider = AzureOpenAIProvider(
api_key="test-key", api_key="test-key",
api_base="https://test-resource.openai.azure.com", api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment", default_model="gpt-4o-deployment",
) )
assert provider.api_key == "test-key" assert provider.api_key == "test-key"
assert provider.api_base == "https://test-resource.openai.azure.com/" assert provider.api_base == "https://test-resource.openai.azure.com/"
assert provider.default_model == "gpt-4o-deployment" assert provider.default_model == "gpt-4o-deployment"
assert provider.api_version == "2024-10-21" # SDK client base_url ends with /openai/v1/
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
def test_azure_openai_provider_init_validation(): def test_init_base_url_no_trailing_slash():
"""Test AzureOpenAIProvider initialization validation.""" """Trailing slashes are normalised before building base_url."""
# Missing api_key provider = AzureOpenAIProvider(
api_key="k", api_base="https://res.openai.azure.com",
)
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
def test_init_base_url_with_trailing_slash():
provider = AzureOpenAIProvider(
api_key="k", api_base="https://res.openai.azure.com/",
)
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
def test_init_validation_missing_key():
with pytest.raises(ValueError, match="Azure OpenAI api_key is required"): with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
AzureOpenAIProvider(api_key="", api_base="https://test.com") AzureOpenAIProvider(api_key="", api_base="https://test.com")
# Missing api_base
def test_init_validation_missing_base():
with pytest.raises(ValueError, match="Azure OpenAI api_base is required"): with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
AzureOpenAIProvider(api_key="test", api_base="") AzureOpenAIProvider(api_key="test", api_base="")
def test_build_chat_url(): def test_no_api_version_in_base_url():
"""Test Azure OpenAI URL building with different deployment names.""" """The /openai/v1/ path should NOT contain an api-version query param."""
provider = AzureOpenAIProvider(api_key="k", api_base="https://res.openai.azure.com")
base = str(provider._client.base_url)
assert "api-version" not in base
# ---------------------------------------------------------------------------
# _supports_temperature
# ---------------------------------------------------------------------------
def test_supports_temperature_standard_model():
assert AzureOpenAIProvider._supports_temperature("gpt-4o") is True
def test_supports_temperature_reasoning_model():
assert AzureOpenAIProvider._supports_temperature("o3-mini") is False
assert AzureOpenAIProvider._supports_temperature("gpt-5-chat") is False
assert AzureOpenAIProvider._supports_temperature("o4-mini") is False
def test_supports_temperature_with_reasoning_effort():
assert AzureOpenAIProvider._supports_temperature("gpt-4o", reasoning_effort="medium") is False
# ---------------------------------------------------------------------------
# _build_body — Responses API body construction
# ---------------------------------------------------------------------------
def test_build_body_basic():
provider = AzureOpenAIProvider( provider = AzureOpenAIProvider(
api_key="test-key", api_key="k", api_base="https://res.openai.azure.com", default_model="gpt-4o",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
) )
messages = [{"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hi"}]
# Test various deployment names body = provider._build_body(messages, None, None, 4096, 0.7, None, None)
test_cases = [
("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"),
("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"),
("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"),
]
for deployment_name, expected_url in test_cases:
url = provider._build_chat_url(deployment_name)
assert url == expected_url
assert body["model"] == "gpt-4o"
def test_build_chat_url_api_base_without_slash(): assert body["instructions"] == "You are helpful."
"""Test URL building when api_base doesn't end with slash.""" assert body["temperature"] == 0.7
provider = AzureOpenAIProvider( assert body["max_output_tokens"] == 4096
api_key="test-key", assert body["store"] is False
api_base="https://test-resource.openai.azure.com", # No trailing slash assert "reasoning" not in body
default_model="gpt-4o", # input should contain the converted user message only (system extracted)
assert any(
item.get("role") == "user"
for item in body["input"]
) )
url = provider._build_chat_url("test-deployment")
expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21"
assert url == expected
def test_build_headers(): def test_build_body_max_tokens_minimum():
"""Test Azure OpenAI header building with api-key authentication.""" """max_output_tokens should never be less than 1."""
provider = AzureOpenAIProvider( provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
api_key="test-api-key-123", body = provider._build_body([{"role": "user", "content": "x"}], None, None, 0, 0.7, None, None)
api_base="https://test-resource.openai.azure.com", assert body["max_output_tokens"] == 1
default_model="gpt-4o",
)
headers = provider._build_headers()
assert headers["Content-Type"] == "application/json"
assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header
assert "x-session-affinity" in headers
def test_prepare_request_payload(): def test_build_body_with_tools():
"""Test request payload preparation with Azure OpenAI 2024-10-21 compliance.""" provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
messages = [{"role": "user", "content": "Hello"}]
payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
assert payload["messages"] == messages
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
assert payload["temperature"] == 0.8
assert "tools" not in payload
# Test with tools
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}] tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools) body = provider._build_body(
assert payload_with_tools["tools"] == tools [{"role": "user", "content": "weather?"}], tools, None, 4096, 0.7, None, None,
assert payload_with_tools["tool_choice"] == "auto"
# Test with reasoning_effort
payload_with_reasoning = provider._prepare_request_payload(
"gpt-5-chat", messages, reasoning_effort="medium"
) )
assert payload_with_reasoning["reasoning_effort"] == "medium" assert body["tools"] == [{"type": "function", "name": "get_weather", "description": "", "parameters": {}}]
assert "temperature" not in payload_with_reasoning assert body["tool_choice"] == "auto"
def test_prepare_request_payload_sanitizes_messages(): def test_build_body_with_reasoning():
"""Test Azure payload strips non-standard message keys before sending.""" provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-5-chat")
provider = AzureOpenAIProvider( body = provider._build_body(
api_key="test-key", [{"role": "user", "content": "think"}], None, "gpt-5-chat", 4096, 0.7, "medium", None,
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
) )
assert body["reasoning"] == {"effort": "medium"}
assert "reasoning.encrypted_content" in body.get("include", [])
# temperature omitted for reasoning models
assert "temperature" not in body
messages = [
{
"role": "assistant",
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
"reasoning_content": "hidden chain-of-thought",
},
{
"role": "tool",
"tool_call_id": "call_123",
"name": "x",
"content": "ok",
"extra_field": "should be removed",
},
]
payload = provider._prepare_request_payload("gpt-4o", messages) def test_build_body_image_conversion():
"""image_url content blocks should be converted to input_image."""
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
messages = [{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "https://example.com/img.png"}},
],
}]
body = provider._build_body(messages, None, None, 4096, 0.7, None, None)
user_item = body["input"][0]
content_types = [b["type"] for b in user_item["content"]]
assert "input_text" in content_types
assert "input_image" in content_types
image_block = next(b for b in user_item["content"] if b["type"] == "input_image")
assert image_block["image_url"] == "https://example.com/img.png"
assert payload["messages"] == [
{ def test_build_body_sanitizes_single_dict_content_block():
"role": "assistant", """Single content dicts should be preserved via shared message sanitization."""
"content": None, provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}], messages = [{
"role": "user",
"content": {"type": "text", "text": "Hi from dict content"},
}]
body = provider._build_body(messages, None, None, 4096, 0.7, None, None)
assert body["input"][0]["content"] == [{"type": "input_text", "text": "Hi from dict content"}]
# ---------------------------------------------------------------------------
# chat() — non-streaming
# ---------------------------------------------------------------------------
def _make_sdk_response(
content="Hello!", tool_calls=None, status="completed",
usage=None,
):
"""Build a mock that quacks like an openai Response object."""
resp = MagicMock()
resp.model_dump = MagicMock(return_value={
"output": [
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": content}]},
*([{
"type": "function_call",
"call_id": tc["call_id"], "id": tc["id"],
"name": tc["name"], "arguments": tc["arguments"],
} for tc in (tool_calls or [])]),
],
"status": status,
"usage": {
"input_tokens": (usage or {}).get("input_tokens", 10),
"output_tokens": (usage or {}).get("output_tokens", 5),
"total_tokens": (usage or {}).get("total_tokens", 15),
}, },
{ })
"role": "tool", return resp
"tool_call_id": "call_123",
"name": "x",
"content": "ok",
},
]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_success(): async def test_chat_success():
"""Test successful chat request using model as deployment name."""
provider = AzureOpenAIProvider( provider = AzureOpenAIProvider(
api_key="test-key", api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
) )
mock_resp = _make_sdk_response(content="Hello!")
# Mock response data provider._client.responses = MagicMock()
mock_response_data = { provider._client.responses.create = AsyncMock(return_value=mock_resp)
"choices": [{
"message": { result = await provider.chat([{"role": "user", "content": "Hi"}])
"content": "Hello! How can I help you today?",
"role": "assistant" assert isinstance(result, LLMResponse)
}, assert result.content == "Hello!"
"finish_reason": "stop" assert result.finish_reason == "stop"
}], assert result.usage["prompt_tokens"] == 10
"usage": {
"prompt_tokens": 12,
"completion_tokens": 18,
"total_tokens": 30
}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
# Test with specific model (deployment name)
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages, model="custom-deployment")
assert isinstance(result, LLMResponse)
assert result.content == "Hello! How can I help you today?"
assert result.finish_reason == "stop"
assert result.usage["prompt_tokens"] == 12
assert result.usage["completion_tokens"] == 18
assert result.usage["total_tokens"] == 30
# Verify URL was built with the provided model as deployment name
call_args = mock_context.post.call_args
expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21"
assert call_args[0][0] == expected_url
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_uses_default_model_when_no_model_provided(): async def test_chat_uses_default_model():
"""Test that chat uses default_model when no model is specified."""
provider = AzureOpenAIProvider( provider = AzureOpenAIProvider(
api_key="test-key", api_key="k", api_base="https://test.openai.azure.com", default_model="my-deployment",
api_base="https://test-resource.openai.azure.com",
default_model="default-deployment",
) )
mock_resp = _make_sdk_response(content="ok")
mock_response_data = { provider._client.responses = MagicMock()
"choices": [{ provider._client.responses.create = AsyncMock(return_value=mock_resp)
"message": {"content": "Response", "role": "assistant"},
"finish_reason": "stop" await provider.chat([{"role": "user", "content": "test"}])
}],
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10} call_kwargs = provider._client.responses.create.call_args[1]
} assert call_kwargs["model"] == "my-deployment"
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock() @pytest.mark.asyncio
mock_response.status_code = 200 async def test_chat_custom_model():
mock_response.json = Mock(return_value=mock_response_data) provider = AzureOpenAIProvider(
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
mock_context = AsyncMock() )
mock_context.post = AsyncMock(return_value=mock_response) mock_resp = _make_sdk_response(content="ok")
mock_client.return_value.__aenter__.return_value = mock_context provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_resp)
messages = [{"role": "user", "content": "Test"}]
await provider.chat(messages) # No model specified await provider.chat([{"role": "user", "content": "test"}], model="custom-deploy")
# Verify URL was built with default model as deployment name call_kwargs = provider._client.responses.create.call_args[1]
call_args = mock_context.post.call_args assert call_kwargs["model"] == "custom-deploy"
expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21"
assert call_args[0][0] == expected_url
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_with_tool_calls(): async def test_chat_with_tool_calls():
"""Test chat request with tool calls in response."""
provider = AzureOpenAIProvider( provider = AzureOpenAIProvider(
api_key="test-key", api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
) )
mock_resp = _make_sdk_response(
# Mock response with tool calls content=None,
mock_response_data = { tool_calls=[{
"choices": [{ "call_id": "call_123", "id": "fc_1",
"message": { "name": "get_weather", "arguments": '{"location": "SF"}',
"content": None,
"role": "assistant",
"tool_calls": [{
"id": "call_12345",
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}'
}
}]
},
"finish_reason": "tool_calls"
}], }],
"usage": { )
"prompt_tokens": 20, provider._client.responses = MagicMock()
"completion_tokens": 15, provider._client.responses.create = AsyncMock(return_value=mock_resp)
"total_tokens": 35
} result = await provider.chat(
} [{"role": "user", "content": "Weather?"}],
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
with patch("httpx.AsyncClient") as mock_client: )
mock_response = AsyncMock()
mock_response.status_code = 200 assert len(result.tool_calls) == 1
mock_response.json = Mock(return_value=mock_response_data) assert result.tool_calls[0].name == "get_weather"
assert result.tool_calls[0].arguments == {"location": "SF"}
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "What's the weather?"}]
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
result = await provider.chat(messages, tools=tools, model="weather-model")
assert isinstance(result, LLMResponse)
assert result.content is None
assert result.finish_reason == "tool_calls"
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "get_weather"
assert result.tool_calls[0].arguments == {"location": "San Francisco"}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_api_error(): async def test_chat_error_handling():
"""Test chat request API error handling."""
provider = AzureOpenAIProvider( provider = AzureOpenAIProvider(
api_key="test-key", api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
) )
provider._client.responses = MagicMock()
with patch("httpx.AsyncClient") as mock_client: provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed"))
mock_response = AsyncMock()
mock_response.status_code = 401
mock_response.text = "Invalid authentication credentials"
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages)
assert isinstance(result, LLMResponse)
assert "Azure OpenAI API Error 401" in result.content
assert "Invalid authentication credentials" in result.content
assert result.finish_reason == "error"
result = await provider.chat([{"role": "user", "content": "Hi"}])
@pytest.mark.asyncio
async def test_chat_connection_error():
"""Test chat request connection error handling."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
with patch("httpx.AsyncClient") as mock_client:
mock_context = AsyncMock()
mock_context.post = AsyncMock(side_effect=Exception("Connection failed"))
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages)
assert isinstance(result, LLMResponse)
assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content
assert result.finish_reason == "error"
def test_parse_response_malformed():
"""Test response parsing with malformed data."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
# Test with missing choices
malformed_response = {"usage": {"prompt_tokens": 10}}
result = provider._parse_response(malformed_response)
assert isinstance(result, LLMResponse) assert isinstance(result, LLMResponse)
assert "Error parsing Azure OpenAI response" in result.content assert "Connection failed" in result.content
assert result.finish_reason == "error" assert result.finish_reason == "error"
@pytest.mark.asyncio
async def test_chat_reasoning_param_format():
"""reasoning_effort should be sent as reasoning={effort: ...} not a flat string."""
provider = AzureOpenAIProvider(
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-5-chat",
)
mock_resp = _make_sdk_response(content="thought")
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_resp)
await provider.chat(
[{"role": "user", "content": "think"}], reasoning_effort="medium",
)
call_kwargs = provider._client.responses.create.call_args[1]
assert call_kwargs["reasoning"] == {"effort": "medium"}
assert "reasoning_effort" not in call_kwargs
# ---------------------------------------------------------------------------
# chat_stream()
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_chat_stream_success():
"""Streaming should call on_content_delta and return combined response."""
provider = AzureOpenAIProvider(
api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
# Build mock SDK stream events
events = []
ev1 = MagicMock(type="response.output_text.delta", delta="Hello")
ev2 = MagicMock(type="response.output_text.delta", delta=" world")
resp_obj = MagicMock(status="completed")
ev3 = MagicMock(type="response.completed", response=resp_obj)
events = [ev1, ev2, ev3]
async def mock_stream():
for e in events:
yield e
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_stream())
deltas: list[str] = []
async def on_delta(text: str) -> None:
deltas.append(text)
result = await provider.chat_stream(
[{"role": "user", "content": "Hi"}], on_content_delta=on_delta,
)
assert result.content == "Hello world"
assert result.finish_reason == "stop"
assert deltas == ["Hello", " world"]
@pytest.mark.asyncio
async def test_chat_stream_with_tool_calls():
"""Streaming tool calls should be accumulated correctly."""
provider = AzureOpenAIProvider(
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
item_added = MagicMock(type="function_call", call_id="call_1", id="fc_1", arguments="")
item_added.name = "get_weather"
ev_added = MagicMock(type="response.output_item.added", item=item_added)
ev_args_delta = MagicMock(type="response.function_call_arguments.delta", call_id="call_1", delta='{"loc')
ev_args_done = MagicMock(
type="response.function_call_arguments.done",
call_id="call_1", arguments='{"location":"SF"}',
)
item_done = MagicMock(
type="function_call", call_id="call_1", id="fc_1",
arguments='{"location":"SF"}',
)
item_done.name = "get_weather"
ev_item_done = MagicMock(type="response.output_item.done", item=item_done)
resp_obj = MagicMock(status="completed")
ev_completed = MagicMock(type="response.completed", response=resp_obj)
async def mock_stream():
for e in [ev_added, ev_args_delta, ev_args_done, ev_item_done, ev_completed]:
yield e
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_stream())
result = await provider.chat_stream(
[{"role": "user", "content": "weather?"}],
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
)
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "get_weather"
assert result.tool_calls[0].arguments == {"location": "SF"}
@pytest.mark.asyncio
async def test_chat_stream_error():
"""Streaming should return error when SDK raises."""
provider = AzureOpenAIProvider(
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed"))
result = await provider.chat_stream([{"role": "user", "content": "Hi"}])
assert "Connection failed" in result.content
assert result.finish_reason == "error"
# ---------------------------------------------------------------------------
# get_default_model
# ---------------------------------------------------------------------------
def test_get_default_model(): def test_get_default_model():
"""Test get_default_model method."""
provider = AzureOpenAIProvider( provider = AzureOpenAIProvider(
api_key="test-key", api_key="k", api_base="https://r.com", default_model="my-deploy",
api_base="https://test-resource.openai.azure.com",
default_model="my-custom-deployment",
) )
assert provider.get_default_model() == "my-deploy"
assert provider.get_default_model() == "my-custom-deployment"
if __name__ == "__main__":
# Run basic tests
print("Running basic Azure OpenAI provider tests...")
# Test initialization
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
)
print("✅ Provider initialization successful")
# Test URL building
url = provider._build_chat_url("my-deployment")
expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21"
assert url == expected
print("✅ URL building works correctly")
# Test headers
headers = provider._build_headers()
assert headers["api-key"] == "test-key"
assert headers["Content-Type"] == "application/json"
print("✅ Header building works correctly")
# Test payload preparation
messages = [{"role": "user", "content": "Test"}]
payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
print("✅ Payload preparation works correctly")
print("✅ All basic tests passed! Updated test file is working correctly.")

View File

@ -0,0 +1,233 @@
"""Tests for cached token extraction from OpenAI-compatible providers."""
from __future__ import annotations
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
class FakeUsage:
"""Mimics an OpenAI SDK usage object (has attributes, not dict keys)."""
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
class FakePromptDetails:
"""Mimics prompt_tokens_details sub-object."""
def __init__(self, cached_tokens=0):
self.cached_tokens = cached_tokens
class _FakeSpec:
supports_prompt_caching = False
model_id_prefix = None
strip_model_prefix = False
max_completion_tokens = False
reasoning_effort = None
def _provider():
from unittest.mock import MagicMock
p = OpenAICompatProvider.__new__(OpenAICompatProvider)
p.client = MagicMock()
p.spec = _FakeSpec()
return p
# Minimal valid choice so _parse reaches _extract_usage.
_DICT_CHOICE = {"message": {"content": "Hello"}}
class _FakeMessage:
content = "Hello"
tool_calls = None
class _FakeChoice:
message = _FakeMessage()
finish_reason = "stop"
# --- dict-based response (raw JSON / mapping) ---
def test_extract_usage_openai_cached_tokens_dict():
"""prompt_tokens_details.cached_tokens from a dict response."""
p = _provider()
response = {
"choices": [_DICT_CHOICE],
"usage": {
"prompt_tokens": 2000,
"completion_tokens": 300,
"total_tokens": 2300,
"prompt_tokens_details": {"cached_tokens": 1200},
}
}
result = p._parse(response)
assert result.usage["cached_tokens"] == 1200
assert result.usage["prompt_tokens"] == 2000
def test_extract_usage_deepseek_cached_tokens_dict():
"""prompt_cache_hit_tokens from a DeepSeek dict response."""
p = _provider()
response = {
"choices": [_DICT_CHOICE],
"usage": {
"prompt_tokens": 1500,
"completion_tokens": 200,
"total_tokens": 1700,
"prompt_cache_hit_tokens": 1200,
"prompt_cache_miss_tokens": 300,
}
}
result = p._parse(response)
assert result.usage["cached_tokens"] == 1200
def test_extract_usage_no_cached_tokens_dict():
"""Response without any cache fields -> no cached_tokens key."""
p = _provider()
response = {
"choices": [_DICT_CHOICE],
"usage": {
"prompt_tokens": 1000,
"completion_tokens": 200,
"total_tokens": 1200,
}
}
result = p._parse(response)
assert "cached_tokens" not in result.usage
def test_extract_usage_openai_cached_zero_dict():
"""cached_tokens=0 should NOT be included (same as existing fields)."""
p = _provider()
response = {
"choices": [_DICT_CHOICE],
"usage": {
"prompt_tokens": 2000,
"completion_tokens": 300,
"total_tokens": 2300,
"prompt_tokens_details": {"cached_tokens": 0},
}
}
result = p._parse(response)
assert "cached_tokens" not in result.usage
# --- object-based response (OpenAI SDK Pydantic model) ---
def test_extract_usage_openai_cached_tokens_obj():
"""prompt_tokens_details.cached_tokens from an SDK object response."""
p = _provider()
usage_obj = FakeUsage(
prompt_tokens=2000,
completion_tokens=300,
total_tokens=2300,
prompt_tokens_details=FakePromptDetails(cached_tokens=1200),
)
response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj)
result = p._parse(response)
assert result.usage["cached_tokens"] == 1200
def test_extract_usage_deepseek_cached_tokens_obj():
"""prompt_cache_hit_tokens from a DeepSeek SDK object response."""
p = _provider()
usage_obj = FakeUsage(
prompt_tokens=1500,
completion_tokens=200,
total_tokens=1700,
prompt_cache_hit_tokens=1200,
)
response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj)
result = p._parse(response)
assert result.usage["cached_tokens"] == 1200
def test_extract_usage_stepfun_top_level_cached_tokens_dict():
"""StepFun/Moonshot: usage.cached_tokens at top level (not nested)."""
p = _provider()
response = {
"choices": [_DICT_CHOICE],
"usage": {
"prompt_tokens": 591,
"completion_tokens": 120,
"total_tokens": 711,
"cached_tokens": 512,
}
}
result = p._parse(response)
assert result.usage["cached_tokens"] == 512
def test_extract_usage_stepfun_top_level_cached_tokens_obj():
"""StepFun/Moonshot: usage.cached_tokens as SDK object attribute."""
p = _provider()
usage_obj = FakeUsage(
prompt_tokens=591,
completion_tokens=120,
total_tokens=711,
cached_tokens=512,
)
response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj)
result = p._parse(response)
assert result.usage["cached_tokens"] == 512
def test_extract_usage_priority_nested_over_top_level_dict():
"""When both nested and top-level cached_tokens exist, nested wins."""
p = _provider()
response = {
"choices": [_DICT_CHOICE],
"usage": {
"prompt_tokens": 2000,
"completion_tokens": 300,
"total_tokens": 2300,
"prompt_tokens_details": {"cached_tokens": 100},
"cached_tokens": 500,
}
}
result = p._parse(response)
assert result.usage["cached_tokens"] == 100
def test_anthropic_maps_cache_fields_to_cached_tokens():
"""Anthropic's cache_read_input_tokens should map to cached_tokens."""
from nanobot.providers.anthropic_provider import AnthropicProvider
usage_obj = FakeUsage(
input_tokens=800,
output_tokens=200,
cache_creation_input_tokens=300,
cache_read_input_tokens=1200,
)
content_block = FakeUsage(type="text", text="hello")
response = FakeUsage(
id="msg_1",
type="message",
stop_reason="end_turn",
content=[content_block],
usage=usage_obj,
)
result = AnthropicProvider._parse_response(response)
assert result.usage["cached_tokens"] == 1200
assert result.usage["prompt_tokens"] == 2300
assert result.usage["total_tokens"] == 2500
assert result.usage["cache_creation_input_tokens"] == 300
def test_anthropic_no_cache_fields():
"""Anthropic response without cache fields should not have cached_tokens."""
from nanobot.providers.anthropic_provider import AnthropicProvider
usage_obj = FakeUsage(input_tokens=800, output_tokens=200)
content_block = FakeUsage(type="text", text="hello")
response = FakeUsage(
id="msg_1",
type="message",
stop_reason="end_turn",
content=[content_block],
usage=usage_obj,
)
result = AnthropicProvider._parse_response(response)
assert "cached_tokens" not in result.usage

View File

@ -0,0 +1,522 @@
"""Tests for the shared openai_responses converters and parsers."""
from unittest.mock import MagicMock, patch
import pytest
from nanobot.providers.base import LLMResponse, ToolCallRequest
from nanobot.providers.openai_responses.converters import (
convert_messages,
convert_tools,
convert_user_message,
split_tool_call_id,
)
from nanobot.providers.openai_responses.parsing import (
consume_sdk_stream,
map_finish_reason,
parse_response_output,
)
# ======================================================================
# converters - split_tool_call_id
# ======================================================================
class TestSplitToolCallId:
def test_plain_id(self):
assert split_tool_call_id("call_abc") == ("call_abc", None)
def test_compound_id(self):
assert split_tool_call_id("call_abc|fc_1") == ("call_abc", "fc_1")
def test_compound_empty_item_id(self):
assert split_tool_call_id("call_abc|") == ("call_abc", None)
def test_none(self):
assert split_tool_call_id(None) == ("call_0", None)
def test_empty_string(self):
assert split_tool_call_id("") == ("call_0", None)
def test_non_string(self):
assert split_tool_call_id(42) == ("call_0", None)
# ======================================================================
# converters - convert_user_message
# ======================================================================
class TestConvertUserMessage:
def test_string_content(self):
result = convert_user_message("hello")
assert result == {"role": "user", "content": [{"type": "input_text", "text": "hello"}]}
def test_text_block(self):
result = convert_user_message([{"type": "text", "text": "hi"}])
assert result["content"] == [{"type": "input_text", "text": "hi"}]
def test_image_url_block(self):
result = convert_user_message([
{"type": "image_url", "image_url": {"url": "https://img.example/a.png"}},
])
assert result["content"] == [
{"type": "input_image", "image_url": "https://img.example/a.png", "detail": "auto"},
]
def test_mixed_text_and_image(self):
result = convert_user_message([
{"type": "text", "text": "what's this?"},
{"type": "image_url", "image_url": {"url": "https://img.example/b.png"}},
])
assert len(result["content"]) == 2
assert result["content"][0]["type"] == "input_text"
assert result["content"][1]["type"] == "input_image"
def test_empty_list_falls_back(self):
result = convert_user_message([])
assert result["content"] == [{"type": "input_text", "text": ""}]
def test_none_falls_back(self):
result = convert_user_message(None)
assert result["content"] == [{"type": "input_text", "text": ""}]
def test_image_without_url_skipped(self):
result = convert_user_message([{"type": "image_url", "image_url": {}}])
assert result["content"] == [{"type": "input_text", "text": ""}]
def test_meta_fields_not_leaked(self):
"""_meta on content blocks must never appear in converted output."""
result = convert_user_message([
{"type": "text", "text": "hi", "_meta": {"path": "/tmp/x"}},
])
assert "_meta" not in result["content"][0]
def test_non_dict_items_skipped(self):
result = convert_user_message(["just a string", 42])
assert result["content"] == [{"type": "input_text", "text": ""}]
# ======================================================================
# converters - convert_messages
# ======================================================================
class TestConvertMessages:
def test_system_extracted_as_instructions(self):
msgs = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hi"},
]
instructions, items = convert_messages(msgs)
assert instructions == "You are helpful."
assert len(items) == 1
assert items[0]["role"] == "user"
def test_multiple_system_messages_last_wins(self):
msgs = [
{"role": "system", "content": "first"},
{"role": "system", "content": "second"},
{"role": "user", "content": "x"},
]
instructions, _ = convert_messages(msgs)
assert instructions == "second"
def test_user_message_converted(self):
_, items = convert_messages([{"role": "user", "content": "hello"}])
assert items[0]["role"] == "user"
assert items[0]["content"][0]["type"] == "input_text"
def test_assistant_text_message(self):
_, items = convert_messages([
{"role": "assistant", "content": "I'll help"},
])
assert items[0]["type"] == "message"
assert items[0]["role"] == "assistant"
assert items[0]["content"][0]["type"] == "output_text"
assert items[0]["content"][0]["text"] == "I'll help"
def test_assistant_empty_content_skipped(self):
_, items = convert_messages([{"role": "assistant", "content": ""}])
assert len(items) == 0
def test_assistant_with_tool_calls(self):
_, items = convert_messages([{
"role": "assistant",
"content": None,
"tool_calls": [{
"id": "call_abc|fc_1",
"function": {"name": "get_weather", "arguments": '{"city":"SF"}'},
}],
}])
assert items[0]["type"] == "function_call"
assert items[0]["call_id"] == "call_abc"
assert items[0]["id"] == "fc_1"
assert items[0]["name"] == "get_weather"
def test_assistant_with_tool_calls_no_id(self):
"""Fallback IDs when tool_call.id is missing."""
_, items = convert_messages([{
"role": "assistant",
"content": None,
"tool_calls": [{"function": {"name": "f1", "arguments": "{}"}}],
}])
assert items[0]["call_id"] == "call_0"
assert items[0]["id"].startswith("fc_")
def test_tool_message(self):
_, items = convert_messages([{
"role": "tool",
"tool_call_id": "call_abc",
"content": "result text",
}])
assert items[0]["type"] == "function_call_output"
assert items[0]["call_id"] == "call_abc"
assert items[0]["output"] == "result text"
def test_tool_message_dict_content(self):
_, items = convert_messages([{
"role": "tool",
"tool_call_id": "call_1",
"content": {"key": "value"},
}])
assert items[0]["output"] == '{"key": "value"}'
def test_non_standard_keys_not_leaked(self):
"""Extra keys on messages must not appear in converted items."""
_, items = convert_messages([{
"role": "user",
"content": "hi",
"extra_field": "should vanish",
"_meta": {"path": "/tmp"},
}])
item = items[0]
assert "extra_field" not in str(item)
assert "_meta" not in str(item)
def test_full_conversation_roundtrip(self):
"""System + user + assistant(tool_call) + tool -> correct structure."""
msgs = [
{"role": "system", "content": "Be concise."},
{"role": "user", "content": "Weather in SF?"},
{
"role": "assistant", "content": None,
"tool_calls": [{
"id": "c1|fc1",
"function": {"name": "get_weather", "arguments": '{"city":"SF"}'},
}],
},
{"role": "tool", "tool_call_id": "c1", "content": '{"temp":72}'},
]
instructions, items = convert_messages(msgs)
assert instructions == "Be concise."
assert len(items) == 3 # user, function_call, function_call_output
assert items[0]["role"] == "user"
assert items[1]["type"] == "function_call"
assert items[2]["type"] == "function_call_output"
# ======================================================================
# converters - convert_tools
# ======================================================================
class TestConvertTools:
def test_standard_function_tool(self):
tools = [{"type": "function", "function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
}}]
result = convert_tools(tools)
assert len(result) == 1
assert result[0]["type"] == "function"
assert result[0]["name"] == "get_weather"
assert result[0]["description"] == "Get weather"
assert "properties" in result[0]["parameters"]
def test_tool_without_name_skipped(self):
tools = [{"type": "function", "function": {"parameters": {}}}]
assert convert_tools(tools) == []
def test_tool_without_function_wrapper(self):
"""Direct dict without type=function wrapper."""
tools = [{"name": "f1", "description": "d", "parameters": {}}]
result = convert_tools(tools)
assert result[0]["name"] == "f1"
def test_missing_optional_fields_default(self):
tools = [{"type": "function", "function": {"name": "f"}}]
result = convert_tools(tools)
assert result[0]["description"] == ""
assert result[0]["parameters"] == {}
def test_multiple_tools(self):
tools = [
{"type": "function", "function": {"name": "a", "parameters": {}}},
{"type": "function", "function": {"name": "b", "parameters": {}}},
]
assert len(convert_tools(tools)) == 2
# ======================================================================
# parsing - map_finish_reason
# ======================================================================
class TestMapFinishReason:
def test_completed(self):
assert map_finish_reason("completed") == "stop"
def test_incomplete(self):
assert map_finish_reason("incomplete") == "length"
def test_failed(self):
assert map_finish_reason("failed") == "error"
def test_cancelled(self):
assert map_finish_reason("cancelled") == "error"
def test_none_defaults_to_stop(self):
assert map_finish_reason(None) == "stop"
def test_unknown_defaults_to_stop(self):
assert map_finish_reason("some_new_status") == "stop"
# ======================================================================
# parsing - parse_response_output
# ======================================================================
class TestParseResponseOutput:
def test_text_response(self):
resp = {
"output": [{"type": "message", "role": "assistant",
"content": [{"type": "output_text", "text": "Hello!"}]}],
"status": "completed",
"usage": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
}
result = parse_response_output(resp)
assert result.content == "Hello!"
assert result.finish_reason == "stop"
assert result.usage == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
assert result.tool_calls == []
def test_tool_call_response(self):
resp = {
"output": [{
"type": "function_call",
"call_id": "call_1", "id": "fc_1",
"name": "get_weather",
"arguments": '{"city": "SF"}',
}],
"status": "completed",
"usage": {},
}
result = parse_response_output(resp)
assert result.content is None
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "get_weather"
assert result.tool_calls[0].arguments == {"city": "SF"}
assert result.tool_calls[0].id == "call_1|fc_1"
def test_malformed_tool_arguments_logged(self):
"""Malformed JSON arguments should log a warning and fallback."""
resp = {
"output": [{
"type": "function_call",
"call_id": "c1", "id": "fc1",
"name": "f", "arguments": "{bad json",
}],
"status": "completed", "usage": {},
}
with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger:
result = parse_response_output(resp)
assert result.tool_calls[0].arguments == {"raw": "{bad json"}
mock_logger.warning.assert_called_once()
assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args)
def test_reasoning_content_extracted(self):
resp = {
"output": [
{"type": "reasoning", "summary": [
{"type": "summary_text", "text": "I think "},
{"type": "summary_text", "text": "therefore I am."},
]},
{"type": "message", "role": "assistant",
"content": [{"type": "output_text", "text": "42"}]},
],
"status": "completed", "usage": {},
}
result = parse_response_output(resp)
assert result.content == "42"
assert result.reasoning_content == "I think therefore I am."
def test_empty_output(self):
resp = {"output": [], "status": "completed", "usage": {}}
result = parse_response_output(resp)
assert result.content is None
assert result.tool_calls == []
def test_incomplete_status(self):
resp = {"output": [], "status": "incomplete", "usage": {}}
result = parse_response_output(resp)
assert result.finish_reason == "length"
def test_sdk_model_object(self):
"""parse_response_output should handle SDK objects with model_dump()."""
mock = MagicMock()
mock.model_dump.return_value = {
"output": [{"type": "message", "role": "assistant",
"content": [{"type": "output_text", "text": "sdk"}]}],
"status": "completed",
"usage": {"input_tokens": 1, "output_tokens": 2, "total_tokens": 3},
}
result = parse_response_output(mock)
assert result.content == "sdk"
assert result.usage["prompt_tokens"] == 1
def test_usage_maps_responses_api_keys(self):
"""Responses API uses input_tokens/output_tokens, not prompt_tokens/completion_tokens."""
resp = {
"output": [],
"status": "completed",
"usage": {"input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
}
result = parse_response_output(resp)
assert result.usage["prompt_tokens"] == 100
assert result.usage["completion_tokens"] == 50
assert result.usage["total_tokens"] == 150
# ======================================================================
# parsing - consume_sdk_stream
# ======================================================================
class TestConsumeSdkStream:
@pytest.mark.asyncio
async def test_text_stream(self):
ev1 = MagicMock(type="response.output_text.delta", delta="Hello")
ev2 = MagicMock(type="response.output_text.delta", delta=" world")
resp_obj = MagicMock(status="completed", usage=None, output=[])
ev3 = MagicMock(type="response.completed", response=resp_obj)
async def stream():
for e in [ev1, ev2, ev3]:
yield e
content, tool_calls, finish_reason, usage, reasoning = await consume_sdk_stream(stream())
assert content == "Hello world"
assert tool_calls == []
assert finish_reason == "stop"
@pytest.mark.asyncio
async def test_on_content_delta_called(self):
ev1 = MagicMock(type="response.output_text.delta", delta="hi")
resp_obj = MagicMock(status="completed", usage=None, output=[])
ev2 = MagicMock(type="response.completed", response=resp_obj)
deltas = []
async def cb(text):
deltas.append(text)
async def stream():
for e in [ev1, ev2]:
yield e
await consume_sdk_stream(stream(), on_content_delta=cb)
assert deltas == ["hi"]
@pytest.mark.asyncio
async def test_tool_call_stream(self):
item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="")
item_added.name = "get_weather"
ev1 = MagicMock(type="response.output_item.added", item=item_added)
ev2 = MagicMock(type="response.function_call_arguments.delta", call_id="c1", delta='{"ci')
ev3 = MagicMock(type="response.function_call_arguments.done", call_id="c1", arguments='{"city":"SF"}')
item_done = MagicMock(type="function_call", call_id="c1", id="fc1", arguments='{"city":"SF"}')
item_done.name = "get_weather"
ev4 = MagicMock(type="response.output_item.done", item=item_done)
resp_obj = MagicMock(status="completed", usage=None, output=[])
ev5 = MagicMock(type="response.completed", response=resp_obj)
async def stream():
for e in [ev1, ev2, ev3, ev4, ev5]:
yield e
content, tool_calls, finish_reason, usage, reasoning = await consume_sdk_stream(stream())
assert content == ""
assert len(tool_calls) == 1
assert tool_calls[0].name == "get_weather"
assert tool_calls[0].arguments == {"city": "SF"}
@pytest.mark.asyncio
async def test_usage_extracted(self):
usage_obj = MagicMock(input_tokens=10, output_tokens=5, total_tokens=15)
resp_obj = MagicMock(status="completed", usage=usage_obj, output=[])
ev = MagicMock(type="response.completed", response=resp_obj)
async def stream():
yield ev
_, _, _, usage, _ = await consume_sdk_stream(stream())
assert usage == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
@pytest.mark.asyncio
async def test_reasoning_extracted(self):
summary_item = MagicMock(type="summary_text", text="thinking...")
reasoning_item = MagicMock(type="reasoning", summary=[summary_item])
resp_obj = MagicMock(status="completed", usage=None, output=[reasoning_item])
ev = MagicMock(type="response.completed", response=resp_obj)
async def stream():
yield ev
_, _, _, _, reasoning = await consume_sdk_stream(stream())
assert reasoning == "thinking..."
@pytest.mark.asyncio
async def test_error_event_raises(self):
ev = MagicMock(type="error", error="rate_limit_exceeded")
async def stream():
yield ev
with pytest.raises(RuntimeError, match="Response failed.*rate_limit_exceeded"):
await consume_sdk_stream(stream())
@pytest.mark.asyncio
async def test_failed_event_raises(self):
ev = MagicMock(type="response.failed", error="server_error")
async def stream():
yield ev
with pytest.raises(RuntimeError, match="Response failed.*server_error"):
await consume_sdk_stream(stream())
@pytest.mark.asyncio
async def test_malformed_tool_args_logged(self):
"""Malformed JSON in streaming tool args should log a warning."""
item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="")
item_added.name = "f"
ev1 = MagicMock(type="response.output_item.added", item=item_added)
ev2 = MagicMock(type="response.function_call_arguments.done", call_id="c1", arguments="{bad")
item_done = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="{bad")
item_done.name = "f"
ev3 = MagicMock(type="response.output_item.done", item=item_done)
resp_obj = MagicMock(status="completed", usage=None, output=[])
ev4 = MagicMock(type="response.completed", response=resp_obj)
async def stream():
for e in [ev1, ev2, ev3, ev4]:
yield e
with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger:
_, tool_calls, _, _, _ = await consume_sdk_stream(stream())
assert tool_calls[0].arguments == {"raw": "{bad"}
mock_logger.warning.assert_called_once()
assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args)

View File

@ -11,6 +11,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None:
monkeypatch.delitem(sys.modules, "nanobot.providers.anthropic_provider", raising=False) monkeypatch.delitem(sys.modules, "nanobot.providers.anthropic_provider", raising=False)
monkeypatch.delitem(sys.modules, "nanobot.providers.openai_compat_provider", raising=False) monkeypatch.delitem(sys.modules, "nanobot.providers.openai_compat_provider", raising=False)
monkeypatch.delitem(sys.modules, "nanobot.providers.openai_codex_provider", raising=False) monkeypatch.delitem(sys.modules, "nanobot.providers.openai_codex_provider", raising=False)
monkeypatch.delitem(sys.modules, "nanobot.providers.github_copilot_provider", raising=False)
monkeypatch.delitem(sys.modules, "nanobot.providers.azure_openai_provider", raising=False) monkeypatch.delitem(sys.modules, "nanobot.providers.azure_openai_provider", raising=False)
providers = importlib.import_module("nanobot.providers") providers = importlib.import_module("nanobot.providers")
@ -18,6 +19,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None:
assert "nanobot.providers.anthropic_provider" not in sys.modules assert "nanobot.providers.anthropic_provider" not in sys.modules
assert "nanobot.providers.openai_compat_provider" not in sys.modules assert "nanobot.providers.openai_compat_provider" not in sys.modules
assert "nanobot.providers.openai_codex_provider" not in sys.modules assert "nanobot.providers.openai_codex_provider" not in sys.modules
assert "nanobot.providers.github_copilot_provider" not in sys.modules
assert "nanobot.providers.azure_openai_provider" not in sys.modules assert "nanobot.providers.azure_openai_provider" not in sys.modules
assert providers.__all__ == [ assert providers.__all__ == [
"LLMProvider", "LLMProvider",
@ -25,6 +27,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None:
"AnthropicProvider", "AnthropicProvider",
"OpenAICompatProvider", "OpenAICompatProvider",
"OpenAICodexProvider", "OpenAICodexProvider",
"GitHubCopilotProvider",
"AzureOpenAIProvider", "AzureOpenAIProvider",
] ]

View File

@ -0,0 +1,59 @@
"""Tests for build_status_content cache hit rate display."""
from nanobot.utils.helpers import build_status_content
def test_status_shows_cache_hit_rate():
content = build_status_content(
version="0.1.0",
model="glm-4-plus",
start_time=1000000.0,
last_usage={"prompt_tokens": 2000, "completion_tokens": 300, "cached_tokens": 1200},
context_window_tokens=128000,
session_msg_count=10,
context_tokens_estimate=5000,
)
assert "60% cached" in content
assert "2000 in / 300 out" in content
def test_status_no_cache_info():
"""Without cached_tokens, display should not show cache percentage."""
content = build_status_content(
version="0.1.0",
model="glm-4-plus",
start_time=1000000.0,
last_usage={"prompt_tokens": 2000, "completion_tokens": 300},
context_window_tokens=128000,
session_msg_count=10,
context_tokens_estimate=5000,
)
assert "cached" not in content.lower()
assert "2000 in / 300 out" in content
def test_status_zero_cached_tokens():
"""cached_tokens=0 should not show cache percentage."""
content = build_status_content(
version="0.1.0",
model="glm-4-plus",
start_time=1000000.0,
last_usage={"prompt_tokens": 2000, "completion_tokens": 300, "cached_tokens": 0},
context_window_tokens=128000,
session_msg_count=10,
context_tokens_estimate=5000,
)
assert "cached" not in content.lower()
def test_status_100_percent_cached():
content = build_status_content(
version="0.1.0",
model="glm-4-plus",
start_time=1000000.0,
last_usage={"prompt_tokens": 1000, "completion_tokens": 100, "cached_tokens": 1000},
context_window_tokens=128000,
session_msg_count=5,
context_tokens_estimate=3000,
)
assert "100% cached" in content

View File

@ -125,6 +125,27 @@ def test_workspace_override(tmp_path):
assert bot._loop.workspace == custom_ws assert bot._loop.workspace == custom_ws
def test_sdk_make_provider_uses_github_copilot_backend():
from nanobot.config.schema import Config
from nanobot.nanobot import _make_provider
config = Config.model_validate(
{
"agents": {
"defaults": {
"provider": "github-copilot",
"model": "github-copilot/gpt-4.1",
}
}
}
)
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = _make_provider(config)
assert provider.__class__.__name__ == "GitHubCopilotProvider"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_custom_session_key(tmp_path): async def test_run_custom_session_key(tmp_path):
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage

View File

@ -95,6 +95,14 @@ def test_exec_extract_absolute_paths_keeps_full_windows_path() -> None:
assert paths == [r"C:\user\workspace\txt"] assert paths == [r"C:\user\workspace\txt"]
def test_exec_extract_absolute_paths_captures_windows_drive_root_path() -> None:
"""Windows drive root paths like `E:\\` must be extracted for workspace guarding."""
# Note: raw strings cannot end with a single backslash.
cmd = "dir E:\\"
paths = ExecTool._extract_absolute_paths(cmd)
assert paths == ["E:\\"]
def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None: def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None:
cmd = ".venv/bin/python script.py" cmd = ".venv/bin/python script.py"
paths = ExecTool._extract_absolute_paths(cmd) paths = ExecTool._extract_absolute_paths(cmd)
@ -134,6 +142,45 @@ def test_exec_guard_blocks_quoted_home_path_outside_workspace(tmp_path) -> None:
assert error == "Error: Command blocked by safety guard (path outside working dir)" assert error == "Error: Command blocked by safety guard (path outside working dir)"
def test_exec_guard_blocks_windows_drive_root_outside_workspace(monkeypatch) -> None:
import nanobot.agent.tools.shell as shell_mod
class FakeWindowsPath:
def __init__(self, raw: str) -> None:
self.raw = raw.rstrip("\\") + ("\\" if raw.endswith("\\") else "")
def resolve(self) -> "FakeWindowsPath":
return self
def expanduser(self) -> "FakeWindowsPath":
return self
def is_absolute(self) -> bool:
return len(self.raw) >= 3 and self.raw[1:3] == ":\\"
@property
def parents(self) -> list["FakeWindowsPath"]:
if not self.is_absolute():
return []
trimmed = self.raw.rstrip("\\")
if len(trimmed) <= 2:
return []
idx = trimmed.rfind("\\")
if idx <= 2:
return [FakeWindowsPath(trimmed[:2] + "\\")]
parent = FakeWindowsPath(trimmed[:idx])
return [parent, *parent.parents]
def __eq__(self, other: object) -> bool:
return isinstance(other, FakeWindowsPath) and self.raw.lower() == other.raw.lower()
monkeypatch.setattr(shell_mod, "Path", FakeWindowsPath)
tool = ExecTool(restrict_to_workspace=True)
error = tool._guard_command("dir E:\\", "E:\\workspace")
assert error == "Error: Command blocked by safety guard (path outside working dir)"
# --- cast_params tests --- # --- cast_params tests ---