Merge remote-tracking branch 'origin/main' into feat/runtime-hardening

This commit is contained in:
Xubin Ren 2026-04-02 10:40:49 +00:00
commit eefd7e60f2
31 changed files with 2385 additions and 789 deletions

View File

@ -95,6 +95,15 @@ class _LoopHook(AgentHook):
logger.info("Tool call: {}({})", tc.name, args_str[:200])
self._loop._set_tool_context(self._channel, self._chat_id, self._message_id)
async def after_iteration(self, context: AgentHookContext) -> None:
u = context.usage or {}
logger.debug(
"LLM usage: prompt={} completion={} cached={}",
u.get("prompt_tokens", 0),
u.get("completion_tokens", 0),
u.get("cached_tokens", 0),
)
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
return self._loop._strip_think(content)

View File

@ -77,7 +77,7 @@ class AgentRunner:
messages = list(spec.initial_messages)
final_content: str | None = None
tools_used: list[str] = []
usage = {"prompt_tokens": 0, "completion_tokens": 0}
usage: dict[str, int] = {}
error: str | None = None
stop_reason = "completed"
tool_events: list[dict[str, str]] = []
@ -122,13 +122,15 @@ class AgentRunner:
response = await self.provider.chat_with_retry(**kwargs)
raw_usage = response.usage or {}
usage = {
"prompt_tokens": int(raw_usage.get("prompt_tokens", 0) or 0),
"completion_tokens": int(raw_usage.get("completion_tokens", 0) or 0),
}
context.response = response
context.usage = usage
context.usage = raw_usage
context.tool_calls = list(response.tool_calls)
# Accumulate standard fields into result usage.
usage["prompt_tokens"] = usage.get("prompt_tokens", 0) + int(raw_usage.get("prompt_tokens", 0) or 0)
usage["completion_tokens"] = usage.get("completion_tokens", 0) + int(raw_usage.get("completion_tokens", 0) or 0)
cached = raw_usage.get("cached_tokens")
if cached:
usage["cached_tokens"] = usage.get("cached_tokens", 0) + int(cached)
if response.has_tool_calls:
if hook.wants_streaming():

View File

@ -97,6 +97,11 @@ class CronTool(Tool):
f"(e.g. '2026-02-12T10:30:00'). Naive values default to {self._default_timezone}."
),
},
"deliver": {
"type": "boolean",
"description": "Whether to deliver the execution result to the user channel (default true)",
"default": True
},
"job_id": {"type": "string", "description": "Job ID (for remove)"},
},
"required": ["action"],
@ -111,12 +116,13 @@ class CronTool(Tool):
tz: str | None = None,
at: str | None = None,
job_id: str | None = None,
deliver: bool = True,
**kwargs: Any,
) -> str:
if action == "add":
if self._in_cron_context.get():
return "Error: cannot schedule new jobs from within a cron job execution"
return self._add_job(message, every_seconds, cron_expr, tz, at)
return self._add_job(message, every_seconds, cron_expr, tz, at, deliver)
elif action == "list":
return self._list_jobs()
elif action == "remove":
@ -130,6 +136,7 @@ class CronTool(Tool):
cron_expr: str | None,
tz: str | None,
at: str | None,
deliver: bool = True,
) -> str:
if not message:
return "Error: message is required for add"
@ -171,7 +178,7 @@ class CronTool(Tool):
name=message[:30],
schedule=schedule,
message=message,
deliver=True,
deliver=deliver,
channel=self._channel,
to=self._chat_id,
delete_after_run=delete_after,

View File

@ -86,7 +86,15 @@ class MessageTool(Tool):
) -> str:
channel = channel or self._default_channel
chat_id = chat_id or self._default_chat_id
message_id = message_id or self._default_message_id
# Only inherit default message_id when targeting the same channel+chat.
# Cross-chat sends must not carry the original message_id, because
# some channels (e.g. Feishu) use it to determine the target
# conversation via their Reply API, which would route the message
# to the wrong chat entirely.
if channel == self._default_channel and chat_id == self._default_chat_id:
message_id = message_id or self._default_message_id
else:
message_id = None
if not channel or not chat_id:
return "Error: No target channel/chat specified"
@ -101,7 +109,7 @@ class MessageTool(Tool):
media=media or [],
metadata={
"message_id": message_id,
},
} if message_id else {},
)
try:

View File

@ -190,7 +190,9 @@ class ExecTool(Tool):
@staticmethod
def _extract_absolute_paths(command: str) -> list[str]:
win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]+", command) # Windows: C:\...
# Windows: match drive-root paths like `C:\` as well as `C:\path\to\file`
# NOTE: `*` is required so `C:\` (nothing after the slash) is still extracted.
win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]*", command)
posix_paths = re.findall(r"(?:^|[\s|>'\"])(/[^\s\"'>;|<]+)", command) # POSIX: /absolute only
home_paths = re.findall(r"(?:^|[\s|>'\"])(~[^\s\"'>;|<]*)", command) # POSIX/Windows home shortcut: ~
return win_paths + posix_paths + home_paths

View File

