fix(agent): persist cross-channel messages into target session history

When session A (e.g. websocket) uses the `message` tool to send to
channel B (e.g. feishu), the outbound message is delivered to the user
but was never recorded in session B's history. This caused session B to
lose context when the user replied on that channel.

Add `_persist_cross_channel_calls()` to detect cross-channel `message`
tool calls during `_save_turn()` and append a lightweight assistant
entry (with `_cross_channel: True` marker) to the target session.
This commit is contained in:
chengyongru 2026-04-14 00:14:30 +08:00
parent 3c06db7e4e
commit 1747ed7885
2 changed files with 205 additions and 0 deletions

View File

@ -843,8 +843,68 @@ class AgentLoop:
entry["content"] = filtered
entry.setdefault("timestamp", datetime.now().isoformat())
session.messages.append(entry)
# Persist cross-channel message tool calls into target sessions so
# that the target session has context when the user replies there.
self._persist_cross_channel_calls(session, messages[skip:])
session.updated_at = datetime.now()
def _persist_cross_channel_calls(
self, source_session: Session, new_messages: list[dict[str, Any]]
) -> None:
"""Record cross-channel ``message`` tool calls into the target session.
When session A (e.g. websocket) uses the *message* tool to send to
channel B (e.g. feishu), the outbound message is delivered to the user
but is not recorded in session B's history. This causes session B to
lose context when the user replies on channel B.
This method detects such cross-channel sends and appends a lightweight
assistant entry to the target session so it has the necessary context.
"""
from datetime import datetime
for m in new_messages:
if m.get("role") != "assistant":
continue
tool_calls = m.get("tool_calls") or []
for tc in tool_calls:
func = tc.get("function", {})
if func.get("name") != "message":
continue
try:
args = json.loads(func.get("arguments", "{}"))
except (json.JSONDecodeError, TypeError):
continue
target_channel = args.get("channel") or source_session.key.split(":", 1)[0]
target_chat_id = args.get("chat_id") or source_session.key.split(":", 1)[-1]
target_key = f"{target_channel}:{target_chat_id}"
if target_key == source_session.key:
continue # same session, nothing to do
content = args.get("content", "")
if not content:
continue
target_session = self.sessions._cache.get(target_key)
if target_session is None:
continue # target session doesn't exist yet, skip silently
target_session.messages.append({
"role": "assistant",
"content": content,
"timestamp": datetime.now().isoformat(),
"_cross_channel": True,
})
target_session.updated_at = datetime.now()
self.sessions.save(target_session)
logger.info(
"Cross-channel message persisted: {} -> {}",
source_session.key, target_key,
)
def _set_runtime_checkpoint(self, session: Session, payload: dict[str, Any]) -> None:
"""Persist the latest in-flight turn state into session metadata."""
session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload

View File

@ -1,3 +1,4 @@
import json
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
@ -308,3 +309,147 @@ async def test_next_turn_after_crash_closes_pending_user_turn_before_new_input(t
{"role": "assistant", "content": "new answer"},
]
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
def _cross_channel_messages(
source_channel: str = "websocket",
source_chat_id: str = "ws-uuid-123",
target_channel: str = "feishu",
target_chat_id: str = "ou_abc123",
) -> list[dict]:
"""Build a message list with a cross-channel message tool call."""
return [
{"role": "user", "content": "send report to feishu"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_x1",
"type": "function",
"function": {
"name": "message",
"arguments": json.dumps({
"content": "Report: audit complete",
"channel": target_channel,
"chat_id": target_chat_id,
}),
},
}
],
},
{
"role": "tool",
"tool_call_id": "call_x1",
"name": "message",
"content": f"Message sent to {target_channel}:{target_chat_id}",
},
{"role": "assistant", "content": "Done, sent to feishu."},
]
def test_cross_channel_message_persisted_in_target_session(tmp_path: Path) -> None:
loop = _make_full_loop(tmp_path)
source_key = "websocket:ws-uuid-123"
target_key = "feishu:ou_abc123"
# Pre-create the target session (simulate an existing feishu conversation)
target_session = loop.sessions.get_or_create(target_key)
target_session.add_message("user", "hello from feishu")
loop.sessions.save(target_session)
source_session = loop.sessions.get_or_create(source_key)
msgs = _cross_channel_messages()
loop._save_turn(source_session, msgs, skip=1) # skip user message
# Source session has its own messages
source_session = loop.sessions.get_or_create(source_key)
assert len(source_session.messages) >= 2 # assistant + tool + final
# Target session now has the cross-channel message appended
loop.sessions.invalidate(target_key)
target = loop.sessions.get_or_create(target_key)
cross_msg = [m for m in target.messages if m.get("_cross_channel")]
assert len(cross_msg) == 1
assert cross_msg[0]["content"] == "Report: audit complete"
assert cross_msg[0]["role"] == "assistant"
def test_cross_channel_same_session_not_duplicated(tmp_path: Path) -> None:
loop = _make_full_loop(tmp_path)
key = "feishu:ou_same"
session = loop.sessions.get_or_create(key)
# message tool call targeting the same session — should NOT create a duplicate
msgs = [
{"role": "user", "content": "hello"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_s1",
"type": "function",
"function": {
"name": "message",
"arguments": json.dumps({
"content": "same channel msg",
"channel": "feishu",
"chat_id": "ou_same",
}),
},
}
],
},
{"role": "tool", "tool_call_id": "call_s1", "name": "message", "content": "ok"},
]
loop._save_turn(session, msgs, skip=1)
# No _cross_channel entries should exist
assert all(not m.get("_cross_channel") for m in session.messages)
def test_cross_channel_target_session_not_exist(tmp_path: Path) -> None:
loop = _make_full_loop(tmp_path)
source_session = loop.sessions.get_or_create("websocket:ws-xyz")
msgs = _cross_channel_messages(
target_channel="feishu", target_chat_id="ou_nonexistent"
)
loop._save_turn(source_session, msgs, skip=1)
# Target session was never created, so no error should occur
assert "feishu:ou_nonexistent" not in loop.sessions._cache
def test_cross_channel_non_message_tools_ignored(tmp_path: Path) -> None:
loop = _make_full_loop(tmp_path)
source_session = loop.sessions.get_or_create("websocket:ws-abc")
target_session = loop.sessions.get_or_create("feishu:ou_tgt")
target_session.add_message("user", "hi")
loop.sessions.save(target_session)
msgs = [
{"role": "user", "content": "do stuff"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_e1",
"type": "function",
"function": {
"name": "exec",
"arguments": json.dumps({"command": "echo hi"}),
},
}
],
},
{"role": "tool", "tool_call_id": "call_e1", "name": "exec", "content": "hi"},
]
loop._save_turn(source_session, msgs, skip=1)
# exec tool should NOT produce cross-channel entries
target = loop.sessions.get_or_create("feishu:ou_tgt")
cross_msgs = [m for m in target.messages if m.get("_cross_channel")]
assert len(cross_msgs) == 0