mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-21 17:12:32 +00:00
perf: optimize gateway cold start from ~6.9s to ~460ms (#3918)
Channel lazy load: discover_enabled() only imports enabled channel modules instead of all 18 modules with heavy SDKs (telegram, discord, slack, etc). discover_all() now delegates to discover_enabled(). Lazy OpenAI client: defer AsyncOpenAI() + httpx construction to _ensure_client() with asyncio.Lock double-checked locking. openai and httpx imports moved from module-level into _ensure_client(). Minor: lazy Nanobot/RunResult and CronService exports via __getattr__. Benchmark: 6910ms → 460ms (-93.3%)
This commit is contained in:
parent
1391aa3d57
commit
af9f8d54b8
1
.gitignore
vendored
1
.gitignore
vendored
@ -97,3 +97,4 @@ logs/
|
|||||||
tmp/
|
tmp/
|
||||||
temp/
|
temp/
|
||||||
*.tmp
|
*.tmp
|
||||||
|
exp/
|
||||||
|
|||||||
@ -2,9 +2,10 @@
|
|||||||
nanobot - A lightweight AI agent framework
|
nanobot - A lightweight AI agent framework
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from importlib.metadata import PackageNotFoundError, version as _pkg_version
|
|
||||||
from pathlib import Path
|
|
||||||
import tomllib
|
import tomllib
|
||||||
|
from importlib.metadata import PackageNotFoundError
|
||||||
|
from importlib.metadata import version as _pkg_version
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
def _read_pyproject_version() -> str | None:
|
def _read_pyproject_version() -> str | None:
|
||||||
@ -27,6 +28,21 @@ def _resolve_version() -> str:
|
|||||||
__version__ = _resolve_version()
|
__version__ = _resolve_version()
|
||||||
__logo__ = "🐈"
|
__logo__ = "🐈"
|
||||||
|
|
||||||
from nanobot.nanobot import Nanobot, RunResult
|
_LAZY_EXPORTS = {
|
||||||
|
"Nanobot": ".nanobot",
|
||||||
|
"RunResult": ".nanobot",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str):
|
||||||
|
module_path = _LAZY_EXPORTS.get(name)
|
||||||
|
if module_path is None:
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
from importlib import import_module
|
||||||
|
mod = import_module(module_path, __name__)
|
||||||
|
val = getattr(mod, name)
|
||||||
|
globals()[name] = val
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Nanobot", "RunResult"]
|
__all__ = ["Nanobot", "RunResult"]
|
||||||
|
|||||||
@ -70,28 +70,40 @@ class ChannelManager:
|
|||||||
|
|
||||||
def _init_channels(self) -> None:
|
def _init_channels(self) -> None:
|
||||||
"""Initialize channels discovered via pkgutil scan + entry_points plugins."""
|
"""Initialize channels discovered via pkgutil scan + entry_points plugins."""
|
||||||
from nanobot.channels.registry import discover_all
|
from nanobot.channels.registry import discover_channel_names, discover_enabled
|
||||||
|
|
||||||
transcription_provider = self.config.channels.transcription_provider
|
transcription_provider = self.config.channels.transcription_provider
|
||||||
transcription_key = self._resolve_transcription_key(transcription_provider)
|
transcription_key = self._resolve_transcription_key(transcription_provider)
|
||||||
transcription_base = self._resolve_transcription_base(transcription_provider)
|
transcription_base = self._resolve_transcription_base(transcription_provider)
|
||||||
transcription_language = self.config.channels.transcription_language
|
transcription_language = self.config.channels.transcription_language
|
||||||
|
|
||||||
for name, cls in discover_all().items():
|
# Collect enabled module names first, then only import those.
|
||||||
|
# Channel configs live in ChannelsConfig's extra fields (via
|
||||||
|
# extra="allow"), so we enumerate candidates from pkgutil scan
|
||||||
|
# (cheap, no imports) and any plugin keys in __pydantic_extra__.
|
||||||
|
names = discover_channel_names()
|
||||||
|
candidate_names = set(names)
|
||||||
|
extra = getattr(self.config.channels, "__pydantic_extra__", None) or {}
|
||||||
|
candidate_names.update(extra.keys())
|
||||||
|
|
||||||
|
enabled_names: set[str] = set()
|
||||||
|
for name in candidate_names:
|
||||||
section = getattr(self.config.channels, name, None)
|
section = getattr(self.config.channels, name, None)
|
||||||
if section is None:
|
if section is None:
|
||||||
continue
|
continue
|
||||||
enabled = (
|
if (
|
||||||
section.get("enabled", False)
|
section.get("enabled", False)
|
||||||
if isinstance(section, dict)
|
if isinstance(section, dict)
|
||||||
else getattr(section, "enabled", False)
|
else getattr(section, "enabled", False)
|
||||||
)
|
):
|
||||||
if not enabled:
|
enabled_names.add(name)
|
||||||
|
|
||||||
|
for name, cls in discover_enabled(enabled_names, _names=names).items():
|
||||||
|
section = getattr(self.config.channels, name, None)
|
||||||
|
if section is None:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
kwargs: dict[str, Any] = {}
|
kwargs: dict[str, Any] = {}
|
||||||
# Only the WebSocket channel currently hosts the embedded webui
|
|
||||||
# surface; other channels stay oblivious to these knobs.
|
|
||||||
if cls.name == "websocket":
|
if cls.name == "websocket":
|
||||||
if self._session_manager is not None:
|
if self._session_manager is not None:
|
||||||
kwargs["session_manager"] = self._session_manager
|
kwargs["session_manager"] = self._session_manager
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
"""Auto-discovery for built-in channel modules and external plugins."""
|
"""Auto-discovery for built-in channel modules and external plugins."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
@ -51,21 +50,44 @@ def discover_plugins() -> dict[str, type[BaseChannel]]:
|
|||||||
return plugins
|
return plugins
|
||||||
|
|
||||||
|
|
||||||
|
def discover_enabled(
|
||||||
|
enabled_names: set[str],
|
||||||
|
*,
|
||||||
|
_names: list[str] | None = None,
|
||||||
|
_include_all_external: bool = False,
|
||||||
|
) -> dict[str, type[BaseChannel]]:
|
||||||
|
"""Return channels whose module names are in *enabled_names*.
|
||||||
|
|
||||||
|
Uses cheap ``pkgutil.iter_modules`` to list names, then imports only
|
||||||
|
those that match — skipping the heavy third-party SDK imports of
|
||||||
|
unneeded channels.
|
||||||
|
"""
|
||||||
|
names = _names if _names is not None else discover_channel_names()
|
||||||
|
result: dict[str, type[BaseChannel]] = {}
|
||||||
|
for modname in names:
|
||||||
|
if modname not in enabled_names:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
result[modname] = load_channel_class(modname)
|
||||||
|
except ImportError as e:
|
||||||
|
logger.debug("Skipping built-in channel '{}': {}", modname, e)
|
||||||
|
|
||||||
|
external = discover_plugins()
|
||||||
|
shadowed = set(external) & set(result)
|
||||||
|
if shadowed:
|
||||||
|
logger.warning("Plugin(s) shadowed by built-in channels (ignored): {}", shadowed)
|
||||||
|
if _include_all_external:
|
||||||
|
result.update({k: v for k, v in external.items() if k not in shadowed})
|
||||||
|
else:
|
||||||
|
result.update({k: v for k, v in external.items() if k not in shadowed and k in enabled_names})
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def discover_all() -> dict[str, type[BaseChannel]]:
|
def discover_all() -> dict[str, type[BaseChannel]]:
|
||||||
"""Return all channels: built-in (pkgutil) merged with external (entry_points).
|
"""Return all channels: built-in (pkgutil) merged with external (entry_points).
|
||||||
|
|
||||||
Built-in channels take priority — an external plugin cannot shadow a built-in name.
|
Built-in channels take priority — an external plugin cannot shadow a built-in name.
|
||||||
"""
|
"""
|
||||||
builtin: dict[str, type[BaseChannel]] = {}
|
names = discover_channel_names()
|
||||||
for modname in discover_channel_names():
|
return discover_enabled(set(names), _names=names, _include_all_external=True)
|
||||||
try:
|
|
||||||
builtin[modname] = load_channel_class(modname)
|
|
||||||
except ImportError as e:
|
|
||||||
logger.debug("Skipping built-in channel '{}': {}", modname, e)
|
|
||||||
|
|
||||||
external = discover_plugins()
|
|
||||||
shadowed = set(external) & set(builtin)
|
|
||||||
if shadowed:
|
|
||||||
logger.warning("Plugin(s) shadowed by built-in channels (ignored): {}", shadowed)
|
|
||||||
|
|
||||||
return {**external, **builtin}
|
|
||||||
|
|||||||
@ -1,6 +1,18 @@
|
|||||||
"""Cron service for scheduled agent tasks."""
|
"""Cron service for scheduled agent tasks."""
|
||||||
|
|
||||||
from nanobot.cron.service import CronService
|
|
||||||
from nanobot.cron.types import CronJob, CronSchedule
|
from nanobot.cron.types import CronJob, CronSchedule
|
||||||
|
|
||||||
__all__ = ["CronService", "CronJob", "CronSchedule"]
|
__all__ = ["CronService", "CronJob", "CronSchedule"]
|
||||||
|
|
||||||
|
_LAZY = {"CronService": ".service"}
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str):
|
||||||
|
module_path = _LAZY.get(name)
|
||||||
|
if module_path is None:
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
from importlib import import_module
|
||||||
|
mod = import_module(module_path, __name__)
|
||||||
|
val = getattr(mod, name)
|
||||||
|
globals()[name] = val
|
||||||
|
return val
|
||||||
|
|||||||
@ -16,20 +16,9 @@ from ipaddress import ip_address
|
|||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
|
||||||
import json_repair
|
import json_repair
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
if os.environ.get("LANGFUSE_SECRET_KEY") and importlib.util.find_spec("langfuse"):
|
|
||||||
from langfuse.openai import AsyncOpenAI
|
|
||||||
else:
|
|
||||||
if os.environ.get("LANGFUSE_SECRET_KEY"):
|
|
||||||
logger.warning(
|
|
||||||
"LANGFUSE_SECRET_KEY is set but langfuse is not installed; "
|
|
||||||
"install with `pip install langfuse` to enable tracing"
|
|
||||||
)
|
|
||||||
from openai import AsyncOpenAI
|
|
||||||
|
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
from nanobot.providers.openai_responses import (
|
from nanobot.providers.openai_responses import (
|
||||||
consume_sdk_stream,
|
consume_sdk_stream,
|
||||||
@ -39,8 +28,15 @@ from nanobot.providers.openai_responses import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from openai import AsyncOpenAI as AsyncOpenAIType
|
||||||
|
|
||||||
from nanobot.providers.registry import ProviderSpec
|
from nanobot.providers.registry import ProviderSpec
|
||||||
|
|
||||||
|
# Module-level placeholder — set lazily by _ensure_client on first real
|
||||||
|
# use, or replaced by tests via ``patch(...)``. Kept as a plain name so
|
||||||
|
# that ``unittest.mock.patch`` can find and replace it.
|
||||||
|
AsyncOpenAI: Any = None
|
||||||
|
|
||||||
_ALLOWED_MSG_KEYS = frozenset({
|
_ALLOWED_MSG_KEYS = frozenset({
|
||||||
"role", "content", "tool_calls", "tool_call_id", "name",
|
"role", "content", "tool_calls", "tool_call_id", "name",
|
||||||
"reasoning_content", "extra_content",
|
"reasoning_content", "extra_content",
|
||||||
@ -302,43 +298,80 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
|
|
||||||
effective_base = api_base or (spec.default_api_base if spec else None) or None
|
effective_base = api_base or (spec.default_api_base if spec else None) or None
|
||||||
self._effective_base = effective_base
|
self._effective_base = effective_base
|
||||||
default_headers = {"x-session-affinity": uuid.uuid4().hex}
|
self._default_headers = {"x-session-affinity": uuid.uuid4().hex}
|
||||||
if _uses_openrouter_attribution(spec, effective_base):
|
if _uses_openrouter_attribution(spec, effective_base):
|
||||||
default_headers.update(_DEFAULT_OPENROUTER_HEADERS)
|
self._default_headers.update(_DEFAULT_OPENROUTER_HEADERS)
|
||||||
if extra_headers:
|
if extra_headers:
|
||||||
default_headers.update(extra_headers)
|
self._default_headers.update(extra_headers)
|
||||||
|
self._api_key_for_client = api_key or "no-key"
|
||||||
|
self._is_local = _is_local_endpoint(spec, effective_base)
|
||||||
|
|
||||||
# Local model servers (Ollama, llama.cpp, vLLM) often close idle
|
# Lazy-init: the OpenAI client and its httpx transport are expensive
|
||||||
# HTTP connections before the client-side keepalive expires. When
|
# to create (~700 ms on Windows). Defer until first use — unless
|
||||||
# two LLM calls happen seconds apart (e.g. heartbeat _decide then
|
# AsyncOpenAI has been patched (tests), in which case build eagerly.
|
||||||
# process_direct), the second call may grab a now-dead pooled
|
self._client: AsyncOpenAIType | None = None
|
||||||
# connection, causing a transient APIConnectionError on every first
|
self._client_lock = asyncio.Lock()
|
||||||
# attempt. Disabling keepalive for local endpoints avoids this by
|
|
||||||
# opening a fresh connection for each request, which is cheap on a
|
|
||||||
# LAN. Cloud providers benefit from keepalive, so we leave the
|
|
||||||
# default pool settings for them.
|
|
||||||
timeout_s = _openai_compat_timeout_s()
|
|
||||||
http_client: httpx.AsyncClient | None = None
|
|
||||||
if _is_local_endpoint(spec, effective_base):
|
|
||||||
http_client = httpx.AsyncClient(
|
|
||||||
limits=httpx.Limits(keepalive_expiry=0),
|
|
||||||
timeout=timeout_s,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._client = AsyncOpenAI(
|
if AsyncOpenAI is not None:
|
||||||
api_key=api_key or "no-key",
|
self._build_client()
|
||||||
base_url=effective_base,
|
|
||||||
default_headers=default_headers,
|
|
||||||
max_retries=0,
|
|
||||||
timeout=timeout_s,
|
|
||||||
http_client=http_client,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Responses API circuit breaker: skip after repeated failures,
|
# Responses API circuit breaker: skip after repeated failures,
|
||||||
# probe again after _RESPONSES_PROBE_INTERVAL_S seconds.
|
# probe again after _RESPONSES_PROBE_INTERVAL_S seconds.
|
||||||
self._responses_failures: dict[str, int] = {}
|
self._responses_failures: dict[str, int] = {}
|
||||||
self._responses_tripped_at: dict[str, float] = {}
|
self._responses_tripped_at: dict[str, float] = {}
|
||||||
|
|
||||||
|
def _build_client(self) -> None:
|
||||||
|
"""Create the OpenAI client using the current module-level AsyncOpenAI."""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
timeout_s = _openai_compat_timeout_s()
|
||||||
|
http_client: httpx.AsyncClient | None = None
|
||||||
|
if self._is_local:
|
||||||
|
# Local model servers (Ollama, llama.cpp, vLLM) often close idle
|
||||||
|
# HTTP connections before the client-side keepalive expires. When
|
||||||
|
# two LLM calls happen seconds apart (e.g. heartbeat _decide then
|
||||||
|
# process_direct), the second call may grab a now-dead pooled
|
||||||
|
# connection, causing a transient APIConnectionError on every first
|
||||||
|
# attempt. Disabling keepalive for local endpoints avoids this by
|
||||||
|
# opening a fresh connection for each request, which is cheap on a
|
||||||
|
# LAN. Cloud providers benefit from keepalive, so we leave the
|
||||||
|
# default pool settings for them.
|
||||||
|
http_client = httpx.AsyncClient(
|
||||||
|
limits=httpx.Limits(keepalive_expiry=0),
|
||||||
|
timeout=timeout_s,
|
||||||
|
)
|
||||||
|
self._client = AsyncOpenAI(
|
||||||
|
api_key=self._api_key_for_client,
|
||||||
|
base_url=self._effective_base,
|
||||||
|
default_headers=self._default_headers,
|
||||||
|
max_retries=0,
|
||||||
|
timeout=timeout_s,
|
||||||
|
http_client=http_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _ensure_client(self):
|
||||||
|
"""Return the shared OpenAI client, creating it on first call."""
|
||||||
|
if self._client is not None:
|
||||||
|
return self._client
|
||||||
|
async with self._client_lock:
|
||||||
|
if self._client is not None:
|
||||||
|
return self._client
|
||||||
|
global AsyncOpenAI
|
||||||
|
if AsyncOpenAI is None:
|
||||||
|
if os.environ.get("LANGFUSE_SECRET_KEY") and importlib.util.find_spec("langfuse"):
|
||||||
|
from langfuse.openai import AsyncOpenAI as _AsyncOpenAI
|
||||||
|
else:
|
||||||
|
if os.environ.get("LANGFUSE_SECRET_KEY"):
|
||||||
|
logger.warning(
|
||||||
|
"LANGFUSE_SECRET_KEY is set but langfuse is not installed; "
|
||||||
|
"install with `pip install langfuse` to enable tracing"
|
||||||
|
)
|
||||||
|
from openai import AsyncOpenAI as _AsyncOpenAI
|
||||||
|
AsyncOpenAI = _AsyncOpenAI
|
||||||
|
|
||||||
|
self._build_client()
|
||||||
|
return self._client
|
||||||
|
|
||||||
def _setup_env(self, api_key: str, api_base: str | None) -> None:
|
def _setup_env(self, api_key: str, api_base: str | None) -> None:
|
||||||
"""Set environment variables based on provider spec."""
|
"""Set environment variables based on provider spec."""
|
||||||
spec = self._spec
|
spec = self._spec
|
||||||
@ -1182,6 +1215,7 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
reasoning_effort: str | None = None,
|
reasoning_effort: str | None = None,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
|
await self._ensure_client()
|
||||||
try:
|
try:
|
||||||
if self._should_use_responses_api(model, reasoning_effort):
|
if self._should_use_responses_api(model, reasoning_effort):
|
||||||
try:
|
try:
|
||||||
@ -1223,6 +1257,7 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
on_thinking_delta: Callable[[str], Awaitable[None]] | None = None,
|
on_thinking_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None,
|
on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
|
await self._ensure_client()
|
||||||
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
|
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
|
||||||
try:
|
try:
|
||||||
if self._should_use_responses_api(model, reasoning_effort):
|
if self._should_use_responses_api(model, reasoning_effort):
|
||||||
|
|||||||
@ -180,7 +180,7 @@ async def test_manager_loads_plugin_from_dict_config():
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"nanobot.channels.registry.discover_all",
|
"nanobot.channels.registry.discover_enabled",
|
||||||
return_value={"fakeplugin": _FakePlugin},
|
return_value={"fakeplugin": _FakePlugin},
|
||||||
):
|
):
|
||||||
mgr = ChannelManager.__new__(ChannelManager)
|
mgr = ChannelManager.__new__(ChannelManager)
|
||||||
@ -210,7 +210,7 @@ async def test_manager_propagates_groq_transcription_api_base_to_channels():
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"nanobot.channels.registry.discover_all",
|
"nanobot.channels.registry.discover_enabled",
|
||||||
return_value={"fakeplugin": _FakePlugin},
|
return_value={"fakeplugin": _FakePlugin},
|
||||||
):
|
):
|
||||||
mgr = ChannelManager.__new__(ChannelManager)
|
mgr = ChannelManager.__new__(ChannelManager)
|
||||||
@ -246,7 +246,7 @@ async def test_manager_propagates_openai_transcription_api_base_to_channels():
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"nanobot.channels.registry.discover_all",
|
"nanobot.channels.registry.discover_enabled",
|
||||||
return_value={"fakeplugin": _FakePlugin},
|
return_value={"fakeplugin": _FakePlugin},
|
||||||
):
|
):
|
||||||
mgr = ChannelManager.__new__(ChannelManager)
|
mgr = ChannelManager.__new__(ChannelManager)
|
||||||
@ -498,10 +498,8 @@ async def test_manager_skips_disabled_plugin():
|
|||||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
ep = _make_entry_point("fakeplugin", _FakePlugin)
|
||||||
"nanobot.channels.registry.discover_all",
|
with patch(_EP_TARGET, return_value=[ep]):
|
||||||
return_value={"fakeplugin": _FakePlugin},
|
|
||||||
):
|
|
||||||
mgr = ChannelManager.__new__(ChannelManager)
|
mgr = ChannelManager.__new__(ChannelManager)
|
||||||
mgr.config = fake_config
|
mgr.config = fake_config
|
||||||
mgr.bus = MessageBus()
|
mgr.bus = MessageBus()
|
||||||
|
|||||||
@ -85,17 +85,18 @@ class TestIsLocalEndpoint:
|
|||||||
class TestLocalKeepaliveConfig:
|
class TestLocalKeepaliveConfig:
|
||||||
"""Verify that local endpoints get keepalive_expiry=0."""
|
"""Verify that local endpoints get keepalive_expiry=0."""
|
||||||
|
|
||||||
def test_local_spec_disables_keepalive(self):
|
async def test_local_spec_disables_keepalive(self):
|
||||||
spec = _make_spec(is_local=True)
|
spec = _make_spec(is_local=True)
|
||||||
spec.env_key = ""
|
spec.env_key = ""
|
||||||
spec.default_api_base = "http://localhost:11434/v1"
|
spec.default_api_base = "http://localhost:11434/v1"
|
||||||
provider = OpenAICompatProvider(
|
provider = OpenAICompatProvider(
|
||||||
api_key="test", api_base="http://localhost:11434/v1", spec=spec,
|
api_key="test", api_base="http://localhost:11434/v1", spec=spec,
|
||||||
)
|
)
|
||||||
|
await provider._ensure_client()
|
||||||
pool = provider._client._client._transport._pool
|
pool = provider._client._client._transport._pool
|
||||||
assert pool._keepalive_expiry == 0
|
assert pool._keepalive_expiry == 0
|
||||||
|
|
||||||
def test_lan_ip_disables_keepalive(self):
|
async def test_lan_ip_disables_keepalive(self):
|
||||||
"""A generic 'openai' spec with a LAN IP should still disable keepalive."""
|
"""A generic 'openai' spec with a LAN IP should still disable keepalive."""
|
||||||
spec = _make_spec(is_local=False)
|
spec = _make_spec(is_local=False)
|
||||||
spec.env_key = ""
|
spec.env_key = ""
|
||||||
@ -103,16 +104,18 @@ class TestLocalKeepaliveConfig:
|
|||||||
provider = OpenAICompatProvider(
|
provider = OpenAICompatProvider(
|
||||||
api_key="test", api_base="http://192.168.8.188:1234/v1", spec=spec,
|
api_key="test", api_base="http://192.168.8.188:1234/v1", spec=spec,
|
||||||
)
|
)
|
||||||
|
await provider._ensure_client()
|
||||||
pool = provider._client._client._transport._pool
|
pool = provider._client._client._transport._pool
|
||||||
assert pool._keepalive_expiry == 0
|
assert pool._keepalive_expiry == 0
|
||||||
|
|
||||||
def test_cloud_keeps_default_keepalive(self):
|
async def test_cloud_keeps_default_keepalive(self):
|
||||||
spec = _make_spec(is_local=False)
|
spec = _make_spec(is_local=False)
|
||||||
spec.env_key = ""
|
spec.env_key = ""
|
||||||
spec.default_api_base = "https://api.openai.com/v1"
|
spec.default_api_base = "https://api.openai.com/v1"
|
||||||
provider = OpenAICompatProvider(
|
provider = OpenAICompatProvider(
|
||||||
api_key="test", api_base=None, spec=spec,
|
api_key="test", api_base=None, spec=spec,
|
||||||
)
|
)
|
||||||
|
await provider._ensure_client()
|
||||||
pool = provider._client._client._transport._pool
|
pool = provider._client._client._transport._pool
|
||||||
# Default httpx keepalive is 5.0s
|
# Default httpx keepalive is 5.0s
|
||||||
assert pool._keepalive_expiry == 5.0
|
assert pool._keepalive_expiry == 5.0
|
||||||
|
|||||||
@ -29,7 +29,7 @@ def test_openai_compat_provider_sets_timeout_on_local_http_client() -> None:
|
|||||||
with (
|
with (
|
||||||
patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_async_openai,
|
patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_async_openai,
|
||||||
patch(
|
patch(
|
||||||
"nanobot.providers.openai_compat_provider.httpx.AsyncClient",
|
"httpx.AsyncClient",
|
||||||
return_value=sentinel.http_client,
|
return_value=sentinel.http_client,
|
||||||
) as mock_http_client,
|
) as mock_http_client,
|
||||||
):
|
):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user