@ -415,6 +415,9 @@ def _make_provider(config: Config):
api_base=p.api_base,
default_model=model,
)
elif backend == "github_copilot":
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
provider = GitHubCopilotProvider(default_model=model)
elif backend == "anthropic":
from nanobot.providers.anthropic_provider import AnthropicProvider
provider = AnthropicProvider(
@ -1029,12 +1032,18 @@ app.add_typer(channels_app, name="channels")
@channels_app.command("status")
def channels_status():
def channels_status(
config_path: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
):
"""Show channel status."""
from nanobot.channels.registry import discover_all
from nanobot.config.loader import load_config
from nanobot.config.loader import load_config, set_config_path
config = load_config()
resolved_config_path = Path(config_path).expanduser().resolve() if config_path else None
if resolved_config_path is not None:
set_config_path(resolved_config_path)
config = load_config(resolved_config_path)
table = Table(title="Channel Status")
table.add_column("Channel", style="cyan")
@ -1121,12 +1130,17 @@ def _get_bridge_dir() -> Path:
def channels_login(
channel_name: str = typer.Argument(..., help="Channel name (e.g. weixin, whatsapp)"),
force: bool = typer.Option(False, "--force", "-f", help="Force re-authentication even if already logged in"),
config_path: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
):
"""Authenticate with a channel via QR code or other interactive login."""
from nanobot.channels.registry import discover_all
from nanobot.config.loader import load_config
from nanobot.config.loader import load_config, set_config_path
config = load_config()
resolved_config_path = Path(config_path).expanduser().resolve() if config_path else None
if resolved_config_path is not None:
set_config_path(resolved_config_path)
config = load_config(resolved_config_path)
channel_cfg = getattr(config.channels, channel_name, None) or {}
# Validate channel exists
@ -1298,26 +1312,16 @@ def _login_openai_codex() -> None:
@_register_login("github_copilot")
def _login_github_copilot() -> None:
import asyncio
from openai import AsyncOpenAI
console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n")
async def _trigger():
client = AsyncOpenAI(
api_key="dummy",
base_url="https://api.githubcopilot.com",
)
await client.chat.completions.create(
model="gpt-4o",
messages=[{"role": "user", "content": "hi"}],
max_tokens=1,
)
try:
asyncio.run(_trigger())
console.print("[green]✓ Authenticated with GitHub Copilot[/green]")
from nanobot.providers.github_copilot_provider import login_github_copilot
console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n")
token = login_github_copilot(
print_fn=lambda s: console.print(s),
prompt_fn=lambda s: typer.prompt(s),
)
account = token.account_id or "GitHub"
console.print(f"[green]✓ Authenticated with GitHub Copilot[/green] [dim]{account}[/dim]")
except Exception as e:
console.print(f"[red]Authentication error: {e}[/red]")
raise typer.Exit(1)

View File

@ -138,6 +138,10 @@ def _make_provider(config: Any) -> Any:
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
provider = OpenAICodexProvider(default_model=model)
elif backend == "github_copilot":
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
provider = GitHubCopilotProvider(default_model=model)
elif backend == "azure_openai":
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider

View File

@ -13,6 +13,7 @@ __all__ = [
"AnthropicProvider",
"OpenAICompatProvider",
"OpenAICodexProvider",
"GitHubCopilotProvider",
"AzureOpenAIProvider",
]
@ -20,12 +21,14 @@ _LAZY_IMPORTS = {
"AnthropicProvider": ".anthropic_provider",
"OpenAICompatProvider": ".openai_compat_provider",
"OpenAICodexProvider": ".openai_codex_provider",
"GitHubCopilotProvider": ".github_copilot_provider",
"AzureOpenAIProvider": ".azure_openai_provider",
}
if TYPE_CHECKING:
from nanobot.providers.anthropic_provider import AnthropicProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
from nanobot.providers.openai_codex_provider import OpenAICodexProvider

View File

@ -372,15 +372,22 @@ class AnthropicProvider(LLMProvider):
usage: dict[str, int] = {}
if response.usage:
input_tokens = response.usage.input_tokens
cache_creation = getattr(response.usage, "cache_creation_input_tokens", 0) or 0
cache_read = getattr(response.usage, "cache_read_input_tokens", 0) or 0
total_prompt_tokens = input_tokens + cache_creation + cache_read
usage = {
"prompt_tokens": response.usage.input_tokens,
"prompt_tokens": total_prompt_tokens,
"completion_tokens": response.usage.output_tokens,
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
"total_tokens": total_prompt_tokens + response.usage.output_tokens,
}
for attr in ("cache_creation_input_tokens", "cache_read_input_tokens"):
val = getattr(response.usage, attr, 0)
if val:
usage[attr] = val
# Normalize to cached_tokens for downstream consistency.
if cache_read:
usage["cached_tokens"] = cache_read
return LLMResponse(
content="".join(content_parts) or None,

View File

@ -1,31 +1,36 @@
"""Azure OpenAI provider implementation with API version 2024-10-21."""
"""Azure OpenAI provider using the OpenAI SDK Responses API.
Uses ``AsyncOpenAI`` pointed at ``https://{endpoint}/openai/v1/`` which
routes to the Responses API (``/responses``). Reuses shared conversion
helpers from :mod:`nanobot.providers.openai_responses`.
"""
from __future__ import annotations
import json
import uuid
from collections.abc import Awaitable, Callable
from typing import Any
from urllib.parse import urljoin
import httpx
import json_repair
from openai import AsyncOpenAI
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
from nanobot.providers.base import LLMProvider, LLMResponse
from nanobot.providers.openai_responses import (
consume_sdk_stream,
convert_messages,
convert_tools,
parse_response_output,
)
class AzureOpenAIProvider(LLMProvider):
"""
Azure OpenAI provider with API version 2024-10-21 compliance.
"""Azure OpenAI provider backed by the Responses API.
Features:
- Hardcoded API version 2024-10-21
- Uses model field as Azure deployment name in URL path
- Uses api-key header instead of Authorization Bearer
- Uses max_completion_tokens instead of max_tokens
- Direct HTTP calls, bypasses LiteLLM
- Uses the OpenAI Python SDK (``AsyncOpenAI``) with
``base_url = {endpoint}/openai/v1/``
- Calls ``client.responses.create()`` (Responses API)
- Reuses shared message/tool/SSE conversion from
``openai_responses``
"""
def __init__(
@ -36,40 +41,28 @@ class AzureOpenAIProvider(LLMProvider):
):
super().__init__(api_key, api_base)
self.default_model = default_model
self.api_version = "2024-10-21"
# Validate required parameters
if not api_key:
raise ValueError("Azure OpenAI api_key is required")
if not api_base:
raise ValueError("Azure OpenAI api_base is required")
# Ensure api_base ends with /
if not api_base.endswith('/'):
api_base += '/'
# Normalise: ensure trailing slash
if not api_base.endswith("/"):
api_base += "/"
self.api_base = api_base
def _build_chat_url(self, deployment_name: str) -> str:
"""Build the Azure OpenAI chat completions URL."""
# Azure OpenAI URL format:
# https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}
base_url = self.api_base
if not base_url.endswith('/'):
base_url += '/'
url = urljoin(
base_url,
f"openai/deployments/{deployment_name}/chat/completions"
# SDK client targeting the Azure Responses API endpoint
base_url = f"{api_base.rstrip('/')}/openai/v1/"
self._client = AsyncOpenAI(
api_key=api_key,
base_url=base_url,
default_headers={"x-session-affinity": uuid.uuid4().hex},
)
return f"{url}?api-version={self.api_version}"
def _build_headers(self) -> dict[str, str]:
"""Build headers for Azure OpenAI API with api-key header."""
return {
"Content-Type": "application/json",
"api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization
"x-session-affinity": uuid.uuid4().hex, # For cache locality
}
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@staticmethod
def _supports_temperature(
@ -82,36 +75,51 @@ class AzureOpenAIProvider(LLMProvider):
name = deployment_name.lower()
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
def _prepare_request_payload(
def _build_body(
self,
deployment_name: str,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None,
model: str | None,
max_tokens: int,
temperature: float,
reasoning_effort: str | None,
tool_choice: str | dict[str, Any] | None,
) -> dict[str, Any]:
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
payload: dict[str, Any] = {
"messages": self._sanitize_request_messages(
self._sanitize_empty_content(messages),
_AZURE_MSG_KEYS,
),
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
"""Build the Responses API request body from Chat-Completions-style args."""
deployment = model or self.default_model
instructions, input_items = convert_messages(self._sanitize_empty_content(messages))
body: dict[str, Any] = {
"model": deployment,
"instructions": instructions or None,
"input": input_items,
"max_output_tokens": max(1, max_tokens),
"store": False,
"stream": False,
}
if self._supports_temperature(deployment_name, reasoning_effort):
payload["temperature"] = temperature
if self._supports_temperature(deployment, reasoning_effort):
body["temperature"] = temperature
if reasoning_effort:
payload["reasoning_effort"] = reasoning_effort
body["reasoning"] = {"effort": reasoning_effort}
body["include"] = ["reasoning.encrypted_content"]
if tools:
payload["tools"] = tools
payload["tool_choice"] = tool_choice or "auto"
body["tools"] = convert_tools(tools)
body["tool_choice"] = tool_choice or "auto"
return payload
return body
@staticmethod
def _handle_error(e: Exception) -> LLMResponse:
body = getattr(e, "body", None) or getattr(getattr(e, "response", None), "text", None)
msg = f"Error: {str(body).strip()[:500]}" if body else f"Error calling Azure OpenAI: {e}"
return LLMResponse(content=msg, finish_reason="error")
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
async def chat(
self,
@ -123,92 +131,15 @@ class AzureOpenAIProvider(LLMProvider):
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
"""
Send a chat completion request to Azure OpenAI.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions in OpenAI format.
model: Model identifier (used as deployment name).
max_tokens: Maximum tokens in response (mapped to max_completion_tokens).
temperature: Sampling temperature.
reasoning_effort: Optional reasoning effort parameter.
Returns:
LLMResponse with content and/or tool calls.
"""
deployment_name = model or self.default_model
url = self._build_chat_url(deployment_name)
headers = self._build_headers()
payload = self._prepare_request_payload(
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort,
tool_choice=tool_choice,
body = self._build_body(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
try:
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
response = await client.post(url, headers=headers, json=payload)
if response.status_code != 200:
return LLMResponse(
content=f"Azure OpenAI API Error {response.status_code}: {response.text}",
finish_reason="error",
)
response_data = response.json()
return self._parse_response(response_data)
response = await self._client.responses.create(**body)
return parse_response_output(response)
except Exception as e:
return LLMResponse(
content=f"Error calling Azure OpenAI: {repr(e)}",
finish_reason="error",
)
def _parse_response(self, response: dict[str, Any]) -> LLMResponse:
"""Parse Azure OpenAI response into our standard format."""
try:
choice = response["choices"][0]
message = choice["message"]
tool_calls = []
if message.get("tool_calls"):
for tc in message["tool_calls"]:
# Parse arguments from JSON string if needed
args = tc["function"]["arguments"]
if isinstance(args, str):
args = json_repair.loads(args)
tool_calls.append(
ToolCallRequest(
id=tc["id"],
name=tc["function"]["name"],
arguments=args,
)
)
usage = {}
if response.get("usage"):
usage_data = response["usage"]
usage = {
"prompt_tokens": usage_data.get("prompt_tokens", 0),
"completion_tokens": usage_data.get("completion_tokens", 0),
"total_tokens": usage_data.get("total_tokens", 0),
}
reasoning_content = message.get("reasoning_content") or None
return LLMResponse(
content=message.get("content"),
tool_calls=tool_calls,
finish_reason=choice.get("finish_reason", "stop"),
usage=usage,
reasoning_content=reasoning_content,
)
except (KeyError, IndexError) as e:
return LLMResponse(
content=f"Error parsing Azure OpenAI response: {str(e)}",
finish_reason="error",
)
return self._handle_error(e)
async def chat_stream(
self,
@ -221,89 +152,26 @@ class AzureOpenAIProvider(LLMProvider):
tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
"""Stream a chat completion via Azure OpenAI SSE."""
deployment_name = model or self.default_model
url = self._build_chat_url(deployment_name)
headers = self._build_headers()
payload = self._prepare_request_payload(
deployment_name, messages, tools, max_tokens, temperature,
reasoning_effort, tool_choice=tool_choice,
body = self._build_body(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
payload["stream"] = True
body["stream"] = True
try:
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
async with client.stream("POST", url, headers=headers, json=payload) as response:
if response.status_code != 200:
text = await response.aread()
return LLMResponse(
content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}",
finish_reason="error",
)
return await self._consume_stream(response, on_content_delta)
except Exception as e:
return LLMResponse(content=f"Error calling Azure OpenAI: {repr(e)}", finish_reason="error")
async def _consume_stream(
self,
response: httpx.Response,
on_content_delta: Callable[[str], Awaitable[None]] | None,
) -> LLMResponse:
"""Parse Azure OpenAI SSE stream into an LLMResponse."""
content_parts: list[str] = []
tool_call_buffers: dict[int, dict[str, str]] = {}
finish_reason = "stop"
async for line in response.aiter_lines():
if not line.startswith("data: "):
continue
data = line[6:].strip()
if data == "[DONE]":
break
try:
chunk = json.loads(data)
except Exception:
continue
choices = chunk.get("choices") or []
if not choices:
continue
choice = choices[0]
if choice.get("finish_reason"):
finish_reason = choice["finish_reason"]
delta = choice.get("delta") or {}
text = delta.get("content")
if text:
content_parts.append(text)
if on_content_delta:
await on_content_delta(text)
for tc in delta.get("tool_calls") or []:
idx = tc.get("index", 0)
buf = tool_call_buffers.setdefault(idx, {"id": "", "name": "", "arguments": ""})
if tc.get("id"):
buf["id"] = tc["id"]
fn = tc.get("function") or {}
if fn.get("name"):
buf["name"] = fn["name"]
if fn.get("arguments"):
buf["arguments"] += fn["arguments"]
tool_calls = [
ToolCallRequest(
id=buf["id"], name=buf["name"],
arguments=json_repair.loads(buf["arguments"]) if buf["arguments"] else {},
stream = await self._client.responses.create(**body)
content, tool_calls, finish_reason, usage, reasoning_content = (
await consume_sdk_stream(stream, on_content_delta)
)
for buf in tool_call_buffers.values()
]
return LLMResponse(
content="".join(content_parts) or None,
tool_calls=tool_calls,
finish_reason=finish_reason,
)
return LLMResponse(
content=content or None,
tool_calls=tool_calls,
finish_reason=finish_reason,
usage=usage,
reasoning_content=reasoning_content,
)
except Exception as e:
return self._handle_error(e)
def get_default_model(self) -> str:
"""Get the default model (also used as default deployment name)."""
return self.default_model

View File

@ -0,0 +1,257 @@
"""GitHub Copilot OAuth-backed provider."""
from __future__ import annotations
import time
import webbrowser
from collections.abc import Callable
import httpx
from oauth_cli_kit.models import OAuthToken
from oauth_cli_kit.storage import FileTokenStorage
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
DEFAULT_GITHUB_DEVICE_CODE_URL = "https://github.com/login/device/code"
DEFAULT_GITHUB_ACCESS_TOKEN_URL = "https://github.com/login/oauth/access_token"
DEFAULT_GITHUB_USER_URL = "https://api.github.com/user"
DEFAULT_COPILOT_TOKEN_URL = "https://api.github.com/copilot_internal/v2/token"
DEFAULT_COPILOT_BASE_URL = "https://api.githubcopilot.com"
GITHUB_COPILOT_CLIENT_ID = "Iv1.b507a08c87ecfe98"
GITHUB_COPILOT_SCOPE = "read:user"
TOKEN_FILENAME = "github-copilot.json"
TOKEN_APP_NAME = "nanobot"
USER_AGENT = "nanobot/0.1"
EDITOR_VERSION = "vscode/1.99.0"
EDITOR_PLUGIN_VERSION = "copilot-chat/0.26.0"
_EXPIRY_SKEW_SECONDS = 60
_LONG_LIVED_TOKEN_SECONDS = 315360000
def _storage() -> FileTokenStorage:
return FileTokenStorage(
token_filename=TOKEN_FILENAME,
app_name=TOKEN_APP_NAME,
import_codex_cli=False,
)
def _copilot_headers(token: str) -> dict[str, str]:
return {
"Authorization": f"token {token}",
"Accept": "application/json",
"User-Agent": USER_AGENT,
"Editor-Version": EDITOR_VERSION,
"Editor-Plugin-Version": EDITOR_PLUGIN_VERSION,
}
def _load_github_token() -> OAuthToken | None:
token = _storage().load()
if not token or not token.access:
return None
return token
def get_github_copilot_login_status() -> OAuthToken | None:
"""Return the persisted GitHub OAuth token if available."""
return _load_github_token()
def login_github_copilot(
print_fn: Callable[[str], None] | None = None,
prompt_fn: Callable[[str], str] | None = None,
) -> OAuthToken:
"""Run GitHub device flow and persist the GitHub OAuth token used for Copilot."""
del prompt_fn
printer = print_fn or print
timeout = httpx.Timeout(20.0, connect=20.0)
with httpx.Client(timeout=timeout, follow_redirects=True, trust_env=True) as client:
response = client.post(
DEFAULT_GITHUB_DEVICE_CODE_URL,
headers={"Accept": "application/json", "User-Agent": USER_AGENT},
data={"client_id": GITHUB_COPILOT_CLIENT_ID, "scope": GITHUB_COPILOT_SCOPE},
)
response.raise_for_status()
payload = response.json()
device_code = str(payload["device_code"])
user_code = str(payload["user_code"])
verify_url = str(payload.get("verification_uri") or payload.get("verification_uri_complete") or "")
verify_complete = str(payload.get("verification_uri_complete") or verify_url)
interval = max(1, int(payload.get("interval") or 5))
expires_in = int(payload.get("expires_in") or 900)
printer(f"Open: {verify_url}")
printer(f"Code: {user_code}")
if verify_complete:
try:
webbrowser.open(verify_complete)
except Exception:
pass
deadline = time.time() + expires_in
current_interval = interval
access_token = None
token_expires_in = _LONG_LIVED_TOKEN_SECONDS
while time.time() < deadline:
poll = client.post(
DEFAULT_GITHUB_ACCESS_TOKEN_URL,
headers={"Accept": "application/json", "User-Agent": USER_AGENT},
data={
"client_id": GITHUB_COPILOT_CLIENT_ID,
"device_code": device_code,
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
},
)
poll.raise_for_status()
poll_payload = poll.json()
access_token = poll_payload.get("access_token")
if access_token:
token_expires_in = int(poll_payload.get("expires_in") or _LONG_LIVED_TOKEN_SECONDS)
break
error = poll_payload.get("error")
if error == "authorization_pending":
time.sleep(current_interval)
continue
if error == "slow_down":
current_interval += 5
time.sleep(current_interval)
continue
if error == "expired_token":
raise RuntimeError("GitHub device code expired. Please run login again.")
if error == "access_denied":
raise RuntimeError("GitHub device flow was denied.")
if error:
desc = poll_payload.get("error_description") or error
raise RuntimeError(str(desc))
time.sleep(current_interval)
else:
raise RuntimeError("GitHub device flow timed out.")
user = client.get(
DEFAULT_GITHUB_USER_URL,
headers={
"Authorization": f"Bearer {access_token}",
"Accept": "application/vnd.github+json",
"User-Agent": USER_AGENT,
},
)
user.raise_for_status()
user_payload = user.json()
account_id = user_payload.get("login") or str(user_payload.get("id") or "") or None
expires_ms = int((time.time() + token_expires_in) * 1000)
token = OAuthToken(
access=str(access_token),
refresh="",
expires=expires_ms,
account_id=str(account_id) if account_id else None,
)
_storage().save(token)
return token
class GitHubCopilotProvider(OpenAICompatProvider):
"""Provider that exchanges a stored GitHub OAuth token for Copilot access tokens."""
def __init__(self, default_model: str = "github-copilot/gpt-4.1"):
from nanobot.providers.registry import find_by_name
self._copilot_access_token: str | None = None
self._copilot_expires_at: float = 0.0
super().__init__(
api_key="no-key",
api_base=DEFAULT_COPILOT_BASE_URL,
default_model=default_model,
extra_headers={
"Editor-Version": EDITOR_VERSION,
"Editor-Plugin-Version": EDITOR_PLUGIN_VERSION,
"User-Agent": USER_AGENT,
},
spec=find_by_name("github_copilot"),
)
async def _get_copilot_access_token(self) -> str:
now = time.time()
if self._copilot_access_token and now < self._copilot_expires_at - _EXPIRY_SKEW_SECONDS:
return self._copilot_access_token
github_token = _load_github_token()
if not github_token or not github_token.access:
raise RuntimeError("GitHub Copilot is not logged in. Run: nanobot provider login github-copilot")
timeout = httpx.Timeout(20.0, connect=20.0)
async with httpx.AsyncClient(timeout=timeout, follow_redirects=True, trust_env=True) as client:
response = await client.get(
DEFAULT_COPILOT_TOKEN_URL,
headers=_copilot_headers(github_token.access),
)
response.raise_for_status()
payload = response.json()
token = payload.get("token")
if not token:
raise RuntimeError("GitHub Copilot token exchange returned no token.")
expires_at = payload.get("expires_at")
if isinstance(expires_at, (int, float)):
self._copilot_expires_at = float(expires_at)
else:
refresh_in = payload.get("refresh_in") or 1500
self._copilot_expires_at = time.time() + int(refresh_in)
self._copilot_access_token = str(token)
return self._copilot_access_token
async def _refresh_client_api_key(self) -> str:
token = await self._get_copilot_access_token()
self.api_key = token
self._client.api_key = token
return token
async def chat(
self,
messages: list[dict[str, object]],
tools: list[dict[str, object]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, object] | None = None,
):
await self._refresh_client_api_key()
return await super().chat(
messages=messages,
tools=tools,
model=model,
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
tool_choice=tool_choice,
)
async def chat_stream(
self,
messages: list[dict[str, object]],
tools: list[dict[str, object]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, object] | None = None,
on_content_delta: Callable[[str], None] | None = None,
):
await self._refresh_client_api_key()
return await super().chat_stream(
messages=messages,
tools=tools,
model=model,
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
tool_choice=tool_choice,
on_content_delta=on_content_delta,
)

View File

@ -6,13 +6,18 @@ import asyncio
import hashlib
import json
from collections.abc import Awaitable, Callable
from typing import Any, AsyncGenerator
from typing import Any
import httpx
from loguru import logger
from oauth_cli_kit import get_token as get_codex_token
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.openai_responses import (
consume_sse,
convert_messages,
convert_tools,
)
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
DEFAULT_ORIGINATOR = "nanobot"
@ -36,7 +41,7 @@ class OpenAICodexProvider(LLMProvider):
) -> LLMResponse:
"""Shared request logic for both chat() and chat_stream()."""
model = model or self.default_model
system_prompt, input_items = _convert_messages(messages)
system_prompt, input_items = convert_messages(messages)
token = await asyncio.to_thread(get_codex_token)
headers = _build_headers(token.account_id, token.access)
@ -56,7 +61,7 @@ class OpenAICodexProvider(LLMProvider):
if reasoning_effort:
body["reasoning"] = {"effort": reasoning_effort}
if tools:
body["tools"] = _convert_tools(tools)
body["tools"] = convert_tools(tools)
try:
try:
@ -127,96 +132,7 @@ async def _request_codex(
if response.status_code != 200:
text = await response.aread()
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
return await _consume_sse(response, on_content_delta)
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert OpenAI function-calling schema to Codex flat format."""
converted: list[dict[str, Any]] = []
for tool in tools:
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
name = fn.get("name")
if not name:
continue
params = fn.get("parameters") or {}
converted.append({
"type": "function",
"name": name,
"description": fn.get("description") or "",
"parameters": params if isinstance(params, dict) else {},
})
return converted
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
system_prompt = ""
input_items: list[dict[str, Any]] = []
for idx, msg in enumerate(messages):
role = msg.get("role")
content = msg.get("content")
if role == "system":
system_prompt = content if isinstance(content, str) else ""
continue
if role == "user":
input_items.append(_convert_user_message(content))
continue
if role == "assistant":
if isinstance(content, str) and content:
input_items.append({
"type": "message", "role": "assistant",
"content": [{"type": "output_text", "text": content}],
"status": "completed", "id": f"msg_{idx}",
})
for tool_call in msg.get("tool_calls", []) or []:
fn = tool_call.get("function") or {}
call_id, item_id = _split_tool_call_id(tool_call.get("id"))
input_items.append({
"type": "function_call",
"id": item_id or f"fc_{idx}",
"call_id": call_id or f"call_{idx}",
"name": fn.get("name"),
"arguments": fn.get("arguments") or "{}",
})
continue
if role == "tool":
call_id, _ = _split_tool_call_id(msg.get("tool_call_id"))
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text})
return system_prompt, input_items
def _convert_user_message(content: Any) -> dict[str, Any]:
if isinstance(content, str):
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
if isinstance(content, list):
converted: list[dict[str, Any]] = []
for item in content:
if not isinstance(item, dict):
continue
if item.get("type") == "text":
converted.append({"type": "input_text", "text": item.get("text", "")})
elif item.get("type") == "image_url":
url = (item.get("image_url") or {}).get("url")
if url:
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
if converted:
return {"role": "user", "content": converted}
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
def _split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
if isinstance(tool_call_id, str) and tool_call_id:
if "|" in tool_call_id:
call_id, item_id = tool_call_id.split("|", 1)
return call_id, item_id or None
return tool_call_id, None
return "call_0", None
return await consume_sse(response, on_content_delta)
def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
@ -224,96 +140,6 @@ def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
buffer: list[str] = []
async for line in response.aiter_lines():
if line == "":
if buffer:
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
buffer = []
if not data_lines:
continue
data = "\n".join(data_lines).strip()
if not data or data == "[DONE]":
continue
try:
yield json.loads(data)
except Exception:
continue
continue
buffer.append(line)
async def _consume_sse(
response: httpx.Response,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> tuple[str, list[ToolCallRequest], str]:
content = ""
tool_calls: list[ToolCallRequest] = []
tool_call_buffers: dict[str, dict[str, Any]] = {}
finish_reason = "stop"
async for event in _iter_sse(response):
event_type = event.get("type")
if event_type == "response.output_item.added":
item = event.get("item") or {}
if item.get("type") == "function_call":
call_id = item.get("call_id")
if not call_id:
continue
tool_call_buffers[call_id] = {
"id": item.get("id") or "fc_0",
"name": item.get("name"),
"arguments": item.get("arguments") or "",
}
elif event_type == "response.output_text.delta":
delta_text = event.get("delta") or ""
content += delta_text
if on_content_delta and delta_text:
await on_content_delta(delta_text)
elif event_type == "response.function_call_arguments.delta":
call_id = event.get("call_id")
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] += event.get("delta") or ""
elif event_type == "response.function_call_arguments.done":
call_id = event.get("call_id")
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] = event.get("arguments") or ""
elif event_type == "response.output_item.done":
item = event.get("item") or {}
if item.get("type") == "function_call":
call_id = item.get("call_id")
if not call_id:
continue
buf = tool_call_buffers.get(call_id) or {}
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
try:
args = json.loads(args_raw)
except Exception:
args = {"raw": args_raw}
tool_calls.append(
ToolCallRequest(
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
name=buf.get("name") or item.get("name"),
arguments=args,
)
)
elif event_type == "response.completed":
status = (event.get("response") or {}).get("status")
finish_reason = _map_finish_reason(status)
elif event_type in {"error", "response.failed"}:
raise RuntimeError("Codex response failed")
return content, tool_calls, finish_reason
_FINISH_REASON_MAP = {"completed": "stop", "incomplete": "length", "failed": "error", "cancelled": "error"}
def _map_finish_reason(status: str | None) -> str:
return _FINISH_REASON_MAP.get(status or "completed", "stop")
def _friendly_error(status_code: int, raw: str) -> str:
if status_code == 429:
return "ChatGPT usage quota exceeded or rate limit triggered. Please try again later."

