nanobot/nanobot/channels/manager.py

349 lines
13 KiB
Python

"""Channel manager for coordinating chat channels."""
from __future__ import annotations
import asyncio
from pathlib import Path
from typing import TYPE_CHECKING, Any
from loguru import logger
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.config.schema import Config
from nanobot.utils.restart import consume_restart_notice_from_env, format_restart_completed_message
if TYPE_CHECKING:
from nanobot.session.manager import SessionManager
def _default_webui_dist() -> Path | None:
"""Return the absolute path to the bundled webui dist directory if it exists."""
try:
import nanobot.web as web_pkg # type: ignore[import-not-found]
except ImportError:
return None
candidate = Path(web_pkg.__file__).resolve().parent / "dist"
return candidate if candidate.is_dir() else None
# Retry delays for message sending (exponential backoff: 1s, 2s, 4s)
_SEND_RETRY_DELAYS = (1, 2, 4)
class ChannelManager:
"""
Manages chat channels and coordinates message routing.
Responsibilities:
- Initialize enabled channels (Telegram, WhatsApp, etc.)
- Start/stop channels
- Route outbound messages
"""
def __init__(
self,
config: Config,
bus: MessageBus,
*,
session_manager: "SessionManager | None" = None,
):
self.config = config
self.bus = bus
self._session_manager = session_manager
self.channels: dict[str, BaseChannel] = {}
self._dispatch_task: asyncio.Task | None = None
self._init_channels()
def _init_channels(self) -> None:
"""Initialize channels discovered via pkgutil scan + entry_points plugins."""
from nanobot.channels.registry import discover_all
transcription_provider = self.config.channels.transcription_provider
transcription_key = self._resolve_transcription_key(transcription_provider)
transcription_base = self._resolve_transcription_base(transcription_provider)
transcription_language = self.config.channels.transcription_language
for name, cls in discover_all().items():
section = getattr(self.config.channels, name, None)
if section is None:
continue
enabled = (
section.get("enabled", False)
if isinstance(section, dict)
else getattr(section, "enabled", False)
)
if not enabled:
continue
try:
kwargs: dict[str, Any] = {}
# Only the WebSocket channel currently hosts the embedded webui
# surface; other channels stay oblivious to these knobs.
if cls.name == "websocket" and self._session_manager is not None:
kwargs["session_manager"] = self._session_manager
static_path = _default_webui_dist()
if static_path is not None:
kwargs["static_dist_path"] = static_path
channel = cls(section, self.bus, **kwargs)
channel.transcription_provider = transcription_provider
channel.transcription_api_key = transcription_key
channel.transcription_api_base = transcription_base
channel.transcription_language = transcription_language
self.channels[name] = channel
logger.info("{} channel enabled", cls.display_name)
except Exception as e:
logger.warning("{} channel not available: {}", name, e)
self._validate_allow_from()
def _resolve_transcription_key(self, provider: str) -> str:
"""Pick the API key for the configured transcription provider."""
try:
if provider == "openai":
return self.config.providers.openai.api_key
return self.config.providers.groq.api_key
except AttributeError:
return ""
def _resolve_transcription_base(self, provider: str) -> str:
"""Pick the API base URL for the configured transcription provider."""
try:
if provider == "openai":
return self.config.providers.openai.api_base or ""
return self.config.providers.groq.api_base or ""
except AttributeError:
return ""
def _validate_allow_from(self) -> None:
for name, ch in self.channels.items():
cfg = ch.config
if isinstance(cfg, dict):
if "allow_from" in cfg:
allow = cfg.get("allow_from")
else:
allow = cfg.get("allowFrom")
else:
allow = getattr(cfg, "allow_from", None)
if allow == []:
raise SystemExit(
f'Error: "{name}" has empty allowFrom (denies all). '
f'Set ["*"] to allow everyone, or add specific user IDs.'
)
async def _start_channel(self, name: str, channel: BaseChannel) -> None:
"""Start a channel and log any exceptions."""
try:
await channel.start()
except Exception as e:
logger.error("Failed to start channel {}: {}", name, e)
async def start_all(self) -> None:
"""Start all channels and the outbound dispatcher."""
if not self.channels:
logger.warning("No channels enabled")
return
# Start outbound dispatcher
self._dispatch_task = asyncio.create_task(self._dispatch_outbound())
# Start channels
tasks = []
for name, channel in self.channels.items():
logger.info("Starting {} channel...", name)
tasks.append(asyncio.create_task(self._start_channel(name, channel)))
self._notify_restart_done_if_needed()
# Wait for all to complete (they should run forever)
await asyncio.gather(*tasks, return_exceptions=True)
def _notify_restart_done_if_needed(self) -> None:
"""Send restart completion message when runtime env markers are present."""
notice = consume_restart_notice_from_env()
if not notice:
return
target = self.channels.get(notice.channel)
if not target:
return
asyncio.create_task(self._send_with_retry(
target,
OutboundMessage(
channel=notice.channel,
chat_id=notice.chat_id,
content=format_restart_completed_message(notice.started_at_raw),
),
))
async def stop_all(self) -> None:
"""Stop all channels and the dispatcher."""
logger.info("Stopping all channels...")
# Stop dispatcher
if self._dispatch_task:
self._dispatch_task.cancel()
try:
await self._dispatch_task
except asyncio.CancelledError:
pass
# Stop all channels
for name, channel in self.channels.items():
try:
await channel.stop()
logger.info("Stopped {} channel", name)
except Exception as e:
logger.error("Error stopping {}: {}", name, e)
async def _dispatch_outbound(self) -> None:
"""Dispatch outbound messages to the appropriate channel."""
logger.info("Outbound dispatcher started")
# Buffer for messages that couldn't be processed during delta coalescing
# (since asyncio.Queue doesn't support push_front)
pending: list[OutboundMessage] = []
while True:
try:
# First check pending buffer before waiting on queue
if pending:
msg = pending.pop(0)
else:
msg = await asyncio.wait_for(
self.bus.consume_outbound(),
timeout=1.0
)
if msg.metadata.get("_progress"):
if msg.metadata.get("_tool_hint") and not self.config.channels.send_tool_hints:
continue
if not msg.metadata.get("_tool_hint") and not self.config.channels.send_progress:
continue
if msg.metadata.get("_retry_wait"):
continue
# Coalesce consecutive _stream_delta messages for the same (channel, chat_id)
# to reduce API calls and improve streaming latency
if msg.metadata.get("_stream_delta") and not msg.metadata.get("_stream_end"):
msg, extra_pending = self._coalesce_stream_deltas(msg)
pending.extend(extra_pending)
channel = self.channels.get(msg.channel)
if channel:
await self._send_with_retry(channel, msg)
else:
logger.warning("Unknown channel: {}", msg.channel)
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
break
@staticmethod
async def _send_once(channel: BaseChannel, msg: OutboundMessage) -> None:
"""Send one outbound message without retry policy."""
if msg.metadata.get("_stream_delta") or msg.metadata.get("_stream_end"):
await channel.send_delta(msg.chat_id, msg.content, msg.metadata)
elif not msg.metadata.get("_streamed"):
await channel.send(msg)
def _coalesce_stream_deltas(
self, first_msg: OutboundMessage
) -> tuple[OutboundMessage, list[OutboundMessage]]:
"""Merge consecutive _stream_delta messages for the same (channel, chat_id).
This reduces the number of API calls when the queue has accumulated multiple
deltas, which happens when LLM generates faster than the channel can process.
Returns:
tuple of (merged_message, list_of_non_matching_messages)
"""
target_key = (first_msg.channel, first_msg.chat_id)
combined_content = first_msg.content
final_metadata = dict(first_msg.metadata or {})
non_matching: list[OutboundMessage] = []
# Only merge consecutive deltas. As soon as we hit any other message,
# stop and hand that boundary back to the dispatcher via `pending`.
while True:
try:
next_msg = self.bus.outbound.get_nowait()
except asyncio.QueueEmpty:
break
# Check if this message belongs to the same stream
same_target = (next_msg.channel, next_msg.chat_id) == target_key
is_delta = next_msg.metadata and next_msg.metadata.get("_stream_delta")
is_end = next_msg.metadata and next_msg.metadata.get("_stream_end")
if same_target and is_delta and not final_metadata.get("_stream_end"):
# Accumulate content
combined_content += next_msg.content
# If we see _stream_end, remember it and stop coalescing this stream
if is_end:
final_metadata["_stream_end"] = True
# Stream ended - stop coalescing this stream
break
else:
# First non-matching message defines the coalescing boundary.
non_matching.append(next_msg)
break
merged = OutboundMessage(
channel=first_msg.channel,
chat_id=first_msg.chat_id,
content=combined_content,
metadata=final_metadata,
)
return merged, non_matching
async def _send_with_retry(self, channel: BaseChannel, msg: OutboundMessage) -> None:
"""Send a message with retry on failure using exponential backoff.
Note: CancelledError is re-raised to allow graceful shutdown.
"""
max_attempts = max(self.config.channels.send_max_retries, 1)
for attempt in range(max_attempts):
try:
await self._send_once(channel, msg)
return # Send succeeded
except asyncio.CancelledError:
raise # Propagate cancellation for graceful shutdown
except Exception as e:
if attempt == max_attempts - 1:
logger.error(
"Failed to send to {} after {} attempts: {} - {}",
msg.channel, max_attempts, type(e).__name__, e
)
return
delay = _SEND_RETRY_DELAYS[min(attempt, len(_SEND_RETRY_DELAYS) - 1)]
logger.warning(
"Send to {} failed (attempt {}/{}): {}, retrying in {}s",
msg.channel, attempt + 1, max_attempts, type(e).__name__, delay
)
try:
await asyncio.sleep(delay)
except asyncio.CancelledError:
raise # Propagate cancellation during sleep
def get_channel(self, name: str) -> BaseChannel | None:
"""Get a channel by name."""
return self.channels.get(name)
def get_status(self) -> dict[str, Any]:
"""Get status of all channels."""
return {
name: {
"enabled": True,
"running": channel.is_running
}
for name, channel in self.channels.items()
}
@property
def enabled_channels(self) -> list[str]:
"""Get list of enabled channel names."""
return list(self.channels.keys())