mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-22 09:32:33 +00:00
Merge PR #3561: fix: origin_message_id support and outbound deduplication
fix: origin_message_id support and outbound deduplication
This commit is contained in:
commit
e16fa7c6b1
@ -441,6 +441,8 @@ class AgentLoop:
|
||||
if hasattr(tool, "set_context"):
|
||||
if name == "spawn":
|
||||
tool.set_context(channel, chat_id, effective_key=effective_key)
|
||||
if hasattr(tool, "set_origin_message_id"):
|
||||
tool.set_origin_message_id(message_id)
|
||||
elif name == "cron":
|
||||
tool.set_context(channel, chat_id, metadata=metadata, session_key=session_key)
|
||||
elif name == "message":
|
||||
@ -957,6 +959,8 @@ class AgentLoop:
|
||||
outbound_metadata: dict[str, Any] = {}
|
||||
if channel == "slack" and key.startswith("slack:") and key.count(":") >= 2:
|
||||
outbound_metadata["slack"] = {"thread_ts": key.split(":", 2)[2]}
|
||||
if origin_message_id := msg.metadata.get("origin_message_id"):
|
||||
outbound_metadata["origin_message_id"] = origin_message_id
|
||||
return OutboundMessage(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
|
||||
@ -114,6 +114,7 @@ class SubagentManager:
|
||||
origin_channel: str = "cli",
|
||||
origin_chat_id: str = "direct",
|
||||
session_key: str | None = None,
|
||||
origin_message_id: str | None = None,
|
||||
) -> str:
|
||||
"""Spawn a subagent to execute a task in the background."""
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
@ -129,7 +130,7 @@ class SubagentManager:
|
||||
self._task_statuses[task_id] = status
|
||||
|
||||
bg_task = asyncio.create_task(
|
||||
self._run_subagent(task_id, task, display_label, origin, status)
|
||||
self._run_subagent(task_id, task, display_label, origin, status, origin_message_id)
|
||||
)
|
||||
self._running_tasks[task_id] = bg_task
|
||||
if session_key:
|
||||
@ -155,6 +156,7 @@ class SubagentManager:
|
||||
label: str,
|
||||
origin: dict[str, str],
|
||||
status: SubagentStatus,
|
||||
origin_message_id: str | None = None,
|
||||
) -> None:
|
||||
"""Execute the subagent task and announce the result."""
|
||||
logger.info("Subagent [{}] starting task: {}", task_id, label)
|
||||
@ -228,24 +230,24 @@ class SubagentManager:
|
||||
await self._announce_result(
|
||||
task_id, label, task,
|
||||
self._format_partial_progress(result),
|
||||
origin, "error",
|
||||
origin, "error", origin_message_id,
|
||||
)
|
||||
elif result.stop_reason == "error":
|
||||
await self._announce_result(
|
||||
task_id, label, task,
|
||||
result.error or "Error: subagent execution failed.",
|
||||
origin, "error",
|
||||
origin, "error", origin_message_id,
|
||||
)
|
||||
else:
|
||||
final_result = result.final_content or "Task completed but no final response was generated."
|
||||
logger.info("Subagent [{}] completed successfully", task_id)
|
||||
await self._announce_result(task_id, label, task, final_result, origin, "ok")
|
||||
await self._announce_result(task_id, label, task, final_result, origin, "ok", origin_message_id)
|
||||
|
||||
except Exception as e:
|
||||
status.phase = "error"
|
||||
status.error = str(e)
|
||||
logger.error("Subagent [{}] failed: {}", task_id, e)
|
||||
await self._announce_result(task_id, label, task, f"Error: {e}", origin, "error")
|
||||
await self._announce_result(task_id, label, task, f"Error: {e}", origin, "error", origin_message_id)
|
||||
|
||||
async def _announce_result(
|
||||
self,
|
||||
@ -255,6 +257,7 @@ class SubagentManager:
|
||||
result: str,
|
||||
origin: dict[str, str],
|
||||
status: str,
|
||||
origin_message_id: str | None = None,
|
||||
) -> None:
|
||||
"""Announce the subagent result to the main agent via the message bus."""
|
||||
status_text = "completed successfully" if status == "ok" else "failed"
|
||||
@ -273,16 +276,19 @@ class SubagentManager:
|
||||
# routed to the correct pending queue (mid-turn injection) instead of
|
||||
# being dispatched as a competing independent task.
|
||||
override = origin.get("session_key") or f"{origin['channel']}:{origin['chat_id']}"
|
||||
metadata: dict[str, Any] = {
|
||||
"injected_event": "subagent_result",
|
||||
"subagent_task_id": task_id,
|
||||
}
|
||||
if origin_message_id:
|
||||
metadata["origin_message_id"] = origin_message_id
|
||||
msg = InboundMessage(
|
||||
channel="system",
|
||||
sender_id="subagent",
|
||||
chat_id=f"{origin['channel']}:{origin['chat_id']}",
|
||||
content=announce_content,
|
||||
session_key_override=override,
|
||||
metadata={
|
||||
"injected_event": "subagent_result",
|
||||
"subagent_task_id": task_id,
|
||||
},
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
await self.bus.publish_inbound(msg)
|
||||
|
||||
@ -25,6 +25,10 @@ class SpawnTool(Tool):
|
||||
self._origin_channel: ContextVar[str] = ContextVar("spawn_origin_channel", default="cli")
|
||||
self._origin_chat_id: ContextVar[str] = ContextVar("spawn_origin_chat_id", default="direct")
|
||||
self._session_key: ContextVar[str] = ContextVar("spawn_session_key", default="cli:direct")
|
||||
self._origin_message_id: ContextVar[str | None] = ContextVar(
|
||||
"spawn_origin_message_id",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def set_context(self, channel: str, chat_id: str, effective_key: str | None = None) -> None:
|
||||
"""Set the origin context for subagent announcements."""
|
||||
@ -32,6 +36,10 @@ class SpawnTool(Tool):
|
||||
self._origin_chat_id.set(chat_id)
|
||||
self._session_key.set(effective_key or f"{channel}:{chat_id}")
|
||||
|
||||
def set_origin_message_id(self, message_id: str | None) -> None:
|
||||
"""Set the source message id for downstream deduplication."""
|
||||
self._origin_message_id.set(message_id)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "spawn"
|
||||
@ -54,4 +62,5 @@ class SpawnTool(Tool):
|
||||
origin_channel=self._origin_channel.get(),
|
||||
origin_chat_id=self._origin_chat_id.get(),
|
||||
session_key=self._session_key.get(),
|
||||
origin_message_id=self._origin_message_id.get(),
|
||||
)
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@ -37,7 +38,6 @@ _BOOL_CAMEL_ALIASES: dict[str, str] = {
|
||||
"send_tool_hints": "sendToolHints",
|
||||
}
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
"""
|
||||
Manages chat channels and coordinates message routing.
|
||||
@ -60,6 +60,7 @@ class ChannelManager:
|
||||
self._session_manager = session_manager
|
||||
self.channels: dict[str, BaseChannel] = {}
|
||||
self._dispatch_task: asyncio.Task | None = None
|
||||
self._origin_reply_fingerprints: dict[tuple[str, str, str], str] = {}
|
||||
|
||||
self._init_channels()
|
||||
|
||||
@ -232,6 +233,33 @@ class ChannelManager:
|
||||
except Exception as e:
|
||||
logger.error("Error stopping {}: {}", name, e)
|
||||
|
||||
@staticmethod
|
||||
def _fingerprint_content(content: str) -> str:
|
||||
normalized = " ".join(content.split())
|
||||
return hashlib.sha1(normalized.encode("utf-8")).hexdigest() if normalized else ""
|
||||
|
||||
def _should_suppress_outbound(self, msg: OutboundMessage) -> bool:
|
||||
metadata = msg.metadata or {}
|
||||
if metadata.get("_progress"):
|
||||
return False
|
||||
fingerprint = self._fingerprint_content(msg.content)
|
||||
if not fingerprint:
|
||||
return False
|
||||
|
||||
origin_message_id = metadata.get("origin_message_id")
|
||||
if isinstance(origin_message_id, str) and origin_message_id:
|
||||
key = (msg.channel, msg.chat_id, origin_message_id)
|
||||
if self._origin_reply_fingerprints.get(key) == fingerprint:
|
||||
return True
|
||||
self._origin_reply_fingerprints[key] = fingerprint
|
||||
|
||||
message_id = metadata.get("message_id")
|
||||
if isinstance(message_id, str) and message_id:
|
||||
key = (msg.channel, msg.chat_id, message_id)
|
||||
self._origin_reply_fingerprints[key] = fingerprint
|
||||
|
||||
return False
|
||||
|
||||
async def _dispatch_outbound(self) -> None:
|
||||
"""Dispatch outbound messages to the appropriate channel."""
|
||||
logger.info("Outbound dispatcher started")
|
||||
@ -272,6 +300,16 @@ class ChannelManager:
|
||||
|
||||
channel = self.channels.get(msg.channel)
|
||||
if channel:
|
||||
# Duplicate suppression is scoped to a known source message
|
||||
# so repeated content from separate turns is still delivered.
|
||||
if (
|
||||
not msg.metadata.get("_stream_delta")
|
||||
and not msg.metadata.get("_stream_end")
|
||||
and not msg.metadata.get("_streamed")
|
||||
):
|
||||
if self._should_suppress_outbound(msg):
|
||||
logger.info("Suppressing duplicate outbound message to {}:{}", msg.channel, msg.chat_id)
|
||||
continue
|
||||
await self._send_with_retry(channel, msg)
|
||||
else:
|
||||
logger.warning("Unknown channel: {}", msg.channel)
|
||||
|
||||
@ -727,6 +727,7 @@ def test_set_tool_context_passes_thread_session_key_to_spawn(tmp_path: Path) ->
|
||||
loop._set_tool_context(
|
||||
"slack",
|
||||
"C123",
|
||||
message_id="msg-123",
|
||||
metadata={"slack": {"thread_ts": "1700.42", "channel_type": "channel"}},
|
||||
session_key="slack:C123:1700.42",
|
||||
)
|
||||
@ -734,6 +735,7 @@ def test_set_tool_context_passes_thread_session_key_to_spawn(tmp_path: Path) ->
|
||||
spawn_tool = loop.tools.get("spawn")
|
||||
assert spawn_tool is not None
|
||||
assert spawn_tool._session_key.get() == "slack:C123:1700.42"
|
||||
assert spawn_tool._origin_message_id.get() == "msg-123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -766,14 +768,17 @@ async def test_system_subagent_followup_uses_thread_session_and_slack_metadata(t
|
||||
chat_id="slack:C123",
|
||||
content="subagent result",
|
||||
session_key_override="slack:C123:1700.42",
|
||||
metadata={"subagent_task_id": "sub-1"},
|
||||
metadata={"subagent_task_id": "sub-1", "origin_message_id": "msg-123"},
|
||||
)
|
||||
)
|
||||
|
||||
assert outbound is not None
|
||||
assert outbound.channel == "slack"
|
||||
assert outbound.chat_id == "C123"
|
||||
assert outbound.metadata == {"slack": {"thread_ts": "1700.42"}}
|
||||
assert outbound.metadata == {
|
||||
"slack": {"thread_ts": "1700.42"},
|
||||
"origin_message_id": "msg-123",
|
||||
}
|
||||
assert "thread question" in seen["initial_messages"][1]["content"]
|
||||
|
||||
loop.sessions.invalidate("slack:C123:1700.42")
|
||||
|
||||
@ -13,6 +13,8 @@ from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.channels.manager import ChannelManager
|
||||
from nanobot.config.schema import ChannelsConfig
|
||||
from nanobot.providers.transcription import GroqTranscriptionProvider as _GroqProvider
|
||||
from nanobot.providers.transcription import OpenAITranscriptionProvider as _OpenAIProvider
|
||||
from nanobot.utils.restart import RestartNotice
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -338,9 +340,6 @@ async def test_base_channel_passes_language_to_groq_transcription_provider():
|
||||
# Transcription provider HTTP tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from nanobot.providers.transcription import GroqTranscriptionProvider as _GroqProvider
|
||||
from nanobot.providers.transcription import OpenAITranscriptionProvider as _OpenAIProvider
|
||||
|
||||
|
||||
class _StubResponse:
|
||||
def raise_for_status(self):
|
||||
@ -791,6 +790,50 @@ async def test_send_with_retry_skips_send_when_streamed():
|
||||
assert send_delta_called is False
|
||||
|
||||
|
||||
def test_outbound_duplicate_suppression_is_scoped_to_origin_message() -> None:
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(send_max_retries=3),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {}
|
||||
mgr._dispatch_task = None
|
||||
mgr._origin_reply_fingerprints = {}
|
||||
|
||||
first = OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="chat123",
|
||||
content="Done",
|
||||
metadata={"message_id": "msg-1"},
|
||||
)
|
||||
duplicate = OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="chat123",
|
||||
content=" Done ",
|
||||
metadata={"origin_message_id": "msg-1"},
|
||||
)
|
||||
separate_turn = OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="chat123",
|
||||
content="Done",
|
||||
metadata={"message_id": "msg-2"},
|
||||
)
|
||||
new_origin_content = OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="chat123",
|
||||
content="Done with extra details",
|
||||
metadata={"origin_message_id": "msg-1"},
|
||||
)
|
||||
|
||||
assert mgr._should_suppress_outbound(first) is False
|
||||
assert mgr._should_suppress_outbound(duplicate) is True
|
||||
assert mgr._should_suppress_outbound(separate_turn) is False
|
||||
assert mgr._should_suppress_outbound(new_origin_content) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_with_retry_propagates_cancelled_error():
|
||||
"""_send_with_retry should re-raise CancelledError for graceful shutdown."""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user