mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-02 15:55:50 +00:00
Merge branch 'main' into nightly
This commit is contained in:
commit
723ed8172b
26
README.md
26
README.md
@ -20,6 +20,14 @@
|
|||||||
|
|
||||||
## 📢 News
|
## 📢 News
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> **Security note:** Due to `litellm` supply chain poisoning, **please check your Python environment ASAP** and refer to this [advisory](https://github.com/HKUDS/nanobot/discussions/2445) for details. We have fully removed the `litellm` dependency in [this commit](https://github.com/HKUDS/nanobot/commit/3dfdab7).
|
||||||
|
|
||||||
|
- **2026-03-21** 🔒 Replace `litellm` with native `openai` + `anthropic` SDKs. Please see [commit](https://github.com/HKUDS/nanobot/commit/3dfdab7).
|
||||||
|
- **2026-03-20** 🧙 Interactive setup wizard — pick your provider, model autocomplete, and you're good to go.
|
||||||
|
- **2026-03-19** 💬 Telegram gets more resilient under load; Feishu now renders code blocks properly.
|
||||||
|
- **2026-03-18** 📷 Telegram can now send media via URL. Cron schedules show human-readable details.
|
||||||
|
- **2026-03-17** ✨ Feishu formatting glow-up, Slack reacts when done, custom endpoints support extra headers, and image handling is more reliable.
|
||||||
- **2026-03-16** 🚀 Released **v0.1.4.post5** — a refinement-focused release with stronger reliability and channel support, and a more dependable day-to-day experience. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post5) for details.
|
- **2026-03-16** 🚀 Released **v0.1.4.post5** — a refinement-focused release with stronger reliability and channel support, and a more dependable day-to-day experience. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post5) for details.
|
||||||
- **2026-03-15** 🧩 DingTalk rich media, smarter built-in skills, and cleaner model compatibility.
|
- **2026-03-15** 🧩 DingTalk rich media, smarter built-in skills, and cleaner model compatibility.
|
||||||
- **2026-03-14** 💬 Channel plugins, Feishu replies, and steadier MCP, QQ, and media handling.
|
- **2026-03-14** 💬 Channel plugins, Feishu replies, and steadier MCP, QQ, and media handling.
|
||||||
@ -391,6 +399,7 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso
|
|||||||
> - `"mention"` (default) — Only respond when @mentioned
|
> - `"mention"` (default) — Only respond when @mentioned
|
||||||
> - `"open"` — Respond to all messages
|
> - `"open"` — Respond to all messages
|
||||||
> DMs always respond when the sender is in `allowFrom`.
|
> DMs always respond when the sender is in `allowFrom`.
|
||||||
|
> - If you set group policy to open create new threads as private threads and then @ the bot into it. Otherwise the thread itself and the channel in which you spawned it will spawn a bot session.
|
||||||
|
|
||||||
**5. Invite the bot**
|
**5. Invite the bot**
|
||||||
- OAuth2 → URL Generator
|
- OAuth2 → URL Generator
|
||||||
@ -772,6 +781,7 @@ pip install -e ".[weixin]"
|
|||||||
|
|
||||||
> - `allowFrom`: Add the sender ID you see in nanobot logs for your WeChat account. Use `["*"]` to allow all users.
|
> - `allowFrom`: Add the sender ID you see in nanobot logs for your WeChat account. Use `["*"]` to allow all users.
|
||||||
> - `token`: Optional. If omitted, log in interactively and nanobot will save the token for you.
|
> - `token`: Optional. If omitted, log in interactively and nanobot will save the token for you.
|
||||||
|
> - `routeTag`: Optional. When your upstream Weixin deployment requires request routing, nanobot will send it as the `SKRouteTag` header.
|
||||||
> - `stateDir`: Optional. Defaults to nanobot's runtime directory for Weixin state.
|
> - `stateDir`: Optional. Defaults to nanobot's runtime directory for Weixin state.
|
||||||
> - `pollTimeout`: Optional long-poll timeout in seconds.
|
> - `pollTimeout`: Optional long-poll timeout in seconds.
|
||||||
|
|
||||||
@ -933,7 +943,7 @@ Config file: `~/.nanobot/config.json`
|
|||||||
|
|
||||||
| Provider | Purpose | Get API Key |
|
| Provider | Purpose | Get API Key |
|
||||||
|----------|---------|-------------|
|
|----------|---------|-------------|
|
||||||
| `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — |
|
| `custom` | Any OpenAI-compatible endpoint | — |
|
||||||
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
|
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
|
||||||
| `volcengine` | LLM (VolcEngine, pay-per-use) | [Coding Plan](https://www.volcengine.com/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [volcengine.com](https://www.volcengine.com) |
|
| `volcengine` | LLM (VolcEngine, pay-per-use) | [Coding Plan](https://www.volcengine.com/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [volcengine.com](https://www.volcengine.com) |
|
||||||
| `byteplus` | LLM (VolcEngine international, pay-per-use) | [Coding Plan](https://www.byteplus.com/en/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [byteplus.com](https://www.byteplus.com) |
|
| `byteplus` | LLM (VolcEngine international, pay-per-use) | [Coding Plan](https://www.byteplus.com/en/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [byteplus.com](https://www.byteplus.com) |
|
||||||
@ -1034,7 +1044,7 @@ nanobot agent -c ~/.nanobot-telegram/config.json -w /tmp/nanobot-telegram-test -
|
|||||||
<details>
|
<details>
|
||||||
<summary><b>Custom Provider (Any OpenAI-compatible API)</b></summary>
|
<summary><b>Custom Provider (Any OpenAI-compatible API)</b></summary>
|
||||||
|
|
||||||
Connects directly to any OpenAI-compatible endpoint — LM Studio, llama.cpp, Together AI, Fireworks, Azure OpenAI, or any self-hosted server. Bypasses LiteLLM; model name is passed as-is.
|
Connects directly to any OpenAI-compatible endpoint — LM Studio, llama.cpp, Together AI, Fireworks, Azure OpenAI, or any self-hosted server. Model name is passed as-is.
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
@ -1211,10 +1221,9 @@ Adding a new provider only takes **2 steps** — no if-elif chains to touch.
|
|||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="myprovider", # config field name
|
name="myprovider", # config field name
|
||||||
keywords=("myprovider", "mymodel"), # model-name keywords for auto-matching
|
keywords=("myprovider", "mymodel"), # model-name keywords for auto-matching
|
||||||
env_key="MYPROVIDER_API_KEY", # env var for LiteLLM
|
env_key="MYPROVIDER_API_KEY", # env var name
|
||||||
display_name="My Provider", # shown in `nanobot status`
|
display_name="My Provider", # shown in `nanobot status`
|
||||||
litellm_prefix="myprovider", # auto-prefix: model → myprovider/model
|
default_api_base="https://api.myprovider.com/v1", # OpenAI-compatible endpoint
|
||||||
skip_prefixes=("myprovider/",), # don't double-prefix
|
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -1226,20 +1235,19 @@ class ProvidersConfig(BaseModel):
|
|||||||
myprovider: ProviderConfig = ProviderConfig()
|
myprovider: ProviderConfig = ProviderConfig()
|
||||||
```
|
```
|
||||||
|
|
||||||
That's it! Environment variables, model prefixing, config matching, and `nanobot status` display will all work automatically.
|
That's it! Environment variables, model routing, config matching, and `nanobot status` display will all work automatically.
|
||||||
|
|
||||||
**Common `ProviderSpec` options:**
|
**Common `ProviderSpec` options:**
|
||||||
|
|
||||||
| Field | Description | Example |
|
| Field | Description | Example |
|
||||||
|-------|-------------|---------|
|
|-------|-------------|---------|
|
||||||
| `litellm_prefix` | Auto-prefix model names for LiteLLM | `"dashscope"` → `dashscope/qwen-max` |
|
| `default_api_base` | OpenAI-compatible base URL | `"https://api.deepseek.com"` |
|
||||||
| `skip_prefixes` | Don't prefix if model already starts with these | `("dashscope/", "openrouter/")` |
|
|
||||||
| `env_extras` | Additional env vars to set | `(("ZHIPUAI_API_KEY", "{api_key}"),)` |
|
| `env_extras` | Additional env vars to set | `(("ZHIPUAI_API_KEY", "{api_key}"),)` |
|
||||||
| `model_overrides` | Per-model parameter overrides | `(("kimi-k2.5", {"temperature": 1.0}),)` |
|
| `model_overrides` | Per-model parameter overrides | `(("kimi-k2.5", {"temperature": 1.0}),)` |
|
||||||
| `is_gateway` | Can route any model (like OpenRouter) | `True` |
|
| `is_gateway` | Can route any model (like OpenRouter) | `True` |
|
||||||
| `detect_by_key_prefix` | Detect gateway by API key prefix | `"sk-or-"` |
|
| `detect_by_key_prefix` | Detect gateway by API key prefix | `"sk-or-"` |
|
||||||
| `detect_by_base_keyword` | Detect gateway by API base URL | `"openrouter"` |
|
| `detect_by_base_keyword` | Detect gateway by API base URL | `"openrouter"` |
|
||||||
| `strip_model_prefix` | Strip existing prefix before re-prefixing | `True` (for AiHubMix) |
|
| `strip_model_prefix` | Strip provider prefix before sending to gateway | `True` (for AiHubMix) |
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ Uses the ilinkai.weixin.qq.com API for personal WeChat messaging.
|
|||||||
No WebSocket, no local WeChat client needed — just HTTP requests with a
|
No WebSocket, no local WeChat client needed — just HTTP requests with a
|
||||||
bot token obtained via QR code login.
|
bot token obtained via QR code login.
|
||||||
|
|
||||||
Protocol reverse-engineered from ``@tencent-weixin/openclaw-weixin`` v1.0.2.
|
Protocol reverse-engineered from ``@tencent-weixin/openclaw-weixin`` v1.0.3.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@ -53,15 +53,18 @@ MESSAGE_TYPE_BOT = 2
|
|||||||
MESSAGE_STATE_FINISH = 2
|
MESSAGE_STATE_FINISH = 2
|
||||||
|
|
||||||
WEIXIN_MAX_MESSAGE_LEN = 4000
|
WEIXIN_MAX_MESSAGE_LEN = 4000
|
||||||
BASE_INFO: dict[str, str] = {"channel_version": "1.0.2"}
|
WEIXIN_CHANNEL_VERSION = "1.0.3"
|
||||||
|
BASE_INFO: dict[str, str] = {"channel_version": WEIXIN_CHANNEL_VERSION}
|
||||||
|
|
||||||
# Session-expired error code
|
# Session-expired error code
|
||||||
ERRCODE_SESSION_EXPIRED = -14
|
ERRCODE_SESSION_EXPIRED = -14
|
||||||
|
SESSION_PAUSE_DURATION_S = 60 * 60
|
||||||
|
|
||||||
# Retry constants (matching the reference plugin's monitor.ts)
|
# Retry constants (matching the reference plugin's monitor.ts)
|
||||||
MAX_CONSECUTIVE_FAILURES = 3
|
MAX_CONSECUTIVE_FAILURES = 3
|
||||||
BACKOFF_DELAY_S = 30
|
BACKOFF_DELAY_S = 30
|
||||||
RETRY_DELAY_S = 2
|
RETRY_DELAY_S = 2
|
||||||
|
MAX_QR_REFRESH_COUNT = 3
|
||||||
|
|
||||||
# Default long-poll timeout; overridden by server via longpolling_timeout_ms.
|
# Default long-poll timeout; overridden by server via longpolling_timeout_ms.
|
||||||
DEFAULT_LONG_POLL_TIMEOUT_S = 35
|
DEFAULT_LONG_POLL_TIMEOUT_S = 35
|
||||||
@ -83,6 +86,7 @@ class WeixinConfig(Base):
|
|||||||
allow_from: list[str] = Field(default_factory=list)
|
allow_from: list[str] = Field(default_factory=list)
|
||||||
base_url: str = "https://ilinkai.weixin.qq.com"
|
base_url: str = "https://ilinkai.weixin.qq.com"
|
||||||
cdn_base_url: str = "https://novac2c.cdn.weixin.qq.com/c2c"
|
cdn_base_url: str = "https://novac2c.cdn.weixin.qq.com/c2c"
|
||||||
|
route_tag: str | int | None = None
|
||||||
token: str = "" # Manually set token, or obtained via QR login
|
token: str = "" # Manually set token, or obtained via QR login
|
||||||
state_dir: str = "" # Default: ~/.nanobot/weixin/
|
state_dir: str = "" # Default: ~/.nanobot/weixin/
|
||||||
poll_timeout: int = DEFAULT_LONG_POLL_TIMEOUT_S # seconds for long-poll
|
poll_timeout: int = DEFAULT_LONG_POLL_TIMEOUT_S # seconds for long-poll
|
||||||
@ -119,6 +123,7 @@ class WeixinChannel(BaseChannel):
|
|||||||
self._token: str = ""
|
self._token: str = ""
|
||||||
self._poll_task: asyncio.Task | None = None
|
self._poll_task: asyncio.Task | None = None
|
||||||
self._next_poll_timeout_s: int = DEFAULT_LONG_POLL_TIMEOUT_S
|
self._next_poll_timeout_s: int = DEFAULT_LONG_POLL_TIMEOUT_S
|
||||||
|
self._session_pause_until: float = 0.0
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# State persistence
|
# State persistence
|
||||||
@ -144,6 +149,15 @@ class WeixinChannel(BaseChannel):
|
|||||||
data = json.loads(state_file.read_text())
|
data = json.loads(state_file.read_text())
|
||||||
self._token = data.get("token", "")
|
self._token = data.get("token", "")
|
||||||
self._get_updates_buf = data.get("get_updates_buf", "")
|
self._get_updates_buf = data.get("get_updates_buf", "")
|
||||||
|
context_tokens = data.get("context_tokens", {})
|
||||||
|
if isinstance(context_tokens, dict):
|
||||||
|
self._context_tokens = {
|
||||||
|
str(user_id): str(token)
|
||||||
|
for user_id, token in context_tokens.items()
|
||||||
|
if str(user_id).strip() and str(token).strip()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
self._context_tokens = {}
|
||||||
base_url = data.get("base_url", "")
|
base_url = data.get("base_url", "")
|
||||||
if base_url:
|
if base_url:
|
||||||
self.config.base_url = base_url
|
self.config.base_url = base_url
|
||||||
@ -158,6 +172,7 @@ class WeixinChannel(BaseChannel):
|
|||||||
data = {
|
data = {
|
||||||
"token": self._token,
|
"token": self._token,
|
||||||
"get_updates_buf": self._get_updates_buf,
|
"get_updates_buf": self._get_updates_buf,
|
||||||
|
"context_tokens": self._context_tokens,
|
||||||
"base_url": self.config.base_url,
|
"base_url": self.config.base_url,
|
||||||
}
|
}
|
||||||
state_file.write_text(json.dumps(data, ensure_ascii=False))
|
state_file.write_text(json.dumps(data, ensure_ascii=False))
|
||||||
@ -187,6 +202,8 @@ class WeixinChannel(BaseChannel):
|
|||||||
}
|
}
|
||||||
if auth and self._token:
|
if auth and self._token:
|
||||||
headers["Authorization"] = f"Bearer {self._token}"
|
headers["Authorization"] = f"Bearer {self._token}"
|
||||||
|
if self.config.route_tag is not None and str(self.config.route_tag).strip():
|
||||||
|
headers["SKRouteTag"] = str(self.config.route_tag).strip()
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
async def _api_get(
|
async def _api_get(
|
||||||
@ -226,24 +243,25 @@ class WeixinChannel(BaseChannel):
|
|||||||
# QR Code Login (matches login-qr.ts)
|
# QR Code Login (matches login-qr.ts)
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _fetch_qr_code(self) -> tuple[str, str]:
|
||||||
|
"""Fetch a fresh QR code. Returns (qrcode_id, scan_url)."""
|
||||||
|
data = await self._api_get(
|
||||||
|
"ilink/bot/get_bot_qrcode",
|
||||||
|
params={"bot_type": "3"},
|
||||||
|
auth=False,
|
||||||
|
)
|
||||||
|
qrcode_img_content = data.get("qrcode_img_content", "")
|
||||||
|
qrcode_id = data.get("qrcode", "")
|
||||||
|
if not qrcode_id:
|
||||||
|
raise RuntimeError(f"Failed to get QR code from WeChat API: {data}")
|
||||||
|
return qrcode_id, (qrcode_img_content or qrcode_id)
|
||||||
|
|
||||||
async def _qr_login(self) -> bool:
|
async def _qr_login(self) -> bool:
|
||||||
"""Perform QR code login flow. Returns True on success."""
|
"""Perform QR code login flow. Returns True on success."""
|
||||||
try:
|
try:
|
||||||
logger.info("Starting WeChat QR code login...")
|
logger.info("Starting WeChat QR code login...")
|
||||||
|
refresh_count = 0
|
||||||
data = await self._api_get(
|
qrcode_id, scan_url = await self._fetch_qr_code()
|
||||||
"ilink/bot/get_bot_qrcode",
|
|
||||||
params={"bot_type": "3"},
|
|
||||||
auth=False,
|
|
||||||
)
|
|
||||||
qrcode_img_content = data.get("qrcode_img_content", "")
|
|
||||||
qrcode_id = data.get("qrcode", "")
|
|
||||||
|
|
||||||
if not qrcode_id:
|
|
||||||
logger.error("Failed to get QR code from WeChat API: {}", data)
|
|
||||||
return False
|
|
||||||
|
|
||||||
scan_url = qrcode_img_content or qrcode_id
|
|
||||||
self._print_qr_code(scan_url)
|
self._print_qr_code(scan_url)
|
||||||
|
|
||||||
logger.info("Waiting for QR code scan...")
|
logger.info("Waiting for QR code scan...")
|
||||||
@ -283,8 +301,23 @@ class WeixinChannel(BaseChannel):
|
|||||||
elif status == "scaned":
|
elif status == "scaned":
|
||||||
logger.info("QR code scanned, waiting for confirmation...")
|
logger.info("QR code scanned, waiting for confirmation...")
|
||||||
elif status == "expired":
|
elif status == "expired":
|
||||||
logger.warning("QR code expired")
|
refresh_count += 1
|
||||||
return False
|
if refresh_count > MAX_QR_REFRESH_COUNT:
|
||||||
|
logger.warning(
|
||||||
|
"QR code expired too many times ({}/{}), giving up.",
|
||||||
|
refresh_count - 1,
|
||||||
|
MAX_QR_REFRESH_COUNT,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
logger.warning(
|
||||||
|
"QR code expired, refreshing... ({}/{})",
|
||||||
|
refresh_count,
|
||||||
|
MAX_QR_REFRESH_COUNT,
|
||||||
|
)
|
||||||
|
qrcode_id, scan_url = await self._fetch_qr_code()
|
||||||
|
self._print_qr_code(scan_url)
|
||||||
|
logger.info("New QR code generated, waiting for scan...")
|
||||||
|
continue
|
||||||
# status == "wait" — keep polling
|
# status == "wait" — keep polling
|
||||||
|
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
@ -392,7 +425,34 @@ class WeixinChannel(BaseChannel):
|
|||||||
# Polling (matches monitor.ts monitorWeixinProvider)
|
# Polling (matches monitor.ts monitorWeixinProvider)
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _pause_session(self, duration_s: int = SESSION_PAUSE_DURATION_S) -> None:
|
||||||
|
self._session_pause_until = time.time() + duration_s
|
||||||
|
|
||||||
|
def _session_pause_remaining_s(self) -> int:
|
||||||
|
remaining = int(self._session_pause_until - time.time())
|
||||||
|
if remaining <= 0:
|
||||||
|
self._session_pause_until = 0.0
|
||||||
|
return 0
|
||||||
|
return remaining
|
||||||
|
|
||||||
|
def _assert_session_active(self) -> None:
|
||||||
|
remaining = self._session_pause_remaining_s()
|
||||||
|
if remaining > 0:
|
||||||
|
remaining_min = max((remaining + 59) // 60, 1)
|
||||||
|
raise RuntimeError(
|
||||||
|
f"WeChat session paused, {remaining_min} min remaining (errcode {ERRCODE_SESSION_EXPIRED})"
|
||||||
|
)
|
||||||
|
|
||||||
async def _poll_once(self) -> None:
|
async def _poll_once(self) -> None:
|
||||||
|
remaining = self._session_pause_remaining_s()
|
||||||
|
if remaining > 0:
|
||||||
|
logger.warning(
|
||||||
|
"WeChat session paused, waiting {} min before next poll.",
|
||||||
|
max((remaining + 59) // 60, 1),
|
||||||
|
)
|
||||||
|
await asyncio.sleep(remaining)
|
||||||
|
return
|
||||||
|
|
||||||
body: dict[str, Any] = {
|
body: dict[str, Any] = {
|
||||||
"get_updates_buf": self._get_updates_buf,
|
"get_updates_buf": self._get_updates_buf,
|
||||||
"base_info": BASE_INFO,
|
"base_info": BASE_INFO,
|
||||||
@ -411,11 +471,13 @@ class WeixinChannel(BaseChannel):
|
|||||||
|
|
||||||
if is_error:
|
if is_error:
|
||||||
if errcode == ERRCODE_SESSION_EXPIRED or ret == ERRCODE_SESSION_EXPIRED:
|
if errcode == ERRCODE_SESSION_EXPIRED or ret == ERRCODE_SESSION_EXPIRED:
|
||||||
|
self._pause_session()
|
||||||
|
remaining = self._session_pause_remaining_s()
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"WeChat session expired (errcode {}). Pausing 60 min.",
|
"WeChat session expired (errcode {}). Pausing {} min.",
|
||||||
errcode,
|
errcode,
|
||||||
|
max((remaining + 59) // 60, 1),
|
||||||
)
|
)
|
||||||
await asyncio.sleep(3600)
|
|
||||||
return
|
return
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"getUpdates failed: ret={ret} errcode={errcode} errmsg={data.get('errmsg', '')}"
|
f"getUpdates failed: ret={ret} errcode={errcode} errmsg={data.get('errmsg', '')}"
|
||||||
@ -468,6 +530,7 @@ class WeixinChannel(BaseChannel):
|
|||||||
ctx_token = msg.get("context_token", "")
|
ctx_token = msg.get("context_token", "")
|
||||||
if ctx_token:
|
if ctx_token:
|
||||||
self._context_tokens[from_user_id] = ctx_token
|
self._context_tokens[from_user_id] = ctx_token
|
||||||
|
self._save_state()
|
||||||
|
|
||||||
# Parse item_list (WeixinMessage.item_list — types.ts:161)
|
# Parse item_list (WeixinMessage.item_list — types.ts:161)
|
||||||
item_list: list[dict] = msg.get("item_list") or []
|
item_list: list[dict] = msg.get("item_list") or []
|
||||||
@ -651,6 +714,11 @@ class WeixinChannel(BaseChannel):
|
|||||||
if not self._client or not self._token:
|
if not self._client or not self._token:
|
||||||
logger.warning("WeChat client not initialized or not authenticated")
|
logger.warning("WeChat client not initialized or not authenticated")
|
||||||
return
|
return
|
||||||
|
try:
|
||||||
|
self._assert_session_active()
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.warning("WeChat send blocked: {}", e)
|
||||||
|
return
|
||||||
|
|
||||||
content = msg.content.strip()
|
content = msg.content.strip()
|
||||||
ctx_token = self._context_tokens.get(msg.chat_id, "")
|
ctx_token = self._context_tokens.get(msg.chat_id, "")
|
||||||
@ -731,7 +799,7 @@ class WeixinChannel(BaseChannel):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Upload a local file to WeChat CDN and send it as a media message.
|
"""Upload a local file to WeChat CDN and send it as a media message.
|
||||||
|
|
||||||
Follows the exact protocol from ``@tencent-weixin/openclaw-weixin`` v1.0.2:
|
Follows the exact protocol from ``@tencent-weixin/openclaw-weixin`` v1.0.3:
|
||||||
1. Generate a random 16-byte AES key (client-side).
|
1. Generate a random 16-byte AES key (client-side).
|
||||||
2. Call ``getuploadurl`` with file metadata + hex-encoded AES key.
|
2. Call ``getuploadurl`` with file metadata + hex-encoded AES key.
|
||||||
3. AES-128-ECB encrypt the file and POST to CDN (``{cdnBaseUrl}/upload``).
|
3. AES-128-ECB encrypt the file and POST to CDN (``{cdnBaseUrl}/upload``).
|
||||||
|
|||||||
@ -376,61 +376,61 @@ def _onboard_plugins(config_path: Path) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _make_provider(config: Config):
|
def _make_provider(config: Config):
|
||||||
"""Create the appropriate LLM provider from config."""
|
"""Create the appropriate LLM provider from config.
|
||||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
|
||||||
|
Routing is driven by ``ProviderSpec.backend`` in the registry.
|
||||||
|
"""
|
||||||
from nanobot.providers.base import GenerationSettings
|
from nanobot.providers.base import GenerationSettings
|
||||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
from nanobot.providers.registry import find_by_name
|
||||||
|
|
||||||
model = config.agents.defaults.model
|
model = config.agents.defaults.model
|
||||||
provider_name = config.get_provider_name(model)
|
provider_name = config.get_provider_name(model)
|
||||||
p = config.get_provider(model)
|
p = config.get_provider(model)
|
||||||
|
spec = find_by_name(provider_name) if provider_name else None
|
||||||
|
backend = spec.backend if spec else "openai_compat"
|
||||||
|
|
||||||
# OpenAI Codex (OAuth)
|
# --- validation ---
|
||||||
if provider_name == "openai_codex" or model.startswith("openai-codex/"):
|
if backend == "azure_openai":
|
||||||
provider = OpenAICodexProvider(default_model=model)
|
|
||||||
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
|
|
||||||
elif provider_name == "custom":
|
|
||||||
from nanobot.providers.custom_provider import CustomProvider
|
|
||||||
provider = CustomProvider(
|
|
||||||
api_key=p.api_key if p else "no-key",
|
|
||||||
api_base=config.get_api_base(model) or "http://localhost:8000/v1",
|
|
||||||
default_model=model,
|
|
||||||
extra_headers=p.extra_headers if p else None,
|
|
||||||
)
|
|
||||||
# Azure OpenAI: direct Azure OpenAI endpoint with deployment name
|
|
||||||
elif provider_name == "azure_openai":
|
|
||||||
if not p or not p.api_key or not p.api_base:
|
if not p or not p.api_key or not p.api_base:
|
||||||
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
|
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
|
||||||
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
|
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
|
||||||
console.print("Use the model field to specify the deployment name.")
|
console.print("Use the model field to specify the deployment name.")
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
elif backend == "openai_compat" and not model.startswith("bedrock/"):
|
||||||
|
needs_key = not (p and p.api_key)
|
||||||
|
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
|
||||||
|
if needs_key and not exempt:
|
||||||
|
console.print("[red]Error: No API key configured.[/red]")
|
||||||
|
console.print("Set one in ~/.nanobot/config.json under providers section")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
# --- instantiation by backend ---
|
||||||
|
if backend == "openai_codex":
|
||||||
|
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||||
|
provider = OpenAICodexProvider(default_model=model)
|
||||||
|
elif backend == "azure_openai":
|
||||||
|
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||||
provider = AzureOpenAIProvider(
|
provider = AzureOpenAIProvider(
|
||||||
api_key=p.api_key,
|
api_key=p.api_key,
|
||||||
api_base=p.api_base,
|
api_base=p.api_base,
|
||||||
default_model=model,
|
default_model=model,
|
||||||
)
|
)
|
||||||
# OpenVINO Model Server: direct OpenAI-compatible endpoint at /v3
|
elif backend == "anthropic":
|
||||||
elif provider_name == "ovms":
|
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||||
from nanobot.providers.custom_provider import CustomProvider
|
provider = AnthropicProvider(
|
||||||
provider = CustomProvider(
|
|
||||||
api_key=p.api_key if p else "no-key",
|
|
||||||
api_base=config.get_api_base(model) or "http://localhost:8000/v3",
|
|
||||||
default_model=model,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
|
||||||
from nanobot.providers.registry import find_by_name
|
|
||||||
spec = find_by_name(provider_name)
|
|
||||||
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and (spec.is_oauth or spec.is_local)):
|
|
||||||
console.print("[red]Error: No API key configured.[/red]")
|
|
||||||
console.print("Set one in ~/.nanobot/config.json under providers section")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
provider = LiteLLMProvider(
|
|
||||||
api_key=p.api_key if p else None,
|
api_key=p.api_key if p else None,
|
||||||
api_base=config.get_api_base(model),
|
api_base=config.get_api_base(model),
|
||||||
default_model=model,
|
default_model=model,
|
||||||
extra_headers=p.extra_headers if p else None,
|
extra_headers=p.extra_headers if p else None,
|
||||||
provider_name=provider_name,
|
)
|
||||||
|
else:
|
||||||
|
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||||
|
provider = OpenAICompatProvider(
|
||||||
|
api_key=p.api_key if p else None,
|
||||||
|
api_base=config.get_api_base(model),
|
||||||
|
default_model=model,
|
||||||
|
extra_headers=p.extra_headers if p else None,
|
||||||
|
spec=spec,
|
||||||
)
|
)
|
||||||
|
|
||||||
defaults = config.agents.defaults
|
defaults = config.agents.defaults
|
||||||
@ -1207,11 +1207,20 @@ def _login_openai_codex() -> None:
|
|||||||
def _login_github_copilot() -> None:
|
def _login_github_copilot() -> None:
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n")
|
console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n")
|
||||||
|
|
||||||
async def _trigger():
|
async def _trigger():
|
||||||
from litellm import acompletion
|
client = AsyncOpenAI(
|
||||||
await acompletion(model="github_copilot/gpt-4o", messages=[{"role": "user", "content": "hi"}], max_tokens=1)
|
api_key="dummy",
|
||||||
|
base_url="https://api.githubcopilot.com",
|
||||||
|
)
|
||||||
|
await client.chat.completions.create(
|
||||||
|
model="gpt-4o",
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
max_tokens=1,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
asyncio.run(_trigger())
|
asyncio.run(_trigger())
|
||||||
|
|||||||
@ -1,229 +1,29 @@
|
|||||||
"""Model information helpers for the onboard wizard.
|
"""Model information helpers for the onboard wizard.
|
||||||
|
|
||||||
Provides model context window lookup and autocomplete suggestions using litellm.
|
Model database / autocomplete is temporarily disabled while litellm is
|
||||||
|
being replaced. All public function signatures are preserved so callers
|
||||||
|
continue to work without changes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from functools import lru_cache
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def _litellm():
|
|
||||||
"""Lazy accessor for litellm (heavy import deferred until actually needed)."""
|
|
||||||
import litellm as _ll
|
|
||||||
|
|
||||||
return _ll
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
|
||||||
def _get_model_cost_map() -> dict[str, Any]:
|
|
||||||
"""Get litellm's model cost map (cached)."""
|
|
||||||
return getattr(_litellm(), "model_cost", {})
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
|
||||||
def get_all_models() -> list[str]:
|
def get_all_models() -> list[str]:
|
||||||
"""Get all known model names from litellm.
|
return []
|
||||||
"""
|
|
||||||
models = set()
|
|
||||||
|
|
||||||
# From model_cost (has pricing info)
|
|
||||||
cost_map = _get_model_cost_map()
|
|
||||||
for k in cost_map.keys():
|
|
||||||
if k != "sample_spec":
|
|
||||||
models.add(k)
|
|
||||||
|
|
||||||
# From models_by_provider (more complete provider coverage)
|
|
||||||
for provider_models in getattr(_litellm(), "models_by_provider", {}).values():
|
|
||||||
if isinstance(provider_models, (set, list)):
|
|
||||||
models.update(provider_models)
|
|
||||||
|
|
||||||
return sorted(models)
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_model_name(model: str) -> str:
|
|
||||||
"""Normalize model name for comparison."""
|
|
||||||
return model.lower().replace("-", "_").replace(".", "")
|
|
||||||
|
|
||||||
|
|
||||||
def find_model_info(model_name: str) -> dict[str, Any] | None:
|
def find_model_info(model_name: str) -> dict[str, Any] | None:
|
||||||
"""Find model info with fuzzy matching.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Model name in any common format
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Model info dict or None if not found
|
|
||||||
"""
|
|
||||||
cost_map = _get_model_cost_map()
|
|
||||||
if not cost_map:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Direct match
|
|
||||||
if model_name in cost_map:
|
|
||||||
return cost_map[model_name]
|
|
||||||
|
|
||||||
# Extract base name (without provider prefix)
|
|
||||||
base_name = model_name.split("/")[-1] if "/" in model_name else model_name
|
|
||||||
base_normalized = _normalize_model_name(base_name)
|
|
||||||
|
|
||||||
candidates = []
|
|
||||||
|
|
||||||
for key, info in cost_map.items():
|
|
||||||
if key == "sample_spec":
|
|
||||||
continue
|
|
||||||
|
|
||||||
key_base = key.split("/")[-1] if "/" in key else key
|
|
||||||
key_base_normalized = _normalize_model_name(key_base)
|
|
||||||
|
|
||||||
# Score the match
|
|
||||||
score = 0
|
|
||||||
|
|
||||||
# Exact base name match (highest priority)
|
|
||||||
if base_normalized == key_base_normalized:
|
|
||||||
score = 100
|
|
||||||
# Base name contains model
|
|
||||||
elif base_normalized in key_base_normalized:
|
|
||||||
score = 80
|
|
||||||
# Model contains base name
|
|
||||||
elif key_base_normalized in base_normalized:
|
|
||||||
score = 70
|
|
||||||
# Partial match
|
|
||||||
elif base_normalized[:10] in key_base_normalized:
|
|
||||||
score = 50
|
|
||||||
|
|
||||||
if score > 0:
|
|
||||||
# Prefer models with max_input_tokens
|
|
||||||
if info.get("max_input_tokens"):
|
|
||||||
score += 10
|
|
||||||
candidates.append((score, key, info))
|
|
||||||
|
|
||||||
if not candidates:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Return the best match
|
|
||||||
candidates.sort(key=lambda x: (-x[0], x[1]))
|
|
||||||
return candidates[0][2]
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_context_limit(model: str, provider: str = "auto") -> int | None:
|
|
||||||
"""Get the maximum input context tokens for a model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: Model name (e.g., "claude-3.5-sonnet", "gpt-4o")
|
|
||||||
provider: Provider name for informational purposes (not yet used for filtering)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Maximum input tokens, or None if unknown
|
|
||||||
|
|
||||||
Note:
|
|
||||||
The provider parameter is currently informational only. Future versions may
|
|
||||||
use it to prefer provider-specific model variants in the lookup.
|
|
||||||
"""
|
|
||||||
# First try fuzzy search in model_cost (has more accurate max_input_tokens)
|
|
||||||
info = find_model_info(model)
|
|
||||||
if info:
|
|
||||||
# Prefer max_input_tokens (this is what we want for context window)
|
|
||||||
max_input = info.get("max_input_tokens")
|
|
||||||
if max_input and isinstance(max_input, int):
|
|
||||||
return max_input
|
|
||||||
|
|
||||||
# Fall back to litellm's get_max_tokens (returns max_output_tokens typically)
|
|
||||||
try:
|
|
||||||
result = _litellm().get_max_tokens(model)
|
|
||||||
if result and result > 0:
|
|
||||||
return result
|
|
||||||
except (KeyError, ValueError, AttributeError):
|
|
||||||
# Model not found in litellm's database or invalid response
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Last resort: use max_tokens from model_cost
|
|
||||||
if info:
|
|
||||||
max_tokens = info.get("max_tokens")
|
|
||||||
if max_tokens and isinstance(max_tokens, int):
|
|
||||||
return max_tokens
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
def get_model_context_limit(model: str, provider: str = "auto") -> int | None:
|
||||||
def _get_provider_keywords() -> dict[str, list[str]]:
|
return None
|
||||||
"""Build provider keywords mapping from nanobot's provider registry.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict mapping provider name to list of keywords for model filtering.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from nanobot.providers.registry import PROVIDERS
|
|
||||||
|
|
||||||
mapping = {}
|
|
||||||
for spec in PROVIDERS:
|
|
||||||
if spec.keywords:
|
|
||||||
mapping[spec.name] = list(spec.keywords)
|
|
||||||
return mapping
|
|
||||||
except ImportError:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_suggestions(partial: str, provider: str = "auto", limit: int = 20) -> list[str]:
|
def get_model_suggestions(partial: str, provider: str = "auto", limit: int = 20) -> list[str]:
|
||||||
"""Get autocomplete suggestions for model names.
|
return []
|
||||||
|
|
||||||
Args:
|
|
||||||
partial: Partial model name typed by user
|
|
||||||
provider: Provider name for filtering (e.g., "openrouter", "minimax")
|
|
||||||
limit: Maximum number of suggestions to return
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of matching model names
|
|
||||||
"""
|
|
||||||
all_models = get_all_models()
|
|
||||||
if not all_models:
|
|
||||||
return []
|
|
||||||
|
|
||||||
partial_lower = partial.lower()
|
|
||||||
partial_normalized = _normalize_model_name(partial)
|
|
||||||
|
|
||||||
# Get provider keywords from registry
|
|
||||||
provider_keywords = _get_provider_keywords()
|
|
||||||
|
|
||||||
# Filter by provider if specified
|
|
||||||
allowed_keywords = None
|
|
||||||
if provider and provider != "auto":
|
|
||||||
allowed_keywords = provider_keywords.get(provider.lower())
|
|
||||||
|
|
||||||
matches = []
|
|
||||||
|
|
||||||
for model in all_models:
|
|
||||||
model_lower = model.lower()
|
|
||||||
|
|
||||||
# Apply provider filter
|
|
||||||
if allowed_keywords:
|
|
||||||
if not any(kw in model_lower for kw in allowed_keywords):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Match against partial input
|
|
||||||
if not partial:
|
|
||||||
matches.append(model)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if partial_lower in model_lower:
|
|
||||||
# Score by position of match (earlier = better)
|
|
||||||
pos = model_lower.find(partial_lower)
|
|
||||||
score = 100 - pos
|
|
||||||
matches.append((score, model))
|
|
||||||
elif partial_normalized in _normalize_model_name(model):
|
|
||||||
score = 50
|
|
||||||
matches.append((score, model))
|
|
||||||
|
|
||||||
# Sort by score if we have scored matches
|
|
||||||
if matches and isinstance(matches[0], tuple):
|
|
||||||
matches.sort(key=lambda x: (-x[0], x[1]))
|
|
||||||
matches = [m[1] for m in matches]
|
|
||||||
else:
|
|
||||||
matches.sort()
|
|
||||||
|
|
||||||
return matches[:limit]
|
|
||||||
|
|
||||||
|
|
||||||
def format_token_count(tokens: int) -> str:
|
def format_token_count(tokens: int) -> str:
|
||||||
|
|||||||
@ -259,8 +259,7 @@ class Config(BaseSettings):
|
|||||||
if p and p.api_base:
|
if p and p.api_base:
|
||||||
return p.api_base
|
return p.api_base
|
||||||
# Only gateways get a default api_base here. Standard providers
|
# Only gateways get a default api_base here. Standard providers
|
||||||
# (like Moonshot) set their base URL via env vars in _setup_env
|
# resolve their base URL from the registry in the provider constructor.
|
||||||
# to avoid polluting the global litellm.api_base.
|
|
||||||
if name:
|
if name:
|
||||||
spec = find_by_name(name)
|
spec = find_by_name(name)
|
||||||
if spec and (spec.is_gateway or spec.is_local) and spec.default_api_base:
|
if spec and (spec.is_gateway or spec.is_local) and spec.default_api_base:
|
||||||
|
|||||||
@ -7,17 +7,26 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse
|
from nanobot.providers.base import LLMProvider, LLMResponse
|
||||||
|
|
||||||
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider", "AzureOpenAIProvider"]
|
__all__ = [
|
||||||
|
"LLMProvider",
|
||||||
|
"LLMResponse",
|
||||||
|
"AnthropicProvider",
|
||||||
|
"OpenAICompatProvider",
|
||||||
|
"OpenAICodexProvider",
|
||||||
|
"AzureOpenAIProvider",
|
||||||
|
]
|
||||||
|
|
||||||
_LAZY_IMPORTS = {
|
_LAZY_IMPORTS = {
|
||||||
"LiteLLMProvider": ".litellm_provider",
|
"AnthropicProvider": ".anthropic_provider",
|
||||||
|
"OpenAICompatProvider": ".openai_compat_provider",
|
||||||
"OpenAICodexProvider": ".openai_codex_provider",
|
"OpenAICodexProvider": ".openai_codex_provider",
|
||||||
"AzureOpenAIProvider": ".azure_openai_provider",
|
"AzureOpenAIProvider": ".azure_openai_provider",
|
||||||
}
|
}
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
441
nanobot/providers/anthropic_provider.py
Normal file
441
nanobot/providers/anthropic_provider.py
Normal file
@ -0,0 +1,441 @@
|
|||||||
|
"""Anthropic provider — direct SDK integration for Claude models."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
import secrets
|
||||||
|
import string
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import json_repair
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
_ALNUM = string.ascii_letters + string.digits
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_tool_id() -> str:
|
||||||
|
return "toolu_" + "".join(secrets.choice(_ALNUM) for _ in range(22))
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicProvider(LLMProvider):
|
||||||
|
"""LLM provider using the native Anthropic SDK for Claude models.
|
||||||
|
|
||||||
|
Handles message format conversion (OpenAI → Anthropic Messages API),
|
||||||
|
prompt caching, extended thinking, tool calls, and streaming.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str | None = None,
|
||||||
|
api_base: str | None = None,
|
||||||
|
default_model: str = "claude-sonnet-4-20250514",
|
||||||
|
extra_headers: dict[str, str] | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(api_key, api_base)
|
||||||
|
self.default_model = default_model
|
||||||
|
self.extra_headers = extra_headers or {}
|
||||||
|
|
||||||
|
from anthropic import AsyncAnthropic
|
||||||
|
|
||||||
|
client_kw: dict[str, Any] = {}
|
||||||
|
if api_key:
|
||||||
|
client_kw["api_key"] = api_key
|
||||||
|
if api_base:
|
||||||
|
client_kw["base_url"] = api_base
|
||||||
|
if extra_headers:
|
||||||
|
client_kw["default_headers"] = extra_headers
|
||||||
|
self._client = AsyncAnthropic(**client_kw)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _strip_prefix(model: str) -> str:
|
||||||
|
if model.startswith("anthropic/"):
|
||||||
|
return model[len("anthropic/"):]
|
||||||
|
return model
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Message conversion: OpenAI chat format → Anthropic Messages API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _convert_messages(
|
||||||
|
self, messages: list[dict[str, Any]],
|
||||||
|
) -> tuple[str | list[dict[str, Any]], list[dict[str, Any]]]:
|
||||||
|
"""Return ``(system, anthropic_messages)``."""
|
||||||
|
system: str | list[dict[str, Any]] = ""
|
||||||
|
raw: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role", "")
|
||||||
|
content = msg.get("content")
|
||||||
|
|
||||||
|
if role == "system":
|
||||||
|
system = content if isinstance(content, (str, list)) else str(content or "")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if role == "tool":
|
||||||
|
block = self._tool_result_block(msg)
|
||||||
|
if raw and raw[-1]["role"] == "user":
|
||||||
|
prev_c = raw[-1]["content"]
|
||||||
|
if isinstance(prev_c, list):
|
||||||
|
prev_c.append(block)
|
||||||
|
else:
|
||||||
|
raw[-1]["content"] = [
|
||||||
|
{"type": "text", "text": prev_c or ""}, block,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raw.append({"role": "user", "content": [block]})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if role == "assistant":
|
||||||
|
raw.append({"role": "assistant", "content": self._assistant_blocks(msg)})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if role == "user":
|
||||||
|
raw.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": self._convert_user_content(content),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
return system, self._merge_consecutive(raw)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _tool_result_block(msg: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
content = msg.get("content")
|
||||||
|
block: dict[str, Any] = {
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": msg.get("tool_call_id", ""),
|
||||||
|
}
|
||||||
|
if isinstance(content, (str, list)):
|
||||||
|
block["content"] = content
|
||||||
|
else:
|
||||||
|
block["content"] = str(content) if content else ""
|
||||||
|
return block
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _assistant_blocks(msg: dict[str, Any]) -> list[dict[str, Any]]:
|
||||||
|
blocks: list[dict[str, Any]] = []
|
||||||
|
content = msg.get("content")
|
||||||
|
|
||||||
|
for tb in msg.get("thinking_blocks") or []:
|
||||||
|
if isinstance(tb, dict) and tb.get("type") == "thinking":
|
||||||
|
blocks.append({
|
||||||
|
"type": "thinking",
|
||||||
|
"thinking": tb.get("thinking", ""),
|
||||||
|
"signature": tb.get("signature", ""),
|
||||||
|
})
|
||||||
|
|
||||||
|
if isinstance(content, str) and content:
|
||||||
|
blocks.append({"type": "text", "text": content})
|
||||||
|
elif isinstance(content, list):
|
||||||
|
for item in content:
|
||||||
|
blocks.append(item if isinstance(item, dict) else {"type": "text", "text": str(item)})
|
||||||
|
|
||||||
|
for tc in msg.get("tool_calls") or []:
|
||||||
|
if not isinstance(tc, dict):
|
||||||
|
continue
|
||||||
|
func = tc.get("function", {})
|
||||||
|
args = func.get("arguments", "{}")
|
||||||
|
if isinstance(args, str):
|
||||||
|
args = json_repair.loads(args)
|
||||||
|
blocks.append({
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": tc.get("id") or _gen_tool_id(),
|
||||||
|
"name": func.get("name", ""),
|
||||||
|
"input": args,
|
||||||
|
})
|
||||||
|
|
||||||
|
return blocks or [{"type": "text", "text": ""}]
|
||||||
|
|
||||||
|
def _convert_user_content(self, content: Any) -> Any:
|
||||||
|
"""Convert user message content, translating image_url blocks."""
|
||||||
|
if isinstance(content, str) or content is None:
|
||||||
|
return content or "(empty)"
|
||||||
|
if not isinstance(content, list):
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
result: list[dict[str, Any]] = []
|
||||||
|
for item in content:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
result.append({"type": "text", "text": str(item)})
|
||||||
|
continue
|
||||||
|
if item.get("type") == "image_url":
|
||||||
|
converted = self._convert_image_block(item)
|
||||||
|
if converted:
|
||||||
|
result.append(converted)
|
||||||
|
continue
|
||||||
|
result.append(item)
|
||||||
|
return result or "(empty)"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_image_block(block: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
"""Convert OpenAI image_url block to Anthropic image block."""
|
||||||
|
url = (block.get("image_url") or {}).get("url", "")
|
||||||
|
if not url:
|
||||||
|
return None
|
||||||
|
m = re.match(r"data:(image/\w+);base64,(.+)", url, re.DOTALL)
|
||||||
|
if m:
|
||||||
|
return {
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "base64", "media_type": m.group(1), "data": m.group(2)},
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "url", "url": url},
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_consecutive(msgs: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Anthropic requires alternating user/assistant roles."""
|
||||||
|
merged: list[dict[str, Any]] = []
|
||||||
|
for msg in msgs:
|
||||||
|
if merged and merged[-1]["role"] == msg["role"]:
|
||||||
|
prev_c = merged[-1]["content"]
|
||||||
|
cur_c = msg["content"]
|
||||||
|
if isinstance(prev_c, str):
|
||||||
|
prev_c = [{"type": "text", "text": prev_c}]
|
||||||
|
if isinstance(cur_c, str):
|
||||||
|
cur_c = [{"type": "text", "text": cur_c}]
|
||||||
|
if isinstance(cur_c, list):
|
||||||
|
prev_c.extend(cur_c)
|
||||||
|
merged[-1]["content"] = prev_c
|
||||||
|
else:
|
||||||
|
merged.append(msg)
|
||||||
|
return merged
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Tool definition conversion
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_tools(tools: list[dict[str, Any]] | None) -> list[dict[str, Any]] | None:
|
||||||
|
if not tools:
|
||||||
|
return None
|
||||||
|
result = []
|
||||||
|
for tool in tools:
|
||||||
|
func = tool.get("function", tool)
|
||||||
|
entry: dict[str, Any] = {
|
||||||
|
"name": func.get("name", ""),
|
||||||
|
"input_schema": func.get("parameters", {"type": "object", "properties": {}}),
|
||||||
|
}
|
||||||
|
desc = func.get("description")
|
||||||
|
if desc:
|
||||||
|
entry["description"] = desc
|
||||||
|
if "cache_control" in tool:
|
||||||
|
entry["cache_control"] = tool["cache_control"]
|
||||||
|
result.append(entry)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_tool_choice(
|
||||||
|
tool_choice: str | dict[str, Any] | None,
|
||||||
|
thinking_enabled: bool = False,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
if thinking_enabled:
|
||||||
|
return {"type": "auto"}
|
||||||
|
if tool_choice is None or tool_choice == "auto":
|
||||||
|
return {"type": "auto"}
|
||||||
|
if tool_choice == "required":
|
||||||
|
return {"type": "any"}
|
||||||
|
if tool_choice == "none":
|
||||||
|
return None
|
||||||
|
if isinstance(tool_choice, dict):
|
||||||
|
name = tool_choice.get("function", {}).get("name")
|
||||||
|
if name:
|
||||||
|
return {"type": "tool", "name": name}
|
||||||
|
return {"type": "auto"}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Prompt caching
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _apply_cache_control(
|
||||||
|
system: str | list[dict[str, Any]],
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None,
|
||||||
|
) -> tuple[str | list[dict[str, Any]], list[dict[str, Any]], list[dict[str, Any]] | None]:
|
||||||
|
marker = {"type": "ephemeral"}
|
||||||
|
|
||||||
|
if isinstance(system, str) and system:
|
||||||
|
system = [{"type": "text", "text": system, "cache_control": marker}]
|
||||||
|
elif isinstance(system, list) and system:
|
||||||
|
system = list(system)
|
||||||
|
system[-1] = {**system[-1], "cache_control": marker}
|
||||||
|
|
||||||
|
new_msgs = list(messages)
|
||||||
|
if len(new_msgs) >= 3:
|
||||||
|
m = new_msgs[-2]
|
||||||
|
c = m.get("content")
|
||||||
|
if isinstance(c, str):
|
||||||
|
new_msgs[-2] = {**m, "content": [{"type": "text", "text": c, "cache_control": marker}]}
|
||||||
|
elif isinstance(c, list) and c:
|
||||||
|
nc = list(c)
|
||||||
|
nc[-1] = {**nc[-1], "cache_control": marker}
|
||||||
|
new_msgs[-2] = {**m, "content": nc}
|
||||||
|
|
||||||
|
new_tools = tools
|
||||||
|
if tools:
|
||||||
|
new_tools = list(tools)
|
||||||
|
new_tools[-1] = {**new_tools[-1], "cache_control": marker}
|
||||||
|
|
||||||
|
return system, new_msgs, new_tools
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Build API kwargs
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _build_kwargs(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None,
|
||||||
|
model: str | None,
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
reasoning_effort: str | None,
|
||||||
|
tool_choice: str | dict[str, Any] | None,
|
||||||
|
supports_caching: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
model_name = self._strip_prefix(model or self.default_model)
|
||||||
|
system, anthropic_msgs = self._convert_messages(self._sanitize_empty_content(messages))
|
||||||
|
anthropic_tools = self._convert_tools(tools)
|
||||||
|
|
||||||
|
if supports_caching:
|
||||||
|
system, anthropic_msgs, anthropic_tools = self._apply_cache_control(
|
||||||
|
system, anthropic_msgs, anthropic_tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
max_tokens = max(1, max_tokens)
|
||||||
|
thinking_enabled = bool(reasoning_effort)
|
||||||
|
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"model": model_name,
|
||||||
|
"messages": anthropic_msgs,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
if system:
|
||||||
|
kwargs["system"] = system
|
||||||
|
|
||||||
|
if thinking_enabled:
|
||||||
|
budget_map = {"low": 1024, "medium": 4096, "high": max(8192, max_tokens)}
|
||||||
|
budget = budget_map.get(reasoning_effort.lower(), 4096) # type: ignore[union-attr]
|
||||||
|
kwargs["thinking"] = {"type": "enabled", "budget_tokens": budget}
|
||||||
|
kwargs["max_tokens"] = max(max_tokens, budget + 4096)
|
||||||
|
kwargs["temperature"] = 1.0
|
||||||
|
else:
|
||||||
|
kwargs["temperature"] = temperature
|
||||||
|
|
||||||
|
if anthropic_tools:
|
||||||
|
kwargs["tools"] = anthropic_tools
|
||||||
|
tc = self._convert_tool_choice(tool_choice, thinking_enabled)
|
||||||
|
if tc:
|
||||||
|
kwargs["tool_choice"] = tc
|
||||||
|
|
||||||
|
if self.extra_headers:
|
||||||
|
kwargs["extra_headers"] = self.extra_headers
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Response parsing
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_response(response: Any) -> LLMResponse:
|
||||||
|
content_parts: list[str] = []
|
||||||
|
tool_calls: list[ToolCallRequest] = []
|
||||||
|
thinking_blocks: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for block in response.content:
|
||||||
|
if block.type == "text":
|
||||||
|
content_parts.append(block.text)
|
||||||
|
elif block.type == "tool_use":
|
||||||
|
tool_calls.append(ToolCallRequest(
|
||||||
|
id=block.id,
|
||||||
|
name=block.name,
|
||||||
|
arguments=block.input if isinstance(block.input, dict) else {},
|
||||||
|
))
|
||||||
|
elif block.type == "thinking":
|
||||||
|
thinking_blocks.append({
|
||||||
|
"type": "thinking",
|
||||||
|
"thinking": block.thinking,
|
||||||
|
"signature": getattr(block, "signature", ""),
|
||||||
|
})
|
||||||
|
|
||||||
|
stop_map = {"tool_use": "tool_calls", "end_turn": "stop", "max_tokens": "length"}
|
||||||
|
finish_reason = stop_map.get(response.stop_reason or "", response.stop_reason or "stop")
|
||||||
|
|
||||||
|
usage: dict[str, int] = {}
|
||||||
|
if response.usage:
|
||||||
|
usage = {
|
||||||
|
"prompt_tokens": response.usage.input_tokens,
|
||||||
|
"completion_tokens": response.usage.output_tokens,
|
||||||
|
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
|
||||||
|
}
|
||||||
|
for attr in ("cache_creation_input_tokens", "cache_read_input_tokens"):
|
||||||
|
val = getattr(response.usage, attr, 0)
|
||||||
|
if val:
|
||||||
|
usage[attr] = val
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content="".join(content_parts) or None,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=usage,
|
||||||
|
thinking_blocks=thinking_blocks or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
kwargs = self._build_kwargs(
|
||||||
|
messages, tools, model, max_tokens, temperature,
|
||||||
|
reasoning_effort, tool_choice,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
response = await self._client.messages.create(**kwargs)
|
||||||
|
return self._parse_response(response)
|
||||||
|
except Exception as e:
|
||||||
|
return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error")
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
kwargs = self._build_kwargs(
|
||||||
|
messages, tools, model, max_tokens, temperature,
|
||||||
|
reasoning_effort, tool_choice,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
async with self._client.messages.stream(**kwargs) as stream:
|
||||||
|
if on_content_delta:
|
||||||
|
async for text in stream.text_stream:
|
||||||
|
await on_content_delta(text)
|
||||||
|
response = await stream.get_final_message()
|
||||||
|
return self._parse_response(response)
|
||||||
|
except Exception as e:
|
||||||
|
return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error")
|
||||||
|
|
||||||
|
def get_default_model(self) -> str:
|
||||||
|
return self.default_model
|
||||||
@ -16,6 +16,7 @@ class ToolCallRequest:
|
|||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
arguments: dict[str, Any]
|
arguments: dict[str, Any]
|
||||||
|
extra_content: dict[str, Any] | None = None
|
||||||
provider_specific_fields: dict[str, Any] | None = None
|
provider_specific_fields: dict[str, Any] | None = None
|
||||||
function_provider_specific_fields: dict[str, Any] | None = None
|
function_provider_specific_fields: dict[str, Any] | None = None
|
||||||
|
|
||||||
@ -29,6 +30,8 @@ class ToolCallRequest:
|
|||||||
"arguments": json.dumps(self.arguments, ensure_ascii=False),
|
"arguments": json.dumps(self.arguments, ensure_ascii=False),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
if self.extra_content:
|
||||||
|
tool_call["extra_content"] = self.extra_content
|
||||||
if self.provider_specific_fields:
|
if self.provider_specific_fields:
|
||||||
tool_call["provider_specific_fields"] = self.provider_specific_fields
|
tool_call["provider_specific_fields"] = self.provider_specific_fields
|
||||||
if self.function_provider_specific_fields:
|
if self.function_provider_specific_fields:
|
||||||
|
|||||||
@ -1,152 +0,0 @@
|
|||||||
"""Direct OpenAI-compatible provider — bypasses LiteLLM."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from collections.abc import Awaitable, Callable
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import json_repair
|
|
||||||
from openai import AsyncOpenAI
|
|
||||||
|
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
|
||||||
|
|
||||||
|
|
||||||
class CustomProvider(LLMProvider):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
api_key: str = "no-key",
|
|
||||||
api_base: str = "http://localhost:8000/v1",
|
|
||||||
default_model: str = "default",
|
|
||||||
extra_headers: dict[str, str] | None = None,
|
|
||||||
):
|
|
||||||
super().__init__(api_key, api_base)
|
|
||||||
self.default_model = default_model
|
|
||||||
self._client = AsyncOpenAI(
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=api_base,
|
|
||||||
default_headers={
|
|
||||||
"x-session-affinity": uuid.uuid4().hex,
|
|
||||||
**(extra_headers or {}),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def _build_kwargs(
|
|
||||||
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None,
|
|
||||||
model: str | None, max_tokens: int, temperature: float,
|
|
||||||
reasoning_effort: str | None, tool_choice: str | dict[str, Any] | None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
kwargs: dict[str, Any] = {
|
|
||||||
"model": model or self.default_model,
|
|
||||||
"messages": self._sanitize_empty_content(messages),
|
|
||||||
"max_tokens": max(1, max_tokens),
|
|
||||||
"temperature": temperature,
|
|
||||||
}
|
|
||||||
if reasoning_effort:
|
|
||||||
kwargs["reasoning_effort"] = reasoning_effort
|
|
||||||
if tools:
|
|
||||||
kwargs.update(tools=tools, tool_choice=tool_choice or "auto")
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
def _handle_error(self, e: Exception) -> LLMResponse:
|
|
||||||
body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None)
|
|
||||||
msg = f"Error: {body.strip()[:500]}" if body and body.strip() else f"Error: {e}"
|
|
||||||
return LLMResponse(content=msg, finish_reason="error")
|
|
||||||
|
|
||||||
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
|
||||||
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
|
||||||
reasoning_effort: str | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None) -> LLMResponse:
|
|
||||||
kwargs = self._build_kwargs(messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice)
|
|
||||||
try:
|
|
||||||
return self._parse(await self._client.chat.completions.create(**kwargs))
|
|
||||||
except Exception as e:
|
|
||||||
return self._handle_error(e)
|
|
||||||
|
|
||||||
async def chat_stream(
|
|
||||||
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
|
||||||
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
|
||||||
reasoning_effort: str | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
|
||||||
) -> LLMResponse:
|
|
||||||
kwargs = self._build_kwargs(messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice)
|
|
||||||
kwargs["stream"] = True
|
|
||||||
try:
|
|
||||||
stream = await self._client.chat.completions.create(**kwargs)
|
|
||||||
chunks: list[Any] = []
|
|
||||||
async for chunk in stream:
|
|
||||||
chunks.append(chunk)
|
|
||||||
if on_content_delta and chunk.choices:
|
|
||||||
text = getattr(chunk.choices[0].delta, "content", None)
|
|
||||||
if text:
|
|
||||||
await on_content_delta(text)
|
|
||||||
return self._parse_chunks(chunks)
|
|
||||||
except Exception as e:
|
|
||||||
return self._handle_error(e)
|
|
||||||
|
|
||||||
def _parse(self, response: Any) -> LLMResponse:
|
|
||||||
if not response.choices:
|
|
||||||
return LLMResponse(
|
|
||||||
content="Error: API returned empty choices.",
|
|
||||||
finish_reason="error",
|
|
||||||
)
|
|
||||||
choice = response.choices[0]
|
|
||||||
msg = choice.message
|
|
||||||
tool_calls = [
|
|
||||||
ToolCallRequest(
|
|
||||||
id=tc.id, name=tc.function.name,
|
|
||||||
arguments=json_repair.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments,
|
|
||||||
)
|
|
||||||
for tc in (msg.tool_calls or [])
|
|
||||||
]
|
|
||||||
u = response.usage
|
|
||||||
return LLMResponse(
|
|
||||||
content=msg.content, tool_calls=tool_calls,
|
|
||||||
finish_reason=choice.finish_reason or "stop",
|
|
||||||
usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {},
|
|
||||||
reasoning_content=getattr(msg, "reasoning_content", None) or None,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_chunks(self, chunks: list[Any]) -> LLMResponse:
|
|
||||||
"""Reassemble streamed chunks into a single LLMResponse."""
|
|
||||||
content_parts: list[str] = []
|
|
||||||
tc_bufs: dict[int, dict[str, str]] = {}
|
|
||||||
finish_reason = "stop"
|
|
||||||
usage: dict[str, int] = {}
|
|
||||||
|
|
||||||
for chunk in chunks:
|
|
||||||
if not chunk.choices:
|
|
||||||
if hasattr(chunk, "usage") and chunk.usage:
|
|
||||||
u = chunk.usage
|
|
||||||
usage = {"prompt_tokens": u.prompt_tokens or 0, "completion_tokens": u.completion_tokens or 0,
|
|
||||||
"total_tokens": u.total_tokens or 0}
|
|
||||||
continue
|
|
||||||
choice = chunk.choices[0]
|
|
||||||
if choice.finish_reason:
|
|
||||||
finish_reason = choice.finish_reason
|
|
||||||
delta = choice.delta
|
|
||||||
if delta and delta.content:
|
|
||||||
content_parts.append(delta.content)
|
|
||||||
for tc in (delta.tool_calls or []) if delta else []:
|
|
||||||
buf = tc_bufs.setdefault(tc.index, {"id": "", "name": "", "arguments": ""})
|
|
||||||
if tc.id:
|
|
||||||
buf["id"] = tc.id
|
|
||||||
if tc.function and tc.function.name:
|
|
||||||
buf["name"] = tc.function.name
|
|
||||||
if tc.function and tc.function.arguments:
|
|
||||||
buf["arguments"] += tc.function.arguments
|
|
||||||
|
|
||||||
return LLMResponse(
|
|
||||||
content="".join(content_parts) or None,
|
|
||||||
tool_calls=[
|
|
||||||
ToolCallRequest(id=b["id"], name=b["name"], arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {})
|
|
||||||
for b in tc_bufs.values()
|
|
||||||
],
|
|
||||||
finish_reason=finish_reason,
|
|
||||||
usage=usage,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_default_model(self) -> str:
|
|
||||||
return self.default_model
|
|
||||||
@ -1,413 +0,0 @@
|
|||||||
"""LiteLLM provider implementation for multi-provider support."""
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import os
|
|
||||||
import secrets
|
|
||||||
import string
|
|
||||||
from collections.abc import Awaitable, Callable
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import json_repair
|
|
||||||
import litellm
|
|
||||||
from litellm import acompletion
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
|
||||||
from nanobot.providers.registry import find_by_model, find_gateway
|
|
||||||
|
|
||||||
# Standard chat-completion message keys.
|
|
||||||
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
|
|
||||||
_ANTHROPIC_EXTRA_KEYS = frozenset({"thinking_blocks"})
|
|
||||||
_ALNUM = string.ascii_letters + string.digits
|
|
||||||
|
|
||||||
def _short_tool_id() -> str:
|
|
||||||
"""Generate a 9-char alphanumeric ID compatible with all providers (incl. Mistral)."""
|
|
||||||
return "".join(secrets.choice(_ALNUM) for _ in range(9))
|
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMProvider(LLMProvider):
|
|
||||||
"""
|
|
||||||
LLM provider using LiteLLM for multi-provider support.
|
|
||||||
|
|
||||||
Supports OpenRouter, Anthropic, OpenAI, Gemini, MiniMax, and many other providers through
|
|
||||||
a unified interface. Provider-specific logic is driven by the registry
|
|
||||||
(see providers/registry.py) — no if-elif chains needed here.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
api_key: str | None = None,
|
|
||||||
api_base: str | None = None,
|
|
||||||
default_model: str = "anthropic/claude-opus-4-5",
|
|
||||||
extra_headers: dict[str, str] | None = None,
|
|
||||||
provider_name: str | None = None,
|
|
||||||
):
|
|
||||||
super().__init__(api_key, api_base)
|
|
||||||
self.default_model = default_model
|
|
||||||
self.extra_headers = extra_headers or {}
|
|
||||||
|
|
||||||
# Detect gateway / local deployment.
|
|
||||||
# provider_name (from config key) is the primary signal;
|
|
||||||
# api_key / api_base are fallback for auto-detection.
|
|
||||||
self._gateway = find_gateway(provider_name, api_key, api_base)
|
|
||||||
|
|
||||||
# Configure environment variables
|
|
||||||
if api_key:
|
|
||||||
self._setup_env(api_key, api_base, default_model)
|
|
||||||
|
|
||||||
if api_base:
|
|
||||||
litellm.api_base = api_base
|
|
||||||
|
|
||||||
# Disable LiteLLM logging noise
|
|
||||||
litellm.suppress_debug_info = True
|
|
||||||
# Drop unsupported parameters for providers (e.g., gpt-5 rejects some params)
|
|
||||||
litellm.drop_params = True
|
|
||||||
|
|
||||||
self._langsmith_enabled = bool(os.getenv("LANGSMITH_API_KEY"))
|
|
||||||
|
|
||||||
def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None:
|
|
||||||
"""Set environment variables based on detected provider."""
|
|
||||||
spec = self._gateway or find_by_model(model)
|
|
||||||
if not spec:
|
|
||||||
return
|
|
||||||
if not spec.env_key:
|
|
||||||
# OAuth/provider-only specs (for example: openai_codex)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Gateway/local overrides existing env; standard provider doesn't
|
|
||||||
if self._gateway:
|
|
||||||
os.environ[spec.env_key] = api_key
|
|
||||||
else:
|
|
||||||
os.environ.setdefault(spec.env_key, api_key)
|
|
||||||
|
|
||||||
# Resolve env_extras placeholders:
|
|
||||||
# {api_key} → user's API key
|
|
||||||
# {api_base} → user's api_base, falling back to spec.default_api_base
|
|
||||||
effective_base = api_base or spec.default_api_base
|
|
||||||
for env_name, env_val in spec.env_extras:
|
|
||||||
resolved = env_val.replace("{api_key}", api_key)
|
|
||||||
resolved = resolved.replace("{api_base}", effective_base)
|
|
||||||
os.environ.setdefault(env_name, resolved)
|
|
||||||
|
|
||||||
def _resolve_model(self, model: str) -> str:
|
|
||||||
"""Resolve model name by applying provider/gateway prefixes."""
|
|
||||||
if self._gateway:
|
|
||||||
prefix = self._gateway.litellm_prefix
|
|
||||||
if self._gateway.strip_model_prefix:
|
|
||||||
model = model.split("/")[-1]
|
|
||||||
if prefix:
|
|
||||||
model = f"{prefix}/{model}"
|
|
||||||
return model
|
|
||||||
|
|
||||||
# Standard mode: auto-prefix for known providers
|
|
||||||
spec = find_by_model(model)
|
|
||||||
if spec and spec.litellm_prefix:
|
|
||||||
model = self._canonicalize_explicit_prefix(model, spec.name, spec.litellm_prefix)
|
|
||||||
if not any(model.startswith(s) for s in spec.skip_prefixes):
|
|
||||||
model = f"{spec.litellm_prefix}/{model}"
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _canonicalize_explicit_prefix(model: str, spec_name: str, canonical_prefix: str) -> str:
|
|
||||||
"""Normalize explicit provider prefixes like `github-copilot/...`."""
|
|
||||||
if "/" not in model:
|
|
||||||
return model
|
|
||||||
prefix, remainder = model.split("/", 1)
|
|
||||||
if prefix.lower().replace("-", "_") != spec_name:
|
|
||||||
return model
|
|
||||||
return f"{canonical_prefix}/{remainder}"
|
|
||||||
|
|
||||||
def _supports_cache_control(self, model: str) -> bool:
|
|
||||||
"""Return True when the provider supports cache_control on content blocks."""
|
|
||||||
if self._gateway is not None:
|
|
||||||
return self._gateway.supports_prompt_caching
|
|
||||||
spec = find_by_model(model)
|
|
||||||
return spec is not None and spec.supports_prompt_caching
|
|
||||||
|
|
||||||
def _apply_cache_control(
|
|
||||||
self,
|
|
||||||
messages: list[dict[str, Any]],
|
|
||||||
tools: list[dict[str, Any]] | None,
|
|
||||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]:
|
|
||||||
"""Return copies of messages and tools with cache_control injected.
|
|
||||||
|
|
||||||
Two breakpoints are placed:
|
|
||||||
1. System message — caches the static system prompt
|
|
||||||
2. Second-to-last message — caches the conversation history prefix
|
|
||||||
This maximises cache hits across multi-turn conversations.
|
|
||||||
"""
|
|
||||||
cache_marker = {"type": "ephemeral"}
|
|
||||||
new_messages = list(messages)
|
|
||||||
|
|
||||||
def _mark(msg: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
content = msg.get("content")
|
|
||||||
if isinstance(content, str):
|
|
||||||
return {**msg, "content": [
|
|
||||||
{"type": "text", "text": content, "cache_control": cache_marker}
|
|
||||||
]}
|
|
||||||
elif isinstance(content, list) and content:
|
|
||||||
new_content = list(content)
|
|
||||||
new_content[-1] = {**new_content[-1], "cache_control": cache_marker}
|
|
||||||
return {**msg, "content": new_content}
|
|
||||||
return msg
|
|
||||||
|
|
||||||
# Breakpoint 1: system message
|
|
||||||
if new_messages and new_messages[0].get("role") == "system":
|
|
||||||
new_messages[0] = _mark(new_messages[0])
|
|
||||||
|
|
||||||
# Breakpoint 2: second-to-last message (caches conversation history prefix)
|
|
||||||
if len(new_messages) >= 3:
|
|
||||||
new_messages[-2] = _mark(new_messages[-2])
|
|
||||||
|
|
||||||
new_tools = tools
|
|
||||||
if tools:
|
|
||||||
new_tools = list(tools)
|
|
||||||
new_tools[-1] = {**new_tools[-1], "cache_control": cache_marker}
|
|
||||||
|
|
||||||
return new_messages, new_tools
|
|
||||||
|
|
||||||
def _apply_model_overrides(self, model: str, kwargs: dict[str, Any]) -> None:
|
|
||||||
"""Apply model-specific parameter overrides from the registry."""
|
|
||||||
model_lower = model.lower()
|
|
||||||
spec = find_by_model(model)
|
|
||||||
if spec:
|
|
||||||
for pattern, overrides in spec.model_overrides:
|
|
||||||
if pattern in model_lower:
|
|
||||||
kwargs.update(overrides)
|
|
||||||
return
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extra_msg_keys(original_model: str, resolved_model: str) -> frozenset[str]:
|
|
||||||
"""Return provider-specific extra keys to preserve in request messages."""
|
|
||||||
spec = find_by_model(original_model) or find_by_model(resolved_model)
|
|
||||||
if (spec and spec.name == "anthropic") or "claude" in original_model.lower() or resolved_model.startswith("anthropic/"):
|
|
||||||
return _ANTHROPIC_EXTRA_KEYS
|
|
||||||
return frozenset()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _normalize_tool_call_id(tool_call_id: Any) -> Any:
|
|
||||||
"""Normalize tool_call_id to a provider-safe 9-char alphanumeric form."""
|
|
||||||
if not isinstance(tool_call_id, str):
|
|
||||||
return tool_call_id
|
|
||||||
if len(tool_call_id) == 9 and tool_call_id.isalnum():
|
|
||||||
return tool_call_id
|
|
||||||
return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]:
|
|
||||||
"""Strip non-standard keys and ensure assistant messages have a content key."""
|
|
||||||
allowed = _ALLOWED_MSG_KEYS | extra_keys
|
|
||||||
sanitized = LLMProvider._sanitize_request_messages(messages, allowed)
|
|
||||||
id_map: dict[str, str] = {}
|
|
||||||
|
|
||||||
def map_id(value: Any) -> Any:
|
|
||||||
if not isinstance(value, str):
|
|
||||||
return value
|
|
||||||
return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value))
|
|
||||||
|
|
||||||
for clean in sanitized:
|
|
||||||
# Keep assistant tool_calls[].id and tool tool_call_id in sync after
|
|
||||||
# shortening, otherwise strict providers reject the broken linkage.
|
|
||||||
if isinstance(clean.get("tool_calls"), list):
|
|
||||||
normalized_tool_calls = []
|
|
||||||
for tc in clean["tool_calls"]:
|
|
||||||
if not isinstance(tc, dict):
|
|
||||||
normalized_tool_calls.append(tc)
|
|
||||||
continue
|
|
||||||
tc_clean = dict(tc)
|
|
||||||
tc_clean["id"] = map_id(tc_clean.get("id"))
|
|
||||||
normalized_tool_calls.append(tc_clean)
|
|
||||||
clean["tool_calls"] = normalized_tool_calls
|
|
||||||
|
|
||||||
if "tool_call_id" in clean and clean["tool_call_id"]:
|
|
||||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
|
||||||
return sanitized
|
|
||||||
|
|
||||||
def _build_chat_kwargs(
|
|
||||||
self,
|
|
||||||
messages: list[dict[str, Any]],
|
|
||||||
tools: list[dict[str, Any]] | None,
|
|
||||||
model: str | None,
|
|
||||||
max_tokens: int,
|
|
||||||
temperature: float,
|
|
||||||
reasoning_effort: str | None,
|
|
||||||
tool_choice: str | dict[str, Any] | None,
|
|
||||||
) -> tuple[dict[str, Any], str]:
|
|
||||||
"""Build the kwargs dict for ``acompletion``.
|
|
||||||
|
|
||||||
Returns ``(kwargs, original_model)`` so callers can reuse the
|
|
||||||
original model string for downstream logic.
|
|
||||||
"""
|
|
||||||
original_model = model or self.default_model
|
|
||||||
resolved = self._resolve_model(original_model)
|
|
||||||
extra_msg_keys = self._extra_msg_keys(original_model, resolved)
|
|
||||||
|
|
||||||
if self._supports_cache_control(original_model):
|
|
||||||
messages, tools = self._apply_cache_control(messages, tools)
|
|
||||||
|
|
||||||
max_tokens = max(1, max_tokens)
|
|
||||||
|
|
||||||
kwargs: dict[str, Any] = {
|
|
||||||
"model": resolved,
|
|
||||||
"messages": self._sanitize_messages(
|
|
||||||
self._sanitize_empty_content(messages), extra_keys=extra_msg_keys,
|
|
||||||
),
|
|
||||||
"max_tokens": max_tokens,
|
|
||||||
"temperature": temperature,
|
|
||||||
}
|
|
||||||
|
|
||||||
if self._gateway:
|
|
||||||
kwargs.update(self._gateway.litellm_kwargs)
|
|
||||||
|
|
||||||
self._apply_model_overrides(resolved, kwargs)
|
|
||||||
|
|
||||||
if self._langsmith_enabled:
|
|
||||||
kwargs.setdefault("callbacks", []).append("langsmith")
|
|
||||||
|
|
||||||
if self.api_key:
|
|
||||||
kwargs["api_key"] = self.api_key
|
|
||||||
if self.api_base:
|
|
||||||
kwargs["api_base"] = self.api_base
|
|
||||||
if self.extra_headers:
|
|
||||||
kwargs["extra_headers"] = self.extra_headers
|
|
||||||
|
|
||||||
if reasoning_effort:
|
|
||||||
kwargs["reasoning_effort"] = reasoning_effort
|
|
||||||
kwargs["drop_params"] = True
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
kwargs["tools"] = tools
|
|
||||||
kwargs["tool_choice"] = tool_choice or "auto"
|
|
||||||
|
|
||||||
return kwargs, original_model
|
|
||||||
|
|
||||||
async def chat(
|
|
||||||
self,
|
|
||||||
messages: list[dict[str, Any]],
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
model: str | None = None,
|
|
||||||
max_tokens: int = 4096,
|
|
||||||
temperature: float = 0.7,
|
|
||||||
reasoning_effort: str | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
) -> LLMResponse:
|
|
||||||
"""Send a chat completion request via LiteLLM."""
|
|
||||||
kwargs, _ = self._build_chat_kwargs(
|
|
||||||
messages, tools, model, max_tokens, temperature,
|
|
||||||
reasoning_effort, tool_choice,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
response = await acompletion(**kwargs)
|
|
||||||
return self._parse_response(response)
|
|
||||||
except Exception as e:
|
|
||||||
return LLMResponse(
|
|
||||||
content=f"Error calling LLM: {str(e)}",
|
|
||||||
finish_reason="error",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def chat_stream(
|
|
||||||
self,
|
|
||||||
messages: list[dict[str, Any]],
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
model: str | None = None,
|
|
||||||
max_tokens: int = 4096,
|
|
||||||
temperature: float = 0.7,
|
|
||||||
reasoning_effort: str | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
|
||||||
) -> LLMResponse:
|
|
||||||
"""Stream a chat completion via LiteLLM, forwarding text deltas."""
|
|
||||||
kwargs, _ = self._build_chat_kwargs(
|
|
||||||
messages, tools, model, max_tokens, temperature,
|
|
||||||
reasoning_effort, tool_choice,
|
|
||||||
)
|
|
||||||
kwargs["stream"] = True
|
|
||||||
|
|
||||||
try:
|
|
||||||
stream = await acompletion(**kwargs)
|
|
||||||
chunks: list[Any] = []
|
|
||||||
async for chunk in stream:
|
|
||||||
chunks.append(chunk)
|
|
||||||
if on_content_delta:
|
|
||||||
delta = chunk.choices[0].delta if chunk.choices else None
|
|
||||||
text = getattr(delta, "content", None) if delta else None
|
|
||||||
if text:
|
|
||||||
await on_content_delta(text)
|
|
||||||
|
|
||||||
full_response = litellm.stream_chunk_builder(
|
|
||||||
chunks, messages=kwargs["messages"],
|
|
||||||
)
|
|
||||||
return self._parse_response(full_response)
|
|
||||||
except Exception as e:
|
|
||||||
return LLMResponse(
|
|
||||||
content=f"Error calling LLM: {str(e)}",
|
|
||||||
finish_reason="error",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_response(self, response: Any) -> LLMResponse:
|
|
||||||
"""Parse LiteLLM response into our standard format."""
|
|
||||||
choice = response.choices[0]
|
|
||||||
message = choice.message
|
|
||||||
content = message.content
|
|
||||||
finish_reason = choice.finish_reason
|
|
||||||
|
|
||||||
# Some providers (e.g. GitHub Copilot) split content and tool_calls
|
|
||||||
# across multiple choices. Merge them so tool_calls are not lost.
|
|
||||||
raw_tool_calls = []
|
|
||||||
for ch in response.choices:
|
|
||||||
msg = ch.message
|
|
||||||
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
|
||||||
raw_tool_calls.extend(msg.tool_calls)
|
|
||||||
if ch.finish_reason in ("tool_calls", "stop"):
|
|
||||||
finish_reason = ch.finish_reason
|
|
||||||
if not content and msg.content:
|
|
||||||
content = msg.content
|
|
||||||
|
|
||||||
if len(response.choices) > 1:
|
|
||||||
logger.debug("LiteLLM response has {} choices, merged {} tool_calls",
|
|
||||||
len(response.choices), len(raw_tool_calls))
|
|
||||||
|
|
||||||
tool_calls = []
|
|
||||||
for tc in raw_tool_calls:
|
|
||||||
# Parse arguments from JSON string if needed
|
|
||||||
args = tc.function.arguments
|
|
||||||
if isinstance(args, str):
|
|
||||||
args = json_repair.loads(args)
|
|
||||||
|
|
||||||
provider_specific_fields = getattr(tc, "provider_specific_fields", None) or None
|
|
||||||
function_provider_specific_fields = (
|
|
||||||
getattr(tc.function, "provider_specific_fields", None) or None
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_calls.append(ToolCallRequest(
|
|
||||||
id=_short_tool_id(),
|
|
||||||
name=tc.function.name,
|
|
||||||
arguments=args,
|
|
||||||
provider_specific_fields=provider_specific_fields,
|
|
||||||
function_provider_specific_fields=function_provider_specific_fields,
|
|
||||||
))
|
|
||||||
|
|
||||||
usage = {}
|
|
||||||
if hasattr(response, "usage") and response.usage:
|
|
||||||
usage = {
|
|
||||||
"prompt_tokens": response.usage.prompt_tokens,
|
|
||||||
"completion_tokens": response.usage.completion_tokens,
|
|
||||||
"total_tokens": response.usage.total_tokens,
|
|
||||||
}
|
|
||||||
|
|
||||||
reasoning_content = getattr(message, "reasoning_content", None) or None
|
|
||||||
thinking_blocks = getattr(message, "thinking_blocks", None) or None
|
|
||||||
|
|
||||||
return LLMResponse(
|
|
||||||
content=content,
|
|
||||||
tool_calls=tool_calls,
|
|
||||||
finish_reason=finish_reason or "stop",
|
|
||||||
usage=usage,
|
|
||||||
reasoning_content=reasoning_content,
|
|
||||||
thinking_blocks=thinking_blocks,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_default_model(self) -> str:
|
|
||||||
"""Get the default model."""
|
|
||||||
return self.default_model
|
|
||||||
571
nanobot/providers/openai_compat_provider.py
Normal file
571
nanobot/providers/openai_compat_provider.py
Normal file
@ -0,0 +1,571 @@
|
|||||||
|
"""OpenAI-compatible provider for all non-Anthropic LLM APIs."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
import secrets
|
||||||
|
import string
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import json_repair
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from nanobot.providers.registry import ProviderSpec
|
||||||
|
|
||||||
|
_ALLOWED_MSG_KEYS = frozenset({
|
||||||
|
"role", "content", "tool_calls", "tool_call_id", "name",
|
||||||
|
"reasoning_content", "extra_content",
|
||||||
|
})
|
||||||
|
_ALNUM = string.ascii_letters + string.digits
|
||||||
|
|
||||||
|
_STANDARD_TC_KEYS = frozenset({"id", "type", "index", "function"})
|
||||||
|
_STANDARD_FN_KEYS = frozenset({"name", "arguments"})
|
||||||
|
|
||||||
|
|
||||||
|
def _short_tool_id() -> str:
|
||||||
|
"""9-char alphanumeric ID compatible with all providers (incl. Mistral)."""
|
||||||
|
return "".join(secrets.choice(_ALNUM) for _ in range(9))
|
||||||
|
|
||||||
|
|
||||||
|
def _get(obj: Any, key: str) -> Any:
|
||||||
|
"""Get a value from dict or object attribute, returning None if absent."""
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return obj.get(key)
|
||||||
|
return getattr(obj, key, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_dict(value: Any) -> dict[str, Any] | None:
|
||||||
|
"""Try to coerce *value* to a dict; return None if not possible or empty."""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return value if value else None
|
||||||
|
model_dump = getattr(value, "model_dump", None)
|
||||||
|
if callable(model_dump):
|
||||||
|
dumped = model_dump()
|
||||||
|
if isinstance(dumped, dict) and dumped:
|
||||||
|
return dumped
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_tc_extras(tc: Any) -> tuple[
|
||||||
|
dict[str, Any] | None,
|
||||||
|
dict[str, Any] | None,
|
||||||
|
dict[str, Any] | None,
|
||||||
|
]:
|
||||||
|
"""Extract (extra_content, provider_specific_fields, fn_provider_specific_fields).
|
||||||
|
|
||||||
|
Works for both SDK objects and dicts. Captures Gemini ``extra_content``
|
||||||
|
verbatim and any non-standard keys on the tool-call / function.
|
||||||
|
"""
|
||||||
|
extra_content = _coerce_dict(_get(tc, "extra_content"))
|
||||||
|
|
||||||
|
tc_dict = _coerce_dict(tc)
|
||||||
|
prov = None
|
||||||
|
fn_prov = None
|
||||||
|
if tc_dict is not None:
|
||||||
|
leftover = {k: v for k, v in tc_dict.items()
|
||||||
|
if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None}
|
||||||
|
if leftover:
|
||||||
|
prov = leftover
|
||||||
|
fn = _coerce_dict(tc_dict.get("function"))
|
||||||
|
if fn is not None:
|
||||||
|
fn_leftover = {k: v for k, v in fn.items()
|
||||||
|
if k not in _STANDARD_FN_KEYS and v is not None}
|
||||||
|
if fn_leftover:
|
||||||
|
fn_prov = fn_leftover
|
||||||
|
else:
|
||||||
|
prov = _coerce_dict(_get(tc, "provider_specific_fields"))
|
||||||
|
fn_obj = _get(tc, "function")
|
||||||
|
if fn_obj is not None:
|
||||||
|
fn_prov = _coerce_dict(_get(fn_obj, "provider_specific_fields"))
|
||||||
|
|
||||||
|
return extra_content, prov, fn_prov
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompatProvider(LLMProvider):
|
||||||
|
"""Unified provider for all OpenAI-compatible APIs.
|
||||||
|
|
||||||
|
Receives a resolved ``ProviderSpec`` from the caller — no internal
|
||||||
|
registry lookups needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str | None = None,
|
||||||
|
api_base: str | None = None,
|
||||||
|
default_model: str = "gpt-4o",
|
||||||
|
extra_headers: dict[str, str] | None = None,
|
||||||
|
spec: ProviderSpec | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(api_key, api_base)
|
||||||
|
self.default_model = default_model
|
||||||
|
self.extra_headers = extra_headers or {}
|
||||||
|
self._spec = spec
|
||||||
|
|
||||||
|
if api_key and spec and spec.env_key:
|
||||||
|
self._setup_env(api_key, api_base)
|
||||||
|
|
||||||
|
effective_base = api_base or (spec.default_api_base if spec else None) or None
|
||||||
|
|
||||||
|
self._client = AsyncOpenAI(
|
||||||
|
api_key=api_key or "no-key",
|
||||||
|
base_url=effective_base,
|
||||||
|
default_headers={
|
||||||
|
"x-session-affinity": uuid.uuid4().hex,
|
||||||
|
**(extra_headers or {}),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _setup_env(self, api_key: str, api_base: str | None) -> None:
|
||||||
|
"""Set environment variables based on provider spec."""
|
||||||
|
spec = self._spec
|
||||||
|
if not spec or not spec.env_key:
|
||||||
|
return
|
||||||
|
if spec.is_gateway:
|
||||||
|
os.environ[spec.env_key] = api_key
|
||||||
|
else:
|
||||||
|
os.environ.setdefault(spec.env_key, api_key)
|
||||||
|
effective_base = api_base or spec.default_api_base
|
||||||
|
for env_name, env_val in spec.env_extras:
|
||||||
|
resolved = env_val.replace("{api_key}", api_key).replace("{api_base}", effective_base)
|
||||||
|
os.environ.setdefault(env_name, resolved)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _apply_cache_control(
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None,
|
||||||
|
) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]:
|
||||||
|
"""Inject cache_control markers for prompt caching."""
|
||||||
|
cache_marker = {"type": "ephemeral"}
|
||||||
|
new_messages = list(messages)
|
||||||
|
|
||||||
|
def _mark(msg: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, str):
|
||||||
|
return {**msg, "content": [
|
||||||
|
{"type": "text", "text": content, "cache_control": cache_marker},
|
||||||
|
]}
|
||||||
|
if isinstance(content, list) and content:
|
||||||
|
nc = list(content)
|
||||||
|
nc[-1] = {**nc[-1], "cache_control": cache_marker}
|
||||||
|
return {**msg, "content": nc}
|
||||||
|
return msg
|
||||||
|
|
||||||
|
if new_messages and new_messages[0].get("role") == "system":
|
||||||
|
new_messages[0] = _mark(new_messages[0])
|
||||||
|
if len(new_messages) >= 3:
|
||||||
|
new_messages[-2] = _mark(new_messages[-2])
|
||||||
|
|
||||||
|
new_tools = tools
|
||||||
|
if tools:
|
||||||
|
new_tools = list(tools)
|
||||||
|
new_tools[-1] = {**new_tools[-1], "cache_control": cache_marker}
|
||||||
|
return new_messages, new_tools
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_tool_call_id(tool_call_id: Any) -> Any:
|
||||||
|
"""Normalize to a provider-safe 9-char alphanumeric form."""
|
||||||
|
if not isinstance(tool_call_id, str):
|
||||||
|
return tool_call_id
|
||||||
|
if len(tool_call_id) == 9 and tool_call_id.isalnum():
|
||||||
|
return tool_call_id
|
||||||
|
return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9]
|
||||||
|
|
||||||
|
def _sanitize_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Strip non-standard keys, normalize tool_call IDs."""
|
||||||
|
sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS)
|
||||||
|
id_map: dict[str, str] = {}
|
||||||
|
|
||||||
|
def map_id(value: Any) -> Any:
|
||||||
|
if not isinstance(value, str):
|
||||||
|
return value
|
||||||
|
return id_map.setdefault(value, self._normalize_tool_call_id(value))
|
||||||
|
|
||||||
|
for clean in sanitized:
|
||||||
|
if isinstance(clean.get("tool_calls"), list):
|
||||||
|
normalized = []
|
||||||
|
for tc in clean["tool_calls"]:
|
||||||
|
if not isinstance(tc, dict):
|
||||||
|
normalized.append(tc)
|
||||||
|
continue
|
||||||
|
tc_clean = dict(tc)
|
||||||
|
tc_clean["id"] = map_id(tc_clean.get("id"))
|
||||||
|
normalized.append(tc_clean)
|
||||||
|
clean["tool_calls"] = normalized
|
||||||
|
if "tool_call_id" in clean and clean["tool_call_id"]:
|
||||||
|
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Build kwargs
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _build_kwargs(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None,
|
||||||
|
model: str | None,
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
reasoning_effort: str | None,
|
||||||
|
tool_choice: str | dict[str, Any] | None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
model_name = model or self.default_model
|
||||||
|
spec = self._spec
|
||||||
|
|
||||||
|
if spec and spec.supports_prompt_caching:
|
||||||
|
messages, tools = self._apply_cache_control(messages, tools)
|
||||||
|
|
||||||
|
if spec and spec.strip_model_prefix:
|
||||||
|
model_name = model_name.split("/")[-1]
|
||||||
|
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"model": model_name,
|
||||||
|
"messages": self._sanitize_messages(self._sanitize_empty_content(messages)),
|
||||||
|
"max_tokens": max(1, max_tokens),
|
||||||
|
"temperature": temperature,
|
||||||
|
}
|
||||||
|
|
||||||
|
if spec:
|
||||||
|
model_lower = model_name.lower()
|
||||||
|
for pattern, overrides in spec.model_overrides:
|
||||||
|
if pattern in model_lower:
|
||||||
|
kwargs.update(overrides)
|
||||||
|
break
|
||||||
|
|
||||||
|
if reasoning_effort:
|
||||||
|
kwargs["reasoning_effort"] = reasoning_effort
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
kwargs["tools"] = tools
|
||||||
|
kwargs["tool_choice"] = tool_choice or "auto"
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Response parsing
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _maybe_mapping(value: Any) -> dict[str, Any] | None:
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return value
|
||||||
|
model_dump = getattr(value, "model_dump", None)
|
||||||
|
if callable(model_dump):
|
||||||
|
dumped = model_dump()
|
||||||
|
if isinstance(dumped, dict):
|
||||||
|
return dumped
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_text_content(cls, value: Any) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value
|
||||||
|
if isinstance(value, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in value:
|
||||||
|
item_map = cls._maybe_mapping(item)
|
||||||
|
if item_map:
|
||||||
|
text = item_map.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
continue
|
||||||
|
text = getattr(item, "text", None)
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
continue
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
return "".join(parts) or None
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_usage(cls, response: Any) -> dict[str, int]:
|
||||||
|
usage_obj = None
|
||||||
|
response_map = cls._maybe_mapping(response)
|
||||||
|
if response_map is not None:
|
||||||
|
usage_obj = response_map.get("usage")
|
||||||
|
elif hasattr(response, "usage") and response.usage:
|
||||||
|
usage_obj = response.usage
|
||||||
|
|
||||||
|
usage_map = cls._maybe_mapping(usage_obj)
|
||||||
|
if usage_map is not None:
|
||||||
|
return {
|
||||||
|
"prompt_tokens": int(usage_map.get("prompt_tokens") or 0),
|
||||||
|
"completion_tokens": int(usage_map.get("completion_tokens") or 0),
|
||||||
|
"total_tokens": int(usage_map.get("total_tokens") or 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
if usage_obj:
|
||||||
|
return {
|
||||||
|
"prompt_tokens": getattr(usage_obj, "prompt_tokens", 0) or 0,
|
||||||
|
"completion_tokens": getattr(usage_obj, "completion_tokens", 0) or 0,
|
||||||
|
"total_tokens": getattr(usage_obj, "total_tokens", 0) or 0,
|
||||||
|
}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _parse(self, response: Any) -> LLMResponse:
|
||||||
|
if isinstance(response, str):
|
||||||
|
return LLMResponse(content=response, finish_reason="stop")
|
||||||
|
|
||||||
|
response_map = self._maybe_mapping(response)
|
||||||
|
if response_map is not None:
|
||||||
|
choices = response_map.get("choices") or []
|
||||||
|
if not choices:
|
||||||
|
content = self._extract_text_content(
|
||||||
|
response_map.get("content") or response_map.get("output_text")
|
||||||
|
)
|
||||||
|
if content is not None:
|
||||||
|
return LLMResponse(
|
||||||
|
content=content,
|
||||||
|
finish_reason=str(response_map.get("finish_reason") or "stop"),
|
||||||
|
usage=self._extract_usage(response_map),
|
||||||
|
)
|
||||||
|
return LLMResponse(content="Error: API returned empty choices.", finish_reason="error")
|
||||||
|
|
||||||
|
choice0 = self._maybe_mapping(choices[0]) or {}
|
||||||
|
msg0 = self._maybe_mapping(choice0.get("message")) or {}
|
||||||
|
content = self._extract_text_content(msg0.get("content"))
|
||||||
|
finish_reason = str(choice0.get("finish_reason") or "stop")
|
||||||
|
|
||||||
|
raw_tool_calls: list[Any] = []
|
||||||
|
reasoning_content = msg0.get("reasoning_content")
|
||||||
|
for ch in choices:
|
||||||
|
ch_map = self._maybe_mapping(ch) or {}
|
||||||
|
m = self._maybe_mapping(ch_map.get("message")) or {}
|
||||||
|
tool_calls = m.get("tool_calls")
|
||||||
|
if isinstance(tool_calls, list) and tool_calls:
|
||||||
|
raw_tool_calls.extend(tool_calls)
|
||||||
|
if ch_map.get("finish_reason") in ("tool_calls", "stop"):
|
||||||
|
finish_reason = str(ch_map["finish_reason"])
|
||||||
|
if not content:
|
||||||
|
content = self._extract_text_content(m.get("content"))
|
||||||
|
if not reasoning_content:
|
||||||
|
reasoning_content = m.get("reasoning_content")
|
||||||
|
|
||||||
|
parsed_tool_calls = []
|
||||||
|
for tc in raw_tool_calls:
|
||||||
|
tc_map = self._maybe_mapping(tc) or {}
|
||||||
|
fn = self._maybe_mapping(tc_map.get("function")) or {}
|
||||||
|
args = fn.get("arguments", {})
|
||||||
|
if isinstance(args, str):
|
||||||
|
args = json_repair.loads(args)
|
||||||
|
ec, prov, fn_prov = _extract_tc_extras(tc)
|
||||||
|
parsed_tool_calls.append(ToolCallRequest(
|
||||||
|
id=_short_tool_id(),
|
||||||
|
name=str(fn.get("name") or ""),
|
||||||
|
arguments=args if isinstance(args, dict) else {},
|
||||||
|
extra_content=ec,
|
||||||
|
provider_specific_fields=prov,
|
||||||
|
function_provider_specific_fields=fn_prov,
|
||||||
|
))
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=content,
|
||||||
|
tool_calls=parsed_tool_calls,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=self._extract_usage(response_map),
|
||||||
|
reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response.choices:
|
||||||
|
return LLMResponse(content="Error: API returned empty choices.", finish_reason="error")
|
||||||
|
|
||||||
|
choice = response.choices[0]
|
||||||
|
msg = choice.message
|
||||||
|
content = msg.content
|
||||||
|
finish_reason = choice.finish_reason
|
||||||
|
|
||||||
|
raw_tool_calls: list[Any] = []
|
||||||
|
for ch in response.choices:
|
||||||
|
m = ch.message
|
||||||
|
if hasattr(m, "tool_calls") and m.tool_calls:
|
||||||
|
raw_tool_calls.extend(m.tool_calls)
|
||||||
|
if ch.finish_reason in ("tool_calls", "stop"):
|
||||||
|
finish_reason = ch.finish_reason
|
||||||
|
if not content and m.content:
|
||||||
|
content = m.content
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
|
for tc in raw_tool_calls:
|
||||||
|
args = tc.function.arguments
|
||||||
|
if isinstance(args, str):
|
||||||
|
args = json_repair.loads(args)
|
||||||
|
ec, prov, fn_prov = _extract_tc_extras(tc)
|
||||||
|
tool_calls.append(ToolCallRequest(
|
||||||
|
id=_short_tool_id(),
|
||||||
|
name=tc.function.name,
|
||||||
|
arguments=args,
|
||||||
|
extra_content=ec,
|
||||||
|
provider_specific_fields=prov,
|
||||||
|
function_provider_specific_fields=fn_prov,
|
||||||
|
))
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=content,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
finish_reason=finish_reason or "stop",
|
||||||
|
usage=self._extract_usage(response),
|
||||||
|
reasoning_content=getattr(msg, "reasoning_content", None) or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _parse_chunks(cls, chunks: list[Any]) -> LLMResponse:
|
||||||
|
content_parts: list[str] = []
|
||||||
|
tc_bufs: dict[int, dict[str, Any]] = {}
|
||||||
|
finish_reason = "stop"
|
||||||
|
usage: dict[str, int] = {}
|
||||||
|
|
||||||
|
def _accum_tc(tc: Any, idx_hint: int) -> None:
|
||||||
|
"""Accumulate one streaming tool-call delta into *tc_bufs*."""
|
||||||
|
tc_index: int = _get(tc, "index") if _get(tc, "index") is not None else idx_hint
|
||||||
|
buf = tc_bufs.setdefault(tc_index, {
|
||||||
|
"id": "", "name": "", "arguments": "",
|
||||||
|
"extra_content": None, "prov": None, "fn_prov": None,
|
||||||
|
})
|
||||||
|
tc_id = _get(tc, "id")
|
||||||
|
if tc_id:
|
||||||
|
buf["id"] = str(tc_id)
|
||||||
|
fn = _get(tc, "function")
|
||||||
|
if fn is not None:
|
||||||
|
fn_name = _get(fn, "name")
|
||||||
|
if fn_name:
|
||||||
|
buf["name"] = str(fn_name)
|
||||||
|
fn_args = _get(fn, "arguments")
|
||||||
|
if fn_args:
|
||||||
|
buf["arguments"] += str(fn_args)
|
||||||
|
ec, prov, fn_prov = _extract_tc_extras(tc)
|
||||||
|
if ec:
|
||||||
|
buf["extra_content"] = ec
|
||||||
|
if prov:
|
||||||
|
buf["prov"] = prov
|
||||||
|
if fn_prov:
|
||||||
|
buf["fn_prov"] = fn_prov
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
if isinstance(chunk, str):
|
||||||
|
content_parts.append(chunk)
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk_map = cls._maybe_mapping(chunk)
|
||||||
|
if chunk_map is not None:
|
||||||
|
choices = chunk_map.get("choices") or []
|
||||||
|
if not choices:
|
||||||
|
usage = cls._extract_usage(chunk_map) or usage
|
||||||
|
text = cls._extract_text_content(
|
||||||
|
chunk_map.get("content") or chunk_map.get("output_text")
|
||||||
|
)
|
||||||
|
if text:
|
||||||
|
content_parts.append(text)
|
||||||
|
continue
|
||||||
|
choice = cls._maybe_mapping(choices[0]) or {}
|
||||||
|
if choice.get("finish_reason"):
|
||||||
|
finish_reason = str(choice["finish_reason"])
|
||||||
|
delta = cls._maybe_mapping(choice.get("delta")) or {}
|
||||||
|
text = cls._extract_text_content(delta.get("content"))
|
||||||
|
if text:
|
||||||
|
content_parts.append(text)
|
||||||
|
for idx, tc in enumerate(delta.get("tool_calls") or []):
|
||||||
|
_accum_tc(tc, idx)
|
||||||
|
usage = cls._extract_usage(chunk_map) or usage
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not chunk.choices:
|
||||||
|
usage = cls._extract_usage(chunk) or usage
|
||||||
|
continue
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
if choice.finish_reason:
|
||||||
|
finish_reason = choice.finish_reason
|
||||||
|
delta = choice.delta
|
||||||
|
if delta and delta.content:
|
||||||
|
content_parts.append(delta.content)
|
||||||
|
for tc in (delta.tool_calls or []) if delta else []:
|
||||||
|
_accum_tc(tc, getattr(tc, "index", 0))
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content="".join(content_parts) or None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id=b["id"] or _short_tool_id(),
|
||||||
|
name=b["name"],
|
||||||
|
arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {},
|
||||||
|
extra_content=b.get("extra_content"),
|
||||||
|
provider_specific_fields=b.get("prov"),
|
||||||
|
function_provider_specific_fields=b.get("fn_prov"),
|
||||||
|
)
|
||||||
|
for b in tc_bufs.values()
|
||||||
|
],
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _handle_error(e: Exception) -> LLMResponse:
|
||||||
|
body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None)
|
||||||
|
msg = f"Error: {body.strip()[:500]}" if body and body.strip() else f"Error calling LLM: {e}"
|
||||||
|
return LLMResponse(content=msg, finish_reason="error")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
kwargs = self._build_kwargs(
|
||||||
|
messages, tools, model, max_tokens, temperature,
|
||||||
|
reasoning_effort, tool_choice,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
return self._parse(await self._client.chat.completions.create(**kwargs))
|
||||||
|
except Exception as e:
|
||||||
|
return self._handle_error(e)
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
kwargs = self._build_kwargs(
|
||||||
|
messages, tools, model, max_tokens, temperature,
|
||||||
|
reasoning_effort, tool_choice,
|
||||||
|
)
|
||||||
|
kwargs["stream"] = True
|
||||||
|
kwargs["stream_options"] = {"include_usage": True}
|
||||||
|
try:
|
||||||
|
stream = await self._client.chat.completions.create(**kwargs)
|
||||||
|
chunks: list[Any] = []
|
||||||
|
async for chunk in stream:
|
||||||
|
chunks.append(chunk)
|
||||||
|
if on_content_delta and chunk.choices:
|
||||||
|
text = getattr(chunk.choices[0].delta, "content", None)
|
||||||
|
if text:
|
||||||
|
await on_content_delta(text)
|
||||||
|
return self._parse_chunks(chunks)
|
||||||
|
except Exception as e:
|
||||||
|
return self._handle_error(e)
|
||||||
|
|
||||||
|
def get_default_model(self) -> str:
|
||||||
|
return self.default_model
|
||||||
@ -4,7 +4,7 @@ Provider Registry — single source of truth for LLM provider metadata.
|
|||||||
Adding a new provider:
|
Adding a new provider:
|
||||||
1. Add a ProviderSpec to PROVIDERS below.
|
1. Add a ProviderSpec to PROVIDERS below.
|
||||||
2. Add a field to ProvidersConfig in config/schema.py.
|
2. Add a field to ProvidersConfig in config/schema.py.
|
||||||
Done. Env vars, prefixing, config matching, status display all derive from here.
|
Done. Env vars, config matching, status display all derive from here.
|
||||||
|
|
||||||
Order matters — it controls match priority and fallback. Gateways first.
|
Order matters — it controls match priority and fallback. Gateways first.
|
||||||
Every entry writes out all fields so you can copy-paste as a template.
|
Every entry writes out all fields so you can copy-paste as a template.
|
||||||
@ -12,7 +12,7 @@ Every entry writes out all fields so you can copy-paste as a template.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic.alias_generators import to_snake
|
from pydantic.alias_generators import to_snake
|
||||||
@ -30,12 +30,12 @@ class ProviderSpec:
|
|||||||
# identity
|
# identity
|
||||||
name: str # config field name, e.g. "dashscope"
|
name: str # config field name, e.g. "dashscope"
|
||||||
keywords: tuple[str, ...] # model-name keywords for matching (lowercase)
|
keywords: tuple[str, ...] # model-name keywords for matching (lowercase)
|
||||||
env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY"
|
env_key: str # env var for API key, e.g. "DASHSCOPE_API_KEY"
|
||||||
display_name: str = "" # shown in `nanobot status`
|
display_name: str = "" # shown in `nanobot status`
|
||||||
|
|
||||||
# model prefixing
|
# which provider implementation to use
|
||||||
litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}"
|
# "openai_compat" | "anthropic" | "azure_openai" | "openai_codex"
|
||||||
skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these
|
backend: str = "openai_compat"
|
||||||
|
|
||||||
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
|
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
|
||||||
env_extras: tuple[tuple[str, str], ...] = ()
|
env_extras: tuple[tuple[str, str], ...] = ()
|
||||||
@ -45,19 +45,18 @@ class ProviderSpec:
|
|||||||
is_local: bool = False # local deployment (vLLM, Ollama)
|
is_local: bool = False # local deployment (vLLM, Ollama)
|
||||||
detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-"
|
detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-"
|
||||||
detect_by_base_keyword: str = "" # match substring in api_base URL
|
detect_by_base_keyword: str = "" # match substring in api_base URL
|
||||||
default_api_base: str = "" # fallback base URL
|
default_api_base: str = "" # OpenAI-compatible base URL for this provider
|
||||||
|
|
||||||
# gateway behavior
|
# gateway behavior
|
||||||
strip_model_prefix: bool = False # strip "provider/" before re-prefixing
|
strip_model_prefix: bool = False # strip "provider/" before sending to gateway
|
||||||
litellm_kwargs: dict[str, Any] = field(default_factory=dict) # extra kwargs passed to LiteLLM
|
|
||||||
|
|
||||||
# per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),)
|
# per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),)
|
||||||
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
|
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
|
||||||
|
|
||||||
# OAuth-based providers (e.g., OpenAI Codex) don't use API keys
|
# OAuth-based providers (e.g., OpenAI Codex) don't use API keys
|
||||||
is_oauth: bool = False # if True, uses OAuth flow instead of API key
|
is_oauth: bool = False
|
||||||
|
|
||||||
# Direct providers bypass LiteLLM entirely (e.g., CustomProvider)
|
# Direct providers skip API-key validation (user supplies everything)
|
||||||
is_direct: bool = False
|
is_direct: bool = False
|
||||||
|
|
||||||
# Provider supports cache_control on content blocks (e.g. Anthropic prompt caching)
|
# Provider supports cache_control on content blocks (e.g. Anthropic prompt caching)
|
||||||
@ -73,13 +72,13 @@ class ProviderSpec:
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
PROVIDERS: tuple[ProviderSpec, ...] = (
|
PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||||
# === Custom (direct OpenAI-compatible endpoint, bypasses LiteLLM) ======
|
# === Custom (direct OpenAI-compatible endpoint) ========================
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="custom",
|
name="custom",
|
||||||
keywords=(),
|
keywords=(),
|
||||||
env_key="",
|
env_key="",
|
||||||
display_name="Custom",
|
display_name="Custom",
|
||||||
litellm_prefix="",
|
backend="openai_compat",
|
||||||
is_direct=True,
|
is_direct=True,
|
||||||
),
|
),
|
||||||
|
|
||||||
@ -89,7 +88,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
keywords=("azure", "azure-openai"),
|
keywords=("azure", "azure-openai"),
|
||||||
env_key="",
|
env_key="",
|
||||||
display_name="Azure OpenAI",
|
display_name="Azure OpenAI",
|
||||||
litellm_prefix="",
|
backend="azure_openai",
|
||||||
is_direct=True,
|
is_direct=True,
|
||||||
),
|
),
|
||||||
# === Gateways (detected by api_key / api_base, not model name) =========
|
# === Gateways (detected by api_key / api_base, not model name) =========
|
||||||
@ -100,36 +99,26 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
keywords=("openrouter",),
|
keywords=("openrouter",),
|
||||||
env_key="OPENROUTER_API_KEY",
|
env_key="OPENROUTER_API_KEY",
|
||||||
display_name="OpenRouter",
|
display_name="OpenRouter",
|
||||||
litellm_prefix="openrouter", # anthropic/claude-3 → openrouter/anthropic/claude-3
|
backend="openai_compat",
|
||||||
skip_prefixes=(),
|
|
||||||
env_extras=(),
|
|
||||||
is_gateway=True,
|
is_gateway=True,
|
||||||
is_local=False,
|
|
||||||
detect_by_key_prefix="sk-or-",
|
detect_by_key_prefix="sk-or-",
|
||||||
detect_by_base_keyword="openrouter",
|
detect_by_base_keyword="openrouter",
|
||||||
default_api_base="https://openrouter.ai/api/v1",
|
default_api_base="https://openrouter.ai/api/v1",
|
||||||
strip_model_prefix=False,
|
|
||||||
model_overrides=(),
|
|
||||||
supports_prompt_caching=True,
|
supports_prompt_caching=True,
|
||||||
),
|
),
|
||||||
# AiHubMix: global gateway, OpenAI-compatible interface.
|
# AiHubMix: global gateway, OpenAI-compatible interface.
|
||||||
# strip_model_prefix=True: it doesn't understand "anthropic/claude-3",
|
# strip_model_prefix=True: doesn't understand "anthropic/claude-3",
|
||||||
# so we strip to bare "claude-3" then re-prefix as "openai/claude-3".
|
# strips to bare "claude-3".
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="aihubmix",
|
name="aihubmix",
|
||||||
keywords=("aihubmix",),
|
keywords=("aihubmix",),
|
||||||
env_key="OPENAI_API_KEY", # OpenAI-compatible
|
env_key="OPENAI_API_KEY",
|
||||||
display_name="AiHubMix",
|
display_name="AiHubMix",
|
||||||
litellm_prefix="openai", # → openai/{model}
|
backend="openai_compat",
|
||||||
skip_prefixes=(),
|
|
||||||
env_extras=(),
|
|
||||||
is_gateway=True,
|
is_gateway=True,
|
||||||
is_local=False,
|
|
||||||
detect_by_key_prefix="",
|
|
||||||
detect_by_base_keyword="aihubmix",
|
detect_by_base_keyword="aihubmix",
|
||||||
default_api_base="https://aihubmix.com/v1",
|
default_api_base="https://aihubmix.com/v1",
|
||||||
strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
|
strip_model_prefix=True,
|
||||||
model_overrides=(),
|
|
||||||
),
|
),
|
||||||
# SiliconFlow (硅基流动): OpenAI-compatible gateway, model names keep org prefix
|
# SiliconFlow (硅基流动): OpenAI-compatible gateway, model names keep org prefix
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
@ -137,16 +126,10 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
keywords=("siliconflow",),
|
keywords=("siliconflow",),
|
||||||
env_key="OPENAI_API_KEY",
|
env_key="OPENAI_API_KEY",
|
||||||
display_name="SiliconFlow",
|
display_name="SiliconFlow",
|
||||||
litellm_prefix="openai",
|
backend="openai_compat",
|
||||||
skip_prefixes=(),
|
|
||||||
env_extras=(),
|
|
||||||
is_gateway=True,
|
is_gateway=True,
|
||||||
is_local=False,
|
|
||||||
detect_by_key_prefix="",
|
|
||||||
detect_by_base_keyword="siliconflow",
|
detect_by_base_keyword="siliconflow",
|
||||||
default_api_base="https://api.siliconflow.cn/v1",
|
default_api_base="https://api.siliconflow.cn/v1",
|
||||||
strip_model_prefix=False,
|
|
||||||
model_overrides=(),
|
|
||||||
),
|
),
|
||||||
|
|
||||||
# VolcEngine (火山引擎): OpenAI-compatible gateway, pay-per-use models
|
# VolcEngine (火山引擎): OpenAI-compatible gateway, pay-per-use models
|
||||||
@ -155,16 +138,10 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
keywords=("volcengine", "volces", "ark"),
|
keywords=("volcengine", "volces", "ark"),
|
||||||
env_key="OPENAI_API_KEY",
|
env_key="OPENAI_API_KEY",
|
||||||
display_name="VolcEngine",
|
display_name="VolcEngine",
|
||||||
litellm_prefix="volcengine",
|
backend="openai_compat",
|
||||||
skip_prefixes=(),
|
|
||||||
env_extras=(),
|
|
||||||
is_gateway=True,
|
is_gateway=True,
|
||||||
is_local=False,
|
|
||||||
detect_by_key_prefix="",
|
|
||||||
detect_by_base_keyword="volces",
|
detect_by_base_keyword="volces",
|
||||||
default_api_base="https://ark.cn-beijing.volces.com/api/v3",
|
default_api_base="https://ark.cn-beijing.volces.com/api/v3",
|
||||||
strip_model_prefix=False,
|
|
||||||
model_overrides=(),
|
|
||||||
),
|
),
|
||||||
|
|
||||||
# VolcEngine Coding Plan (火山引擎 Coding Plan): same key as volcengine
|
# VolcEngine Coding Plan (火山引擎 Coding Plan): same key as volcengine
|
||||||
@ -173,16 +150,10 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
keywords=("volcengine-plan",),
|
keywords=("volcengine-plan",),
|
||||||
env_key="OPENAI_API_KEY",
|
env_key="OPENAI_API_KEY",
|
||||||
display_name="VolcEngine Coding Plan",
|
display_name="VolcEngine Coding Plan",
|
||||||
litellm_prefix="volcengine",
|
backend="openai_compat",
|
||||||
skip_prefixes=(),
|
|
||||||
env_extras=(),
|
|
||||||
is_gateway=True,
|
is_gateway=True,
|
||||||
is_local=False,
|
|
||||||
detect_by_key_prefix="",
|
|
||||||
detect_by_base_keyword="",
|
|
||||||
default_api_base="https://ark.cn-beijing.volces.com/api/coding/v3",
|
default_api_base="https://ark.cn-beijing.volces.com/api/coding/v3",
|
||||||
strip_model_prefix=True,
|
strip_model_prefix=True,
|
||||||
model_overrides=(),
|
|
||||||
),
|
),
|
||||||
|
|
||||||
# BytePlus: VolcEngine international, pay-per-use models
|
# BytePlus: VolcEngine international, pay-per-use models
|
||||||
@ -191,16 +162,11 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
keywords=("byteplus",),
|
keywords=("byteplus",),
|
||||||
env_key="OPENAI_API_KEY",
|
env_key="OPENAI_API_KEY",
|
||||||
display_name="BytePlus",
|
display_name="BytePlus",
|
||||||
litellm_prefix="volcengine",
|
backend="openai_compat",
|
||||||
skip_prefixes=(),
|
|
||||||
env_extras=(),
|
|
||||||
is_gateway=True,
|
is_gateway=True,
|
||||||
is_local=False,
|
|
||||||
detect_by_key_prefix="",
|
|
||||||
detect_by_base_keyword="bytepluses",
|
detect_by_base_keyword="bytepluses",
|
||||||
default_api_base="https://ark.ap-southeast.bytepluses.com/api/v3",
|
default_api_base="https://ark.ap-southeast.bytepluses.com/api/v3",
|
||||||
strip_model_prefix=True,
|
strip_model_prefix=True,
|
||||||
model_overrides=(),
|
|
||||||
),
|
),
|
||||||
|
|
||||||
# BytePlus Coding Plan: same key as byteplus
|
# BytePlus Coding Plan: same key as byteplus
|
||||||
@ -209,250 +175,137 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
keywords=("byteplus-plan",),
|
keywords=("byteplus-plan",),
|
||||||
env_key="OPENAI_API_KEY",
|
env_key="OPENAI_API_KEY",
|
||||||
display_name="BytePlus Coding Plan",
|
display_name="BytePlus Coding Plan",
|
||||||
litellm_prefix="volcengine",
|
backend="openai_compat",
|
||||||
skip_prefixes=(),
|
|
||||||
env_extras=(),
|
|
||||||
is_gateway=True,
|
is_gateway=True,
|
||||||
is_local=False,
|
|
||||||
detect_by_key_prefix="",
|
|
||||||
detect_by_base_keyword="",
|
|
||||||
default_api_base="https://ark.ap-southeast.bytepluses.com/api/coding/v3",
|
default_api_base="https://ark.ap-southeast.bytepluses.com/api/coding/v3",
|
||||||
strip_model_prefix=True,
|
strip_model_prefix=True,
|
||||||
model_overrides=(),
|
|
||||||
),
|
),
|
||||||
|
|
||||||
|
|
||||||
# === Standard providers (matched by model-name keywords) ===============
|
# === Standard providers (matched by model-name keywords) ===============
|
||||||
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
|
# Anthropic: native Anthropic SDK
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="anthropic",
|
name="anthropic",
|
||||||
keywords=("anthropic", "claude"),
|
keywords=("anthropic", "claude"),
|
||||||
env_key="ANTHROPIC_API_KEY",
|
env_key="ANTHROPIC_API_KEY",
|
||||||
display_name="Anthropic",
|
display_name="Anthropic",
|
||||||
litellm_prefix="",
|
backend="anthropic",
|
||||||
skip_prefixes=(),
|
|
||||||
env_extras=(),
|
|
||||||
is_gateway=False,
|
|
||||||
is_local=False,
|
|
||||||
detect_by_key_prefix="",
|
|
||||||
detect_by_base_keyword="",
|
|
||||||
default_api_base="",
|
|
||||||
strip_model_prefix=False,
|
|
||||||
model_overrides=(),
|
|
||||||
supports_prompt_caching=True,
|
supports_prompt_caching=True,
|
||||||
),
|
),
|
||||||
# OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed.
|
# OpenAI: SDK default base URL (no override needed)
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="openai",
|
name="openai",
|
||||||
keywords=("openai", "gpt"),
|
keywords=("openai", "gpt"),
|
||||||
env_key="OPENAI_API_KEY",
|
env_key="OPENAI_API_KEY",
|
||||||
display_name="OpenAI",
|
display_name="OpenAI",
|
||||||
litellm_prefix="",
|
backend="openai_compat",
|
||||||
skip_prefixes=(),
|
|
||||||
env_extras=(),
|
|
||||||
is_gateway=False,
|
|
||||||
is_local=False,
|
|
||||||
detect_by_key_prefix="",
|
|
||||||
detect_by_base_keyword="",
|
|
||||||
default_api_base="",
|
|
||||||
strip_model_prefix=False,
|
|
||||||
model_overrides=(),
|
|
||||||
),
|
),
|
||||||
# OpenAI Codex: uses OAuth, not API key.
|
# OpenAI Codex: OAuth-based, dedicated provider
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="openai_codex",
|
name="openai_codex",
|
||||||
keywords=("openai-codex",),
|
keywords=("openai-codex",),
|
||||||
env_key="", # OAuth-based, no API key
|
env_key="",
|
||||||
display_name="OpenAI Codex",
|
display_name="OpenAI Codex",
|
||||||
litellm_prefix="", # Not routed through LiteLLM
|
backend="openai_codex",
|
||||||
skip_prefixes=(),
|
|
||||||
env_extras=(),
|
|
||||||
is_gateway=False,
|
|
||||||
is_local=False,
|
|
||||||
detect_by_key_prefix="",
|
|
||||||
detect_by_base_keyword="codex",
|
detect_by_base_keyword="codex",
|
||||||
default_api_base="https://chatgpt.com/backend-api",
|
default_api_base="https://chatgpt.com/backend-api",
|
||||||
strip_model_prefix=False,
|
is_oauth=True,
|
||||||
model_overrides=(),
|
|
||||||
is_oauth=True, # OAuth-based authentication
|
|
||||||
),
|
),
|
||||||
# Github Copilot: uses OAuth, not API key.
|
# GitHub Copilot: OAuth-based
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="github_copilot",
|
name="github_copilot",
|
||||||
keywords=("github_copilot", "copilot"),
|
keywords=("github_copilot", "copilot"),
|
||||||
env_key="", # OAuth-based, no API key
|
env_key="",
|
||||||
display_name="Github Copilot",
|
display_name="Github Copilot",
|
||||||
litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model
|
backend="openai_compat",
|
||||||
skip_prefixes=("github_copilot/",),
|
default_api_base="https://api.githubcopilot.com",
|
||||||
env_extras=(),
|
is_oauth=True,
|
||||||
is_gateway=False,
|
|
||||||
is_local=False,
|
|
||||||
detect_by_key_prefix="",
|
|
||||||
detect_by_base_keyword="",
|
|
||||||
default_api_base="",
|
|
||||||
strip_model_prefix=False,
|
|
||||||
model_overrides=(),
|
|
||||||
is_oauth=True, # OAuth-based authentication
|
|
||||||
),
|
),
|
||||||
# DeepSeek: needs "deepseek/" prefix for LiteLLM routing.
|
# DeepSeek: OpenAI-compatible at api.deepseek.com
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="deepseek",
|
name="deepseek",
|
||||||
keywords=("deepseek",),
|
keywords=("deepseek",),
|
||||||
env_key="DEEPSEEK_API_KEY",
|
env_key="DEEPSEEK_API_KEY",
|
||||||
display_name="DeepSeek",
|
display_name="DeepSeek",
|
||||||
litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat
|
backend="openai_compat",
|
||||||
skip_prefixes=("deepseek/",), # avoid double-prefix
|
default_api_base="https://api.deepseek.com",
|
||||||
env_extras=(),
|
|
||||||
is_gateway=False,
|
|
||||||
is_local=False,
|
|
||||||
detect_by_key_prefix="",
|
|
||||||
detect_by_base_keyword="",
|
|
||||||
default_api_base="",
|
|
||||||
strip_model_prefix=False,
|
|
||||||
model_overrides=(),
|
|
||||||
),
|
),
|
||||||
# Gemini: needs "gemini/" prefix for LiteLLM.
|
# Gemini: Google's OpenAI-compatible endpoint
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="gemini",
|
name="gemini",
|
||||||
keywords=("gemini",),
|
keywords=("gemini",),
|
||||||
env_key="GEMINI_API_KEY",
|
env_key="GEMINI_API_KEY",
|
||||||
display_name="Gemini",
|
display_name="Gemini",
|
||||||
litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro
|
backend="openai_compat",
|
||||||
skip_prefixes=("gemini/",), # avoid double-prefix
|
default_api_base="https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||||
env_extras=(),
|
|
||||||
is_gateway=False,
|
|
||||||
is_local=False,
|
|
||||||
detect_by_key_prefix="",
|
|
||||||
detect_by_base_keyword="",
|
|
||||||
default_api_base="",
|
|
||||||
strip_model_prefix=False,
|
|
||||||
model_overrides=(),
|
|
||||||
),
|
),
|
||||||
# Zhipu: LiteLLM uses "zai/" prefix.
|
# Zhipu (智谱): OpenAI-compatible at open.bigmodel.cn
|
||||||
# Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that).
|
|
||||||
# skip_prefixes: don't add "zai/" when already routed via gateway.
|
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="zhipu",
|
name="zhipu",
|
||||||
keywords=("zhipu", "glm", "zai"),
|
keywords=("zhipu", "glm", "zai"),
|
||||||
env_key="ZAI_API_KEY",
|
env_key="ZAI_API_KEY",
|
||||||
display_name="Zhipu AI",
|
display_name="Zhipu AI",
|
||||||
litellm_prefix="zai", # glm-4 → zai/glm-4
|
backend="openai_compat",
|
||||||
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
|
|
||||||
env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),),
|
env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),),
|
||||||
is_gateway=False,
|
default_api_base="https://open.bigmodel.cn/api/paas/v4",
|
||||||
is_local=False,
|
|
||||||
detect_by_key_prefix="",
|
|
||||||
detect_by_base_keyword="",
|
|
||||||
default_api_base="",
|
|
||||||
strip_model_prefix=False,
|
|
||||||
model_overrides=(),
|
|
||||||
),
|
),
|
||||||
# DashScope: Qwen models, needs "dashscope/" prefix.
|
# DashScope (通义): Qwen models, OpenAI-compatible endpoint
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="dashscope",
|
name="dashscope",
|
||||||
keywords=("qwen", "dashscope"),
|
keywords=("qwen", "dashscope"),
|
||||||
env_key="DASHSCOPE_API_KEY",
|
env_key="DASHSCOPE_API_KEY",
|
||||||
display_name="DashScope",
|
display_name="DashScope",
|
||||||
litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max
|
backend="openai_compat",
|
||||||
skip_prefixes=("dashscope/", "openrouter/"),
|
default_api_base="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
env_extras=(),
|
|
||||||
is_gateway=False,
|
|
||||||
is_local=False,
|
|
||||||
detect_by_key_prefix="",
|
|
||||||
detect_by_base_keyword="",
|
|
||||||
default_api_base="",
|
|
||||||
strip_model_prefix=False,
|
|
||||||
model_overrides=(),
|
|
||||||
),
|
),
|
||||||
# Moonshot: Kimi models, needs "moonshot/" prefix.
|
# Moonshot (月之暗面): Kimi models. K2.5 enforces temperature >= 1.0.
|
||||||
# LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint.
|
|
||||||
# Kimi K2.5 API enforces temperature >= 1.0.
|
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="moonshot",
|
name="moonshot",
|
||||||
keywords=("moonshot", "kimi"),
|
keywords=("moonshot", "kimi"),
|
||||||
env_key="MOONSHOT_API_KEY",
|
env_key="MOONSHOT_API_KEY",
|
||||||
display_name="Moonshot",
|
display_name="Moonshot",
|
||||||
litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
|
backend="openai_compat",
|
||||||
skip_prefixes=("moonshot/", "openrouter/"),
|
default_api_base="https://api.moonshot.ai/v1",
|
||||||
env_extras=(("MOONSHOT_API_BASE", "{api_base}"),),
|
|
||||||
is_gateway=False,
|
|
||||||
is_local=False,
|
|
||||||
detect_by_key_prefix="",
|
|
||||||
detect_by_base_keyword="",
|
|
||||||
default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
|
|
||||||
strip_model_prefix=False,
|
|
||||||
model_overrides=(("kimi-k2.5", {"temperature": 1.0}),),
|
model_overrides=(("kimi-k2.5", {"temperature": 1.0}),),
|
||||||
),
|
),
|
||||||
# MiniMax: needs "minimax/" prefix for LiteLLM routing.
|
# MiniMax: OpenAI-compatible API
|
||||||
# Uses OpenAI-compatible API at api.minimax.io/v1.
|
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="minimax",
|
name="minimax",
|
||||||
keywords=("minimax",),
|
keywords=("minimax",),
|
||||||
env_key="MINIMAX_API_KEY",
|
env_key="MINIMAX_API_KEY",
|
||||||
display_name="MiniMax",
|
display_name="MiniMax",
|
||||||
litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
|
backend="openai_compat",
|
||||||
skip_prefixes=("minimax/", "openrouter/"),
|
|
||||||
env_extras=(),
|
|
||||||
is_gateway=False,
|
|
||||||
is_local=False,
|
|
||||||
detect_by_key_prefix="",
|
|
||||||
detect_by_base_keyword="",
|
|
||||||
default_api_base="https://api.minimax.io/v1",
|
default_api_base="https://api.minimax.io/v1",
|
||||||
strip_model_prefix=False,
|
|
||||||
model_overrides=(),
|
|
||||||
),
|
),
|
||||||
# Mistral AI: OpenAI-compatible API at api.mistral.ai/v1.
|
# Mistral AI: OpenAI-compatible API
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="mistral",
|
name="mistral",
|
||||||
keywords=("mistral",),
|
keywords=("mistral",),
|
||||||
env_key="MISTRAL_API_KEY",
|
env_key="MISTRAL_API_KEY",
|
||||||
display_name="Mistral",
|
display_name="Mistral",
|
||||||
litellm_prefix="mistral", # mistral-large-latest → mistral/mistral-large-latest
|
backend="openai_compat",
|
||||||
skip_prefixes=("mistral/",), # avoid double-prefix
|
|
||||||
env_extras=(),
|
|
||||||
is_gateway=False,
|
|
||||||
is_local=False,
|
|
||||||
detect_by_key_prefix="",
|
|
||||||
detect_by_base_keyword="",
|
|
||||||
default_api_base="https://api.mistral.ai/v1",
|
default_api_base="https://api.mistral.ai/v1",
|
||||||
strip_model_prefix=False,
|
|
||||||
model_overrides=(),
|
|
||||||
),
|
),
|
||||||
# === Local deployment (matched by config key, NOT by api_base) =========
|
# === Local deployment (matched by config key, NOT by api_base) =========
|
||||||
# vLLM / any OpenAI-compatible local server.
|
# vLLM / any OpenAI-compatible local server
|
||||||
# Detected when config key is "vllm" (provider_name="vllm").
|
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="vllm",
|
name="vllm",
|
||||||
keywords=("vllm",),
|
keywords=("vllm",),
|
||||||
env_key="HOSTED_VLLM_API_KEY",
|
env_key="HOSTED_VLLM_API_KEY",
|
||||||
display_name="vLLM/Local",
|
display_name="vLLM/Local",
|
||||||
litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B
|
backend="openai_compat",
|
||||||
skip_prefixes=(),
|
|
||||||
env_extras=(),
|
|
||||||
is_gateway=False,
|
|
||||||
is_local=True,
|
is_local=True,
|
||||||
detect_by_key_prefix="",
|
|
||||||
detect_by_base_keyword="",
|
|
||||||
default_api_base="", # user must provide in config
|
|
||||||
strip_model_prefix=False,
|
|
||||||
model_overrides=(),
|
|
||||||
),
|
),
|
||||||
# === Ollama (local, OpenAI-compatible) ===================================
|
# Ollama (local, OpenAI-compatible)
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="ollama",
|
name="ollama",
|
||||||
keywords=("ollama", "nemotron"),
|
keywords=("ollama", "nemotron"),
|
||||||
env_key="OLLAMA_API_KEY",
|
env_key="OLLAMA_API_KEY",
|
||||||
display_name="Ollama",
|
display_name="Ollama",
|
||||||
litellm_prefix="ollama_chat", # model → ollama_chat/model
|
backend="openai_compat",
|
||||||
skip_prefixes=("ollama/", "ollama_chat/"),
|
|
||||||
env_extras=(),
|
|
||||||
is_gateway=False,
|
|
||||||
is_local=True,
|
is_local=True,
|
||||||
detect_by_key_prefix="",
|
|
||||||
detect_by_base_keyword="11434",
|
detect_by_base_keyword="11434",
|
||||||
default_api_base="http://localhost:11434",
|
default_api_base="http://localhost:11434/v1",
|
||||||
strip_model_prefix=False,
|
|
||||||
model_overrides=(),
|
|
||||||
),
|
),
|
||||||
# === OpenVINO Model Server (direct, local, OpenAI-compatible at /v3) ===
|
# === OpenVINO Model Server (direct, local, OpenAI-compatible at /v3) ===
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
@ -460,29 +313,20 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
keywords=("openvino", "ovms"),
|
keywords=("openvino", "ovms"),
|
||||||
env_key="",
|
env_key="",
|
||||||
display_name="OpenVINO Model Server",
|
display_name="OpenVINO Model Server",
|
||||||
litellm_prefix="",
|
backend="openai_compat",
|
||||||
is_direct=True,
|
is_direct=True,
|
||||||
is_local=True,
|
is_local=True,
|
||||||
default_api_base="http://localhost:8000/v3",
|
default_api_base="http://localhost:8000/v3",
|
||||||
),
|
),
|
||||||
# === Auxiliary (not a primary LLM provider) ============================
|
# === Auxiliary (not a primary LLM provider) ============================
|
||||||
# Groq: mainly used for Whisper voice transcription, also usable for LLM.
|
# Groq: mainly used for Whisper voice transcription, also usable for LLM
|
||||||
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
|
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="groq",
|
name="groq",
|
||||||
keywords=("groq",),
|
keywords=("groq",),
|
||||||
env_key="GROQ_API_KEY",
|
env_key="GROQ_API_KEY",
|
||||||
display_name="Groq",
|
display_name="Groq",
|
||||||
litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192
|
backend="openai_compat",
|
||||||
skip_prefixes=("groq/",), # avoid double-prefix
|
default_api_base="https://api.groq.com/openai/v1",
|
||||||
env_extras=(),
|
|
||||||
is_gateway=False,
|
|
||||||
is_local=False,
|
|
||||||
detect_by_key_prefix="",
|
|
||||||
detect_by_base_keyword="",
|
|
||||||
default_api_base="",
|
|
||||||
strip_model_prefix=False,
|
|
||||||
model_overrides=(),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -492,59 +336,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def find_by_model(model: str) -> ProviderSpec | None:
|
|
||||||
"""Match a standard provider by model-name keyword (case-insensitive).
|
|
||||||
Skips gateways/local — those are matched by api_key/api_base instead."""
|
|
||||||
model_lower = model.lower()
|
|
||||||
model_normalized = model_lower.replace("-", "_")
|
|
||||||
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
|
|
||||||
normalized_prefix = model_prefix.replace("-", "_")
|
|
||||||
std_specs = [s for s in PROVIDERS if not s.is_gateway and not s.is_local]
|
|
||||||
|
|
||||||
# Prefer explicit provider prefix — prevents `github-copilot/...codex` matching openai_codex.
|
|
||||||
for spec in std_specs:
|
|
||||||
if model_prefix and normalized_prefix == spec.name:
|
|
||||||
return spec
|
|
||||||
|
|
||||||
for spec in std_specs:
|
|
||||||
if any(
|
|
||||||
kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords
|
|
||||||
):
|
|
||||||
return spec
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def find_gateway(
|
|
||||||
provider_name: str | None = None,
|
|
||||||
api_key: str | None = None,
|
|
||||||
api_base: str | None = None,
|
|
||||||
) -> ProviderSpec | None:
|
|
||||||
"""Detect gateway/local provider.
|
|
||||||
|
|
||||||
Priority:
|
|
||||||
1. provider_name — if it maps to a gateway/local spec, use it directly.
|
|
||||||
2. api_key prefix — e.g. "sk-or-" → OpenRouter.
|
|
||||||
3. api_base keyword — e.g. "aihubmix" in URL → AiHubMix.
|
|
||||||
|
|
||||||
A standard provider with a custom api_base (e.g. DeepSeek behind a proxy)
|
|
||||||
will NOT be mistaken for vLLM — the old fallback is gone.
|
|
||||||
"""
|
|
||||||
# 1. Direct match by config key
|
|
||||||
if provider_name:
|
|
||||||
spec = find_by_name(provider_name)
|
|
||||||
if spec and (spec.is_gateway or spec.is_local):
|
|
||||||
return spec
|
|
||||||
|
|
||||||
# 2. Auto-detect by api_key prefix / api_base keyword
|
|
||||||
for spec in PROVIDERS:
|
|
||||||
if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix):
|
|
||||||
return spec
|
|
||||||
if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base:
|
|
||||||
return spec
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def find_by_name(name: str) -> ProviderSpec | None:
|
def find_by_name(name: str) -> ProviderSpec | None:
|
||||||
"""Find a provider spec by config field name, e.g. "dashscope"."""
|
"""Find a provider spec by config field name, e.g. "dashscope"."""
|
||||||
normalized = to_snake(name.replace("-", "_"))
|
normalized = to_snake(name.replace("-", "_"))
|
||||||
|
|||||||
@ -19,7 +19,7 @@ classifiers = [
|
|||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"typer>=0.20.0,<1.0.0",
|
"typer>=0.20.0,<1.0.0",
|
||||||
"litellm>=1.82.1,<2.0.0",
|
"anthropic>=0.45.0,<1.0.0",
|
||||||
"pydantic>=2.12.0,<3.0.0",
|
"pydantic>=2.12.0,<3.0.0",
|
||||||
"pydantic-settings>=2.12.0,<3.0.0",
|
"pydantic-settings>=2.12.0,<3.0.0",
|
||||||
"websockets>=16.0,<17.0",
|
"websockets>=16.0,<17.0",
|
||||||
|
|||||||
@ -1,53 +1,200 @@
|
|||||||
|
"""Tests for Gemini thought_signature round-trip through extra_content.
|
||||||
|
|
||||||
|
The Gemini OpenAI-compatibility API returns tool calls with an extra_content
|
||||||
|
field: ``{"google": {"thought_signature": "..."}}``. This MUST survive the
|
||||||
|
parse → serialize round-trip so the model can continue reasoning.
|
||||||
|
"""
|
||||||
|
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from nanobot.providers.base import ToolCallRequest
|
from nanobot.providers.base import ToolCallRequest
|
||||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||||
|
|
||||||
|
|
||||||
def test_litellm_parse_response_preserves_tool_call_provider_fields() -> None:
|
GEMINI_EXTRA = {"google": {"thought_signature": "sig-abc-123"}}
|
||||||
provider = LiteLLMProvider(default_model="gemini/gemini-3-flash")
|
|
||||||
|
|
||||||
response = SimpleNamespace(
|
|
||||||
choices=[
|
|
||||||
SimpleNamespace(
|
|
||||||
finish_reason="tool_calls",
|
|
||||||
message=SimpleNamespace(
|
|
||||||
content=None,
|
|
||||||
tool_calls=[
|
|
||||||
SimpleNamespace(
|
|
||||||
id="call_123",
|
|
||||||
function=SimpleNamespace(
|
|
||||||
name="read_file",
|
|
||||||
arguments='{"path":"todo.md"}',
|
|
||||||
provider_specific_fields={"inner": "value"},
|
|
||||||
),
|
|
||||||
provider_specific_fields={"thought_signature": "signed-token"},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
usage=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
parsed = provider._parse_response(response)
|
|
||||||
|
|
||||||
assert len(parsed.tool_calls) == 1
|
|
||||||
assert parsed.tool_calls[0].provider_specific_fields == {"thought_signature": "signed-token"}
|
|
||||||
assert parsed.tool_calls[0].function_provider_specific_fields == {"inner": "value"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_tool_call_request_serializes_provider_fields() -> None:
|
# ── ToolCallRequest serialization ──────────────────────────────────────
|
||||||
tool_call = ToolCallRequest(
|
|
||||||
|
def test_tool_call_request_serializes_extra_content() -> None:
|
||||||
|
tc = ToolCallRequest(
|
||||||
id="abc123xyz",
|
id="abc123xyz",
|
||||||
name="read_file",
|
name="read_file",
|
||||||
arguments={"path": "todo.md"},
|
arguments={"path": "todo.md"},
|
||||||
provider_specific_fields={"thought_signature": "signed-token"},
|
extra_content=GEMINI_EXTRA,
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = tc.to_openai_tool_call()
|
||||||
|
|
||||||
|
assert payload["extra_content"] == GEMINI_EXTRA
|
||||||
|
assert payload["function"]["arguments"] == '{"path": "todo.md"}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_call_request_serializes_provider_fields() -> None:
|
||||||
|
tc = ToolCallRequest(
|
||||||
|
id="abc123xyz",
|
||||||
|
name="read_file",
|
||||||
|
arguments={"path": "todo.md"},
|
||||||
|
provider_specific_fields={"custom_key": "custom_val"},
|
||||||
function_provider_specific_fields={"inner": "value"},
|
function_provider_specific_fields={"inner": "value"},
|
||||||
)
|
)
|
||||||
|
|
||||||
message = tool_call.to_openai_tool_call()
|
payload = tc.to_openai_tool_call()
|
||||||
|
|
||||||
assert message["provider_specific_fields"] == {"thought_signature": "signed-token"}
|
assert payload["provider_specific_fields"] == {"custom_key": "custom_val"}
|
||||||
assert message["function"]["provider_specific_fields"] == {"inner": "value"}
|
assert payload["function"]["provider_specific_fields"] == {"inner": "value"}
|
||||||
assert message["function"]["arguments"] == '{"path": "todo.md"}'
|
|
||||||
|
|
||||||
|
def test_tool_call_request_omits_absent_extras() -> None:
|
||||||
|
tc = ToolCallRequest(id="x", name="fn", arguments={})
|
||||||
|
payload = tc.to_openai_tool_call()
|
||||||
|
|
||||||
|
assert "extra_content" not in payload
|
||||||
|
assert "provider_specific_fields" not in payload
|
||||||
|
assert "provider_specific_fields" not in payload["function"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── _parse: SDK-object branch ──────────────────────────────────────────
|
||||||
|
|
||||||
|
def _make_sdk_response_with_extra_content():
|
||||||
|
"""Simulate a Gemini response via the OpenAI SDK (SimpleNamespace)."""
|
||||||
|
fn = SimpleNamespace(name="get_weather", arguments='{"city":"Tokyo"}')
|
||||||
|
tc = SimpleNamespace(
|
||||||
|
id="call_1",
|
||||||
|
index=0,
|
||||||
|
type="function",
|
||||||
|
function=fn,
|
||||||
|
extra_content=GEMINI_EXTRA,
|
||||||
|
)
|
||||||
|
msg = SimpleNamespace(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[tc],
|
||||||
|
reasoning_content=None,
|
||||||
|
)
|
||||||
|
choice = SimpleNamespace(message=msg, finish_reason="tool_calls")
|
||||||
|
usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
|
||||||
|
return SimpleNamespace(choices=[choice], usage=usage)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_sdk_object_preserves_extra_content() -> None:
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
provider = OpenAICompatProvider()
|
||||||
|
|
||||||
|
result = provider._parse(_make_sdk_response_with_extra_content())
|
||||||
|
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
tc = result.tool_calls[0]
|
||||||
|
assert tc.name == "get_weather"
|
||||||
|
assert tc.extra_content == GEMINI_EXTRA
|
||||||
|
|
||||||
|
payload = tc.to_openai_tool_call()
|
||||||
|
assert payload["extra_content"] == GEMINI_EXTRA
|
||||||
|
|
||||||
|
|
||||||
|
# ── _parse: dict/mapping branch ───────────────────────────────────────
|
||||||
|
|
||||||
|
def test_parse_dict_preserves_extra_content() -> None:
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
provider = OpenAICompatProvider()
|
||||||
|
|
||||||
|
response_dict = {
|
||||||
|
"choices": [{
|
||||||
|
"message": {
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": "call_1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "get_weather", "arguments": '{"city":"Tokyo"}'},
|
||||||
|
"extra_content": GEMINI_EXTRA,
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
"finish_reason": "tool_calls",
|
||||||
|
}],
|
||||||
|
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = provider._parse(response_dict)
|
||||||
|
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
tc = result.tool_calls[0]
|
||||||
|
assert tc.name == "get_weather"
|
||||||
|
assert tc.extra_content == GEMINI_EXTRA
|
||||||
|
|
||||||
|
payload = tc.to_openai_tool_call()
|
||||||
|
assert payload["extra_content"] == GEMINI_EXTRA
|
||||||
|
|
||||||
|
|
||||||
|
# ── _parse_chunks: streaming round-trip ───────────────────────────────
|
||||||
|
|
||||||
|
def test_parse_chunks_sdk_preserves_extra_content() -> None:
|
||||||
|
fn_delta = SimpleNamespace(name="get_weather", arguments='{"city":"Tokyo"}')
|
||||||
|
tc_delta = SimpleNamespace(
|
||||||
|
id="call_1",
|
||||||
|
index=0,
|
||||||
|
function=fn_delta,
|
||||||
|
extra_content=GEMINI_EXTRA,
|
||||||
|
)
|
||||||
|
delta = SimpleNamespace(content=None, tool_calls=[tc_delta])
|
||||||
|
choice = SimpleNamespace(finish_reason="tool_calls", delta=delta)
|
||||||
|
chunk = SimpleNamespace(choices=[choice], usage=None)
|
||||||
|
|
||||||
|
result = OpenAICompatProvider._parse_chunks([chunk])
|
||||||
|
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
tc = result.tool_calls[0]
|
||||||
|
assert tc.extra_content == GEMINI_EXTRA
|
||||||
|
|
||||||
|
payload = tc.to_openai_tool_call()
|
||||||
|
assert payload["extra_content"] == GEMINI_EXTRA
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_chunks_dict_preserves_extra_content() -> None:
|
||||||
|
chunk = {
|
||||||
|
"choices": [{
|
||||||
|
"finish_reason": "tool_calls",
|
||||||
|
"delta": {
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{
|
||||||
|
"index": 0,
|
||||||
|
"id": "call_1",
|
||||||
|
"function": {"name": "get_weather", "arguments": '{"city":"Tokyo"}'},
|
||||||
|
"extra_content": GEMINI_EXTRA,
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = OpenAICompatProvider._parse_chunks([chunk])
|
||||||
|
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
tc = result.tool_calls[0]
|
||||||
|
assert tc.extra_content == GEMINI_EXTRA
|
||||||
|
|
||||||
|
payload = tc.to_openai_tool_call()
|
||||||
|
assert payload["extra_content"] == GEMINI_EXTRA
|
||||||
|
|
||||||
|
|
||||||
|
# ── Model switching: stale extras shouldn't break other providers ─────
|
||||||
|
|
||||||
|
def test_stale_extra_content_in_tool_calls_survives_sanitize() -> None:
|
||||||
|
"""When switching from Gemini to OpenAI, extra_content inside tool_calls
|
||||||
|
should survive message sanitization (it lives inside the tool_call dict,
|
||||||
|
not at message level, so it bypasses _ALLOWED_MSG_KEYS filtering)."""
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
provider = OpenAICompatProvider()
|
||||||
|
|
||||||
|
messages = [{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": "call_1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "fn", "arguments": "{}"},
|
||||||
|
"extra_content": GEMINI_EXTRA,
|
||||||
|
}],
|
||||||
|
}]
|
||||||
|
|
||||||
|
sanitized = provider._sanitize_messages(messages)
|
||||||
|
|
||||||
|
assert sanitized[0]["tool_calls"][0]["extra_content"] == GEMINI_EXTRA
|
||||||
|
|||||||
@ -380,7 +380,7 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
"""Forced tool_choice rejected by provider -> retry with auto and succeed."""
|
"""Forced tool_choice rejected by provider -> retry with auto and succeed."""
|
||||||
store = MemoryStore(tmp_path)
|
store = MemoryStore(tmp_path)
|
||||||
error_resp = LLMResponse(
|
error_resp = LLMResponse(
|
||||||
content="Error calling LLM: litellm.BadRequestError: "
|
content="Error calling LLM: BadRequestError: "
|
||||||
"The tool_choice parameter does not support being set to required or object",
|
"The tool_choice parameter does not support being set to required or object",
|
||||||
finish_reason="error",
|
finish_reason="error",
|
||||||
tool_calls=[],
|
tool_calls=[],
|
||||||
|
|||||||
@ -1,4 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -8,6 +11,7 @@ from nanobot.channels.weixin import (
|
|||||||
ITEM_IMAGE,
|
ITEM_IMAGE,
|
||||||
ITEM_TEXT,
|
ITEM_TEXT,
|
||||||
MESSAGE_TYPE_BOT,
|
MESSAGE_TYPE_BOT,
|
||||||
|
WEIXIN_CHANNEL_VERSION,
|
||||||
WeixinChannel,
|
WeixinChannel,
|
||||||
WeixinConfig,
|
WeixinConfig,
|
||||||
)
|
)
|
||||||
@ -16,12 +20,58 @@ from nanobot.channels.weixin import (
|
|||||||
def _make_channel() -> tuple[WeixinChannel, MessageBus]:
|
def _make_channel() -> tuple[WeixinChannel, MessageBus]:
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
channel = WeixinChannel(
|
channel = WeixinChannel(
|
||||||
WeixinConfig(enabled=True, allow_from=["*"]),
|
WeixinConfig(
|
||||||
|
enabled=True,
|
||||||
|
allow_from=["*"],
|
||||||
|
state_dir=tempfile.mkdtemp(prefix="nanobot-weixin-test-"),
|
||||||
|
),
|
||||||
bus,
|
bus,
|
||||||
)
|
)
|
||||||
return channel, bus
|
return channel, bus
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_headers_includes_route_tag_when_configured() -> None:
|
||||||
|
bus = MessageBus()
|
||||||
|
channel = WeixinChannel(
|
||||||
|
WeixinConfig(enabled=True, allow_from=["*"], route_tag=123),
|
||||||
|
bus,
|
||||||
|
)
|
||||||
|
channel._token = "token"
|
||||||
|
|
||||||
|
headers = channel._make_headers()
|
||||||
|
|
||||||
|
assert headers["Authorization"] == "Bearer token"
|
||||||
|
assert headers["SKRouteTag"] == "123"
|
||||||
|
|
||||||
|
|
||||||
|
def test_channel_version_matches_reference_plugin_version() -> None:
|
||||||
|
assert WEIXIN_CHANNEL_VERSION == "1.0.3"
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_and_load_state_persists_context_tokens(tmp_path) -> None:
|
||||||
|
bus = MessageBus()
|
||||||
|
channel = WeixinChannel(
|
||||||
|
WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)),
|
||||||
|
bus,
|
||||||
|
)
|
||||||
|
channel._token = "token"
|
||||||
|
channel._get_updates_buf = "cursor"
|
||||||
|
channel._context_tokens = {"wx-user": "ctx-1"}
|
||||||
|
|
||||||
|
channel._save_state()
|
||||||
|
|
||||||
|
saved = json.loads((tmp_path / "account.json").read_text())
|
||||||
|
assert saved["context_tokens"] == {"wx-user": "ctx-1"}
|
||||||
|
|
||||||
|
restored = WeixinChannel(
|
||||||
|
WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)),
|
||||||
|
bus,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert restored._load_state() is True
|
||||||
|
assert restored._context_tokens == {"wx-user": "ctx-1"}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_message_deduplicates_inbound_ids() -> None:
|
async def test_process_message_deduplicates_inbound_ids() -> None:
|
||||||
channel, bus = _make_channel()
|
channel, bus = _make_channel()
|
||||||
@ -71,6 +121,30 @@ async def test_process_message_caches_context_token_and_send_uses_it() -> None:
|
|||||||
channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-2")
|
channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-2")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_message_persists_context_token_to_state_file(tmp_path) -> None:
|
||||||
|
bus = MessageBus()
|
||||||
|
channel = WeixinChannel(
|
||||||
|
WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)),
|
||||||
|
bus,
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel._process_message(
|
||||||
|
{
|
||||||
|
"message_type": 1,
|
||||||
|
"message_id": "m2b",
|
||||||
|
"from_user_id": "wx-user",
|
||||||
|
"context_token": "ctx-2b",
|
||||||
|
"item_list": [
|
||||||
|
{"type": ITEM_TEXT, "text_item": {"text": "ping"}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
saved = json.loads((tmp_path / "account.json").read_text())
|
||||||
|
assert saved["context_tokens"] == {"wx-user": "ctx-2b"}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_message_extracts_media_and_preserves_paths() -> None:
|
async def test_process_message_extracts_media_and_preserves_paths() -> None:
|
||||||
channel, bus = _make_channel()
|
channel, bus = _make_channel()
|
||||||
@ -109,6 +183,85 @@ async def test_send_without_context_token_does_not_send_text() -> None:
|
|||||||
channel._send_text.assert_not_awaited()
|
channel._send_text.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_does_not_send_when_session_is_paused() -> None:
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
|
channel._context_tokens["wx-user"] = "ctx-2"
|
||||||
|
channel._pause_session(60)
|
||||||
|
channel._send_text = AsyncMock()
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})()
|
||||||
|
)
|
||||||
|
|
||||||
|
channel._send_text.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_poll_once_pauses_session_on_expired_errcode() -> None:
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = SimpleNamespace(timeout=None)
|
||||||
|
channel._token = "token"
|
||||||
|
channel._api_post = AsyncMock(return_value={"ret": 0, "errcode": -14, "errmsg": "expired"})
|
||||||
|
|
||||||
|
await channel._poll_once()
|
||||||
|
|
||||||
|
assert channel._session_pause_remaining_s() > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_qr_login_refreshes_expired_qr_and_then_succeeds() -> None:
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._running = True
|
||||||
|
channel._save_state = lambda: None
|
||||||
|
channel._print_qr_code = lambda url: None
|
||||||
|
channel._api_get = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
{"qrcode": "qr-1", "qrcode_img_content": "url-1"},
|
||||||
|
{"status": "expired"},
|
||||||
|
{"qrcode": "qr-2", "qrcode_img_content": "url-2"},
|
||||||
|
{
|
||||||
|
"status": "confirmed",
|
||||||
|
"bot_token": "token-2",
|
||||||
|
"ilink_bot_id": "bot-2",
|
||||||
|
"baseurl": "https://example.test",
|
||||||
|
"ilink_user_id": "wx-user",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
ok = await channel._qr_login()
|
||||||
|
|
||||||
|
assert ok is True
|
||||||
|
assert channel._token == "token-2"
|
||||||
|
assert channel.config.base_url == "https://example.test"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_qr_login_returns_false_after_too_many_expired_qr_codes() -> None:
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._running = True
|
||||||
|
channel._print_qr_code = lambda url: None
|
||||||
|
channel._api_get = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
{"qrcode": "qr-1", "qrcode_img_content": "url-1"},
|
||||||
|
{"status": "expired"},
|
||||||
|
{"qrcode": "qr-2", "qrcode_img_content": "url-2"},
|
||||||
|
{"status": "expired"},
|
||||||
|
{"qrcode": "qr-3", "qrcode_img_content": "url-3"},
|
||||||
|
{"status": "expired"},
|
||||||
|
{"qrcode": "qr-4", "qrcode_img_content": "url-4"},
|
||||||
|
{"status": "expired"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
ok = await channel._qr_login()
|
||||||
|
|
||||||
|
assert ok is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_message_skips_bot_messages() -> None:
|
async def test_process_message_skips_bot_messages() -> None:
|
||||||
channel, bus = _make_channel()
|
channel, bus = _make_channel()
|
||||||
|
|||||||
@ -9,9 +9,8 @@ from typer.testing import CliRunner
|
|||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.cli.commands import _make_provider, app
|
from nanobot.cli.commands import _make_provider, app
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
|
||||||
from nanobot.providers.openai_codex_provider import _strip_model_prefix
|
from nanobot.providers.openai_codex_provider import _strip_model_prefix
|
||||||
from nanobot.providers.registry import find_by_model, find_by_name
|
from nanobot.providers.registry import find_by_name
|
||||||
|
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
|
|
||||||
@ -228,7 +227,7 @@ def test_config_matches_explicit_ollama_prefix_without_api_key():
|
|||||||
config.agents.defaults.model = "ollama/llama3.2"
|
config.agents.defaults.model = "ollama/llama3.2"
|
||||||
|
|
||||||
assert config.get_provider_name() == "ollama"
|
assert config.get_provider_name() == "ollama"
|
||||||
assert config.get_api_base() == "http://localhost:11434"
|
assert config.get_api_base() == "http://localhost:11434/v1"
|
||||||
|
|
||||||
|
|
||||||
def test_config_explicit_ollama_provider_uses_default_localhost_api_base():
|
def test_config_explicit_ollama_provider_uses_default_localhost_api_base():
|
||||||
@ -237,7 +236,7 @@ def test_config_explicit_ollama_provider_uses_default_localhost_api_base():
|
|||||||
config.agents.defaults.model = "llama3.2"
|
config.agents.defaults.model = "llama3.2"
|
||||||
|
|
||||||
assert config.get_provider_name() == "ollama"
|
assert config.get_provider_name() == "ollama"
|
||||||
assert config.get_api_base() == "http://localhost:11434"
|
assert config.get_api_base() == "http://localhost:11434/v1"
|
||||||
|
|
||||||
|
|
||||||
def test_config_accepts_camel_case_explicit_provider_name_for_coding_plan():
|
def test_config_accepts_camel_case_explicit_provider_name_for_coding_plan():
|
||||||
@ -272,12 +271,12 @@ def test_config_auto_detects_ollama_from_local_api_base():
|
|||||||
config = Config.model_validate(
|
config = Config.model_validate(
|
||||||
{
|
{
|
||||||
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
|
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
|
||||||
"providers": {"ollama": {"apiBase": "http://localhost:11434"}},
|
"providers": {"ollama": {"apiBase": "http://localhost:11434/v1"}},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert config.get_provider_name() == "ollama"
|
assert config.get_provider_name() == "ollama"
|
||||||
assert config.get_api_base() == "http://localhost:11434"
|
assert config.get_api_base() == "http://localhost:11434/v1"
|
||||||
|
|
||||||
|
|
||||||
def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured():
|
def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured():
|
||||||
@ -286,13 +285,13 @@ def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured():
|
|||||||
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
|
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
|
||||||
"providers": {
|
"providers": {
|
||||||
"vllm": {"apiBase": "http://localhost:8000"},
|
"vllm": {"apiBase": "http://localhost:8000"},
|
||||||
"ollama": {"apiBase": "http://localhost:11434"},
|
"ollama": {"apiBase": "http://localhost:11434/v1"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert config.get_provider_name() == "ollama"
|
assert config.get_provider_name() == "ollama"
|
||||||
assert config.get_api_base() == "http://localhost:11434"
|
assert config.get_api_base() == "http://localhost:11434/v1"
|
||||||
|
|
||||||
|
|
||||||
def test_config_falls_back_to_vllm_when_ollama_not_configured():
|
def test_config_falls_back_to_vllm_when_ollama_not_configured():
|
||||||
@ -309,19 +308,13 @@ def test_config_falls_back_to_vllm_when_ollama_not_configured():
|
|||||||
assert config.get_api_base() == "http://localhost:8000"
|
assert config.get_api_base() == "http://localhost:8000"
|
||||||
|
|
||||||
|
|
||||||
def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword():
|
def test_openai_compat_provider_passes_model_through():
|
||||||
spec = find_by_model("github-copilot/gpt-5.3-codex")
|
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||||
|
|
||||||
assert spec is not None
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
assert spec.name == "github_copilot"
|
provider = OpenAICompatProvider(default_model="github-copilot/gpt-5.3-codex")
|
||||||
|
|
||||||
|
assert provider.get_default_model() == "github-copilot/gpt-5.3-codex"
|
||||||
def test_litellm_provider_canonicalizes_github_copilot_hyphen_prefix():
|
|
||||||
provider = LiteLLMProvider(default_model="github-copilot/gpt-5.3-codex")
|
|
||||||
|
|
||||||
resolved = provider._resolve_model("github-copilot/gpt-5.3-codex")
|
|
||||||
|
|
||||||
assert resolved == "github_copilot/gpt-5.3-codex"
|
|
||||||
|
|
||||||
|
|
||||||
def test_openai_codex_strip_prefix_supports_hyphen_and_underscore():
|
def test_openai_codex_strip_prefix_supports_hyphen_and_underscore():
|
||||||
@ -346,7 +339,7 @@ def test_make_provider_passes_extra_headers_to_custom_provider():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch("nanobot.providers.custom_provider.AsyncOpenAI") as mock_async_openai:
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_async_openai:
|
||||||
_make_provider(config)
|
_make_provider(config)
|
||||||
|
|
||||||
kwargs = mock_async_openai.call_args.kwargs
|
kwargs = mock_async_openai.call_args.kwargs
|
||||||
|
|||||||
@ -1,13 +1,55 @@
|
|||||||
from types import SimpleNamespace
|
"""Tests for OpenAICompatProvider handling custom/direct endpoints."""
|
||||||
|
|
||||||
from nanobot.providers.custom_provider import CustomProvider
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||||
|
|
||||||
|
|
||||||
def test_custom_provider_parse_handles_empty_choices() -> None:
|
def test_custom_provider_parse_handles_empty_choices() -> None:
|
||||||
provider = CustomProvider()
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
provider = OpenAICompatProvider()
|
||||||
response = SimpleNamespace(choices=[])
|
response = SimpleNamespace(choices=[])
|
||||||
|
|
||||||
result = provider._parse(response)
|
result = provider._parse(response)
|
||||||
|
|
||||||
assert result.finish_reason == "error"
|
assert result.finish_reason == "error"
|
||||||
assert "empty choices" in result.content
|
assert "empty choices" in result.content
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_provider_parse_accepts_plain_string_response() -> None:
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
provider = OpenAICompatProvider()
|
||||||
|
|
||||||
|
result = provider._parse("hello from backend")
|
||||||
|
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
|
assert result.content == "hello from backend"
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_provider_parse_accepts_dict_response() -> None:
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
provider = OpenAICompatProvider()
|
||||||
|
|
||||||
|
result = provider._parse({
|
||||||
|
"choices": [{
|
||||||
|
"message": {"content": "hello from dict"},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 1,
|
||||||
|
"completion_tokens": 2,
|
||||||
|
"total_tokens": 3,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
|
assert result.content == "hello from dict"
|
||||||
|
assert result.usage["total_tokens"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_provider_parse_chunks_accepts_plain_text_chunks() -> None:
|
||||||
|
result = OpenAICompatProvider._parse_chunks(["hello ", "world"])
|
||||||
|
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
|
assert result.content == "hello world"
|
||||||
|
|||||||
@ -1,161 +1,177 @@
|
|||||||
"""Regression tests for PR #2026 — litellm_kwargs injection from ProviderSpec.
|
"""Tests for OpenAICompatProvider spec-driven behavior.
|
||||||
|
|
||||||
Validates that:
|
Validates that:
|
||||||
- OpenRouter uses litellm_prefix (NOT custom_llm_provider) to avoid LiteLLM double-prefixing.
|
- OpenRouter (no strip) keeps model names intact.
|
||||||
- The litellm_kwargs mechanism works correctly for providers that declare it.
|
- AiHubMix (strip_model_prefix=True) strips provider prefixes.
|
||||||
- Non-gateway providers are unaffected.
|
- Standard providers pass model names through as-is.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||||
from nanobot.providers.registry import find_by_name
|
from nanobot.providers.registry import find_by_name
|
||||||
|
|
||||||
|
|
||||||
def _fake_response(content: str = "ok") -> SimpleNamespace:
|
def _fake_chat_response(content: str = "ok") -> SimpleNamespace:
|
||||||
"""Build a minimal acompletion-shaped response object."""
|
"""Build a minimal OpenAI chat completion response."""
|
||||||
message = SimpleNamespace(
|
message = SimpleNamespace(
|
||||||
content=content,
|
content=content,
|
||||||
tool_calls=None,
|
tool_calls=None,
|
||||||
reasoning_content=None,
|
reasoning_content=None,
|
||||||
thinking_blocks=None,
|
|
||||||
)
|
)
|
||||||
choice = SimpleNamespace(message=message, finish_reason="stop")
|
choice = SimpleNamespace(message=message, finish_reason="stop")
|
||||||
usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
|
usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
|
||||||
return SimpleNamespace(choices=[choice], usage=usage)
|
return SimpleNamespace(choices=[choice], usage=usage)
|
||||||
|
|
||||||
|
|
||||||
def test_openrouter_spec_uses_prefix_not_custom_llm_provider() -> None:
|
def _fake_tool_call_response() -> SimpleNamespace:
|
||||||
"""OpenRouter must rely on litellm_prefix, not custom_llm_provider kwarg.
|
"""Build a minimal chat response that includes Gemini-style extra_content."""
|
||||||
|
function = SimpleNamespace(
|
||||||
|
name="exec",
|
||||||
|
arguments='{"cmd":"ls"}',
|
||||||
|
provider_specific_fields={"inner": "value"},
|
||||||
|
)
|
||||||
|
tool_call = SimpleNamespace(
|
||||||
|
id="call_123",
|
||||||
|
index=0,
|
||||||
|
type="function",
|
||||||
|
function=function,
|
||||||
|
extra_content={"google": {"thought_signature": "signed-token"}},
|
||||||
|
)
|
||||||
|
message = SimpleNamespace(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[tool_call],
|
||||||
|
reasoning_content=None,
|
||||||
|
)
|
||||||
|
choice = SimpleNamespace(message=message, finish_reason="tool_calls")
|
||||||
|
usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
|
||||||
|
return SimpleNamespace(choices=[choice], usage=usage)
|
||||||
|
|
||||||
LiteLLM internally adds a provider/ prefix when custom_llm_provider is set,
|
|
||||||
which double-prefixes models (openrouter/anthropic/model) and breaks the API.
|
def test_openrouter_spec_is_gateway() -> None:
|
||||||
"""
|
|
||||||
spec = find_by_name("openrouter")
|
spec = find_by_name("openrouter")
|
||||||
assert spec is not None
|
assert spec is not None
|
||||||
assert spec.litellm_prefix == "openrouter"
|
assert spec.is_gateway is True
|
||||||
assert "custom_llm_provider" not in spec.litellm_kwargs, (
|
assert spec.default_api_base == "https://openrouter.ai/api/v1"
|
||||||
"custom_llm_provider causes LiteLLM to double-prefix the model name"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_openrouter_prefixes_model_correctly() -> None:
|
async def test_openrouter_keeps_model_name_intact() -> None:
|
||||||
"""OpenRouter should prefix model as openrouter/vendor/model for LiteLLM routing."""
|
"""OpenRouter gateway keeps the full model name (gateway does its own routing)."""
|
||||||
mock_acompletion = AsyncMock(return_value=_fake_response())
|
mock_create = AsyncMock(return_value=_fake_chat_response())
|
||||||
|
spec = find_by_name("openrouter")
|
||||||
|
|
||||||
with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion):
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||||
provider = LiteLLMProvider(
|
client_instance = MockClient.return_value
|
||||||
|
client_instance.chat.completions.create = mock_create
|
||||||
|
|
||||||
|
provider = OpenAICompatProvider(
|
||||||
api_key="sk-or-test-key",
|
api_key="sk-or-test-key",
|
||||||
api_base="https://openrouter.ai/api/v1",
|
api_base="https://openrouter.ai/api/v1",
|
||||||
default_model="anthropic/claude-sonnet-4-5",
|
default_model="anthropic/claude-sonnet-4-5",
|
||||||
provider_name="openrouter",
|
spec=spec,
|
||||||
)
|
)
|
||||||
await provider.chat(
|
await provider.chat(
|
||||||
messages=[{"role": "user", "content": "hello"}],
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
model="anthropic/claude-sonnet-4-5",
|
model="anthropic/claude-sonnet-4-5",
|
||||||
)
|
)
|
||||||
|
|
||||||
call_kwargs = mock_acompletion.call_args.kwargs
|
call_kwargs = mock_create.call_args.kwargs
|
||||||
assert call_kwargs["model"] == "openrouter/anthropic/claude-sonnet-4-5", (
|
assert call_kwargs["model"] == "anthropic/claude-sonnet-4-5"
|
||||||
"LiteLLM needs openrouter/ prefix to detect the provider and strip it before API call"
|
|
||||||
)
|
|
||||||
assert "custom_llm_provider" not in call_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_non_gateway_provider_no_extra_kwargs() -> None:
|
async def test_aihubmix_strips_model_prefix() -> None:
|
||||||
"""Standard (non-gateway) providers must NOT inject any litellm_kwargs."""
|
"""AiHubMix strips the provider prefix (strip_model_prefix=True)."""
|
||||||
mock_acompletion = AsyncMock(return_value=_fake_response())
|
mock_create = AsyncMock(return_value=_fake_chat_response())
|
||||||
|
spec = find_by_name("aihubmix")
|
||||||
|
|
||||||
with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion):
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||||
provider = LiteLLMProvider(
|
client_instance = MockClient.return_value
|
||||||
api_key="sk-ant-test-key",
|
client_instance.chat.completions.create = mock_create
|
||||||
default_model="claude-sonnet-4-5",
|
|
||||||
)
|
|
||||||
await provider.chat(
|
|
||||||
messages=[{"role": "user", "content": "hello"}],
|
|
||||||
model="claude-sonnet-4-5",
|
|
||||||
)
|
|
||||||
|
|
||||||
call_kwargs = mock_acompletion.call_args.kwargs
|
provider = OpenAICompatProvider(
|
||||||
assert "custom_llm_provider" not in call_kwargs, (
|
|
||||||
"Standard Anthropic provider should NOT inject custom_llm_provider"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_gateway_without_litellm_kwargs_injects_nothing_extra() -> None:
|
|
||||||
"""Gateways without litellm_kwargs (e.g. AiHubMix) must not add extra keys."""
|
|
||||||
mock_acompletion = AsyncMock(return_value=_fake_response())
|
|
||||||
|
|
||||||
with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion):
|
|
||||||
provider = LiteLLMProvider(
|
|
||||||
api_key="sk-aihub-test-key",
|
api_key="sk-aihub-test-key",
|
||||||
api_base="https://aihubmix.com/v1",
|
api_base="https://aihubmix.com/v1",
|
||||||
default_model="claude-sonnet-4-5",
|
default_model="claude-sonnet-4-5",
|
||||||
provider_name="aihubmix",
|
spec=spec,
|
||||||
)
|
|
||||||
await provider.chat(
|
|
||||||
messages=[{"role": "user", "content": "hello"}],
|
|
||||||
model="claude-sonnet-4-5",
|
|
||||||
)
|
|
||||||
|
|
||||||
call_kwargs = mock_acompletion.call_args.kwargs
|
|
||||||
assert "custom_llm_provider" not in call_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_openrouter_autodetect_by_key_prefix() -> None:
|
|
||||||
"""OpenRouter should be auto-detected by sk-or- key prefix even without explicit provider_name."""
|
|
||||||
mock_acompletion = AsyncMock(return_value=_fake_response())
|
|
||||||
|
|
||||||
with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion):
|
|
||||||
provider = LiteLLMProvider(
|
|
||||||
api_key="sk-or-auto-detect-key",
|
|
||||||
default_model="anthropic/claude-sonnet-4-5",
|
|
||||||
)
|
)
|
||||||
await provider.chat(
|
await provider.chat(
|
||||||
messages=[{"role": "user", "content": "hello"}],
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
model="anthropic/claude-sonnet-4-5",
|
model="anthropic/claude-sonnet-4-5",
|
||||||
)
|
)
|
||||||
|
|
||||||
call_kwargs = mock_acompletion.call_args.kwargs
|
call_kwargs = mock_create.call_args.kwargs
|
||||||
assert call_kwargs["model"] == "openrouter/anthropic/claude-sonnet-4-5", (
|
assert call_kwargs["model"] == "claude-sonnet-4-5"
|
||||||
"Auto-detected OpenRouter should prefix model for LiteLLM routing"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_openrouter_native_model_id_gets_double_prefixed() -> None:
|
async def test_standard_provider_passes_model_through() -> None:
|
||||||
"""Models like openrouter/free must be double-prefixed so LiteLLM strips one layer.
|
"""Standard provider (e.g. deepseek) passes model name through as-is."""
|
||||||
|
mock_create = AsyncMock(return_value=_fake_chat_response())
|
||||||
|
spec = find_by_name("deepseek")
|
||||||
|
|
||||||
openrouter/free is an actual OpenRouter model ID. LiteLLM strips the first
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||||
openrouter/ for routing, so we must send openrouter/openrouter/free to ensure
|
client_instance = MockClient.return_value
|
||||||
the API receives openrouter/free.
|
client_instance.chat.completions.create = mock_create
|
||||||
"""
|
|
||||||
mock_acompletion = AsyncMock(return_value=_fake_response())
|
|
||||||
|
|
||||||
with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion):
|
provider = OpenAICompatProvider(
|
||||||
provider = LiteLLMProvider(
|
api_key="sk-deepseek-test-key",
|
||||||
api_key="sk-or-test-key",
|
default_model="deepseek-chat",
|
||||||
api_base="https://openrouter.ai/api/v1",
|
spec=spec,
|
||||||
default_model="openrouter/free",
|
|
||||||
provider_name="openrouter",
|
|
||||||
)
|
)
|
||||||
await provider.chat(
|
await provider.chat(
|
||||||
messages=[{"role": "user", "content": "hello"}],
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
model="openrouter/free",
|
model="deepseek-chat",
|
||||||
)
|
)
|
||||||
|
|
||||||
call_kwargs = mock_acompletion.call_args.kwargs
|
call_kwargs = mock_create.call_args.kwargs
|
||||||
assert call_kwargs["model"] == "openrouter/openrouter/free", (
|
assert call_kwargs["model"] == "deepseek-chat"
|
||||||
"openrouter/free must become openrouter/openrouter/free — "
|
|
||||||
"LiteLLM strips one layer so the API receives openrouter/free"
|
|
||||||
)
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_compat_preserves_extra_content_on_tool_calls() -> None:
|
||||||
|
"""Gemini extra_content (thought signatures) must survive parse→serialize round-trip."""
|
||||||
|
mock_create = AsyncMock(return_value=_fake_tool_call_response())
|
||||||
|
spec = find_by_name("gemini")
|
||||||
|
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||||
|
client_instance = MockClient.return_value
|
||||||
|
client_instance.chat.completions.create = mock_create
|
||||||
|
|
||||||
|
provider = OpenAICompatProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||||
|
default_model="google/gemini-3.1-pro-preview",
|
||||||
|
spec=spec,
|
||||||
|
)
|
||||||
|
result = await provider.chat(
|
||||||
|
messages=[{"role": "user", "content": "run exec"}],
|
||||||
|
model="google/gemini-3.1-pro-preview",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
tool_call = result.tool_calls[0]
|
||||||
|
assert tool_call.extra_content == {"google": {"thought_signature": "signed-token"}}
|
||||||
|
assert tool_call.function_provider_specific_fields == {"inner": "value"}
|
||||||
|
|
||||||
|
serialized = tool_call.to_openai_tool_call()
|
||||||
|
assert serialized["extra_content"] == {"google": {"thought_signature": "signed-token"}}
|
||||||
|
assert serialized["function"]["provider_specific_fields"] == {"inner": "value"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_model_passthrough() -> None:
|
||||||
|
"""OpenAI models pass through unchanged."""
|
||||||
|
spec = find_by_name("openai")
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
provider = OpenAICompatProvider(
|
||||||
|
api_key="sk-test-key",
|
||||||
|
default_model="gpt-4o",
|
||||||
|
spec=spec,
|
||||||
|
)
|
||||||
|
assert provider.get_default_model() == "gpt-4o"
|
||||||
|
|||||||
@ -17,6 +17,4 @@ def test_mistral_provider_in_registry():
|
|||||||
|
|
||||||
mistral = specs["mistral"]
|
mistral = specs["mistral"]
|
||||||
assert mistral.env_key == "MISTRAL_API_KEY"
|
assert mistral.env_key == "MISTRAL_API_KEY"
|
||||||
assert mistral.litellm_prefix == "mistral"
|
|
||||||
assert mistral.default_api_base == "https://api.mistral.ai/v1"
|
assert mistral.default_api_base == "https://api.mistral.ai/v1"
|
||||||
assert "mistral/" in mistral.skip_prefixes
|
|
||||||
|
|||||||
@ -8,19 +8,22 @@ import sys
|
|||||||
|
|
||||||
def test_importing_providers_package_is_lazy(monkeypatch) -> None:
|
def test_importing_providers_package_is_lazy(monkeypatch) -> None:
|
||||||
monkeypatch.delitem(sys.modules, "nanobot.providers", raising=False)
|
monkeypatch.delitem(sys.modules, "nanobot.providers", raising=False)
|
||||||
monkeypatch.delitem(sys.modules, "nanobot.providers.litellm_provider", raising=False)
|
monkeypatch.delitem(sys.modules, "nanobot.providers.anthropic_provider", raising=False)
|
||||||
|
monkeypatch.delitem(sys.modules, "nanobot.providers.openai_compat_provider", raising=False)
|
||||||
monkeypatch.delitem(sys.modules, "nanobot.providers.openai_codex_provider", raising=False)
|
monkeypatch.delitem(sys.modules, "nanobot.providers.openai_codex_provider", raising=False)
|
||||||
monkeypatch.delitem(sys.modules, "nanobot.providers.azure_openai_provider", raising=False)
|
monkeypatch.delitem(sys.modules, "nanobot.providers.azure_openai_provider", raising=False)
|
||||||
|
|
||||||
providers = importlib.import_module("nanobot.providers")
|
providers = importlib.import_module("nanobot.providers")
|
||||||
|
|
||||||
assert "nanobot.providers.litellm_provider" not in sys.modules
|
assert "nanobot.providers.anthropic_provider" not in sys.modules
|
||||||
|
assert "nanobot.providers.openai_compat_provider" not in sys.modules
|
||||||
assert "nanobot.providers.openai_codex_provider" not in sys.modules
|
assert "nanobot.providers.openai_codex_provider" not in sys.modules
|
||||||
assert "nanobot.providers.azure_openai_provider" not in sys.modules
|
assert "nanobot.providers.azure_openai_provider" not in sys.modules
|
||||||
assert providers.__all__ == [
|
assert providers.__all__ == [
|
||||||
"LLMProvider",
|
"LLMProvider",
|
||||||
"LLMResponse",
|
"LLMResponse",
|
||||||
"LiteLLMProvider",
|
"AnthropicProvider",
|
||||||
|
"OpenAICompatProvider",
|
||||||
"OpenAICodexProvider",
|
"OpenAICodexProvider",
|
||||||
"AzureOpenAIProvider",
|
"AzureOpenAIProvider",
|
||||||
]
|
]
|
||||||
@ -28,10 +31,10 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None:
|
|||||||
|
|
||||||
def test_explicit_provider_import_still_works(monkeypatch) -> None:
|
def test_explicit_provider_import_still_works(monkeypatch) -> None:
|
||||||
monkeypatch.delitem(sys.modules, "nanobot.providers", raising=False)
|
monkeypatch.delitem(sys.modules, "nanobot.providers", raising=False)
|
||||||
monkeypatch.delitem(sys.modules, "nanobot.providers.litellm_provider", raising=False)
|
monkeypatch.delitem(sys.modules, "nanobot.providers.anthropic_provider", raising=False)
|
||||||
|
|
||||||
namespace: dict[str, object] = {}
|
namespace: dict[str, object] = {}
|
||||||
exec("from nanobot.providers import LiteLLMProvider", namespace)
|
exec("from nanobot.providers import AnthropicProvider", namespace)
|
||||||
|
|
||||||
assert namespace["LiteLLMProvider"].__name__ == "LiteLLMProvider"
|
assert namespace["AnthropicProvider"].__name__ == "AnthropicProvider"
|
||||||
assert "nanobot.providers.litellm_provider" in sys.modules
|
assert "nanobot.providers.anthropic_provider" in sys.modules
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user