diff --git a/nanobot/channels/websocket.py b/nanobot/channels/websocket.py index a8e28317f..58f0c265d 100644 --- a/nanobot/channels/websocket.py +++ b/nanobot/channels/websocket.py @@ -7,6 +7,7 @@ import email.utils import hmac import http import json +import logging import mimetypes import re import secrets @@ -116,6 +117,45 @@ def _host_for_url(host: str, port: int) -> str: return f"{host}:{port}" +_OPENING_HANDSHAKE_FAILED_MESSAGE = "opening handshake failed" + + +def _exception_chain_has_disconnect(exc: BaseException | None) -> bool: + seen: set[int] = set() + while exc is not None: + ident = id(exc) + if ident in seen: + return False + seen.add(ident) + if isinstance(exc, ( + BrokenPipeError, + ConnectionAbortedError, + ConnectionResetError, + ConnectionClosed, + )): + return True + exc = exc.__cause__ or exc.__context__ + return False + + +class _WebSocketHandshakeNoiseFilter(logging.Filter): + """Suppress noisy restart-time handshakes where the client already disconnected.""" + + def filter(self, record: logging.LogRecord) -> bool: + if record.getMessage() != _OPENING_HANDSHAKE_FAILED_MESSAGE: + return True + exc_info = record.exc_info + exc = exc_info[1] if isinstance(exc_info, tuple) and len(exc_info) >= 2 else None + return not _exception_chain_has_disconnect(exc) + + +def _websockets_server_logger() -> logging.Logger: + ws_logger = logging.getLogger("websockets.server") + if not any(isinstance(f, _WebSocketHandshakeNoiseFilter) for f in ws_logger.filters): + ws_logger.addFilter(_WebSocketHandshakeNoiseFilter()) + return ws_logger + + class WebSocketConfig(Base): """WebSocket server channel configuration. @@ -1239,6 +1279,7 @@ class WebSocketChannel(BaseChannel): from nanobot.utils.logging_bridge import redirect_lib_logging redirect_lib_logging("websockets", level="WARNING") + ws_logger = _websockets_server_logger() self._running = True self._stop_event = asyncio.Event() @@ -1290,6 +1331,7 @@ class WebSocketChannel(BaseChannel): max_size=self.config.max_message_bytes, ping_interval=self.config.ping_interval_s, ping_timeout=self.config.ping_timeout_s, + logger=ws_logger, ) with suppress(OSError): path_obj.chmod(0o600) @@ -1303,6 +1345,7 @@ class WebSocketChannel(BaseChannel): ping_interval=self.config.ping_interval_s, ping_timeout=self.config.ping_timeout_s, ssl=ssl_context, + logger=ws_logger, ) try: assert self._stop_event is not None diff --git a/tests/channels/test_websocket_channel.py b/tests/channels/test_websocket_channel.py index 03cee58f7..78d22b969 100644 --- a/tests/channels/test_websocket_channel.py +++ b/tests/channels/test_websocket_channel.py @@ -3,6 +3,7 @@ import asyncio import functools import json +import logging import time from typing import Any from unittest.mock import AsyncMock, MagicMock @@ -16,6 +17,7 @@ from websockets.frames import Close from nanobot.bus.events import OUTBOUND_META_AGENT_UI, OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.websocket import ( + _OPENING_HANDSHAKE_FAILED_MESSAGE, WebSocketChannel, WebSocketConfig, _is_valid_chat_id, @@ -26,6 +28,7 @@ from nanobot.channels.websocket import ( _parse_inbound_payload, _parse_query, _parse_request_path, + _WebSocketHandshakeNoiseFilter, publish_runtime_model_update, ) from nanobot.config.loader import load_config, save_config @@ -39,6 +42,18 @@ from nanobot.webui.settings_api import settings_payload, update_provider_setting _PORT = 29876 +def _log_record(message: str, exc: BaseException) -> logging.LogRecord: + return logging.LogRecord( + name="websockets.server", + level=logging.ERROR, + pathname=__file__, + lineno=1, + msg=message, + args=(), + exc_info=(type(exc), exc, exc.__traceback__), + ) + + def _ch(bus: Any, **kw: Any) -> WebSocketChannel: cfg: dict[str, Any] = { "enabled": True, @@ -113,6 +128,22 @@ def test_websocket_config_rejects_relative_unix_socket() -> None: WebSocketConfig(unix_socket_path="engine.sock") +def test_websocket_handshake_noise_filter_suppresses_disconnects() -> None: + filter_ = _WebSocketHandshakeNoiseFilter() + wrapped = RuntimeError("wrapped") + wrapped.__cause__ = BrokenPipeError(32, "Broken pipe") + + assert not filter_.filter(_log_record(_OPENING_HANDSHAKE_FAILED_MESSAGE, BrokenPipeError())) + assert not filter_.filter(_log_record(_OPENING_HANDSHAKE_FAILED_MESSAGE, wrapped)) + + +def test_websocket_handshake_noise_filter_keeps_real_errors() -> None: + filter_ = _WebSocketHandshakeNoiseFilter() + + assert filter_.filter(_log_record(_OPENING_HANDSHAKE_FAILED_MESSAGE, RuntimeError("boom"))) + assert filter_.filter(_log_record("connection handler failed", BrokenPipeError())) + + def test_parse_query_extracts_token_and_client_id() -> None: query = _parse_query("/?token=secret&client_id=u1") assert query.get("token") == ["secret"]