View File

@ -235,7 +235,9 @@ class OpenAICompatProvider(LLMProvider):
spec = self._spec
if spec and spec.supports_prompt_caching:
messages, tools = self._apply_cache_control(messages, tools)
model_name = model or self.default_model
if any(model_name.lower().startswith(k) for k in ("anthropic/", "claude")):
messages, tools = self._apply_cache_control(messages, tools)
if spec and spec.strip_model_prefix:
model_name = model_name.split("/")[-1]
@ -308,6 +310,13 @@ class OpenAICompatProvider(LLMProvider):
@classmethod
def _extract_usage(cls, response: Any) -> dict[str, int]:
"""Extract token usage from an OpenAI-compatible response.
Handles both dict-based (raw JSON) and object-based (SDK Pydantic)
responses. Provider-specific ``cached_tokens`` fields are normalised
under a single key; see the priority chain inside for details.
"""
# --- resolve usage object ---
usage_obj = None
response_map = cls._maybe_mapping(response)
if response_map is not None:
@ -317,19 +326,53 @@ class OpenAICompatProvider(LLMProvider):
usage_map = cls._maybe_mapping(usage_obj)
if usage_map is not None:
return {
result = {
"prompt_tokens": int(usage_map.get("prompt_tokens") or 0),
"completion_tokens": int(usage_map.get("completion_tokens") or 0),
"total_tokens": int(usage_map.get("total_tokens") or 0),
}
if usage_obj:
return {
elif usage_obj:
result = {
"prompt_tokens": getattr(usage_obj, "prompt_tokens", 0) or 0,
"completion_tokens": getattr(usage_obj, "completion_tokens", 0) or 0,
"total_tokens": getattr(usage_obj, "total_tokens", 0) or 0,
}
return {}
else:
return {}
# --- cached_tokens (normalised across providers) ---
# Try nested paths first (dict), fall back to attribute (SDK object).
# Priority order ensures the most specific field wins.
for path in (
("prompt_tokens_details", "cached_tokens"), # OpenAI/Zhipu/MiniMax/Qwen/Mistral/xAI
("cached_tokens",), # StepFun/Moonshot (top-level)
("prompt_cache_hit_tokens",), # DeepSeek/SiliconFlow
):
cached = cls._get_nested_int(usage_map, path)
if not cached and usage_obj:
cached = cls._get_nested_int(usage_obj, path)
if cached:
result["cached_tokens"] = cached
break
return result
@staticmethod
def _get_nested_int(obj: Any, path: tuple[str, ...]) -> int:
"""Drill into *obj* by *path* segments and return an ``int`` value.
Supports both dict-key access and attribute access so it works
uniformly with raw JSON dicts **and** SDK Pydantic models.
"""
current = obj
for segment in path:
if current is None:
return 0
if isinstance(current, dict):
current = current.get(segment)
else:
current = getattr(current, segment, None)
return int(current or 0) if current is not None else 0
def _parse(self, response: Any) -> LLMResponse:
if isinstance(response, str):
@ -603,4 +646,4 @@ class OpenAICompatProvider(LLMProvider):
return self._handle_error(e)
def get_default_model(self) -> str:
return self.default_model
return self.default_model

View File

@ -0,0 +1,29 @@
"""Shared helpers for OpenAI Responses API providers (Codex, Azure OpenAI)."""
from nanobot.providers.openai_responses.converters import (
convert_messages,
convert_tools,
convert_user_message,
split_tool_call_id,
)
from nanobot.providers.openai_responses.parsing import (
FINISH_REASON_MAP,
consume_sdk_stream,
consume_sse,
iter_sse,
map_finish_reason,
parse_response_output,
)
__all__ = [
"convert_messages",
"convert_tools",
"convert_user_message",
"split_tool_call_id",
"iter_sse",
"consume_sse",
"consume_sdk_stream",
"map_finish_reason",
"parse_response_output",
"FINISH_REASON_MAP",
]

View File

@ -0,0 +1,110 @@
"""Convert Chat Completions messages/tools to Responses API format."""
from __future__ import annotations
import json
from typing import Any
def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
"""Convert Chat Completions messages to Responses API input items.
Returns ``(system_prompt, input_items)`` where *system_prompt* is extracted
from any ``system`` role message and *input_items* is the Responses API
``input`` array.
"""
system_prompt = ""
input_items: list[dict[str, Any]] = []
for idx, msg in enumerate(messages):
role = msg.get("role")
content = msg.get("content")
if role == "system":
system_prompt = content if isinstance(content, str) else ""
continue
if role == "user":
input_items.append(convert_user_message(content))
continue
if role == "assistant":
if isinstance(content, str) and content:
input_items.append({
"type": "message", "role": "assistant",
"content": [{"type": "output_text", "text": content}],
"status": "completed", "id": f"msg_{idx}",
})
for tool_call in msg.get("tool_calls", []) or []:
fn = tool_call.get("function") or {}
call_id, item_id = split_tool_call_id(tool_call.get("id"))
input_items.append({
"type": "function_call",
"id": item_id or f"fc_{idx}",
"call_id": call_id or f"call_{idx}",
"name": fn.get("name"),
"arguments": fn.get("arguments") or "{}",
})
continue
if role == "tool":
call_id, _ = split_tool_call_id(msg.get("tool_call_id"))
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text})
return system_prompt, input_items
def convert_user_message(content: Any) -> dict[str, Any]:
"""Convert a user message's content to Responses API format.
Handles plain strings, ``text`` blocks -> ``input_text``, and
``image_url`` blocks -> ``input_image``.
"""
if isinstance(content, str):
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
if isinstance(content, list):
converted: list[dict[str, Any]] = []
for item in content:
if not isinstance(item, dict):
continue
if item.get("type") == "text":
converted.append({"type": "input_text", "text": item.get("text", "")})
elif item.get("type") == "image_url":
url = (item.get("image_url") or {}).get("url")
if url:
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
if converted:
return {"role": "user", "content": converted}
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert OpenAI function-calling tool schema to Responses API flat format."""
converted: list[dict[str, Any]] = []
for tool in tools:
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
name = fn.get("name")
if not name:
continue
params = fn.get("parameters") or {}
converted.append({
"type": "function",
"name": name,
"description": fn.get("description") or "",
"parameters": params if isinstance(params, dict) else {},
})
return converted
def split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
"""Split a compound ``call_id|item_id`` string.
Returns ``(call_id, item_id)`` where *item_id* may be ``None``.
"""
if isinstance(tool_call_id, str) and tool_call_id:
if "|" in tool_call_id:
call_id, item_id = tool_call_id.split("|", 1)
return call_id, item_id or None
return tool_call_id, None
return "call_0", None

View File

@ -0,0 +1,297 @@
"""Parse Responses API SSE streams and SDK response objects."""
from __future__ import annotations
import json
from collections.abc import Awaitable, Callable
from typing import Any, AsyncGenerator
import httpx
import json_repair
from loguru import logger
from nanobot.providers.base import LLMResponse, ToolCallRequest
FINISH_REASON_MAP = {
"completed": "stop",
"incomplete": "length",
"failed": "error",
"cancelled": "error",
}
def map_finish_reason(status: str | None) -> str:
"""Map a Responses API status string to a Chat-Completions-style finish_reason."""
return FINISH_REASON_MAP.get(status or "completed", "stop")
async def iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
"""Yield parsed JSON events from a Responses API SSE stream."""
buffer: list[str] = []
def _flush() -> dict[str, Any] | None:
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
buffer.clear()
if not data_lines:
return None
data = "\n".join(data_lines).strip()
if not data or data == "[DONE]":
return None
try:
return json.loads(data)
except Exception:
logger.warning("Failed to parse SSE event JSON: {}", data[:200])
return None
async for line in response.aiter_lines():
if line == "":
if buffer:
event = _flush()
if event is not None:
yield event
continue
buffer.append(line)
# Flush any remaining buffer at EOF (#10)
if buffer:
event = _flush()
if event is not None:
yield event
async def consume_sse(
response: httpx.Response,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> tuple[str, list[ToolCallRequest], str]:
"""Consume a Responses API SSE stream into ``(content, tool_calls, finish_reason)``."""
content = ""
tool_calls: list[ToolCallRequest] = []
tool_call_buffers: dict[str, dict[str, Any]] = {}
finish_reason = "stop"
async for event in iter_sse(response):
event_type = event.get("type")
if event_type == "response.output_item.added":
item = event.get("item") or {}
if item.get("type") == "function_call":
call_id = item.get("call_id")
if not call_id:
continue
tool_call_buffers[call_id] = {
"id": item.get("id") or "fc_0",
"name": item.get("name"),
"arguments": item.get("arguments") or "",
}
elif event_type == "response.output_text.delta":
delta_text = event.get("delta") or ""
content += delta_text
if on_content_delta and delta_text:
await on_content_delta(delta_text)
elif event_type == "response.function_call_arguments.delta":
call_id = event.get("call_id")
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] += event.get("delta") or ""
elif event_type == "response.function_call_arguments.done":
call_id = event.get("call_id")
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] = event.get("arguments") or ""
elif event_type == "response.output_item.done":
item = event.get("item") or {}
if item.get("type") == "function_call":
call_id = item.get("call_id")
if not call_id:
continue
buf = tool_call_buffers.get(call_id) or {}
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
try:
args = json.loads(args_raw)
except Exception:
logger.warning(
"Failed to parse tool call arguments for '{}': {}",
buf.get("name") or item.get("name"),
args_raw[:200],
)
args = json_repair.loads(args_raw)
if not isinstance(args, dict):
args = {"raw": args_raw}
tool_calls.append(
ToolCallRequest(
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
name=buf.get("name") or item.get("name") or "",
arguments=args,
)
)
elif event_type == "response.completed":
status = (event.get("response") or {}).get("status")
finish_reason = map_finish_reason(status)
elif event_type in {"error", "response.failed"}:
detail = event.get("error") or event.get("message") or event
raise RuntimeError(f"Response failed: {str(detail)[:500]}")
return content, tool_calls, finish_reason
def parse_response_output(response: Any) -> LLMResponse:
"""Parse an SDK ``Response`` object into an ``LLMResponse``."""
if not isinstance(response, dict):
dump = getattr(response, "model_dump", None)
response = dump() if callable(dump) else vars(response)
output = response.get("output") or []
content_parts: list[str] = []
tool_calls: list[ToolCallRequest] = []
reasoning_content: str | None = None
for item in output:
if not isinstance(item, dict):
dump = getattr(item, "model_dump", None)
item = dump() if callable(dump) else vars(item)
item_type = item.get("type")
if item_type == "message":
for block in item.get("content") or []:
if not isinstance(block, dict):
dump = getattr(block, "model_dump", None)
block = dump() if callable(dump) else vars(block)
if block.get("type") == "output_text":
content_parts.append(block.get("text") or "")
elif item_type == "reasoning":
for s in item.get("summary") or []:
if not isinstance(s, dict):
dump = getattr(s, "model_dump", None)
s = dump() if callable(dump) else vars(s)
if s.get("type") == "summary_text" and s.get("text"):
reasoning_content = (reasoning_content or "") + s["text"]
elif item_type == "function_call":
call_id = item.get("call_id") or ""
item_id = item.get("id") or "fc_0"
args_raw = item.get("arguments") or "{}"
try:
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
except Exception:
logger.warning(
"Failed to parse tool call arguments for '{}': {}",
item.get("name"),
str(args_raw)[:200],
)
args = json_repair.loads(args_raw) if isinstance(args_raw, str) else args_raw
if not isinstance(args, dict):
args = {"raw": args_raw}
tool_calls.append(ToolCallRequest(
id=f"{call_id}|{item_id}",
name=item.get("name") or "",
arguments=args if isinstance(args, dict) else {},
))
usage_raw = response.get("usage") or {}
if not isinstance(usage_raw, dict):
dump = getattr(usage_raw, "model_dump", None)
usage_raw = dump() if callable(dump) else vars(usage_raw)
usage = {}
if usage_raw:
usage = {
"prompt_tokens": int(usage_raw.get("input_tokens") or 0),
"completion_tokens": int(usage_raw.get("output_tokens") or 0),
"total_tokens": int(usage_raw.get("total_tokens") or 0),
}
status = response.get("status")
finish_reason = map_finish_reason(status)
return LLMResponse(
content="".join(content_parts) or None,
tool_calls=tool_calls,
finish_reason=finish_reason,
usage=usage,
reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None,
)
async def consume_sdk_stream(
stream: Any,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> tuple[str, list[ToolCallRequest], str, dict[str, int], str | None]:
"""Consume an SDK async stream from ``client.responses.create(stream=True)``."""
content = ""
tool_calls: list[ToolCallRequest] = []
tool_call_buffers: dict[str, dict[str, Any]] = {}
finish_reason = "stop"
usage: dict[str, int] = {}
reasoning_content: str | None = None
async for event in stream:
event_type = getattr(event, "type", None)
if event_type == "response.output_item.added":
item = getattr(event, "item", None)
if item and getattr(item, "type", None) == "function_call":
call_id = getattr(item, "call_id", None)
if not call_id:
continue
tool_call_buffers[call_id] = {
"id": getattr(item, "id", None) or "fc_0",
"name": getattr(item, "name", None),
"arguments": getattr(item, "arguments", None) or "",
}
elif event_type == "response.output_text.delta":
delta_text = getattr(event, "delta", "") or ""
content += delta_text
if on_content_delta and delta_text:
await on_content_delta(delta_text)
elif event_type == "response.function_call_arguments.delta":
call_id = getattr(event, "call_id", None)
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] += getattr(event, "delta", "") or ""
elif event_type == "response.function_call_arguments.done":
call_id = getattr(event, "call_id", None)
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] = getattr(event, "arguments", "") or ""
elif event_type == "response.output_item.done":
item = getattr(event, "item", None)
if item and getattr(item, "type", None) == "function_call":
call_id = getattr(item, "call_id", None)
if not call_id:
continue
buf = tool_call_buffers.get(call_id) or {}
args_raw = buf.get("arguments") or getattr(item, "arguments", None) or "{}"
try:
args = json.loads(args_raw)
except Exception:
logger.warning(
"Failed to parse tool call arguments for '{}': {}",
buf.get("name") or getattr(item, "name", None),
str(args_raw)[:200],
)
args = json_repair.loads(args_raw)
if not isinstance(args, dict):
args = {"raw": args_raw}
tool_calls.append(
ToolCallRequest(
id=f"{call_id}|{buf.get('id') or getattr(item, 'id', None) or 'fc_0'}",
name=buf.get("name") or getattr(item, "name", None) or "",
arguments=args,
)
)
elif event_type == "response.completed":
resp = getattr(event, "response", None)
status = getattr(resp, "status", None) if resp else None
finish_reason = map_finish_reason(status)
if resp:
usage_obj = getattr(resp, "usage", None)
if usage_obj:
usage = {
"prompt_tokens": int(getattr(usage_obj, "input_tokens", 0) or 0),
"completion_tokens": int(getattr(usage_obj, "output_tokens", 0) or 0),
"total_tokens": int(getattr(usage_obj, "total_tokens", 0) or 0),
}
for out_item in getattr(resp, "output", None) or []:
if getattr(out_item, "type", None) == "reasoning":
for s in getattr(out_item, "summary", None) or []:
if getattr(s, "type", None) == "summary_text":
text = getattr(s, "text", None)
if text:
reasoning_content = (reasoning_content or "") + text
elif event_type in {"error", "response.failed"}:
detail = getattr(event, "error", None) or getattr(event, "message", None) or event
raise RuntimeError(f"Response failed: {str(detail)[:500]}")
return content, tool_calls, finish_reason, usage, reasoning_content

View File

@ -34,7 +34,7 @@ class ProviderSpec:
display_name: str = "" # shown in `nanobot status`
# which provider implementation to use
# "openai_compat" | "anthropic" | "azure_openai" | "openai_codex"
# "openai_compat" | "anthropic" | "azure_openai" | "openai_codex" | "github_copilot"
backend: str = "openai_compat"
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
@ -218,8 +218,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("github_copilot", "copilot"),
env_key="",
display_name="Github Copilot",
backend="openai_compat",
backend="github_copilot",
default_api_base="https://api.githubcopilot.com",
strip_model_prefix=True,
is_oauth=True,
),
# DeepSeek: OpenAI-compatible at api.deepseek.com

View File

@ -405,14 +405,18 @@ def build_status_content(
)
last_in = last_usage.get("prompt_tokens", 0)
last_out = last_usage.get("completion_tokens", 0)
cached = last_usage.get("cached_tokens", 0)
ctx_total = max(context_window_tokens, 0)
ctx_pct = int((context_tokens_estimate / ctx_total) * 100) if ctx_total > 0 else 0
ctx_used_str = f"{context_tokens_estimate // 1000}k" if context_tokens_estimate >= 1000 else str(context_tokens_estimate)
ctx_total_str = f"{ctx_total // 1024}k" if ctx_total > 0 else "n/a"
token_line = f"\U0001f4ca Tokens: {last_in} in / {last_out} out"
if cached and last_in:
token_line += f" ({cached * 100 // last_in}% cached)"
return "\n".join([
f"\U0001f408 nanobot v{version}",
f"\U0001f9e0 Model: {model}",
f"\U0001f4ca Tokens: {last_in} in / {last_out} out",
token_line,
f"\U0001f4da Context: {ctx_used_str}/{ctx_total_str} ({ctx_pct}%)",
f"\U0001f4ac Session: {session_msg_count} messages",
f"\u23f1 Uptime: {uptime}",

View File

@ -578,3 +578,82 @@ async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, mon
args = mgr._announce_result.await_args.args
assert args[3] == "Task completed but no final response was generated."
assert args[5] == "ok"
@pytest.mark.asyncio
async def test_runner_accumulates_usage_and_preserves_cached_tokens():
"""Runner should accumulate prompt/completion tokens across iterations
and preserve cached_tokens from provider responses."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})],
usage={"prompt_tokens": 100, "completion_tokens": 10, "cached_tokens": 80},
)
return LLMResponse(
content="done",
tool_calls=[],
usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150},
)
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="file content")
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "do task"}],
tools=tools,
model="test-model",
max_iterations=3,
))
# Usage should be accumulated across iterations
assert result.usage["prompt_tokens"] == 300 # 100 + 200
assert result.usage["completion_tokens"] == 30 # 10 + 20
assert result.usage["cached_tokens"] == 230 # 80 + 150
@pytest.mark.asyncio
async def test_runner_passes_cached_tokens_to_hook_context():
"""Hook context.usage should contain cached_tokens."""
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
captured_usage: list[dict] = []
class UsageHook(AgentHook):
async def after_iteration(self, context: AgentHookContext) -> None:
captured_usage.append(dict(context.usage))
async def chat_with_retry(**kwargs):
return LLMResponse(
content="done",
tool_calls=[],
usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150},
)
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
runner = AgentRunner(provider)
await runner.run(AgentRunSpec(
initial_messages=[],
tools=tools,
model="test-model",
max_iterations=1,
hook=UsageHook(),
))
assert len(captured_usage) == 1
assert captured_usage[0]["cached_tokens"] == 150

