From a37bc26ed3f384464ce1719a27b51f73d509f349 Mon Sep 17 00:00:00 2001 From: RongLei Date: Tue, 31 Mar 2026 23:36:37 +0800 Subject: [PATCH 01/31] fix: restore GitHub Copilot auth flow Implement the real GitHub device flow and Copilot token exchange for the GitHub Copilot provider. Also route github-copilot models through a dedicated backend and strip the provider prefix before API requests. Add focused regression coverage for provider wiring and model normalization. Generated with GitHub Copilot, GPT-5.4. --- nanobot/cli/commands.py | 31 ++- nanobot/providers/__init__.py | 3 + nanobot/providers/github_copilot_provider.py | 207 +++++++++++++++++++ nanobot/providers/registry.py | 5 +- tests/cli/test_commands.py | 40 ++++ 5 files changed, 265 insertions(+), 21 deletions(-) create mode 100644 nanobot/providers/github_copilot_provider.py diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 7f7d24f39..49521aa16 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -415,6 +415,9 @@ def _make_provider(config: Config): api_base=p.api_base, default_model=model, ) + elif backend == "github_copilot": + from nanobot.providers.github_copilot_provider import GitHubCopilotProvider + provider = GitHubCopilotProvider(default_model=model) elif backend == "anthropic": from nanobot.providers.anthropic_provider import AnthropicProvider provider = AnthropicProvider( @@ -1289,26 +1292,16 @@ def _login_openai_codex() -> None: @_register_login("github_copilot") def _login_github_copilot() -> None: - import asyncio - - from openai import AsyncOpenAI - - console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n") - - async def _trigger(): - client = AsyncOpenAI( - api_key="dummy", - base_url="https://api.githubcopilot.com", - ) - await client.chat.completions.create( - model="gpt-4o", - messages=[{"role": "user", "content": "hi"}], - max_tokens=1, - ) - try: - asyncio.run(_trigger()) - console.print("[green]✓ Authenticated with GitHub Copilot[/green]") + from nanobot.providers.github_copilot_provider import login_github_copilot + + console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n") + token = login_github_copilot( + print_fn=lambda s: console.print(s), + prompt_fn=lambda s: typer.prompt(s), + ) + account = token.account_id or "GitHub" + console.print(f"[green]✓ Authenticated with GitHub Copilot[/green] [dim]{account}[/dim]") except Exception as e: console.print(f"[red]Authentication error: {e}[/red]") raise typer.Exit(1) diff --git a/nanobot/providers/__init__.py b/nanobot/providers/__init__.py index 0e259e6f0..ce2378707 100644 --- a/nanobot/providers/__init__.py +++ b/nanobot/providers/__init__.py @@ -13,6 +13,7 @@ __all__ = [ "AnthropicProvider", "OpenAICompatProvider", "OpenAICodexProvider", + "GitHubCopilotProvider", "AzureOpenAIProvider", ] @@ -20,12 +21,14 @@ _LAZY_IMPORTS = { "AnthropicProvider": ".anthropic_provider", "OpenAICompatProvider": ".openai_compat_provider", "OpenAICodexProvider": ".openai_codex_provider", + "GitHubCopilotProvider": ".github_copilot_provider", "AzureOpenAIProvider": ".azure_openai_provider", } if TYPE_CHECKING: from nanobot.providers.anthropic_provider import AnthropicProvider from nanobot.providers.azure_openai_provider import AzureOpenAIProvider + from nanobot.providers.github_copilot_provider import GitHubCopilotProvider from nanobot.providers.openai_compat_provider import OpenAICompatProvider from nanobot.providers.openai_codex_provider import OpenAICodexProvider diff --git a/nanobot/providers/github_copilot_provider.py b/nanobot/providers/github_copilot_provider.py new file mode 100644 index 000000000..eb8b922af --- /dev/null +++ b/nanobot/providers/github_copilot_provider.py @@ -0,0 +1,207 @@ +"""GitHub Copilot OAuth-backed provider.""" + +from __future__ import annotations + +import time +import webbrowser +from collections.abc import Callable + +import httpx +from oauth_cli_kit.models import OAuthToken +from oauth_cli_kit.storage import FileTokenStorage + +from nanobot.providers.openai_compat_provider import OpenAICompatProvider + +DEFAULT_GITHUB_DEVICE_CODE_URL = "https://github.com/login/device/code" +DEFAULT_GITHUB_ACCESS_TOKEN_URL = "https://github.com/login/oauth/access_token" +DEFAULT_GITHUB_USER_URL = "https://api.github.com/user" +DEFAULT_COPILOT_TOKEN_URL = "https://api.github.com/copilot_internal/v2/token" +DEFAULT_COPILOT_BASE_URL = "https://api.githubcopilot.com" +GITHUB_COPILOT_CLIENT_ID = "Iv1.b507a08c87ecfe98" +GITHUB_COPILOT_SCOPE = "read:user" +TOKEN_FILENAME = "github-copilot.json" +TOKEN_APP_NAME = "nanobot" +USER_AGENT = "nanobot/0.1" +EDITOR_VERSION = "vscode/1.99.0" +EDITOR_PLUGIN_VERSION = "copilot-chat/0.26.0" +_EXPIRY_SKEW_SECONDS = 60 +_LONG_LIVED_TOKEN_SECONDS = 315360000 + + +def _storage() -> FileTokenStorage: + return FileTokenStorage( + token_filename=TOKEN_FILENAME, + app_name=TOKEN_APP_NAME, + import_codex_cli=False, + ) + + +def _copilot_headers(token: str) -> dict[str, str]: + return { + "Authorization": f"token {token}", + "Accept": "application/json", + "User-Agent": USER_AGENT, + "Editor-Version": EDITOR_VERSION, + "Editor-Plugin-Version": EDITOR_PLUGIN_VERSION, + } + + +def _load_github_token() -> OAuthToken | None: + token = _storage().load() + if not token or not token.access: + return None + return token + + +def get_github_copilot_login_status() -> OAuthToken | None: + """Return the persisted GitHub OAuth token if available.""" + return _load_github_token() + + +def login_github_copilot( + print_fn: Callable[[str], None] | None = None, + prompt_fn: Callable[[str], str] | None = None, +) -> OAuthToken: + """Run GitHub device flow and persist the GitHub OAuth token used for Copilot.""" + del prompt_fn + printer = print_fn or print + timeout = httpx.Timeout(20.0, connect=20.0) + + with httpx.Client(timeout=timeout, follow_redirects=True, trust_env=True) as client: + response = client.post( + DEFAULT_GITHUB_DEVICE_CODE_URL, + headers={"Accept": "application/json", "User-Agent": USER_AGENT}, + data={"client_id": GITHUB_COPILOT_CLIENT_ID, "scope": GITHUB_COPILOT_SCOPE}, + ) + response.raise_for_status() + payload = response.json() + + device_code = str(payload["device_code"]) + user_code = str(payload["user_code"]) + verify_url = str(payload.get("verification_uri") or payload.get("verification_uri_complete") or "") + verify_complete = str(payload.get("verification_uri_complete") or verify_url) + interval = max(1, int(payload.get("interval") or 5)) + expires_in = int(payload.get("expires_in") or 900) + + printer(f"Open: {verify_url}") + printer(f"Code: {user_code}") + if verify_complete: + try: + webbrowser.open(verify_complete) + except Exception: + pass + + deadline = time.time() + expires_in + current_interval = interval + access_token = None + token_expires_in = _LONG_LIVED_TOKEN_SECONDS + while time.time() < deadline: + poll = client.post( + DEFAULT_GITHUB_ACCESS_TOKEN_URL, + headers={"Accept": "application/json", "User-Agent": USER_AGENT}, + data={ + "client_id": GITHUB_COPILOT_CLIENT_ID, + "device_code": device_code, + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", + }, + ) + poll.raise_for_status() + poll_payload = poll.json() + + access_token = poll_payload.get("access_token") + if access_token: + token_expires_in = int(poll_payload.get("expires_in") or _LONG_LIVED_TOKEN_SECONDS) + break + + error = poll_payload.get("error") + if error == "authorization_pending": + time.sleep(current_interval) + continue + if error == "slow_down": + current_interval += 5 + time.sleep(current_interval) + continue + if error == "expired_token": + raise RuntimeError("GitHub device code expired. Please run login again.") + if error == "access_denied": + raise RuntimeError("GitHub device flow was denied.") + if error: + desc = poll_payload.get("error_description") or error + raise RuntimeError(str(desc)) + time.sleep(current_interval) + else: + raise RuntimeError("GitHub device flow timed out.") + + user = client.get( + DEFAULT_GITHUB_USER_URL, + headers={ + "Authorization": f"Bearer {access_token}", + "Accept": "application/vnd.github+json", + "User-Agent": USER_AGENT, + }, + ) + user.raise_for_status() + user_payload = user.json() + account_id = user_payload.get("login") or str(user_payload.get("id") or "") or None + + expires_ms = int((time.time() + token_expires_in) * 1000) + token = OAuthToken( + access=str(access_token), + refresh="", + expires=expires_ms, + account_id=str(account_id) if account_id else None, + ) + _storage().save(token) + return token + + +class GitHubCopilotProvider(OpenAICompatProvider): + """Provider that exchanges a stored GitHub OAuth token for Copilot access tokens.""" + + def __init__(self, default_model: str = "github-copilot/gpt-4.1"): + from nanobot.providers.registry import find_by_name + + self._copilot_access_token: str | None = None + self._copilot_expires_at: float = 0.0 + super().__init__( + api_key=self._get_copilot_access_token, + api_base=DEFAULT_COPILOT_BASE_URL, + default_model=default_model, + extra_headers={ + "Editor-Version": EDITOR_VERSION, + "Editor-Plugin-Version": EDITOR_PLUGIN_VERSION, + "User-Agent": USER_AGENT, + }, + spec=find_by_name("github_copilot"), + ) + + async def _get_copilot_access_token(self) -> str: + now = time.time() + if self._copilot_access_token and now < self._copilot_expires_at - _EXPIRY_SKEW_SECONDS: + return self._copilot_access_token + + github_token = _load_github_token() + if not github_token or not github_token.access: + raise RuntimeError("GitHub Copilot is not logged in. Run: nanobot provider login github-copilot") + + timeout = httpx.Timeout(20.0, connect=20.0) + async with httpx.AsyncClient(timeout=timeout, follow_redirects=True, trust_env=True) as client: + response = await client.get( + DEFAULT_COPILOT_TOKEN_URL, + headers=_copilot_headers(github_token.access), + ) + response.raise_for_status() + payload = response.json() + + token = payload.get("token") + if not token: + raise RuntimeError("GitHub Copilot token exchange returned no token.") + + expires_at = payload.get("expires_at") + if isinstance(expires_at, (int, float)): + self._copilot_expires_at = float(expires_at) + else: + refresh_in = payload.get("refresh_in") or 1500 + self._copilot_expires_at = time.time() + int(refresh_in) + self._copilot_access_token = str(token) + return self._copilot_access_token diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py index 5644fc51d..8435005e1 100644 --- a/nanobot/providers/registry.py +++ b/nanobot/providers/registry.py @@ -34,7 +34,7 @@ class ProviderSpec: display_name: str = "" # shown in `nanobot status` # which provider implementation to use - # "openai_compat" | "anthropic" | "azure_openai" | "openai_codex" + # "openai_compat" | "anthropic" | "azure_openai" | "openai_codex" | "github_copilot" backend: str = "openai_compat" # extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),) @@ -218,8 +218,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = ( keywords=("github_copilot", "copilot"), env_key="", display_name="Github Copilot", - backend="openai_compat", + backend="github_copilot", default_api_base="https://api.githubcopilot.com", + strip_model_prefix=True, is_oauth=True, ), # DeepSeek: OpenAI-compatible at api.deepseek.com diff --git a/tests/cli/test_commands.py b/tests/cli/test_commands.py index 735c02a5a..b9869e74d 100644 --- a/tests/cli/test_commands.py +++ b/tests/cli/test_commands.py @@ -317,6 +317,46 @@ def test_openai_compat_provider_passes_model_through(): assert provider.get_default_model() == "github-copilot/gpt-5.3-codex" +def test_make_provider_uses_github_copilot_backend(): + from nanobot.cli.commands import _make_provider + from nanobot.config.schema import Config + + config = Config.model_validate( + { + "agents": { + "defaults": { + "provider": "github-copilot", + "model": "github-copilot/gpt-4.1", + } + } + } + ) + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = _make_provider(config) + + assert provider.__class__.__name__ == "GitHubCopilotProvider" + + +def test_github_copilot_provider_strips_prefixed_model_name(): + from nanobot.providers.github_copilot_provider import GitHubCopilotProvider + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = GitHubCopilotProvider(default_model="github-copilot/gpt-5.1") + + kwargs = provider._build_kwargs( + messages=[{"role": "user", "content": "hi"}], + tools=None, + model="github-copilot/gpt-5.1", + max_tokens=16, + temperature=0.1, + reasoning_effort=None, + tool_choice=None, + ) + + assert kwargs["model"] == "gpt-5.1" + + def test_openai_codex_strip_prefix_supports_hyphen_and_underscore(): assert _strip_model_prefix("openai-codex/gpt-5.1-codex") == "gpt-5.1-codex" assert _strip_model_prefix("openai_codex/gpt-5.1-codex") == "gpt-5.1-codex" From c5f09973817b6de5461aeca6b6f4fe60ebf1cac1 Mon Sep 17 00:00:00 2001 From: RongLei Date: Wed, 1 Apr 2026 21:43:49 +0800 Subject: [PATCH 02/31] fix: refresh copilot token before requests Address PR review feedback by avoiding an async method reference as the OpenAI client api_key. Initialize the client with a placeholder key, refresh the Copilot token before each chat/chat_stream call, and update the runtime client api_key before dispatch. Add a regression test that verifies the client api_key is refreshed to a real string before chat requests. Generated with GitHub Copilot, GPT-5.4. --- nanobot/providers/github_copilot_provider.py | 52 +++++++++++++++++++- tests/cli/test_commands.py | 29 +++++++++++ 2 files changed, 80 insertions(+), 1 deletion(-) diff --git a/nanobot/providers/github_copilot_provider.py b/nanobot/providers/github_copilot_provider.py index eb8b922af..8d50006a0 100644 --- a/nanobot/providers/github_copilot_provider.py +++ b/nanobot/providers/github_copilot_provider.py @@ -164,7 +164,7 @@ class GitHubCopilotProvider(OpenAICompatProvider): self._copilot_access_token: str | None = None self._copilot_expires_at: float = 0.0 super().__init__( - api_key=self._get_copilot_access_token, + api_key="no-key", api_base=DEFAULT_COPILOT_BASE_URL, default_model=default_model, extra_headers={ @@ -205,3 +205,53 @@ class GitHubCopilotProvider(OpenAICompatProvider): self._copilot_expires_at = time.time() + int(refresh_in) self._copilot_access_token = str(token) return self._copilot_access_token + + async def _refresh_client_api_key(self) -> str: + token = await self._get_copilot_access_token() + self.api_key = token + self._client.api_key = token + return token + + async def chat( + self, + messages: list[dict[str, object]], + tools: list[dict[str, object]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, object] | None = None, + ): + await self._refresh_client_api_key() + return await super().chat( + messages=messages, + tools=tools, + model=model, + max_tokens=max_tokens, + temperature=temperature, + reasoning_effort=reasoning_effort, + tool_choice=tool_choice, + ) + + async def chat_stream( + self, + messages: list[dict[str, object]], + tools: list[dict[str, object]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, object] | None = None, + on_content_delta: Callable[[str], None] | None = None, + ): + await self._refresh_client_api_key() + return await super().chat_stream( + messages=messages, + tools=tools, + model=model, + max_tokens=max_tokens, + temperature=temperature, + reasoning_effort=reasoning_effort, + tool_choice=tool_choice, + on_content_delta=on_content_delta, + ) diff --git a/tests/cli/test_commands.py b/tests/cli/test_commands.py index b9869e74d..0f6ff8177 100644 --- a/tests/cli/test_commands.py +++ b/tests/cli/test_commands.py @@ -357,6 +357,35 @@ def test_github_copilot_provider_strips_prefixed_model_name(): assert kwargs["model"] == "gpt-5.1" +@pytest.mark.asyncio +async def test_github_copilot_provider_refreshes_client_api_key_before_chat(): + from nanobot.providers.github_copilot_provider import GitHubCopilotProvider + + mock_client = MagicMock() + mock_client.api_key = "no-key" + mock_client.chat.completions.create = AsyncMock(return_value={ + "choices": [{"message": {"content": "ok"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + }) + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI", return_value=mock_client): + provider = GitHubCopilotProvider(default_model="github-copilot/gpt-5.1") + + provider._get_copilot_access_token = AsyncMock(return_value="copilot-access-token") + + response = await provider.chat( + messages=[{"role": "user", "content": "hi"}], + model="github-copilot/gpt-5.1", + max_tokens=16, + temperature=0.1, + ) + + assert response.content == "ok" + assert provider._client.api_key == "copilot-access-token" + provider._get_copilot_access_token.assert_awaited_once() + mock_client.chat.completions.create.assert_awaited_once() + + def test_openai_codex_strip_prefix_supports_hyphen_and_underscore(): assert _strip_model_prefix("openai-codex/gpt-5.1-codex") == "gpt-5.1-codex" assert _strip_model_prefix("openai_codex/gpt-5.1-codex") == "gpt-5.1-codex" From 2ec68582eb78113be3dfce4b1bf3165668750af6 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Wed, 1 Apr 2026 19:37:08 +0000 Subject: [PATCH 03/31] fix(sdk): route github copilot through oauth provider --- nanobot/nanobot.py | 4 ++++ tests/test_nanobot_facade.py | 21 +++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/nanobot/nanobot.py b/nanobot/nanobot.py index 137688455..84fb70934 100644 --- a/nanobot/nanobot.py +++ b/nanobot/nanobot.py @@ -135,6 +135,10 @@ def _make_provider(config: Any) -> Any: from nanobot.providers.openai_codex_provider import OpenAICodexProvider provider = OpenAICodexProvider(default_model=model) + elif backend == "github_copilot": + from nanobot.providers.github_copilot_provider import GitHubCopilotProvider + + provider = GitHubCopilotProvider(default_model=model) elif backend == "azure_openai": from nanobot.providers.azure_openai_provider import AzureOpenAIProvider diff --git a/tests/test_nanobot_facade.py b/tests/test_nanobot_facade.py index 9d0d8a175..9ad9c5db1 100644 --- a/tests/test_nanobot_facade.py +++ b/tests/test_nanobot_facade.py @@ -125,6 +125,27 @@ def test_workspace_override(tmp_path): assert bot._loop.workspace == custom_ws +def test_sdk_make_provider_uses_github_copilot_backend(): + from nanobot.config.schema import Config + from nanobot.nanobot import _make_provider + + config = Config.model_validate( + { + "agents": { + "defaults": { + "provider": "github-copilot", + "model": "github-copilot/gpt-4.1", + } + } + } + ) + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = _make_provider(config) + + assert provider.__class__.__name__ == "GitHubCopilotProvider" + + @pytest.mark.asyncio async def test_run_custom_session_key(tmp_path): from nanobot.bus.events import OutboundMessage From 7e719f41cc7b4c4edf9f5900ff25d91b134e26d2 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Wed, 1 Apr 2026 19:43:41 +0000 Subject: [PATCH 04/31] test(providers): cover github copilot lazy export --- tests/providers/test_providers_init.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/providers/test_providers_init.py b/tests/providers/test_providers_init.py index 32cbab478..d6912b437 100644 --- a/tests/providers/test_providers_init.py +++ b/tests/providers/test_providers_init.py @@ -11,6 +11,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None: monkeypatch.delitem(sys.modules, "nanobot.providers.anthropic_provider", raising=False) monkeypatch.delitem(sys.modules, "nanobot.providers.openai_compat_provider", raising=False) monkeypatch.delitem(sys.modules, "nanobot.providers.openai_codex_provider", raising=False) + monkeypatch.delitem(sys.modules, "nanobot.providers.github_copilot_provider", raising=False) monkeypatch.delitem(sys.modules, "nanobot.providers.azure_openai_provider", raising=False) providers = importlib.import_module("nanobot.providers") @@ -18,6 +19,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None: assert "nanobot.providers.anthropic_provider" not in sys.modules assert "nanobot.providers.openai_compat_provider" not in sys.modules assert "nanobot.providers.openai_codex_provider" not in sys.modules + assert "nanobot.providers.github_copilot_provider" not in sys.modules assert "nanobot.providers.azure_openai_provider" not in sys.modules assert providers.__all__ == [ "LLMProvider", @@ -25,6 +27,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None: "AnthropicProvider", "OpenAICompatProvider", "OpenAICodexProvider", + "GitHubCopilotProvider", "AzureOpenAIProvider", ] From 6973bfff24b66d81bcb8d673ee6b745dfdbd0f4f Mon Sep 17 00:00:00 2001 From: WormW Date: Wed, 25 Mar 2026 17:37:56 +0800 Subject: [PATCH 05/31] fix(agent): message tool incorrectly replies to original chat when targeting different chat_id When the message tool is used to send a message to a different chat_id than the current conversation, it was incorrectly including the default message_id from the original context. This caused channels like Feishu to send the message as a reply to the original chat instead of creating a new message in the target chat. Changes: - Only use default message_id when chat_id matches the default context - When targeting a different chat, set message_id to None to avoid unintended reply behavior --- nanobot/agent/tools/message.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/nanobot/agent/tools/message.py b/nanobot/agent/tools/message.py index c8d50cf1e..efbadca10 100644 --- a/nanobot/agent/tools/message.py +++ b/nanobot/agent/tools/message.py @@ -86,7 +86,13 @@ class MessageTool(Tool): ) -> str: channel = channel or self._default_channel chat_id = chat_id or self._default_chat_id - message_id = message_id or self._default_message_id + # Only use default message_id if chat_id matches the default context. + # If targeting a different chat, don't reply to the original message. + if chat_id == self._default_chat_id: + message_id = message_id or self._default_message_id + else: + # Targeting a different chat - don't use default message_id + message_id = None if not channel or not chat_id: return "Error: No target channel/chat specified" @@ -101,7 +107,7 @@ class MessageTool(Tool): media=media or [], metadata={ "message_id": message_id, - }, + } if message_id else {}, ) try: From ddc9fc4fd286025aebaab5fb3f2f032a18ed2478 Mon Sep 17 00:00:00 2001 From: WormW Date: Wed, 1 Apr 2026 12:32:15 +0800 Subject: [PATCH 06/31] fix: also check channel match before inheriting default message_id Different channels could theoretically share the same chat_id. Check both channel and chat_id to avoid cross-channel reply issues. Co-authored-by: layla <111667698+04cb@users.noreply.github.com> --- nanobot/agent/tools/message.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/nanobot/agent/tools/message.py b/nanobot/agent/tools/message.py index efbadca10..3ac813248 100644 --- a/nanobot/agent/tools/message.py +++ b/nanobot/agent/tools/message.py @@ -86,12 +86,14 @@ class MessageTool(Tool): ) -> str: channel = channel or self._default_channel chat_id = chat_id or self._default_chat_id - # Only use default message_id if chat_id matches the default context. - # If targeting a different chat, don't reply to the original message. - if chat_id == self._default_chat_id: + # Only inherit default message_id when targeting the same channel+chat. + # Cross-chat sends must not carry the original message_id, because + # some channels (e.g. Feishu) use it to determine the target + # conversation via their Reply API, which would route the message + # to the wrong chat entirely. + if channel == self._default_channel and chat_id == self._default_chat_id: message_id = message_id or self._default_message_id else: - # Targeting a different chat - don't use default message_id message_id = None if not channel or not chat_id: From bc2e474079a38f0f68039a8767cff088682fd6c0 Mon Sep 17 00:00:00 2001 From: "zhangxiaoyu.york" Date: Tue, 31 Mar 2026 23:27:39 +0800 Subject: [PATCH 07/31] Fix ExecTool to block root directory paths when restrict_to_workspace is enabled --- nanobot/agent/tools/shell.py | 4 +++- tests/tools/test_tool_validation.py | 8 ++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index ed552b33e..b051edffc 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -186,7 +186,9 @@ class ExecTool(Tool): @staticmethod def _extract_absolute_paths(command: str) -> list[str]: - win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]+", command) # Windows: C:\... + # Windows: match drive-root paths like `C:\` as well as `C:\path\to\file` + # NOTE: `*` is required so `C:\` (nothing after the slash) is still extracted. + win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]*", command) posix_paths = re.findall(r"(?:^|[\s|>'\"])(/[^\s\"'>;|<]+)", command) # POSIX: /absolute only home_paths = re.findall(r"(?:^|[\s|>'\"])(~[^\s\"'>;|<]*)", command) # POSIX/Windows home shortcut: ~ return win_paths + posix_paths + home_paths diff --git a/tests/tools/test_tool_validation.py b/tests/tools/test_tool_validation.py index a95418fe5..af4675310 100644 --- a/tests/tools/test_tool_validation.py +++ b/tests/tools/test_tool_validation.py @@ -95,6 +95,14 @@ def test_exec_extract_absolute_paths_keeps_full_windows_path() -> None: assert paths == [r"C:\user\workspace\txt"] +def test_exec_extract_absolute_paths_captures_windows_drive_root_path() -> None: + """Windows drive root paths like `E:\\` must be extracted for workspace guarding.""" + # Note: raw strings cannot end with a single backslash. + cmd = "dir E:\\" + paths = ExecTool._extract_absolute_paths(cmd) + assert paths == ["E:\\"] + + def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None: cmd = ".venv/bin/python script.py" paths = ExecTool._extract_absolute_paths(cmd) From 485c75e065808aa3d27bb35805d782d3365a5794 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Wed, 1 Apr 2026 19:52:54 +0000 Subject: [PATCH 08/31] test(exec): verify windows drive-root workspace guard --- tests/tools/test_tool_validation.py | 39 +++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/tools/test_tool_validation.py b/tests/tools/test_tool_validation.py index af4675310..98a3dc903 100644 --- a/tests/tools/test_tool_validation.py +++ b/tests/tools/test_tool_validation.py @@ -142,6 +142,45 @@ def test_exec_guard_blocks_quoted_home_path_outside_workspace(tmp_path) -> None: assert error == "Error: Command blocked by safety guard (path outside working dir)" +def test_exec_guard_blocks_windows_drive_root_outside_workspace(monkeypatch) -> None: + import nanobot.agent.tools.shell as shell_mod + + class FakeWindowsPath: + def __init__(self, raw: str) -> None: + self.raw = raw.rstrip("\\") + ("\\" if raw.endswith("\\") else "") + + def resolve(self) -> "FakeWindowsPath": + return self + + def expanduser(self) -> "FakeWindowsPath": + return self + + def is_absolute(self) -> bool: + return len(self.raw) >= 3 and self.raw[1:3] == ":\\" + + @property + def parents(self) -> list["FakeWindowsPath"]: + if not self.is_absolute(): + return [] + trimmed = self.raw.rstrip("\\") + if len(trimmed) <= 2: + return [] + idx = trimmed.rfind("\\") + if idx <= 2: + return [FakeWindowsPath(trimmed[:2] + "\\")] + parent = FakeWindowsPath(trimmed[:idx]) + return [parent, *parent.parents] + + def __eq__(self, other: object) -> bool: + return isinstance(other, FakeWindowsPath) and self.raw.lower() == other.raw.lower() + + monkeypatch.setattr(shell_mod, "Path", FakeWindowsPath) + + tool = ExecTool(restrict_to_workspace=True) + error = tool._guard_command("dir E:\\", "E:\\workspace") + assert error == "Error: Command blocked by safety guard (path outside working dir)" + + # --- cast_params tests --- From 05fe73947f219be405be57d9a27eb97e00fa4953 Mon Sep 17 00:00:00 2001 From: Tejas1Koli Date: Wed, 1 Apr 2026 00:51:49 +0530 Subject: [PATCH 09/31] fix(providers): only apply cache_control for Claude models on OpenRouter --- nanobot/providers/openai_compat_provider.py | 117 +++++++++++++------- 1 file changed, 79 insertions(+), 38 deletions(-) diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 397b8e797..a033b44ef 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -18,10 +18,17 @@ from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest if TYPE_CHECKING: from nanobot.providers.registry import ProviderSpec -_ALLOWED_MSG_KEYS = frozenset({ - "role", "content", "tool_calls", "tool_call_id", "name", - "reasoning_content", "extra_content", -}) +_ALLOWED_MSG_KEYS = frozenset( + { + "role", + "content", + "tool_calls", + "tool_call_id", + "name", + "reasoning_content", + "extra_content", + } +) _ALNUM = string.ascii_letters + string.digits _STANDARD_TC_KEYS = frozenset({"id", "type", "index", "function"}) @@ -59,7 +66,9 @@ def _coerce_dict(value: Any) -> dict[str, Any] | None: return None -def _extract_tc_extras(tc: Any) -> tuple[ +def _extract_tc_extras( + tc: Any, +) -> tuple[ dict[str, Any] | None, dict[str, Any] | None, dict[str, Any] | None, @@ -75,14 +84,18 @@ def _extract_tc_extras(tc: Any) -> tuple[ prov = None fn_prov = None if tc_dict is not None: - leftover = {k: v for k, v in tc_dict.items() - if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None} + leftover = { + k: v + for k, v in tc_dict.items() + if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None + } if leftover: prov = leftover fn = _coerce_dict(tc_dict.get("function")) if fn is not None: - fn_leftover = {k: v for k, v in fn.items() - if k not in _STANDARD_FN_KEYS and v is not None} + fn_leftover = { + k: v for k, v in fn.items() if k not in _STANDARD_FN_KEYS and v is not None + } if fn_leftover: fn_prov = fn_leftover else: @@ -163,9 +176,12 @@ class OpenAICompatProvider(LLMProvider): def _mark(msg: dict[str, Any]) -> dict[str, Any]: content = msg.get("content") if isinstance(content, str): - return {**msg, "content": [ - {"type": "text", "text": content, "cache_control": cache_marker}, - ]} + return { + **msg, + "content": [ + {"type": "text", "text": content, "cache_control": cache_marker}, + ], + } if isinstance(content, list) and content: nc = list(content) nc[-1] = {**nc[-1], "cache_control": cache_marker} @@ -235,7 +251,9 @@ class OpenAICompatProvider(LLMProvider): spec = self._spec if spec and spec.supports_prompt_caching: - messages, tools = self._apply_cache_control(messages, tools) + model_name = model or self.default_model + if any(model_name.lower().startswith(k) for k in ("anthropic/", "claude")): + messages, tools = self._apply_cache_control(messages, tools) if spec and spec.strip_model_prefix: model_name = model_name.split("/")[-1] @@ -348,7 +366,9 @@ class OpenAICompatProvider(LLMProvider): finish_reason=str(response_map.get("finish_reason") or "stop"), usage=self._extract_usage(response_map), ) - return LLMResponse(content="Error: API returned empty choices.", finish_reason="error") + return LLMResponse( + content="Error: API returned empty choices.", finish_reason="error" + ) choice0 = self._maybe_mapping(choices[0]) or {} msg0 = self._maybe_mapping(choice0.get("message")) or {} @@ -378,14 +398,16 @@ class OpenAICompatProvider(LLMProvider): if isinstance(args, str): args = json_repair.loads(args) ec, prov, fn_prov = _extract_tc_extras(tc) - parsed_tool_calls.append(ToolCallRequest( - id=_short_tool_id(), - name=str(fn.get("name") or ""), - arguments=args if isinstance(args, dict) else {}, - extra_content=ec, - provider_specific_fields=prov, - function_provider_specific_fields=fn_prov, - )) + parsed_tool_calls.append( + ToolCallRequest( + id=_short_tool_id(), + name=str(fn.get("name") or ""), + arguments=args if isinstance(args, dict) else {}, + extra_content=ec, + provider_specific_fields=prov, + function_provider_specific_fields=fn_prov, + ) + ) return LLMResponse( content=content, @@ -419,14 +441,16 @@ class OpenAICompatProvider(LLMProvider): if isinstance(args, str): args = json_repair.loads(args) ec, prov, fn_prov = _extract_tc_extras(tc) - tool_calls.append(ToolCallRequest( - id=_short_tool_id(), - name=tc.function.name, - arguments=args, - extra_content=ec, - provider_specific_fields=prov, - function_provider_specific_fields=fn_prov, - )) + tool_calls.append( + ToolCallRequest( + id=_short_tool_id(), + name=tc.function.name, + arguments=args, + extra_content=ec, + provider_specific_fields=prov, + function_provider_specific_fields=fn_prov, + ) + ) return LLMResponse( content=content, @@ -446,10 +470,17 @@ class OpenAICompatProvider(LLMProvider): def _accum_tc(tc: Any, idx_hint: int) -> None: """Accumulate one streaming tool-call delta into *tc_bufs*.""" tc_index: int = _get(tc, "index") if _get(tc, "index") is not None else idx_hint - buf = tc_bufs.setdefault(tc_index, { - "id": "", "name": "", "arguments": "", - "extra_content": None, "prov": None, "fn_prov": None, - }) + buf = tc_bufs.setdefault( + tc_index, + { + "id": "", + "name": "", + "arguments": "", + "extra_content": None, + "prov": None, + "fn_prov": None, + }, + ) tc_id = _get(tc, "id") if tc_id: buf["id"] = str(tc_id) @@ -547,8 +578,13 @@ class OpenAICompatProvider(LLMProvider): tool_choice: str | dict[str, Any] | None = None, ) -> LLMResponse: kwargs = self._build_kwargs( - messages, tools, model, max_tokens, temperature, - reasoning_effort, tool_choice, + messages, + tools, + model, + max_tokens, + temperature, + reasoning_effort, + tool_choice, ) try: return self._parse(await self._client.chat.completions.create(**kwargs)) @@ -567,8 +603,13 @@ class OpenAICompatProvider(LLMProvider): on_content_delta: Callable[[str], Awaitable[None]] | None = None, ) -> LLMResponse: kwargs = self._build_kwargs( - messages, tools, model, max_tokens, temperature, - reasoning_effort, tool_choice, + messages, + tools, + model, + max_tokens, + temperature, + reasoning_effort, + tool_choice, ) kwargs["stream"] = True kwargs["stream_options"] = {"include_usage": True} From 42fa8fa9339b16c031fdb3671c9ee4f3d55d74de Mon Sep 17 00:00:00 2001 From: Tejas1Koli Date: Wed, 1 Apr 2026 10:36:24 +0530 Subject: [PATCH 10/31] fix(providers): only apply cache_control for Claude models on OpenRouter --- nanobot/providers/openai_compat_provider.py | 115 +++++++------------- 1 file changed, 38 insertions(+), 77 deletions(-) diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index a033b44ef..967c21976 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -18,17 +18,10 @@ from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest if TYPE_CHECKING: from nanobot.providers.registry import ProviderSpec -_ALLOWED_MSG_KEYS = frozenset( - { - "role", - "content", - "tool_calls", - "tool_call_id", - "name", - "reasoning_content", - "extra_content", - } -) +_ALLOWED_MSG_KEYS = frozenset({ + "role", "content", "tool_calls", "tool_call_id", "name", + "reasoning_content", "extra_content", +}) _ALNUM = string.ascii_letters + string.digits _STANDARD_TC_KEYS = frozenset({"id", "type", "index", "function"}) @@ -66,9 +59,7 @@ def _coerce_dict(value: Any) -> dict[str, Any] | None: return None -def _extract_tc_extras( - tc: Any, -) -> tuple[ +def _extract_tc_extras(tc: Any) -> tuple[ dict[str, Any] | None, dict[str, Any] | None, dict[str, Any] | None, @@ -84,18 +75,14 @@ def _extract_tc_extras( prov = None fn_prov = None if tc_dict is not None: - leftover = { - k: v - for k, v in tc_dict.items() - if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None - } + leftover = {k: v for k, v in tc_dict.items() + if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None} if leftover: prov = leftover fn = _coerce_dict(tc_dict.get("function")) if fn is not None: - fn_leftover = { - k: v for k, v in fn.items() if k not in _STANDARD_FN_KEYS and v is not None - } + fn_leftover = {k: v for k, v in fn.items() + if k not in _STANDARD_FN_KEYS and v is not None} if fn_leftover: fn_prov = fn_leftover else: @@ -176,12 +163,9 @@ class OpenAICompatProvider(LLMProvider): def _mark(msg: dict[str, Any]) -> dict[str, Any]: content = msg.get("content") if isinstance(content, str): - return { - **msg, - "content": [ - {"type": "text", "text": content, "cache_control": cache_marker}, - ], - } + return {**msg, "content": [ + {"type": "text", "text": content, "cache_control": cache_marker}, + ]} if isinstance(content, list) and content: nc = list(content) nc[-1] = {**nc[-1], "cache_control": cache_marker} @@ -366,9 +350,7 @@ class OpenAICompatProvider(LLMProvider): finish_reason=str(response_map.get("finish_reason") or "stop"), usage=self._extract_usage(response_map), ) - return LLMResponse( - content="Error: API returned empty choices.", finish_reason="error" - ) + return LLMResponse(content="Error: API returned empty choices.", finish_reason="error") choice0 = self._maybe_mapping(choices[0]) or {} msg0 = self._maybe_mapping(choice0.get("message")) or {} @@ -398,16 +380,14 @@ class OpenAICompatProvider(LLMProvider): if isinstance(args, str): args = json_repair.loads(args) ec, prov, fn_prov = _extract_tc_extras(tc) - parsed_tool_calls.append( - ToolCallRequest( - id=_short_tool_id(), - name=str(fn.get("name") or ""), - arguments=args if isinstance(args, dict) else {}, - extra_content=ec, - provider_specific_fields=prov, - function_provider_specific_fields=fn_prov, - ) - ) + parsed_tool_calls.append(ToolCallRequest( + id=_short_tool_id(), + name=str(fn.get("name") or ""), + arguments=args if isinstance(args, dict) else {}, + extra_content=ec, + provider_specific_fields=prov, + function_provider_specific_fields=fn_prov, + )) return LLMResponse( content=content, @@ -441,16 +421,14 @@ class OpenAICompatProvider(LLMProvider): if isinstance(args, str): args = json_repair.loads(args) ec, prov, fn_prov = _extract_tc_extras(tc) - tool_calls.append( - ToolCallRequest( - id=_short_tool_id(), - name=tc.function.name, - arguments=args, - extra_content=ec, - provider_specific_fields=prov, - function_provider_specific_fields=fn_prov, - ) - ) + tool_calls.append(ToolCallRequest( + id=_short_tool_id(), + name=tc.function.name, + arguments=args, + extra_content=ec, + provider_specific_fields=prov, + function_provider_specific_fields=fn_prov, + )) return LLMResponse( content=content, @@ -470,17 +448,10 @@ class OpenAICompatProvider(LLMProvider): def _accum_tc(tc: Any, idx_hint: int) -> None: """Accumulate one streaming tool-call delta into *tc_bufs*.""" tc_index: int = _get(tc, "index") if _get(tc, "index") is not None else idx_hint - buf = tc_bufs.setdefault( - tc_index, - { - "id": "", - "name": "", - "arguments": "", - "extra_content": None, - "prov": None, - "fn_prov": None, - }, - ) + buf = tc_bufs.setdefault(tc_index, { + "id": "", "name": "", "arguments": "", + "extra_content": None, "prov": None, "fn_prov": None, + }) tc_id = _get(tc, "id") if tc_id: buf["id"] = str(tc_id) @@ -578,13 +549,8 @@ class OpenAICompatProvider(LLMProvider): tool_choice: str | dict[str, Any] | None = None, ) -> LLMResponse: kwargs = self._build_kwargs( - messages, - tools, - model, - max_tokens, - temperature, - reasoning_effort, - tool_choice, + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, ) try: return self._parse(await self._client.chat.completions.create(**kwargs)) @@ -603,13 +569,8 @@ class OpenAICompatProvider(LLMProvider): on_content_delta: Callable[[str], Awaitable[None]] | None = None, ) -> LLMResponse: kwargs = self._build_kwargs( - messages, - tools, - model, - max_tokens, - temperature, - reasoning_effort, - tool_choice, + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, ) kwargs["stream"] = True kwargs["stream_options"] = {"include_usage": True} @@ -627,4 +588,4 @@ class OpenAICompatProvider(LLMProvider): return self._handle_error(e) def get_default_model(self) -> str: - return self.default_model + return self.default_model \ No newline at end of file From da08dee144bb2d9abf819a56d9a64e67cf76a849 Mon Sep 17 00:00:00 2001 From: chengyongru <61816729+chengyongru@users.noreply.github.com> Date: Tue, 31 Mar 2026 09:48:43 +0800 Subject: [PATCH 11/31] feat(provider): show cache hit rate in /status (#2645) --- nanobot/agent/loop.py | 9 + nanobot/agent/runner.py | 14 +- nanobot/providers/anthropic_provider.py | 4 + nanobot/providers/openai_compat_provider.py | 51 ++++- nanobot/utils/helpers.py | 6 +- tests/agent/test_runner.py | 79 +++++++ tests/cli/test_restart_command.py | 6 +- tests/providers/test_cached_tokens.py | 231 ++++++++++++++++++++ tests/test_build_status.py | 59 +++++ 9 files changed, 445 insertions(+), 14 deletions(-) create mode 100644 tests/providers/test_cached_tokens.py create mode 100644 tests/test_build_status.py diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index a9dc589e8..50fef58fd 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -97,6 +97,15 @@ class _LoopHook(AgentHook): logger.info("Tool call: {}({})", tc.name, args_str[:200]) self._loop._set_tool_context(self._channel, self._chat_id, self._message_id) + async def after_iteration(self, context: AgentHookContext) -> None: + u = context.usage or {} + logger.debug( + "LLM usage: prompt={} completion={} cached={}", + u.get("prompt_tokens", 0), + u.get("completion_tokens", 0), + u.get("cached_tokens", 0), + ) + def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: return self._loop._strip_think(content) diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index d6242a6b4..4fec539dd 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -60,7 +60,7 @@ class AgentRunner: messages = list(spec.initial_messages) final_content: str | None = None tools_used: list[str] = [] - usage = {"prompt_tokens": 0, "completion_tokens": 0} + usage: dict[str, int] = {} error: str | None = None stop_reason = "completed" tool_events: list[dict[str, str]] = [] @@ -92,13 +92,15 @@ class AgentRunner: response = await self.provider.chat_with_retry(**kwargs) raw_usage = response.usage or {} - usage = { - "prompt_tokens": int(raw_usage.get("prompt_tokens", 0) or 0), - "completion_tokens": int(raw_usage.get("completion_tokens", 0) or 0), - } context.response = response - context.usage = usage + context.usage = raw_usage context.tool_calls = list(response.tool_calls) + # Accumulate standard fields into result usage. + usage["prompt_tokens"] = usage.get("prompt_tokens", 0) + int(raw_usage.get("prompt_tokens", 0) or 0) + usage["completion_tokens"] = usage.get("completion_tokens", 0) + int(raw_usage.get("completion_tokens", 0) or 0) + cached = raw_usage.get("cached_tokens") + if cached: + usage["cached_tokens"] = usage.get("cached_tokens", 0) + int(cached) if response.has_tool_calls: if hook.wants_streaming(): diff --git a/nanobot/providers/anthropic_provider.py b/nanobot/providers/anthropic_provider.py index 3c789e730..fabcd5656 100644 --- a/nanobot/providers/anthropic_provider.py +++ b/nanobot/providers/anthropic_provider.py @@ -379,6 +379,10 @@ class AnthropicProvider(LLMProvider): val = getattr(response.usage, attr, 0) if val: usage[attr] = val + # Normalize to cached_tokens for downstream consistency. + cache_read = usage.get("cache_read_input_tokens", 0) + if cache_read: + usage["cached_tokens"] = cache_read return LLMResponse( content="".join(content_parts) or None, diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 967c21976..f89879c90 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -310,6 +310,13 @@ class OpenAICompatProvider(LLMProvider): @classmethod def _extract_usage(cls, response: Any) -> dict[str, int]: + """Extract token usage from an OpenAI-compatible response. + + Handles both dict-based (raw JSON) and object-based (SDK Pydantic) + responses. Provider-specific ``cached_tokens`` fields are normalised + under a single key; see the priority chain inside for details. + """ + # --- resolve usage object --- usage_obj = None response_map = cls._maybe_mapping(response) if response_map is not None: @@ -319,19 +326,53 @@ class OpenAICompatProvider(LLMProvider): usage_map = cls._maybe_mapping(usage_obj) if usage_map is not None: - return { + result = { "prompt_tokens": int(usage_map.get("prompt_tokens") or 0), "completion_tokens": int(usage_map.get("completion_tokens") or 0), "total_tokens": int(usage_map.get("total_tokens") or 0), } - - if usage_obj: - return { + elif usage_obj: + result = { "prompt_tokens": getattr(usage_obj, "prompt_tokens", 0) or 0, "completion_tokens": getattr(usage_obj, "completion_tokens", 0) or 0, "total_tokens": getattr(usage_obj, "total_tokens", 0) or 0, } - return {} + else: + return {} + + # --- cached_tokens (normalised across providers) --- + # Try nested paths first (dict), fall back to attribute (SDK object). + # Priority order ensures the most specific field wins. + for path in ( + ("prompt_tokens_details", "cached_tokens"), # OpenAI/Zhipu/MiniMax/Qwen/Mistral/xAI + ("cached_tokens",), # StepFun/Moonshot (top-level) + ("prompt_cache_hit_tokens",), # DeepSeek/SiliconFlow + ): + cached = cls._get_nested_int(usage_map, path) + if not cached and usage_obj: + cached = cls._get_nested_int(usage_obj, path) + if cached: + result["cached_tokens"] = cached + break + + return result + + @staticmethod + def _get_nested_int(obj: Any, path: tuple[str, ...]) -> int: + """Drill into *obj* by *path* segments and return an ``int`` value. + + Supports both dict-key access and attribute access so it works + uniformly with raw JSON dicts **and** SDK Pydantic models. + """ + current = obj + for segment in path: + if current is None: + return 0 + if isinstance(current, dict): + current = current.get(segment) + else: + current = getattr(current, segment, None) + return int(current or 0) if current is not None else 0 def _parse(self, response: Any) -> LLMResponse: if isinstance(response, str): diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index a7c2c2574..406a4dd45 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -255,14 +255,18 @@ def build_status_content( ) last_in = last_usage.get("prompt_tokens", 0) last_out = last_usage.get("completion_tokens", 0) + cached = last_usage.get("cached_tokens", 0) ctx_total = max(context_window_tokens, 0) ctx_pct = int((context_tokens_estimate / ctx_total) * 100) if ctx_total > 0 else 0 ctx_used_str = f"{context_tokens_estimate // 1000}k" if context_tokens_estimate >= 1000 else str(context_tokens_estimate) ctx_total_str = f"{ctx_total // 1024}k" if ctx_total > 0 else "n/a" + token_line = f"\U0001f4ca Tokens: {last_in} in / {last_out} out" + if cached and last_in: + token_line += f" ({cached * 100 // last_in}% cached)" return "\n".join([ f"\U0001f408 nanobot v{version}", f"\U0001f9e0 Model: {model}", - f"\U0001f4ca Tokens: {last_in} in / {last_out} out", + token_line, f"\U0001f4da Context: {ctx_used_str}/{ctx_total_str} ({ctx_pct}%)", f"\U0001f4ac Session: {session_msg_count} messages", f"\u23f1 Uptime: {uptime}", diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index 86b0ba710..98f1d73ae 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -333,3 +333,82 @@ async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, mon args = mgr._announce_result.await_args.args assert args[3] == "Task completed but no final response was generated." assert args[5] == "ok" + + +@pytest.mark.asyncio +async def test_runner_accumulates_usage_and_preserves_cached_tokens(): + """Runner should accumulate prompt/completion tokens across iterations + and preserve cached_tokens from provider responses.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="thinking", + tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})], + usage={"prompt_tokens": 100, "completion_tokens": 10, "cached_tokens": 80}, + ) + return LLMResponse( + content="done", + tool_calls=[], + usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="file content") + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "do task"}], + tools=tools, + model="test-model", + max_iterations=3, + )) + + # Usage should be accumulated across iterations + assert result.usage["prompt_tokens"] == 300 # 100 + 200 + assert result.usage["completion_tokens"] == 30 # 10 + 20 + assert result.usage["cached_tokens"] == 230 # 80 + 150 + + +@pytest.mark.asyncio +async def test_runner_passes_cached_tokens_to_hook_context(): + """Hook context.usage should contain cached_tokens.""" + from nanobot.agent.hook import AgentHook, AgentHookContext + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + captured_usage: list[dict] = [] + + class UsageHook(AgentHook): + async def after_iteration(self, context: AgentHookContext) -> None: + captured_usage.append(dict(context.usage)) + + async def chat_with_retry(**kwargs): + return LLMResponse( + content="done", + tool_calls=[], + usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + runner = AgentRunner(provider) + await runner.run(AgentRunSpec( + initial_messages=[], + tools=tools, + model="test-model", + max_iterations=1, + hook=UsageHook(), + )) + + assert len(captured_usage) == 1 + assert captured_usage[0]["cached_tokens"] == 150 diff --git a/tests/cli/test_restart_command.py b/tests/cli/test_restart_command.py index 3281afe2d..6efcdad0d 100644 --- a/tests/cli/test_restart_command.py +++ b/tests/cli/test_restart_command.py @@ -152,10 +152,12 @@ class TestRestartCommand: ]) await loop._run_agent_loop([]) - assert loop._last_usage == {"prompt_tokens": 9, "completion_tokens": 4} + assert loop._last_usage["prompt_tokens"] == 9 + assert loop._last_usage["completion_tokens"] == 4 await loop._run_agent_loop([]) - assert loop._last_usage == {"prompt_tokens": 0, "completion_tokens": 0} + assert loop._last_usage["prompt_tokens"] == 0 + assert loop._last_usage["completion_tokens"] == 0 @pytest.mark.asyncio async def test_status_falls_back_to_last_usage_when_context_estimate_missing(self): diff --git a/tests/providers/test_cached_tokens.py b/tests/providers/test_cached_tokens.py new file mode 100644 index 000000000..fce22cf65 --- /dev/null +++ b/tests/providers/test_cached_tokens.py @@ -0,0 +1,231 @@ +"""Tests for cached token extraction from OpenAI-compatible providers.""" + +from __future__ import annotations + +from nanobot.providers.openai_compat_provider import OpenAICompatProvider + + +class FakeUsage: + """Mimics an OpenAI SDK usage object (has attributes, not dict keys).""" + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +class FakePromptDetails: + """Mimics prompt_tokens_details sub-object.""" + def __init__(self, cached_tokens=0): + self.cached_tokens = cached_tokens + + +class _FakeSpec: + supports_prompt_caching = False + model_id_prefix = None + strip_model_prefix = False + max_completion_tokens = False + reasoning_effort = None + + +def _provider(): + from unittest.mock import MagicMock + p = OpenAICompatProvider.__new__(OpenAICompatProvider) + p.client = MagicMock() + p.spec = _FakeSpec() + return p + + +# Minimal valid choice so _parse reaches _extract_usage. +_DICT_CHOICE = {"message": {"content": "Hello"}} + +class _FakeMessage: + content = "Hello" + tool_calls = None + + +class _FakeChoice: + message = _FakeMessage() + finish_reason = "stop" + + +# --- dict-based response (raw JSON / mapping) --- + +def test_extract_usage_openai_cached_tokens_dict(): + """prompt_tokens_details.cached_tokens from a dict response.""" + p = _provider() + response = { + "choices": [_DICT_CHOICE], + "usage": { + "prompt_tokens": 2000, + "completion_tokens": 300, + "total_tokens": 2300, + "prompt_tokens_details": {"cached_tokens": 1200}, + } + } + result = p._parse(response) + assert result.usage["cached_tokens"] == 1200 + assert result.usage["prompt_tokens"] == 2000 + + +def test_extract_usage_deepseek_cached_tokens_dict(): + """prompt_cache_hit_tokens from a DeepSeek dict response.""" + p = _provider() + response = { + "choices": [_DICT_CHOICE], + "usage": { + "prompt_tokens": 1500, + "completion_tokens": 200, + "total_tokens": 1700, + "prompt_cache_hit_tokens": 1200, + "prompt_cache_miss_tokens": 300, + } + } + result = p._parse(response) + assert result.usage["cached_tokens"] == 1200 + + +def test_extract_usage_no_cached_tokens_dict(): + """Response without any cache fields -> no cached_tokens key.""" + p = _provider() + response = { + "choices": [_DICT_CHOICE], + "usage": { + "prompt_tokens": 1000, + "completion_tokens": 200, + "total_tokens": 1200, + } + } + result = p._parse(response) + assert "cached_tokens" not in result.usage + + +def test_extract_usage_openai_cached_zero_dict(): + """cached_tokens=0 should NOT be included (same as existing fields).""" + p = _provider() + response = { + "choices": [_DICT_CHOICE], + "usage": { + "prompt_tokens": 2000, + "completion_tokens": 300, + "total_tokens": 2300, + "prompt_tokens_details": {"cached_tokens": 0}, + } + } + result = p._parse(response) + assert "cached_tokens" not in result.usage + + +# --- object-based response (OpenAI SDK Pydantic model) --- + +def test_extract_usage_openai_cached_tokens_obj(): + """prompt_tokens_details.cached_tokens from an SDK object response.""" + p = _provider() + usage_obj = FakeUsage( + prompt_tokens=2000, + completion_tokens=300, + total_tokens=2300, + prompt_tokens_details=FakePromptDetails(cached_tokens=1200), + ) + response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj) + result = p._parse(response) + assert result.usage["cached_tokens"] == 1200 + + +def test_extract_usage_deepseek_cached_tokens_obj(): + """prompt_cache_hit_tokens from a DeepSeek SDK object response.""" + p = _provider() + usage_obj = FakeUsage( + prompt_tokens=1500, + completion_tokens=200, + total_tokens=1700, + prompt_cache_hit_tokens=1200, + ) + response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj) + result = p._parse(response) + assert result.usage["cached_tokens"] == 1200 + + +def test_extract_usage_stepfun_top_level_cached_tokens_dict(): + """StepFun/Moonshot: usage.cached_tokens at top level (not nested).""" + p = _provider() + response = { + "choices": [_DICT_CHOICE], + "usage": { + "prompt_tokens": 591, + "completion_tokens": 120, + "total_tokens": 711, + "cached_tokens": 512, + } + } + result = p._parse(response) + assert result.usage["cached_tokens"] == 512 + + +def test_extract_usage_stepfun_top_level_cached_tokens_obj(): + """StepFun/Moonshot: usage.cached_tokens as SDK object attribute.""" + p = _provider() + usage_obj = FakeUsage( + prompt_tokens=591, + completion_tokens=120, + total_tokens=711, + cached_tokens=512, + ) + response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj) + result = p._parse(response) + assert result.usage["cached_tokens"] == 512 + + +def test_extract_usage_priority_nested_over_top_level_dict(): + """When both nested and top-level cached_tokens exist, nested wins.""" + p = _provider() + response = { + "choices": [_DICT_CHOICE], + "usage": { + "prompt_tokens": 2000, + "completion_tokens": 300, + "total_tokens": 2300, + "prompt_tokens_details": {"cached_tokens": 100}, + "cached_tokens": 500, + } + } + result = p._parse(response) + assert result.usage["cached_tokens"] == 100 + + +def test_anthropic_maps_cache_fields_to_cached_tokens(): + """Anthropic's cache_read_input_tokens should map to cached_tokens.""" + from nanobot.providers.anthropic_provider import AnthropicProvider + + usage_obj = FakeUsage( + input_tokens=800, + output_tokens=200, + cache_creation_input_tokens=0, + cache_read_input_tokens=1200, + ) + content_block = FakeUsage(type="text", text="hello") + response = FakeUsage( + id="msg_1", + type="message", + stop_reason="end_turn", + content=[content_block], + usage=usage_obj, + ) + result = AnthropicProvider._parse_response(response) + assert result.usage["cached_tokens"] == 1200 + assert result.usage["prompt_tokens"] == 800 + + +def test_anthropic_no_cache_fields(): + """Anthropic response without cache fields should not have cached_tokens.""" + from nanobot.providers.anthropic_provider import AnthropicProvider + + usage_obj = FakeUsage(input_tokens=800, output_tokens=200) + content_block = FakeUsage(type="text", text="hello") + response = FakeUsage( + id="msg_1", + type="message", + stop_reason="end_turn", + content=[content_block], + usage=usage_obj, + ) + result = AnthropicProvider._parse_response(response) + assert "cached_tokens" not in result.usage diff --git a/tests/test_build_status.py b/tests/test_build_status.py new file mode 100644 index 000000000..d98301cf7 --- /dev/null +++ b/tests/test_build_status.py @@ -0,0 +1,59 @@ +"""Tests for build_status_content cache hit rate display.""" + +from nanobot.utils.helpers import build_status_content + + +def test_status_shows_cache_hit_rate(): + content = build_status_content( + version="0.1.0", + model="glm-4-plus", + start_time=1000000.0, + last_usage={"prompt_tokens": 2000, "completion_tokens": 300, "cached_tokens": 1200}, + context_window_tokens=128000, + session_msg_count=10, + context_tokens_estimate=5000, + ) + assert "60% cached" in content + assert "2000 in / 300 out" in content + + +def test_status_no_cache_info(): + """Without cached_tokens, display should not show cache percentage.""" + content = build_status_content( + version="0.1.0", + model="glm-4-plus", + start_time=1000000.0, + last_usage={"prompt_tokens": 2000, "completion_tokens": 300}, + context_window_tokens=128000, + session_msg_count=10, + context_tokens_estimate=5000, + ) + assert "cached" not in content.lower() + assert "2000 in / 300 out" in content + + +def test_status_zero_cached_tokens(): + """cached_tokens=0 should not show cache percentage.""" + content = build_status_content( + version="0.1.0", + model="glm-4-plus", + start_time=1000000.0, + last_usage={"prompt_tokens": 2000, "completion_tokens": 300, "cached_tokens": 0}, + context_window_tokens=128000, + session_msg_count=10, + context_tokens_estimate=5000, + ) + assert "cached" not in content.lower() + + +def test_status_100_percent_cached(): + content = build_status_content( + version="0.1.0", + model="glm-4-plus", + start_time=1000000.0, + last_usage={"prompt_tokens": 1000, "completion_tokens": 100, "cached_tokens": 1000}, + context_window_tokens=128000, + session_msg_count=5, + context_tokens_estimate=3000, + ) + assert "100% cached" in content From a3e4c77fff90242f4bd5c344789adc9e46c5ee2e Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Thu, 2 Apr 2026 04:48:11 +0000 Subject: [PATCH 12/31] fix(providers): normalize anthropic cached token usage --- nanobot/providers/anthropic_provider.py | 9 ++++++--- tests/providers/test_cached_tokens.py | 6 ++++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/nanobot/providers/anthropic_provider.py b/nanobot/providers/anthropic_provider.py index fabcd5656..8e102d305 100644 --- a/nanobot/providers/anthropic_provider.py +++ b/nanobot/providers/anthropic_provider.py @@ -370,17 +370,20 @@ class AnthropicProvider(LLMProvider): usage: dict[str, int] = {} if response.usage: + input_tokens = response.usage.input_tokens + cache_creation = getattr(response.usage, "cache_creation_input_tokens", 0) or 0 + cache_read = getattr(response.usage, "cache_read_input_tokens", 0) or 0 + total_prompt_tokens = input_tokens + cache_creation + cache_read usage = { - "prompt_tokens": response.usage.input_tokens, + "prompt_tokens": total_prompt_tokens, "completion_tokens": response.usage.output_tokens, - "total_tokens": response.usage.input_tokens + response.usage.output_tokens, + "total_tokens": total_prompt_tokens + response.usage.output_tokens, } for attr in ("cache_creation_input_tokens", "cache_read_input_tokens"): val = getattr(response.usage, attr, 0) if val: usage[attr] = val # Normalize to cached_tokens for downstream consistency. - cache_read = usage.get("cache_read_input_tokens", 0) if cache_read: usage["cached_tokens"] = cache_read diff --git a/tests/providers/test_cached_tokens.py b/tests/providers/test_cached_tokens.py index fce22cf65..1b01408a4 100644 --- a/tests/providers/test_cached_tokens.py +++ b/tests/providers/test_cached_tokens.py @@ -198,7 +198,7 @@ def test_anthropic_maps_cache_fields_to_cached_tokens(): usage_obj = FakeUsage( input_tokens=800, output_tokens=200, - cache_creation_input_tokens=0, + cache_creation_input_tokens=300, cache_read_input_tokens=1200, ) content_block = FakeUsage(type="text", text="hello") @@ -211,7 +211,9 @@ def test_anthropic_maps_cache_fields_to_cached_tokens(): ) result = AnthropicProvider._parse_response(response) assert result.usage["cached_tokens"] == 1200 - assert result.usage["prompt_tokens"] == 800 + assert result.usage["prompt_tokens"] == 2300 + assert result.usage["total_tokens"] == 2500 + assert result.usage["cache_creation_input_tokens"] == 300 def test_anthropic_no_cache_fields(): From 73e80b199a97b7576fa1c7c5a93f526076d7d27b Mon Sep 17 00:00:00 2001 From: lucario <912156837@qq.com> Date: Wed, 1 Apr 2026 23:17:13 +0800 Subject: [PATCH 13/31] feat(cron): add deliver parameter to support silent jobs, default true for backward compatibility --- nanobot/agent/tools/cron.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/nanobot/agent/tools/cron.py b/nanobot/agent/tools/cron.py index 00f726c08..89b403b71 100644 --- a/nanobot/agent/tools/cron.py +++ b/nanobot/agent/tools/cron.py @@ -97,9 +97,14 @@ class CronTool(Tool): f"(e.g. '2026-02-12T10:30:00'). Naive values default to {self._default_timezone}." ), }, - "job_id": {"type": "string", "description": "Job ID (for remove)"}, - }, - "required": ["action"], + "job_id": {"type": "string", "description": "Job ID (for remove)"}, + "deliver": { + "type": "boolean", + "description": "Whether to deliver the execution result to the user channel (default true)", + "default": true + }, + }, + "required": ["action"], } async def execute( @@ -111,12 +116,13 @@ class CronTool(Tool): tz: str | None = None, at: str | None = None, job_id: str | None = None, + deliver: bool = True, **kwargs: Any, ) -> str: if action == "add": if self._in_cron_context.get(): return "Error: cannot schedule new jobs from within a cron job execution" - return self._add_job(message, every_seconds, cron_expr, tz, at) + return self._add_job(message, every_seconds, cron_expr, tz, at, deliver) elif action == "list": return self._list_jobs() elif action == "remove": @@ -130,6 +136,7 @@ class CronTool(Tool): cron_expr: str | None, tz: str | None, at: str | None, + deliver: bool = True, ) -> str: if not message: return "Error: message is required for add" @@ -171,7 +178,7 @@ class CronTool(Tool): name=message[:30], schedule=schedule, message=message, - deliver=True, + deliver=deliver, channel=self._channel, to=self._chat_id, delete_after_run=delete_after, From 2e3cb5b20e1eba863ea05b8e35eb377e9030378b Mon Sep 17 00:00:00 2001 From: archlinux Date: Wed, 1 Apr 2026 23:25:11 +0800 Subject: [PATCH 14/31] fix default value True --- nanobot/agent/tools/cron.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanobot/agent/tools/cron.py b/nanobot/agent/tools/cron.py index 89b403b71..a78ab89b4 100644 --- a/nanobot/agent/tools/cron.py +++ b/nanobot/agent/tools/cron.py @@ -101,7 +101,7 @@ class CronTool(Tool): "deliver": { "type": "boolean", "description": "Whether to deliver the execution result to the user channel (default true)", - "default": true + "default": True }, }, "required": ["action"], From 5f2157baeb1da922fbfa154f2bd7b6f72213c2b1 Mon Sep 17 00:00:00 2001 From: lucario <912156837@qq.com> Date: Thu, 2 Apr 2026 00:05:53 +0800 Subject: [PATCH 15/31] fix(cron): move deliver param before job_id in parameters schema --- nanobot/agent/tools/cron.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanobot/agent/tools/cron.py b/nanobot/agent/tools/cron.py index a78ab89b4..850ecdc49 100644 --- a/nanobot/agent/tools/cron.py +++ b/nanobot/agent/tools/cron.py @@ -97,12 +97,12 @@ class CronTool(Tool): f"(e.g. '2026-02-12T10:30:00'). Naive values default to {self._default_timezone}." ), }, - "job_id": {"type": "string", "description": "Job ID (for remove)"}, "deliver": { "type": "boolean", "description": "Whether to deliver the execution result to the user channel (default true)", "default": True }, + "job_id": {"type": "string", "description": "Job ID (for remove)"}, }, "required": ["action"], } From 35b51c0694c87d13d5a2e40603390c5584673946 Mon Sep 17 00:00:00 2001 From: lucario <912156837@qq.com> Date: Thu, 2 Apr 2026 00:15:39 +0800 Subject: [PATCH 16/31] fix(cron): fix extra indent for deliver param --- nanobot/agent/tools/cron.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/nanobot/agent/tools/cron.py b/nanobot/agent/tools/cron.py index 850ecdc49..5205d0d63 100644 --- a/nanobot/agent/tools/cron.py +++ b/nanobot/agent/tools/cron.py @@ -97,12 +97,12 @@ class CronTool(Tool): f"(e.g. '2026-02-12T10:30:00'). Naive values default to {self._default_timezone}." ), }, - "deliver": { - "type": "boolean", - "description": "Whether to deliver the execution result to the user channel (default true)", - "default": True - }, - "job_id": {"type": "string", "description": "Job ID (for remove)"}, + "deliver": { + "type": "boolean", + "description": "Whether to deliver the execution result to the user channel (default true)", + "default": True + }, + "job_id": {"type": "string", "description": "Job ID (for remove)"}, }, "required": ["action"], } From 15faa3b1151e4ec5c350f346ece9dc3265bf342a Mon Sep 17 00:00:00 2001 From: lucario <912156837@qq.com> Date: Thu, 2 Apr 2026 00:17:26 +0800 Subject: [PATCH 17/31] fix(cron): fix extra indent for properties closing brace and required field --- nanobot/agent/tools/cron.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nanobot/agent/tools/cron.py b/nanobot/agent/tools/cron.py index 5205d0d63..f2aba0b97 100644 --- a/nanobot/agent/tools/cron.py +++ b/nanobot/agent/tools/cron.py @@ -103,8 +103,8 @@ class CronTool(Tool): "default": True }, "job_id": {"type": "string", "description": "Job ID (for remove)"}, - }, - "required": ["action"], + }, + "required": ["action"], } async def execute( From 9ba413c82e2157c2f2f4123efb79c42fa5783f60 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Thu, 2 Apr 2026 04:57:29 +0000 Subject: [PATCH 18/31] test(cron): cover deliver flag on scheduled jobs --- tests/cron/test_cron_tool_list.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/cron/test_cron_tool_list.py b/tests/cron/test_cron_tool_list.py index 22a502fa4..42ad7d419 100644 --- a/tests/cron/test_cron_tool_list.py +++ b/tests/cron/test_cron_tool_list.py @@ -285,6 +285,28 @@ def test_add_at_job_uses_default_timezone_for_naive_datetime(tmp_path) -> None: assert job.schedule.at_ms == expected +def test_add_job_delivers_by_default(tmp_path) -> None: + tool = _make_tool(tmp_path) + tool.set_context("telegram", "chat-1") + + result = tool._add_job("Morning standup", 60, None, None, None) + + assert result.startswith("Created job") + job = tool._cron.list_jobs()[0] + assert job.payload.deliver is True + + +def test_add_job_can_disable_delivery(tmp_path) -> None: + tool = _make_tool(tmp_path) + tool.set_context("telegram", "chat-1") + + result = tool._add_job("Background refresh", 60, None, None, None, deliver=False) + + assert result.startswith("Created job") + job = tool._cron.list_jobs()[0] + assert job.payload.deliver is False + + def test_list_excludes_disabled_jobs(tmp_path) -> None: tool = _make_tool(tmp_path) job = tool._cron.add_job( From 0417c3f03b6b1a5fdd61e6a48c8896c1222c66b6 Mon Sep 17 00:00:00 2001 From: Kunal Karmakar Date: Tue, 31 Mar 2026 02:05:59 +0000 Subject: [PATCH 19/31] Use OpenAI responses API --- nanobot/providers/azure_openai_provider.py | 316 +++----- nanobot/providers/openai_codex_provider.py | 192 +---- .../openai_responses_common/__init__.py | 27 + .../openai_responses_common/converters.py | 110 +++ .../openai_responses_common/parsing.py | 173 +++++ tests/providers/test_azure_openai_provider.py | 679 +++++++++--------- 6 files changed, 769 insertions(+), 728 deletions(-) create mode 100644 nanobot/providers/openai_responses_common/__init__.py create mode 100644 nanobot/providers/openai_responses_common/converters.py create mode 100644 nanobot/providers/openai_responses_common/parsing.py diff --git a/nanobot/providers/azure_openai_provider.py b/nanobot/providers/azure_openai_provider.py index d71dae917..ab4d187ae 100644 --- a/nanobot/providers/azure_openai_provider.py +++ b/nanobot/providers/azure_openai_provider.py @@ -1,31 +1,37 @@ -"""Azure OpenAI provider implementation with API version 2024-10-21.""" +"""Azure OpenAI provider using the OpenAI SDK Responses API. + +Uses ``AsyncOpenAI`` pointed at ``https://{endpoint}/openai/v1/`` which +routes to the Responses API (``/responses``). Reuses shared conversion +helpers from :mod:`nanobot.providers.openai_responses_common`. +""" from __future__ import annotations -import json import uuid from collections.abc import Awaitable, Callable from typing import Any -from urllib.parse import urljoin import httpx -import json_repair +from openai import AsyncOpenAI -from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest - -_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"}) +from nanobot.providers.base import LLMProvider, LLMResponse +from nanobot.providers.openai_responses_common import ( + consume_sse, + convert_messages, + convert_tools, + parse_response_output, +) class AzureOpenAIProvider(LLMProvider): - """ - Azure OpenAI provider with API version 2024-10-21 compliance. - + """Azure OpenAI provider backed by the Responses API. + Features: - - Hardcoded API version 2024-10-21 - - Uses model field as Azure deployment name in URL path - - Uses api-key header instead of Authorization Bearer - - Uses max_completion_tokens instead of max_tokens - - Direct HTTP calls, bypasses LiteLLM + - Uses the OpenAI Python SDK (``AsyncOpenAI``) with + ``base_url = {endpoint}/openai/v1/`` + - Calls ``client.responses.create()`` (Responses API) + - Reuses shared message/tool/SSE conversion from + ``openai_responses_common`` """ def __init__( @@ -36,40 +42,28 @@ class AzureOpenAIProvider(LLMProvider): ): super().__init__(api_key, api_base) self.default_model = default_model - self.api_version = "2024-10-21" - - # Validate required parameters + if not api_key: raise ValueError("Azure OpenAI api_key is required") if not api_base: raise ValueError("Azure OpenAI api_base is required") - - # Ensure api_base ends with / - if not api_base.endswith('/'): - api_base += '/' + + # Normalise: ensure trailing slash + if not api_base.endswith("/"): + api_base += "/" self.api_base = api_base - def _build_chat_url(self, deployment_name: str) -> str: - """Build the Azure OpenAI chat completions URL.""" - # Azure OpenAI URL format: - # https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version} - base_url = self.api_base - if not base_url.endswith('/'): - base_url += '/' - - url = urljoin( - base_url, - f"openai/deployments/{deployment_name}/chat/completions" + # SDK client targeting the Azure Responses API endpoint + base_url = f"{api_base.rstrip('/')}/openai/v1/" + self._client = AsyncOpenAI( + api_key=api_key, + base_url=base_url, + default_headers={"x-session-affinity": uuid.uuid4().hex}, ) - return f"{url}?api-version={self.api_version}" - def _build_headers(self) -> dict[str, str]: - """Build headers for Azure OpenAI API with api-key header.""" - return { - "Content-Type": "application/json", - "api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization - "x-session-affinity": uuid.uuid4().hex, # For cache locality - } + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ @staticmethod def _supports_temperature( @@ -82,36 +76,50 @@ class AzureOpenAIProvider(LLMProvider): name = deployment_name.lower() return not any(token in name for token in ("gpt-5", "o1", "o3", "o4")) - def _prepare_request_payload( + def _build_body( self, - deployment_name: str, messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None = None, - max_tokens: int = 4096, - temperature: float = 0.7, - reasoning_effort: str | None = None, - tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None, + model: str | None, + max_tokens: int, + temperature: float, + reasoning_effort: str | None, + tool_choice: str | dict[str, Any] | None, ) -> dict[str, Any]: - """Prepare the request payload with Azure OpenAI 2024-10-21 compliance.""" - payload: dict[str, Any] = { - "messages": self._sanitize_request_messages( - self._sanitize_empty_content(messages), - _AZURE_MSG_KEYS, - ), - "max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens + """Build the Responses API request body from Chat-Completions-style args.""" + deployment = model or self.default_model + instructions, input_items = convert_messages(messages) + + body: dict[str, Any] = { + "model": deployment, + "instructions": instructions or None, + "input": input_items, + "store": False, + "stream": False, } - if self._supports_temperature(deployment_name, reasoning_effort): - payload["temperature"] = temperature + if self._supports_temperature(deployment, reasoning_effort): + body["temperature"] = temperature if reasoning_effort: - payload["reasoning_effort"] = reasoning_effort + body["reasoning"] = {"effort": reasoning_effort} + body["include"] = ["reasoning.encrypted_content"] if tools: - payload["tools"] = tools - payload["tool_choice"] = tool_choice or "auto" + body["tools"] = convert_tools(tools) + body["tool_choice"] = tool_choice or "auto" - return payload + return body + + @staticmethod + def _handle_error(e: Exception) -> LLMResponse: + body = getattr(e, "body", None) or getattr(getattr(e, "response", None), "text", None) + msg = f"Error: {str(body).strip()[:500]}" if body else f"Error calling Azure OpenAI: {e}" + return LLMResponse(content=msg, finish_reason="error") + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ async def chat( self, @@ -123,92 +131,15 @@ class AzureOpenAIProvider(LLMProvider): reasoning_effort: str | None = None, tool_choice: str | dict[str, Any] | None = None, ) -> LLMResponse: - """ - Send a chat completion request to Azure OpenAI. - - Args: - messages: List of message dicts with 'role' and 'content'. - tools: Optional list of tool definitions in OpenAI format. - model: Model identifier (used as deployment name). - max_tokens: Maximum tokens in response (mapped to max_completion_tokens). - temperature: Sampling temperature. - reasoning_effort: Optional reasoning effort parameter. - - Returns: - LLMResponse with content and/or tool calls. - """ - deployment_name = model or self.default_model - url = self._build_chat_url(deployment_name) - headers = self._build_headers() - payload = self._prepare_request_payload( - deployment_name, messages, tools, max_tokens, temperature, reasoning_effort, - tool_choice=tool_choice, + body = self._build_body( + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, ) - try: - async with httpx.AsyncClient(timeout=60.0, verify=True) as client: - response = await client.post(url, headers=headers, json=payload) - if response.status_code != 200: - return LLMResponse( - content=f"Azure OpenAI API Error {response.status_code}: {response.text}", - finish_reason="error", - ) - - response_data = response.json() - return self._parse_response(response_data) - + response = await self._client.responses.create(**body) + return parse_response_output(response) except Exception as e: - return LLMResponse( - content=f"Error calling Azure OpenAI: {repr(e)}", - finish_reason="error", - ) - - def _parse_response(self, response: dict[str, Any]) -> LLMResponse: - """Parse Azure OpenAI response into our standard format.""" - try: - choice = response["choices"][0] - message = choice["message"] - - tool_calls = [] - if message.get("tool_calls"): - for tc in message["tool_calls"]: - # Parse arguments from JSON string if needed - args = tc["function"]["arguments"] - if isinstance(args, str): - args = json_repair.loads(args) - - tool_calls.append( - ToolCallRequest( - id=tc["id"], - name=tc["function"]["name"], - arguments=args, - ) - ) - - usage = {} - if response.get("usage"): - usage_data = response["usage"] - usage = { - "prompt_tokens": usage_data.get("prompt_tokens", 0), - "completion_tokens": usage_data.get("completion_tokens", 0), - "total_tokens": usage_data.get("total_tokens", 0), - } - - reasoning_content = message.get("reasoning_content") or None - - return LLMResponse( - content=message.get("content"), - tool_calls=tool_calls, - finish_reason=choice.get("finish_reason", "stop"), - usage=usage, - reasoning_content=reasoning_content, - ) - - except (KeyError, IndexError) as e: - return LLMResponse( - content=f"Error parsing Azure OpenAI response: {str(e)}", - finish_reason="error", - ) + return self._handle_error(e) async def chat_stream( self, @@ -221,89 +152,40 @@ class AzureOpenAIProvider(LLMProvider): tool_choice: str | dict[str, Any] | None = None, on_content_delta: Callable[[str], Awaitable[None]] | None = None, ) -> LLMResponse: - """Stream a chat completion via Azure OpenAI SSE.""" - deployment_name = model or self.default_model - url = self._build_chat_url(deployment_name) - headers = self._build_headers() - payload = self._prepare_request_payload( - deployment_name, messages, tools, max_tokens, temperature, - reasoning_effort, tool_choice=tool_choice, + body = self._build_body( + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, ) - payload["stream"] = True + body["stream"] = True try: - async with httpx.AsyncClient(timeout=60.0, verify=True) as client: - async with client.stream("POST", url, headers=headers, json=payload) as response: + # Use raw httpx stream via the SDK's base URL so we can reuse + # the shared Responses-API SSE parser (same as Codex provider). + base_url = str(self._client.base_url).rstrip("/") + url = f"{base_url}/responses" + headers = { + "Authorization": f"Bearer {self._client.api_key}", + "Content-Type": "application/json", + **(self._client._custom_headers or {}), + } + async with httpx.AsyncClient(timeout=60.0, verify=True) as http: + async with http.stream("POST", url, headers=headers, json=body) as response: if response.status_code != 200: text = await response.aread() return LLMResponse( content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}", finish_reason="error", ) - return await self._consume_stream(response, on_content_delta) + content, tool_calls, finish_reason = await consume_sse( + response, on_content_delta, + ) + return LLMResponse( + content=content or None, + tool_calls=tool_calls, + finish_reason=finish_reason, + ) except Exception as e: - return LLMResponse(content=f"Error calling Azure OpenAI: {repr(e)}", finish_reason="error") - - async def _consume_stream( - self, - response: httpx.Response, - on_content_delta: Callable[[str], Awaitable[None]] | None, - ) -> LLMResponse: - """Parse Azure OpenAI SSE stream into an LLMResponse.""" - content_parts: list[str] = [] - tool_call_buffers: dict[int, dict[str, str]] = {} - finish_reason = "stop" - - async for line in response.aiter_lines(): - if not line.startswith("data: "): - continue - data = line[6:].strip() - if data == "[DONE]": - break - try: - chunk = json.loads(data) - except Exception: - continue - - choices = chunk.get("choices") or [] - if not choices: - continue - choice = choices[0] - if choice.get("finish_reason"): - finish_reason = choice["finish_reason"] - delta = choice.get("delta") or {} - - text = delta.get("content") - if text: - content_parts.append(text) - if on_content_delta: - await on_content_delta(text) - - for tc in delta.get("tool_calls") or []: - idx = tc.get("index", 0) - buf = tool_call_buffers.setdefault(idx, {"id": "", "name": "", "arguments": ""}) - if tc.get("id"): - buf["id"] = tc["id"] - fn = tc.get("function") or {} - if fn.get("name"): - buf["name"] = fn["name"] - if fn.get("arguments"): - buf["arguments"] += fn["arguments"] - - tool_calls = [ - ToolCallRequest( - id=buf["id"], name=buf["name"], - arguments=json_repair.loads(buf["arguments"]) if buf["arguments"] else {}, - ) - for buf in tool_call_buffers.values() - ] - - return LLMResponse( - content="".join(content_parts) or None, - tool_calls=tool_calls, - finish_reason=finish_reason, - ) + return self._handle_error(e) def get_default_model(self) -> str: - """Get the default model (also used as default deployment name).""" return self.default_model \ No newline at end of file diff --git a/nanobot/providers/openai_codex_provider.py b/nanobot/providers/openai_codex_provider.py index 1c6bc7075..68145173b 100644 --- a/nanobot/providers/openai_codex_provider.py +++ b/nanobot/providers/openai_codex_provider.py @@ -6,13 +6,18 @@ import asyncio import hashlib import json from collections.abc import Awaitable, Callable -from typing import Any, AsyncGenerator +from typing import Any import httpx from loguru import logger from oauth_cli_kit import get_token as get_codex_token from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest +from nanobot.providers.openai_responses_common import ( + consume_sse, + convert_messages, + convert_tools, +) DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses" DEFAULT_ORIGINATOR = "nanobot" @@ -36,7 +41,7 @@ class OpenAICodexProvider(LLMProvider): ) -> LLMResponse: """Shared request logic for both chat() and chat_stream().""" model = model or self.default_model - system_prompt, input_items = _convert_messages(messages) + system_prompt, input_items = convert_messages(messages) token = await asyncio.to_thread(get_codex_token) headers = _build_headers(token.account_id, token.access) @@ -56,7 +61,7 @@ class OpenAICodexProvider(LLMProvider): if reasoning_effort: body["reasoning"] = {"effort": reasoning_effort} if tools: - body["tools"] = _convert_tools(tools) + body["tools"] = convert_tools(tools) try: try: @@ -127,96 +132,7 @@ async def _request_codex( if response.status_code != 200: text = await response.aread() raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore"))) - return await _consume_sse(response, on_content_delta) - - -def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]: - """Convert OpenAI function-calling schema to Codex flat format.""" - converted: list[dict[str, Any]] = [] - for tool in tools: - fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool - name = fn.get("name") - if not name: - continue - params = fn.get("parameters") or {} - converted.append({ - "type": "function", - "name": name, - "description": fn.get("description") or "", - "parameters": params if isinstance(params, dict) else {}, - }) - return converted - - -def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]: - system_prompt = "" - input_items: list[dict[str, Any]] = [] - - for idx, msg in enumerate(messages): - role = msg.get("role") - content = msg.get("content") - - if role == "system": - system_prompt = content if isinstance(content, str) else "" - continue - - if role == "user": - input_items.append(_convert_user_message(content)) - continue - - if role == "assistant": - if isinstance(content, str) and content: - input_items.append({ - "type": "message", "role": "assistant", - "content": [{"type": "output_text", "text": content}], - "status": "completed", "id": f"msg_{idx}", - }) - for tool_call in msg.get("tool_calls", []) or []: - fn = tool_call.get("function") or {} - call_id, item_id = _split_tool_call_id(tool_call.get("id")) - input_items.append({ - "type": "function_call", - "id": item_id or f"fc_{idx}", - "call_id": call_id or f"call_{idx}", - "name": fn.get("name"), - "arguments": fn.get("arguments") or "{}", - }) - continue - - if role == "tool": - call_id, _ = _split_tool_call_id(msg.get("tool_call_id")) - output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False) - input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text}) - - return system_prompt, input_items - - -def _convert_user_message(content: Any) -> dict[str, Any]: - if isinstance(content, str): - return {"role": "user", "content": [{"type": "input_text", "text": content}]} - if isinstance(content, list): - converted: list[dict[str, Any]] = [] - for item in content: - if not isinstance(item, dict): - continue - if item.get("type") == "text": - converted.append({"type": "input_text", "text": item.get("text", "")}) - elif item.get("type") == "image_url": - url = (item.get("image_url") or {}).get("url") - if url: - converted.append({"type": "input_image", "image_url": url, "detail": "auto"}) - if converted: - return {"role": "user", "content": converted} - return {"role": "user", "content": [{"type": "input_text", "text": ""}]} - - -def _split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]: - if isinstance(tool_call_id, str) and tool_call_id: - if "|" in tool_call_id: - call_id, item_id = tool_call_id.split("|", 1) - return call_id, item_id or None - return tool_call_id, None - return "call_0", None + return await consume_sse(response, on_content_delta) def _prompt_cache_key(messages: list[dict[str, Any]]) -> str: @@ -224,96 +140,6 @@ def _prompt_cache_key(messages: list[dict[str, Any]]) -> str: return hashlib.sha256(raw.encode("utf-8")).hexdigest() -async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]: - buffer: list[str] = [] - async for line in response.aiter_lines(): - if line == "": - if buffer: - data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")] - buffer = [] - if not data_lines: - continue - data = "\n".join(data_lines).strip() - if not data or data == "[DONE]": - continue - try: - yield json.loads(data) - except Exception: - continue - continue - buffer.append(line) - - -async def _consume_sse( - response: httpx.Response, - on_content_delta: Callable[[str], Awaitable[None]] | None = None, -) -> tuple[str, list[ToolCallRequest], str]: - content = "" - tool_calls: list[ToolCallRequest] = [] - tool_call_buffers: dict[str, dict[str, Any]] = {} - finish_reason = "stop" - - async for event in _iter_sse(response): - event_type = event.get("type") - if event_type == "response.output_item.added": - item = event.get("item") or {} - if item.get("type") == "function_call": - call_id = item.get("call_id") - if not call_id: - continue - tool_call_buffers[call_id] = { - "id": item.get("id") or "fc_0", - "name": item.get("name"), - "arguments": item.get("arguments") or "", - } - elif event_type == "response.output_text.delta": - delta_text = event.get("delta") or "" - content += delta_text - if on_content_delta and delta_text: - await on_content_delta(delta_text) - elif event_type == "response.function_call_arguments.delta": - call_id = event.get("call_id") - if call_id and call_id in tool_call_buffers: - tool_call_buffers[call_id]["arguments"] += event.get("delta") or "" - elif event_type == "response.function_call_arguments.done": - call_id = event.get("call_id") - if call_id and call_id in tool_call_buffers: - tool_call_buffers[call_id]["arguments"] = event.get("arguments") or "" - elif event_type == "response.output_item.done": - item = event.get("item") or {} - if item.get("type") == "function_call": - call_id = item.get("call_id") - if not call_id: - continue - buf = tool_call_buffers.get(call_id) or {} - args_raw = buf.get("arguments") or item.get("arguments") or "{}" - try: - args = json.loads(args_raw) - except Exception: - args = {"raw": args_raw} - tool_calls.append( - ToolCallRequest( - id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}", - name=buf.get("name") or item.get("name"), - arguments=args, - ) - ) - elif event_type == "response.completed": - status = (event.get("response") or {}).get("status") - finish_reason = _map_finish_reason(status) - elif event_type in {"error", "response.failed"}: - raise RuntimeError("Codex response failed") - - return content, tool_calls, finish_reason - - -_FINISH_REASON_MAP = {"completed": "stop", "incomplete": "length", "failed": "error", "cancelled": "error"} - - -def _map_finish_reason(status: str | None) -> str: - return _FINISH_REASON_MAP.get(status or "completed", "stop") - - def _friendly_error(status_code: int, raw: str) -> str: if status_code == 429: return "ChatGPT usage quota exceeded or rate limit triggered. Please try again later." diff --git a/nanobot/providers/openai_responses_common/__init__.py b/nanobot/providers/openai_responses_common/__init__.py new file mode 100644 index 000000000..cfc327bdb --- /dev/null +++ b/nanobot/providers/openai_responses_common/__init__.py @@ -0,0 +1,27 @@ +"""Shared helpers for OpenAI Responses API providers (Codex, Azure OpenAI).""" + +from nanobot.providers.openai_responses_common.converters import ( + convert_messages, + convert_tools, + convert_user_message, + split_tool_call_id, +) +from nanobot.providers.openai_responses_common.parsing import ( + FINISH_REASON_MAP, + consume_sse, + iter_sse, + map_finish_reason, + parse_response_output, +) + +__all__ = [ + "convert_messages", + "convert_tools", + "convert_user_message", + "split_tool_call_id", + "iter_sse", + "consume_sse", + "map_finish_reason", + "parse_response_output", + "FINISH_REASON_MAP", +] diff --git a/nanobot/providers/openai_responses_common/converters.py b/nanobot/providers/openai_responses_common/converters.py new file mode 100644 index 000000000..37596692d --- /dev/null +++ b/nanobot/providers/openai_responses_common/converters.py @@ -0,0 +1,110 @@ +"""Convert Chat Completions messages/tools to Responses API format.""" + +from __future__ import annotations + +import json +from typing import Any + + +def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]: + """Convert Chat Completions messages to Responses API input items. + + Returns ``(system_prompt, input_items)`` where *system_prompt* is extracted + from any ``system`` role message and *input_items* is the Responses API + ``input`` array. + """ + system_prompt = "" + input_items: list[dict[str, Any]] = [] + + for idx, msg in enumerate(messages): + role = msg.get("role") + content = msg.get("content") + + if role == "system": + system_prompt = content if isinstance(content, str) else "" + continue + + if role == "user": + input_items.append(convert_user_message(content)) + continue + + if role == "assistant": + if isinstance(content, str) and content: + input_items.append({ + "type": "message", "role": "assistant", + "content": [{"type": "output_text", "text": content}], + "status": "completed", "id": f"msg_{idx}", + }) + for tool_call in msg.get("tool_calls", []) or []: + fn = tool_call.get("function") or {} + call_id, item_id = split_tool_call_id(tool_call.get("id")) + input_items.append({ + "type": "function_call", + "id": item_id or f"fc_{idx}", + "call_id": call_id or f"call_{idx}", + "name": fn.get("name"), + "arguments": fn.get("arguments") or "{}", + }) + continue + + if role == "tool": + call_id, _ = split_tool_call_id(msg.get("tool_call_id")) + output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False) + input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text}) + + return system_prompt, input_items + + +def convert_user_message(content: Any) -> dict[str, Any]: + """Convert a user message's content to Responses API format. + + Handles plain strings, ``text`` blocks → ``input_text``, and + ``image_url`` blocks → ``input_image``. + """ + if isinstance(content, str): + return {"role": "user", "content": [{"type": "input_text", "text": content}]} + if isinstance(content, list): + converted: list[dict[str, Any]] = [] + for item in content: + if not isinstance(item, dict): + continue + if item.get("type") == "text": + converted.append({"type": "input_text", "text": item.get("text", "")}) + elif item.get("type") == "image_url": + url = (item.get("image_url") or {}).get("url") + if url: + converted.append({"type": "input_image", "image_url": url, "detail": "auto"}) + if converted: + return {"role": "user", "content": converted} + return {"role": "user", "content": [{"type": "input_text", "text": ""}]} + + +def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Convert OpenAI function-calling tool schema to Responses API flat format.""" + converted: list[dict[str, Any]] = [] + for tool in tools: + fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool + name = fn.get("name") + if not name: + continue + params = fn.get("parameters") or {} + converted.append({ + "type": "function", + "name": name, + "description": fn.get("description") or "", + "parameters": params if isinstance(params, dict) else {}, + }) + return converted + + +def split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]: + """Split a compound ``call_id|item_id`` string. + + Returns ``(call_id, item_id)`` where *item_id* may be ``None``. + """ + if isinstance(tool_call_id, str) and tool_call_id: + if "|" in tool_call_id: + call_id, item_id = tool_call_id.split("|", 1) + return call_id, item_id or None + return tool_call_id, None + return "call_0", None diff --git a/nanobot/providers/openai_responses_common/parsing.py b/nanobot/providers/openai_responses_common/parsing.py new file mode 100644 index 000000000..e0d5f4462 --- /dev/null +++ b/nanobot/providers/openai_responses_common/parsing.py @@ -0,0 +1,173 @@ +"""Parse Responses API SSE streams and SDK response objects.""" + +from __future__ import annotations + +import json +from collections.abc import Awaitable, Callable +from typing import Any, AsyncGenerator + +import httpx + +from nanobot.providers.base import LLMResponse, ToolCallRequest + +FINISH_REASON_MAP = { + "completed": "stop", + "incomplete": "length", + "failed": "error", + "cancelled": "error", +} + + +def map_finish_reason(status: str | None) -> str: + """Map a Responses API status string to a Chat-Completions-style finish_reason.""" + return FINISH_REASON_MAP.get(status or "completed", "stop") + + +async def iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]: + """Yield parsed JSON events from a Responses API SSE stream.""" + buffer: list[str] = [] + async for line in response.aiter_lines(): + if line == "": + if buffer: + data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")] + buffer = [] + if not data_lines: + continue + data = "\n".join(data_lines).strip() + if not data or data == "[DONE]": + continue + try: + yield json.loads(data) + except Exception: + continue + continue + buffer.append(line) + + +async def consume_sse( + response: httpx.Response, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, +) -> tuple[str, list[ToolCallRequest], str]: + """Consume a Responses API SSE stream into ``(content, tool_calls, finish_reason)``.""" + content = "" + tool_calls: list[ToolCallRequest] = [] + tool_call_buffers: dict[str, dict[str, Any]] = {} + finish_reason = "stop" + + async for event in iter_sse(response): + event_type = event.get("type") + if event_type == "response.output_item.added": + item = event.get("item") or {} + if item.get("type") == "function_call": + call_id = item.get("call_id") + if not call_id: + continue + tool_call_buffers[call_id] = { + "id": item.get("id") or "fc_0", + "name": item.get("name"), + "arguments": item.get("arguments") or "", + } + elif event_type == "response.output_text.delta": + delta_text = event.get("delta") or "" + content += delta_text + if on_content_delta and delta_text: + await on_content_delta(delta_text) + elif event_type == "response.function_call_arguments.delta": + call_id = event.get("call_id") + if call_id and call_id in tool_call_buffers: + tool_call_buffers[call_id]["arguments"] += event.get("delta") or "" + elif event_type == "response.function_call_arguments.done": + call_id = event.get("call_id") + if call_id and call_id in tool_call_buffers: + tool_call_buffers[call_id]["arguments"] = event.get("arguments") or "" + elif event_type == "response.output_item.done": + item = event.get("item") or {} + if item.get("type") == "function_call": + call_id = item.get("call_id") + if not call_id: + continue + buf = tool_call_buffers.get(call_id) or {} + args_raw = buf.get("arguments") or item.get("arguments") or "{}" + try: + args = json.loads(args_raw) + except Exception: + args = {"raw": args_raw} + tool_calls.append( + ToolCallRequest( + id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}", + name=buf.get("name") or item.get("name"), + arguments=args, + ) + ) + elif event_type == "response.completed": + status = (event.get("response") or {}).get("status") + finish_reason = map_finish_reason(status) + elif event_type in {"error", "response.failed"}: + raise RuntimeError("Response failed") + + return content, tool_calls, finish_reason + + +def parse_response_output(response: Any) -> LLMResponse: + """Parse an SDK ``Response`` object (from ``client.responses.create()``) + into an ``LLMResponse``. + + Works with both Pydantic model objects and plain dicts. + """ + # Normalise to dict + if not isinstance(response, dict): + dump = getattr(response, "model_dump", None) + response = dump() if callable(dump) else vars(response) + + output = response.get("output") or [] + content_parts: list[str] = [] + tool_calls: list[ToolCallRequest] = [] + + for item in output: + if not isinstance(item, dict): + dump = getattr(item, "model_dump", None) + item = dump() if callable(dump) else vars(item) + + item_type = item.get("type") + if item_type == "message": + for block in item.get("content") or []: + if not isinstance(block, dict): + dump = getattr(block, "model_dump", None) + block = dump() if callable(dump) else vars(block) + if block.get("type") == "output_text": + content_parts.append(block.get("text") or "") + elif item_type == "function_call": + call_id = item.get("call_id") or "" + item_id = item.get("id") or "fc_0" + args_raw = item.get("arguments") or "{}" + try: + args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw + except Exception: + args = {"raw": args_raw} + tool_calls.append(ToolCallRequest( + id=f"{call_id}|{item_id}", + name=item.get("name") or "", + arguments=args if isinstance(args, dict) else {}, + )) + + usage_raw = response.get("usage") or {} + if not isinstance(usage_raw, dict): + dump = getattr(usage_raw, "model_dump", None) + usage_raw = dump() if callable(dump) else vars(usage_raw) + usage = {} + if usage_raw: + usage = { + "prompt_tokens": int(usage_raw.get("input_tokens") or 0), + "completion_tokens": int(usage_raw.get("output_tokens") or 0), + "total_tokens": int(usage_raw.get("total_tokens") or 0), + } + + status = response.get("status") + finish_reason = map_finish_reason(status) + + return LLMResponse( + content="".join(content_parts) or None, + tool_calls=tool_calls, + finish_reason=finish_reason, + usage=usage, + ) diff --git a/tests/providers/test_azure_openai_provider.py b/tests/providers/test_azure_openai_provider.py index 77f36d468..9a95cae5d 100644 --- a/tests/providers/test_azure_openai_provider.py +++ b/tests/providers/test_azure_openai_provider.py @@ -1,6 +1,6 @@ -"""Test Azure OpenAI provider implementation (updated for model-based deployment names).""" +"""Test Azure OpenAI provider (Responses API via OpenAI SDK).""" -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -8,392 +8,415 @@ from nanobot.providers.azure_openai_provider import AzureOpenAIProvider from nanobot.providers.base import LLMResponse -def test_azure_openai_provider_init(): - """Test AzureOpenAIProvider initialization without deployment_name.""" +# --------------------------------------------------------------------------- +# Init & validation +# --------------------------------------------------------------------------- + + +def test_init_creates_sdk_client(): + """Provider creates an AsyncOpenAI client with correct base_url.""" provider = AzureOpenAIProvider( api_key="test-key", api_base="https://test-resource.openai.azure.com", default_model="gpt-4o-deployment", ) - assert provider.api_key == "test-key" assert provider.api_base == "https://test-resource.openai.azure.com/" assert provider.default_model == "gpt-4o-deployment" - assert provider.api_version == "2024-10-21" + # SDK client base_url ends with /openai/v1/ + assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1") -def test_azure_openai_provider_init_validation(): - """Test AzureOpenAIProvider initialization validation.""" - # Missing api_key +def test_init_base_url_no_trailing_slash(): + """Trailing slashes are normalised before building base_url.""" + provider = AzureOpenAIProvider( + api_key="k", api_base="https://res.openai.azure.com", + ) + assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1") + + +def test_init_base_url_with_trailing_slash(): + provider = AzureOpenAIProvider( + api_key="k", api_base="https://res.openai.azure.com/", + ) + assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1") + + +def test_init_validation_missing_key(): with pytest.raises(ValueError, match="Azure OpenAI api_key is required"): AzureOpenAIProvider(api_key="", api_base="https://test.com") - - # Missing api_base + + +def test_init_validation_missing_base(): with pytest.raises(ValueError, match="Azure OpenAI api_base is required"): AzureOpenAIProvider(api_key="test", api_base="") -def test_build_chat_url(): - """Test Azure OpenAI URL building with different deployment names.""" +def test_no_api_version_in_base_url(): + """The /openai/v1/ path should NOT contain an api-version query param.""" + provider = AzureOpenAIProvider(api_key="k", api_base="https://res.openai.azure.com") + base = str(provider._client.base_url) + assert "api-version" not in base + + +# --------------------------------------------------------------------------- +# _supports_temperature +# --------------------------------------------------------------------------- + + +def test_supports_temperature_standard_model(): + assert AzureOpenAIProvider._supports_temperature("gpt-4o") is True + + +def test_supports_temperature_reasoning_model(): + assert AzureOpenAIProvider._supports_temperature("o3-mini") is False + assert AzureOpenAIProvider._supports_temperature("gpt-5-chat") is False + assert AzureOpenAIProvider._supports_temperature("o4-mini") is False + + +def test_supports_temperature_with_reasoning_effort(): + assert AzureOpenAIProvider._supports_temperature("gpt-4o", reasoning_effort="medium") is False + + +# --------------------------------------------------------------------------- +# _build_body — Responses API body construction +# --------------------------------------------------------------------------- + + +def test_build_body_basic(): provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", - default_model="gpt-4o", + api_key="k", api_base="https://res.openai.azure.com", default_model="gpt-4o", ) - - # Test various deployment names - test_cases = [ - ("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"), - ("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"), - ("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"), - ] - - for deployment_name, expected_url in test_cases: - url = provider._build_chat_url(deployment_name) - assert url == expected_url + messages = [{"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hi"}] + body = provider._build_body(messages, None, None, 4096, 0.7, None, None) - -def test_build_chat_url_api_base_without_slash(): - """Test URL building when api_base doesn't end with slash.""" - provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", # No trailing slash - default_model="gpt-4o", + assert body["model"] == "gpt-4o" + assert body["instructions"] == "You are helpful." + assert body["temperature"] == 0.7 + assert body["store"] is False + assert "reasoning" not in body + # input should contain the converted user message only (system extracted) + assert any( + item.get("role") == "user" + for item in body["input"] ) - - url = provider._build_chat_url("test-deployment") - expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21" - assert url == expected -def test_build_headers(): - """Test Azure OpenAI header building with api-key authentication.""" - provider = AzureOpenAIProvider( - api_key="test-api-key-123", - api_base="https://test-resource.openai.azure.com", - default_model="gpt-4o", - ) - - headers = provider._build_headers() - assert headers["Content-Type"] == "application/json" - assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header - assert "x-session-affinity" in headers - - -def test_prepare_request_payload(): - """Test request payload preparation with Azure OpenAI 2024-10-21 compliance.""" - provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", - default_model="gpt-4o", - ) - - messages = [{"role": "user", "content": "Hello"}] - payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8) - - assert payload["messages"] == messages - assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens - assert payload["temperature"] == 0.8 - assert "tools" not in payload - - # Test with tools +def test_build_body_with_tools(): + provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o") tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}] - payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools) - assert payload_with_tools["tools"] == tools - assert payload_with_tools["tool_choice"] == "auto" - - # Test with reasoning_effort - payload_with_reasoning = provider._prepare_request_payload( - "gpt-5-chat", messages, reasoning_effort="medium" + body = provider._build_body( + [{"role": "user", "content": "weather?"}], tools, None, 4096, 0.7, None, None, ) - assert payload_with_reasoning["reasoning_effort"] == "medium" - assert "temperature" not in payload_with_reasoning + assert body["tools"] == [{"type": "function", "name": "get_weather", "description": "", "parameters": {}}] + assert body["tool_choice"] == "auto" -def test_prepare_request_payload_sanitizes_messages(): - """Test Azure payload strips non-standard message keys before sending.""" - provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", - default_model="gpt-4o", +def test_build_body_with_reasoning(): + provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-5-chat") + body = provider._build_body( + [{"role": "user", "content": "think"}], None, "gpt-5-chat", 4096, 0.7, "medium", None, ) + assert body["reasoning"] == {"effort": "medium"} + assert "reasoning.encrypted_content" in body.get("include", []) + # temperature omitted for reasoning models + assert "temperature" not in body - messages = [ - { - "role": "assistant", - "tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}], - "reasoning_content": "hidden chain-of-thought", - }, - { - "role": "tool", - "tool_call_id": "call_123", - "name": "x", - "content": "ok", - "extra_field": "should be removed", - }, - ] - payload = provider._prepare_request_payload("gpt-4o", messages) +def test_build_body_image_conversion(): + """image_url content blocks should be converted to input_image.""" + provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o") + messages = [{ + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + {"type": "image_url", "image_url": {"url": "https://example.com/img.png"}}, + ], + }] + body = provider._build_body(messages, None, None, 4096, 0.7, None, None) + user_item = body["input"][0] + content_types = [b["type"] for b in user_item["content"]] + assert "input_text" in content_types + assert "input_image" in content_types + image_block = next(b for b in user_item["content"] if b["type"] == "input_image") + assert image_block["image_url"] == "https://example.com/img.png" - assert payload["messages"] == [ - { - "role": "assistant", - "content": None, - "tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}], + +# --------------------------------------------------------------------------- +# chat() — non-streaming +# --------------------------------------------------------------------------- + + +def _make_sdk_response( + content="Hello!", tool_calls=None, status="completed", + usage=None, +): + """Build a mock that quacks like an openai Response object.""" + resp = MagicMock() + resp.model_dump = MagicMock(return_value={ + "output": [ + {"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": content}]}, + *([{ + "type": "function_call", + "call_id": tc["call_id"], "id": tc["id"], + "name": tc["name"], "arguments": tc["arguments"], + } for tc in (tool_calls or [])]), + ], + "status": status, + "usage": { + "input_tokens": (usage or {}).get("input_tokens", 10), + "output_tokens": (usage or {}).get("output_tokens", 5), + "total_tokens": (usage or {}).get("total_tokens", 15), }, - { - "role": "tool", - "tool_call_id": "call_123", - "name": "x", - "content": "ok", - }, - ] + }) + return resp @pytest.mark.asyncio async def test_chat_success(): - """Test successful chat request using model as deployment name.""" provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", - default_model="gpt-4o-deployment", + api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o", ) - - # Mock response data - mock_response_data = { - "choices": [{ - "message": { - "content": "Hello! How can I help you today?", - "role": "assistant" - }, - "finish_reason": "stop" - }], - "usage": { - "prompt_tokens": 12, - "completion_tokens": 18, - "total_tokens": 30 - } - } - - with patch("httpx.AsyncClient") as mock_client: - mock_response = AsyncMock() - mock_response.status_code = 200 - mock_response.json = Mock(return_value=mock_response_data) - - mock_context = AsyncMock() - mock_context.post = AsyncMock(return_value=mock_response) - mock_client.return_value.__aenter__.return_value = mock_context - - # Test with specific model (deployment name) - messages = [{"role": "user", "content": "Hello"}] - result = await provider.chat(messages, model="custom-deployment") - - assert isinstance(result, LLMResponse) - assert result.content == "Hello! How can I help you today?" - assert result.finish_reason == "stop" - assert result.usage["prompt_tokens"] == 12 - assert result.usage["completion_tokens"] == 18 - assert result.usage["total_tokens"] == 30 - - # Verify URL was built with the provided model as deployment name - call_args = mock_context.post.call_args - expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21" - assert call_args[0][0] == expected_url + mock_resp = _make_sdk_response(content="Hello!") + provider._client.responses = MagicMock() + provider._client.responses.create = AsyncMock(return_value=mock_resp) + + result = await provider.chat([{"role": "user", "content": "Hi"}]) + + assert isinstance(result, LLMResponse) + assert result.content == "Hello!" + assert result.finish_reason == "stop" + assert result.usage["prompt_tokens"] == 10 @pytest.mark.asyncio -async def test_chat_uses_default_model_when_no_model_provided(): - """Test that chat uses default_model when no model is specified.""" +async def test_chat_uses_default_model(): provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", - default_model="default-deployment", + api_key="k", api_base="https://test.openai.azure.com", default_model="my-deployment", ) - - mock_response_data = { - "choices": [{ - "message": {"content": "Response", "role": "assistant"}, - "finish_reason": "stop" - }], - "usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10} - } - - with patch("httpx.AsyncClient") as mock_client: - mock_response = AsyncMock() - mock_response.status_code = 200 - mock_response.json = Mock(return_value=mock_response_data) - - mock_context = AsyncMock() - mock_context.post = AsyncMock(return_value=mock_response) - mock_client.return_value.__aenter__.return_value = mock_context - - messages = [{"role": "user", "content": "Test"}] - await provider.chat(messages) # No model specified - - # Verify URL was built with default model as deployment name - call_args = mock_context.post.call_args - expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21" - assert call_args[0][0] == expected_url + mock_resp = _make_sdk_response(content="ok") + provider._client.responses = MagicMock() + provider._client.responses.create = AsyncMock(return_value=mock_resp) + + await provider.chat([{"role": "user", "content": "test"}]) + + call_kwargs = provider._client.responses.create.call_args[1] + assert call_kwargs["model"] == "my-deployment" + + +@pytest.mark.asyncio +async def test_chat_custom_model(): + provider = AzureOpenAIProvider( + api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o", + ) + mock_resp = _make_sdk_response(content="ok") + provider._client.responses = MagicMock() + provider._client.responses.create = AsyncMock(return_value=mock_resp) + + await provider.chat([{"role": "user", "content": "test"}], model="custom-deploy") + + call_kwargs = provider._client.responses.create.call_args[1] + assert call_kwargs["model"] == "custom-deploy" @pytest.mark.asyncio async def test_chat_with_tool_calls(): - """Test chat request with tool calls in response.""" provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", - default_model="gpt-4o", + api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o", ) - - # Mock response with tool calls - mock_response_data = { - "choices": [{ - "message": { - "content": None, - "role": "assistant", - "tool_calls": [{ - "id": "call_12345", - "function": { - "name": "get_weather", - "arguments": '{"location": "San Francisco"}' - } - }] - }, - "finish_reason": "tool_calls" + mock_resp = _make_sdk_response( + content=None, + tool_calls=[{ + "call_id": "call_123", "id": "fc_1", + "name": "get_weather", "arguments": '{"location": "SF"}', }], - "usage": { - "prompt_tokens": 20, - "completion_tokens": 15, - "total_tokens": 35 - } - } - - with patch("httpx.AsyncClient") as mock_client: - mock_response = AsyncMock() - mock_response.status_code = 200 - mock_response.json = Mock(return_value=mock_response_data) - - mock_context = AsyncMock() - mock_context.post = AsyncMock(return_value=mock_response) - mock_client.return_value.__aenter__.return_value = mock_context - - messages = [{"role": "user", "content": "What's the weather?"}] - tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}] - result = await provider.chat(messages, tools=tools, model="weather-model") - - assert isinstance(result, LLMResponse) - assert result.content is None - assert result.finish_reason == "tool_calls" - assert len(result.tool_calls) == 1 - assert result.tool_calls[0].name == "get_weather" - assert result.tool_calls[0].arguments == {"location": "San Francisco"} + ) + provider._client.responses = MagicMock() + provider._client.responses.create = AsyncMock(return_value=mock_resp) + + result = await provider.chat( + [{"role": "user", "content": "Weather?"}], + tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}], + ) + + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].name == "get_weather" + assert result.tool_calls[0].arguments == {"location": "SF"} @pytest.mark.asyncio -async def test_chat_api_error(): - """Test chat request API error handling.""" +async def test_chat_error_handling(): provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", - default_model="gpt-4o", + api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o", ) - - with patch("httpx.AsyncClient") as mock_client: - mock_response = AsyncMock() - mock_response.status_code = 401 - mock_response.text = "Invalid authentication credentials" - - mock_context = AsyncMock() - mock_context.post = AsyncMock(return_value=mock_response) - mock_client.return_value.__aenter__.return_value = mock_context - - messages = [{"role": "user", "content": "Hello"}] - result = await provider.chat(messages) - - assert isinstance(result, LLMResponse) - assert "Azure OpenAI API Error 401" in result.content - assert "Invalid authentication credentials" in result.content - assert result.finish_reason == "error" + provider._client.responses = MagicMock() + provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed")) + result = await provider.chat([{"role": "user", "content": "Hi"}]) -@pytest.mark.asyncio -async def test_chat_connection_error(): - """Test chat request connection error handling.""" - provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", - default_model="gpt-4o", - ) - - with patch("httpx.AsyncClient") as mock_client: - mock_context = AsyncMock() - mock_context.post = AsyncMock(side_effect=Exception("Connection failed")) - mock_client.return_value.__aenter__.return_value = mock_context - - messages = [{"role": "user", "content": "Hello"}] - result = await provider.chat(messages) - - assert isinstance(result, LLMResponse) - assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content - assert result.finish_reason == "error" - - -def test_parse_response_malformed(): - """Test response parsing with malformed data.""" - provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", - default_model="gpt-4o", - ) - - # Test with missing choices - malformed_response = {"usage": {"prompt_tokens": 10}} - result = provider._parse_response(malformed_response) - assert isinstance(result, LLMResponse) - assert "Error parsing Azure OpenAI response" in result.content + assert "Connection failed" in result.content assert result.finish_reason == "error" +@pytest.mark.asyncio +async def test_chat_reasoning_param_format(): + """reasoning_effort should be sent as reasoning={effort: ...} not a flat string.""" + provider = AzureOpenAIProvider( + api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-5-chat", + ) + mock_resp = _make_sdk_response(content="thought") + provider._client.responses = MagicMock() + provider._client.responses.create = AsyncMock(return_value=mock_resp) + + await provider.chat( + [{"role": "user", "content": "think"}], reasoning_effort="medium", + ) + + call_kwargs = provider._client.responses.create.call_args[1] + assert call_kwargs["reasoning"] == {"effort": "medium"} + assert "reasoning_effort" not in call_kwargs + + +# --------------------------------------------------------------------------- +# chat_stream() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_chat_stream_success(): + """Streaming should call on_content_delta and return combined response.""" + provider = AzureOpenAIProvider( + api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o", + ) + + # Build SSE lines for the mock httpx stream + sse_events = [ + 'event: response.output_text.delta', + 'data: {"type":"response.output_text.delta","delta":"Hello"}', + '', + 'event: response.output_text.delta', + 'data: {"type":"response.output_text.delta","delta":" world"}', + '', + 'event: response.completed', + 'data: {"type":"response.completed","response":{"status":"completed"}}', + '', + ] + + deltas: list[str] = [] + + async def on_delta(text: str) -> None: + deltas.append(text) + + # Mock httpx stream + mock_response = AsyncMock() + mock_response.status_code = 200 + + async def aiter_lines(): + for line in sse_events: + yield line + + mock_response.aiter_lines = aiter_lines + + with patch("httpx.AsyncClient") as mock_client: + mock_ctx = AsyncMock() + mock_stream_ctx = AsyncMock() + mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response) + mock_stream_ctx.__aexit__ = AsyncMock(return_value=False) + mock_ctx.stream = MagicMock(return_value=mock_stream_ctx) + mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) + mock_client.return_value.__aexit__ = AsyncMock(return_value=False) + + result = await provider.chat_stream( + [{"role": "user", "content": "Hi"}], on_content_delta=on_delta, + ) + + assert result.content == "Hello world" + assert result.finish_reason == "stop" + assert deltas == ["Hello", " world"] + + +@pytest.mark.asyncio +async def test_chat_stream_with_tool_calls(): + """Streaming tool calls should be accumulated correctly.""" + provider = AzureOpenAIProvider( + api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o", + ) + + sse_events = [ + 'data: {"type":"response.output_item.added","item":{"type":"function_call","call_id":"call_1","id":"fc_1","name":"get_weather","arguments":""}}', + '', + 'data: {"type":"response.function_call_arguments.delta","call_id":"call_1","delta":"{\\"loc"}', + '', + 'data: {"type":"response.function_call_arguments.done","call_id":"call_1","arguments":"{\\"location\\":\\"SF\\"}"}', + '', + 'data: {"type":"response.output_item.done","item":{"type":"function_call","call_id":"call_1","id":"fc_1","name":"get_weather","arguments":"{\\"location\\":\\"SF\\"}"}}', + '', + 'data: {"type":"response.completed","response":{"status":"completed"}}', + '', + ] + + mock_response = AsyncMock() + mock_response.status_code = 200 + + async def aiter_lines(): + for line in sse_events: + yield line + + mock_response.aiter_lines = aiter_lines + + with patch("httpx.AsyncClient") as mock_client: + mock_ctx = AsyncMock() + mock_stream_ctx = AsyncMock() + mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response) + mock_stream_ctx.__aexit__ = AsyncMock(return_value=False) + mock_ctx.stream = MagicMock(return_value=mock_stream_ctx) + mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) + mock_client.return_value.__aexit__ = AsyncMock(return_value=False) + + result = await provider.chat_stream( + [{"role": "user", "content": "weather?"}], + tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}], + ) + + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].name == "get_weather" + assert result.tool_calls[0].arguments == {"location": "SF"} + + +@pytest.mark.asyncio +async def test_chat_stream_http_error(): + """Streaming should return error on non-200 status.""" + provider = AzureOpenAIProvider( + api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o", + ) + + mock_response = AsyncMock() + mock_response.status_code = 401 + mock_response.aread = AsyncMock(return_value=b"Unauthorized") + + with patch("httpx.AsyncClient") as mock_client: + mock_ctx = AsyncMock() + mock_stream_ctx = AsyncMock() + mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response) + mock_stream_ctx.__aexit__ = AsyncMock(return_value=False) + mock_ctx.stream = MagicMock(return_value=mock_stream_ctx) + mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) + mock_client.return_value.__aexit__ = AsyncMock(return_value=False) + + result = await provider.chat_stream([{"role": "user", "content": "Hi"}]) + + assert "401" in result.content + assert result.finish_reason == "error" + + +# --------------------------------------------------------------------------- +# get_default_model +# --------------------------------------------------------------------------- + + def test_get_default_model(): - """Test get_default_model method.""" provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", - default_model="my-custom-deployment", + api_key="k", api_base="https://r.com", default_model="my-deploy", ) - - assert provider.get_default_model() == "my-custom-deployment" - - -if __name__ == "__main__": - # Run basic tests - print("Running basic Azure OpenAI provider tests...") - - # Test initialization - provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", - default_model="gpt-4o-deployment", - ) - print("✅ Provider initialization successful") - - # Test URL building - url = provider._build_chat_url("my-deployment") - expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21" - assert url == expected - print("✅ URL building works correctly") - - # Test headers - headers = provider._build_headers() - assert headers["api-key"] == "test-key" - assert headers["Content-Type"] == "application/json" - print("✅ Header building works correctly") - - # Test payload preparation - messages = [{"role": "user", "content": "Test"}] - payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000) - assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format - print("✅ Payload preparation works correctly") - - print("✅ All basic tests passed! Updated test file is working correctly.") \ No newline at end of file + assert provider.get_default_model() == "my-deploy" From 8c0607e079eff78932c3d45013164975501cfe64 Mon Sep 17 00:00:00 2001 From: Kunal Karmakar Date: Tue, 31 Mar 2026 02:17:30 +0000 Subject: [PATCH 20/31] Use SDK for stream --- nanobot/providers/azure_openai_provider.py | 38 ++--- .../openai_responses_common/__init__.py | 2 + .../openai_responses_common/parsing.py | 69 +++++++++ tests/providers/test_azure_openai_provider.py | 141 +++++++----------- 4 files changed, 139 insertions(+), 111 deletions(-) diff --git a/nanobot/providers/azure_openai_provider.py b/nanobot/providers/azure_openai_provider.py index ab4d187ae..b97743ab2 100644 --- a/nanobot/providers/azure_openai_provider.py +++ b/nanobot/providers/azure_openai_provider.py @@ -11,12 +11,11 @@ import uuid from collections.abc import Awaitable, Callable from typing import Any -import httpx from openai import AsyncOpenAI from nanobot.providers.base import LLMProvider, LLMResponse from nanobot.providers.openai_responses_common import ( - consume_sse, + consume_sdk_stream, convert_messages, convert_tools, parse_response_output, @@ -94,6 +93,7 @@ class AzureOpenAIProvider(LLMProvider): "model": deployment, "instructions": instructions or None, "input": input_items, + "max_output_tokens": max(1, max_tokens), "store": False, "stream": False, } @@ -159,31 +159,15 @@ class AzureOpenAIProvider(LLMProvider): body["stream"] = True try: - # Use raw httpx stream via the SDK's base URL so we can reuse - # the shared Responses-API SSE parser (same as Codex provider). - base_url = str(self._client.base_url).rstrip("/") - url = f"{base_url}/responses" - headers = { - "Authorization": f"Bearer {self._client.api_key}", - "Content-Type": "application/json", - **(self._client._custom_headers or {}), - } - async with httpx.AsyncClient(timeout=60.0, verify=True) as http: - async with http.stream("POST", url, headers=headers, json=body) as response: - if response.status_code != 200: - text = await response.aread() - return LLMResponse( - content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}", - finish_reason="error", - ) - content, tool_calls, finish_reason = await consume_sse( - response, on_content_delta, - ) - return LLMResponse( - content=content or None, - tool_calls=tool_calls, - finish_reason=finish_reason, - ) + stream = await self._client.responses.create(**body) + content, tool_calls, finish_reason = await consume_sdk_stream( + stream, on_content_delta, + ) + return LLMResponse( + content=content or None, + tool_calls=tool_calls, + finish_reason=finish_reason, + ) except Exception as e: return self._handle_error(e) diff --git a/nanobot/providers/openai_responses_common/__init__.py b/nanobot/providers/openai_responses_common/__init__.py index cfc327bdb..80a03e43a 100644 --- a/nanobot/providers/openai_responses_common/__init__.py +++ b/nanobot/providers/openai_responses_common/__init__.py @@ -8,6 +8,7 @@ from nanobot.providers.openai_responses_common.converters import ( ) from nanobot.providers.openai_responses_common.parsing import ( FINISH_REASON_MAP, + consume_sdk_stream, consume_sse, iter_sse, map_finish_reason, @@ -21,6 +22,7 @@ __all__ = [ "split_tool_call_id", "iter_sse", "consume_sse", + "consume_sdk_stream", "map_finish_reason", "parse_response_output", "FINISH_REASON_MAP", diff --git a/nanobot/providers/openai_responses_common/parsing.py b/nanobot/providers/openai_responses_common/parsing.py index e0d5f4462..5de895534 100644 --- a/nanobot/providers/openai_responses_common/parsing.py +++ b/nanobot/providers/openai_responses_common/parsing.py @@ -171,3 +171,72 @@ def parse_response_output(response: Any) -> LLMResponse: finish_reason=finish_reason, usage=usage, ) + + +async def consume_sdk_stream( + stream: Any, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, +) -> tuple[str, list[ToolCallRequest], str]: + """Consume an SDK async stream from ``client.responses.create(stream=True)``. + + The SDK yields typed event objects with a ``.type`` attribute and + event-specific fields. Returns ``(content, tool_calls, finish_reason)``. + """ + content = "" + tool_calls: list[ToolCallRequest] = [] + tool_call_buffers: dict[str, dict[str, Any]] = {} + finish_reason = "stop" + + async for event in stream: + event_type = getattr(event, "type", None) + if event_type == "response.output_item.added": + item = getattr(event, "item", None) + if item and getattr(item, "type", None) == "function_call": + call_id = getattr(item, "call_id", None) + if not call_id: + continue + tool_call_buffers[call_id] = { + "id": getattr(item, "id", None) or "fc_0", + "name": getattr(item, "name", None), + "arguments": getattr(item, "arguments", None) or "", + } + elif event_type == "response.output_text.delta": + delta_text = getattr(event, "delta", "") or "" + content += delta_text + if on_content_delta and delta_text: + await on_content_delta(delta_text) + elif event_type == "response.function_call_arguments.delta": + call_id = getattr(event, "call_id", None) + if call_id and call_id in tool_call_buffers: + tool_call_buffers[call_id]["arguments"] += getattr(event, "delta", "") or "" + elif event_type == "response.function_call_arguments.done": + call_id = getattr(event, "call_id", None) + if call_id and call_id in tool_call_buffers: + tool_call_buffers[call_id]["arguments"] = getattr(event, "arguments", "") or "" + elif event_type == "response.output_item.done": + item = getattr(event, "item", None) + if item and getattr(item, "type", None) == "function_call": + call_id = getattr(item, "call_id", None) + if not call_id: + continue + buf = tool_call_buffers.get(call_id) or {} + args_raw = buf.get("arguments") or getattr(item, "arguments", None) or "{}" + try: + args = json.loads(args_raw) + except Exception: + args = {"raw": args_raw} + tool_calls.append( + ToolCallRequest( + id=f"{call_id}|{buf.get('id') or getattr(item, 'id', None) or 'fc_0'}", + name=buf.get("name") or getattr(item, "name", None), + arguments=args, + ) + ) + elif event_type == "response.completed": + resp = getattr(event, "response", None) + status = getattr(resp, "status", None) if resp else None + finish_reason = map_finish_reason(status) + elif event_type in {"error", "response.failed"}: + raise RuntimeError("Response failed") + + return content, tool_calls, finish_reason diff --git a/tests/providers/test_azure_openai_provider.py b/tests/providers/test_azure_openai_provider.py index 9a95cae5d..4a18f3bf9 100644 --- a/tests/providers/test_azure_openai_provider.py +++ b/tests/providers/test_azure_openai_provider.py @@ -1,6 +1,6 @@ """Test Azure OpenAI provider (Responses API via OpenAI SDK).""" -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock import pytest @@ -93,6 +93,7 @@ def test_build_body_basic(): assert body["model"] == "gpt-4o" assert body["instructions"] == "You are helpful." assert body["temperature"] == 0.7 + assert body["max_output_tokens"] == 4096 assert body["store"] is False assert "reasoning" not in body # input should contain the converted user message only (system extracted) @@ -102,6 +103,13 @@ def test_build_body_basic(): ) +def test_build_body_max_tokens_minimum(): + """max_output_tokens should never be less than 1.""" + provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o") + body = provider._build_body([{"role": "user", "content": "x"}], None, None, 0, 0.7, None, None) + assert body["max_output_tokens"] == 1 + + def test_build_body_with_tools(): provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o") tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}] @@ -290,46 +298,29 @@ async def test_chat_stream_success(): api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o", ) - # Build SSE lines for the mock httpx stream - sse_events = [ - 'event: response.output_text.delta', - 'data: {"type":"response.output_text.delta","delta":"Hello"}', - '', - 'event: response.output_text.delta', - 'data: {"type":"response.output_text.delta","delta":" world"}', - '', - 'event: response.completed', - 'data: {"type":"response.completed","response":{"status":"completed"}}', - '', - ] + # Build mock SDK stream events + events = [] + ev1 = MagicMock(type="response.output_text.delta", delta="Hello") + ev2 = MagicMock(type="response.output_text.delta", delta=" world") + resp_obj = MagicMock(status="completed") + ev3 = MagicMock(type="response.completed", response=resp_obj) + events = [ev1, ev2, ev3] + + async def mock_stream(): + for e in events: + yield e + + provider._client.responses = MagicMock() + provider._client.responses.create = AsyncMock(return_value=mock_stream()) deltas: list[str] = [] async def on_delta(text: str) -> None: deltas.append(text) - # Mock httpx stream - mock_response = AsyncMock() - mock_response.status_code = 200 - - async def aiter_lines(): - for line in sse_events: - yield line - - mock_response.aiter_lines = aiter_lines - - with patch("httpx.AsyncClient") as mock_client: - mock_ctx = AsyncMock() - mock_stream_ctx = AsyncMock() - mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response) - mock_stream_ctx.__aexit__ = AsyncMock(return_value=False) - mock_ctx.stream = MagicMock(return_value=mock_stream_ctx) - mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) - mock_client.return_value.__aexit__ = AsyncMock(return_value=False) - - result = await provider.chat_stream( - [{"role": "user", "content": "Hi"}], on_content_delta=on_delta, - ) + result = await provider.chat_stream( + [{"role": "user", "content": "Hi"}], on_content_delta=on_delta, + ) assert result.content == "Hello world" assert result.finish_reason == "stop" @@ -343,41 +334,34 @@ async def test_chat_stream_with_tool_calls(): api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o", ) - sse_events = [ - 'data: {"type":"response.output_item.added","item":{"type":"function_call","call_id":"call_1","id":"fc_1","name":"get_weather","arguments":""}}', - '', - 'data: {"type":"response.function_call_arguments.delta","call_id":"call_1","delta":"{\\"loc"}', - '', - 'data: {"type":"response.function_call_arguments.done","call_id":"call_1","arguments":"{\\"location\\":\\"SF\\"}"}', - '', - 'data: {"type":"response.output_item.done","item":{"type":"function_call","call_id":"call_1","id":"fc_1","name":"get_weather","arguments":"{\\"location\\":\\"SF\\"}"}}', - '', - 'data: {"type":"response.completed","response":{"status":"completed"}}', - '', - ] + item_added = MagicMock(type="function_call", call_id="call_1", id="fc_1", arguments="") + item_added.name = "get_weather" + ev_added = MagicMock(type="response.output_item.added", item=item_added) + ev_args_delta = MagicMock(type="response.function_call_arguments.delta", call_id="call_1", delta='{"loc') + ev_args_done = MagicMock( + type="response.function_call_arguments.done", + call_id="call_1", arguments='{"location":"SF"}', + ) + item_done = MagicMock( + type="function_call", call_id="call_1", id="fc_1", + arguments='{"location":"SF"}', + ) + item_done.name = "get_weather" + ev_item_done = MagicMock(type="response.output_item.done", item=item_done) + resp_obj = MagicMock(status="completed") + ev_completed = MagicMock(type="response.completed", response=resp_obj) - mock_response = AsyncMock() - mock_response.status_code = 200 + async def mock_stream(): + for e in [ev_added, ev_args_delta, ev_args_done, ev_item_done, ev_completed]: + yield e - async def aiter_lines(): - for line in sse_events: - yield line + provider._client.responses = MagicMock() + provider._client.responses.create = AsyncMock(return_value=mock_stream()) - mock_response.aiter_lines = aiter_lines - - with patch("httpx.AsyncClient") as mock_client: - mock_ctx = AsyncMock() - mock_stream_ctx = AsyncMock() - mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response) - mock_stream_ctx.__aexit__ = AsyncMock(return_value=False) - mock_ctx.stream = MagicMock(return_value=mock_stream_ctx) - mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) - mock_client.return_value.__aexit__ = AsyncMock(return_value=False) - - result = await provider.chat_stream( - [{"role": "user", "content": "weather?"}], - tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}], - ) + result = await provider.chat_stream( + [{"role": "user", "content": "weather?"}], + tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}], + ) assert len(result.tool_calls) == 1 assert result.tool_calls[0].name == "get_weather" @@ -385,28 +369,17 @@ async def test_chat_stream_with_tool_calls(): @pytest.mark.asyncio -async def test_chat_stream_http_error(): - """Streaming should return error on non-200 status.""" +async def test_chat_stream_error(): + """Streaming should return error when SDK raises.""" provider = AzureOpenAIProvider( api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o", ) + provider._client.responses = MagicMock() + provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed")) - mock_response = AsyncMock() - mock_response.status_code = 401 - mock_response.aread = AsyncMock(return_value=b"Unauthorized") + result = await provider.chat_stream([{"role": "user", "content": "Hi"}]) - with patch("httpx.AsyncClient") as mock_client: - mock_ctx = AsyncMock() - mock_stream_ctx = AsyncMock() - mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response) - mock_stream_ctx.__aexit__ = AsyncMock(return_value=False) - mock_ctx.stream = MagicMock(return_value=mock_stream_ctx) - mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) - mock_client.return_value.__aexit__ = AsyncMock(return_value=False) - - result = await provider.chat_stream([{"role": "user", "content": "Hi"}]) - - assert "401" in result.content + assert "Connection failed" in result.content assert result.finish_reason == "error" From 7c44aa92ca42847fcf6d01b150a36daa740e3548 Mon Sep 17 00:00:00 2001 From: Kunal Karmakar Date: Tue, 31 Mar 2026 02:29:40 +0000 Subject: [PATCH 21/31] Fill up gaps --- nanobot/providers/azure_openai_provider.py | 6 ++-- .../openai_responses_common/parsing.py | 36 +++++++++++++++++-- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/nanobot/providers/azure_openai_provider.py b/nanobot/providers/azure_openai_provider.py index b97743ab2..f2f63a5ba 100644 --- a/nanobot/providers/azure_openai_provider.py +++ b/nanobot/providers/azure_openai_provider.py @@ -160,13 +160,15 @@ class AzureOpenAIProvider(LLMProvider): try: stream = await self._client.responses.create(**body) - content, tool_calls, finish_reason = await consume_sdk_stream( - stream, on_content_delta, + content, tool_calls, finish_reason, usage, reasoning_content = ( + await consume_sdk_stream(stream, on_content_delta) ) return LLMResponse( content=content or None, tool_calls=tool_calls, finish_reason=finish_reason, + usage=usage, + reasoning_content=reasoning_content, ) except Exception as e: return self._handle_error(e) diff --git a/nanobot/providers/openai_responses_common/parsing.py b/nanobot/providers/openai_responses_common/parsing.py index 5de895534..df59babd5 100644 --- a/nanobot/providers/openai_responses_common/parsing.py +++ b/nanobot/providers/openai_responses_common/parsing.py @@ -122,6 +122,7 @@ def parse_response_output(response: Any) -> LLMResponse: output = response.get("output") or [] content_parts: list[str] = [] tool_calls: list[ToolCallRequest] = [] + reasoning_content: str | None = None for item in output: if not isinstance(item, dict): @@ -136,6 +137,14 @@ def parse_response_output(response: Any) -> LLMResponse: block = dump() if callable(dump) else vars(block) if block.get("type") == "output_text": content_parts.append(block.get("text") or "") + elif item_type == "reasoning": + # Reasoning items may have a summary list with text blocks + for s in item.get("summary") or []: + if not isinstance(s, dict): + dump = getattr(s, "model_dump", None) + s = dump() if callable(dump) else vars(s) + if s.get("type") == "summary_text" and s.get("text"): + reasoning_content = (reasoning_content or "") + s["text"] elif item_type == "function_call": call_id = item.get("call_id") or "" item_id = item.get("id") or "fc_0" @@ -170,22 +179,26 @@ def parse_response_output(response: Any) -> LLMResponse: tool_calls=tool_calls, finish_reason=finish_reason, usage=usage, + reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None, ) async def consume_sdk_stream( stream: Any, on_content_delta: Callable[[str], Awaitable[None]] | None = None, -) -> tuple[str, list[ToolCallRequest], str]: +) -> tuple[str, list[ToolCallRequest], str, dict[str, int], str | None]: """Consume an SDK async stream from ``client.responses.create(stream=True)``. The SDK yields typed event objects with a ``.type`` attribute and - event-specific fields. Returns ``(content, tool_calls, finish_reason)``. + event-specific fields. Returns + ``(content, tool_calls, finish_reason, usage, reasoning_content)``. """ content = "" tool_calls: list[ToolCallRequest] = [] tool_call_buffers: dict[str, dict[str, Any]] = {} finish_reason = "stop" + usage: dict[str, int] = {} + reasoning_content: str | None = None async for event in stream: event_type = getattr(event, "type", None) @@ -236,7 +249,24 @@ async def consume_sdk_stream( resp = getattr(event, "response", None) status = getattr(resp, "status", None) if resp else None finish_reason = map_finish_reason(status) + # Extract usage from the completed response + if resp: + usage_obj = getattr(resp, "usage", None) + if usage_obj: + usage = { + "prompt_tokens": int(getattr(usage_obj, "input_tokens", 0) or 0), + "completion_tokens": int(getattr(usage_obj, "output_tokens", 0) or 0), + "total_tokens": int(getattr(usage_obj, "total_tokens", 0) or 0), + } + # Extract reasoning_content from completed output items + for out_item in getattr(resp, "output", None) or []: + if getattr(out_item, "type", None) == "reasoning": + for s in getattr(out_item, "summary", None) or []: + if getattr(s, "type", None) == "summary_text": + text = getattr(s, "text", None) + if text: + reasoning_content = (reasoning_content or "") + text elif event_type in {"error", "response.failed"}: raise RuntimeError("Response failed") - return content, tool_calls, finish_reason + return content, tool_calls, finish_reason, usage, reasoning_content From ac2ee587914bc042cf41f8c9e88b1f7024e4448f Mon Sep 17 00:00:00 2001 From: Kunal Karmakar Date: Tue, 31 Mar 2026 08:30:11 +0000 Subject: [PATCH 22/31] Add tests and logs --- .../openai_responses_common/parsing.py | 9 + .../providers/test_openai_responses_common.py | 532 ++++++++++++++++++ 2 files changed, 541 insertions(+) create mode 100644 tests/providers/test_openai_responses_common.py diff --git a/nanobot/providers/openai_responses_common/parsing.py b/nanobot/providers/openai_responses_common/parsing.py index df59babd5..1e38fdc4e 100644 --- a/nanobot/providers/openai_responses_common/parsing.py +++ b/nanobot/providers/openai_responses_common/parsing.py @@ -7,6 +7,7 @@ from collections.abc import Awaitable, Callable from typing import Any, AsyncGenerator import httpx +from loguru import logger from nanobot.providers.base import LLMResponse, ToolCallRequest @@ -39,6 +40,7 @@ async def iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], N try: yield json.loads(data) except Exception: + logger.warning("Failed to parse SSE event JSON: {}", data[:200]) continue continue buffer.append(line) @@ -91,6 +93,8 @@ async def consume_sse( try: args = json.loads(args_raw) except Exception: + logger.warning("Failed to parse tool call arguments for '{}': {}", + buf.get("name") or item.get("name"), args_raw[:200]) args = {"raw": args_raw} tool_calls.append( ToolCallRequest( @@ -152,6 +156,8 @@ def parse_response_output(response: Any) -> LLMResponse: try: args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw except Exception: + logger.warning("Failed to parse tool call arguments for '{}': {}", + item.get("name"), str(args_raw)[:200]) args = {"raw": args_raw} tool_calls.append(ToolCallRequest( id=f"{call_id}|{item_id}", @@ -237,6 +243,9 @@ async def consume_sdk_stream( try: args = json.loads(args_raw) except Exception: + logger.warning("Failed to parse tool call arguments for '{}': {}", + buf.get("name") or getattr(item, "name", None), + str(args_raw)[:200]) args = {"raw": args_raw} tool_calls.append( ToolCallRequest( diff --git a/tests/providers/test_openai_responses_common.py b/tests/providers/test_openai_responses_common.py new file mode 100644 index 000000000..aa972f08b --- /dev/null +++ b/tests/providers/test_openai_responses_common.py @@ -0,0 +1,532 @@ +"""Tests for the shared openai_responses_common converters and parsers.""" + +from unittest.mock import MagicMock + +import pytest +from loguru import logger + +from nanobot.providers.base import LLMResponse, ToolCallRequest +from nanobot.providers.openai_responses_common.converters import ( + convert_messages, + convert_tools, + convert_user_message, + split_tool_call_id, +) +from nanobot.providers.openai_responses_common.parsing import ( + consume_sdk_stream, + map_finish_reason, + parse_response_output, +) + + +@pytest.fixture() +def loguru_capture(): + """Capture loguru messages into a list for assertion.""" + messages: list[str] = [] + + def sink(message): + messages.append(str(message)) + + handler_id = logger.add(sink, format="{message}", level="DEBUG") + yield messages + logger.remove(handler_id) + + +# ====================================================================== +# converters — split_tool_call_id +# ====================================================================== + + +class TestSplitToolCallId: + def test_plain_id(self): + assert split_tool_call_id("call_abc") == ("call_abc", None) + + def test_compound_id(self): + assert split_tool_call_id("call_abc|fc_1") == ("call_abc", "fc_1") + + def test_compound_empty_item_id(self): + assert split_tool_call_id("call_abc|") == ("call_abc", None) + + def test_none(self): + assert split_tool_call_id(None) == ("call_0", None) + + def test_empty_string(self): + assert split_tool_call_id("") == ("call_0", None) + + def test_non_string(self): + assert split_tool_call_id(42) == ("call_0", None) + + +# ====================================================================== +# converters — convert_user_message +# ====================================================================== + + +class TestConvertUserMessage: + def test_string_content(self): + result = convert_user_message("hello") + assert result == {"role": "user", "content": [{"type": "input_text", "text": "hello"}]} + + def test_text_block(self): + result = convert_user_message([{"type": "text", "text": "hi"}]) + assert result["content"] == [{"type": "input_text", "text": "hi"}] + + def test_image_url_block(self): + result = convert_user_message([ + {"type": "image_url", "image_url": {"url": "https://img.example/a.png"}}, + ]) + assert result["content"] == [ + {"type": "input_image", "image_url": "https://img.example/a.png", "detail": "auto"}, + ] + + def test_mixed_text_and_image(self): + result = convert_user_message([ + {"type": "text", "text": "what's this?"}, + {"type": "image_url", "image_url": {"url": "https://img.example/b.png"}}, + ]) + assert len(result["content"]) == 2 + assert result["content"][0]["type"] == "input_text" + assert result["content"][1]["type"] == "input_image" + + def test_empty_list_falls_back(self): + result = convert_user_message([]) + assert result["content"] == [{"type": "input_text", "text": ""}] + + def test_none_falls_back(self): + result = convert_user_message(None) + assert result["content"] == [{"type": "input_text", "text": ""}] + + def test_image_without_url_skipped(self): + result = convert_user_message([{"type": "image_url", "image_url": {}}]) + assert result["content"] == [{"type": "input_text", "text": ""}] + + def test_meta_fields_not_leaked(self): + """_meta on content blocks must never appear in converted output.""" + result = convert_user_message([ + {"type": "text", "text": "hi", "_meta": {"path": "/tmp/x"}}, + ]) + assert "_meta" not in result["content"][0] + + def test_non_dict_items_skipped(self): + result = convert_user_message(["just a string", 42]) + assert result["content"] == [{"type": "input_text", "text": ""}] + + +# ====================================================================== +# converters — convert_messages +# ====================================================================== + + +class TestConvertMessages: + def test_system_extracted_as_instructions(self): + msgs = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + ] + instructions, items = convert_messages(msgs) + assert instructions == "You are helpful." + assert len(items) == 1 + assert items[0]["role"] == "user" + + def test_multiple_system_messages_last_wins(self): + msgs = [ + {"role": "system", "content": "first"}, + {"role": "system", "content": "second"}, + {"role": "user", "content": "x"}, + ] + instructions, _ = convert_messages(msgs) + assert instructions == "second" + + def test_user_message_converted(self): + _, items = convert_messages([{"role": "user", "content": "hello"}]) + assert items[0]["role"] == "user" + assert items[0]["content"][0]["type"] == "input_text" + + def test_assistant_text_message(self): + _, items = convert_messages([ + {"role": "assistant", "content": "I'll help"}, + ]) + assert items[0]["type"] == "message" + assert items[0]["role"] == "assistant" + assert items[0]["content"][0]["type"] == "output_text" + assert items[0]["content"][0]["text"] == "I'll help" + + def test_assistant_empty_content_skipped(self): + _, items = convert_messages([{"role": "assistant", "content": ""}]) + assert len(items) == 0 + + def test_assistant_with_tool_calls(self): + _, items = convert_messages([{ + "role": "assistant", + "content": None, + "tool_calls": [{ + "id": "call_abc|fc_1", + "function": {"name": "get_weather", "arguments": '{"city":"SF"}'}, + }], + }]) + assert items[0]["type"] == "function_call" + assert items[0]["call_id"] == "call_abc" + assert items[0]["id"] == "fc_1" + assert items[0]["name"] == "get_weather" + + def test_assistant_with_tool_calls_no_id(self): + """Fallback IDs when tool_call.id is missing.""" + _, items = convert_messages([{ + "role": "assistant", + "content": None, + "tool_calls": [{"function": {"name": "f1", "arguments": "{}"}}], + }]) + assert items[0]["call_id"] == "call_0" + assert items[0]["id"].startswith("fc_") + + def test_tool_message(self): + _, items = convert_messages([{ + "role": "tool", + "tool_call_id": "call_abc", + "content": "result text", + }]) + assert items[0]["type"] == "function_call_output" + assert items[0]["call_id"] == "call_abc" + assert items[0]["output"] == "result text" + + def test_tool_message_dict_content(self): + _, items = convert_messages([{ + "role": "tool", + "tool_call_id": "call_1", + "content": {"key": "value"}, + }]) + assert items[0]["output"] == '{"key": "value"}' + + def test_non_standard_keys_not_leaked(self): + """Extra keys on messages must not appear in converted items.""" + _, items = convert_messages([{ + "role": "user", + "content": "hi", + "extra_field": "should vanish", + "_meta": {"path": "/tmp"}, + }]) + item = items[0] + assert "extra_field" not in str(item) + assert "_meta" not in str(item) + + def test_full_conversation_roundtrip(self): + """System + user + assistant(tool_call) + tool → correct structure.""" + msgs = [ + {"role": "system", "content": "Be concise."}, + {"role": "user", "content": "Weather in SF?"}, + { + "role": "assistant", "content": None, + "tool_calls": [{ + "id": "c1|fc1", + "function": {"name": "get_weather", "arguments": '{"city":"SF"}'}, + }], + }, + {"role": "tool", "tool_call_id": "c1", "content": '{"temp":72}'}, + ] + instructions, items = convert_messages(msgs) + assert instructions == "Be concise." + assert len(items) == 3 # user, function_call, function_call_output + assert items[0]["role"] == "user" + assert items[1]["type"] == "function_call" + assert items[2]["type"] == "function_call_output" + + +# ====================================================================== +# converters — convert_tools +# ====================================================================== + + +class TestConvertTools: + def test_standard_function_tool(self): + tools = [{"type": "function", "function": { + "name": "get_weather", + "description": "Get weather", + "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, + }}] + result = convert_tools(tools) + assert len(result) == 1 + assert result[0]["type"] == "function" + assert result[0]["name"] == "get_weather" + assert result[0]["description"] == "Get weather" + assert "properties" in result[0]["parameters"] + + def test_tool_without_name_skipped(self): + tools = [{"type": "function", "function": {"parameters": {}}}] + assert convert_tools(tools) == [] + + def test_tool_without_function_wrapper(self): + """Direct dict without type=function wrapper.""" + tools = [{"name": "f1", "description": "d", "parameters": {}}] + result = convert_tools(tools) + assert result[0]["name"] == "f1" + + def test_missing_optional_fields_default(self): + tools = [{"type": "function", "function": {"name": "f"}}] + result = convert_tools(tools) + assert result[0]["description"] == "" + assert result[0]["parameters"] == {} + + def test_multiple_tools(self): + tools = [ + {"type": "function", "function": {"name": "a", "parameters": {}}}, + {"type": "function", "function": {"name": "b", "parameters": {}}}, + ] + assert len(convert_tools(tools)) == 2 + + +# ====================================================================== +# parsing — map_finish_reason +# ====================================================================== + + +class TestMapFinishReason: + def test_completed(self): + assert map_finish_reason("completed") == "stop" + + def test_incomplete(self): + assert map_finish_reason("incomplete") == "length" + + def test_failed(self): + assert map_finish_reason("failed") == "error" + + def test_cancelled(self): + assert map_finish_reason("cancelled") == "error" + + def test_none_defaults_to_stop(self): + assert map_finish_reason(None) == "stop" + + def test_unknown_defaults_to_stop(self): + assert map_finish_reason("some_new_status") == "stop" + + +# ====================================================================== +# parsing — parse_response_output +# ====================================================================== + + +class TestParseResponseOutput: + def test_text_response(self): + resp = { + "output": [{"type": "message", "role": "assistant", + "content": [{"type": "output_text", "text": "Hello!"}]}], + "status": "completed", + "usage": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}, + } + result = parse_response_output(resp) + assert result.content == "Hello!" + assert result.finish_reason == "stop" + assert result.usage == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + assert result.tool_calls == [] + + def test_tool_call_response(self): + resp = { + "output": [{ + "type": "function_call", + "call_id": "call_1", "id": "fc_1", + "name": "get_weather", + "arguments": '{"city": "SF"}', + }], + "status": "completed", + "usage": {}, + } + result = parse_response_output(resp) + assert result.content is None + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].name == "get_weather" + assert result.tool_calls[0].arguments == {"city": "SF"} + assert result.tool_calls[0].id == "call_1|fc_1" + + def test_malformed_tool_arguments_logged(self, loguru_capture): + """Malformed JSON arguments should log a warning and fallback.""" + resp = { + "output": [{ + "type": "function_call", + "call_id": "c1", "id": "fc1", + "name": "f", "arguments": "{bad json", + }], + "status": "completed", "usage": {}, + } + result = parse_response_output(resp) + assert result.tool_calls[0].arguments == {"raw": "{bad json"} + assert any("Failed to parse tool call arguments" in m for m in loguru_capture) + + def test_reasoning_content_extracted(self): + resp = { + "output": [ + {"type": "reasoning", "summary": [ + {"type": "summary_text", "text": "I think "}, + {"type": "summary_text", "text": "therefore I am."}, + ]}, + {"type": "message", "role": "assistant", + "content": [{"type": "output_text", "text": "42"}]}, + ], + "status": "completed", "usage": {}, + } + result = parse_response_output(resp) + assert result.content == "42" + assert result.reasoning_content == "I think therefore I am." + + def test_empty_output(self): + resp = {"output": [], "status": "completed", "usage": {}} + result = parse_response_output(resp) + assert result.content is None + assert result.tool_calls == [] + + def test_incomplete_status(self): + resp = {"output": [], "status": "incomplete", "usage": {}} + result = parse_response_output(resp) + assert result.finish_reason == "length" + + def test_sdk_model_object(self): + """parse_response_output should handle SDK objects with model_dump().""" + mock = MagicMock() + mock.model_dump.return_value = { + "output": [{"type": "message", "role": "assistant", + "content": [{"type": "output_text", "text": "sdk"}]}], + "status": "completed", + "usage": {"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}, + } + result = parse_response_output(mock) + assert result.content == "sdk" + assert result.usage["prompt_tokens"] == 1 + + def test_usage_maps_responses_api_keys(self): + """Responses API uses input_tokens/output_tokens, not prompt_tokens/completion_tokens.""" + resp = { + "output": [], + "status": "completed", + "usage": {"input_tokens": 100, "output_tokens": 50, "total_tokens": 150}, + } + result = parse_response_output(resp) + assert result.usage["prompt_tokens"] == 100 + assert result.usage["completion_tokens"] == 50 + assert result.usage["total_tokens"] == 150 + + +# ====================================================================== +# parsing — consume_sdk_stream +# ====================================================================== + + +class TestConsumeSdkStream: + @pytest.mark.asyncio + async def test_text_stream(self): + ev1 = MagicMock(type="response.output_text.delta", delta="Hello") + ev2 = MagicMock(type="response.output_text.delta", delta=" world") + resp_obj = MagicMock(status="completed", usage=None, output=[]) + ev3 = MagicMock(type="response.completed", response=resp_obj) + + async def stream(): + for e in [ev1, ev2, ev3]: + yield e + + content, tool_calls, finish_reason, usage, reasoning = await consume_sdk_stream(stream()) + assert content == "Hello world" + assert tool_calls == [] + assert finish_reason == "stop" + + @pytest.mark.asyncio + async def test_on_content_delta_called(self): + ev1 = MagicMock(type="response.output_text.delta", delta="hi") + resp_obj = MagicMock(status="completed", usage=None, output=[]) + ev2 = MagicMock(type="response.completed", response=resp_obj) + deltas = [] + + async def cb(text): + deltas.append(text) + + async def stream(): + for e in [ev1, ev2]: + yield e + + await consume_sdk_stream(stream(), on_content_delta=cb) + assert deltas == ["hi"] + + @pytest.mark.asyncio + async def test_tool_call_stream(self): + item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="") + item_added.name = "get_weather" + ev1 = MagicMock(type="response.output_item.added", item=item_added) + ev2 = MagicMock(type="response.function_call_arguments.delta", call_id="c1", delta='{"ci') + ev3 = MagicMock(type="response.function_call_arguments.done", call_id="c1", arguments='{"city":"SF"}') + item_done = MagicMock(type="function_call", call_id="c1", id="fc1", arguments='{"city":"SF"}') + item_done.name = "get_weather" + ev4 = MagicMock(type="response.output_item.done", item=item_done) + resp_obj = MagicMock(status="completed", usage=None, output=[]) + ev5 = MagicMock(type="response.completed", response=resp_obj) + + async def stream(): + for e in [ev1, ev2, ev3, ev4, ev5]: + yield e + + content, tool_calls, finish_reason, usage, reasoning = await consume_sdk_stream(stream()) + assert content == "" + assert len(tool_calls) == 1 + assert tool_calls[0].name == "get_weather" + assert tool_calls[0].arguments == {"city": "SF"} + + @pytest.mark.asyncio + async def test_usage_extracted(self): + usage_obj = MagicMock(input_tokens=10, output_tokens=5, total_tokens=15) + resp_obj = MagicMock(status="completed", usage=usage_obj, output=[]) + ev = MagicMock(type="response.completed", response=resp_obj) + + async def stream(): + yield ev + + _, _, _, usage, _ = await consume_sdk_stream(stream()) + assert usage == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + + @pytest.mark.asyncio + async def test_reasoning_extracted(self): + summary_item = MagicMock(type="summary_text", text="thinking...") + reasoning_item = MagicMock(type="reasoning", summary=[summary_item]) + resp_obj = MagicMock(status="completed", usage=None, output=[reasoning_item]) + ev = MagicMock(type="response.completed", response=resp_obj) + + async def stream(): + yield ev + + _, _, _, _, reasoning = await consume_sdk_stream(stream()) + assert reasoning == "thinking..." + + @pytest.mark.asyncio + async def test_error_event_raises(self): + ev = MagicMock(type="error") + + async def stream(): + yield ev + + with pytest.raises(RuntimeError, match="Response failed"): + await consume_sdk_stream(stream()) + + @pytest.mark.asyncio + async def test_failed_event_raises(self): + ev = MagicMock(type="response.failed") + + async def stream(): + yield ev + + with pytest.raises(RuntimeError, match="Response failed"): + await consume_sdk_stream(stream()) + + @pytest.mark.asyncio + async def test_malformed_tool_args_logged(self, loguru_capture): + """Malformed JSON in streaming tool args should log a warning.""" + item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="") + item_added.name = "f" + ev1 = MagicMock(type="response.output_item.added", item=item_added) + ev2 = MagicMock(type="response.function_call_arguments.done", call_id="c1", arguments="{bad") + item_done = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="{bad") + item_done.name = "f" + ev3 = MagicMock(type="response.output_item.done", item=item_done) + resp_obj = MagicMock(status="completed", usage=None, output=[]) + ev4 = MagicMock(type="response.completed", response=resp_obj) + + async def stream(): + for e in [ev1, ev2, ev3, ev4]: + yield e + + _, tool_calls, _, _, _ = await consume_sdk_stream(stream()) + assert tool_calls[0].arguments == {"raw": "{bad"} + assert any("Failed to parse tool call arguments" in m for m in loguru_capture) From e206cffd7a59238a9a2bef691b58111e214be2e0 Mon Sep 17 00:00:00 2001 From: Kunal Karmakar Date: Tue, 31 Mar 2026 08:37:41 +0000 Subject: [PATCH 23/31] Add tests and handle json --- .../openai_responses_common/parsing.py | 59 +++++++++++++------ .../providers/test_openai_responses_common.py | 8 +-- 2 files changed, 44 insertions(+), 23 deletions(-) diff --git a/nanobot/providers/openai_responses_common/parsing.py b/nanobot/providers/openai_responses_common/parsing.py index 1e38fdc4e..fa1ba13cf 100644 --- a/nanobot/providers/openai_responses_common/parsing.py +++ b/nanobot/providers/openai_responses_common/parsing.py @@ -7,6 +7,7 @@ from collections.abc import Awaitable, Callable from typing import Any, AsyncGenerator import httpx +import json_repair from loguru import logger from nanobot.providers.base import LLMResponse, ToolCallRequest @@ -27,24 +28,36 @@ def map_finish_reason(status: str | None) -> str: async def iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]: """Yield parsed JSON events from a Responses API SSE stream.""" buffer: list[str] = [] + + def _flush() -> dict[str, Any] | None: + data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")] + buffer.clear() + if not data_lines: + return None + data = "\n".join(data_lines).strip() + if not data or data == "[DONE]": + return None + try: + return json.loads(data) + except Exception: + logger.warning("Failed to parse SSE event JSON: {}", data[:200]) + return None + async for line in response.aiter_lines(): if line == "": if buffer: - data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")] - buffer = [] - if not data_lines: - continue - data = "\n".join(data_lines).strip() - if not data or data == "[DONE]": - continue - try: - yield json.loads(data) - except Exception: - logger.warning("Failed to parse SSE event JSON: {}", data[:200]) - continue + event = _flush() + if event is not None: + yield event continue buffer.append(line) + # Flush any remaining buffer at EOF (#10) + if buffer: + event = _flush() + if event is not None: + yield event + async def consume_sse( response: httpx.Response, @@ -95,11 +108,13 @@ async def consume_sse( except Exception: logger.warning("Failed to parse tool call arguments for '{}': {}", buf.get("name") or item.get("name"), args_raw[:200]) - args = {"raw": args_raw} + args = json_repair.loads(args_raw) + if not isinstance(args, dict): + args = {"raw": args_raw} tool_calls.append( ToolCallRequest( id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}", - name=buf.get("name") or item.get("name"), + name=buf.get("name") or item.get("name") or "", arguments=args, ) ) @@ -107,7 +122,8 @@ async def consume_sse( status = (event.get("response") or {}).get("status") finish_reason = map_finish_reason(status) elif event_type in {"error", "response.failed"}: - raise RuntimeError("Response failed") + detail = event.get("error") or event.get("message") or event + raise RuntimeError(f"Response failed: {str(detail)[:500]}") return content, tool_calls, finish_reason @@ -158,7 +174,9 @@ def parse_response_output(response: Any) -> LLMResponse: except Exception: logger.warning("Failed to parse tool call arguments for '{}': {}", item.get("name"), str(args_raw)[:200]) - args = {"raw": args_raw} + args = json_repair.loads(args_raw) if isinstance(args_raw, str) else args_raw + if not isinstance(args, dict): + args = {"raw": args_raw} tool_calls.append(ToolCallRequest( id=f"{call_id}|{item_id}", name=item.get("name") or "", @@ -246,11 +264,13 @@ async def consume_sdk_stream( logger.warning("Failed to parse tool call arguments for '{}': {}", buf.get("name") or getattr(item, "name", None), str(args_raw)[:200]) - args = {"raw": args_raw} + args = json_repair.loads(args_raw) + if not isinstance(args, dict): + args = {"raw": args_raw} tool_calls.append( ToolCallRequest( id=f"{call_id}|{buf.get('id') or getattr(item, 'id', None) or 'fc_0'}", - name=buf.get("name") or getattr(item, "name", None), + name=buf.get("name") or getattr(item, "name", None) or "", arguments=args, ) ) @@ -276,6 +296,7 @@ async def consume_sdk_stream( if text: reasoning_content = (reasoning_content or "") + text elif event_type in {"error", "response.failed"}: - raise RuntimeError("Response failed") + detail = getattr(event, "error", None) or getattr(event, "message", None) or event + raise RuntimeError(f"Response failed: {str(detail)[:500]}") return content, tool_calls, finish_reason, usage, reasoning_content diff --git a/tests/providers/test_openai_responses_common.py b/tests/providers/test_openai_responses_common.py index aa972f08b..adddf49ee 100644 --- a/tests/providers/test_openai_responses_common.py +++ b/tests/providers/test_openai_responses_common.py @@ -492,22 +492,22 @@ class TestConsumeSdkStream: @pytest.mark.asyncio async def test_error_event_raises(self): - ev = MagicMock(type="error") + ev = MagicMock(type="error", error="rate_limit_exceeded") async def stream(): yield ev - with pytest.raises(RuntimeError, match="Response failed"): + with pytest.raises(RuntimeError, match="Response failed.*rate_limit_exceeded"): await consume_sdk_stream(stream()) @pytest.mark.asyncio async def test_failed_event_raises(self): - ev = MagicMock(type="response.failed") + ev = MagicMock(type="response.failed", error="server_error") async def stream(): yield ev - with pytest.raises(RuntimeError, match="Response failed"): + with pytest.raises(RuntimeError, match="Response failed.*server_error"): await consume_sdk_stream(stream()) @pytest.mark.asyncio From 76226274bfb5ad51ad6c77f8e1ebae0312783e2a Mon Sep 17 00:00:00 2001 From: Kunal Karmakar Date: Tue, 31 Mar 2026 09:15:08 +0000 Subject: [PATCH 24/31] Failing test --- tests/providers/test_openai_responses_common.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/providers/test_openai_responses_common.py b/tests/providers/test_openai_responses_common.py index adddf49ee..0879685b2 100644 --- a/tests/providers/test_openai_responses_common.py +++ b/tests/providers/test_openai_responses_common.py @@ -23,11 +23,7 @@ from nanobot.providers.openai_responses_common.parsing import ( def loguru_capture(): """Capture loguru messages into a list for assertion.""" messages: list[str] = [] - - def sink(message): - messages.append(str(message)) - - handler_id = logger.add(sink, format="{message}", level="DEBUG") + handler_id = logger.add(lambda m: messages.append(str(m)), format="{message}", level="DEBUG") yield messages logger.remove(handler_id) From 61d7411238131155b545d283d510ab3c1b8650e9 Mon Sep 17 00:00:00 2001 From: Kunal Karmakar Date: Tue, 31 Mar 2026 09:22:50 +0000 Subject: [PATCH 25/31] Fix failing test --- .../providers/test_openai_responses_common.py | 28 ++++++++----------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/tests/providers/test_openai_responses_common.py b/tests/providers/test_openai_responses_common.py index 0879685b2..15d24041c 100644 --- a/tests/providers/test_openai_responses_common.py +++ b/tests/providers/test_openai_responses_common.py @@ -1,9 +1,8 @@ """Tests for the shared openai_responses_common converters and parsers.""" -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest -from loguru import logger from nanobot.providers.base import LLMResponse, ToolCallRequest from nanobot.providers.openai_responses_common.converters import ( @@ -19,15 +18,6 @@ from nanobot.providers.openai_responses_common.parsing import ( ) -@pytest.fixture() -def loguru_capture(): - """Capture loguru messages into a list for assertion.""" - messages: list[str] = [] - handler_id = logger.add(lambda m: messages.append(str(m)), format="{message}", level="DEBUG") - yield messages - logger.remove(handler_id) - - # ====================================================================== # converters — split_tool_call_id # ====================================================================== @@ -332,7 +322,7 @@ class TestParseResponseOutput: assert result.tool_calls[0].arguments == {"city": "SF"} assert result.tool_calls[0].id == "call_1|fc_1" - def test_malformed_tool_arguments_logged(self, loguru_capture): + def test_malformed_tool_arguments_logged(self): """Malformed JSON arguments should log a warning and fallback.""" resp = { "output": [{ @@ -342,9 +332,11 @@ class TestParseResponseOutput: }], "status": "completed", "usage": {}, } - result = parse_response_output(resp) + with patch("nanobot.providers.openai_responses_common.parsing.logger") as mock_logger: + result = parse_response_output(resp) assert result.tool_calls[0].arguments == {"raw": "{bad json"} - assert any("Failed to parse tool call arguments" in m for m in loguru_capture) + mock_logger.warning.assert_called_once() + assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args) def test_reasoning_content_extracted(self): resp = { @@ -507,7 +499,7 @@ class TestConsumeSdkStream: await consume_sdk_stream(stream()) @pytest.mark.asyncio - async def test_malformed_tool_args_logged(self, loguru_capture): + async def test_malformed_tool_args_logged(self): """Malformed JSON in streaming tool args should log a warning.""" item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="") item_added.name = "f" @@ -523,6 +515,8 @@ class TestConsumeSdkStream: for e in [ev1, ev2, ev3, ev4]: yield e - _, tool_calls, _, _, _ = await consume_sdk_stream(stream()) + with patch("nanobot.providers.openai_responses_common.parsing.logger") as mock_logger: + _, tool_calls, _, _, _ = await consume_sdk_stream(stream()) assert tool_calls[0].arguments == {"raw": "{bad"} - assert any("Failed to parse tool call arguments" in m for m in loguru_capture) + mock_logger.warning.assert_called_once() + assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args) From ded0967c1804be6da4a7eeedd127c1ba7a2f371b Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Thu, 2 Apr 2026 05:11:56 +0000 Subject: [PATCH 26/31] fix(providers): sanitize azure responses input messages --- nanobot/providers/azure_openai_provider.py | 2 +- tests/providers/test_azure_openai_provider.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/nanobot/providers/azure_openai_provider.py b/nanobot/providers/azure_openai_provider.py index f2f63a5ba..bf6ccae8b 100644 --- a/nanobot/providers/azure_openai_provider.py +++ b/nanobot/providers/azure_openai_provider.py @@ -87,7 +87,7 @@ class AzureOpenAIProvider(LLMProvider): ) -> dict[str, Any]: """Build the Responses API request body from Chat-Completions-style args.""" deployment = model or self.default_model - instructions, input_items = convert_messages(messages) + instructions, input_items = convert_messages(self._sanitize_empty_content(messages)) body: dict[str, Any] = { "model": deployment, diff --git a/tests/providers/test_azure_openai_provider.py b/tests/providers/test_azure_openai_provider.py index 4a18f3bf9..89cea64f0 100644 --- a/tests/providers/test_azure_openai_provider.py +++ b/tests/providers/test_azure_openai_provider.py @@ -150,6 +150,19 @@ def test_build_body_image_conversion(): assert image_block["image_url"] == "https://example.com/img.png" +def test_build_body_sanitizes_single_dict_content_block(): + """Single content dicts should be preserved via shared message sanitization.""" + provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o") + messages = [{ + "role": "user", + "content": {"type": "text", "text": "Hi from dict content"}, + }] + + body = provider._build_body(messages, None, None, 4096, 0.7, None, None) + + assert body["input"][0]["content"] == [{"type": "input_text", "text": "Hi from dict content"}] + + # --------------------------------------------------------------------------- # chat() — non-streaming # --------------------------------------------------------------------------- From cc33057985b265d6af99167758a5265575dc5f3f Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Thu, 2 Apr 2026 05:38:19 +0000 Subject: [PATCH 27/31] refactor(providers): rename openai responses helpers --- nanobot/providers/azure_openai_provider.py | 6 +-- nanobot/providers/openai_codex_provider.py | 2 +- .../__init__.py | 4 +- .../converters.py | 4 +- .../parsing.py | 39 ++++++++----------- ...ses_common.py => test_openai_responses.py} | 26 ++++++------- 6 files changed, 38 insertions(+), 43 deletions(-) rename nanobot/providers/{openai_responses_common => openai_responses}/__init__.py (80%) rename nanobot/providers/{openai_responses_common => openai_responses}/converters.py (97%) rename nanobot/providers/{openai_responses_common => openai_responses}/parsing.py (91%) rename tests/providers/{test_openai_responses_common.py => test_openai_responses.py} (96%) diff --git a/nanobot/providers/azure_openai_provider.py b/nanobot/providers/azure_openai_provider.py index bf6ccae8b..12c74be02 100644 --- a/nanobot/providers/azure_openai_provider.py +++ b/nanobot/providers/azure_openai_provider.py @@ -2,7 +2,7 @@ Uses ``AsyncOpenAI`` pointed at ``https://{endpoint}/openai/v1/`` which routes to the Responses API (``/responses``). Reuses shared conversion -helpers from :mod:`nanobot.providers.openai_responses_common`. +helpers from :mod:`nanobot.providers.openai_responses`. """ from __future__ import annotations @@ -14,7 +14,7 @@ from typing import Any from openai import AsyncOpenAI from nanobot.providers.base import LLMProvider, LLMResponse -from nanobot.providers.openai_responses_common import ( +from nanobot.providers.openai_responses import ( consume_sdk_stream, convert_messages, convert_tools, @@ -30,7 +30,7 @@ class AzureOpenAIProvider(LLMProvider): ``base_url = {endpoint}/openai/v1/`` - Calls ``client.responses.create()`` (Responses API) - Reuses shared message/tool/SSE conversion from - ``openai_responses_common`` + ``openai_responses`` """ def __init__( diff --git a/nanobot/providers/openai_codex_provider.py b/nanobot/providers/openai_codex_provider.py index 68145173b..265b4b106 100644 --- a/nanobot/providers/openai_codex_provider.py +++ b/nanobot/providers/openai_codex_provider.py @@ -13,7 +13,7 @@ from loguru import logger from oauth_cli_kit import get_token as get_codex_token from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest -from nanobot.providers.openai_responses_common import ( +from nanobot.providers.openai_responses import ( consume_sse, convert_messages, convert_tools, diff --git a/nanobot/providers/openai_responses_common/__init__.py b/nanobot/providers/openai_responses/__init__.py similarity index 80% rename from nanobot/providers/openai_responses_common/__init__.py rename to nanobot/providers/openai_responses/__init__.py index 80a03e43a..b40e896ed 100644 --- a/nanobot/providers/openai_responses_common/__init__.py +++ b/nanobot/providers/openai_responses/__init__.py @@ -1,12 +1,12 @@ """Shared helpers for OpenAI Responses API providers (Codex, Azure OpenAI).""" -from nanobot.providers.openai_responses_common.converters import ( +from nanobot.providers.openai_responses.converters import ( convert_messages, convert_tools, convert_user_message, split_tool_call_id, ) -from nanobot.providers.openai_responses_common.parsing import ( +from nanobot.providers.openai_responses.parsing import ( FINISH_REASON_MAP, consume_sdk_stream, consume_sse, diff --git a/nanobot/providers/openai_responses_common/converters.py b/nanobot/providers/openai_responses/converters.py similarity index 97% rename from nanobot/providers/openai_responses_common/converters.py rename to nanobot/providers/openai_responses/converters.py index 37596692d..e0bfe832d 100644 --- a/nanobot/providers/openai_responses_common/converters.py +++ b/nanobot/providers/openai_responses/converters.py @@ -58,8 +58,8 @@ def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str def convert_user_message(content: Any) -> dict[str, Any]: """Convert a user message's content to Responses API format. - Handles plain strings, ``text`` blocks → ``input_text``, and - ``image_url`` blocks → ``input_image``. + Handles plain strings, ``text`` blocks -> ``input_text``, and + ``image_url`` blocks -> ``input_image``. """ if isinstance(content, str): return {"role": "user", "content": [{"type": "input_text", "text": content}]} diff --git a/nanobot/providers/openai_responses_common/parsing.py b/nanobot/providers/openai_responses/parsing.py similarity index 91% rename from nanobot/providers/openai_responses_common/parsing.py rename to nanobot/providers/openai_responses/parsing.py index fa1ba13cf..9e3f0ef02 100644 --- a/nanobot/providers/openai_responses_common/parsing.py +++ b/nanobot/providers/openai_responses/parsing.py @@ -106,8 +106,11 @@ async def consume_sse( try: args = json.loads(args_raw) except Exception: - logger.warning("Failed to parse tool call arguments for '{}': {}", - buf.get("name") or item.get("name"), args_raw[:200]) + logger.warning( + "Failed to parse tool call arguments for '{}': {}", + buf.get("name") or item.get("name"), + args_raw[:200], + ) args = json_repair.loads(args_raw) if not isinstance(args, dict): args = {"raw": args_raw} @@ -129,12 +132,7 @@ async def consume_sse( def parse_response_output(response: Any) -> LLMResponse: - """Parse an SDK ``Response`` object (from ``client.responses.create()``) - into an ``LLMResponse``. - - Works with both Pydantic model objects and plain dicts. - """ - # Normalise to dict + """Parse an SDK ``Response`` object into an ``LLMResponse``.""" if not isinstance(response, dict): dump = getattr(response, "model_dump", None) response = dump() if callable(dump) else vars(response) @@ -158,7 +156,6 @@ def parse_response_output(response: Any) -> LLMResponse: if block.get("type") == "output_text": content_parts.append(block.get("text") or "") elif item_type == "reasoning": - # Reasoning items may have a summary list with text blocks for s in item.get("summary") or []: if not isinstance(s, dict): dump = getattr(s, "model_dump", None) @@ -172,8 +169,11 @@ def parse_response_output(response: Any) -> LLMResponse: try: args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw except Exception: - logger.warning("Failed to parse tool call arguments for '{}': {}", - item.get("name"), str(args_raw)[:200]) + logger.warning( + "Failed to parse tool call arguments for '{}': {}", + item.get("name"), + str(args_raw)[:200], + ) args = json_repair.loads(args_raw) if isinstance(args_raw, str) else args_raw if not isinstance(args, dict): args = {"raw": args_raw} @@ -211,12 +211,7 @@ async def consume_sdk_stream( stream: Any, on_content_delta: Callable[[str], Awaitable[None]] | None = None, ) -> tuple[str, list[ToolCallRequest], str, dict[str, int], str | None]: - """Consume an SDK async stream from ``client.responses.create(stream=True)``. - - The SDK yields typed event objects with a ``.type`` attribute and - event-specific fields. Returns - ``(content, tool_calls, finish_reason, usage, reasoning_content)``. - """ + """Consume an SDK async stream from ``client.responses.create(stream=True)``.""" content = "" tool_calls: list[ToolCallRequest] = [] tool_call_buffers: dict[str, dict[str, Any]] = {} @@ -261,9 +256,11 @@ async def consume_sdk_stream( try: args = json.loads(args_raw) except Exception: - logger.warning("Failed to parse tool call arguments for '{}': {}", - buf.get("name") or getattr(item, "name", None), - str(args_raw)[:200]) + logger.warning( + "Failed to parse tool call arguments for '{}': {}", + buf.get("name") or getattr(item, "name", None), + str(args_raw)[:200], + ) args = json_repair.loads(args_raw) if not isinstance(args, dict): args = {"raw": args_raw} @@ -278,7 +275,6 @@ async def consume_sdk_stream( resp = getattr(event, "response", None) status = getattr(resp, "status", None) if resp else None finish_reason = map_finish_reason(status) - # Extract usage from the completed response if resp: usage_obj = getattr(resp, "usage", None) if usage_obj: @@ -287,7 +283,6 @@ async def consume_sdk_stream( "completion_tokens": int(getattr(usage_obj, "output_tokens", 0) or 0), "total_tokens": int(getattr(usage_obj, "total_tokens", 0) or 0), } - # Extract reasoning_content from completed output items for out_item in getattr(resp, "output", None) or []: if getattr(out_item, "type", None) == "reasoning": for s in getattr(out_item, "summary", None) or []: diff --git a/tests/providers/test_openai_responses_common.py b/tests/providers/test_openai_responses.py similarity index 96% rename from tests/providers/test_openai_responses_common.py rename to tests/providers/test_openai_responses.py index 15d24041c..ce4220655 100644 --- a/tests/providers/test_openai_responses_common.py +++ b/tests/providers/test_openai_responses.py @@ -1,17 +1,17 @@ -"""Tests for the shared openai_responses_common converters and parsers.""" +"""Tests for the shared openai_responses converters and parsers.""" from unittest.mock import MagicMock, patch import pytest from nanobot.providers.base import LLMResponse, ToolCallRequest -from nanobot.providers.openai_responses_common.converters import ( +from nanobot.providers.openai_responses.converters import ( convert_messages, convert_tools, convert_user_message, split_tool_call_id, ) -from nanobot.providers.openai_responses_common.parsing import ( +from nanobot.providers.openai_responses.parsing import ( consume_sdk_stream, map_finish_reason, parse_response_output, @@ -19,7 +19,7 @@ from nanobot.providers.openai_responses_common.parsing import ( # ====================================================================== -# converters — split_tool_call_id +# converters - split_tool_call_id # ====================================================================== @@ -44,7 +44,7 @@ class TestSplitToolCallId: # ====================================================================== -# converters — convert_user_message +# converters - convert_user_message # ====================================================================== @@ -99,7 +99,7 @@ class TestConvertUserMessage: # ====================================================================== -# converters — convert_messages +# converters - convert_messages # ====================================================================== @@ -196,7 +196,7 @@ class TestConvertMessages: assert "_meta" not in str(item) def test_full_conversation_roundtrip(self): - """System + user + assistant(tool_call) + tool → correct structure.""" + """System + user + assistant(tool_call) + tool -> correct structure.""" msgs = [ {"role": "system", "content": "Be concise."}, {"role": "user", "content": "Weather in SF?"}, @@ -218,7 +218,7 @@ class TestConvertMessages: # ====================================================================== -# converters — convert_tools +# converters - convert_tools # ====================================================================== @@ -261,7 +261,7 @@ class TestConvertTools: # ====================================================================== -# parsing — map_finish_reason +# parsing - map_finish_reason # ====================================================================== @@ -286,7 +286,7 @@ class TestMapFinishReason: # ====================================================================== -# parsing — parse_response_output +# parsing - parse_response_output # ====================================================================== @@ -332,7 +332,7 @@ class TestParseResponseOutput: }], "status": "completed", "usage": {}, } - with patch("nanobot.providers.openai_responses_common.parsing.logger") as mock_logger: + with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger: result = parse_response_output(resp) assert result.tool_calls[0].arguments == {"raw": "{bad json"} mock_logger.warning.assert_called_once() @@ -392,7 +392,7 @@ class TestParseResponseOutput: # ====================================================================== -# parsing — consume_sdk_stream +# parsing - consume_sdk_stream # ====================================================================== @@ -515,7 +515,7 @@ class TestConsumeSdkStream: for e in [ev1, ev2, ev3, ev4]: yield e - with patch("nanobot.providers.openai_responses_common.parsing.logger") as mock_logger: + with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger: _, tool_calls, _, _, _ = await consume_sdk_stream(stream()) assert tool_calls[0].arguments == {"raw": "{bad"} mock_logger.warning.assert_called_once() From 7a6416bcb21a61659dcd6670924fcc0c7e80d4b3 Mon Sep 17 00:00:00 2001 From: haosenwang1018 Date: Thu, 2 Apr 2026 06:26:10 +0000 Subject: [PATCH 28/31] test(matrix): skip cleanly when optional deps are missing --- tests/channels/test_matrix_channel.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/channels/test_matrix_channel.py b/tests/channels/test_matrix_channel.py index 18a8e1097..27b7e1255 100644 --- a/tests/channels/test_matrix_channel.py +++ b/tests/channels/test_matrix_channel.py @@ -3,16 +3,14 @@ from pathlib import Path from types import SimpleNamespace import pytest + +pytest.importorskip("nio") +pytest.importorskip("nh3") +pytest.importorskip("mistune") from nio import RoomSendResponse from nanobot.channels.matrix import _build_matrix_text_content -# Check optional matrix dependencies before importing -try: - import nh3 # noqa: F401 -except ImportError: - pytest.skip("Matrix dependencies not installed (nh3)", allow_module_level=True) - import nanobot.channels.matrix as matrix_module from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus From 7332d133a772e826a753d6c2df823e405ca4fabf Mon Sep 17 00:00:00 2001 From: masterlyj <167326996+masterlyj@users.noreply.github.com> Date: Thu, 2 Apr 2026 13:49:08 +0800 Subject: [PATCH 29/31] feat(cli): add --config option to channels login and status commands Allows users to specify custom config file paths when managing channels. Usage: nanobot channels login weixin --config .nanobot-feishu/config.json nanobot channels status -c .nanobot-qq/config.json - Added optional --config/-c parameter to both commands - Defaults to ~/.nanobot/config.json when not specified - Maintains backward compatibility --- nanobot/cli/commands.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 49521aa16..b1a15ebfd 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -1023,12 +1023,14 @@ app.add_typer(channels_app, name="channels") @channels_app.command("status") -def channels_status(): +def channels_status( + config_path: str | None = typer.Option(None, "--config", "-c", help="Path to config file"), +): """Show channel status.""" from nanobot.channels.registry import discover_all from nanobot.config.loader import load_config - config = load_config() + config = load_config(Path(config_path) if config_path else None) table = Table(title="Channel Status") table.add_column("Channel", style="cyan") @@ -1115,12 +1117,13 @@ def _get_bridge_dir() -> Path: def channels_login( channel_name: str = typer.Argument(..., help="Channel name (e.g. weixin, whatsapp)"), force: bool = typer.Option(False, "--force", "-f", help="Force re-authentication even if already logged in"), + config_path: str | None = typer.Option(None, "--config", "-c", help="Path to config file"), ): """Authenticate with a channel via QR code or other interactive login.""" from nanobot.channels.registry import discover_all from nanobot.config.loader import load_config - config = load_config() + config = load_config(Path(config_path) if config_path else None) channel_cfg = getattr(config.channels, channel_name, None) or {} # Validate channel exists From 11ba733ab6d3c8abe79ca72f22e44b23c0d094a7 Mon Sep 17 00:00:00 2001 From: masterlyj <167326996+masterlyj@users.noreply.github.com> Date: Thu, 2 Apr 2026 14:09:48 +0800 Subject: [PATCH 30/31] fix(test): update load_config mock to accept config_path parameter --- tests/channels/test_channel_plugins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/channels/test_channel_plugins.py b/tests/channels/test_channel_plugins.py index a0b458a08..93bf7f1d0 100644 --- a/tests/channels/test_channel_plugins.py +++ b/tests/channels/test_channel_plugins.py @@ -208,7 +208,7 @@ def test_channels_login_uses_discovered_plugin_class(monkeypatch): seen["config"] = self.config return True - monkeypatch.setattr("nanobot.config.loader.load_config", lambda: Config()) + monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config()) monkeypatch.setattr( "nanobot.channels.registry.discover_all", lambda: {"fakeplugin": _LoginPlugin}, From 3558fe4933e8b89a27cfda3b1ff04d30f731de5c Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Thu, 2 Apr 2026 10:31:50 +0000 Subject: [PATCH 31/31] fix(cli): honor custom config path in channel commands --- nanobot/cli/commands.py | 16 ++++++-- tests/channels/test_channel_plugins.py | 51 ++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 4 deletions(-) diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index b1a15ebfd..53d17dfa8 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -1028,9 +1028,13 @@ def channels_status( ): """Show channel status.""" from nanobot.channels.registry import discover_all - from nanobot.config.loader import load_config + from nanobot.config.loader import load_config, set_config_path - config = load_config(Path(config_path) if config_path else None) + resolved_config_path = Path(config_path).expanduser().resolve() if config_path else None + if resolved_config_path is not None: + set_config_path(resolved_config_path) + + config = load_config(resolved_config_path) table = Table(title="Channel Status") table.add_column("Channel", style="cyan") @@ -1121,9 +1125,13 @@ def channels_login( ): """Authenticate with a channel via QR code or other interactive login.""" from nanobot.channels.registry import discover_all - from nanobot.config.loader import load_config + from nanobot.config.loader import load_config, set_config_path - config = load_config(Path(config_path) if config_path else None) + resolved_config_path = Path(config_path).expanduser().resolve() if config_path else None + if resolved_config_path is not None: + set_config_path(resolved_config_path) + + config = load_config(resolved_config_path) channel_cfg = getattr(config.channels, channel_name, None) or {} # Validate channel exists diff --git a/tests/channels/test_channel_plugins.py b/tests/channels/test_channel_plugins.py index 93bf7f1d0..4cf4fab21 100644 --- a/tests/channels/test_channel_plugins.py +++ b/tests/channels/test_channel_plugins.py @@ -220,6 +220,57 @@ def test_channels_login_uses_discovered_plugin_class(monkeypatch): assert seen["force"] is True +def test_channels_login_sets_custom_config_path(monkeypatch, tmp_path): + from nanobot.cli.commands import app + from nanobot.config.schema import Config + from typer.testing import CliRunner + + runner = CliRunner() + seen: dict[str, object] = {} + config_path = tmp_path / "custom-config.json" + + class _LoginPlugin(_FakePlugin): + async def login(self, force: bool = False) -> bool: + return True + + monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config()) + monkeypatch.setattr( + "nanobot.config.loader.set_config_path", + lambda path: seen.__setitem__("config_path", path), + ) + monkeypatch.setattr( + "nanobot.channels.registry.discover_all", + lambda: {"fakeplugin": _LoginPlugin}, + ) + + result = runner.invoke(app, ["channels", "login", "fakeplugin", "--config", str(config_path)]) + + assert result.exit_code == 0 + assert seen["config_path"] == config_path.resolve() + + +def test_channels_status_sets_custom_config_path(monkeypatch, tmp_path): + from nanobot.cli.commands import app + from nanobot.config.schema import Config + from typer.testing import CliRunner + + runner = CliRunner() + seen: dict[str, object] = {} + config_path = tmp_path / "custom-config.json" + + monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config()) + monkeypatch.setattr( + "nanobot.config.loader.set_config_path", + lambda path: seen.__setitem__("config_path", path), + ) + monkeypatch.setattr("nanobot.channels.registry.discover_all", lambda: {}) + + result = runner.invoke(app, ["channels", "status", "--config", str(config_path)]) + + assert result.exit_code == 0 + assert seen["config_path"] == config_path.resolve() + + @pytest.mark.asyncio async def test_manager_skips_disabled_plugin(): fake_config = SimpleNamespace(