mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 16:42:25 +00:00
fix(agent): persist cross-channel messages into target session history
When session A (e.g. websocket) uses the `message` tool to send to channel B (e.g. feishu), the outbound message is delivered to the user but was never recorded in session B's history. This caused session B to lose context when the user replied on that channel. Add `_persist_cross_channel_calls()` to detect cross-channel `message` tool calls during `_save_turn()` and append a lightweight assistant entry (with `_cross_channel: True` marker) to the target session.
This commit is contained in:
parent
3c06db7e4e
commit
1747ed7885
@ -843,8 +843,68 @@ class AgentLoop:
|
||||
entry["content"] = filtered
|
||||
entry.setdefault("timestamp", datetime.now().isoformat())
|
||||
session.messages.append(entry)
|
||||
|
||||
# Persist cross-channel message tool calls into target sessions so
|
||||
# that the target session has context when the user replies there.
|
||||
self._persist_cross_channel_calls(session, messages[skip:])
|
||||
session.updated_at = datetime.now()
|
||||
|
||||
def _persist_cross_channel_calls(
|
||||
self, source_session: Session, new_messages: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Record cross-channel ``message`` tool calls into the target session.
|
||||
|
||||
When session A (e.g. websocket) uses the *message* tool to send to
|
||||
channel B (e.g. feishu), the outbound message is delivered to the user
|
||||
but is not recorded in session B's history. This causes session B to
|
||||
lose context when the user replies on channel B.
|
||||
|
||||
This method detects such cross-channel sends and appends a lightweight
|
||||
assistant entry to the target session so it has the necessary context.
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
for m in new_messages:
|
||||
if m.get("role") != "assistant":
|
||||
continue
|
||||
tool_calls = m.get("tool_calls") or []
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function", {})
|
||||
if func.get("name") != "message":
|
||||
continue
|
||||
try:
|
||||
args = json.loads(func.get("arguments", "{}"))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
|
||||
target_channel = args.get("channel") or source_session.key.split(":", 1)[0]
|
||||
target_chat_id = args.get("chat_id") or source_session.key.split(":", 1)[-1]
|
||||
target_key = f"{target_channel}:{target_chat_id}"
|
||||
|
||||
if target_key == source_session.key:
|
||||
continue # same session, nothing to do
|
||||
|
||||
content = args.get("content", "")
|
||||
if not content:
|
||||
continue
|
||||
|
||||
target_session = self.sessions._cache.get(target_key)
|
||||
if target_session is None:
|
||||
continue # target session doesn't exist yet, skip silently
|
||||
|
||||
target_session.messages.append({
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"_cross_channel": True,
|
||||
})
|
||||
target_session.updated_at = datetime.now()
|
||||
self.sessions.save(target_session)
|
||||
logger.info(
|
||||
"Cross-channel message persisted: {} -> {}",
|
||||
source_session.key, target_key,
|
||||
)
|
||||
|
||||
def _set_runtime_checkpoint(self, session: Session, payload: dict[str, Any]) -> None:
|
||||
"""Persist the latest in-flight turn state into session metadata."""
|
||||
session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
@ -308,3 +309,147 @@ async def test_next_turn_after_crash_closes_pending_user_turn_before_new_input(t
|
||||
{"role": "assistant", "content": "new answer"},
|
||||
]
|
||||
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
|
||||
|
||||
|
||||
def _cross_channel_messages(
|
||||
source_channel: str = "websocket",
|
||||
source_chat_id: str = "ws-uuid-123",
|
||||
target_channel: str = "feishu",
|
||||
target_chat_id: str = "ou_abc123",
|
||||
) -> list[dict]:
|
||||
"""Build a message list with a cross-channel message tool call."""
|
||||
return [
|
||||
{"role": "user", "content": "send report to feishu"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_x1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "message",
|
||||
"arguments": json.dumps({
|
||||
"content": "Report: audit complete",
|
||||
"channel": target_channel,
|
||||
"chat_id": target_chat_id,
|
||||
}),
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_x1",
|
||||
"name": "message",
|
||||
"content": f"Message sent to {target_channel}:{target_chat_id}",
|
||||
},
|
||||
{"role": "assistant", "content": "Done, sent to feishu."},
|
||||
]
|
||||
|
||||
|
||||
def test_cross_channel_message_persisted_in_target_session(tmp_path: Path) -> None:
|
||||
loop = _make_full_loop(tmp_path)
|
||||
source_key = "websocket:ws-uuid-123"
|
||||
target_key = "feishu:ou_abc123"
|
||||
|
||||
# Pre-create the target session (simulate an existing feishu conversation)
|
||||
target_session = loop.sessions.get_or_create(target_key)
|
||||
target_session.add_message("user", "hello from feishu")
|
||||
loop.sessions.save(target_session)
|
||||
|
||||
source_session = loop.sessions.get_or_create(source_key)
|
||||
msgs = _cross_channel_messages()
|
||||
loop._save_turn(source_session, msgs, skip=1) # skip user message
|
||||
|
||||
# Source session has its own messages
|
||||
source_session = loop.sessions.get_or_create(source_key)
|
||||
assert len(source_session.messages) >= 2 # assistant + tool + final
|
||||
|
||||
# Target session now has the cross-channel message appended
|
||||
loop.sessions.invalidate(target_key)
|
||||
target = loop.sessions.get_or_create(target_key)
|
||||
cross_msg = [m for m in target.messages if m.get("_cross_channel")]
|
||||
assert len(cross_msg) == 1
|
||||
assert cross_msg[0]["content"] == "Report: audit complete"
|
||||
assert cross_msg[0]["role"] == "assistant"
|
||||
|
||||
|
||||
def test_cross_channel_same_session_not_duplicated(tmp_path: Path) -> None:
|
||||
loop = _make_full_loop(tmp_path)
|
||||
key = "feishu:ou_same"
|
||||
|
||||
session = loop.sessions.get_or_create(key)
|
||||
# message tool call targeting the same session — should NOT create a duplicate
|
||||
msgs = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_s1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "message",
|
||||
"arguments": json.dumps({
|
||||
"content": "same channel msg",
|
||||
"channel": "feishu",
|
||||
"chat_id": "ou_same",
|
||||
}),
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_s1", "name": "message", "content": "ok"},
|
||||
]
|
||||
loop._save_turn(session, msgs, skip=1)
|
||||
|
||||
# No _cross_channel entries should exist
|
||||
assert all(not m.get("_cross_channel") for m in session.messages)
|
||||
|
||||
|
||||
def test_cross_channel_target_session_not_exist(tmp_path: Path) -> None:
|
||||
loop = _make_full_loop(tmp_path)
|
||||
source_session = loop.sessions.get_or_create("websocket:ws-xyz")
|
||||
|
||||
msgs = _cross_channel_messages(
|
||||
target_channel="feishu", target_chat_id="ou_nonexistent"
|
||||
)
|
||||
loop._save_turn(source_session, msgs, skip=1)
|
||||
|
||||
# Target session was never created, so no error should occur
|
||||
assert "feishu:ou_nonexistent" not in loop.sessions._cache
|
||||
|
||||
|
||||
def test_cross_channel_non_message_tools_ignored(tmp_path: Path) -> None:
|
||||
loop = _make_full_loop(tmp_path)
|
||||
source_session = loop.sessions.get_or_create("websocket:ws-abc")
|
||||
target_session = loop.sessions.get_or_create("feishu:ou_tgt")
|
||||
target_session.add_message("user", "hi")
|
||||
loop.sessions.save(target_session)
|
||||
|
||||
msgs = [
|
||||
{"role": "user", "content": "do stuff"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_e1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "exec",
|
||||
"arguments": json.dumps({"command": "echo hi"}),
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_e1", "name": "exec", "content": "hi"},
|
||||
]
|
||||
loop._save_turn(source_session, msgs, skip=1)
|
||||
|
||||
# exec tool should NOT produce cross-channel entries
|
||||
target = loop.sessions.get_or_create("feishu:ou_tgt")
|
||||
cross_msgs = [m for m in target.messages if m.get("_cross_channel")]
|
||||
assert len(cross_msgs) == 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user