View File

@ -208,7 +208,7 @@ def test_channels_login_uses_discovered_plugin_class(monkeypatch):
seen["config"] = self.config
return True
monkeypatch.setattr("nanobot.config.loader.load_config", lambda: Config())
monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config())
monkeypatch.setattr(
"nanobot.channels.registry.discover_all",
lambda: {"fakeplugin": _LoginPlugin},
@ -220,6 +220,57 @@ def test_channels_login_uses_discovered_plugin_class(monkeypatch):
assert seen["force"] is True
def test_channels_login_sets_custom_config_path(monkeypatch, tmp_path):
from nanobot.cli.commands import app
from nanobot.config.schema import Config
from typer.testing import CliRunner
runner = CliRunner()
seen: dict[str, object] = {}
config_path = tmp_path / "custom-config.json"
class _LoginPlugin(_FakePlugin):
async def login(self, force: bool = False) -> bool:
return True
monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config())
monkeypatch.setattr(
"nanobot.config.loader.set_config_path",
lambda path: seen.__setitem__("config_path", path),
)
monkeypatch.setattr(
"nanobot.channels.registry.discover_all",
lambda: {"fakeplugin": _LoginPlugin},
)
result = runner.invoke(app, ["channels", "login", "fakeplugin", "--config", str(config_path)])
assert result.exit_code == 0
assert seen["config_path"] == config_path.resolve()
def test_channels_status_sets_custom_config_path(monkeypatch, tmp_path):
from nanobot.cli.commands import app
from nanobot.config.schema import Config
from typer.testing import CliRunner
runner = CliRunner()
seen: dict[str, object] = {}
config_path = tmp_path / "custom-config.json"
monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config())
monkeypatch.setattr(
"nanobot.config.loader.set_config_path",
lambda path: seen.__setitem__("config_path", path),
)
monkeypatch.setattr("nanobot.channels.registry.discover_all", lambda: {})
result = runner.invoke(app, ["channels", "status", "--config", str(config_path)])
assert result.exit_code == 0
assert seen["config_path"] == config_path.resolve()
@pytest.mark.asyncio
async def test_manager_skips_disabled_plugin():
fake_config = SimpleNamespace(

View File

@ -3,16 +3,14 @@ from pathlib import Path
from types import SimpleNamespace
import pytest
pytest.importorskip("nio")
pytest.importorskip("nh3")
pytest.importorskip("mistune")
from nio import RoomSendResponse
from nanobot.channels.matrix import _build_matrix_text_content
# Check optional matrix dependencies before importing
try:
import nh3 # noqa: F401
except ImportError:
pytest.skip("Matrix dependencies not installed (nh3)", allow_module_level=True)
import nanobot.channels.matrix as matrix_module
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus

View File

@ -317,6 +317,75 @@ def test_openai_compat_provider_passes_model_through():
assert provider.get_default_model() == "github-copilot/gpt-5.3-codex"
def test_make_provider_uses_github_copilot_backend():
from nanobot.cli.commands import _make_provider
from nanobot.config.schema import Config
config = Config.model_validate(
{
"agents": {
"defaults": {
"provider": "github-copilot",
"model": "github-copilot/gpt-4.1",
}
}
}
)
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = _make_provider(config)
assert provider.__class__.__name__ == "GitHubCopilotProvider"
def test_github_copilot_provider_strips_prefixed_model_name():
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = GitHubCopilotProvider(default_model="github-copilot/gpt-5.1")
kwargs = provider._build_kwargs(
messages=[{"role": "user", "content": "hi"}],
tools=None,
model="github-copilot/gpt-5.1",
max_tokens=16,
temperature=0.1,
reasoning_effort=None,
tool_choice=None,
)
assert kwargs["model"] == "gpt-5.1"
@pytest.mark.asyncio
async def test_github_copilot_provider_refreshes_client_api_key_before_chat():
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
mock_client = MagicMock()
mock_client.api_key = "no-key"
mock_client.chat.completions.create = AsyncMock(return_value={
"choices": [{"message": {"content": "ok"}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
})
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI", return_value=mock_client):
provider = GitHubCopilotProvider(default_model="github-copilot/gpt-5.1")
provider._get_copilot_access_token = AsyncMock(return_value="copilot-access-token")
response = await provider.chat(
messages=[{"role": "user", "content": "hi"}],
model="github-copilot/gpt-5.1",
max_tokens=16,
temperature=0.1,
)
assert response.content == "ok"
assert provider._client.api_key == "copilot-access-token"
provider._get_copilot_access_token.assert_awaited_once()
mock_client.chat.completions.create.assert_awaited_once()
def test_openai_codex_strip_prefix_supports_hyphen_and_underscore():
assert _strip_model_prefix("openai-codex/gpt-5.1-codex") == "gpt-5.1-codex"
assert _strip_model_prefix("openai_codex/gpt-5.1-codex") == "gpt-5.1-codex"

View File

@ -152,10 +152,12 @@ class TestRestartCommand:
])
await loop._run_agent_loop([])
assert loop._last_usage == {"prompt_tokens": 9, "completion_tokens": 4}
assert loop._last_usage["prompt_tokens"] == 9
assert loop._last_usage["completion_tokens"] == 4
await loop._run_agent_loop([])
assert loop._last_usage == {"prompt_tokens": 0, "completion_tokens": 0}
assert loop._last_usage["prompt_tokens"] == 0
assert loop._last_usage["completion_tokens"] == 0
@pytest.mark.asyncio
async def test_status_falls_back_to_last_usage_when_context_estimate_missing(self):

View File

@ -285,6 +285,28 @@ def test_add_at_job_uses_default_timezone_for_naive_datetime(tmp_path) -> None:
assert job.schedule.at_ms == expected
def test_add_job_delivers_by_default(tmp_path) -> None:
tool = _make_tool(tmp_path)
tool.set_context("telegram", "chat-1")
result = tool._add_job("Morning standup", 60, None, None, None)
assert result.startswith("Created job")
job = tool._cron.list_jobs()[0]
assert job.payload.deliver is True
def test_add_job_can_disable_delivery(tmp_path) -> None:
tool = _make_tool(tmp_path)
tool.set_context("telegram", "chat-1")
result = tool._add_job("Background refresh", 60, None, None, None, deliver=False)
assert result.startswith("Created job")
job = tool._cron.list_jobs()[0]
assert job.payload.deliver is False
def test_list_excludes_disabled_jobs(tmp_path) -> None:
tool = _make_tool(tmp_path)
job = tool._cron.add_job(

View File

@ -1,6 +1,6 @@
"""Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
"""Test Azure OpenAI provider (Responses API via OpenAI SDK)."""
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import AsyncMock, MagicMock
import pytest
@ -8,392 +8,401 @@ from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
from nanobot.providers.base import LLMResponse
def test_azure_openai_provider_init():
"""Test AzureOpenAIProvider initialization without deployment_name."""
# ---------------------------------------------------------------------------
# Init & validation
# ---------------------------------------------------------------------------
def test_init_creates_sdk_client():
"""Provider creates an AsyncOpenAI client with correct base_url."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
)
assert provider.api_key == "test-key"
assert provider.api_base == "https://test-resource.openai.azure.com/"
assert provider.default_model == "gpt-4o-deployment"
assert provider.api_version == "2024-10-21"
# SDK client base_url ends with /openai/v1/
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
def test_azure_openai_provider_init_validation():
"""Test AzureOpenAIProvider initialization validation."""
# Missing api_key
def test_init_base_url_no_trailing_slash():
"""Trailing slashes are normalised before building base_url."""
provider = AzureOpenAIProvider(
api_key="k", api_base="https://res.openai.azure.com",
)
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
def test_init_base_url_with_trailing_slash():
provider = AzureOpenAIProvider(
api_key="k", api_base="https://res.openai.azure.com/",
)
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
def test_init_validation_missing_key():
with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
AzureOpenAIProvider(api_key="", api_base="https://test.com")
# Missing api_base
def test_init_validation_missing_base():
with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
AzureOpenAIProvider(api_key="test", api_base="")
def test_build_chat_url():
"""Test Azure OpenAI URL building with different deployment names."""
def test_no_api_version_in_base_url():
"""The /openai/v1/ path should NOT contain an api-version query param."""
provider = AzureOpenAIProvider(api_key="k", api_base="https://res.openai.azure.com")
base = str(provider._client.base_url)
assert "api-version" not in base
# ---------------------------------------------------------------------------
# _supports_temperature
# ---------------------------------------------------------------------------
def test_supports_temperature_standard_model():
assert AzureOpenAIProvider._supports_temperature("gpt-4o") is True
def test_supports_temperature_reasoning_model():
assert AzureOpenAIProvider._supports_temperature("o3-mini") is False
assert AzureOpenAIProvider._supports_temperature("gpt-5-chat") is False
assert AzureOpenAIProvider._supports_temperature("o4-mini") is False
def test_supports_temperature_with_reasoning_effort():
assert AzureOpenAIProvider._supports_temperature("gpt-4o", reasoning_effort="medium") is False
# ---------------------------------------------------------------------------
# _build_body — Responses API body construction
# ---------------------------------------------------------------------------
def test_build_body_basic():
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
api_key="k", api_base="https://res.openai.azure.com", default_model="gpt-4o",
)
# Test various deployment names
test_cases = [
("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"),
("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"),
("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"),
]
for deployment_name, expected_url in test_cases:
url = provider._build_chat_url(deployment_name)
assert url == expected_url
messages = [{"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hi"}]
body = provider._build_body(messages, None, None, 4096, 0.7, None, None)
def test_build_chat_url_api_base_without_slash():
"""Test URL building when api_base doesn't end with slash."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com", # No trailing slash
default_model="gpt-4o",
assert body["model"] == "gpt-4o"
assert body["instructions"] == "You are helpful."
assert body["temperature"] == 0.7
assert body["max_output_tokens"] == 4096
assert body["store"] is False
assert "reasoning" not in body
# input should contain the converted user message only (system extracted)
assert any(
item.get("role") == "user"
for item in body["input"]
)
url = provider._build_chat_url("test-deployment")
expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21"
assert url == expected
def test_build_headers():
"""Test Azure OpenAI header building with api-key authentication."""
provider = AzureOpenAIProvider(
api_key="test-api-key-123",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
headers = provider._build_headers()
assert headers["Content-Type"] == "application/json"
assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header
assert "x-session-affinity" in headers
def test_build_body_max_tokens_minimum():
"""max_output_tokens should never be less than 1."""
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
body = provider._build_body([{"role": "user", "content": "x"}], None, None, 0, 0.7, None, None)
assert body["max_output_tokens"] == 1
def test_prepare_request_payload():
"""Test request payload preparation with Azure OpenAI 2024-10-21 compliance."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
messages = [{"role": "user", "content": "Hello"}]
payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
assert payload["messages"] == messages
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
assert payload["temperature"] == 0.8
assert "tools" not in payload
# Test with tools
def test_build_body_with_tools():
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools)
assert payload_with_tools["tools"] == tools
assert payload_with_tools["tool_choice"] == "auto"
# Test with reasoning_effort
payload_with_reasoning = provider._prepare_request_payload(
"gpt-5-chat", messages, reasoning_effort="medium"
body = provider._build_body(
[{"role": "user", "content": "weather?"}], tools, None, 4096, 0.7, None, None,
)
assert payload_with_reasoning["reasoning_effort"] == "medium"
assert "temperature" not in payload_with_reasoning
assert body["tools"] == [{"type": "function", "name": "get_weather", "description": "", "parameters": {}}]
assert body["tool_choice"] == "auto"
def test_prepare_request_payload_sanitizes_messages():
"""Test Azure payload strips non-standard message keys before sending."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
def test_build_body_with_reasoning():
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-5-chat")
body = provider._build_body(
[{"role": "user", "content": "think"}], None, "gpt-5-chat", 4096, 0.7, "medium", None,
)
assert body["reasoning"] == {"effort": "medium"}
assert "reasoning.encrypted_content" in body.get("include", [])
# temperature omitted for reasoning models
assert "temperature" not in body
messages = [
{
"role": "assistant",
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
"reasoning_content": "hidden chain-of-thought",
},
{
"role": "tool",
"tool_call_id": "call_123",
"name": "x",
"content": "ok",
"extra_field": "should be removed",
},
]
payload = provider._prepare_request_payload("gpt-4o", messages)
def test_build_body_image_conversion():
"""image_url content blocks should be converted to input_image."""
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
messages = [{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "https://example.com/img.png"}},
],
}]
body = provider._build_body(messages, None, None, 4096, 0.7, None, None)
user_item = body["input"][0]
content_types = [b["type"] for b in user_item["content"]]
assert "input_text" in content_types
assert "input_image" in content_types
image_block = next(b for b in user_item["content"] if b["type"] == "input_image")
assert image_block["image_url"] == "https://example.com/img.png"
assert payload["messages"] == [
{
"role": "assistant",
"content": None,
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
def test_build_body_sanitizes_single_dict_content_block():
"""Single content dicts should be preserved via shared message sanitization."""
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
messages = [{
"role": "user",
"content": {"type": "text", "text": "Hi from dict content"},
}]
body = provider._build_body(messages, None, None, 4096, 0.7, None, None)
assert body["input"][0]["content"] == [{"type": "input_text", "text": "Hi from dict content"}]
# ---------------------------------------------------------------------------
# chat() — non-streaming
# ---------------------------------------------------------------------------
def _make_sdk_response(
content="Hello!", tool_calls=None, status="completed",
usage=None,
):
"""Build a mock that quacks like an openai Response object."""
resp = MagicMock()
resp.model_dump = MagicMock(return_value={
"output": [
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": content}]},
*([{
"type": "function_call",
"call_id": tc["call_id"], "id": tc["id"],
"name": tc["name"], "arguments": tc["arguments"],
} for tc in (tool_calls or [])]),
],
"status": status,
"usage": {
"input_tokens": (usage or {}).get("input_tokens", 10),
"output_tokens": (usage or {}).get("output_tokens", 5),
"total_tokens": (usage or {}).get("total_tokens", 15),
},
{
"role": "tool",
"tool_call_id": "call_123",
"name": "x",
"content": "ok",
},
]
})
return resp
@pytest.mark.asyncio
async def test_chat_success():
"""Test successful chat request using model as deployment name."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
# Mock response data
mock_response_data = {
"choices": [{
"message": {
"content": "Hello! How can I help you today?",
"role": "assistant"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 12,
"completion_tokens": 18,
"total_tokens": 30
}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
# Test with specific model (deployment name)
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages, model="custom-deployment")
assert isinstance(result, LLMResponse)
assert result.content == "Hello! How can I help you today?"
assert result.finish_reason == "stop"
assert result.usage["prompt_tokens"] == 12
assert result.usage["completion_tokens"] == 18
assert result.usage["total_tokens"] == 30
# Verify URL was built with the provided model as deployment name
call_args = mock_context.post.call_args
expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21"
assert call_args[0][0] == expected_url
mock_resp = _make_sdk_response(content="Hello!")
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_resp)
result = await provider.chat([{"role": "user", "content": "Hi"}])
assert isinstance(result, LLMResponse)
assert result.content == "Hello!"
assert result.finish_reason == "stop"
assert result.usage["prompt_tokens"] == 10
@pytest.mark.asyncio
async def test_chat_uses_default_model_when_no_model_provided():
"""Test that chat uses default_model when no model is specified."""
async def test_chat_uses_default_model():
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="default-deployment",
api_key="k", api_base="https://test.openai.azure.com", default_model="my-deployment",
)
mock_response_data = {
"choices": [{
"message": {"content": "Response", "role": "assistant"},
"finish_reason": "stop"
}],
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Test"}]
await provider.chat(messages) # No model specified
# Verify URL was built with default model as deployment name
call_args = mock_context.post.call_args
expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21"
assert call_args[0][0] == expected_url
mock_resp = _make_sdk_response(content="ok")
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_resp)
await provider.chat([{"role": "user", "content": "test"}])
call_kwargs = provider._client.responses.create.call_args[1]
assert call_kwargs["model"] == "my-deployment"
@pytest.mark.asyncio
async def test_chat_custom_model():
provider = AzureOpenAIProvider(
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
mock_resp = _make_sdk_response(content="ok")
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_resp)
await provider.chat([{"role": "user", "content": "test"}], model="custom-deploy")
call_kwargs = provider._client.responses.create.call_args[1]
assert call_kwargs["model"] == "custom-deploy"
@pytest.mark.asyncio
async def test_chat_with_tool_calls():
"""Test chat request with tool calls in response."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
# Mock response with tool calls
mock_response_data = {
"choices": [{
"message": {
"content": None,
"role": "assistant",
"tool_calls": [{
"id": "call_12345",
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}'
}
}]
},
"finish_reason": "tool_calls"
mock_resp = _make_sdk_response(
content=None,
tool_calls=[{
"call_id": "call_123", "id": "fc_1",
"name": "get_weather", "arguments": '{"location": "SF"}',
}],
"usage": {
"prompt_tokens": 20,
"completion_tokens": 15,
"total_tokens": 35
}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "What's the weather?"}]
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
result = await provider.chat(messages, tools=tools, model="weather-model")
assert isinstance(result, LLMResponse)
assert result.content is None
assert result.finish_reason == "tool_calls"
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "get_weather"
assert result.tool_calls[0].arguments == {"location": "San Francisco"}
)
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_resp)
result = await provider.chat(
[{"role": "user", "content": "Weather?"}],
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
)
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "get_weather"
assert result.tool_calls[0].arguments == {"location": "SF"}
@pytest.mark.asyncio
async def test_chat_api_error():
"""Test chat request API error handling."""
async def test_chat_error_handling():
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 401
mock_response.text = "Invalid authentication credentials"
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages)
assert isinstance(result, LLMResponse)
assert "Azure OpenAI API Error 401" in result.content
assert "Invalid authentication credentials" in result.content
assert result.finish_reason == "error"
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed"))
result = await provider.chat([{"role": "user", "content": "Hi"}])
@pytest.mark.asyncio
async def test_chat_connection_error():
"""Test chat request connection error handling."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
with patch("httpx.AsyncClient") as mock_client:
mock_context = AsyncMock()
mock_context.post = AsyncMock(side_effect=Exception("Connection failed"))
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages)
assert isinstance(result, LLMResponse)
assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content
assert result.finish_reason == "error"
def test_parse_response_malformed():
"""Test response parsing with malformed data."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
# Test with missing choices
malformed_response = {"usage": {"prompt_tokens": 10}}
result = provider._parse_response(malformed_response)
assert isinstance(result, LLMResponse)
assert "Error parsing Azure OpenAI response" in result.content
assert "Connection failed" in result.content
assert result.finish_reason == "error"
@pytest.mark.asyncio
async def test_chat_reasoning_param_format():
"""reasoning_effort should be sent as reasoning={effort: ...} not a flat string."""
provider = AzureOpenAIProvider(
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-5-chat",
)
mock_resp = _make_sdk_response(content="thought")
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_resp)
await provider.chat(
[{"role": "user", "content": "think"}], reasoning_effort="medium",
)
call_kwargs = provider._client.responses.create.call_args[1]
assert call_kwargs["reasoning"] == {"effort": "medium"}
assert "reasoning_effort" not in call_kwargs
# ---------------------------------------------------------------------------
# chat_stream()
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_chat_stream_success():
"""Streaming should call on_content_delta and return combined response."""
provider = AzureOpenAIProvider(
api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
# Build mock SDK stream events
events = []
ev1 = MagicMock(type="response.output_text.delta", delta="Hello")
ev2 = MagicMock(type="response.output_text.delta", delta=" world")
resp_obj = MagicMock(status="completed")
ev3 = MagicMock(type="response.completed", response=resp_obj)
events = [ev1, ev2, ev3]
async def mock_stream():
for e in events:
yield e
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_stream())
deltas: list[str] = []
async def on_delta(text: str) -> None:
deltas.append(text)
result = await provider.chat_stream(
[{"role": "user", "content": "Hi"}], on_content_delta=on_delta,
)
assert result.content == "Hello world"
assert result.finish_reason == "stop"
assert deltas == ["Hello", " world"]
@pytest.mark.asyncio
async def test_chat_stream_with_tool_calls():
"""Streaming tool calls should be accumulated correctly."""
provider = AzureOpenAIProvider(
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
item_added = MagicMock(type="function_call", call_id="call_1", id="fc_1", arguments="")
item_added.name = "get_weather"
ev_added = MagicMock(type="response.output_item.added", item=item_added)
ev_args_delta = MagicMock(type="response.function_call_arguments.delta", call_id="call_1", delta='{"loc')
ev_args_done = MagicMock(
type="response.function_call_arguments.done",
call_id="call_1", arguments='{"location":"SF"}',
)
item_done = MagicMock(
type="function_call", call_id="call_1", id="fc_1",
arguments='{"location":"SF"}',
)
item_done.name = "get_weather"
ev_item_done = MagicMock(type="response.output_item.done", item=item_done)
resp_obj = MagicMock(status="completed")
ev_completed = MagicMock(type="response.completed", response=resp_obj)
async def mock_stream():
for e in [ev_added, ev_args_delta, ev_args_done, ev_item_done, ev_completed]:
yield e
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_stream())
result = await provider.chat_stream(
[{"role": "user", "content": "weather?"}],
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
)
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "get_weather"
assert result.tool_calls[0].arguments == {"location": "SF"}
@pytest.mark.asyncio
async def test_chat_stream_error():
"""Streaming should return error when SDK raises."""
provider = AzureOpenAIProvider(
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed"))
result = await provider.chat_stream([{"role": "user", "content": "Hi"}])
assert "Connection failed" in result.content
assert result.finish_reason == "error"
# ---------------------------------------------------------------------------
# get_default_model
# ---------------------------------------------------------------------------
def test_get_default_model():
"""Test get_default_model method."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="my-custom-deployment",
api_key="k", api_base="https://r.com", default_model="my-deploy",
)
assert provider.get_default_model() == "my-custom-deployment"
if __name__ == "__main__":
# Run basic tests
print("Running basic Azure OpenAI provider tests...")
# Test initialization
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
)
print("✅ Provider initialization successful")
# Test URL building
url = provider._build_chat_url("my-deployment")
expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21"
assert url == expected
print("✅ URL building works correctly")
# Test headers
headers = provider._build_headers()
assert headers["api-key"] == "test-key"
assert headers["Content-Type"] == "application/json"
print("✅ Header building works correctly")
# Test payload preparation
messages = [{"role": "user", "content": "Test"}]
payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
print("✅ Payload preparation works correctly")
print("✅ All basic tests passed! Updated test file is working correctly.")
assert provider.get_default_model() == "my-deploy"

