mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-21 00:52:34 +00:00
When session A (e.g. websocket) uses the `message` tool to send to channel B (e.g. feishu), the outbound message is delivered to the user but was never recorded in session B's history. This caused session B to lose context when the user replied on that channel. Add `_persist_cross_channel_calls()` to detect cross-channel `message` tool calls during `_save_turn()` and append a lightweight assistant entry (with `_cross_channel: True` marker) to the target session.
456 lines
16 KiB
Python
456 lines
16 KiB
Python
import json
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from nanobot.agent.context import ContextBuilder
|
|
from nanobot.agent.loop import AgentLoop
|
|
from nanobot.bus.events import InboundMessage
|
|
from nanobot.bus.queue import MessageBus
|
|
from nanobot.session.manager import Session
|
|
|
|
|
|
def _mk_loop() -> AgentLoop:
|
|
loop = AgentLoop.__new__(AgentLoop)
|
|
from nanobot.config.schema import AgentDefaults
|
|
|
|
loop.max_tool_result_chars = AgentDefaults().max_tool_result_chars
|
|
return loop
|
|
|
|
|
|
def _make_full_loop(tmp_path: Path) -> AgentLoop:
|
|
provider = MagicMock()
|
|
provider.get_default_model.return_value = "test-model"
|
|
return AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model")
|
|
|
|
|
|
def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None:
|
|
loop = _mk_loop()
|
|
session = Session(key="test:runtime-only")
|
|
runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
|
|
|
|
loop._save_turn(
|
|
session,
|
|
[{"role": "user", "content": [{"type": "text", "text": runtime}]}],
|
|
skip=0,
|
|
)
|
|
assert session.messages == []
|
|
|
|
|
|
def test_save_turn_keeps_image_placeholder_with_path_after_runtime_strip() -> None:
|
|
loop = _mk_loop()
|
|
session = Session(key="test:image")
|
|
runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
|
|
|
|
loop._save_turn(
|
|
session,
|
|
[{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": runtime},
|
|
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}, "_meta": {"path": "/media/feishu/photo.jpg"}},
|
|
],
|
|
}],
|
|
skip=0,
|
|
)
|
|
assert session.messages[0]["content"] == [{"type": "text", "text": "[image: /media/feishu/photo.jpg]"}]
|
|
|
|
|
|
def test_save_turn_keeps_image_placeholder_without_meta() -> None:
|
|
loop = _mk_loop()
|
|
session = Session(key="test:image-no-meta")
|
|
runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
|
|
|
|
loop._save_turn(
|
|
session,
|
|
[{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": runtime},
|
|
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
|
],
|
|
}],
|
|
skip=0,
|
|
)
|
|
assert session.messages[0]["content"] == [{"type": "text", "text": "[image]"}]
|
|
|
|
|
|
def test_save_turn_keeps_tool_results_under_16k() -> None:
|
|
loop = _mk_loop()
|
|
session = Session(key="test:tool-result")
|
|
content = "x" * 12_000
|
|
|
|
loop._save_turn(
|
|
session,
|
|
[{"role": "tool", "tool_call_id": "call_1", "name": "read_file", "content": content}],
|
|
skip=0,
|
|
)
|
|
|
|
assert session.messages[0]["content"] == content
|
|
|
|
|
|
def test_restore_runtime_checkpoint_rehydrates_completed_and_pending_tools() -> None:
|
|
loop = _mk_loop()
|
|
session = Session(
|
|
key="test:checkpoint",
|
|
metadata={
|
|
AgentLoop._RUNTIME_CHECKPOINT_KEY: {
|
|
"assistant_message": {
|
|
"role": "assistant",
|
|
"content": "working",
|
|
"tool_calls": [
|
|
{
|
|
"id": "call_done",
|
|
"type": "function",
|
|
"function": {"name": "read_file", "arguments": "{}"},
|
|
},
|
|
{
|
|
"id": "call_pending",
|
|
"type": "function",
|
|
"function": {"name": "exec", "arguments": "{}"},
|
|
},
|
|
],
|
|
},
|
|
"completed_tool_results": [
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": "call_done",
|
|
"name": "read_file",
|
|
"content": "ok",
|
|
}
|
|
],
|
|
"pending_tool_calls": [
|
|
{
|
|
"id": "call_pending",
|
|
"type": "function",
|
|
"function": {"name": "exec", "arguments": "{}"},
|
|
}
|
|
],
|
|
}
|
|
},
|
|
)
|
|
|
|
restored = loop._restore_runtime_checkpoint(session)
|
|
|
|
assert restored is True
|
|
assert session.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is None
|
|
assert session.messages[0]["role"] == "assistant"
|
|
assert session.messages[1]["tool_call_id"] == "call_done"
|
|
assert session.messages[2]["tool_call_id"] == "call_pending"
|
|
assert "interrupted before this tool finished" in session.messages[2]["content"].lower()
|
|
|
|
|
|
def test_restore_runtime_checkpoint_dedupes_overlapping_tail() -> None:
|
|
loop = _mk_loop()
|
|
session = Session(
|
|
key="test:checkpoint-overlap",
|
|
messages=[
|
|
{
|
|
"role": "assistant",
|
|
"content": "working",
|
|
"tool_calls": [
|
|
{
|
|
"id": "call_done",
|
|
"type": "function",
|
|
"function": {"name": "read_file", "arguments": "{}"},
|
|
},
|
|
{
|
|
"id": "call_pending",
|
|
"type": "function",
|
|
"function": {"name": "exec", "arguments": "{}"},
|
|
},
|
|
],
|
|
},
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": "call_done",
|
|
"name": "read_file",
|
|
"content": "ok",
|
|
},
|
|
],
|
|
metadata={
|
|
AgentLoop._RUNTIME_CHECKPOINT_KEY: {
|
|
"assistant_message": {
|
|
"role": "assistant",
|
|
"content": "working",
|
|
"tool_calls": [
|
|
{
|
|
"id": "call_done",
|
|
"type": "function",
|
|
"function": {"name": "read_file", "arguments": "{}"},
|
|
},
|
|
{
|
|
"id": "call_pending",
|
|
"type": "function",
|
|
"function": {"name": "exec", "arguments": "{}"},
|
|
},
|
|
],
|
|
},
|
|
"completed_tool_results": [
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": "call_done",
|
|
"name": "read_file",
|
|
"content": "ok",
|
|
}
|
|
],
|
|
"pending_tool_calls": [
|
|
{
|
|
"id": "call_pending",
|
|
"type": "function",
|
|
"function": {"name": "exec", "arguments": "{}"},
|
|
}
|
|
],
|
|
}
|
|
},
|
|
)
|
|
|
|
restored = loop._restore_runtime_checkpoint(session)
|
|
|
|
assert restored is True
|
|
assert session.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is None
|
|
assert len(session.messages) == 3
|
|
assert session.messages[0]["role"] == "assistant"
|
|
assert session.messages[1]["tool_call_id"] == "call_done"
|
|
assert session.messages[2]["tool_call_id"] == "call_pending"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_process_message_persists_user_message_before_turn_completes(tmp_path: Path) -> None:
|
|
loop = _make_full_loop(tmp_path)
|
|
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
|
loop._run_agent_loop = AsyncMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign]
|
|
|
|
msg = InboundMessage(channel="feishu", sender_id="u1", chat_id="c1", content="persist me")
|
|
with pytest.raises(RuntimeError, match="boom"):
|
|
await loop._process_message(msg)
|
|
|
|
loop.sessions.invalidate("feishu:c1")
|
|
persisted = loop.sessions.get_or_create("feishu:c1")
|
|
assert [m["role"] for m in persisted.messages] == ["user"]
|
|
assert persisted.messages[0]["content"] == "persist me"
|
|
assert persisted.metadata.get(AgentLoop._PENDING_USER_TURN_KEY) is True
|
|
assert persisted.updated_at >= persisted.created_at
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_process_message_does_not_duplicate_early_persisted_user_message(tmp_path: Path) -> None:
|
|
loop = _make_full_loop(tmp_path)
|
|
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
|
loop._run_agent_loop = AsyncMock(return_value=(
|
|
"done",
|
|
None,
|
|
[
|
|
{"role": "system", "content": "system"},
|
|
{"role": "user", "content": "hello"},
|
|
{"role": "assistant", "content": "done"},
|
|
],
|
|
"stop",
|
|
False,
|
|
)) # type: ignore[method-assign]
|
|
|
|
result = await loop._process_message(
|
|
InboundMessage(channel="feishu", sender_id="u1", chat_id="c2", content="hello")
|
|
)
|
|
|
|
assert result is not None
|
|
assert result.content == "done"
|
|
session = loop.sessions.get_or_create("feishu:c2")
|
|
assert [
|
|
{k: v for k, v in m.items() if k in {"role", "content"}}
|
|
for m in session.messages
|
|
] == [
|
|
{"role": "user", "content": "hello"},
|
|
{"role": "assistant", "content": "done"},
|
|
]
|
|
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_next_turn_after_crash_closes_pending_user_turn_before_new_input(tmp_path: Path) -> None:
|
|
loop = _make_full_loop(tmp_path)
|
|
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
|
loop.provider.chat_with_retry = AsyncMock(return_value=MagicMock()) # unused because _run_agent_loop is stubbed
|
|
|
|
session = loop.sessions.get_or_create("feishu:c3")
|
|
session.add_message("user", "old question")
|
|
session.metadata[AgentLoop._PENDING_USER_TURN_KEY] = True
|
|
loop.sessions.save(session)
|
|
|
|
loop._run_agent_loop = AsyncMock(return_value=(
|
|
"new answer",
|
|
None,
|
|
[
|
|
{"role": "system", "content": "system"},
|
|
{"role": "user", "content": "old question"},
|
|
{"role": "assistant", "content": "Error: Task interrupted before a response was generated."},
|
|
{"role": "user", "content": "new question"},
|
|
{"role": "assistant", "content": "new answer"},
|
|
],
|
|
"stop",
|
|
False,
|
|
)) # type: ignore[method-assign]
|
|
|
|
result = await loop._process_message(
|
|
InboundMessage(channel="feishu", sender_id="u1", chat_id="c3", content="new question")
|
|
)
|
|
|
|
assert result is not None
|
|
assert result.content == "new answer"
|
|
session = loop.sessions.get_or_create("feishu:c3")
|
|
assert [
|
|
{k: v for k, v in m.items() if k in {"role", "content"}}
|
|
for m in session.messages
|
|
] == [
|
|
{"role": "user", "content": "old question"},
|
|
{"role": "assistant", "content": "Error: Task interrupted before a response was generated."},
|
|
{"role": "user", "content": "new question"},
|
|
{"role": "assistant", "content": "new answer"},
|
|
]
|
|
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
|
|
|
|
|
|
def _cross_channel_messages(
|
|
source_channel: str = "websocket",
|
|
source_chat_id: str = "ws-uuid-123",
|
|
target_channel: str = "feishu",
|
|
target_chat_id: str = "ou_abc123",
|
|
) -> list[dict]:
|
|
"""Build a message list with a cross-channel message tool call."""
|
|
return [
|
|
{"role": "user", "content": "send report to feishu"},
|
|
{
|
|
"role": "assistant",
|
|
"content": "",
|
|
"tool_calls": [
|
|
{
|
|
"id": "call_x1",
|
|
"type": "function",
|
|
"function": {
|
|
"name": "message",
|
|
"arguments": json.dumps({
|
|
"content": "Report: audit complete",
|
|
"channel": target_channel,
|
|
"chat_id": target_chat_id,
|
|
}),
|
|
},
|
|
}
|
|
],
|
|
},
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": "call_x1",
|
|
"name": "message",
|
|
"content": f"Message sent to {target_channel}:{target_chat_id}",
|
|
},
|
|
{"role": "assistant", "content": "Done, sent to feishu."},
|
|
]
|
|
|
|
|
|
def test_cross_channel_message_persisted_in_target_session(tmp_path: Path) -> None:
|
|
loop = _make_full_loop(tmp_path)
|
|
source_key = "websocket:ws-uuid-123"
|
|
target_key = "feishu:ou_abc123"
|
|
|
|
# Pre-create the target session (simulate an existing feishu conversation)
|
|
target_session = loop.sessions.get_or_create(target_key)
|
|
target_session.add_message("user", "hello from feishu")
|
|
loop.sessions.save(target_session)
|
|
|
|
source_session = loop.sessions.get_or_create(source_key)
|
|
msgs = _cross_channel_messages()
|
|
loop._save_turn(source_session, msgs, skip=1) # skip user message
|
|
|
|
# Source session has its own messages
|
|
source_session = loop.sessions.get_or_create(source_key)
|
|
assert len(source_session.messages) >= 2 # assistant + tool + final
|
|
|
|
# Target session now has the cross-channel message appended
|
|
loop.sessions.invalidate(target_key)
|
|
target = loop.sessions.get_or_create(target_key)
|
|
cross_msg = [m for m in target.messages if m.get("_cross_channel")]
|
|
assert len(cross_msg) == 1
|
|
assert cross_msg[0]["content"] == "Report: audit complete"
|
|
assert cross_msg[0]["role"] == "assistant"
|
|
|
|
|
|
def test_cross_channel_same_session_not_duplicated(tmp_path: Path) -> None:
|
|
loop = _make_full_loop(tmp_path)
|
|
key = "feishu:ou_same"
|
|
|
|
session = loop.sessions.get_or_create(key)
|
|
# message tool call targeting the same session — should NOT create a duplicate
|
|
msgs = [
|
|
{"role": "user", "content": "hello"},
|
|
{
|
|
"role": "assistant",
|
|
"content": "",
|
|
"tool_calls": [
|
|
{
|
|
"id": "call_s1",
|
|
"type": "function",
|
|
"function": {
|
|
"name": "message",
|
|
"arguments": json.dumps({
|
|
"content": "same channel msg",
|
|
"channel": "feishu",
|
|
"chat_id": "ou_same",
|
|
}),
|
|
},
|
|
}
|
|
],
|
|
},
|
|
{"role": "tool", "tool_call_id": "call_s1", "name": "message", "content": "ok"},
|
|
]
|
|
loop._save_turn(session, msgs, skip=1)
|
|
|
|
# No _cross_channel entries should exist
|
|
assert all(not m.get("_cross_channel") for m in session.messages)
|
|
|
|
|
|
def test_cross_channel_target_session_not_exist(tmp_path: Path) -> None:
|
|
loop = _make_full_loop(tmp_path)
|
|
source_session = loop.sessions.get_or_create("websocket:ws-xyz")
|
|
|
|
msgs = _cross_channel_messages(
|
|
target_channel="feishu", target_chat_id="ou_nonexistent"
|
|
)
|
|
loop._save_turn(source_session, msgs, skip=1)
|
|
|
|
# Target session was never created, so no error should occur
|
|
assert "feishu:ou_nonexistent" not in loop.sessions._cache
|
|
|
|
|
|
def test_cross_channel_non_message_tools_ignored(tmp_path: Path) -> None:
|
|
loop = _make_full_loop(tmp_path)
|
|
source_session = loop.sessions.get_or_create("websocket:ws-abc")
|
|
target_session = loop.sessions.get_or_create("feishu:ou_tgt")
|
|
target_session.add_message("user", "hi")
|
|
loop.sessions.save(target_session)
|
|
|
|
msgs = [
|
|
{"role": "user", "content": "do stuff"},
|
|
{
|
|
"role": "assistant",
|
|
"content": "",
|
|
"tool_calls": [
|
|
{
|
|
"id": "call_e1",
|
|
"type": "function",
|
|
"function": {
|
|
"name": "exec",
|
|
"arguments": json.dumps({"command": "echo hi"}),
|
|
},
|
|
}
|
|
],
|
|
},
|
|
{"role": "tool", "tool_call_id": "call_e1", "name": "exec", "content": "hi"},
|
|
]
|
|
loop._save_turn(source_session, msgs, skip=1)
|
|
|
|
# exec tool should NOT produce cross-channel entries
|
|
target = loop.sessions.get_or_create("feishu:ou_tgt")
|
|
cross_msgs = [m for m in target.messages if m.get("_cross_channel")]
|
|
assert len(cross_msgs) == 0
|