mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 17:32:39 +00:00
Merge remote-tracking branch 'origin/main' into feat/runtime-hardening
This commit is contained in:
commit
eefd7e60f2
@ -95,6 +95,15 @@ class _LoopHook(AgentHook):
|
||||
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
||||
self._loop._set_tool_context(self._channel, self._chat_id, self._message_id)
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
u = context.usage or {}
|
||||
logger.debug(
|
||||
"LLM usage: prompt={} completion={} cached={}",
|
||||
u.get("prompt_tokens", 0),
|
||||
u.get("completion_tokens", 0),
|
||||
u.get("cached_tokens", 0),
|
||||
)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return self._loop._strip_think(content)
|
||||
|
||||
|
||||
@ -77,7 +77,7 @@ class AgentRunner:
|
||||
messages = list(spec.initial_messages)
|
||||
final_content: str | None = None
|
||||
tools_used: list[str] = []
|
||||
usage = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
usage: dict[str, int] = {}
|
||||
error: str | None = None
|
||||
stop_reason = "completed"
|
||||
tool_events: list[dict[str, str]] = []
|
||||
@ -122,13 +122,15 @@ class AgentRunner:
|
||||
response = await self.provider.chat_with_retry(**kwargs)
|
||||
|
||||
raw_usage = response.usage or {}
|
||||
usage = {
|
||||
"prompt_tokens": int(raw_usage.get("prompt_tokens", 0) or 0),
|
||||
"completion_tokens": int(raw_usage.get("completion_tokens", 0) or 0),
|
||||
}
|
||||
context.response = response
|
||||
context.usage = usage
|
||||
context.usage = raw_usage
|
||||
context.tool_calls = list(response.tool_calls)
|
||||
# Accumulate standard fields into result usage.
|
||||
usage["prompt_tokens"] = usage.get("prompt_tokens", 0) + int(raw_usage.get("prompt_tokens", 0) or 0)
|
||||
usage["completion_tokens"] = usage.get("completion_tokens", 0) + int(raw_usage.get("completion_tokens", 0) or 0)
|
||||
cached = raw_usage.get("cached_tokens")
|
||||
if cached:
|
||||
usage["cached_tokens"] = usage.get("cached_tokens", 0) + int(cached)
|
||||
|
||||
if response.has_tool_calls:
|
||||
if hook.wants_streaming():
|
||||
|
||||
@ -97,6 +97,11 @@ class CronTool(Tool):
|
||||
f"(e.g. '2026-02-12T10:30:00'). Naive values default to {self._default_timezone}."
|
||||
),
|
||||
},
|
||||
"deliver": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to deliver the execution result to the user channel (default true)",
|
||||
"default": True
|
||||
},
|
||||
"job_id": {"type": "string", "description": "Job ID (for remove)"},
|
||||
},
|
||||
"required": ["action"],
|
||||
@ -111,12 +116,13 @@ class CronTool(Tool):
|
||||
tz: str | None = None,
|
||||
at: str | None = None,
|
||||
job_id: str | None = None,
|
||||
deliver: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if action == "add":
|
||||
if self._in_cron_context.get():
|
||||
return "Error: cannot schedule new jobs from within a cron job execution"
|
||||
return self._add_job(message, every_seconds, cron_expr, tz, at)
|
||||
return self._add_job(message, every_seconds, cron_expr, tz, at, deliver)
|
||||
elif action == "list":
|
||||
return self._list_jobs()
|
||||
elif action == "remove":
|
||||
@ -130,6 +136,7 @@ class CronTool(Tool):
|
||||
cron_expr: str | None,
|
||||
tz: str | None,
|
||||
at: str | None,
|
||||
deliver: bool = True,
|
||||
) -> str:
|
||||
if not message:
|
||||
return "Error: message is required for add"
|
||||
@ -171,7 +178,7 @@ class CronTool(Tool):
|
||||
name=message[:30],
|
||||
schedule=schedule,
|
||||
message=message,
|
||||
deliver=True,
|
||||
deliver=deliver,
|
||||
channel=self._channel,
|
||||
to=self._chat_id,
|
||||
delete_after_run=delete_after,
|
||||
|
||||
@ -86,7 +86,15 @@ class MessageTool(Tool):
|
||||
) -> str:
|
||||
channel = channel or self._default_channel
|
||||
chat_id = chat_id or self._default_chat_id
|
||||
message_id = message_id or self._default_message_id
|
||||
# Only inherit default message_id when targeting the same channel+chat.
|
||||
# Cross-chat sends must not carry the original message_id, because
|
||||
# some channels (e.g. Feishu) use it to determine the target
|
||||
# conversation via their Reply API, which would route the message
|
||||
# to the wrong chat entirely.
|
||||
if channel == self._default_channel and chat_id == self._default_chat_id:
|
||||
message_id = message_id or self._default_message_id
|
||||
else:
|
||||
message_id = None
|
||||
|
||||
if not channel or not chat_id:
|
||||
return "Error: No target channel/chat specified"
|
||||
@ -101,7 +109,7 @@ class MessageTool(Tool):
|
||||
media=media or [],
|
||||
metadata={
|
||||
"message_id": message_id,
|
||||
},
|
||||
} if message_id else {},
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -190,7 +190,9 @@ class ExecTool(Tool):
|
||||
|
||||
@staticmethod
|
||||
def _extract_absolute_paths(command: str) -> list[str]:
|
||||
win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]+", command) # Windows: C:\...
|
||||
# Windows: match drive-root paths like `C:\` as well as `C:\path\to\file`
|
||||
# NOTE: `*` is required so `C:\` (nothing after the slash) is still extracted.
|
||||
win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]*", command)
|
||||
posix_paths = re.findall(r"(?:^|[\s|>'\"])(/[^\s\"'>;|<]+)", command) # POSIX: /absolute only
|
||||
home_paths = re.findall(r"(?:^|[\s|>'\"])(~[^\s\"'>;|<]*)", command) # POSIX/Windows home shortcut: ~
|
||||
return win_paths + posix_paths + home_paths
|
||||
|
||||
@ -415,6 +415,9 @@ def _make_provider(config: Config):
|
||||
api_base=p.api_base,
|
||||
default_model=model,
|
||||
)
|
||||
elif backend == "github_copilot":
|
||||
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
||||
provider = GitHubCopilotProvider(default_model=model)
|
||||
elif backend == "anthropic":
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
provider = AnthropicProvider(
|
||||
@ -1029,12 +1032,18 @@ app.add_typer(channels_app, name="channels")
|
||||
|
||||
|
||||
@channels_app.command("status")
|
||||
def channels_status():
|
||||
def channels_status(
|
||||
config_path: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
|
||||
):
|
||||
"""Show channel status."""
|
||||
from nanobot.channels.registry import discover_all
|
||||
from nanobot.config.loader import load_config
|
||||
from nanobot.config.loader import load_config, set_config_path
|
||||
|
||||
config = load_config()
|
||||
resolved_config_path = Path(config_path).expanduser().resolve() if config_path else None
|
||||
if resolved_config_path is not None:
|
||||
set_config_path(resolved_config_path)
|
||||
|
||||
config = load_config(resolved_config_path)
|
||||
|
||||
table = Table(title="Channel Status")
|
||||
table.add_column("Channel", style="cyan")
|
||||
@ -1121,12 +1130,17 @@ def _get_bridge_dir() -> Path:
|
||||
def channels_login(
|
||||
channel_name: str = typer.Argument(..., help="Channel name (e.g. weixin, whatsapp)"),
|
||||
force: bool = typer.Option(False, "--force", "-f", help="Force re-authentication even if already logged in"),
|
||||
config_path: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
|
||||
):
|
||||
"""Authenticate with a channel via QR code or other interactive login."""
|
||||
from nanobot.channels.registry import discover_all
|
||||
from nanobot.config.loader import load_config
|
||||
from nanobot.config.loader import load_config, set_config_path
|
||||
|
||||
config = load_config()
|
||||
resolved_config_path = Path(config_path).expanduser().resolve() if config_path else None
|
||||
if resolved_config_path is not None:
|
||||
set_config_path(resolved_config_path)
|
||||
|
||||
config = load_config(resolved_config_path)
|
||||
channel_cfg = getattr(config.channels, channel_name, None) or {}
|
||||
|
||||
# Validate channel exists
|
||||
@ -1298,26 +1312,16 @@ def _login_openai_codex() -> None:
|
||||
|
||||
@_register_login("github_copilot")
|
||||
def _login_github_copilot() -> None:
|
||||
import asyncio
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n")
|
||||
|
||||
async def _trigger():
|
||||
client = AsyncOpenAI(
|
||||
api_key="dummy",
|
||||
base_url="https://api.githubcopilot.com",
|
||||
)
|
||||
await client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
max_tokens=1,
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.run(_trigger())
|
||||
console.print("[green]✓ Authenticated with GitHub Copilot[/green]")
|
||||
from nanobot.providers.github_copilot_provider import login_github_copilot
|
||||
|
||||
console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n")
|
||||
token = login_github_copilot(
|
||||
print_fn=lambda s: console.print(s),
|
||||
prompt_fn=lambda s: typer.prompt(s),
|
||||
)
|
||||
account = token.account_id or "GitHub"
|
||||
console.print(f"[green]✓ Authenticated with GitHub Copilot[/green] [dim]{account}[/dim]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Authentication error: {e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
@ -138,6 +138,10 @@ def _make_provider(config: Any) -> Any:
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
|
||||
provider = OpenAICodexProvider(default_model=model)
|
||||
elif backend == "github_copilot":
|
||||
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
||||
|
||||
provider = GitHubCopilotProvider(default_model=model)
|
||||
elif backend == "azure_openai":
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ __all__ = [
|
||||
"AnthropicProvider",
|
||||
"OpenAICompatProvider",
|
||||
"OpenAICodexProvider",
|
||||
"GitHubCopilotProvider",
|
||||
"AzureOpenAIProvider",
|
||||
]
|
||||
|
||||
@ -20,12 +21,14 @@ _LAZY_IMPORTS = {
|
||||
"AnthropicProvider": ".anthropic_provider",
|
||||
"OpenAICompatProvider": ".openai_compat_provider",
|
||||
"OpenAICodexProvider": ".openai_codex_provider",
|
||||
"GitHubCopilotProvider": ".github_copilot_provider",
|
||||
"AzureOpenAIProvider": ".azure_openai_provider",
|
||||
}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
|
||||
|
||||
@ -372,15 +372,22 @@ class AnthropicProvider(LLMProvider):
|
||||
|
||||
usage: dict[str, int] = {}
|
||||
if response.usage:
|
||||
input_tokens = response.usage.input_tokens
|
||||
cache_creation = getattr(response.usage, "cache_creation_input_tokens", 0) or 0
|
||||
cache_read = getattr(response.usage, "cache_read_input_tokens", 0) or 0
|
||||
total_prompt_tokens = input_tokens + cache_creation + cache_read
|
||||
usage = {
|
||||
"prompt_tokens": response.usage.input_tokens,
|
||||
"prompt_tokens": total_prompt_tokens,
|
||||
"completion_tokens": response.usage.output_tokens,
|
||||
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
|
||||
"total_tokens": total_prompt_tokens + response.usage.output_tokens,
|
||||
}
|
||||
for attr in ("cache_creation_input_tokens", "cache_read_input_tokens"):
|
||||
val = getattr(response.usage, attr, 0)
|
||||
if val:
|
||||
usage[attr] = val
|
||||
# Normalize to cached_tokens for downstream consistency.
|
||||
if cache_read:
|
||||
usage["cached_tokens"] = cache_read
|
||||
|
||||
return LLMResponse(
|
||||
content="".join(content_parts) or None,
|
||||
|
||||
@ -1,31 +1,36 @@
|
||||
"""Azure OpenAI provider implementation with API version 2024-10-21."""
|
||||
"""Azure OpenAI provider using the OpenAI SDK Responses API.
|
||||
|
||||
Uses ``AsyncOpenAI`` pointed at ``https://{endpoint}/openai/v1/`` which
|
||||
routes to the Responses API (``/responses``). Reuses shared conversion
|
||||
helpers from :mod:`nanobot.providers.openai_responses`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
import json_repair
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse
|
||||
from nanobot.providers.openai_responses import (
|
||||
consume_sdk_stream,
|
||||
convert_messages,
|
||||
convert_tools,
|
||||
parse_response_output,
|
||||
)
|
||||
|
||||
|
||||
class AzureOpenAIProvider(LLMProvider):
|
||||
"""
|
||||
Azure OpenAI provider with API version 2024-10-21 compliance.
|
||||
|
||||
"""Azure OpenAI provider backed by the Responses API.
|
||||
|
||||
Features:
|
||||
- Hardcoded API version 2024-10-21
|
||||
- Uses model field as Azure deployment name in URL path
|
||||
- Uses api-key header instead of Authorization Bearer
|
||||
- Uses max_completion_tokens instead of max_tokens
|
||||
- Direct HTTP calls, bypasses LiteLLM
|
||||
- Uses the OpenAI Python SDK (``AsyncOpenAI``) with
|
||||
``base_url = {endpoint}/openai/v1/``
|
||||
- Calls ``client.responses.create()`` (Responses API)
|
||||
- Reuses shared message/tool/SSE conversion from
|
||||
``openai_responses``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -36,40 +41,28 @@ class AzureOpenAIProvider(LLMProvider):
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = default_model
|
||||
self.api_version = "2024-10-21"
|
||||
|
||||
# Validate required parameters
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("Azure OpenAI api_key is required")
|
||||
if not api_base:
|
||||
raise ValueError("Azure OpenAI api_base is required")
|
||||
|
||||
# Ensure api_base ends with /
|
||||
if not api_base.endswith('/'):
|
||||
api_base += '/'
|
||||
|
||||
# Normalise: ensure trailing slash
|
||||
if not api_base.endswith("/"):
|
||||
api_base += "/"
|
||||
self.api_base = api_base
|
||||
|
||||
def _build_chat_url(self, deployment_name: str) -> str:
|
||||
"""Build the Azure OpenAI chat completions URL."""
|
||||
# Azure OpenAI URL format:
|
||||
# https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}
|
||||
base_url = self.api_base
|
||||
if not base_url.endswith('/'):
|
||||
base_url += '/'
|
||||
|
||||
url = urljoin(
|
||||
base_url,
|
||||
f"openai/deployments/{deployment_name}/chat/completions"
|
||||
# SDK client targeting the Azure Responses API endpoint
|
||||
base_url = f"{api_base.rstrip('/')}/openai/v1/"
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
default_headers={"x-session-affinity": uuid.uuid4().hex},
|
||||
)
|
||||
return f"{url}?api-version={self.api_version}"
|
||||
|
||||
def _build_headers(self) -> dict[str, str]:
|
||||
"""Build headers for Azure OpenAI API with api-key header."""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization
|
||||
"x-session-affinity": uuid.uuid4().hex, # For cache locality
|
||||
}
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _supports_temperature(
|
||||
@ -82,36 +75,51 @@ class AzureOpenAIProvider(LLMProvider):
|
||||
name = deployment_name.lower()
|
||||
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
|
||||
|
||||
def _prepare_request_payload(
|
||||
def _build_body(
|
||||
self,
|
||||
deployment_name: str,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
model: str | None,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
reasoning_effort: str | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
|
||||
payload: dict[str, Any] = {
|
||||
"messages": self._sanitize_request_messages(
|
||||
self._sanitize_empty_content(messages),
|
||||
_AZURE_MSG_KEYS,
|
||||
),
|
||||
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
|
||||
"""Build the Responses API request body from Chat-Completions-style args."""
|
||||
deployment = model or self.default_model
|
||||
instructions, input_items = convert_messages(self._sanitize_empty_content(messages))
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"model": deployment,
|
||||
"instructions": instructions or None,
|
||||
"input": input_items,
|
||||
"max_output_tokens": max(1, max_tokens),
|
||||
"store": False,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
if self._supports_temperature(deployment_name, reasoning_effort):
|
||||
payload["temperature"] = temperature
|
||||
if self._supports_temperature(deployment, reasoning_effort):
|
||||
body["temperature"] = temperature
|
||||
|
||||
if reasoning_effort:
|
||||
payload["reasoning_effort"] = reasoning_effort
|
||||
body["reasoning"] = {"effort": reasoning_effort}
|
||||
body["include"] = ["reasoning.encrypted_content"]
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
payload["tool_choice"] = tool_choice or "auto"
|
||||
body["tools"] = convert_tools(tools)
|
||||
body["tool_choice"] = tool_choice or "auto"
|
||||
|
||||
return payload
|
||||
return body
|
||||
|
||||
@staticmethod
|
||||
def _handle_error(e: Exception) -> LLMResponse:
|
||||
body = getattr(e, "body", None) or getattr(getattr(e, "response", None), "text", None)
|
||||
msg = f"Error: {str(body).strip()[:500]}" if body else f"Error calling Azure OpenAI: {e}"
|
||||
return LLMResponse(content=msg, finish_reason="error")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
@ -123,92 +131,15 @@ class AzureOpenAIProvider(LLMProvider):
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request to Azure OpenAI.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: Optional list of tool definitions in OpenAI format.
|
||||
model: Model identifier (used as deployment name).
|
||||
max_tokens: Maximum tokens in response (mapped to max_completion_tokens).
|
||||
temperature: Sampling temperature.
|
||||
reasoning_effort: Optional reasoning effort parameter.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
deployment_name = model or self.default_model
|
||||
url = self._build_chat_url(deployment_name)
|
||||
headers = self._build_headers()
|
||||
payload = self._prepare_request_payload(
|
||||
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort,
|
||||
tool_choice=tool_choice,
|
||||
body = self._build_body(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
|
||||
response = await client.post(url, headers=headers, json=payload)
|
||||
if response.status_code != 200:
|
||||
return LLMResponse(
|
||||
content=f"Azure OpenAI API Error {response.status_code}: {response.text}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
return self._parse_response(response_data)
|
||||
|
||||
response = await self._client.responses.create(**body)
|
||||
return parse_response_output(response)
|
||||
except Exception as e:
|
||||
return LLMResponse(
|
||||
content=f"Error calling Azure OpenAI: {repr(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
def _parse_response(self, response: dict[str, Any]) -> LLMResponse:
|
||||
"""Parse Azure OpenAI response into our standard format."""
|
||||
try:
|
||||
choice = response["choices"][0]
|
||||
message = choice["message"]
|
||||
|
||||
tool_calls = []
|
||||
if message.get("tool_calls"):
|
||||
for tc in message["tool_calls"]:
|
||||
# Parse arguments from JSON string if needed
|
||||
args = tc["function"]["arguments"]
|
||||
if isinstance(args, str):
|
||||
args = json_repair.loads(args)
|
||||
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=tc["id"],
|
||||
name=tc["function"]["name"],
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
|
||||
usage = {}
|
||||
if response.get("usage"):
|
||||
usage_data = response["usage"]
|
||||
usage = {
|
||||
"prompt_tokens": usage_data.get("prompt_tokens", 0),
|
||||
"completion_tokens": usage_data.get("completion_tokens", 0),
|
||||
"total_tokens": usage_data.get("total_tokens", 0),
|
||||
}
|
||||
|
||||
reasoning_content = message.get("reasoning_content") or None
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content"),
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=choice.get("finish_reason", "stop"),
|
||||
usage=usage,
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
|
||||
except (KeyError, IndexError) as e:
|
||||
return LLMResponse(
|
||||
content=f"Error parsing Azure OpenAI response: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
return self._handle_error(e)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
@ -221,89 +152,26 @@ class AzureOpenAIProvider(LLMProvider):
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Stream a chat completion via Azure OpenAI SSE."""
|
||||
deployment_name = model or self.default_model
|
||||
url = self._build_chat_url(deployment_name)
|
||||
headers = self._build_headers()
|
||||
payload = self._prepare_request_payload(
|
||||
deployment_name, messages, tools, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice=tool_choice,
|
||||
body = self._build_body(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
payload["stream"] = True
|
||||
body["stream"] = True
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
|
||||
async with client.stream("POST", url, headers=headers, json=payload) as response:
|
||||
if response.status_code != 200:
|
||||
text = await response.aread()
|
||||
return LLMResponse(
|
||||
content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}",
|
||||
finish_reason="error",
|
||||
)
|
||||
return await self._consume_stream(response, on_content_delta)
|
||||
except Exception as e:
|
||||
return LLMResponse(content=f"Error calling Azure OpenAI: {repr(e)}", finish_reason="error")
|
||||
|
||||
async def _consume_stream(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None,
|
||||
) -> LLMResponse:
|
||||
"""Parse Azure OpenAI SSE stream into an LLMResponse."""
|
||||
content_parts: list[str] = []
|
||||
tool_call_buffers: dict[int, dict[str, str]] = {}
|
||||
finish_reason = "stop"
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data = line[6:].strip()
|
||||
if data == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
choices = chunk.get("choices") or []
|
||||
if not choices:
|
||||
continue
|
||||
choice = choices[0]
|
||||
if choice.get("finish_reason"):
|
||||
finish_reason = choice["finish_reason"]
|
||||
delta = choice.get("delta") or {}
|
||||
|
||||
text = delta.get("content")
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
if on_content_delta:
|
||||
await on_content_delta(text)
|
||||
|
||||
for tc in delta.get("tool_calls") or []:
|
||||
idx = tc.get("index", 0)
|
||||
buf = tool_call_buffers.setdefault(idx, {"id": "", "name": "", "arguments": ""})
|
||||
if tc.get("id"):
|
||||
buf["id"] = tc["id"]
|
||||
fn = tc.get("function") or {}
|
||||
if fn.get("name"):
|
||||
buf["name"] = fn["name"]
|
||||
if fn.get("arguments"):
|
||||
buf["arguments"] += fn["arguments"]
|
||||
|
||||
tool_calls = [
|
||||
ToolCallRequest(
|
||||
id=buf["id"], name=buf["name"],
|
||||
arguments=json_repair.loads(buf["arguments"]) if buf["arguments"] else {},
|
||||
stream = await self._client.responses.create(**body)
|
||||
content, tool_calls, finish_reason, usage, reasoning_content = (
|
||||
await consume_sdk_stream(stream, on_content_delta)
|
||||
)
|
||||
for buf in tool_call_buffers.values()
|
||||
]
|
||||
|
||||
return LLMResponse(
|
||||
content="".join(content_parts) or None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
return LLMResponse(
|
||||
content=content or None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
except Exception as e:
|
||||
return self._handle_error(e)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model (also used as default deployment name)."""
|
||||
return self.default_model
|
||||
257
nanobot/providers/github_copilot_provider.py
Normal file
257
nanobot/providers/github_copilot_provider.py
Normal file
@ -0,0 +1,257 @@
|
||||
"""GitHub Copilot OAuth-backed provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import webbrowser
|
||||
from collections.abc import Callable
|
||||
|
||||
import httpx
|
||||
from oauth_cli_kit.models import OAuthToken
|
||||
from oauth_cli_kit.storage import FileTokenStorage
|
||||
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
DEFAULT_GITHUB_DEVICE_CODE_URL = "https://github.com/login/device/code"
|
||||
DEFAULT_GITHUB_ACCESS_TOKEN_URL = "https://github.com/login/oauth/access_token"
|
||||
DEFAULT_GITHUB_USER_URL = "https://api.github.com/user"
|
||||
DEFAULT_COPILOT_TOKEN_URL = "https://api.github.com/copilot_internal/v2/token"
|
||||
DEFAULT_COPILOT_BASE_URL = "https://api.githubcopilot.com"
|
||||
GITHUB_COPILOT_CLIENT_ID = "Iv1.b507a08c87ecfe98"
|
||||
GITHUB_COPILOT_SCOPE = "read:user"
|
||||
TOKEN_FILENAME = "github-copilot.json"
|
||||
TOKEN_APP_NAME = "nanobot"
|
||||
USER_AGENT = "nanobot/0.1"
|
||||
EDITOR_VERSION = "vscode/1.99.0"
|
||||
EDITOR_PLUGIN_VERSION = "copilot-chat/0.26.0"
|
||||
_EXPIRY_SKEW_SECONDS = 60
|
||||
_LONG_LIVED_TOKEN_SECONDS = 315360000
|
||||
|
||||
|
||||
def _storage() -> FileTokenStorage:
|
||||
return FileTokenStorage(
|
||||
token_filename=TOKEN_FILENAME,
|
||||
app_name=TOKEN_APP_NAME,
|
||||
import_codex_cli=False,
|
||||
)
|
||||
|
||||
|
||||
def _copilot_headers(token: str) -> dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"token {token}",
|
||||
"Accept": "application/json",
|
||||
"User-Agent": USER_AGENT,
|
||||
"Editor-Version": EDITOR_VERSION,
|
||||
"Editor-Plugin-Version": EDITOR_PLUGIN_VERSION,
|
||||
}
|
||||
|
||||
|
||||
def _load_github_token() -> OAuthToken | None:
|
||||
token = _storage().load()
|
||||
if not token or not token.access:
|
||||
return None
|
||||
return token
|
||||
|
||||
|
||||
def get_github_copilot_login_status() -> OAuthToken | None:
|
||||
"""Return the persisted GitHub OAuth token if available."""
|
||||
return _load_github_token()
|
||||
|
||||
|
||||
def login_github_copilot(
|
||||
print_fn: Callable[[str], None] | None = None,
|
||||
prompt_fn: Callable[[str], str] | None = None,
|
||||
) -> OAuthToken:
|
||||
"""Run GitHub device flow and persist the GitHub OAuth token used for Copilot."""
|
||||
del prompt_fn
|
||||
printer = print_fn or print
|
||||
timeout = httpx.Timeout(20.0, connect=20.0)
|
||||
|
||||
with httpx.Client(timeout=timeout, follow_redirects=True, trust_env=True) as client:
|
||||
response = client.post(
|
||||
DEFAULT_GITHUB_DEVICE_CODE_URL,
|
||||
headers={"Accept": "application/json", "User-Agent": USER_AGENT},
|
||||
data={"client_id": GITHUB_COPILOT_CLIENT_ID, "scope": GITHUB_COPILOT_SCOPE},
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
|
||||
device_code = str(payload["device_code"])
|
||||
user_code = str(payload["user_code"])
|
||||
verify_url = str(payload.get("verification_uri") or payload.get("verification_uri_complete") or "")
|
||||
verify_complete = str(payload.get("verification_uri_complete") or verify_url)
|
||||
interval = max(1, int(payload.get("interval") or 5))
|
||||
expires_in = int(payload.get("expires_in") or 900)
|
||||
|
||||
printer(f"Open: {verify_url}")
|
||||
printer(f"Code: {user_code}")
|
||||
if verify_complete:
|
||||
try:
|
||||
webbrowser.open(verify_complete)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
deadline = time.time() + expires_in
|
||||
current_interval = interval
|
||||
access_token = None
|
||||
token_expires_in = _LONG_LIVED_TOKEN_SECONDS
|
||||
while time.time() < deadline:
|
||||
poll = client.post(
|
||||
DEFAULT_GITHUB_ACCESS_TOKEN_URL,
|
||||
headers={"Accept": "application/json", "User-Agent": USER_AGENT},
|
||||
data={
|
||||
"client_id": GITHUB_COPILOT_CLIENT_ID,
|
||||
"device_code": device_code,
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
},
|
||||
)
|
||||
poll.raise_for_status()
|
||||
poll_payload = poll.json()
|
||||
|
||||
access_token = poll_payload.get("access_token")
|
||||
if access_token:
|
||||
token_expires_in = int(poll_payload.get("expires_in") or _LONG_LIVED_TOKEN_SECONDS)
|
||||
break
|
||||
|
||||
error = poll_payload.get("error")
|
||||
if error == "authorization_pending":
|
||||
time.sleep(current_interval)
|
||||
continue
|
||||
if error == "slow_down":
|
||||
current_interval += 5
|
||||
time.sleep(current_interval)
|
||||
continue
|
||||
if error == "expired_token":
|
||||
raise RuntimeError("GitHub device code expired. Please run login again.")
|
||||
if error == "access_denied":
|
||||
raise RuntimeError("GitHub device flow was denied.")
|
||||
if error:
|
||||
desc = poll_payload.get("error_description") or error
|
||||
raise RuntimeError(str(desc))
|
||||
time.sleep(current_interval)
|
||||
else:
|
||||
raise RuntimeError("GitHub device flow timed out.")
|
||||
|
||||
user = client.get(
|
||||
DEFAULT_GITHUB_USER_URL,
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
"User-Agent": USER_AGENT,
|
||||
},
|
||||
)
|
||||
user.raise_for_status()
|
||||
user_payload = user.json()
|
||||
account_id = user_payload.get("login") or str(user_payload.get("id") or "") or None
|
||||
|
||||
expires_ms = int((time.time() + token_expires_in) * 1000)
|
||||
token = OAuthToken(
|
||||
access=str(access_token),
|
||||
refresh="",
|
||||
expires=expires_ms,
|
||||
account_id=str(account_id) if account_id else None,
|
||||
)
|
||||
_storage().save(token)
|
||||
return token
|
||||
|
||||
|
||||
class GitHubCopilotProvider(OpenAICompatProvider):
|
||||
"""Provider that exchanges a stored GitHub OAuth token for Copilot access tokens."""
|
||||
|
||||
def __init__(self, default_model: str = "github-copilot/gpt-4.1"):
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
self._copilot_access_token: str | None = None
|
||||
self._copilot_expires_at: float = 0.0
|
||||
super().__init__(
|
||||
api_key="no-key",
|
||||
api_base=DEFAULT_COPILOT_BASE_URL,
|
||||
default_model=default_model,
|
||||
extra_headers={
|
||||
"Editor-Version": EDITOR_VERSION,
|
||||
"Editor-Plugin-Version": EDITOR_PLUGIN_VERSION,
|
||||
"User-Agent": USER_AGENT,
|
||||
},
|
||||
spec=find_by_name("github_copilot"),
|
||||
)
|
||||
|
||||
async def _get_copilot_access_token(self) -> str:
|
||||
now = time.time()
|
||||
if self._copilot_access_token and now < self._copilot_expires_at - _EXPIRY_SKEW_SECONDS:
|
||||
return self._copilot_access_token
|
||||
|
||||
github_token = _load_github_token()
|
||||
if not github_token or not github_token.access:
|
||||
raise RuntimeError("GitHub Copilot is not logged in. Run: nanobot provider login github-copilot")
|
||||
|
||||
timeout = httpx.Timeout(20.0, connect=20.0)
|
||||
async with httpx.AsyncClient(timeout=timeout, follow_redirects=True, trust_env=True) as client:
|
||||
response = await client.get(
|
||||
DEFAULT_COPILOT_TOKEN_URL,
|
||||
headers=_copilot_headers(github_token.access),
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
|
||||
token = payload.get("token")
|
||||
if not token:
|
||||
raise RuntimeError("GitHub Copilot token exchange returned no token.")
|
||||
|
||||
expires_at = payload.get("expires_at")
|
||||
if isinstance(expires_at, (int, float)):
|
||||
self._copilot_expires_at = float(expires_at)
|
||||
else:
|
||||
refresh_in = payload.get("refresh_in") or 1500
|
||||
self._copilot_expires_at = time.time() + int(refresh_in)
|
||||
self._copilot_access_token = str(token)
|
||||
return self._copilot_access_token
|
||||
|
||||
async def _refresh_client_api_key(self) -> str:
|
||||
token = await self._get_copilot_access_token()
|
||||
self.api_key = token
|
||||
self._client.api_key = token
|
||||
return token
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, object]],
|
||||
tools: list[dict[str, object]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, object] | None = None,
|
||||
):
|
||||
await self._refresh_client_api_key()
|
||||
return await super().chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
reasoning_effort=reasoning_effort,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict[str, object]],
|
||||
tools: list[dict[str, object]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, object] | None = None,
|
||||
on_content_delta: Callable[[str], None] | None = None,
|
||||
):
|
||||
await self._refresh_client_api_key()
|
||||
return await super().chat_stream(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
reasoning_effort=reasoning_effort,
|
||||
tool_choice=tool_choice,
|
||||
on_content_delta=on_content_delta,
|
||||
)
|
||||
@ -6,13 +6,18 @@ import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from oauth_cli_kit import get_token as get_codex_token
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.openai_responses import (
|
||||
consume_sse,
|
||||
convert_messages,
|
||||
convert_tools,
|
||||
)
|
||||
|
||||
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
DEFAULT_ORIGINATOR = "nanobot"
|
||||
@ -36,7 +41,7 @@ class OpenAICodexProvider(LLMProvider):
|
||||
) -> LLMResponse:
|
||||
"""Shared request logic for both chat() and chat_stream()."""
|
||||
model = model or self.default_model
|
||||
system_prompt, input_items = _convert_messages(messages)
|
||||
system_prompt, input_items = convert_messages(messages)
|
||||
|
||||
token = await asyncio.to_thread(get_codex_token)
|
||||
headers = _build_headers(token.account_id, token.access)
|
||||
@ -56,7 +61,7 @@ class OpenAICodexProvider(LLMProvider):
|
||||
if reasoning_effort:
|
||||
body["reasoning"] = {"effort": reasoning_effort}
|
||||
if tools:
|
||||
body["tools"] = _convert_tools(tools)
|
||||
body["tools"] = convert_tools(tools)
|
||||
|
||||
try:
|
||||
try:
|
||||
@ -127,96 +132,7 @@ async def _request_codex(
|
||||
if response.status_code != 200:
|
||||
text = await response.aread()
|
||||
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
|
||||
return await _consume_sse(response, on_content_delta)
|
||||
|
||||
|
||||
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert OpenAI function-calling schema to Codex flat format."""
|
||||
converted: list[dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
|
||||
name = fn.get("name")
|
||||
if not name:
|
||||
continue
|
||||
params = fn.get("parameters") or {}
|
||||
converted.append({
|
||||
"type": "function",
|
||||
"name": name,
|
||||
"description": fn.get("description") or "",
|
||||
"parameters": params if isinstance(params, dict) else {},
|
||||
})
|
||||
return converted
|
||||
|
||||
|
||||
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
||||
system_prompt = ""
|
||||
input_items: list[dict[str, Any]] = []
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
|
||||
if role == "system":
|
||||
system_prompt = content if isinstance(content, str) else ""
|
||||
continue
|
||||
|
||||
if role == "user":
|
||||
input_items.append(_convert_user_message(content))
|
||||
continue
|
||||
|
||||
if role == "assistant":
|
||||
if isinstance(content, str) and content:
|
||||
input_items.append({
|
||||
"type": "message", "role": "assistant",
|
||||
"content": [{"type": "output_text", "text": content}],
|
||||
"status": "completed", "id": f"msg_{idx}",
|
||||
})
|
||||
for tool_call in msg.get("tool_calls", []) or []:
|
||||
fn = tool_call.get("function") or {}
|
||||
call_id, item_id = _split_tool_call_id(tool_call.get("id"))
|
||||
input_items.append({
|
||||
"type": "function_call",
|
||||
"id": item_id or f"fc_{idx}",
|
||||
"call_id": call_id or f"call_{idx}",
|
||||
"name": fn.get("name"),
|
||||
"arguments": fn.get("arguments") or "{}",
|
||||
})
|
||||
continue
|
||||
|
||||
if role == "tool":
|
||||
call_id, _ = _split_tool_call_id(msg.get("tool_call_id"))
|
||||
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
|
||||
input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text})
|
||||
|
||||
return system_prompt, input_items
|
||||
|
||||
|
||||
def _convert_user_message(content: Any) -> dict[str, Any]:
|
||||
if isinstance(content, str):
|
||||
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
|
||||
if isinstance(content, list):
|
||||
converted: list[dict[str, Any]] = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("type") == "text":
|
||||
converted.append({"type": "input_text", "text": item.get("text", "")})
|
||||
elif item.get("type") == "image_url":
|
||||
url = (item.get("image_url") or {}).get("url")
|
||||
if url:
|
||||
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
|
||||
if converted:
|
||||
return {"role": "user", "content": converted}
|
||||
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
|
||||
|
||||
|
||||
def _split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
|
||||
if isinstance(tool_call_id, str) and tool_call_id:
|
||||
if "|" in tool_call_id:
|
||||
call_id, item_id = tool_call_id.split("|", 1)
|
||||
return call_id, item_id or None
|
||||
return tool_call_id, None
|
||||
return "call_0", None
|
||||
return await consume_sse(response, on_content_delta)
|
||||
|
||||
|
||||
def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
|
||||
@ -224,96 +140,6 @@ def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
|
||||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
|
||||
buffer: list[str] = []
|
||||
async for line in response.aiter_lines():
|
||||
if line == "":
|
||||
if buffer:
|
||||
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
|
||||
buffer = []
|
||||
if not data_lines:
|
||||
continue
|
||||
data = "\n".join(data_lines).strip()
|
||||
if not data or data == "[DONE]":
|
||||
continue
|
||||
try:
|
||||
yield json.loads(data)
|
||||
except Exception:
|
||||
continue
|
||||
continue
|
||||
buffer.append(line)
|
||||
|
||||
|
||||
async def _consume_sse(
|
||||
response: httpx.Response,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> tuple[str, list[ToolCallRequest], str]:
|
||||
content = ""
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
tool_call_buffers: dict[str, dict[str, Any]] = {}
|
||||
finish_reason = "stop"
|
||||
|
||||
async for event in _iter_sse(response):
|
||||
event_type = event.get("type")
|
||||
if event_type == "response.output_item.added":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
call_id = item.get("call_id")
|
||||
if not call_id:
|
||||
continue
|
||||
tool_call_buffers[call_id] = {
|
||||
"id": item.get("id") or "fc_0",
|
||||
"name": item.get("name"),
|
||||
"arguments": item.get("arguments") or "",
|
||||
}
|
||||
elif event_type == "response.output_text.delta":
|
||||
delta_text = event.get("delta") or ""
|
||||
content += delta_text
|
||||
if on_content_delta and delta_text:
|
||||
await on_content_delta(delta_text)
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] += event.get("delta") or ""
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] = event.get("arguments") or ""
|
||||
elif event_type == "response.output_item.done":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
call_id = item.get("call_id")
|
||||
if not call_id:
|
||||
continue
|
||||
buf = tool_call_buffers.get(call_id) or {}
|
||||
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
|
||||
try:
|
||||
args = json.loads(args_raw)
|
||||
except Exception:
|
||||
args = {"raw": args_raw}
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
|
||||
name=buf.get("name") or item.get("name"),
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
elif event_type == "response.completed":
|
||||
status = (event.get("response") or {}).get("status")
|
||||
finish_reason = _map_finish_reason(status)
|
||||
elif event_type in {"error", "response.failed"}:
|
||||
raise RuntimeError("Codex response failed")
|
||||
|
||||
return content, tool_calls, finish_reason
|
||||
|
||||
|
||||
_FINISH_REASON_MAP = {"completed": "stop", "incomplete": "length", "failed": "error", "cancelled": "error"}
|
||||
|
||||
|
||||
def _map_finish_reason(status: str | None) -> str:
|
||||
return _FINISH_REASON_MAP.get(status or "completed", "stop")
|
||||
|
||||
|
||||
def _friendly_error(status_code: int, raw: str) -> str:
|
||||
if status_code == 429:
|
||||
return "ChatGPT usage quota exceeded or rate limit triggered. Please try again later."
|
||||
|
||||
@ -235,7 +235,9 @@ class OpenAICompatProvider(LLMProvider):
|
||||
spec = self._spec
|
||||
|
||||
if spec and spec.supports_prompt_caching:
|
||||
messages, tools = self._apply_cache_control(messages, tools)
|
||||
model_name = model or self.default_model
|
||||
if any(model_name.lower().startswith(k) for k in ("anthropic/", "claude")):
|
||||
messages, tools = self._apply_cache_control(messages, tools)
|
||||
|
||||
if spec and spec.strip_model_prefix:
|
||||
model_name = model_name.split("/")[-1]
|
||||
@ -308,6 +310,13 @@ class OpenAICompatProvider(LLMProvider):
|
||||
|
||||
@classmethod
|
||||
def _extract_usage(cls, response: Any) -> dict[str, int]:
|
||||
"""Extract token usage from an OpenAI-compatible response.
|
||||
|
||||
Handles both dict-based (raw JSON) and object-based (SDK Pydantic)
|
||||
responses. Provider-specific ``cached_tokens`` fields are normalised
|
||||
under a single key; see the priority chain inside for details.
|
||||
"""
|
||||
# --- resolve usage object ---
|
||||
usage_obj = None
|
||||
response_map = cls._maybe_mapping(response)
|
||||
if response_map is not None:
|
||||
@ -317,19 +326,53 @@ class OpenAICompatProvider(LLMProvider):
|
||||
|
||||
usage_map = cls._maybe_mapping(usage_obj)
|
||||
if usage_map is not None:
|
||||
return {
|
||||
result = {
|
||||
"prompt_tokens": int(usage_map.get("prompt_tokens") or 0),
|
||||
"completion_tokens": int(usage_map.get("completion_tokens") or 0),
|
||||
"total_tokens": int(usage_map.get("total_tokens") or 0),
|
||||
}
|
||||
|
||||
if usage_obj:
|
||||
return {
|
||||
elif usage_obj:
|
||||
result = {
|
||||
"prompt_tokens": getattr(usage_obj, "prompt_tokens", 0) or 0,
|
||||
"completion_tokens": getattr(usage_obj, "completion_tokens", 0) or 0,
|
||||
"total_tokens": getattr(usage_obj, "total_tokens", 0) or 0,
|
||||
}
|
||||
return {}
|
||||
else:
|
||||
return {}
|
||||
|
||||
# --- cached_tokens (normalised across providers) ---
|
||||
# Try nested paths first (dict), fall back to attribute (SDK object).
|
||||
# Priority order ensures the most specific field wins.
|
||||
for path in (
|
||||
("prompt_tokens_details", "cached_tokens"), # OpenAI/Zhipu/MiniMax/Qwen/Mistral/xAI
|
||||
("cached_tokens",), # StepFun/Moonshot (top-level)
|
||||
("prompt_cache_hit_tokens",), # DeepSeek/SiliconFlow
|
||||
):
|
||||
cached = cls._get_nested_int(usage_map, path)
|
||||
if not cached and usage_obj:
|
||||
cached = cls._get_nested_int(usage_obj, path)
|
||||
if cached:
|
||||
result["cached_tokens"] = cached
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _get_nested_int(obj: Any, path: tuple[str, ...]) -> int:
|
||||
"""Drill into *obj* by *path* segments and return an ``int`` value.
|
||||
|
||||
Supports both dict-key access and attribute access so it works
|
||||
uniformly with raw JSON dicts **and** SDK Pydantic models.
|
||||
"""
|
||||
current = obj
|
||||
for segment in path:
|
||||
if current is None:
|
||||
return 0
|
||||
if isinstance(current, dict):
|
||||
current = current.get(segment)
|
||||
else:
|
||||
current = getattr(current, segment, None)
|
||||
return int(current or 0) if current is not None else 0
|
||||
|
||||
def _parse(self, response: Any) -> LLMResponse:
|
||||
if isinstance(response, str):
|
||||
@ -603,4 +646,4 @@ class OpenAICompatProvider(LLMProvider):
|
||||
return self._handle_error(e)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
return self.default_model
|
||||
29
nanobot/providers/openai_responses/__init__.py
Normal file
29
nanobot/providers/openai_responses/__init__.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""Shared helpers for OpenAI Responses API providers (Codex, Azure OpenAI)."""
|
||||
|
||||
from nanobot.providers.openai_responses.converters import (
|
||||
convert_messages,
|
||||
convert_tools,
|
||||
convert_user_message,
|
||||
split_tool_call_id,
|
||||
)
|
||||
from nanobot.providers.openai_responses.parsing import (
|
||||
FINISH_REASON_MAP,
|
||||
consume_sdk_stream,
|
||||
consume_sse,
|
||||
iter_sse,
|
||||
map_finish_reason,
|
||||
parse_response_output,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"convert_messages",
|
||||
"convert_tools",
|
||||
"convert_user_message",
|
||||
"split_tool_call_id",
|
||||
"iter_sse",
|
||||
"consume_sse",
|
||||
"consume_sdk_stream",
|
||||
"map_finish_reason",
|
||||
"parse_response_output",
|
||||
"FINISH_REASON_MAP",
|
||||
]
|
||||
110
nanobot/providers/openai_responses/converters.py
Normal file
110
nanobot/providers/openai_responses/converters.py
Normal file
@ -0,0 +1,110 @@
|
||||
"""Convert Chat Completions messages/tools to Responses API format."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
||||
def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
||||
"""Convert Chat Completions messages to Responses API input items.
|
||||
|
||||
Returns ``(system_prompt, input_items)`` where *system_prompt* is extracted
|
||||
from any ``system`` role message and *input_items* is the Responses API
|
||||
``input`` array.
|
||||
"""
|
||||
system_prompt = ""
|
||||
input_items: list[dict[str, Any]] = []
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
|
||||
if role == "system":
|
||||
system_prompt = content if isinstance(content, str) else ""
|
||||
continue
|
||||
|
||||
if role == "user":
|
||||
input_items.append(convert_user_message(content))
|
||||
continue
|
||||
|
||||
if role == "assistant":
|
||||
if isinstance(content, str) and content:
|
||||
input_items.append({
|
||||
"type": "message", "role": "assistant",
|
||||
"content": [{"type": "output_text", "text": content}],
|
||||
"status": "completed", "id": f"msg_{idx}",
|
||||
})
|
||||
for tool_call in msg.get("tool_calls", []) or []:
|
||||
fn = tool_call.get("function") or {}
|
||||
call_id, item_id = split_tool_call_id(tool_call.get("id"))
|
||||
input_items.append({
|
||||
"type": "function_call",
|
||||
"id": item_id or f"fc_{idx}",
|
||||
"call_id": call_id or f"call_{idx}",
|
||||
"name": fn.get("name"),
|
||||
"arguments": fn.get("arguments") or "{}",
|
||||
})
|
||||
continue
|
||||
|
||||
if role == "tool":
|
||||
call_id, _ = split_tool_call_id(msg.get("tool_call_id"))
|
||||
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
|
||||
input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text})
|
||||
|
||||
return system_prompt, input_items
|
||||
|
||||
|
||||
def convert_user_message(content: Any) -> dict[str, Any]:
|
||||
"""Convert a user message's content to Responses API format.
|
||||
|
||||
Handles plain strings, ``text`` blocks -> ``input_text``, and
|
||||
``image_url`` blocks -> ``input_image``.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
|
||||
if isinstance(content, list):
|
||||
converted: list[dict[str, Any]] = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("type") == "text":
|
||||
converted.append({"type": "input_text", "text": item.get("text", "")})
|
||||
elif item.get("type") == "image_url":
|
||||
url = (item.get("image_url") or {}).get("url")
|
||||
if url:
|
||||
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
|
||||
if converted:
|
||||
return {"role": "user", "content": converted}
|
||||
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
|
||||
|
||||
|
||||
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert OpenAI function-calling tool schema to Responses API flat format."""
|
||||
converted: list[dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
|
||||
name = fn.get("name")
|
||||
if not name:
|
||||
continue
|
||||
params = fn.get("parameters") or {}
|
||||
converted.append({
|
||||
"type": "function",
|
||||
"name": name,
|
||||
"description": fn.get("description") or "",
|
||||
"parameters": params if isinstance(params, dict) else {},
|
||||
})
|
||||
return converted
|
||||
|
||||
|
||||
def split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
|
||||
"""Split a compound ``call_id|item_id`` string.
|
||||
|
||||
Returns ``(call_id, item_id)`` where *item_id* may be ``None``.
|
||||
"""
|
||||
if isinstance(tool_call_id, str) and tool_call_id:
|
||||
if "|" in tool_call_id:
|
||||
call_id, item_id = tool_call_id.split("|", 1)
|
||||
return call_id, item_id or None
|
||||
return tool_call_id, None
|
||||
return "call_0", None
|
||||
297
nanobot/providers/openai_responses/parsing.py
Normal file
297
nanobot/providers/openai_responses/parsing.py
Normal file
@ -0,0 +1,297 @@
|
||||
"""Parse Responses API SSE streams and SDK response objects."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import httpx
|
||||
import json_repair
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
FINISH_REASON_MAP = {
|
||||
"completed": "stop",
|
||||
"incomplete": "length",
|
||||
"failed": "error",
|
||||
"cancelled": "error",
|
||||
}
|
||||
|
||||
|
||||
def map_finish_reason(status: str | None) -> str:
|
||||
"""Map a Responses API status string to a Chat-Completions-style finish_reason."""
|
||||
return FINISH_REASON_MAP.get(status or "completed", "stop")
|
||||
|
||||
|
||||
async def iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
|
||||
"""Yield parsed JSON events from a Responses API SSE stream."""
|
||||
buffer: list[str] = []
|
||||
|
||||
def _flush() -> dict[str, Any] | None:
|
||||
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
|
||||
buffer.clear()
|
||||
if not data_lines:
|
||||
return None
|
||||
data = "\n".join(data_lines).strip()
|
||||
if not data or data == "[DONE]":
|
||||
return None
|
||||
try:
|
||||
return json.loads(data)
|
||||
except Exception:
|
||||
logger.warning("Failed to parse SSE event JSON: {}", data[:200])
|
||||
return None
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line == "":
|
||||
if buffer:
|
||||
event = _flush()
|
||||
if event is not None:
|
||||
yield event
|
||||
continue
|
||||
buffer.append(line)
|
||||
|
||||
# Flush any remaining buffer at EOF (#10)
|
||||
if buffer:
|
||||
event = _flush()
|
||||
if event is not None:
|
||||
yield event
|
||||
|
||||
|
||||
async def consume_sse(
|
||||
response: httpx.Response,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> tuple[str, list[ToolCallRequest], str]:
|
||||
"""Consume a Responses API SSE stream into ``(content, tool_calls, finish_reason)``."""
|
||||
content = ""
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
tool_call_buffers: dict[str, dict[str, Any]] = {}
|
||||
finish_reason = "stop"
|
||||
|
||||
async for event in iter_sse(response):
|
||||
event_type = event.get("type")
|
||||
if event_type == "response.output_item.added":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
call_id = item.get("call_id")
|
||||
if not call_id:
|
||||
continue
|
||||
tool_call_buffers[call_id] = {
|
||||
"id": item.get("id") or "fc_0",
|
||||
"name": item.get("name"),
|
||||
"arguments": item.get("arguments") or "",
|
||||
}
|
||||
elif event_type == "response.output_text.delta":
|
||||
delta_text = event.get("delta") or ""
|
||||
content += delta_text
|
||||
if on_content_delta and delta_text:
|
||||
await on_content_delta(delta_text)
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] += event.get("delta") or ""
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] = event.get("arguments") or ""
|
||||
elif event_type == "response.output_item.done":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
call_id = item.get("call_id")
|
||||
if not call_id:
|
||||
continue
|
||||
buf = tool_call_buffers.get(call_id) or {}
|
||||
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
|
||||
try:
|
||||
args = json.loads(args_raw)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to parse tool call arguments for '{}': {}",
|
||||
buf.get("name") or item.get("name"),
|
||||
args_raw[:200],
|
||||
)
|
||||
args = json_repair.loads(args_raw)
|
||||
if not isinstance(args, dict):
|
||||
args = {"raw": args_raw}
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
|
||||
name=buf.get("name") or item.get("name") or "",
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
elif event_type == "response.completed":
|
||||
status = (event.get("response") or {}).get("status")
|
||||
finish_reason = map_finish_reason(status)
|
||||
elif event_type in {"error", "response.failed"}:
|
||||
detail = event.get("error") or event.get("message") or event
|
||||
raise RuntimeError(f"Response failed: {str(detail)[:500]}")
|
||||
|
||||
return content, tool_calls, finish_reason
|
||||
|
||||
|
||||
def parse_response_output(response: Any) -> LLMResponse:
|
||||
"""Parse an SDK ``Response`` object into an ``LLMResponse``."""
|
||||
if not isinstance(response, dict):
|
||||
dump = getattr(response, "model_dump", None)
|
||||
response = dump() if callable(dump) else vars(response)
|
||||
|
||||
output = response.get("output") or []
|
||||
content_parts: list[str] = []
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
reasoning_content: str | None = None
|
||||
|
||||
for item in output:
|
||||
if not isinstance(item, dict):
|
||||
dump = getattr(item, "model_dump", None)
|
||||
item = dump() if callable(dump) else vars(item)
|
||||
|
||||
item_type = item.get("type")
|
||||
if item_type == "message":
|
||||
for block in item.get("content") or []:
|
||||
if not isinstance(block, dict):
|
||||
dump = getattr(block, "model_dump", None)
|
||||
block = dump() if callable(dump) else vars(block)
|
||||
if block.get("type") == "output_text":
|
||||
content_parts.append(block.get("text") or "")
|
||||
elif item_type == "reasoning":
|
||||
for s in item.get("summary") or []:
|
||||
if not isinstance(s, dict):
|
||||
dump = getattr(s, "model_dump", None)
|
||||
s = dump() if callable(dump) else vars(s)
|
||||
if s.get("type") == "summary_text" and s.get("text"):
|
||||
reasoning_content = (reasoning_content or "") + s["text"]
|
||||
elif item_type == "function_call":
|
||||
call_id = item.get("call_id") or ""
|
||||
item_id = item.get("id") or "fc_0"
|
||||
args_raw = item.get("arguments") or "{}"
|
||||
try:
|
||||
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to parse tool call arguments for '{}': {}",
|
||||
item.get("name"),
|
||||
str(args_raw)[:200],
|
||||
)
|
||||
args = json_repair.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||||
if not isinstance(args, dict):
|
||||
args = {"raw": args_raw}
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=f"{call_id}|{item_id}",
|
||||
name=item.get("name") or "",
|
||||
arguments=args if isinstance(args, dict) else {},
|
||||
))
|
||||
|
||||
usage_raw = response.get("usage") or {}
|
||||
if not isinstance(usage_raw, dict):
|
||||
dump = getattr(usage_raw, "model_dump", None)
|
||||
usage_raw = dump() if callable(dump) else vars(usage_raw)
|
||||
usage = {}
|
||||
if usage_raw:
|
||||
usage = {
|
||||
"prompt_tokens": int(usage_raw.get("input_tokens") or 0),
|
||||
"completion_tokens": int(usage_raw.get("output_tokens") or 0),
|
||||
"total_tokens": int(usage_raw.get("total_tokens") or 0),
|
||||
}
|
||||
|
||||
status = response.get("status")
|
||||
finish_reason = map_finish_reason(status)
|
||||
|
||||
return LLMResponse(
|
||||
content="".join(content_parts) or None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None,
|
||||
)
|
||||
|
||||
|
||||
async def consume_sdk_stream(
|
||||
stream: Any,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> tuple[str, list[ToolCallRequest], str, dict[str, int], str | None]:
|
||||
"""Consume an SDK async stream from ``client.responses.create(stream=True)``."""
|
||||
content = ""
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
tool_call_buffers: dict[str, dict[str, Any]] = {}
|
||||
finish_reason = "stop"
|
||||
usage: dict[str, int] = {}
|
||||
reasoning_content: str | None = None
|
||||
|
||||
async for event in stream:
|
||||
event_type = getattr(event, "type", None)
|
||||
if event_type == "response.output_item.added":
|
||||
item = getattr(event, "item", None)
|
||||
if item and getattr(item, "type", None) == "function_call":
|
||||
call_id = getattr(item, "call_id", None)
|
||||
if not call_id:
|
||||
continue
|
||||
tool_call_buffers[call_id] = {
|
||||
"id": getattr(item, "id", None) or "fc_0",
|
||||
"name": getattr(item, "name", None),
|
||||
"arguments": getattr(item, "arguments", None) or "",
|
||||
}
|
||||
elif event_type == "response.output_text.delta":
|
||||
delta_text = getattr(event, "delta", "") or ""
|
||||
content += delta_text
|
||||
if on_content_delta and delta_text:
|
||||
await on_content_delta(delta_text)
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
call_id = getattr(event, "call_id", None)
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] += getattr(event, "delta", "") or ""
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
call_id = getattr(event, "call_id", None)
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] = getattr(event, "arguments", "") or ""
|
||||
elif event_type == "response.output_item.done":
|
||||
item = getattr(event, "item", None)
|
||||
if item and getattr(item, "type", None) == "function_call":
|
||||
call_id = getattr(item, "call_id", None)
|
||||
if not call_id:
|
||||
continue
|
||||
buf = tool_call_buffers.get(call_id) or {}
|
||||
args_raw = buf.get("arguments") or getattr(item, "arguments", None) or "{}"
|
||||
try:
|
||||
args = json.loads(args_raw)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to parse tool call arguments for '{}': {}",
|
||||
buf.get("name") or getattr(item, "name", None),
|
||||
str(args_raw)[:200],
|
||||
)
|
||||
args = json_repair.loads(args_raw)
|
||||
if not isinstance(args, dict):
|
||||
args = {"raw": args_raw}
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=f"{call_id}|{buf.get('id') or getattr(item, 'id', None) or 'fc_0'}",
|
||||
name=buf.get("name") or getattr(item, "name", None) or "",
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
elif event_type == "response.completed":
|
||||
resp = getattr(event, "response", None)
|
||||
status = getattr(resp, "status", None) if resp else None
|
||||
finish_reason = map_finish_reason(status)
|
||||
if resp:
|
||||
usage_obj = getattr(resp, "usage", None)
|
||||
if usage_obj:
|
||||
usage = {
|
||||
"prompt_tokens": int(getattr(usage_obj, "input_tokens", 0) or 0),
|
||||
"completion_tokens": int(getattr(usage_obj, "output_tokens", 0) or 0),
|
||||
"total_tokens": int(getattr(usage_obj, "total_tokens", 0) or 0),
|
||||
}
|
||||
for out_item in getattr(resp, "output", None) or []:
|
||||
if getattr(out_item, "type", None) == "reasoning":
|
||||
for s in getattr(out_item, "summary", None) or []:
|
||||
if getattr(s, "type", None) == "summary_text":
|
||||
text = getattr(s, "text", None)
|
||||
if text:
|
||||
reasoning_content = (reasoning_content or "") + text
|
||||
elif event_type in {"error", "response.failed"}:
|
||||
detail = getattr(event, "error", None) or getattr(event, "message", None) or event
|
||||
raise RuntimeError(f"Response failed: {str(detail)[:500]}")
|
||||
|
||||
return content, tool_calls, finish_reason, usage, reasoning_content
|
||||
@ -34,7 +34,7 @@ class ProviderSpec:
|
||||
display_name: str = "" # shown in `nanobot status`
|
||||
|
||||
# which provider implementation to use
|
||||
# "openai_compat" | "anthropic" | "azure_openai" | "openai_codex"
|
||||
# "openai_compat" | "anthropic" | "azure_openai" | "openai_codex" | "github_copilot"
|
||||
backend: str = "openai_compat"
|
||||
|
||||
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
|
||||
@ -218,8 +218,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
keywords=("github_copilot", "copilot"),
|
||||
env_key="",
|
||||
display_name="Github Copilot",
|
||||
backend="openai_compat",
|
||||
backend="github_copilot",
|
||||
default_api_base="https://api.githubcopilot.com",
|
||||
strip_model_prefix=True,
|
||||
is_oauth=True,
|
||||
),
|
||||
# DeepSeek: OpenAI-compatible at api.deepseek.com
|
||||
|
||||
@ -405,14 +405,18 @@ def build_status_content(
|
||||
)
|
||||
last_in = last_usage.get("prompt_tokens", 0)
|
||||
last_out = last_usage.get("completion_tokens", 0)
|
||||
cached = last_usage.get("cached_tokens", 0)
|
||||
ctx_total = max(context_window_tokens, 0)
|
||||
ctx_pct = int((context_tokens_estimate / ctx_total) * 100) if ctx_total > 0 else 0
|
||||
ctx_used_str = f"{context_tokens_estimate // 1000}k" if context_tokens_estimate >= 1000 else str(context_tokens_estimate)
|
||||
ctx_total_str = f"{ctx_total // 1024}k" if ctx_total > 0 else "n/a"
|
||||
token_line = f"\U0001f4ca Tokens: {last_in} in / {last_out} out"
|
||||
if cached and last_in:
|
||||
token_line += f" ({cached * 100 // last_in}% cached)"
|
||||
return "\n".join([
|
||||
f"\U0001f408 nanobot v{version}",
|
||||
f"\U0001f9e0 Model: {model}",
|
||||
f"\U0001f4ca Tokens: {last_in} in / {last_out} out",
|
||||
token_line,
|
||||
f"\U0001f4da Context: {ctx_used_str}/{ctx_total_str} ({ctx_pct}%)",
|
||||
f"\U0001f4ac Session: {session_msg_count} messages",
|
||||
f"\u23f1 Uptime: {uptime}",
|
||||
|
||||
@ -578,3 +578,82 @@ async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, mon
|
||||
args = mgr._announce_result.await_args.args
|
||||
assert args[3] == "Task completed but no final response was generated."
|
||||
assert args[5] == "ok"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_accumulates_usage_and_preserves_cached_tokens():
|
||||
"""Runner should accumulate prompt/completion tokens across iterations
|
||||
and preserve cached_tokens from provider responses."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})],
|
||||
usage={"prompt_tokens": 100, "completion_tokens": 10, "cached_tokens": 80},
|
||||
)
|
||||
return LLMResponse(
|
||||
content="done",
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="file content")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
))
|
||||
|
||||
# Usage should be accumulated across iterations
|
||||
assert result.usage["prompt_tokens"] == 300 # 100 + 200
|
||||
assert result.usage["completion_tokens"] == 30 # 10 + 20
|
||||
assert result.usage["cached_tokens"] == 230 # 80 + 150
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_passes_cached_tokens_to_hook_context():
|
||||
"""Hook context.usage should contain cached_tokens."""
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_usage: list[dict] = []
|
||||
|
||||
class UsageHook(AgentHook):
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
captured_usage.append(dict(context.usage))
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(
|
||||
content="done",
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
hook=UsageHook(),
|
||||
))
|
||||
|
||||
assert len(captured_usage) == 1
|
||||
assert captured_usage[0]["cached_tokens"] == 150
|
||||
|
||||
@ -208,7 +208,7 @@ def test_channels_login_uses_discovered_plugin_class(monkeypatch):
|
||||
seen["config"] = self.config
|
||||
return True
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda: Config())
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config())
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.registry.discover_all",
|
||||
lambda: {"fakeplugin": _LoginPlugin},
|
||||
@ -220,6 +220,57 @@ def test_channels_login_uses_discovered_plugin_class(monkeypatch):
|
||||
assert seen["force"] is True
|
||||
|
||||
|
||||
def test_channels_login_sets_custom_config_path(monkeypatch, tmp_path):
|
||||
from nanobot.cli.commands import app
|
||||
from nanobot.config.schema import Config
|
||||
from typer.testing import CliRunner
|
||||
|
||||
runner = CliRunner()
|
||||
seen: dict[str, object] = {}
|
||||
config_path = tmp_path / "custom-config.json"
|
||||
|
||||
class _LoginPlugin(_FakePlugin):
|
||||
async def login(self, force: bool = False) -> bool:
|
||||
return True
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config())
|
||||
monkeypatch.setattr(
|
||||
"nanobot.config.loader.set_config_path",
|
||||
lambda path: seen.__setitem__("config_path", path),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.registry.discover_all",
|
||||
lambda: {"fakeplugin": _LoginPlugin},
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["channels", "login", "fakeplugin", "--config", str(config_path)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert seen["config_path"] == config_path.resolve()
|
||||
|
||||
|
||||
def test_channels_status_sets_custom_config_path(monkeypatch, tmp_path):
|
||||
from nanobot.cli.commands import app
|
||||
from nanobot.config.schema import Config
|
||||
from typer.testing import CliRunner
|
||||
|
||||
runner = CliRunner()
|
||||
seen: dict[str, object] = {}
|
||||
config_path = tmp_path / "custom-config.json"
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config())
|
||||
monkeypatch.setattr(
|
||||
"nanobot.config.loader.set_config_path",
|
||||
lambda path: seen.__setitem__("config_path", path),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.channels.registry.discover_all", lambda: {})
|
||||
|
||||
result = runner.invoke(app, ["channels", "status", "--config", str(config_path)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert seen["config_path"] == config_path.resolve()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_skips_disabled_plugin():
|
||||
fake_config = SimpleNamespace(
|
||||
|
||||
@ -3,16 +3,14 @@ from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("nio")
|
||||
pytest.importorskip("nh3")
|
||||
pytest.importorskip("mistune")
|
||||
from nio import RoomSendResponse
|
||||
|
||||
from nanobot.channels.matrix import _build_matrix_text_content
|
||||
|
||||
# Check optional matrix dependencies before importing
|
||||
try:
|
||||
import nh3 # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("Matrix dependencies not installed (nh3)", allow_module_level=True)
|
||||
|
||||
import nanobot.channels.matrix as matrix_module
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
@ -317,6 +317,75 @@ def test_openai_compat_provider_passes_model_through():
|
||||
assert provider.get_default_model() == "github-copilot/gpt-5.3-codex"
|
||||
|
||||
|
||||
def test_make_provider_uses_github_copilot_backend():
|
||||
from nanobot.cli.commands import _make_provider
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
config = Config.model_validate(
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"provider": "github-copilot",
|
||||
"model": "github-copilot/gpt-4.1",
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = _make_provider(config)
|
||||
|
||||
assert provider.__class__.__name__ == "GitHubCopilotProvider"
|
||||
|
||||
|
||||
def test_github_copilot_provider_strips_prefixed_model_name():
|
||||
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = GitHubCopilotProvider(default_model="github-copilot/gpt-5.1")
|
||||
|
||||
kwargs = provider._build_kwargs(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=None,
|
||||
model="github-copilot/gpt-5.1",
|
||||
max_tokens=16,
|
||||
temperature=0.1,
|
||||
reasoning_effort=None,
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
assert kwargs["model"] == "gpt-5.1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_github_copilot_provider_refreshes_client_api_key_before_chat():
|
||||
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.api_key = "no-key"
|
||||
mock_client.chat.completions.create = AsyncMock(return_value={
|
||||
"choices": [{"message": {"content": "ok"}, "finish_reason": "stop"}],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
})
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI", return_value=mock_client):
|
||||
provider = GitHubCopilotProvider(default_model="github-copilot/gpt-5.1")
|
||||
|
||||
provider._get_copilot_access_token = AsyncMock(return_value="copilot-access-token")
|
||||
|
||||
response = await provider.chat(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
model="github-copilot/gpt-5.1",
|
||||
max_tokens=16,
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
assert response.content == "ok"
|
||||
assert provider._client.api_key == "copilot-access-token"
|
||||
provider._get_copilot_access_token.assert_awaited_once()
|
||||
mock_client.chat.completions.create.assert_awaited_once()
|
||||
|
||||
|
||||
def test_openai_codex_strip_prefix_supports_hyphen_and_underscore():
|
||||
assert _strip_model_prefix("openai-codex/gpt-5.1-codex") == "gpt-5.1-codex"
|
||||
assert _strip_model_prefix("openai_codex/gpt-5.1-codex") == "gpt-5.1-codex"
|
||||
|
||||
@ -152,10 +152,12 @@ class TestRestartCommand:
|
||||
])
|
||||
|
||||
await loop._run_agent_loop([])
|
||||
assert loop._last_usage == {"prompt_tokens": 9, "completion_tokens": 4}
|
||||
assert loop._last_usage["prompt_tokens"] == 9
|
||||
assert loop._last_usage["completion_tokens"] == 4
|
||||
|
||||
await loop._run_agent_loop([])
|
||||
assert loop._last_usage == {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
assert loop._last_usage["prompt_tokens"] == 0
|
||||
assert loop._last_usage["completion_tokens"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_falls_back_to_last_usage_when_context_estimate_missing(self):
|
||||
|
||||
@ -285,6 +285,28 @@ def test_add_at_job_uses_default_timezone_for_naive_datetime(tmp_path) -> None:
|
||||
assert job.schedule.at_ms == expected
|
||||
|
||||
|
||||
def test_add_job_delivers_by_default(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
tool.set_context("telegram", "chat-1")
|
||||
|
||||
result = tool._add_job("Morning standup", 60, None, None, None)
|
||||
|
||||
assert result.startswith("Created job")
|
||||
job = tool._cron.list_jobs()[0]
|
||||
assert job.payload.deliver is True
|
||||
|
||||
|
||||
def test_add_job_can_disable_delivery(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
tool.set_context("telegram", "chat-1")
|
||||
|
||||
result = tool._add_job("Background refresh", 60, None, None, None, deliver=False)
|
||||
|
||||
assert result.startswith("Created job")
|
||||
job = tool._cron.list_jobs()[0]
|
||||
assert job.payload.deliver is False
|
||||
|
||||
|
||||
def test_list_excludes_disabled_jobs(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
job = tool._cron.add_job(
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
|
||||
"""Test Azure OpenAI provider (Responses API via OpenAI SDK)."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
@ -8,392 +8,401 @@ from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
def test_azure_openai_provider_init():
|
||||
"""Test AzureOpenAIProvider initialization without deployment_name."""
|
||||
# ---------------------------------------------------------------------------
|
||||
# Init & validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_init_creates_sdk_client():
|
||||
"""Provider creates an AsyncOpenAI client with correct base_url."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
)
|
||||
|
||||
assert provider.api_key == "test-key"
|
||||
assert provider.api_base == "https://test-resource.openai.azure.com/"
|
||||
assert provider.default_model == "gpt-4o-deployment"
|
||||
assert provider.api_version == "2024-10-21"
|
||||
# SDK client base_url ends with /openai/v1/
|
||||
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
|
||||
|
||||
|
||||
def test_azure_openai_provider_init_validation():
|
||||
"""Test AzureOpenAIProvider initialization validation."""
|
||||
# Missing api_key
|
||||
def test_init_base_url_no_trailing_slash():
|
||||
"""Trailing slashes are normalised before building base_url."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="k", api_base="https://res.openai.azure.com",
|
||||
)
|
||||
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
|
||||
|
||||
|
||||
def test_init_base_url_with_trailing_slash():
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="k", api_base="https://res.openai.azure.com/",
|
||||
)
|
||||
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
|
||||
|
||||
|
||||
def test_init_validation_missing_key():
|
||||
with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
|
||||
AzureOpenAIProvider(api_key="", api_base="https://test.com")
|
||||
|
||||
# Missing api_base
|
||||
|
||||
|
||||
def test_init_validation_missing_base():
|
||||
with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
|
||||
AzureOpenAIProvider(api_key="test", api_base="")
|
||||
|
||||
|
||||
def test_build_chat_url():
|
||||
"""Test Azure OpenAI URL building with different deployment names."""
|
||||
def test_no_api_version_in_base_url():
|
||||
"""The /openai/v1/ path should NOT contain an api-version query param."""
|
||||
provider = AzureOpenAIProvider(api_key="k", api_base="https://res.openai.azure.com")
|
||||
base = str(provider._client.base_url)
|
||||
assert "api-version" not in base
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _supports_temperature
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_supports_temperature_standard_model():
|
||||
assert AzureOpenAIProvider._supports_temperature("gpt-4o") is True
|
||||
|
||||
|
||||
def test_supports_temperature_reasoning_model():
|
||||
assert AzureOpenAIProvider._supports_temperature("o3-mini") is False
|
||||
assert AzureOpenAIProvider._supports_temperature("gpt-5-chat") is False
|
||||
assert AzureOpenAIProvider._supports_temperature("o4-mini") is False
|
||||
|
||||
|
||||
def test_supports_temperature_with_reasoning_effort():
|
||||
assert AzureOpenAIProvider._supports_temperature("gpt-4o", reasoning_effort="medium") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_body — Responses API body construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_body_basic():
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
api_key="k", api_base="https://res.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Test various deployment names
|
||||
test_cases = [
|
||||
("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"),
|
||||
("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"),
|
||||
("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"),
|
||||
]
|
||||
|
||||
for deployment_name, expected_url in test_cases:
|
||||
url = provider._build_chat_url(deployment_name)
|
||||
assert url == expected_url
|
||||
messages = [{"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hi"}]
|
||||
body = provider._build_body(messages, None, None, 4096, 0.7, None, None)
|
||||
|
||||
|
||||
def test_build_chat_url_api_base_without_slash():
|
||||
"""Test URL building when api_base doesn't end with slash."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com", # No trailing slash
|
||||
default_model="gpt-4o",
|
||||
assert body["model"] == "gpt-4o"
|
||||
assert body["instructions"] == "You are helpful."
|
||||
assert body["temperature"] == 0.7
|
||||
assert body["max_output_tokens"] == 4096
|
||||
assert body["store"] is False
|
||||
assert "reasoning" not in body
|
||||
# input should contain the converted user message only (system extracted)
|
||||
assert any(
|
||||
item.get("role") == "user"
|
||||
for item in body["input"]
|
||||
)
|
||||
|
||||
url = provider._build_chat_url("test-deployment")
|
||||
expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert url == expected
|
||||
|
||||
|
||||
def test_build_headers():
|
||||
"""Test Azure OpenAI header building with api-key authentication."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-api-key-123",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
headers = provider._build_headers()
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header
|
||||
assert "x-session-affinity" in headers
|
||||
def test_build_body_max_tokens_minimum():
|
||||
"""max_output_tokens should never be less than 1."""
|
||||
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
|
||||
body = provider._build_body([{"role": "user", "content": "x"}], None, None, 0, 0.7, None, None)
|
||||
assert body["max_output_tokens"] == 1
|
||||
|
||||
|
||||
def test_prepare_request_payload():
|
||||
"""Test request payload preparation with Azure OpenAI 2024-10-21 compliance."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
|
||||
|
||||
assert payload["messages"] == messages
|
||||
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
|
||||
assert payload["temperature"] == 0.8
|
||||
assert "tools" not in payload
|
||||
|
||||
# Test with tools
|
||||
def test_build_body_with_tools():
|
||||
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
|
||||
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||
payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools)
|
||||
assert payload_with_tools["tools"] == tools
|
||||
assert payload_with_tools["tool_choice"] == "auto"
|
||||
|
||||
# Test with reasoning_effort
|
||||
payload_with_reasoning = provider._prepare_request_payload(
|
||||
"gpt-5-chat", messages, reasoning_effort="medium"
|
||||
body = provider._build_body(
|
||||
[{"role": "user", "content": "weather?"}], tools, None, 4096, 0.7, None, None,
|
||||
)
|
||||
assert payload_with_reasoning["reasoning_effort"] == "medium"
|
||||
assert "temperature" not in payload_with_reasoning
|
||||
assert body["tools"] == [{"type": "function", "name": "get_weather", "description": "", "parameters": {}}]
|
||||
assert body["tool_choice"] == "auto"
|
||||
|
||||
|
||||
def test_prepare_request_payload_sanitizes_messages():
|
||||
"""Test Azure payload strips non-standard message keys before sending."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
def test_build_body_with_reasoning():
|
||||
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-5-chat")
|
||||
body = provider._build_body(
|
||||
[{"role": "user", "content": "think"}], None, "gpt-5-chat", 4096, 0.7, "medium", None,
|
||||
)
|
||||
assert body["reasoning"] == {"effort": "medium"}
|
||||
assert "reasoning.encrypted_content" in body.get("include", [])
|
||||
# temperature omitted for reasoning models
|
||||
assert "temperature" not in body
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
||||
"reasoning_content": "hidden chain-of-thought",
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"name": "x",
|
||||
"content": "ok",
|
||||
"extra_field": "should be removed",
|
||||
},
|
||||
]
|
||||
|
||||
payload = provider._prepare_request_payload("gpt-4o", messages)
|
||||
def test_build_body_image_conversion():
|
||||
"""image_url content blocks should be converted to input_image."""
|
||||
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/img.png"}},
|
||||
],
|
||||
}]
|
||||
body = provider._build_body(messages, None, None, 4096, 0.7, None, None)
|
||||
user_item = body["input"][0]
|
||||
content_types = [b["type"] for b in user_item["content"]]
|
||||
assert "input_text" in content_types
|
||||
assert "input_image" in content_types
|
||||
image_block = next(b for b in user_item["content"] if b["type"] == "input_image")
|
||||
assert image_block["image_url"] == "https://example.com/img.png"
|
||||
|
||||
assert payload["messages"] == [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
||||
|
||||
def test_build_body_sanitizes_single_dict_content_block():
|
||||
"""Single content dicts should be preserved via shared message sanitization."""
|
||||
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": {"type": "text", "text": "Hi from dict content"},
|
||||
}]
|
||||
|
||||
body = provider._build_body(messages, None, None, 4096, 0.7, None, None)
|
||||
|
||||
assert body["input"][0]["content"] == [{"type": "input_text", "text": "Hi from dict content"}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# chat() — non-streaming
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_sdk_response(
|
||||
content="Hello!", tool_calls=None, status="completed",
|
||||
usage=None,
|
||||
):
|
||||
"""Build a mock that quacks like an openai Response object."""
|
||||
resp = MagicMock()
|
||||
resp.model_dump = MagicMock(return_value={
|
||||
"output": [
|
||||
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": content}]},
|
||||
*([{
|
||||
"type": "function_call",
|
||||
"call_id": tc["call_id"], "id": tc["id"],
|
||||
"name": tc["name"], "arguments": tc["arguments"],
|
||||
} for tc in (tool_calls or [])]),
|
||||
],
|
||||
"status": status,
|
||||
"usage": {
|
||||
"input_tokens": (usage or {}).get("input_tokens", 10),
|
||||
"output_tokens": (usage or {}).get("output_tokens", 5),
|
||||
"total_tokens": (usage or {}).get("total_tokens", 15),
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"name": "x",
|
||||
"content": "ok",
|
||||
},
|
||||
]
|
||||
})
|
||||
return resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_success():
|
||||
"""Test successful chat request using model as deployment name."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Mock response data
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": "Hello! How can I help you today?",
|
||||
"role": "assistant"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 12,
|
||||
"completion_tokens": 18,
|
||||
"total_tokens": 30
|
||||
}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
# Test with specific model (deployment name)
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages, model="custom-deployment")
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert result.content == "Hello! How can I help you today?"
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.usage["prompt_tokens"] == 12
|
||||
assert result.usage["completion_tokens"] == 18
|
||||
assert result.usage["total_tokens"] == 30
|
||||
|
||||
# Verify URL was built with the provided model as deployment name
|
||||
call_args = mock_context.post.call_args
|
||||
expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert call_args[0][0] == expected_url
|
||||
mock_resp = _make_sdk_response(content="Hello!")
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_resp)
|
||||
|
||||
result = await provider.chat([{"role": "user", "content": "Hi"}])
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert result.content == "Hello!"
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.usage["prompt_tokens"] == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_uses_default_model_when_no_model_provided():
|
||||
"""Test that chat uses default_model when no model is specified."""
|
||||
async def test_chat_uses_default_model():
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="default-deployment",
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="my-deployment",
|
||||
)
|
||||
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {"content": "Response", "role": "assistant"},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Test"}]
|
||||
await provider.chat(messages) # No model specified
|
||||
|
||||
# Verify URL was built with default model as deployment name
|
||||
call_args = mock_context.post.call_args
|
||||
expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert call_args[0][0] == expected_url
|
||||
mock_resp = _make_sdk_response(content="ok")
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_resp)
|
||||
|
||||
await provider.chat([{"role": "user", "content": "test"}])
|
||||
|
||||
call_kwargs = provider._client.responses.create.call_args[1]
|
||||
assert call_kwargs["model"] == "my-deployment"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_custom_model():
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
mock_resp = _make_sdk_response(content="ok")
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_resp)
|
||||
|
||||
await provider.chat([{"role": "user", "content": "test"}], model="custom-deploy")
|
||||
|
||||
call_kwargs = provider._client.responses.create.call_args[1]
|
||||
assert call_kwargs["model"] == "custom-deploy"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_tool_calls():
|
||||
"""Test chat request with tool calls in response."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Mock response with tool calls
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": None,
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_12345",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}'
|
||||
}
|
||||
}]
|
||||
},
|
||||
"finish_reason": "tool_calls"
|
||||
mock_resp = _make_sdk_response(
|
||||
content=None,
|
||||
tool_calls=[{
|
||||
"call_id": "call_123", "id": "fc_1",
|
||||
"name": "get_weather", "arguments": '{"location": "SF"}',
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 20,
|
||||
"completion_tokens": 15,
|
||||
"total_tokens": 35
|
||||
}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "What's the weather?"}]
|
||||
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||
result = await provider.chat(messages, tools=tools, model="weather-model")
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert result.content is None
|
||||
assert result.finish_reason == "tool_calls"
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "get_weather"
|
||||
assert result.tool_calls[0].arguments == {"location": "San Francisco"}
|
||||
)
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_resp)
|
||||
|
||||
result = await provider.chat(
|
||||
[{"role": "user", "content": "Weather?"}],
|
||||
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "get_weather"
|
||||
assert result.tool_calls[0].arguments == {"location": "SF"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_api_error():
|
||||
"""Test chat request API error handling."""
|
||||
async def test_chat_error_handling():
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Invalid authentication credentials"
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Azure OpenAI API Error 401" in result.content
|
||||
assert "Invalid authentication credentials" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
|
||||
result = await provider.chat([{"role": "user", "content": "Hi"}])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_connection_error():
|
||||
"""Test chat request connection error handling."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
def test_parse_response_malformed():
|
||||
"""Test response parsing with malformed data."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Test with missing choices
|
||||
malformed_response = {"usage": {"prompt_tokens": 10}}
|
||||
result = provider._parse_response(malformed_response)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Error parsing Azure OpenAI response" in result.content
|
||||
assert "Connection failed" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_reasoning_param_format():
|
||||
"""reasoning_effort should be sent as reasoning={effort: ...} not a flat string."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-5-chat",
|
||||
)
|
||||
mock_resp = _make_sdk_response(content="thought")
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_resp)
|
||||
|
||||
await provider.chat(
|
||||
[{"role": "user", "content": "think"}], reasoning_effort="medium",
|
||||
)
|
||||
|
||||
call_kwargs = provider._client.responses.create.call_args[1]
|
||||
assert call_kwargs["reasoning"] == {"effort": "medium"}
|
||||
assert "reasoning_effort" not in call_kwargs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# chat_stream()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_stream_success():
|
||||
"""Streaming should call on_content_delta and return combined response."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Build mock SDK stream events
|
||||
events = []
|
||||
ev1 = MagicMock(type="response.output_text.delta", delta="Hello")
|
||||
ev2 = MagicMock(type="response.output_text.delta", delta=" world")
|
||||
resp_obj = MagicMock(status="completed")
|
||||
ev3 = MagicMock(type="response.completed", response=resp_obj)
|
||||
events = [ev1, ev2, ev3]
|
||||
|
||||
async def mock_stream():
|
||||
for e in events:
|
||||
yield e
|
||||
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_stream())
|
||||
|
||||
deltas: list[str] = []
|
||||
|
||||
async def on_delta(text: str) -> None:
|
||||
deltas.append(text)
|
||||
|
||||
result = await provider.chat_stream(
|
||||
[{"role": "user", "content": "Hi"}], on_content_delta=on_delta,
|
||||
)
|
||||
|
||||
assert result.content == "Hello world"
|
||||
assert result.finish_reason == "stop"
|
||||
assert deltas == ["Hello", " world"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_stream_with_tool_calls():
|
||||
"""Streaming tool calls should be accumulated correctly."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
|
||||
item_added = MagicMock(type="function_call", call_id="call_1", id="fc_1", arguments="")
|
||||
item_added.name = "get_weather"
|
||||
ev_added = MagicMock(type="response.output_item.added", item=item_added)
|
||||
ev_args_delta = MagicMock(type="response.function_call_arguments.delta", call_id="call_1", delta='{"loc')
|
||||
ev_args_done = MagicMock(
|
||||
type="response.function_call_arguments.done",
|
||||
call_id="call_1", arguments='{"location":"SF"}',
|
||||
)
|
||||
item_done = MagicMock(
|
||||
type="function_call", call_id="call_1", id="fc_1",
|
||||
arguments='{"location":"SF"}',
|
||||
)
|
||||
item_done.name = "get_weather"
|
||||
ev_item_done = MagicMock(type="response.output_item.done", item=item_done)
|
||||
resp_obj = MagicMock(status="completed")
|
||||
ev_completed = MagicMock(type="response.completed", response=resp_obj)
|
||||
|
||||
async def mock_stream():
|
||||
for e in [ev_added, ev_args_delta, ev_args_done, ev_item_done, ev_completed]:
|
||||
yield e
|
||||
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_stream())
|
||||
|
||||
result = await provider.chat_stream(
|
||||
[{"role": "user", "content": "weather?"}],
|
||||
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "get_weather"
|
||||
assert result.tool_calls[0].arguments == {"location": "SF"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_stream_error():
|
||||
"""Streaming should return error when SDK raises."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
|
||||
result = await provider.chat_stream([{"role": "user", "content": "Hi"}])
|
||||
|
||||
assert "Connection failed" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_default_model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_default_model():
|
||||
"""Test get_default_model method."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="my-custom-deployment",
|
||||
api_key="k", api_base="https://r.com", default_model="my-deploy",
|
||||
)
|
||||
|
||||
assert provider.get_default_model() == "my-custom-deployment"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run basic tests
|
||||
print("Running basic Azure OpenAI provider tests...")
|
||||
|
||||
# Test initialization
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
)
|
||||
print("✅ Provider initialization successful")
|
||||
|
||||
# Test URL building
|
||||
url = provider._build_chat_url("my-deployment")
|
||||
expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert url == expected
|
||||
print("✅ URL building works correctly")
|
||||
|
||||
# Test headers
|
||||
headers = provider._build_headers()
|
||||
assert headers["api-key"] == "test-key"
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
print("✅ Header building works correctly")
|
||||
|
||||
# Test payload preparation
|
||||
messages = [{"role": "user", "content": "Test"}]
|
||||
payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
|
||||
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
|
||||
print("✅ Payload preparation works correctly")
|
||||
|
||||
print("✅ All basic tests passed! Updated test file is working correctly.")
|
||||
assert provider.get_default_model() == "my-deploy"
|
||||
|
||||
233
tests/providers/test_cached_tokens.py
Normal file
233
tests/providers/test_cached_tokens.py
Normal file
@ -0,0 +1,233 @@
|
||||
"""Tests for cached token extraction from OpenAI-compatible providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
|
||||
class FakeUsage:
|
||||
"""Mimics an OpenAI SDK usage object (has attributes, not dict keys)."""
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
class FakePromptDetails:
|
||||
"""Mimics prompt_tokens_details sub-object."""
|
||||
def __init__(self, cached_tokens=0):
|
||||
self.cached_tokens = cached_tokens
|
||||
|
||||
|
||||
class _FakeSpec:
|
||||
supports_prompt_caching = False
|
||||
model_id_prefix = None
|
||||
strip_model_prefix = False
|
||||
max_completion_tokens = False
|
||||
reasoning_effort = None
|
||||
|
||||
|
||||
def _provider():
|
||||
from unittest.mock import MagicMock
|
||||
p = OpenAICompatProvider.__new__(OpenAICompatProvider)
|
||||
p.client = MagicMock()
|
||||
p.spec = _FakeSpec()
|
||||
return p
|
||||
|
||||
|
||||
# Minimal valid choice so _parse reaches _extract_usage.
|
||||
_DICT_CHOICE = {"message": {"content": "Hello"}}
|
||||
|
||||
class _FakeMessage:
|
||||
content = "Hello"
|
||||
tool_calls = None
|
||||
|
||||
|
||||
class _FakeChoice:
|
||||
message = _FakeMessage()
|
||||
finish_reason = "stop"
|
||||
|
||||
|
||||
# --- dict-based response (raw JSON / mapping) ---
|
||||
|
||||
def test_extract_usage_openai_cached_tokens_dict():
|
||||
"""prompt_tokens_details.cached_tokens from a dict response."""
|
||||
p = _provider()
|
||||
response = {
|
||||
"choices": [_DICT_CHOICE],
|
||||
"usage": {
|
||||
"prompt_tokens": 2000,
|
||||
"completion_tokens": 300,
|
||||
"total_tokens": 2300,
|
||||
"prompt_tokens_details": {"cached_tokens": 1200},
|
||||
}
|
||||
}
|
||||
result = p._parse(response)
|
||||
assert result.usage["cached_tokens"] == 1200
|
||||
assert result.usage["prompt_tokens"] == 2000
|
||||
|
||||
|
||||
def test_extract_usage_deepseek_cached_tokens_dict():
|
||||
"""prompt_cache_hit_tokens from a DeepSeek dict response."""
|
||||
p = _provider()
|
||||
response = {
|
||||
"choices": [_DICT_CHOICE],
|
||||
"usage": {
|
||||
"prompt_tokens": 1500,
|
||||
"completion_tokens": 200,
|
||||
"total_tokens": 1700,
|
||||
"prompt_cache_hit_tokens": 1200,
|
||||
"prompt_cache_miss_tokens": 300,
|
||||
}
|
||||
}
|
||||
result = p._parse(response)
|
||||
assert result.usage["cached_tokens"] == 1200
|
||||
|
||||
|
||||
def test_extract_usage_no_cached_tokens_dict():
|
||||
"""Response without any cache fields -> no cached_tokens key."""
|
||||
p = _provider()
|
||||
response = {
|
||||
"choices": [_DICT_CHOICE],
|
||||
"usage": {
|
||||
"prompt_tokens": 1000,
|
||||
"completion_tokens": 200,
|
||||
"total_tokens": 1200,
|
||||
}
|
||||
}
|
||||
result = p._parse(response)
|
||||
assert "cached_tokens" not in result.usage
|
||||
|
||||
|
||||
def test_extract_usage_openai_cached_zero_dict():
|
||||
"""cached_tokens=0 should NOT be included (same as existing fields)."""
|
||||
p = _provider()
|
||||
response = {
|
||||
"choices": [_DICT_CHOICE],
|
||||
"usage": {
|
||||
"prompt_tokens": 2000,
|
||||
"completion_tokens": 300,
|
||||
"total_tokens": 2300,
|
||||
"prompt_tokens_details": {"cached_tokens": 0},
|
||||
}
|
||||
}
|
||||
result = p._parse(response)
|
||||
assert "cached_tokens" not in result.usage
|
||||
|
||||
|
||||
# --- object-based response (OpenAI SDK Pydantic model) ---
|
||||
|
||||
def test_extract_usage_openai_cached_tokens_obj():
|
||||
"""prompt_tokens_details.cached_tokens from an SDK object response."""
|
||||
p = _provider()
|
||||
usage_obj = FakeUsage(
|
||||
prompt_tokens=2000,
|
||||
completion_tokens=300,
|
||||
total_tokens=2300,
|
||||
prompt_tokens_details=FakePromptDetails(cached_tokens=1200),
|
||||
)
|
||||
response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj)
|
||||
result = p._parse(response)
|
||||
assert result.usage["cached_tokens"] == 1200
|
||||
|
||||
|
||||
def test_extract_usage_deepseek_cached_tokens_obj():
|
||||
"""prompt_cache_hit_tokens from a DeepSeek SDK object response."""
|
||||
p = _provider()
|
||||
usage_obj = FakeUsage(
|
||||
prompt_tokens=1500,
|
||||
completion_tokens=200,
|
||||
total_tokens=1700,
|
||||
prompt_cache_hit_tokens=1200,
|
||||
)
|
||||
response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj)
|
||||
result = p._parse(response)
|
||||
assert result.usage["cached_tokens"] == 1200
|
||||
|
||||
|
||||
def test_extract_usage_stepfun_top_level_cached_tokens_dict():
|
||||
"""StepFun/Moonshot: usage.cached_tokens at top level (not nested)."""
|
||||
p = _provider()
|
||||
response = {
|
||||
"choices": [_DICT_CHOICE],
|
||||
"usage": {
|
||||
"prompt_tokens": 591,
|
||||
"completion_tokens": 120,
|
||||
"total_tokens": 711,
|
||||
"cached_tokens": 512,
|
||||
}
|
||||
}
|
||||
result = p._parse(response)
|
||||
assert result.usage["cached_tokens"] == 512
|
||||
|
||||
|
||||
def test_extract_usage_stepfun_top_level_cached_tokens_obj():
|
||||
"""StepFun/Moonshot: usage.cached_tokens as SDK object attribute."""
|
||||
p = _provider()
|
||||
usage_obj = FakeUsage(
|
||||
prompt_tokens=591,
|
||||
completion_tokens=120,
|
||||
total_tokens=711,
|
||||
cached_tokens=512,
|
||||
)
|
||||
response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj)
|
||||
result = p._parse(response)
|
||||
assert result.usage["cached_tokens"] == 512
|
||||
|
||||
|
||||
def test_extract_usage_priority_nested_over_top_level_dict():
|
||||
"""When both nested and top-level cached_tokens exist, nested wins."""
|
||||
p = _provider()
|
||||
response = {
|
||||
"choices": [_DICT_CHOICE],
|
||||
"usage": {
|
||||
"prompt_tokens": 2000,
|
||||
"completion_tokens": 300,
|
||||
"total_tokens": 2300,
|
||||
"prompt_tokens_details": {"cached_tokens": 100},
|
||||
"cached_tokens": 500,
|
||||
}
|
||||
}
|
||||
result = p._parse(response)
|
||||
assert result.usage["cached_tokens"] == 100
|
||||
|
||||
|
||||
def test_anthropic_maps_cache_fields_to_cached_tokens():
|
||||
"""Anthropic's cache_read_input_tokens should map to cached_tokens."""
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
usage_obj = FakeUsage(
|
||||
input_tokens=800,
|
||||
output_tokens=200,
|
||||
cache_creation_input_tokens=300,
|
||||
cache_read_input_tokens=1200,
|
||||
)
|
||||
content_block = FakeUsage(type="text", text="hello")
|
||||
response = FakeUsage(
|
||||
id="msg_1",
|
||||
type="message",
|
||||
stop_reason="end_turn",
|
||||
content=[content_block],
|
||||
usage=usage_obj,
|
||||
)
|
||||
result = AnthropicProvider._parse_response(response)
|
||||
assert result.usage["cached_tokens"] == 1200
|
||||
assert result.usage["prompt_tokens"] == 2300
|
||||
assert result.usage["total_tokens"] == 2500
|
||||
assert result.usage["cache_creation_input_tokens"] == 300
|
||||
|
||||
|
||||
def test_anthropic_no_cache_fields():
|
||||
"""Anthropic response without cache fields should not have cached_tokens."""
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
usage_obj = FakeUsage(input_tokens=800, output_tokens=200)
|
||||
content_block = FakeUsage(type="text", text="hello")
|
||||
response = FakeUsage(
|
||||
id="msg_1",
|
||||
type="message",
|
||||
stop_reason="end_turn",
|
||||
content=[content_block],
|
||||
usage=usage_obj,
|
||||
)
|
||||
result = AnthropicProvider._parse_response(response)
|
||||
assert "cached_tokens" not in result.usage
|
||||
522
tests/providers/test_openai_responses.py
Normal file
522
tests/providers/test_openai_responses.py
Normal file
@ -0,0 +1,522 @@
|
||||
"""Tests for the shared openai_responses converters and parsers."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.openai_responses.converters import (
|
||||
convert_messages,
|
||||
convert_tools,
|
||||
convert_user_message,
|
||||
split_tool_call_id,
|
||||
)
|
||||
from nanobot.providers.openai_responses.parsing import (
|
||||
consume_sdk_stream,
|
||||
map_finish_reason,
|
||||
parse_response_output,
|
||||
)
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# converters - split_tool_call_id
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestSplitToolCallId:
|
||||
def test_plain_id(self):
|
||||
assert split_tool_call_id("call_abc") == ("call_abc", None)
|
||||
|
||||
def test_compound_id(self):
|
||||
assert split_tool_call_id("call_abc|fc_1") == ("call_abc", "fc_1")
|
||||
|
||||
def test_compound_empty_item_id(self):
|
||||
assert split_tool_call_id("call_abc|") == ("call_abc", None)
|
||||
|
||||
def test_none(self):
|
||||
assert split_tool_call_id(None) == ("call_0", None)
|
||||
|
||||
def test_empty_string(self):
|
||||
assert split_tool_call_id("") == ("call_0", None)
|
||||
|
||||
def test_non_string(self):
|
||||
assert split_tool_call_id(42) == ("call_0", None)
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# converters - convert_user_message
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestConvertUserMessage:
|
||||
def test_string_content(self):
|
||||
result = convert_user_message("hello")
|
||||
assert result == {"role": "user", "content": [{"type": "input_text", "text": "hello"}]}
|
||||
|
||||
def test_text_block(self):
|
||||
result = convert_user_message([{"type": "text", "text": "hi"}])
|
||||
assert result["content"] == [{"type": "input_text", "text": "hi"}]
|
||||
|
||||
def test_image_url_block(self):
|
||||
result = convert_user_message([
|
||||
{"type": "image_url", "image_url": {"url": "https://img.example/a.png"}},
|
||||
])
|
||||
assert result["content"] == [
|
||||
{"type": "input_image", "image_url": "https://img.example/a.png", "detail": "auto"},
|
||||
]
|
||||
|
||||
def test_mixed_text_and_image(self):
|
||||
result = convert_user_message([
|
||||
{"type": "text", "text": "what's this?"},
|
||||
{"type": "image_url", "image_url": {"url": "https://img.example/b.png"}},
|
||||
])
|
||||
assert len(result["content"]) == 2
|
||||
assert result["content"][0]["type"] == "input_text"
|
||||
assert result["content"][1]["type"] == "input_image"
|
||||
|
||||
def test_empty_list_falls_back(self):
|
||||
result = convert_user_message([])
|
||||
assert result["content"] == [{"type": "input_text", "text": ""}]
|
||||
|
||||
def test_none_falls_back(self):
|
||||
result = convert_user_message(None)
|
||||
assert result["content"] == [{"type": "input_text", "text": ""}]
|
||||
|
||||
def test_image_without_url_skipped(self):
|
||||
result = convert_user_message([{"type": "image_url", "image_url": {}}])
|
||||
assert result["content"] == [{"type": "input_text", "text": ""}]
|
||||
|
||||
def test_meta_fields_not_leaked(self):
|
||||
"""_meta on content blocks must never appear in converted output."""
|
||||
result = convert_user_message([
|
||||
{"type": "text", "text": "hi", "_meta": {"path": "/tmp/x"}},
|
||||
])
|
||||
assert "_meta" not in result["content"][0]
|
||||
|
||||
def test_non_dict_items_skipped(self):
|
||||
result = convert_user_message(["just a string", 42])
|
||||
assert result["content"] == [{"type": "input_text", "text": ""}]
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# converters - convert_messages
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestConvertMessages:
|
||||
def test_system_extracted_as_instructions(self):
|
||||
msgs = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hi"},
|
||||
]
|
||||
instructions, items = convert_messages(msgs)
|
||||
assert instructions == "You are helpful."
|
||||
assert len(items) == 1
|
||||
assert items[0]["role"] == "user"
|
||||
|
||||
def test_multiple_system_messages_last_wins(self):
|
||||
msgs = [
|
||||
{"role": "system", "content": "first"},
|
||||
{"role": "system", "content": "second"},
|
||||
{"role": "user", "content": "x"},
|
||||
]
|
||||
instructions, _ = convert_messages(msgs)
|
||||
assert instructions == "second"
|
||||
|
||||
def test_user_message_converted(self):
|
||||
_, items = convert_messages([{"role": "user", "content": "hello"}])
|
||||
assert items[0]["role"] == "user"
|
||||
assert items[0]["content"][0]["type"] == "input_text"
|
||||
|
||||
def test_assistant_text_message(self):
|
||||
_, items = convert_messages([
|
||||
{"role": "assistant", "content": "I'll help"},
|
||||
])
|
||||
assert items[0]["type"] == "message"
|
||||
assert items[0]["role"] == "assistant"
|
||||
assert items[0]["content"][0]["type"] == "output_text"
|
||||
assert items[0]["content"][0]["text"] == "I'll help"
|
||||
|
||||
def test_assistant_empty_content_skipped(self):
|
||||
_, items = convert_messages([{"role": "assistant", "content": ""}])
|
||||
assert len(items) == 0
|
||||
|
||||
def test_assistant_with_tool_calls(self):
|
||||
_, items = convert_messages([{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_abc|fc_1",
|
||||
"function": {"name": "get_weather", "arguments": '{"city":"SF"}'},
|
||||
}],
|
||||
}])
|
||||
assert items[0]["type"] == "function_call"
|
||||
assert items[0]["call_id"] == "call_abc"
|
||||
assert items[0]["id"] == "fc_1"
|
||||
assert items[0]["name"] == "get_weather"
|
||||
|
||||
def test_assistant_with_tool_calls_no_id(self):
|
||||
"""Fallback IDs when tool_call.id is missing."""
|
||||
_, items = convert_messages([{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"function": {"name": "f1", "arguments": "{}"}}],
|
||||
}])
|
||||
assert items[0]["call_id"] == "call_0"
|
||||
assert items[0]["id"].startswith("fc_")
|
||||
|
||||
def test_tool_message(self):
|
||||
_, items = convert_messages([{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_abc",
|
||||
"content": "result text",
|
||||
}])
|
||||
assert items[0]["type"] == "function_call_output"
|
||||
assert items[0]["call_id"] == "call_abc"
|
||||
assert items[0]["output"] == "result text"
|
||||
|
||||
def test_tool_message_dict_content(self):
|
||||
_, items = convert_messages([{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": {"key": "value"},
|
||||
}])
|
||||
assert items[0]["output"] == '{"key": "value"}'
|
||||
|
||||
def test_non_standard_keys_not_leaked(self):
|
||||
"""Extra keys on messages must not appear in converted items."""
|
||||
_, items = convert_messages([{
|
||||
"role": "user",
|
||||
"content": "hi",
|
||||
"extra_field": "should vanish",
|
||||
"_meta": {"path": "/tmp"},
|
||||
}])
|
||||
item = items[0]
|
||||
assert "extra_field" not in str(item)
|
||||
assert "_meta" not in str(item)
|
||||
|
||||
def test_full_conversation_roundtrip(self):
|
||||
"""System + user + assistant(tool_call) + tool -> correct structure."""
|
||||
msgs = [
|
||||
{"role": "system", "content": "Be concise."},
|
||||
{"role": "user", "content": "Weather in SF?"},
|
||||
{
|
||||
"role": "assistant", "content": None,
|
||||
"tool_calls": [{
|
||||
"id": "c1|fc1",
|
||||
"function": {"name": "get_weather", "arguments": '{"city":"SF"}'},
|
||||
}],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "c1", "content": '{"temp":72}'},
|
||||
]
|
||||
instructions, items = convert_messages(msgs)
|
||||
assert instructions == "Be concise."
|
||||
assert len(items) == 3 # user, function_call, function_call_output
|
||||
assert items[0]["role"] == "user"
|
||||
assert items[1]["type"] == "function_call"
|
||||
assert items[2]["type"] == "function_call_output"
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# converters - convert_tools
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestConvertTools:
|
||||
def test_standard_function_tool(self):
|
||||
tools = [{"type": "function", "function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
|
||||
}}]
|
||||
result = convert_tools(tools)
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == "function"
|
||||
assert result[0]["name"] == "get_weather"
|
||||
assert result[0]["description"] == "Get weather"
|
||||
assert "properties" in result[0]["parameters"]
|
||||
|
||||
def test_tool_without_name_skipped(self):
|
||||
tools = [{"type": "function", "function": {"parameters": {}}}]
|
||||
assert convert_tools(tools) == []
|
||||
|
||||
def test_tool_without_function_wrapper(self):
|
||||
"""Direct dict without type=function wrapper."""
|
||||
tools = [{"name": "f1", "description": "d", "parameters": {}}]
|
||||
result = convert_tools(tools)
|
||||
assert result[0]["name"] == "f1"
|
||||
|
||||
def test_missing_optional_fields_default(self):
|
||||
tools = [{"type": "function", "function": {"name": "f"}}]
|
||||
result = convert_tools(tools)
|
||||
assert result[0]["description"] == ""
|
||||
assert result[0]["parameters"] == {}
|
||||
|
||||
def test_multiple_tools(self):
|
||||
tools = [
|
||||
{"type": "function", "function": {"name": "a", "parameters": {}}},
|
||||
{"type": "function", "function": {"name": "b", "parameters": {}}},
|
||||
]
|
||||
assert len(convert_tools(tools)) == 2
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# parsing - map_finish_reason
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestMapFinishReason:
|
||||
def test_completed(self):
|
||||
assert map_finish_reason("completed") == "stop"
|
||||
|
||||
def test_incomplete(self):
|
||||
assert map_finish_reason("incomplete") == "length"
|
||||
|
||||
def test_failed(self):
|
||||
assert map_finish_reason("failed") == "error"
|
||||
|
||||
def test_cancelled(self):
|
||||
assert map_finish_reason("cancelled") == "error"
|
||||
|
||||
def test_none_defaults_to_stop(self):
|
||||
assert map_finish_reason(None) == "stop"
|
||||
|
||||
def test_unknown_defaults_to_stop(self):
|
||||
assert map_finish_reason("some_new_status") == "stop"
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# parsing - parse_response_output
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestParseResponseOutput:
|
||||
def test_text_response(self):
|
||||
resp = {
|
||||
"output": [{"type": "message", "role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "Hello!"}]}],
|
||||
"status": "completed",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
}
|
||||
result = parse_response_output(resp)
|
||||
assert result.content == "Hello!"
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.usage == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
|
||||
assert result.tool_calls == []
|
||||
|
||||
def test_tool_call_response(self):
|
||||
resp = {
|
||||
"output": [{
|
||||
"type": "function_call",
|
||||
"call_id": "call_1", "id": "fc_1",
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "SF"}',
|
||||
}],
|
||||
"status": "completed",
|
||||
"usage": {},
|
||||
}
|
||||
result = parse_response_output(resp)
|
||||
assert result.content is None
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "get_weather"
|
||||
assert result.tool_calls[0].arguments == {"city": "SF"}
|
||||
assert result.tool_calls[0].id == "call_1|fc_1"
|
||||
|
||||
def test_malformed_tool_arguments_logged(self):
|
||||
"""Malformed JSON arguments should log a warning and fallback."""
|
||||
resp = {
|
||||
"output": [{
|
||||
"type": "function_call",
|
||||
"call_id": "c1", "id": "fc1",
|
||||
"name": "f", "arguments": "{bad json",
|
||||
}],
|
||||
"status": "completed", "usage": {},
|
||||
}
|
||||
with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger:
|
||||
result = parse_response_output(resp)
|
||||
assert result.tool_calls[0].arguments == {"raw": "{bad json"}
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args)
|
||||
|
||||
def test_reasoning_content_extracted(self):
|
||||
resp = {
|
||||
"output": [
|
||||
{"type": "reasoning", "summary": [
|
||||
{"type": "summary_text", "text": "I think "},
|
||||
{"type": "summary_text", "text": "therefore I am."},
|
||||
]},
|
||||
{"type": "message", "role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "42"}]},
|
||||
],
|
||||
"status": "completed", "usage": {},
|
||||
}
|
||||
result = parse_response_output(resp)
|
||||
assert result.content == "42"
|
||||
assert result.reasoning_content == "I think therefore I am."
|
||||
|
||||
def test_empty_output(self):
|
||||
resp = {"output": [], "status": "completed", "usage": {}}
|
||||
result = parse_response_output(resp)
|
||||
assert result.content is None
|
||||
assert result.tool_calls == []
|
||||
|
||||
def test_incomplete_status(self):
|
||||
resp = {"output": [], "status": "incomplete", "usage": {}}
|
||||
result = parse_response_output(resp)
|
||||
assert result.finish_reason == "length"
|
||||
|
||||
def test_sdk_model_object(self):
|
||||
"""parse_response_output should handle SDK objects with model_dump()."""
|
||||
mock = MagicMock()
|
||||
mock.model_dump.return_value = {
|
||||
"output": [{"type": "message", "role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "sdk"}]}],
|
||||
"status": "completed",
|
||||
"usage": {"input_tokens": 1, "output_tokens": 2, "total_tokens": 3},
|
||||
}
|
||||
result = parse_response_output(mock)
|
||||
assert result.content == "sdk"
|
||||
assert result.usage["prompt_tokens"] == 1
|
||||
|
||||
def test_usage_maps_responses_api_keys(self):
|
||||
"""Responses API uses input_tokens/output_tokens, not prompt_tokens/completion_tokens."""
|
||||
resp = {
|
||||
"output": [],
|
||||
"status": "completed",
|
||||
"usage": {"input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
|
||||
}
|
||||
result = parse_response_output(resp)
|
||||
assert result.usage["prompt_tokens"] == 100
|
||||
assert result.usage["completion_tokens"] == 50
|
||||
assert result.usage["total_tokens"] == 150
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# parsing - consume_sdk_stream
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestConsumeSdkStream:
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_stream(self):
|
||||
ev1 = MagicMock(type="response.output_text.delta", delta="Hello")
|
||||
ev2 = MagicMock(type="response.output_text.delta", delta=" world")
|
||||
resp_obj = MagicMock(status="completed", usage=None, output=[])
|
||||
ev3 = MagicMock(type="response.completed", response=resp_obj)
|
||||
|
||||
async def stream():
|
||||
for e in [ev1, ev2, ev3]:
|
||||
yield e
|
||||
|
||||
content, tool_calls, finish_reason, usage, reasoning = await consume_sdk_stream(stream())
|
||||
assert content == "Hello world"
|
||||
assert tool_calls == []
|
||||
assert finish_reason == "stop"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_content_delta_called(self):
|
||||
ev1 = MagicMock(type="response.output_text.delta", delta="hi")
|
||||
resp_obj = MagicMock(status="completed", usage=None, output=[])
|
||||
ev2 = MagicMock(type="response.completed", response=resp_obj)
|
||||
deltas = []
|
||||
|
||||
async def cb(text):
|
||||
deltas.append(text)
|
||||
|
||||
async def stream():
|
||||
for e in [ev1, ev2]:
|
||||
yield e
|
||||
|
||||
await consume_sdk_stream(stream(), on_content_delta=cb)
|
||||
assert deltas == ["hi"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_stream(self):
|
||||
item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="")
|
||||
item_added.name = "get_weather"
|
||||
ev1 = MagicMock(type="response.output_item.added", item=item_added)
|
||||
ev2 = MagicMock(type="response.function_call_arguments.delta", call_id="c1", delta='{"ci')
|
||||
ev3 = MagicMock(type="response.function_call_arguments.done", call_id="c1", arguments='{"city":"SF"}')
|
||||
item_done = MagicMock(type="function_call", call_id="c1", id="fc1", arguments='{"city":"SF"}')
|
||||
item_done.name = "get_weather"
|
||||
ev4 = MagicMock(type="response.output_item.done", item=item_done)
|
||||
resp_obj = MagicMock(status="completed", usage=None, output=[])
|
||||
ev5 = MagicMock(type="response.completed", response=resp_obj)
|
||||
|
||||
async def stream():
|
||||
for e in [ev1, ev2, ev3, ev4, ev5]:
|
||||
yield e
|
||||
|
||||
content, tool_calls, finish_reason, usage, reasoning = await consume_sdk_stream(stream())
|
||||
assert content == ""
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].name == "get_weather"
|
||||
assert tool_calls[0].arguments == {"city": "SF"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_extracted(self):
|
||||
usage_obj = MagicMock(input_tokens=10, output_tokens=5, total_tokens=15)
|
||||
resp_obj = MagicMock(status="completed", usage=usage_obj, output=[])
|
||||
ev = MagicMock(type="response.completed", response=resp_obj)
|
||||
|
||||
async def stream():
|
||||
yield ev
|
||||
|
||||
_, _, _, usage, _ = await consume_sdk_stream(stream())
|
||||
assert usage == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_extracted(self):
|
||||
summary_item = MagicMock(type="summary_text", text="thinking...")
|
||||
reasoning_item = MagicMock(type="reasoning", summary=[summary_item])
|
||||
resp_obj = MagicMock(status="completed", usage=None, output=[reasoning_item])
|
||||
ev = MagicMock(type="response.completed", response=resp_obj)
|
||||
|
||||
async def stream():
|
||||
yield ev
|
||||
|
||||
_, _, _, _, reasoning = await consume_sdk_stream(stream())
|
||||
assert reasoning == "thinking..."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_event_raises(self):
|
||||
ev = MagicMock(type="error", error="rate_limit_exceeded")
|
||||
|
||||
async def stream():
|
||||
yield ev
|
||||
|
||||
with pytest.raises(RuntimeError, match="Response failed.*rate_limit_exceeded"):
|
||||
await consume_sdk_stream(stream())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_event_raises(self):
|
||||
ev = MagicMock(type="response.failed", error="server_error")
|
||||
|
||||
async def stream():
|
||||
yield ev
|
||||
|
||||
with pytest.raises(RuntimeError, match="Response failed.*server_error"):
|
||||
await consume_sdk_stream(stream())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_malformed_tool_args_logged(self):
|
||||
"""Malformed JSON in streaming tool args should log a warning."""
|
||||
item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="")
|
||||
item_added.name = "f"
|
||||
ev1 = MagicMock(type="response.output_item.added", item=item_added)
|
||||
ev2 = MagicMock(type="response.function_call_arguments.done", call_id="c1", arguments="{bad")
|
||||
item_done = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="{bad")
|
||||
item_done.name = "f"
|
||||
ev3 = MagicMock(type="response.output_item.done", item=item_done)
|
||||
resp_obj = MagicMock(status="completed", usage=None, output=[])
|
||||
ev4 = MagicMock(type="response.completed", response=resp_obj)
|
||||
|
||||
async def stream():
|
||||
for e in [ev1, ev2, ev3, ev4]:
|
||||
yield e
|
||||
|
||||
with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger:
|
||||
_, tool_calls, _, _, _ = await consume_sdk_stream(stream())
|
||||
assert tool_calls[0].arguments == {"raw": "{bad"}
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args)
|
||||
@ -11,6 +11,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None:
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers.anthropic_provider", raising=False)
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers.openai_compat_provider", raising=False)
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers.openai_codex_provider", raising=False)
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers.github_copilot_provider", raising=False)
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers.azure_openai_provider", raising=False)
|
||||
|
||||
providers = importlib.import_module("nanobot.providers")
|
||||
@ -18,6 +19,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None:
|
||||
assert "nanobot.providers.anthropic_provider" not in sys.modules
|
||||
assert "nanobot.providers.openai_compat_provider" not in sys.modules
|
||||
assert "nanobot.providers.openai_codex_provider" not in sys.modules
|
||||
assert "nanobot.providers.github_copilot_provider" not in sys.modules
|
||||
assert "nanobot.providers.azure_openai_provider" not in sys.modules
|
||||
assert providers.__all__ == [
|
||||
"LLMProvider",
|
||||
@ -25,6 +27,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None:
|
||||
"AnthropicProvider",
|
||||
"OpenAICompatProvider",
|
||||
"OpenAICodexProvider",
|
||||
"GitHubCopilotProvider",
|
||||
"AzureOpenAIProvider",
|
||||
]
|
||||
|
||||
|
||||
59
tests/test_build_status.py
Normal file
59
tests/test_build_status.py
Normal file
@ -0,0 +1,59 @@
|
||||
"""Tests for build_status_content cache hit rate display."""
|
||||
|
||||
from nanobot.utils.helpers import build_status_content
|
||||
|
||||
|
||||
def test_status_shows_cache_hit_rate():
|
||||
content = build_status_content(
|
||||
version="0.1.0",
|
||||
model="glm-4-plus",
|
||||
start_time=1000000.0,
|
||||
last_usage={"prompt_tokens": 2000, "completion_tokens": 300, "cached_tokens": 1200},
|
||||
context_window_tokens=128000,
|
||||
session_msg_count=10,
|
||||
context_tokens_estimate=5000,
|
||||
)
|
||||
assert "60% cached" in content
|
||||
assert "2000 in / 300 out" in content
|
||||
|
||||
|
||||
def test_status_no_cache_info():
|
||||
"""Without cached_tokens, display should not show cache percentage."""
|
||||
content = build_status_content(
|
||||
version="0.1.0",
|
||||
model="glm-4-plus",
|
||||
start_time=1000000.0,
|
||||
last_usage={"prompt_tokens": 2000, "completion_tokens": 300},
|
||||
context_window_tokens=128000,
|
||||
session_msg_count=10,
|
||||
context_tokens_estimate=5000,
|
||||
)
|
||||
assert "cached" not in content.lower()
|
||||
assert "2000 in / 300 out" in content
|
||||
|
||||
|
||||
def test_status_zero_cached_tokens():
|
||||
"""cached_tokens=0 should not show cache percentage."""
|
||||
content = build_status_content(
|
||||
version="0.1.0",
|
||||
model="glm-4-plus",
|
||||
start_time=1000000.0,
|
||||
last_usage={"prompt_tokens": 2000, "completion_tokens": 300, "cached_tokens": 0},
|
||||
context_window_tokens=128000,
|
||||
session_msg_count=10,
|
||||
context_tokens_estimate=5000,
|
||||
)
|
||||
assert "cached" not in content.lower()
|
||||
|
||||
|
||||
def test_status_100_percent_cached():
|
||||
content = build_status_content(
|
||||
version="0.1.0",
|
||||
model="glm-4-plus",
|
||||
start_time=1000000.0,
|
||||
last_usage={"prompt_tokens": 1000, "completion_tokens": 100, "cached_tokens": 1000},
|
||||
context_window_tokens=128000,
|
||||
session_msg_count=5,
|
||||
context_tokens_estimate=3000,
|
||||
)
|
||||
assert "100% cached" in content
|
||||
@ -125,6 +125,27 @@ def test_workspace_override(tmp_path):
|
||||
assert bot._loop.workspace == custom_ws
|
||||
|
||||
|
||||
def test_sdk_make_provider_uses_github_copilot_backend():
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.nanobot import _make_provider
|
||||
|
||||
config = Config.model_validate(
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"provider": "github-copilot",
|
||||
"model": "github-copilot/gpt-4.1",
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = _make_provider(config)
|
||||
|
||||
assert provider.__class__.__name__ == "GitHubCopilotProvider"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_custom_session_key(tmp_path):
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
|
||||
@ -95,6 +95,14 @@ def test_exec_extract_absolute_paths_keeps_full_windows_path() -> None:
|
||||
assert paths == [r"C:\user\workspace\txt"]
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_captures_windows_drive_root_path() -> None:
|
||||
"""Windows drive root paths like `E:\\` must be extracted for workspace guarding."""
|
||||
# Note: raw strings cannot end with a single backslash.
|
||||
cmd = "dir E:\\"
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert paths == ["E:\\"]
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None:
|
||||
cmd = ".venv/bin/python script.py"
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
@ -134,6 +142,45 @@ def test_exec_guard_blocks_quoted_home_path_outside_workspace(tmp_path) -> None:
|
||||
assert error == "Error: Command blocked by safety guard (path outside working dir)"
|
||||
|
||||
|
||||
def test_exec_guard_blocks_windows_drive_root_outside_workspace(monkeypatch) -> None:
|
||||
import nanobot.agent.tools.shell as shell_mod
|
||||
|
||||
class FakeWindowsPath:
|
||||
def __init__(self, raw: str) -> None:
|
||||
self.raw = raw.rstrip("\\") + ("\\" if raw.endswith("\\") else "")
|
||||
|
||||
def resolve(self) -> "FakeWindowsPath":
|
||||
return self
|
||||
|
||||
def expanduser(self) -> "FakeWindowsPath":
|
||||
return self
|
||||
|
||||
def is_absolute(self) -> bool:
|
||||
return len(self.raw) >= 3 and self.raw[1:3] == ":\\"
|
||||
|
||||
@property
|
||||
def parents(self) -> list["FakeWindowsPath"]:
|
||||
if not self.is_absolute():
|
||||
return []
|
||||
trimmed = self.raw.rstrip("\\")
|
||||
if len(trimmed) <= 2:
|
||||
return []
|
||||
idx = trimmed.rfind("\\")
|
||||
if idx <= 2:
|
||||
return [FakeWindowsPath(trimmed[:2] + "\\")]
|
||||
parent = FakeWindowsPath(trimmed[:idx])
|
||||
return [parent, *parent.parents]
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, FakeWindowsPath) and self.raw.lower() == other.raw.lower()
|
||||
|
||||
monkeypatch.setattr(shell_mod, "Path", FakeWindowsPath)
|
||||
|
||||
tool = ExecTool(restrict_to_workspace=True)
|
||||
error = tool._guard_command("dir E:\\", "E:\\workspace")
|
||||
assert error == "Error: Command blocked by safety guard (path outside working dir)"
|
||||
|
||||
|
||||
# --- cast_params tests ---
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user