refactor(channels): resolve progress overrides at init-time like transcription

This commit is contained in:
chengyongru 2026-04-29 13:46:35 +08:00 committed by Xubin Ren
parent a0443e8f9e
commit 74270bb8a8
3 changed files with 56 additions and 70 deletions

View File

@ -26,6 +26,8 @@ class BaseChannel(ABC):
transcription_api_key: str = "" transcription_api_key: str = ""
transcription_api_base: str = "" transcription_api_base: str = ""
transcription_language: str | None = None transcription_language: str | None = None
send_progress: bool = True
send_tool_hints: bool = False
def __init__(self, config: Any, bus: MessageBus): def __init__(self, config: Any, bus: MessageBus):
""" """

View File

@ -7,7 +7,6 @@ from pathlib import Path
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from loguru import logger from loguru import logger
from pydantic.alias_generators import to_camel
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
@ -29,13 +28,14 @@ def _default_webui_dist() -> Path | None:
return candidate if candidate.is_dir() else None return candidate if candidate.is_dir() else None
def _coerce_optional_bool(value: Any) -> bool | None:
return value if isinstance(value, bool) else None
# Retry delays for message sending (exponential backoff: 1s, 2s, 4s) # Retry delays for message sending (exponential backoff: 1s, 2s, 4s)
_SEND_RETRY_DELAYS = (1, 2, 4) _SEND_RETRY_DELAYS = (1, 2, 4)
_BOOL_CAMEL_ALIASES: dict[str, str] = {
"send_progress": "sendProgress",
"send_tool_hints": "sendToolHints",
}
class ChannelManager: class ChannelManager:
""" """
@ -96,6 +96,12 @@ class ChannelManager:
channel.transcription_api_key = transcription_key channel.transcription_api_key = transcription_key
channel.transcription_api_base = transcription_base channel.transcription_api_base = transcription_base
channel.transcription_language = transcription_language channel.transcription_language = transcription_language
channel.send_progress = self._resolve_bool_override(
section, "send_progress", self.config.channels.send_progress,
)
channel.send_tool_hints = self._resolve_bool_override(
section, "send_tool_hints", self.config.channels.send_tool_hints,
)
self.channels[name] = channel self.channels[name] = channel
logger.info("{} channel enabled", cls.display_name) logger.info("{} channel enabled", cls.display_name)
except Exception as e: except Exception as e:
@ -138,26 +144,29 @@ class ChannelManager:
) )
def _should_send_progress(self, channel_name: str, *, tool_hint: bool = False) -> bool: def _should_send_progress(self, channel_name: str, *, tool_hint: bool = False) -> bool:
"""Resolve progress visibility, allowing per-channel overrides.""" """Return whether progress (or tool-hints) may be sent to *channel_name*."""
key = "send_tool_hints" if tool_hint else "send_progress" ch = self.channels.get(channel_name)
default = getattr(self.config.channels, key) if ch is None:
override = self._channel_bool_override(channel_name, key) logger.warning("Progress check for unknown channel: {}", channel_name)
return default if override is None else override return False
return ch.send_tool_hints if tool_hint else ch.send_progress
def _channel_bool_override(self, channel_name: str, key: str) -> bool | None: def _resolve_bool_override(self, section: Any, key: str, default: bool) -> bool:
section = getattr(self.config.channels, channel_name, None) """Return *key* from *section* if it is a bool, otherwise *default*.
if section is None:
return None
camel_key = to_camel(key) For dict configs also checks the camelCase alias (e.g. ``sendProgress``
for ``send_progress``) so raw JSON/TOML configs work alongside
Pydantic models.
"""
if isinstance(section, dict): if isinstance(section, dict):
value = section.get(key, section.get(camel_key)) value = section.get(key)
return _coerce_optional_bool(value)
value = getattr(section, key, None)
if value is None: if value is None:
value = getattr(section, camel_key, None) camel = _BOOL_CAMEL_ALIASES.get(key)
return _coerce_optional_bool(value) if camel:
value = section.get(camel)
return value if isinstance(value, bool) else default
value = getattr(section, key, None)
return value if isinstance(value, bool) else default
async def _start_channel(self, name: str, channel: BaseChannel) -> None: async def _start_channel(self, name: str, channel: BaseChannel) -> None:
"""Start a channel and log any exceptions.""" """Start a channel and log any exceptions."""

View File

