mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-21 09:02:32 +00:00
perf: optimize gateway cold start from ~6.9s to ~460ms (#3918)
Channel lazy load: discover_enabled() only imports enabled channel modules instead of all 18 modules with heavy SDKs (telegram, discord, slack, etc). discover_all() now delegates to discover_enabled(). Lazy OpenAI client: defer AsyncOpenAI() + httpx construction to _ensure_client() with asyncio.Lock double-checked locking. openai and httpx imports moved from module-level into _ensure_client(). Minor: lazy Nanobot/RunResult and CronService exports via __getattr__. Benchmark: 6910ms → 460ms (-93.3%)
This commit is contained in:
parent
1391aa3d57
commit
af9f8d54b8
1
.gitignore
vendored
1
.gitignore
vendored
@ -97,3 +97,4 @@ logs/
|
||||
tmp/
|
||||
temp/
|
||||
*.tmp
|
||||
exp/
|
||||
|
||||
@ -2,9 +2,10 @@
|
||||
nanobot - A lightweight AI agent framework
|
||||
"""
|
||||
|
||||
from importlib.metadata import PackageNotFoundError, version as _pkg_version
|
||||
from pathlib import Path
|
||||
import tomllib
|
||||
from importlib.metadata import PackageNotFoundError
|
||||
from importlib.metadata import version as _pkg_version
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _read_pyproject_version() -> str | None:
|
||||
@ -27,6 +28,21 @@ def _resolve_version() -> str:
|
||||
__version__ = _resolve_version()
|
||||
__logo__ = "🐈"
|
||||
|
||||
from nanobot.nanobot import Nanobot, RunResult
|
||||
_LAZY_EXPORTS = {
|
||||
"Nanobot": ".nanobot",
|
||||
"RunResult": ".nanobot",
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
module_path = _LAZY_EXPORTS.get(name)
|
||||
if module_path is None:
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
from importlib import import_module
|
||||
mod = import_module(module_path, __name__)
|
||||
val = getattr(mod, name)
|
||||
globals()[name] = val
|
||||
return val
|
||||
|
||||
|
||||
__all__ = ["Nanobot", "RunResult"]
|
||||
|
||||
@ -70,28 +70,40 @@ class ChannelManager:
|
||||
|
||||
def _init_channels(self) -> None:
|
||||
"""Initialize channels discovered via pkgutil scan + entry_points plugins."""
|
||||
from nanobot.channels.registry import discover_all
|
||||
from nanobot.channels.registry import discover_channel_names, discover_enabled
|
||||
|
||||
transcription_provider = self.config.channels.transcription_provider
|
||||
transcription_key = self._resolve_transcription_key(transcription_provider)
|
||||
transcription_base = self._resolve_transcription_base(transcription_provider)
|
||||
transcription_language = self.config.channels.transcription_language
|
||||
|
||||
for name, cls in discover_all().items():
|
||||
# Collect enabled module names first, then only import those.
|
||||
# Channel configs live in ChannelsConfig's extra fields (via
|
||||
# extra="allow"), so we enumerate candidates from pkgutil scan
|
||||
# (cheap, no imports) and any plugin keys in __pydantic_extra__.
|
||||
names = discover_channel_names()
|
||||
candidate_names = set(names)
|
||||
extra = getattr(self.config.channels, "__pydantic_extra__", None) or {}
|
||||
candidate_names.update(extra.keys())
|
||||
|
||||
enabled_names: set[str] = set()
|
||||
for name in candidate_names:
|
||||
section = getattr(self.config.channels, name, None)
|
||||
if section is None:
|
||||
continue
|
||||
enabled = (
|
||||
if (
|
||||
section.get("enabled", False)
|
||||
if isinstance(section, dict)
|
||||
else getattr(section, "enabled", False)
|
||||
)
|
||||
if not enabled:
|
||||
):
|
||||
enabled_names.add(name)
|
||||
|
||||
for name, cls in discover_enabled(enabled_names, _names=names).items():
|
||||
section = getattr(self.config.channels, name, None)
|
||||
if section is None:
|
||||
continue
|
||||
try:
|
||||
kwargs: dict[str, Any] = {}
|
||||
# Only the WebSocket channel currently hosts the embedded webui
|
||||
# surface; other channels stay oblivious to these knobs.
|
||||
if cls.name == "websocket":
|
||||
if self._session_manager is not None:
|
||||
kwargs["session_manager"] = self._session_manager
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
"""Auto-discovery for built-in channel modules and external plugins."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
@ -51,21 +50,44 @@ def discover_plugins() -> dict[str, type[BaseChannel]]:
|
||||
return plugins
|
||||
|
||||
|
||||
def discover_enabled(
|
||||
enabled_names: set[str],
|
||||
*,
|
||||
_names: list[str] | None = None,
|
||||
_include_all_external: bool = False,
|
||||
) -> dict[str, type[BaseChannel]]:
|
||||
"""Return channels whose module names are in *enabled_names*.
|
||||
|
||||
Uses cheap ``pkgutil.iter_modules`` to list names, then imports only
|
||||
those that match — skipping the heavy third-party SDK imports of
|
||||
unneeded channels.
|
||||
"""
|
||||
names = _names if _names is not None else discover_channel_names()
|
||||
result: dict[str, type[BaseChannel]] = {}
|
||||
for modname in names:
|
||||
if modname not in enabled_names:
|
||||
continue
|
||||
try:
|
||||
result[modname] = load_channel_class(modname)
|
||||
except ImportError as e:
|
||||
logger.debug("Skipping built-in channel '{}': {}", modname, e)
|
||||
|
||||
external = discover_plugins()
|
||||
shadowed = set(external) & set(result)
|
||||
if shadowed:
|
||||
logger.warning("Plugin(s) shadowed by built-in channels (ignored): {}", shadowed)
|
||||
if _include_all_external:
|
||||
result.update({k: v for k, v in external.items() if k not in shadowed})
|
||||
else:
|
||||
result.update({k: v for k, v in external.items() if k not in shadowed and k in enabled_names})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def discover_all() -> dict[str, type[BaseChannel]]:
|
||||
"""Return all channels: built-in (pkgutil) merged with external (entry_points).
|
||||
|
||||
Built-in channels take priority — an external plugin cannot shadow a built-in name.
|
||||
"""
|
||||
builtin: dict[str, type[BaseChannel]] = {}
|
||||
for modname in discover_channel_names():
|
||||
try:
|
||||
builtin[modname] = load_channel_class(modname)
|
||||
except ImportError as e:
|
||||
logger.debug("Skipping built-in channel '{}': {}", modname, e)
|
||||
|
||||
external = discover_plugins()
|
||||
shadowed = set(external) & set(builtin)
|
||||
if shadowed:
|
||||
logger.warning("Plugin(s) shadowed by built-in channels (ignored): {}", shadowed)
|
||||
|
||||
return {**external, **builtin}
|
||||
names = discover_channel_names()
|
||||
return discover_enabled(set(names), _names=names, _include_all_external=True)
|
||||
|
||||
@ -1,6 +1,18 @@
|
||||
"""Cron service for scheduled agent tasks."""
|
||||
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.cron.types import CronJob, CronSchedule
|
||||
|
||||
__all__ = ["CronService", "CronJob", "CronSchedule"]
|
||||
|
||||
_LAZY = {"CronService": ".service"}
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
module_path = _LAZY.get(name)
|
||||
if module_path is None:
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
from importlib import import_module
|
||||
mod = import_module(module_path, __name__)
|
||||
val = getattr(mod, name)
|
||||
globals()[name] = val
|
||||
return val
|
||||
|
||||
@ -16,20 +16,9 @@ from ipaddress import ip_address
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
import json_repair
|
||||
from loguru import logger
|
||||
|
||||
if os.environ.get("LANGFUSE_SECRET_KEY") and importlib.util.find_spec("langfuse"):
|
||||
from langfuse.openai import AsyncOpenAI
|
||||
else:
|
||||
if os.environ.get("LANGFUSE_SECRET_KEY"):
|
||||
logger.warning(
|
||||
"LANGFUSE_SECRET_KEY is set but langfuse is not installed; "
|
||||
"install with `pip install langfuse` to enable tracing"
|
||||
)
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.openai_responses import (
|
||||
consume_sdk_stream,
|
||||
@ -39,8 +28,15 @@ from nanobot.providers.openai_responses import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import AsyncOpenAI as AsyncOpenAIType
|
||||
|
||||
from nanobot.providers.registry import ProviderSpec
|
||||
|
||||
# Module-level placeholder — set lazily by _ensure_client on first real
|
||||
# use, or replaced by tests via ``patch(...)``. Kept as a plain name so
|
||||
# that ``unittest.mock.patch`` can find and replace it.
|
||||
AsyncOpenAI: Any = None
|
||||
|
||||
_ALLOWED_MSG_KEYS = frozenset({
|
||||
"role", "content", "tool_calls", "tool_call_id", "name",
|
||||
"reasoning_content", "extra_content",
|
||||
@ -302,12 +298,35 @@ class OpenAICompatProvider(LLMProvider):
|
||||
|
||||
effective_base = api_base or (spec.default_api_base if spec else None) or None
|
||||
self._effective_base = effective_base
|
||||
default_headers = {"x-session-affinity": uuid.uuid4().hex}
|
||||
self._default_headers = {"x-session-affinity": uuid.uuid4().hex}
|
||||
if _uses_openrouter_attribution(spec, effective_base):
|
||||
default_headers.update(_DEFAULT_OPENROUTER_HEADERS)
|
||||
self._default_headers.update(_DEFAULT_OPENROUTER_HEADERS)
|
||||
if extra_headers:
|
||||
default_headers.update(extra_headers)
|
||||
self._default_headers.update(extra_headers)
|
||||
self._api_key_for_client = api_key or "no-key"
|
||||
self._is_local = _is_local_endpoint(spec, effective_base)
|
||||
|
||||
# Lazy-init: the OpenAI client and its httpx transport are expensive
|
||||
# to create (~700 ms on Windows). Defer until first use — unless
|
||||
# AsyncOpenAI has been patched (tests), in which case build eagerly.
|
||||
self._client: AsyncOpenAIType | None = None
|
||||
self._client_lock = asyncio.Lock()
|
||||
|
||||
if AsyncOpenAI is not None:
|
||||
self._build_client()
|
||||
|
||||
# Responses API circuit breaker: skip after repeated failures,
|
||||
# probe again after _RESPONSES_PROBE_INTERVAL_S seconds.
|
||||
self._responses_failures: dict[str, int] = {}
|
||||
self._responses_tripped_at: dict[str, float] = {}
|
||||
|
||||
def _build_client(self) -> None:
|
||||
"""Create the OpenAI client using the current module-level AsyncOpenAI."""
|
||||
import httpx
|
||||
|
||||
timeout_s = _openai_compat_timeout_s()
|
||||
http_client: httpx.AsyncClient | None = None
|
||||
if self._is_local:
|
||||
# Local model servers (Ollama, llama.cpp, vLLM) often close idle
|
||||
# HTTP connections before the client-side keepalive expires. When
|
||||
# two LLM calls happen seconds apart (e.g. heartbeat _decide then
|
||||
@ -317,27 +336,41 @@ class OpenAICompatProvider(LLMProvider):
|
||||
# opening a fresh connection for each request, which is cheap on a
|
||||
# LAN. Cloud providers benefit from keepalive, so we leave the
|
||||
# default pool settings for them.
|
||||
timeout_s = _openai_compat_timeout_s()
|
||||
http_client: httpx.AsyncClient | None = None
|
||||
if _is_local_endpoint(spec, effective_base):
|
||||
http_client = httpx.AsyncClient(
|
||||
limits=httpx.Limits(keepalive_expiry=0),
|
||||
timeout=timeout_s,
|
||||
)
|
||||
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=api_key or "no-key",
|
||||
base_url=effective_base,
|
||||
default_headers=default_headers,
|
||||
api_key=self._api_key_for_client,
|
||||
base_url=self._effective_base,
|
||||
default_headers=self._default_headers,
|
||||
max_retries=0,
|
||||
timeout=timeout_s,
|
||||
http_client=http_client,
|
||||
)
|
||||
|
||||
# Responses API circuit breaker: skip after repeated failures,
|
||||
# probe again after _RESPONSES_PROBE_INTERVAL_S seconds.
|
||||
self._responses_failures: dict[str, int] = {}
|
||||
self._responses_tripped_at: dict[str, float] = {}
|
||||
async def _ensure_client(self):
|
||||
"""Return the shared OpenAI client, creating it on first call."""
|
||||
if self._client is not None:
|
||||
return self._client
|
||||
async with self._client_lock:
|
||||
if self._client is not None:
|
||||
return self._client
|
||||
global AsyncOpenAI
|
||||
if AsyncOpenAI is None:
|
||||
if os.environ.get("LANGFUSE_SECRET_KEY") and importlib.util.find_spec("langfuse"):
|
||||
from langfuse.openai import AsyncOpenAI as _AsyncOpenAI
|
||||
else:
|
||||
if os.environ.get("LANGFUSE_SECRET_KEY"):
|
||||
logger.warning(
|
||||
"LANGFUSE_SECRET_KEY is set but langfuse is not installed; "
|
||||
"install with `pip install langfuse` to enable tracing"
|
||||
)
|
||||
from openai import AsyncOpenAI as _AsyncOpenAI
|
||||
AsyncOpenAI = _AsyncOpenAI
|
||||
|
||||
self._build_client()
|
||||
return self._client
|
||||
|
||||
def _setup_env(self, api_key: str, api_base: str | None) -> None:
|
||||
"""Set environment variables based on provider spec."""
|
||||
@ -1182,6 +1215,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> LLMResponse:
|
||||
await self._ensure_client()
|
||||
try:
|
||||
if self._should_use_responses_api(model, reasoning_effort):
|
||||
try:
|
||||
@ -1223,6 +1257,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
on_thinking_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
await self._ensure_client()
|
||||
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
|
||||
try:
|
||||
if self._should_use_responses_api(model, reasoning_effort):
|
||||
|
||||
@ -180,7 +180,7 @@ async def test_manager_loads_plugin_from_dict_config():
|
||||
)
|
||||
|
||||
with patch(
|
||||
"nanobot.channels.registry.discover_all",
|
||||
"nanobot.channels.registry.discover_enabled",
|
||||
return_value={"fakeplugin": _FakePlugin},
|
||||
):
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
@ -210,7 +210,7 @@ async def test_manager_propagates_groq_transcription_api_base_to_channels():
|
||||
)
|
||||
|
||||
with patch(
|
||||
"nanobot.channels.registry.discover_all",
|
||||
"nanobot.channels.registry.discover_enabled",
|
||||
return_value={"fakeplugin": _FakePlugin},
|
||||
):
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
@ -246,7 +246,7 @@ async def test_manager_propagates_openai_transcription_api_base_to_channels():
|
||||
)
|
||||
|
||||
with patch(
|
||||
"nanobot.channels.registry.discover_all",
|
||||
"nanobot.channels.registry.discover_enabled",
|
||||
return_value={"fakeplugin": _FakePlugin},
|
||||
):
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
@ -498,10 +498,8 @@ async def test_manager_skips_disabled_plugin():
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"nanobot.channels.registry.discover_all",
|
||||
return_value={"fakeplugin": _FakePlugin},
|
||||
):
|
||||
ep = _make_entry_point("fakeplugin", _FakePlugin)
|
||||
with patch(_EP_TARGET, return_value=[ep]):
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
|
||||
@ -85,17 +85,18 @@ class TestIsLocalEndpoint:
|
||||
class TestLocalKeepaliveConfig:
|
||||
"""Verify that local endpoints get keepalive_expiry=0."""
|
||||
|
||||
def test_local_spec_disables_keepalive(self):
|
||||
async def test_local_spec_disables_keepalive(self):
|
||||
spec = _make_spec(is_local=True)
|
||||
spec.env_key = ""
|
||||
spec.default_api_base = "http://localhost:11434/v1"
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="test", api_base="http://localhost:11434/v1", spec=spec,
|
||||
)
|
||||
await provider._ensure_client()
|
||||
pool = provider._client._client._transport._pool
|
||||
assert pool._keepalive_expiry == 0
|
||||
|
||||
def test_lan_ip_disables_keepalive(self):
|
||||
async def test_lan_ip_disables_keepalive(self):
|
||||
"""A generic 'openai' spec with a LAN IP should still disable keepalive."""
|
||||
spec = _make_spec(is_local=False)
|
||||
spec.env_key = ""
|
||||
@ -103,16 +104,18 @@ class TestLocalKeepaliveConfig:
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="test", api_base="http://192.168.8.188:1234/v1", spec=spec,
|
||||
)
|
||||
await provider._ensure_client()
|
||||
pool = provider._client._client._transport._pool
|
||||
assert pool._keepalive_expiry == 0
|
||||
|
||||
def test_cloud_keeps_default_keepalive(self):
|
||||
async def test_cloud_keeps_default_keepalive(self):
|
||||
spec = _make_spec(is_local=False)
|
||||
spec.env_key = ""
|
||||
spec.default_api_base = "https://api.openai.com/v1"
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="test", api_base=None, spec=spec,
|
||||
)
|
||||
await provider._ensure_client()
|
||||
pool = provider._client._client._transport._pool
|
||||
# Default httpx keepalive is 5.0s
|
||||
assert pool._keepalive_expiry == 5.0
|
||||
|
||||
@ -29,7 +29,7 @@ def test_openai_compat_provider_sets_timeout_on_local_http_client() -> None:
|
||||
with (
|
||||
patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_async_openai,
|
||||
patch(
|
||||
"nanobot.providers.openai_compat_provider.httpx.AsyncClient",
|
||||
"httpx.AsyncClient",
|
||||
return_value=sentinel.http_client,
|
||||
) as mock_http_client,
|
||||
):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user