View File

@ -0,0 +1,233 @@
"""Tests for cached token extraction from OpenAI-compatible providers."""
from __future__ import annotations
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
class FakeUsage:
"""Mimics an OpenAI SDK usage object (has attributes, not dict keys)."""
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
class FakePromptDetails:
"""Mimics prompt_tokens_details sub-object."""
def __init__(self, cached_tokens=0):
self.cached_tokens = cached_tokens
class _FakeSpec:
supports_prompt_caching = False
model_id_prefix = None
strip_model_prefix = False
max_completion_tokens = False
reasoning_effort = None
def _provider():
from unittest.mock import MagicMock
p = OpenAICompatProvider.__new__(OpenAICompatProvider)
p.client = MagicMock()
p.spec = _FakeSpec()
return p
# Minimal valid choice so _parse reaches _extract_usage.
_DICT_CHOICE = {"message": {"content": "Hello"}}
class _FakeMessage:
content = "Hello"
tool_calls = None
class _FakeChoice:
message = _FakeMessage()
finish_reason = "stop"
# --- dict-based response (raw JSON / mapping) ---
def test_extract_usage_openai_cached_tokens_dict():
"""prompt_tokens_details.cached_tokens from a dict response."""
p = _provider()
response = {
"choices": [_DICT_CHOICE],
"usage": {
"prompt_tokens": 2000,
"completion_tokens": 300,
"total_tokens": 2300,
"prompt_tokens_details": {"cached_tokens": 1200},
}
}
result = p._parse(response)
assert result.usage["cached_tokens"] == 1200
assert result.usage["prompt_tokens"] == 2000
def test_extract_usage_deepseek_cached_tokens_dict():
"""prompt_cache_hit_tokens from a DeepSeek dict response."""
p = _provider()
response = {
"choices": [_DICT_CHOICE],
"usage": {
"prompt_tokens": 1500,
"completion_tokens": 200,
"total_tokens": 1700,
"prompt_cache_hit_tokens": 1200,
"prompt_cache_miss_tokens": 300,
}
}
result = p._parse(response)
assert result.usage["cached_tokens"] == 1200
def test_extract_usage_no_cached_tokens_dict():
"""Response without any cache fields -> no cached_tokens key."""
p = _provider()
response = {
"choices": [_DICT_CHOICE],
"usage": {
"prompt_tokens": 1000,
"completion_tokens": 200,
"total_tokens": 1200,
}
}
result = p._parse(response)
assert "cached_tokens" not in result.usage
def test_extract_usage_openai_cached_zero_dict():
"""cached_tokens=0 should NOT be included (same as existing fields)."""
p = _provider()
response = {
"choices": [_DICT_CHOICE],
"usage": {
"prompt_tokens": 2000,
"completion_tokens": 300,
"total_tokens": 2300,
"prompt_tokens_details": {"cached_tokens": 0},
}
}
result = p._parse(response)
assert "cached_tokens" not in result.usage
# --- object-based response (OpenAI SDK Pydantic model) ---
def test_extract_usage_openai_cached_tokens_obj():
"""prompt_tokens_details.cached_tokens from an SDK object response."""
p = _provider()
usage_obj = FakeUsage(
prompt_tokens=2000,
completion_tokens=300,
total_tokens=2300,
prompt_tokens_details=FakePromptDetails(cached_tokens=1200),
)
response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj)
result = p._parse(response)
assert result.usage["cached_tokens"] == 1200
def test_extract_usage_deepseek_cached_tokens_obj():
"""prompt_cache_hit_tokens from a DeepSeek SDK object response."""
p = _provider()
usage_obj = FakeUsage(
prompt_tokens=1500,
completion_tokens=200,
total_tokens=1700,
prompt_cache_hit_tokens=1200,
)
response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj)
result = p._parse(response)
assert result.usage["cached_tokens"] == 1200
def test_extract_usage_stepfun_top_level_cached_tokens_dict():
"""StepFun/Moonshot: usage.cached_tokens at top level (not nested)."""
p = _provider()
response = {
"choices": [_DICT_CHOICE],
"usage": {
"prompt_tokens": 591,
"completion_tokens": 120,
"total_tokens": 711,
"cached_tokens": 512,
}
}
result = p._parse(response)
assert result.usage["cached_tokens"] == 512
def test_extract_usage_stepfun_top_level_cached_tokens_obj():
"""StepFun/Moonshot: usage.cached_tokens as SDK object attribute."""
p = _provider()
usage_obj = FakeUsage(
prompt_tokens=591,
completion_tokens=120,
total_tokens=711,
cached_tokens=512,
)
response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj)
result = p._parse(response)
assert result.usage["cached_tokens"] == 512
def test_extract_usage_priority_nested_over_top_level_dict():
"""When both nested and top-level cached_tokens exist, nested wins."""
p = _provider()
response = {
"choices": [_DICT_CHOICE],
"usage": {
"prompt_tokens": 2000,
"completion_tokens": 300,
"total_tokens": 2300,
"prompt_tokens_details": {"cached_tokens": 100},
"cached_tokens": 500,
}
}
result = p._parse(response)
assert result.usage["cached_tokens"] == 100
def test_anthropic_maps_cache_fields_to_cached_tokens():
"""Anthropic's cache_read_input_tokens should map to cached_tokens."""
from nanobot.providers.anthropic_provider import AnthropicProvider
usage_obj = FakeUsage(
input_tokens=800,
output_tokens=200,
cache_creation_input_tokens=300,
cache_read_input_tokens=1200,
)
content_block = FakeUsage(type="text", text="hello")
response = FakeUsage(
id="msg_1",
type="message",
stop_reason="end_turn",
content=[content_block],
usage=usage_obj,
)
result = AnthropicProvider._parse_response(response)
assert result.usage["cached_tokens"] == 1200
assert result.usage["prompt_tokens"] == 2300
assert result.usage["total_tokens"] == 2500
assert result.usage["cache_creation_input_tokens"] == 300
def test_anthropic_no_cache_fields():
"""Anthropic response without cache fields should not have cached_tokens."""
from nanobot.providers.anthropic_provider import AnthropicProvider
usage_obj = FakeUsage(input_tokens=800, output_tokens=200)
content_block = FakeUsage(type="text", text="hello")
response = FakeUsage(
id="msg_1",
type="message",
stop_reason="end_turn",
content=[content_block],
usage=usage_obj,
)
result = AnthropicProvider._parse_response(response)
assert "cached_tokens" not in result.usage

