mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-22 09:32:33 +00:00
Merge PR #3561: fix: origin_message_id support and outbound deduplication
fix: origin_message_id support and outbound deduplication
This commit is contained in:
commit
e16fa7c6b1
@ -441,6 +441,8 @@ class AgentLoop:
|
|||||||
if hasattr(tool, "set_context"):
|
if hasattr(tool, "set_context"):
|
||||||
if name == "spawn":
|
if name == "spawn":
|
||||||
tool.set_context(channel, chat_id, effective_key=effective_key)
|
tool.set_context(channel, chat_id, effective_key=effective_key)
|
||||||
|
if hasattr(tool, "set_origin_message_id"):
|
||||||
|
tool.set_origin_message_id(message_id)
|
||||||
elif name == "cron":
|
elif name == "cron":
|
||||||
tool.set_context(channel, chat_id, metadata=metadata, session_key=session_key)
|
tool.set_context(channel, chat_id, metadata=metadata, session_key=session_key)
|
||||||
elif name == "message":
|
elif name == "message":
|
||||||
@ -957,6 +959,8 @@ class AgentLoop:
|
|||||||
outbound_metadata: dict[str, Any] = {}
|
outbound_metadata: dict[str, Any] = {}
|
||||||
if channel == "slack" and key.startswith("slack:") and key.count(":") >= 2:
|
if channel == "slack" and key.startswith("slack:") and key.count(":") >= 2:
|
||||||
outbound_metadata["slack"] = {"thread_ts": key.split(":", 2)[2]}
|
outbound_metadata["slack"] = {"thread_ts": key.split(":", 2)[2]}
|
||||||
|
if origin_message_id := msg.metadata.get("origin_message_id"):
|
||||||
|
outbound_metadata["origin_message_id"] = origin_message_id
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=channel,
|
channel=channel,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
|
|||||||
@ -114,6 +114,7 @@ class SubagentManager:
|
|||||||
origin_channel: str = "cli",
|
origin_channel: str = "cli",
|
||||||
origin_chat_id: str = "direct",
|
origin_chat_id: str = "direct",
|
||||||
session_key: str | None = None,
|
session_key: str | None = None,
|
||||||
|
origin_message_id: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Spawn a subagent to execute a task in the background."""
|
"""Spawn a subagent to execute a task in the background."""
|
||||||
task_id = str(uuid.uuid4())[:8]
|
task_id = str(uuid.uuid4())[:8]
|
||||||
@ -129,7 +130,7 @@ class SubagentManager:
|
|||||||
self._task_statuses[task_id] = status
|
self._task_statuses[task_id] = status
|
||||||
|
|
||||||
bg_task = asyncio.create_task(
|
bg_task = asyncio.create_task(
|
||||||
self._run_subagent(task_id, task, display_label, origin, status)
|
self._run_subagent(task_id, task, display_label, origin, status, origin_message_id)
|
||||||
)
|
)
|
||||||
self._running_tasks[task_id] = bg_task
|
self._running_tasks[task_id] = bg_task
|
||||||
if session_key:
|
if session_key:
|
||||||
@ -155,6 +156,7 @@ class SubagentManager:
|
|||||||
label: str,
|
label: str,
|
||||||
origin: dict[str, str],
|
origin: dict[str, str],
|
||||||
status: SubagentStatus,
|
status: SubagentStatus,
|
||||||
|
origin_message_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute the subagent task and announce the result."""
|
"""Execute the subagent task and announce the result."""
|
||||||
logger.info("Subagent [{}] starting task: {}", task_id, label)
|
logger.info("Subagent [{}] starting task: {}", task_id, label)
|
||||||
@ -228,24 +230,24 @@ class SubagentManager:
|
|||||||
await self._announce_result(
|
await self._announce_result(
|
||||||
task_id, label, task,
|
task_id, label, task,
|
||||||
self._format_partial_progress(result),
|
self._format_partial_progress(result),
|
||||||
origin, "error",
|
origin, "error", origin_message_id,
|
||||||
)
|
)
|
||||||
elif result.stop_reason == "error":
|
elif result.stop_reason == "error":
|
||||||
await self._announce_result(
|
await self._announce_result(
|
||||||
task_id, label, task,
|
task_id, label, task,
|
||||||
result.error or "Error: subagent execution failed.",
|
result.error or "Error: subagent execution failed.",
|
||||||
origin, "error",
|
origin, "error", origin_message_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
final_result = result.final_content or "Task completed but no final response was generated."
|
final_result = result.final_content or "Task completed but no final response was generated."
|
||||||
logger.info("Subagent [{}] completed successfully", task_id)
|
logger.info("Subagent [{}] completed successfully", task_id)
|
||||||
await self._announce_result(task_id, label, task, final_result, origin, "ok")
|
await self._announce_result(task_id, label, task, final_result, origin, "ok", origin_message_id)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
status.phase = "error"
|
status.phase = "error"
|
||||||
status.error = str(e)
|
status.error = str(e)
|
||||||
logger.error("Subagent [{}] failed: {}", task_id, e)
|
logger.error("Subagent [{}] failed: {}", task_id, e)
|
||||||
await self._announce_result(task_id, label, task, f"Error: {e}", origin, "error")
|
await self._announce_result(task_id, label, task, f"Error: {e}", origin, "error", origin_message_id)
|
||||||
|
|
||||||
async def _announce_result(
|
async def _announce_result(
|
||||||
self,
|
self,
|
||||||
@ -255,6 +257,7 @@ class SubagentManager:
|
|||||||
result: str,
|
result: str,
|
||||||
origin: dict[str, str],
|
origin: dict[str, str],
|
||||||
status: str,
|
status: str,
|
||||||
|
origin_message_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Announce the subagent result to the main agent via the message bus."""
|
"""Announce the subagent result to the main agent via the message bus."""
|
||||||
status_text = "completed successfully" if status == "ok" else "failed"
|
status_text = "completed successfully" if status == "ok" else "failed"
|
||||||
@ -273,16 +276,19 @@ class SubagentManager:
|
|||||||
# routed to the correct pending queue (mid-turn injection) instead of
|
# routed to the correct pending queue (mid-turn injection) instead of
|
||||||
# being dispatched as a competing independent task.
|
# being dispatched as a competing independent task.
|
||||||
override = origin.get("session_key") or f"{origin['channel']}:{origin['chat_id']}"
|
override = origin.get("session_key") or f"{origin['channel']}:{origin['chat_id']}"
|
||||||
|
metadata: dict[str, Any] = {
|
||||||
|
"injected_event": "subagent_result",
|
||||||
|
"subagent_task_id": task_id,
|
||||||
|
}
|
||||||
|
if origin_message_id:
|
||||||
|
metadata["origin_message_id"] = origin_message_id
|
||||||
msg = InboundMessage(
|
msg = InboundMessage(
|
||||||
channel="system",
|
channel="system",
|
||||||
sender_id="subagent",
|
sender_id="subagent",
|
||||||
chat_id=f"{origin['channel']}:{origin['chat_id']}",
|
chat_id=f"{origin['channel']}:{origin['chat_id']}",
|
||||||
content=announce_content,
|
content=announce_content,
|
||||||
session_key_override=override,
|
session_key_override=override,
|
||||||
metadata={
|
metadata=metadata,
|
||||||
"injected_event": "subagent_result",
|
|
||||||
"subagent_task_id": task_id,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.bus.publish_inbound(msg)
|
await self.bus.publish_inbound(msg)
|
||||||
|
|||||||
@ -25,6 +25,10 @@ class SpawnTool(Tool):
|
|||||||
self._origin_channel: ContextVar[str] = ContextVar("spawn_origin_channel", default="cli")
|
self._origin_channel: ContextVar[str] = ContextVar("spawn_origin_channel", default="cli")
|
||||||
self._origin_chat_id: ContextVar[str] = ContextVar("spawn_origin_chat_id", default="direct")
|
self._origin_chat_id: ContextVar[str] = ContextVar("spawn_origin_chat_id", default="direct")
|
||||||
self._session_key: ContextVar[str] = ContextVar("spawn_session_key", default="cli:direct")
|
self._session_key: ContextVar[str] = ContextVar("spawn_session_key", default="cli:direct")
|
||||||
|
self._origin_message_id: ContextVar[str | None] = ContextVar(
|
||||||
|
"spawn_origin_message_id",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
def set_context(self, channel: str, chat_id: str, effective_key: str | None = None) -> None:
|
def set_context(self, channel: str, chat_id: str, effective_key: str | None = None) -> None:
|
||||||
"""Set the origin context for subagent announcements."""
|
"""Set the origin context for subagent announcements."""
|
||||||
@ -32,6 +36,10 @@ class SpawnTool(Tool):
|
|||||||
self._origin_chat_id.set(chat_id)
|
self._origin_chat_id.set(chat_id)
|
||||||
self._session_key.set(effective_key or f"{channel}:{chat_id}")
|
self._session_key.set(effective_key or f"{channel}:{chat_id}")
|
||||||
|
|
||||||
|
def set_origin_message_id(self, message_id: str | None) -> None:
|
||||||
|
"""Set the source message id for downstream deduplication."""
|
||||||
|
self._origin_message_id.set(message_id)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "spawn"
|
return "spawn"
|
||||||
@ -54,4 +62,5 @@ class SpawnTool(Tool):
|
|||||||
origin_channel=self._origin_channel.get(),
|
origin_channel=self._origin_channel.get(),
|
||||||
origin_chat_id=self._origin_chat_id.get(),
|
origin_chat_id=self._origin_chat_id.get(),
|
||||||
session_key=self._session_key.get(),
|
session_key=self._session_key.get(),
|
||||||
|
origin_message_id=self._origin_message_id.get(),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import hashlib
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
@ -37,7 +38,6 @@ _BOOL_CAMEL_ALIASES: dict[str, str] = {
|
|||||||
"send_tool_hints": "sendToolHints",
|
"send_tool_hints": "sendToolHints",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ChannelManager:
|
class ChannelManager:
|
||||||
"""
|
"""
|
||||||
Manages chat channels and coordinates message routing.
|
Manages chat channels and coordinates message routing.
|
||||||
@ -60,6 +60,7 @@ class ChannelManager:
|
|||||||
self._session_manager = session_manager
|
self._session_manager = session_manager
|
||||||
self.channels: dict[str, BaseChannel] = {}
|
self.channels: dict[str, BaseChannel] = {}
|
||||||
self._dispatch_task: asyncio.Task | None = None
|
self._dispatch_task: asyncio.Task | None = None
|
||||||
|
self._origin_reply_fingerprints: dict[tuple[str, str, str], str] = {}
|
||||||
|
|
||||||
self._init_channels()
|
self._init_channels()
|
||||||
|
|
||||||
@ -232,6 +233,33 @@ class ChannelManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error stopping {}: {}", name, e)
|
logger.error("Error stopping {}: {}", name, e)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fingerprint_content(content: str) -> str:
|
||||||
|
normalized = " ".join(content.split())
|
||||||
|
return hashlib.sha1(normalized.encode("utf-8")).hexdigest() if normalized else ""
|
||||||
|
|
||||||
|
def _should_suppress_outbound(self, msg: OutboundMessage) -> bool:
|
||||||
|
metadata = msg.metadata or {}
|
||||||
|
if metadata.get("_progress"):
|
||||||
|
return False
|
||||||
|
fingerprint = self._fingerprint_content(msg.content)
|
||||||
|
if not fingerprint:
|
||||||
|
return False
|
||||||
|
|
||||||
|
origin_message_id = metadata.get("origin_message_id")
|
||||||
|
if isinstance(origin_message_id, str) and origin_message_id:
|
||||||
|
key = (msg.channel, msg.chat_id, origin_message_id)
|
||||||
|
if self._origin_reply_fingerprints.get(key) == fingerprint:
|
||||||
|
return True
|
||||||
|
self._origin_reply_fingerprints[key] = fingerprint
|
||||||
|
|
||||||
|
message_id = metadata.get("message_id")
|
||||||
|
if isinstance(message_id, str) and message_id:
|
||||||
|
key = (msg.channel, msg.chat_id, message_id)
|
||||||
|
self._origin_reply_fingerprints[key] = fingerprint
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
async def _dispatch_outbound(self) -> None:
|
async def _dispatch_outbound(self) -> None:
|
||||||
"""Dispatch outbound messages to the appropriate channel."""
|
"""Dispatch outbound messages to the appropriate channel."""
|
||||||
logger.info("Outbound dispatcher started")
|
logger.info("Outbound dispatcher started")
|
||||||
@ -272,6 +300,16 @@ class ChannelManager:
|
|||||||
|
|
||||||
channel = self.channels.get(msg.channel)
|
channel = self.channels.get(msg.channel)
|
||||||
if channel:
|
if channel:
|
||||||
|
# Duplicate suppression is scoped to a known source message
|
||||||
|
# so repeated content from separate turns is still delivered.
|
||||||
|
if (
|
||||||
|
not msg.metadata.get("_stream_delta")
|
||||||
|
and not msg.metadata.get("_stream_end")
|
||||||
|
and not msg.metadata.get("_streamed")
|
||||||
|
):
|
||||||
|
if self._should_suppress_outbound(msg):
|
||||||
|
logger.info("Suppressing duplicate outbound message to {}:{}", msg.channel, msg.chat_id)
|
||||||
|
continue
|
||||||
await self._send_with_retry(channel, msg)
|
await self._send_with_retry(channel, msg)
|
||||||
else:
|
else:
|
||||||
logger.warning("Unknown channel: {}", msg.channel)
|
logger.warning("Unknown channel: {}", msg.channel)
|
||||||
|
|||||||
@ -727,6 +727,7 @@ def test_set_tool_context_passes_thread_session_key_to_spawn(tmp_path: Path) ->
|
|||||||
loop._set_tool_context(
|
loop._set_tool_context(
|
||||||
"slack",
|
"slack",
|
||||||
"C123",
|
"C123",
|
||||||
|
message_id="msg-123",
|
||||||
metadata={"slack": {"thread_ts": "1700.42", "channel_type": "channel"}},
|
metadata={"slack": {"thread_ts": "1700.42", "channel_type": "channel"}},
|
||||||
session_key="slack:C123:1700.42",
|
session_key="slack:C123:1700.42",
|
||||||
)
|
)
|
||||||
@ -734,6 +735,7 @@ def test_set_tool_context_passes_thread_session_key_to_spawn(tmp_path: Path) ->
|
|||||||
spawn_tool = loop.tools.get("spawn")
|
spawn_tool = loop.tools.get("spawn")
|
||||||
assert spawn_tool is not None
|
assert spawn_tool is not None
|
||||||
assert spawn_tool._session_key.get() == "slack:C123:1700.42"
|
assert spawn_tool._session_key.get() == "slack:C123:1700.42"
|
||||||
|
assert spawn_tool._origin_message_id.get() == "msg-123"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -766,14 +768,17 @@ async def test_system_subagent_followup_uses_thread_session_and_slack_metadata(t
|
|||||||
chat_id="slack:C123",
|
chat_id="slack:C123",
|
||||||
content="subagent result",
|
content="subagent result",
|
||||||
session_key_override="slack:C123:1700.42",
|
session_key_override="slack:C123:1700.42",
|
||||||
metadata={"subagent_task_id": "sub-1"},
|
metadata={"subagent_task_id": "sub-1", "origin_message_id": "msg-123"},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert outbound is not None
|
assert outbound is not None
|
||||||
assert outbound.channel == "slack"
|
assert outbound.channel == "slack"
|
||||||
assert outbound.chat_id == "C123"
|
assert outbound.chat_id == "C123"
|
||||||
assert outbound.metadata == {"slack": {"thread_ts": "1700.42"}}
|
assert outbound.metadata == {
|
||||||
|
"slack": {"thread_ts": "1700.42"},
|
||||||
|
"origin_message_id": "msg-123",
|
||||||
|
}
|
||||||
assert "thread question" in seen["initial_messages"][1]["content"]
|
assert "thread question" in seen["initial_messages"][1]["content"]
|
||||||
|
|
||||||
loop.sessions.invalidate("slack:C123:1700.42")
|
loop.sessions.invalidate("slack:C123:1700.42")
|
||||||
|
|||||||
@ -13,6 +13,8 @@ from nanobot.bus.queue import MessageBus
|
|||||||
from nanobot.channels.base import BaseChannel
|
from nanobot.channels.base import BaseChannel
|
||||||
from nanobot.channels.manager import ChannelManager
|
from nanobot.channels.manager import ChannelManager
|
||||||
from nanobot.config.schema import ChannelsConfig
|
from nanobot.config.schema import ChannelsConfig
|
||||||
|
from nanobot.providers.transcription import GroqTranscriptionProvider as _GroqProvider
|
||||||
|
from nanobot.providers.transcription import OpenAITranscriptionProvider as _OpenAIProvider
|
||||||
from nanobot.utils.restart import RestartNotice
|
from nanobot.utils.restart import RestartNotice
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@ -338,9 +340,6 @@ async def test_base_channel_passes_language_to_groq_transcription_provider():
|
|||||||
# Transcription provider HTTP tests
|
# Transcription provider HTTP tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
from nanobot.providers.transcription import GroqTranscriptionProvider as _GroqProvider
|
|
||||||
from nanobot.providers.transcription import OpenAITranscriptionProvider as _OpenAIProvider
|
|
||||||
|
|
||||||
|
|
||||||
class _StubResponse:
|
class _StubResponse:
|
||||||
def raise_for_status(self):
|
def raise_for_status(self):
|
||||||
@ -791,6 +790,50 @@ async def test_send_with_retry_skips_send_when_streamed():
|
|||||||
assert send_delta_called is False
|
assert send_delta_called is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_outbound_duplicate_suppression_is_scoped_to_origin_message() -> None:
|
||||||
|
fake_config = SimpleNamespace(
|
||||||
|
channels=ChannelsConfig(send_max_retries=3),
|
||||||
|
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||||
|
)
|
||||||
|
|
||||||
|
mgr = ChannelManager.__new__(ChannelManager)
|
||||||
|
mgr.config = fake_config
|
||||||
|
mgr.bus = MessageBus()
|
||||||
|
mgr.channels = {}
|
||||||
|
mgr._dispatch_task = None
|
||||||
|
mgr._origin_reply_fingerprints = {}
|
||||||
|
|
||||||
|
first = OutboundMessage(
|
||||||
|
channel="feishu",
|
||||||
|
chat_id="chat123",
|
||||||
|
content="Done",
|
||||||
|
metadata={"message_id": "msg-1"},
|
||||||
|
)
|
||||||
|
duplicate = OutboundMessage(
|
||||||
|
channel="feishu",
|
||||||
|
chat_id="chat123",
|
||||||
|
content=" Done ",
|
||||||
|
metadata={"origin_message_id": "msg-1"},
|
||||||
|
)
|
||||||
|
separate_turn = OutboundMessage(
|
||||||
|
channel="feishu",
|
||||||
|
chat_id="chat123",
|
||||||
|
content="Done",
|
||||||
|
metadata={"message_id": "msg-2"},
|
||||||
|
)
|
||||||
|
new_origin_content = OutboundMessage(
|
||||||
|
channel="feishu",
|
||||||
|
chat_id="chat123",
|
||||||
|
content="Done with extra details",
|
||||||
|
metadata={"origin_message_id": "msg-1"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mgr._should_suppress_outbound(first) is False
|
||||||
|
assert mgr._should_suppress_outbound(duplicate) is True
|
||||||
|
assert mgr._should_suppress_outbound(separate_turn) is False
|
||||||
|
assert mgr._should_suppress_outbound(new_origin_content) is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_with_retry_propagates_cancelled_error():
|
async def test_send_with_retry_propagates_cancelled_error():
|
||||||
"""_send_with_retry should re-raise CancelledError for graceful shutdown."""
|
"""_send_with_retry should re-raise CancelledError for graceful shutdown."""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user