diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 39e1ce23a..d973d860c 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -843,8 +843,68 @@ class AgentLoop: entry["content"] = filtered entry.setdefault("timestamp", datetime.now().isoformat()) session.messages.append(entry) + + # Persist cross-channel message tool calls into target sessions so + # that the target session has context when the user replies there. + self._persist_cross_channel_calls(session, messages[skip:]) session.updated_at = datetime.now() + def _persist_cross_channel_calls( + self, source_session: Session, new_messages: list[dict[str, Any]] + ) -> None: + """Record cross-channel ``message`` tool calls into the target session. + + When session A (e.g. websocket) uses the *message* tool to send to + channel B (e.g. feishu), the outbound message is delivered to the user + but is not recorded in session B's history. This causes session B to + lose context when the user replies on channel B. + + This method detects such cross-channel sends and appends a lightweight + assistant entry to the target session so it has the necessary context. + """ + from datetime import datetime + + for m in new_messages: + if m.get("role") != "assistant": + continue + tool_calls = m.get("tool_calls") or [] + for tc in tool_calls: + func = tc.get("function", {}) + if func.get("name") != "message": + continue + try: + args = json.loads(func.get("arguments", "{}")) + except (json.JSONDecodeError, TypeError): + continue + + target_channel = args.get("channel") or source_session.key.split(":", 1)[0] + target_chat_id = args.get("chat_id") or source_session.key.split(":", 1)[-1] + target_key = f"{target_channel}:{target_chat_id}" + + if target_key == source_session.key: + continue # same session, nothing to do + + content = args.get("content", "") + if not content: + continue + + target_session = self.sessions._cache.get(target_key) + if target_session is None: + continue # target session doesn't exist yet, skip silently + + target_session.messages.append({ + "role": "assistant", + "content": content, + "timestamp": datetime.now().isoformat(), + "_cross_channel": True, + }) + target_session.updated_at = datetime.now() + self.sessions.save(target_session) + logger.info( + "Cross-channel message persisted: {} -> {}", + source_session.key, target_key, + ) + def _set_runtime_checkpoint(self, session: Session, payload: dict[str, Any]) -> None: """Persist the latest in-flight turn state into session metadata.""" session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload diff --git a/tests/agent/test_loop_save_turn.py b/tests/agent/test_loop_save_turn.py index c965ccd8c..ae1402e0e 100644 --- a/tests/agent/test_loop_save_turn.py +++ b/tests/agent/test_loop_save_turn.py @@ -1,3 +1,4 @@ +import json from pathlib import Path from unittest.mock import AsyncMock, MagicMock @@ -308,3 +309,147 @@ async def test_next_turn_after_crash_closes_pending_user_turn_before_new_input(t {"role": "assistant", "content": "new answer"}, ] assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata + + +def _cross_channel_messages( + source_channel: str = "websocket", + source_chat_id: str = "ws-uuid-123", + target_channel: str = "feishu", + target_chat_id: str = "ou_abc123", +) -> list[dict]: + """Build a message list with a cross-channel message tool call.""" + return [ + {"role": "user", "content": "send report to feishu"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_x1", + "type": "function", + "function": { + "name": "message", + "arguments": json.dumps({ + "content": "Report: audit complete", + "channel": target_channel, + "chat_id": target_chat_id, + }), + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call_x1", + "name": "message", + "content": f"Message sent to {target_channel}:{target_chat_id}", + }, + {"role": "assistant", "content": "Done, sent to feishu."}, + ] + + +def test_cross_channel_message_persisted_in_target_session(tmp_path: Path) -> None: + loop = _make_full_loop(tmp_path) + source_key = "websocket:ws-uuid-123" + target_key = "feishu:ou_abc123" + + # Pre-create the target session (simulate an existing feishu conversation) + target_session = loop.sessions.get_or_create(target_key) + target_session.add_message("user", "hello from feishu") + loop.sessions.save(target_session) + + source_session = loop.sessions.get_or_create(source_key) + msgs = _cross_channel_messages() + loop._save_turn(source_session, msgs, skip=1) # skip user message + + # Source session has its own messages + source_session = loop.sessions.get_or_create(source_key) + assert len(source_session.messages) >= 2 # assistant + tool + final + + # Target session now has the cross-channel message appended + loop.sessions.invalidate(target_key) + target = loop.sessions.get_or_create(target_key) + cross_msg = [m for m in target.messages if m.get("_cross_channel")] + assert len(cross_msg) == 1 + assert cross_msg[0]["content"] == "Report: audit complete" + assert cross_msg[0]["role"] == "assistant" + + +def test_cross_channel_same_session_not_duplicated(tmp_path: Path) -> None: + loop = _make_full_loop(tmp_path) + key = "feishu:ou_same" + + session = loop.sessions.get_or_create(key) + # message tool call targeting the same session — should NOT create a duplicate + msgs = [ + {"role": "user", "content": "hello"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_s1", + "type": "function", + "function": { + "name": "message", + "arguments": json.dumps({ + "content": "same channel msg", + "channel": "feishu", + "chat_id": "ou_same", + }), + }, + } + ], + }, + {"role": "tool", "tool_call_id": "call_s1", "name": "message", "content": "ok"}, + ] + loop._save_turn(session, msgs, skip=1) + + # No _cross_channel entries should exist + assert all(not m.get("_cross_channel") for m in session.messages) + + +def test_cross_channel_target_session_not_exist(tmp_path: Path) -> None: + loop = _make_full_loop(tmp_path) + source_session = loop.sessions.get_or_create("websocket:ws-xyz") + + msgs = _cross_channel_messages( + target_channel="feishu", target_chat_id="ou_nonexistent" + ) + loop._save_turn(source_session, msgs, skip=1) + + # Target session was never created, so no error should occur + assert "feishu:ou_nonexistent" not in loop.sessions._cache + + +def test_cross_channel_non_message_tools_ignored(tmp_path: Path) -> None: + loop = _make_full_loop(tmp_path) + source_session = loop.sessions.get_or_create("websocket:ws-abc") + target_session = loop.sessions.get_or_create("feishu:ou_tgt") + target_session.add_message("user", "hi") + loop.sessions.save(target_session) + + msgs = [ + {"role": "user", "content": "do stuff"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_e1", + "type": "function", + "function": { + "name": "exec", + "arguments": json.dumps({"command": "echo hi"}), + }, + } + ], + }, + {"role": "tool", "tool_call_id": "call_e1", "name": "exec", "content": "hi"}, + ] + loop._save_turn(source_session, msgs, skip=1) + + # exec tool should NOT produce cross-channel entries + target = loop.sessions.get_or_create("feishu:ou_tgt") + cross_msgs = [m for m in target.messages if m.get("_cross_channel")] + assert len(cross_msgs) == 0