mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-03 16:25:53 +00:00
fix: add origin_message_id support for spawn and message deduplication
This commit is contained in:
parent
3c20d16117
commit
4e06c00b46
@ -25,6 +25,7 @@ class SpawnTool(Tool):
|
||||
self._origin_channel: ContextVar[str] = ContextVar("spawn_origin_channel", default="cli")
|
||||
self._origin_chat_id: ContextVar[str] = ContextVar("spawn_origin_chat_id", default="direct")
|
||||
self._session_key: ContextVar[str] = ContextVar("spawn_session_key", default="cli:direct")
|
||||
self._origin_message_id: str | None = None
|
||||
|
||||
def set_context(self, channel: str, chat_id: str, effective_key: str | None = None) -> None:
|
||||
"""Set the origin context for subagent announcements."""
|
||||
@ -32,6 +33,10 @@ class SpawnTool(Tool):
|
||||
self._origin_chat_id.set(chat_id)
|
||||
self._session_key.set(effective_key or f"{channel}:{chat_id}")
|
||||
|
||||
def set_origin_message_id(self, message_id: str | None) -> None:
|
||||
"""Set the source message id for downstream deduplication."""
|
||||
self._origin_message_id = message_id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "spawn"
|
||||
@ -54,4 +59,5 @@ class SpawnTool(Tool):
|
||||
origin_channel=self._origin_channel.get(),
|
||||
origin_chat_id=self._origin_chat_id.get(),
|
||||
session_key=self._session_key.get(),
|
||||
origin_message_id=self._origin_message_id,
|
||||
)
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@ -37,6 +38,12 @@ _BOOL_CAMEL_ALIASES: dict[str, str] = {
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class _RecentOutbound:
|
||||
fingerprint: str
|
||||
ts: float
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
"""
|
||||
Manages chat channels and coordinates message routing.
|
||||
@ -59,6 +66,7 @@ class ChannelManager:
|
||||
self._session_manager = session_manager
|
||||
self.channels: dict[str, BaseChannel] = {}
|
||||
self._dispatch_task: asyncio.Task | None = None
|
||||
self._recent_outbound: dict[tuple[str, str], _RecentOutbound] = {}
|
||||
|
||||
self._init_channels()
|
||||
|
||||
@ -233,6 +241,25 @@ class ChannelManager:
|
||||
except Exception as e:
|
||||
logger.error("Error stopping {}: {}", name, e)
|
||||
|
||||
@staticmethod
|
||||
def _fingerprint_content(content: str) -> str:
|
||||
normalized = " ".join(content.split())
|
||||
return hashlib.sha1(normalized.encode("utf-8")).hexdigest() if normalized else ""
|
||||
|
||||
def _should_suppress_outbound(self, msg: OutboundMessage) -> bool:
|
||||
if msg.metadata.get("_progress"):
|
||||
return False
|
||||
fingerprint = self._fingerprint_content(msg.content)
|
||||
if not fingerprint:
|
||||
return False
|
||||
key = (msg.channel, msg.chat_id)
|
||||
recent = self._recent_outbound.get(key)
|
||||
now = asyncio.get_running_loop().time()
|
||||
if recent and recent.fingerprint == fingerprint and now - recent.ts <= 8.0:
|
||||
return True
|
||||
self._recent_outbound[key] = _RecentOutbound(fingerprint=fingerprint, ts=now)
|
||||
return False
|
||||
|
||||
async def _dispatch_outbound(self) -> None:
|
||||
"""Dispatch outbound messages to the appropriate channel."""
|
||||
logger.info("Outbound dispatcher started")
|
||||
@ -273,6 +300,11 @@ class ChannelManager:
|
||||
|
||||
channel = self.channels.get(msg.channel)
|
||||
if channel:
|
||||
# Duplicate suppression (non-streaming only)
|
||||
if not msg.metadata.get("_stream_delta") and not msg.metadata.get("_stream_end") and not msg.metadata.get("_streamed"):
|
||||
if self._should_suppress_outbound(msg):
|
||||
logger.info("Suppressing duplicate outbound message to {}:{}", msg.channel, msg.chat_id)
|
||||
continue
|
||||
await self._send_with_retry(channel, msg)
|
||||
else:
|
||||
logger.warning("Unknown channel: {}", msg.channel)
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -166,3 +166,35 @@ class TestMessageToolTurnTracking:
|
||||
tool._sent_in_turn = True
|
||||
tool.start_turn()
|
||||
assert not tool._sent_in_turn
|
||||
|
||||
|
||||
class TestSystemReplySuppression:
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_system_reply_suppressed_when_duplicate(self, tmp_path: Path) -> None:
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager") as MockSessionManager, \
|
||||
patch("nanobot.agent.loop.SubagentManager"):
|
||||
session = MagicMock()
|
||||
session.get_history.return_value = []
|
||||
MockSessionManager.return_value.get_or_create.return_value = session
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10)
|
||||
|
||||
loop._remember_visible_reply("feishu:chat123", "Done")
|
||||
loop._run_agent_loop = AsyncMock(return_value=("Done", [], []))
|
||||
loop._save_turn = MagicMock()
|
||||
loop.sessions.save = MagicMock()
|
||||
|
||||
msg = InboundMessage(
|
||||
channel="system",
|
||||
sender_id="subagent",
|
||||
chat_id="feishu:chat123",
|
||||
content="background result",
|
||||
metadata={"source": "subagent"},
|
||||
)
|
||||
|
||||
result = await loop._process_message(msg)
|
||||
assert result is None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user