mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-04 00:35:58 +00:00
fix: add origin_message_id support for spawn and message deduplication
This commit is contained in:
parent
3c20d16117
commit
4e06c00b46
@ -25,6 +25,7 @@ class SpawnTool(Tool):
|
|||||||
self._origin_channel: ContextVar[str] = ContextVar("spawn_origin_channel", default="cli")
|
self._origin_channel: ContextVar[str] = ContextVar("spawn_origin_channel", default="cli")
|
||||||
self._origin_chat_id: ContextVar[str] = ContextVar("spawn_origin_chat_id", default="direct")
|
self._origin_chat_id: ContextVar[str] = ContextVar("spawn_origin_chat_id", default="direct")
|
||||||
self._session_key: ContextVar[str] = ContextVar("spawn_session_key", default="cli:direct")
|
self._session_key: ContextVar[str] = ContextVar("spawn_session_key", default="cli:direct")
|
||||||
|
self._origin_message_id: str | None = None
|
||||||
|
|
||||||
def set_context(self, channel: str, chat_id: str, effective_key: str | None = None) -> None:
|
def set_context(self, channel: str, chat_id: str, effective_key: str | None = None) -> None:
|
||||||
"""Set the origin context for subagent announcements."""
|
"""Set the origin context for subagent announcements."""
|
||||||
@ -32,6 +33,10 @@ class SpawnTool(Tool):
|
|||||||
self._origin_chat_id.set(chat_id)
|
self._origin_chat_id.set(chat_id)
|
||||||
self._session_key.set(effective_key or f"{channel}:{chat_id}")
|
self._session_key.set(effective_key or f"{channel}:{chat_id}")
|
||||||
|
|
||||||
|
def set_origin_message_id(self, message_id: str | None) -> None:
|
||||||
|
"""Set the source message id for downstream deduplication."""
|
||||||
|
self._origin_message_id = message_id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "spawn"
|
return "spawn"
|
||||||
@ -54,4 +59,5 @@ class SpawnTool(Tool):
|
|||||||
origin_channel=self._origin_channel.get(),
|
origin_channel=self._origin_channel.get(),
|
||||||
origin_chat_id=self._origin_chat_id.get(),
|
origin_chat_id=self._origin_chat_id.get(),
|
||||||
session_key=self._session_key.get(),
|
session_key=self._session_key.get(),
|
||||||
|
origin_message_id=self._origin_message_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import hashlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
@ -37,6 +38,12 @@ _BOOL_CAMEL_ALIASES: dict[str, str] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _RecentOutbound:
|
||||||
|
fingerprint: str
|
||||||
|
ts: float
|
||||||
|
|
||||||
|
|
||||||
class ChannelManager:
|
class ChannelManager:
|
||||||
"""
|
"""
|
||||||
Manages chat channels and coordinates message routing.
|
Manages chat channels and coordinates message routing.
|
||||||
@ -59,6 +66,7 @@ class ChannelManager:
|
|||||||
self._session_manager = session_manager
|
self._session_manager = session_manager
|
||||||
self.channels: dict[str, BaseChannel] = {}
|
self.channels: dict[str, BaseChannel] = {}
|
||||||
self._dispatch_task: asyncio.Task | None = None
|
self._dispatch_task: asyncio.Task | None = None
|
||||||
|
self._recent_outbound: dict[tuple[str, str], _RecentOutbound] = {}
|
||||||
|
|
||||||
self._init_channels()
|
self._init_channels()
|
||||||
|
|
||||||
@ -233,6 +241,25 @@ class ChannelManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error stopping {}: {}", name, e)
|
logger.error("Error stopping {}: {}", name, e)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fingerprint_content(content: str) -> str:
|
||||||
|
normalized = " ".join(content.split())
|
||||||
|
return hashlib.sha1(normalized.encode("utf-8")).hexdigest() if normalized else ""
|
||||||
|
|
||||||
|
def _should_suppress_outbound(self, msg: OutboundMessage) -> bool:
|
||||||
|
if msg.metadata.get("_progress"):
|
||||||
|
return False
|
||||||
|
fingerprint = self._fingerprint_content(msg.content)
|
||||||
|
if not fingerprint:
|
||||||
|
return False
|
||||||
|
key = (msg.channel, msg.chat_id)
|
||||||
|
recent = self._recent_outbound.get(key)
|
||||||
|
now = asyncio.get_running_loop().time()
|
||||||
|
if recent and recent.fingerprint == fingerprint and now - recent.ts <= 8.0:
|
||||||
|
return True
|
||||||
|
self._recent_outbound[key] = _RecentOutbound(fingerprint=fingerprint, ts=now)
|
||||||
|
return False
|
||||||
|
|
||||||
async def _dispatch_outbound(self) -> None:
|
async def _dispatch_outbound(self) -> None:
|
||||||
"""Dispatch outbound messages to the appropriate channel."""
|
"""Dispatch outbound messages to the appropriate channel."""
|
||||||
logger.info("Outbound dispatcher started")
|
logger.info("Outbound dispatcher started")
|
||||||
@ -273,6 +300,11 @@ class ChannelManager:
|
|||||||
|
|
||||||
channel = self.channels.get(msg.channel)
|
channel = self.channels.get(msg.channel)
|
||||||
if channel:
|
if channel:
|
||||||
|
# Duplicate suppression (non-streaming only)
|
||||||
|
if not msg.metadata.get("_stream_delta") and not msg.metadata.get("_stream_end") and not msg.metadata.get("_streamed"):
|
||||||
|
if self._should_suppress_outbound(msg):
|
||||||
|
logger.info("Suppressing duplicate outbound message to {}:{}", msg.channel, msg.chat_id)
|
||||||
|
continue
|
||||||
await self._send_with_retry(channel, msg)
|
await self._send_with_retry(channel, msg)
|
||||||
else:
|
else:
|
||||||
logger.warning("Unknown channel: {}", msg.channel)
|
logger.warning("Unknown channel: {}", msg.channel)
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -166,3 +166,35 @@ class TestMessageToolTurnTracking:
|
|||||||
tool._sent_in_turn = True
|
tool._sent_in_turn = True
|
||||||
tool.start_turn()
|
tool.start_turn()
|
||||||
assert not tool._sent_in_turn
|
assert not tool._sent_in_turn
|
||||||
|
|
||||||
|
|
||||||
|
class TestSystemReplySuppression:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subagent_system_reply_suppressed_when_duplicate(self, tmp_path: Path) -> None:
|
||||||
|
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||||
|
patch("nanobot.agent.loop.SessionManager") as MockSessionManager, \
|
||||||
|
patch("nanobot.agent.loop.SubagentManager"):
|
||||||
|
session = MagicMock()
|
||||||
|
session.get_history.return_value = []
|
||||||
|
MockSessionManager.return_value.get_or_create.return_value = session
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10)
|
||||||
|
|
||||||
|
loop._remember_visible_reply("feishu:chat123", "Done")
|
||||||
|
loop._run_agent_loop = AsyncMock(return_value=("Done", [], []))
|
||||||
|
loop._save_turn = MagicMock()
|
||||||
|
loop.sessions.save = MagicMock()
|
||||||
|
|
||||||
|
msg = InboundMessage(
|
||||||
|
channel="system",
|
||||||
|
sender_id="subagent",
|
||||||
|
chat_id="feishu:chat123",
|
||||||
|
content="background result",
|
||||||
|
metadata={"source": "subagent"},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await loop._process_message(msg)
|
||||||
|
assert result is None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user