mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-04 00:35:58 +00:00
refactor(channels): resolve progress overrides at init-time like transcription
This commit is contained in:
parent
a0443e8f9e
commit
74270bb8a8
@ -26,6 +26,8 @@ class BaseChannel(ABC):
|
||||
transcription_api_key: str = ""
|
||||
transcription_api_base: str = ""
|
||||
transcription_language: str | None = None
|
||||
send_progress: bool = True
|
||||
send_tool_hints: bool = False
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
"""
|
||||
|
||||
@ -7,7 +7,6 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from loguru import logger
|
||||
from pydantic.alias_generators import to_camel
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
@ -29,13 +28,14 @@ def _default_webui_dist() -> Path | None:
|
||||
return candidate if candidate.is_dir() else None
|
||||
|
||||
|
||||
def _coerce_optional_bool(value: Any) -> bool | None:
|
||||
return value if isinstance(value, bool) else None
|
||||
|
||||
|
||||
# Retry delays for message sending (exponential backoff: 1s, 2s, 4s)
|
||||
_SEND_RETRY_DELAYS = (1, 2, 4)
|
||||
|
||||
_BOOL_CAMEL_ALIASES: dict[str, str] = {
|
||||
"send_progress": "sendProgress",
|
||||
"send_tool_hints": "sendToolHints",
|
||||
}
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
"""
|
||||
@ -96,6 +96,12 @@ class ChannelManager:
|
||||
channel.transcription_api_key = transcription_key
|
||||
channel.transcription_api_base = transcription_base
|
||||
channel.transcription_language = transcription_language
|
||||
channel.send_progress = self._resolve_bool_override(
|
||||
section, "send_progress", self.config.channels.send_progress,
|
||||
)
|
||||
channel.send_tool_hints = self._resolve_bool_override(
|
||||
section, "send_tool_hints", self.config.channels.send_tool_hints,
|
||||
)
|
||||
self.channels[name] = channel
|
||||
logger.info("{} channel enabled", cls.display_name)
|
||||
except Exception as e:
|
||||
@ -138,26 +144,29 @@ class ChannelManager:
|
||||
)
|
||||
|
||||
def _should_send_progress(self, channel_name: str, *, tool_hint: bool = False) -> bool:
|
||||
"""Resolve progress visibility, allowing per-channel overrides."""
|
||||
key = "send_tool_hints" if tool_hint else "send_progress"
|
||||
default = getattr(self.config.channels, key)
|
||||
override = self._channel_bool_override(channel_name, key)
|
||||
return default if override is None else override
|
||||
"""Return whether progress (or tool-hints) may be sent to *channel_name*."""
|
||||
ch = self.channels.get(channel_name)
|
||||
if ch is None:
|
||||
logger.warning("Progress check for unknown channel: {}", channel_name)
|
||||
return False
|
||||
return ch.send_tool_hints if tool_hint else ch.send_progress
|
||||
|
||||
def _channel_bool_override(self, channel_name: str, key: str) -> bool | None:
|
||||
section = getattr(self.config.channels, channel_name, None)
|
||||
if section is None:
|
||||
return None
|
||||
def _resolve_bool_override(self, section: Any, key: str, default: bool) -> bool:
|
||||
"""Return *key* from *section* if it is a bool, otherwise *default*.
|
||||
|
||||
camel_key = to_camel(key)
|
||||
For dict configs also checks the camelCase alias (e.g. ``sendProgress``
|
||||
for ``send_progress``) so raw JSON/TOML configs work alongside
|
||||
Pydantic models.
|
||||
"""
|
||||
if isinstance(section, dict):
|
||||
value = section.get(key, section.get(camel_key))
|
||||
return _coerce_optional_bool(value)
|
||||
|
||||
value = section.get(key)
|
||||
if value is None:
|
||||
camel = _BOOL_CAMEL_ALIASES.get(key)
|
||||
if camel:
|
||||
value = section.get(camel)
|
||||
return value if isinstance(value, bool) else default
|
||||
value = getattr(section, key, None)
|
||||
if value is None:
|
||||
value = getattr(section, camel_key, None)
|
||||
return _coerce_optional_bool(value)
|
||||
return value if isinstance(value, bool) else default
|
||||
|
||||
async def _start_channel(self, name: str, channel: BaseChannel) -> None:
|
||||
"""Start a channel and log any exceptions."""
|
||||
|
||||
@ -8,7 +8,7 @@ from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.channels.manager import ChannelManager
|
||||
from nanobot.config.schema import ChannelsConfig, Config
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
|
||||
class MockChannel(BaseChannel):
|
||||
@ -299,64 +299,42 @@ class TestDispatchOutboundWithCoalescing:
|
||||
|
||||
|
||||
class TestProgressFiltering:
|
||||
"""Progress filtering should honor per-channel config overrides."""
|
||||
"""Progress filtering should honor per-channel settings."""
|
||||
|
||||
def test_progress_visibility_uses_global_defaults(self, manager):
|
||||
manager.config.channels = ChannelsConfig.model_validate({
|
||||
"sendProgress": True,
|
||||
"sendToolHints": False,
|
||||
})
|
||||
|
||||
assert manager._should_send_progress("mock", tool_hint=False) is True
|
||||
assert manager._should_send_progress("mock", tool_hint=True) is False
|
||||
|
||||
def test_progress_visibility_uses_channel_overrides(self, manager):
|
||||
manager.config.channels = ChannelsConfig.model_validate({
|
||||
"sendProgress": True,
|
||||
"sendToolHints": False,
|
||||
"mock": {
|
||||
"sendProgress": False,
|
||||
"sendToolHints": True,
|
||||
},
|
||||
})
|
||||
|
||||
assert manager._should_send_progress("mock", tool_hint=False) is False
|
||||
assert manager._should_send_progress("mock", tool_hint=True) is True
|
||||
assert manager._should_send_progress("other", tool_hint=False) is True
|
||||
assert manager._should_send_progress("other", tool_hint=True) is False
|
||||
|
||||
def test_progress_visibility_uses_snake_case_channel_overrides(self, manager):
|
||||
manager.config.channels = ChannelsConfig.model_validate({
|
||||
"sendProgress": True,
|
||||
"sendToolHints": False,
|
||||
"mock": {
|
||||
"send_progress": False,
|
||||
"send_tool_hints": True,
|
||||
},
|
||||
})
|
||||
manager.channels["mock"].send_progress = False
|
||||
manager.channels["mock"].send_tool_hints = True
|
||||
|
||||
assert manager._should_send_progress("mock", tool_hint=False) is False
|
||||
assert manager._should_send_progress("mock", tool_hint=True) is True
|
||||
|
||||
def test_progress_visibility_ignores_non_bool_channel_overrides(self, manager):
|
||||
manager.config.channels = ChannelsConfig.model_validate({
|
||||
"sendProgress": True,
|
||||
"sendToolHints": False,
|
||||
"mock": {
|
||||
"sendProgress": "false",
|
||||
"sendToolHints": "true",
|
||||
},
|
||||
})
|
||||
def test_progress_visibility_returns_false_for_missing_channel(self, manager):
|
||||
assert manager._should_send_progress("nonexistent", tool_hint=False) is False
|
||||
assert manager._should_send_progress("nonexistent", tool_hint=True) is False
|
||||
|
||||
assert manager._should_send_progress("mock", tool_hint=False) is True
|
||||
assert manager._should_send_progress("mock", tool_hint=True) is False
|
||||
def test_resolve_bool_override_dict(self, manager):
|
||||
assert manager._resolve_bool_override({}, "send_progress", True) is True
|
||||
assert manager._resolve_bool_override({"send_progress": False}, "send_progress", True) is False
|
||||
assert manager._resolve_bool_override({"sendProgress": False}, "send_progress", True) is False
|
||||
assert manager._resolve_bool_override({"send_progress": "false"}, "send_progress", True) is True
|
||||
|
||||
def test_resolve_bool_override_model(self, manager):
|
||||
class FakeSection:
|
||||
send_progress = False
|
||||
send_tool_hints = True
|
||||
|
||||
assert manager._resolve_bool_override(FakeSection(), "send_progress", True) is False
|
||||
assert manager._resolve_bool_override(FakeSection(), "send_tool_hints", False) is True
|
||||
# Missing attribute falls back to default
|
||||
assert manager._resolve_bool_override(FakeSection(), "unknown_key", True) is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_channel_override_can_drop_progress_message(self, manager, bus):
|
||||
manager.config.channels = ChannelsConfig.model_validate({
|
||||
"sendProgress": True,
|
||||
"mock": {"sendProgress": False},
|
||||
})
|
||||
manager.channels["mock"].send_progress = False
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
@ -389,10 +367,7 @@ class TestProgressFiltering:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_channel_override_can_enable_tool_hints(self, manager, bus):
|
||||
manager.config.channels = ChannelsConfig.model_validate({
|
||||
"sendToolHints": False,
|
||||
"mock": {"sendToolHints": True},
|
||||
})
|
||||
manager.channels["mock"].send_tool_hints = True
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user