View File

@ -0,0 +1,522 @@
"""Tests for the shared openai_responses converters and parsers."""
from unittest.mock import MagicMock, patch
import pytest
from nanobot.providers.base import LLMResponse, ToolCallRequest
from nanobot.providers.openai_responses.converters import (
convert_messages,
convert_tools,
convert_user_message,
split_tool_call_id,
)
from nanobot.providers.openai_responses.parsing import (
consume_sdk_stream,
map_finish_reason,
parse_response_output,
)
# ======================================================================
# converters - split_tool_call_id
# ======================================================================
class TestSplitToolCallId:
def test_plain_id(self):
assert split_tool_call_id("call_abc") == ("call_abc", None)
def test_compound_id(self):
assert split_tool_call_id("call_abc|fc_1") == ("call_abc", "fc_1")
def test_compound_empty_item_id(self):
assert split_tool_call_id("call_abc|") == ("call_abc", None)
def test_none(self):
assert split_tool_call_id(None) == ("call_0", None)
def test_empty_string(self):
assert split_tool_call_id("") == ("call_0", None)
def test_non_string(self):
assert split_tool_call_id(42) == ("call_0", None)
# ======================================================================
# converters - convert_user_message
# ======================================================================
class TestConvertUserMessage:
def test_string_content(self):
result = convert_user_message("hello")
assert result == {"role": "user", "content": [{"type": "input_text", "text": "hello"}]}
def test_text_block(self):
result = convert_user_message([{"type": "text", "text": "hi"}])
assert result["content"] == [{"type": "input_text", "text": "hi"}]
def test_image_url_block(self):
result = convert_user_message([
{"type": "image_url", "image_url": {"url": "https://img.example/a.png"}},
])
assert result["content"] == [
{"type": "input_image", "image_url": "https://img.example/a.png", "detail": "auto"},
]
def test_mixed_text_and_image(self):
result = convert_user_message([
{"type": "text", "text": "what's this?"},
{"type": "image_url", "image_url": {"url": "https://img.example/b.png"}},
])
assert len(result["content"]) == 2
assert result["content"][0]["type"] == "input_text"
assert result["content"][1]["type"] == "input_image"
def test_empty_list_falls_back(self):
result = convert_user_message([])
assert result["content"] == [{"type": "input_text", "text": ""}]
def test_none_falls_back(self):
result = convert_user_message(None)
assert result["content"] == [{"type": "input_text", "text": ""}]
def test_image_without_url_skipped(self):
result = convert_user_message([{"type": "image_url", "image_url": {}}])
assert result["content"] == [{"type": "input_text", "text": ""}]
def test_meta_fields_not_leaked(self):
"""_meta on content blocks must never appear in converted output."""
result = convert_user_message([
{"type": "text", "text": "hi", "_meta": {"path": "/tmp/x"}},
])
assert "_meta" not in result["content"][0]
def test_non_dict_items_skipped(self):
result = convert_user_message(["just a string", 42])
assert result["content"] == [{"type": "input_text", "text": ""}]
# ======================================================================
# converters - convert_messages
# ======================================================================
class TestConvertMessages:
def test_system_extracted_as_instructions(self):
msgs = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hi"},
]
instructions, items = convert_messages(msgs)
assert instructions == "You are helpful."
assert len(items) == 1
assert items[0]["role"] == "user"
def test_multiple_system_messages_last_wins(self):
msgs = [
{"role": "system", "content": "first"},
{"role": "system", "content": "second"},
{"role": "user", "content": "x"},
]
instructions, _ = convert_messages(msgs)
assert instructions == "second"
def test_user_message_converted(self):
_, items = convert_messages([{"role": "user", "content": "hello"}])
assert items[0]["role"] == "user"
assert items[0]["content"][0]["type"] == "input_text"
def test_assistant_text_message(self):
_, items = convert_messages([
{"role": "assistant", "content": "I'll help"},
])
assert items[0]["type"] == "message"
assert items[0]["role"] == "assistant"
assert items[0]["content"][0]["type"] == "output_text"
assert items[0]["content"][0]["text"] == "I'll help"
def test_assistant_empty_content_skipped(self):
_, items = convert_messages([{"role": "assistant", "content": ""}])
assert len(items) == 0
def test_assistant_with_tool_calls(self):
_, items = convert_messages([{
"role": "assistant",
"content": None,
"tool_calls": [{
"id": "call_abc|fc_1",
"function": {"name": "get_weather", "arguments": '{"city":"SF"}'},
}],
}])
assert items[0]["type"] == "function_call"
assert items[0]["call_id"] == "call_abc"
assert items[0]["id"] == "fc_1"
assert items[0]["name"] == "get_weather"
def test_assistant_with_tool_calls_no_id(self):
"""Fallback IDs when tool_call.id is missing."""
_, items = convert_messages([{
"role": "assistant",
"content": None,
"tool_calls": [{"function": {"name": "f1", "arguments": "{}"}}],
}])
assert items[0]["call_id"] == "call_0"
assert items[0]["id"].startswith("fc_")
def test_tool_message(self):
_, items = convert_messages([{
"role": "tool",
"tool_call_id": "call_abc",
"content": "result text",
}])
assert items[0]["type"] == "function_call_output"
assert items[0]["call_id"] == "call_abc"
assert items[0]["output"] == "result text"
def test_tool_message_dict_content(self):
_, items = convert_messages([{
"role": "tool",
"tool_call_id": "call_1",
"content": {"key": "value"},
}])
assert items[0]["output"] == '{"key": "value"}'
def test_non_standard_keys_not_leaked(self):
"""Extra keys on messages must not appear in converted items."""
_, items = convert_messages([{
"role": "user",
"content": "hi",
"extra_field": "should vanish",
"_meta": {"path": "/tmp"},
}])
item = items[0]
assert "extra_field" not in str(item)
assert "_meta" not in str(item)
def test_full_conversation_roundtrip(self):
"""System + user + assistant(tool_call) + tool -> correct structure."""
msgs = [
{"role": "system", "content": "Be concise."},
{"role": "user", "content": "Weather in SF?"},
{
"role": "assistant", "content": None,
"tool_calls": [{
"id": "c1|fc1",
"function": {"name": "get_weather", "arguments": '{"city":"SF"}'},
}],
},
{"role": "tool", "tool_call_id": "c1", "content": '{"temp":72}'},
]
instructions, items = convert_messages(msgs)
assert instructions == "Be concise."
assert len(items) == 3 # user, function_call, function_call_output
assert items[0]["role"] == "user"
assert items[1]["type"] == "function_call"
assert items[2]["type"] == "function_call_output"
# ======================================================================
# converters - convert_tools
# ======================================================================
class TestConvertTools:
def test_standard_function_tool(self):
tools = [{"type": "function", "function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
}}]
result = convert_tools(tools)
assert len(result) == 1
assert result[0]["type"] == "function"
assert result[0]["name"] == "get_weather"
assert result[0]["description"] == "Get weather"
assert "properties" in result[0]["parameters"]
def test_tool_without_name_skipped(self):
tools = [{"type": "function", "function": {"parameters": {}}}]
assert convert_tools(tools) == []
def test_tool_without_function_wrapper(self):
"""Direct dict without type=function wrapper."""
tools = [{"name": "f1", "description": "d", "parameters": {}}]
result = convert_tools(tools)
assert result[0]["name"] == "f1"
def test_missing_optional_fields_default(self):
tools = [{"type": "function", "function": {"name": "f"}}]
result = convert_tools(tools)
assert result[0]["description"] == ""
assert result[0]["parameters"] == {}
def test_multiple_tools(self):
tools = [
{"type": "function", "function": {"name": "a", "parameters": {}}},
{"type": "function", "function": {"name": "b", "parameters": {}}},
]
assert len(convert_tools(tools)) == 2
# ======================================================================
# parsing - map_finish_reason
# ======================================================================
class TestMapFinishReason:
def test_completed(self):
assert map_finish_reason("completed") == "stop"
def test_incomplete(self):
assert map_finish_reason("incomplete") == "length"
def test_failed(self):
assert map_finish_reason("failed") == "error"
def test_cancelled(self):
assert map_finish_reason("cancelled") == "error"
def test_none_defaults_to_stop(self):
assert map_finish_reason(None) == "stop"
def test_unknown_defaults_to_stop(self):
assert map_finish_reason("some_new_status") == "stop"
# ======================================================================
# parsing - parse_response_output
# ======================================================================
class TestParseResponseOutput:
def test_text_response(self):
resp = {
"output": [{"type": "message", "role": "assistant",
"content": [{"type": "output_text", "text": "Hello!"}]}],
"status": "completed",
"usage": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
}
result = parse_response_output(resp)
assert result.content == "Hello!"
assert result.finish_reason == "stop"
assert result.usage == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
assert result.tool_calls == []
def test_tool_call_response(self):
resp = {
"output": [{
"type": "function_call",
"call_id": "call_1", "id": "fc_1",
"name": "get_weather",
"arguments": '{"city": "SF"}',
}],
"status": "completed",
"usage": {},
}
result = parse_response_output(resp)
assert result.content is None
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "get_weather"
assert result.tool_calls[0].arguments == {"city": "SF"}
assert result.tool_calls[0].id == "call_1|fc_1"
def test_malformed_tool_arguments_logged(self):
"""Malformed JSON arguments should log a warning and fallback."""
resp = {
"output": [{
"type": "function_call",
"call_id": "c1", "id": "fc1",
"name": "f", "arguments": "{bad json",
}],
"status": "completed", "usage": {},
}
with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger:
result = parse_response_output(resp)
assert result.tool_calls[0].arguments == {"raw": "{bad json"}
mock_logger.warning.assert_called_once()
assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args)
def test_reasoning_content_extracted(self):
resp = {
"output": [
{"type": "reasoning", "summary": [
{"type": "summary_text", "text": "I think "},
{"type": "summary_text", "text": "therefore I am."},
]},
{"type": "message", "role": "assistant",
"content": [{"type": "output_text", "text": "42"}]},
],
"status": "completed", "usage": {},
}
result = parse_response_output(resp)
assert result.content == "42"
assert result.reasoning_content == "I think therefore I am."
def test_empty_output(self):
resp = {"output": [], "status": "completed", "usage": {}}
result = parse_response_output(resp)
assert result.content is None
assert result.tool_calls == []
def test_incomplete_status(self):
resp = {"output": [], "status": "incomplete", "usage": {}}
result = parse_response_output(resp)
assert result.finish_reason == "length"
def test_sdk_model_object(self):
"""parse_response_output should handle SDK objects with model_dump()."""
mock = MagicMock()
mock.model_dump.return_value = {
"output": [{"type": "message", "role": "assistant",
"content": [{"type": "output_text", "text": "sdk"}]}],
"status": "completed",
"usage": {"input_tokens": 1, "output_tokens": 2, "total_tokens": 3},
}
result = parse_response_output(mock)
assert result.content == "sdk"
assert result.usage["prompt_tokens"] == 1
def test_usage_maps_responses_api_keys(self):
"""Responses API uses input_tokens/output_tokens, not prompt_tokens/completion_tokens."""
resp = {
"output": [],
"status": "completed",
"usage": {"input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
}
result = parse_response_output(resp)
assert result.usage["prompt_tokens"] == 100
assert result.usage["completion_tokens"] == 50
assert result.usage["total_tokens"] == 150
# ======================================================================
# parsing - consume_sdk_stream
# ======================================================================
class TestConsumeSdkStream:
@pytest.mark.asyncio
async def test_text_stream(self):
ev1 = MagicMock(type="response.output_text.delta", delta="Hello")
ev2 = MagicMock(type="response.output_text.delta", delta=" world")
resp_obj = MagicMock(status="completed", usage=None, output=[])
ev3 = MagicMock(type="response.completed", response=resp_obj)
async def stream():
for e in [ev1, ev2, ev3]:
yield e
content, tool_calls, finish_reason, usage, reasoning = await consume_sdk_stream(stream())
assert content == "Hello world"
assert tool_calls == []
assert finish_reason == "stop"
@pytest.mark.asyncio
async def test_on_content_delta_called(self):
ev1 = MagicMock(type="response.output_text.delta", delta="hi")
resp_obj = MagicMock(status="completed", usage=None, output=[])
ev2 = MagicMock(type="response.completed", response=resp_obj)
deltas = []
async def cb(text):
deltas.append(text)
async def stream():
for e in [ev1, ev2]:
yield e
await consume_sdk_stream(stream(), on_content_delta=cb)
assert deltas == ["hi"]
@pytest.mark.asyncio
async def test_tool_call_stream(self):
item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="")
item_added.name = "get_weather"
ev1 = MagicMock(type="response.output_item.added", item=item_added)
ev2 = MagicMock(type="response.function_call_arguments.delta", call_id="c1", delta='{"ci')
ev3 = MagicMock(type="response.function_call_arguments.done", call_id="c1", arguments='{"city":"SF"}')
item_done = MagicMock(type="function_call", call_id="c1", id="fc1", arguments='{"city":"SF"}')
item_done.name = "get_weather"
ev4 = MagicMock(type="response.output_item.done", item=item_done)
resp_obj = MagicMock(status="completed", usage=None, output=[])
ev5 = MagicMock(type="response.completed", response=resp_obj)
async def stream():
for e in [ev1, ev2, ev3, ev4, ev5]:
yield e
content, tool_calls, finish_reason, usage, reasoning = await consume_sdk_stream(stream())
assert content == ""
assert len(tool_calls) == 1
assert tool_calls[0].name == "get_weather"
assert tool_calls[0].arguments == {"city": "SF"}
@pytest.mark.asyncio
async def test_usage_extracted(self):
usage_obj = MagicMock(input_tokens=10, output_tokens=5, total_tokens=15)
resp_obj = MagicMock(status="completed", usage=usage_obj, output=[])
ev = MagicMock(type="response.completed", response=resp_obj)
async def stream():
yield ev
_, _, _, usage, _ = await consume_sdk_stream(stream())
assert usage == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
@pytest.mark.asyncio
async def test_reasoning_extracted(self):
summary_item = MagicMock(type="summary_text", text="thinking...")
reasoning_item = MagicMock(type="reasoning", summary=[summary_item])
resp_obj = MagicMock(status="completed", usage=None, output=[reasoning_item])
ev = MagicMock(type="response.completed", response=resp_obj)
async def stream():
yield ev
_, _, _, _, reasoning = await consume_sdk_stream(stream())
assert reasoning == "thinking..."
@pytest.mark.asyncio
async def test_error_event_raises(self):
ev = MagicMock(type="error", error="rate_limit_exceeded")
async def stream():
yield ev
with pytest.raises(RuntimeError, match="Response failed.*rate_limit_exceeded"):
await consume_sdk_stream(stream())
@pytest.mark.asyncio
async def test_failed_event_raises(self):
ev = MagicMock(type="response.failed", error="server_error")
async def stream():
yield ev
with pytest.raises(RuntimeError, match="Response failed.*server_error"):
await consume_sdk_stream(stream())
@pytest.mark.asyncio
async def test_malformed_tool_args_logged(self):
"""Malformed JSON in streaming tool args should log a warning."""
item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="")
item_added.name = "f"
ev1 = MagicMock(type="response.output_item.added", item=item_added)
ev2 = MagicMock(type="response.function_call_arguments.done", call_id="c1", arguments="{bad")
item_done = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="{bad")
item_done.name = "f"
ev3 = MagicMock(type="response.output_item.done", item=item_done)
resp_obj = MagicMock(status="completed", usage=None, output=[])
ev4 = MagicMock(type="response.completed", response=resp_obj)
async def stream():
for e in [ev1, ev2, ev3, ev4]:
yield e
with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger:
_, tool_calls, _, _, _ = await consume_sdk_stream(stream())
assert tool_calls[0].arguments == {"raw": "{bad"}
mock_logger.warning.assert_called_once()
assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args)