@ -8,7 +8,7 @@ from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel from nanobot.channels.base import BaseChannel
from nanobot.channels.manager import ChannelManager from nanobot.channels.manager import ChannelManager
from nanobot.config.schema import ChannelsConfig, Config from nanobot.config.schema import Config
class MockChannel(BaseChannel): class MockChannel(BaseChannel):
@ -299,64 +299,42 @@ class TestDispatchOutboundWithCoalescing:
class TestProgressFiltering: class TestProgressFiltering:
"""Progress filtering should honor per-channel config overrides.""" """Progress filtering should honor per-channel settings."""
def test_progress_visibility_uses_global_defaults(self, manager): def test_progress_visibility_uses_global_defaults(self, manager):
manager.config.channels = ChannelsConfig.model_validate({
"sendProgress": True,
"sendToolHints": False,
})
assert manager._should_send_progress("mock", tool_hint=False) is True assert manager._should_send_progress("mock", tool_hint=False) is True
assert manager._should_send_progress("mock", tool_hint=True) is False assert manager._should_send_progress("mock", tool_hint=True) is False
def test_progress_visibility_uses_channel_overrides(self, manager): def test_progress_visibility_uses_channel_overrides(self, manager):
manager.config.channels = ChannelsConfig.model_validate({ manager.channels["mock"].send_progress = False
"sendProgress": True, manager.channels["mock"].send_tool_hints = True
"sendToolHints": False,
"mock": {
"sendProgress": False,
"sendToolHints": True,
},
})
assert manager._should_send_progress("mock", tool_hint=False) is False
assert manager._should_send_progress("mock", tool_hint=True) is True
assert manager._should_send_progress("other", tool_hint=False) is True
assert manager._should_send_progress("other", tool_hint=True) is False
def test_progress_visibility_uses_snake_case_channel_overrides(self, manager):
manager.config.channels = ChannelsConfig.model_validate({
"sendProgress": True,
"sendToolHints": False,
"mock": {
"send_progress": False,
"send_tool_hints": True,
},
})
assert manager._should_send_progress("mock", tool_hint=False) is False assert manager._should_send_progress("mock", tool_hint=False) is False
assert manager._should_send_progress("mock", tool_hint=True) is True assert manager._should_send_progress("mock", tool_hint=True) is True
def test_progress_visibility_ignores_non_bool_channel_overrides(self, manager): def test_progress_visibility_returns_false_for_missing_channel(self, manager):
manager.config.channels = ChannelsConfig.model_validate({ assert manager._should_send_progress("nonexistent", tool_hint=False) is False
"sendProgress": True, assert manager._should_send_progress("nonexistent", tool_hint=True) is False
"sendToolHints": False,
"mock": {
"sendProgress": "false",
"sendToolHints": "true",
},
})
assert manager._should_send_progress("mock", tool_hint=False) is True def test_resolve_bool_override_dict(self, manager):
assert manager._should_send_progress("mock", tool_hint=True) is False assert manager._resolve_bool_override({}, "send_progress", True) is True
assert manager._resolve_bool_override({"send_progress": False}, "send_progress", True) is False
assert manager._resolve_bool_override({"sendProgress": False}, "send_progress", True) is False
assert manager._resolve_bool_override({"send_progress": "false"}, "send_progress", True) is True
def test_resolve_bool_override_model(self, manager):
class FakeSection:
send_progress = False
send_tool_hints = True
assert manager._resolve_bool_override(FakeSection(), "send_progress", True) is False
assert manager._resolve_bool_override(FakeSection(), "send_tool_hints", False) is True
# Missing attribute falls back to default
assert manager._resolve_bool_override(FakeSection(), "unknown_key", True) is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_channel_override_can_drop_progress_message(self, manager, bus): async def test_channel_override_can_drop_progress_message(self, manager, bus):
manager.config.channels = ChannelsConfig.model_validate({ manager.channels["mock"].send_progress = False
"sendProgress": True,
"mock": {"sendProgress": False},
})
await bus.publish_outbound(OutboundMessage( await bus.publish_outbound(OutboundMessage(
channel="mock", channel="mock",
chat_id="chat1", chat_id="chat1",
@ -389,10 +367,7 @@ class TestProgressFiltering:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_channel_override_can_enable_tool_hints(self, manager, bus): async def test_channel_override_can_enable_tool_hints(self, manager, bus):
manager.config.channels = ChannelsConfig.model_validate({ manager.channels["mock"].send_tool_hints = True
"sendToolHints": False,
"mock": {"sendToolHints": True},
})
await bus.publish_outbound(OutboundMessage( await bus.publish_outbound(OutboundMessage(
channel="mock", channel="mock",
chat_id="chat1", chat_id="chat1",