From 74270bb8a81d827c40dcc86bd9534e5fcae4c8b1 Mon Sep 17 00:00:00 2001 From: chengyongru Date: Wed, 29 Apr 2026 13:46:35 +0800 Subject: [PATCH] refactor(channels): resolve progress overrides at init-time like transcription --- nanobot/channels/base.py | 2 + nanobot/channels/manager.py | 51 +++++++------ .../test_channel_manager_delta_coalescing.py | 73 ++++++------------- 3 files changed, 56 insertions(+), 70 deletions(-) diff --git a/nanobot/channels/base.py b/nanobot/channels/base.py index 62bcd45c1..6097b420f 100644 --- a/nanobot/channels/base.py +++ b/nanobot/channels/base.py @@ -26,6 +26,8 @@ class BaseChannel(ABC): transcription_api_key: str = "" transcription_api_base: str = "" transcription_language: str | None = None + send_progress: bool = True + send_tool_hints: bool = False def __init__(self, config: Any, bus: MessageBus): """ diff --git a/nanobot/channels/manager.py b/nanobot/channels/manager.py index e2fc0dafc..14a6b2a5e 100644 --- a/nanobot/channels/manager.py +++ b/nanobot/channels/manager.py @@ -7,7 +7,6 @@ from pathlib import Path from typing import TYPE_CHECKING, Any from loguru import logger -from pydantic.alias_generators import to_camel from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus @@ -29,13 +28,14 @@ def _default_webui_dist() -> Path | None: return candidate if candidate.is_dir() else None -def _coerce_optional_bool(value: Any) -> bool | None: - return value if isinstance(value, bool) else None - - # Retry delays for message sending (exponential backoff: 1s, 2s, 4s) _SEND_RETRY_DELAYS = (1, 2, 4) +_BOOL_CAMEL_ALIASES: dict[str, str] = { + "send_progress": "sendProgress", + "send_tool_hints": "sendToolHints", +} + class ChannelManager: """ @@ -96,6 +96,12 @@ class ChannelManager: channel.transcription_api_key = transcription_key channel.transcription_api_base = transcription_base channel.transcription_language = transcription_language + channel.send_progress = self._resolve_bool_override( + section, "send_progress", self.config.channels.send_progress, + ) + channel.send_tool_hints = self._resolve_bool_override( + section, "send_tool_hints", self.config.channels.send_tool_hints, + ) self.channels[name] = channel logger.info("{} channel enabled", cls.display_name) except Exception as e: @@ -138,26 +144,29 @@ class ChannelManager: ) def _should_send_progress(self, channel_name: str, *, tool_hint: bool = False) -> bool: - """Resolve progress visibility, allowing per-channel overrides.""" - key = "send_tool_hints" if tool_hint else "send_progress" - default = getattr(self.config.channels, key) - override = self._channel_bool_override(channel_name, key) - return default if override is None else override + """Return whether progress (or tool-hints) may be sent to *channel_name*.""" + ch = self.channels.get(channel_name) + if ch is None: + logger.warning("Progress check for unknown channel: {}", channel_name) + return False + return ch.send_tool_hints if tool_hint else ch.send_progress - def _channel_bool_override(self, channel_name: str, key: str) -> bool | None: - section = getattr(self.config.channels, channel_name, None) - if section is None: - return None + def _resolve_bool_override(self, section: Any, key: str, default: bool) -> bool: + """Return *key* from *section* if it is a bool, otherwise *default*. - camel_key = to_camel(key) + For dict configs also checks the camelCase alias (e.g. ``sendProgress`` + for ``send_progress``) so raw JSON/TOML configs work alongside + Pydantic models. + """ if isinstance(section, dict): - value = section.get(key, section.get(camel_key)) - return _coerce_optional_bool(value) - + value = section.get(key) + if value is None: + camel = _BOOL_CAMEL_ALIASES.get(key) + if camel: + value = section.get(camel) + return value if isinstance(value, bool) else default value = getattr(section, key, None) - if value is None: - value = getattr(section, camel_key, None) - return _coerce_optional_bool(value) + return value if isinstance(value, bool) else default async def _start_channel(self, name: str, channel: BaseChannel) -> None: """Start a channel and log any exceptions.""" diff --git a/tests/channels/test_channel_manager_delta_coalescing.py b/tests/channels/test_channel_manager_delta_coalescing.py index ea4b68334..adec72e75 100644 --- a/tests/channels/test_channel_manager_delta_coalescing.py +++ b/tests/channels/test_channel_manager_delta_coalescing.py @@ -8,7 +8,7 @@ from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel from nanobot.channels.manager import ChannelManager -from nanobot.config.schema import ChannelsConfig, Config +from nanobot.config.schema import Config class MockChannel(BaseChannel): @@ -299,64 +299,42 @@ class TestDispatchOutboundWithCoalescing: class TestProgressFiltering: - """Progress filtering should honor per-channel config overrides.""" + """Progress filtering should honor per-channel settings.""" def test_progress_visibility_uses_global_defaults(self, manager): - manager.config.channels = ChannelsConfig.model_validate({ - "sendProgress": True, - "sendToolHints": False, - }) - assert manager._should_send_progress("mock", tool_hint=False) is True assert manager._should_send_progress("mock", tool_hint=True) is False def test_progress_visibility_uses_channel_overrides(self, manager): - manager.config.channels = ChannelsConfig.model_validate({ - "sendProgress": True, - "sendToolHints": False, - "mock": { - "sendProgress": False, - "sendToolHints": True, - }, - }) - - assert manager._should_send_progress("mock", tool_hint=False) is False - assert manager._should_send_progress("mock", tool_hint=True) is True - assert manager._should_send_progress("other", tool_hint=False) is True - assert manager._should_send_progress("other", tool_hint=True) is False - - def test_progress_visibility_uses_snake_case_channel_overrides(self, manager): - manager.config.channels = ChannelsConfig.model_validate({ - "sendProgress": True, - "sendToolHints": False, - "mock": { - "send_progress": False, - "send_tool_hints": True, - }, - }) + manager.channels["mock"].send_progress = False + manager.channels["mock"].send_tool_hints = True assert manager._should_send_progress("mock", tool_hint=False) is False assert manager._should_send_progress("mock", tool_hint=True) is True - def test_progress_visibility_ignores_non_bool_channel_overrides(self, manager): - manager.config.channels = ChannelsConfig.model_validate({ - "sendProgress": True, - "sendToolHints": False, - "mock": { - "sendProgress": "false", - "sendToolHints": "true", - }, - }) + def test_progress_visibility_returns_false_for_missing_channel(self, manager): + assert manager._should_send_progress("nonexistent", tool_hint=False) is False + assert manager._should_send_progress("nonexistent", tool_hint=True) is False - assert manager._should_send_progress("mock", tool_hint=False) is True - assert manager._should_send_progress("mock", tool_hint=True) is False + def test_resolve_bool_override_dict(self, manager): + assert manager._resolve_bool_override({}, "send_progress", True) is True + assert manager._resolve_bool_override({"send_progress": False}, "send_progress", True) is False + assert manager._resolve_bool_override({"sendProgress": False}, "send_progress", True) is False + assert manager._resolve_bool_override({"send_progress": "false"}, "send_progress", True) is True + + def test_resolve_bool_override_model(self, manager): + class FakeSection: + send_progress = False + send_tool_hints = True + + assert manager._resolve_bool_override(FakeSection(), "send_progress", True) is False + assert manager._resolve_bool_override(FakeSection(), "send_tool_hints", False) is True + # Missing attribute falls back to default + assert manager._resolve_bool_override(FakeSection(), "unknown_key", True) is True @pytest.mark.asyncio async def test_channel_override_can_drop_progress_message(self, manager, bus): - manager.config.channels = ChannelsConfig.model_validate({ - "sendProgress": True, - "mock": {"sendProgress": False}, - }) + manager.channels["mock"].send_progress = False await bus.publish_outbound(OutboundMessage( channel="mock", chat_id="chat1", @@ -389,10 +367,7 @@ class TestProgressFiltering: @pytest.mark.asyncio async def test_channel_override_can_enable_tool_hints(self, manager, bus): - manager.config.channels = ChannelsConfig.model_validate({ - "sendToolHints": False, - "mock": {"sendToolHints": True}, - }) + manager.channels["mock"].send_tool_hints = True await bus.publish_outbound(OutboundMessage( channel="mock", chat_id="chat1",