View File

@ -11,6 +11,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None:
monkeypatch.delitem(sys.modules, "nanobot.providers.anthropic_provider", raising=False)
monkeypatch.delitem(sys.modules, "nanobot.providers.openai_compat_provider", raising=False)
monkeypatch.delitem(sys.modules, "nanobot.providers.openai_codex_provider", raising=False)
monkeypatch.delitem(sys.modules, "nanobot.providers.github_copilot_provider", raising=False)
monkeypatch.delitem(sys.modules, "nanobot.providers.azure_openai_provider", raising=False)
providers = importlib.import_module("nanobot.providers")
@ -18,6 +19,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None:
assert "nanobot.providers.anthropic_provider" not in sys.modules
assert "nanobot.providers.openai_compat_provider" not in sys.modules
assert "nanobot.providers.openai_codex_provider" not in sys.modules
assert "nanobot.providers.github_copilot_provider" not in sys.modules
assert "nanobot.providers.azure_openai_provider" not in sys.modules
assert providers.__all__ == [
"LLMProvider",
@ -25,6 +27,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None:
"AnthropicProvider",
"OpenAICompatProvider",
"OpenAICodexProvider",
"GitHubCopilotProvider",
"AzureOpenAIProvider",
]

View File

@ -0,0 +1,59 @@
"""Tests for build_status_content cache hit rate display."""
from nanobot.utils.helpers import build_status_content
def test_status_shows_cache_hit_rate():
content = build_status_content(
version="0.1.0",
model="glm-4-plus",
start_time=1000000.0,
last_usage={"prompt_tokens": 2000, "completion_tokens": 300, "cached_tokens": 1200},
context_window_tokens=128000,
session_msg_count=10,
context_tokens_estimate=5000,
)
assert "60% cached" in content
assert "2000 in / 300 out" in content
def test_status_no_cache_info():
"""Without cached_tokens, display should not show cache percentage."""
content = build_status_content(
version="0.1.0",
model="glm-4-plus",
start_time=1000000.0,
last_usage={"prompt_tokens": 2000, "completion_tokens": 300},
context_window_tokens=128000,
session_msg_count=10,
context_tokens_estimate=5000,
)
assert "cached" not in content.lower()
assert "2000 in / 300 out" in content
def test_status_zero_cached_tokens():
"""cached_tokens=0 should not show cache percentage."""
content = build_status_content(
version="0.1.0",
model="glm-4-plus",
start_time=1000000.0,
last_usage={"prompt_tokens": 2000, "completion_tokens": 300, "cached_tokens": 0},
context_window_tokens=128000,
session_msg_count=10,
context_tokens_estimate=5000,
)
assert "cached" not in content.lower()
def test_status_100_percent_cached():
content = build_status_content(
version="0.1.0",
model="glm-4-plus",
start_time=1000000.0,
last_usage={"prompt_tokens": 1000, "completion_tokens": 100, "cached_tokens": 1000},
context_window_tokens=128000,
session_msg_count=5,
context_tokens_estimate=3000,
)
assert "100% cached" in content

View File

@ -125,6 +125,27 @@ def test_workspace_override(tmp_path):
assert bot._loop.workspace == custom_ws
def test_sdk_make_provider_uses_github_copilot_backend():
from nanobot.config.schema import Config
from nanobot.nanobot import _make_provider
config = Config.model_validate(
{
"agents": {
"defaults": {
"provider": "github-copilot",
"model": "github-copilot/gpt-4.1",
}
}
}
)
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = _make_provider(config)
assert provider.__class__.__name__ == "GitHubCopilotProvider"
@pytest.mark.asyncio
async def test_run_custom_session_key(tmp_path):
from nanobot.bus.events import OutboundMessage

View File

@ -95,6 +95,14 @@ def test_exec_extract_absolute_paths_keeps_full_windows_path() -> None:
assert paths == [r"C:\user\workspace\txt"]
def test_exec_extract_absolute_paths_captures_windows_drive_root_path() -> None:
"""Windows drive root paths like `E:\\` must be extracted for workspace guarding."""
# Note: raw strings cannot end with a single backslash.
cmd = "dir E:\\"
paths = ExecTool._extract_absolute_paths(cmd)
assert paths == ["E:\\"]
def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None:
cmd = ".venv/bin/python script.py"
paths = ExecTool._extract_absolute_paths(cmd)
@ -134,6 +142,45 @@ def test_exec_guard_blocks_quoted_home_path_outside_workspace(tmp_path) -> None:
assert error == "Error: Command blocked by safety guard (path outside working dir)"
def test_exec_guard_blocks_windows_drive_root_outside_workspace(monkeypatch) -> None:
import nanobot.agent.tools.shell as shell_mod
class FakeWindowsPath:
def __init__(self, raw: str) -> None:
self.raw = raw.rstrip("\\") + ("\\" if raw.endswith("\\") else "")
def resolve(self) -> "FakeWindowsPath":
return self
def expanduser(self) -> "FakeWindowsPath":
return self
def is_absolute(self) -> bool:
return len(self.raw) >= 3 and self.raw[1:3] == ":\\"
@property
def parents(self) -> list["FakeWindowsPath"]:
if not self.is_absolute():
return []
trimmed = self.raw.rstrip("\\")
if len(trimmed) <= 2:
return []
idx = trimmed.rfind("\\")
if idx <= 2:
return [FakeWindowsPath(trimmed[:2] + "\\")]
parent = FakeWindowsPath(trimmed[:idx])
return [parent, *parent.parents]
def __eq__(self, other: object) -> bool:
return isinstance(other, FakeWindowsPath) and self.raw.lower() == other.raw.lower()
monkeypatch.setattr(shell_mod, "Path", FakeWindowsPath)
tool = ExecTool(restrict_to_workspace=True)
error = tool._guard_command("dir E:\\", "E:\\workspace")
assert error == "Error: Command blocked by safety guard (path outside working dir)"
# --- cast_params tests ---