mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 08:32:25 +00:00
test(agent): expand coverage and refactor test structure
- Add 42 tests for ContextBuilder (context.py: 0→42 tests) - Add 37 tests for SubagentManager lifecycle (subagent.py: 2→37 tests) - Add 42 unit tests for AutoCompact in isolation - Split monolithic test_runner.py (3313 lines) into 9 focused files: test_runner_core, test_runner_hooks, test_runner_errors, test_runner_safety, test_runner_persistence, test_runner_governance, test_runner_tool_execution, test_runner_injections, test_loop_runner_integration - Add 3 config passthrough tests (temperature/max_tokens/reasoning_effort) - Fix fragile patch.object(__init__) in test_stop_preserves_context - Create shared conftest.py with make_provider/make_loop factories Total: 934 tests passing, 0 regressions
This commit is contained in:
parent
00597fccd6
commit
99cc6ee808
93
tests/agent/conftest.py
Normal file
93
tests/agent/conftest.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
"""Shared fixtures and helpers for agent tests."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMProvider
|
||||||
|
|
||||||
|
|
||||||
|
def make_provider(
|
||||||
|
default_model: str = "test-model",
|
||||||
|
*,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
spec: bool = True,
|
||||||
|
) -> MagicMock:
|
||||||
|
"""Create a spec-limited LLM provider mock."""
|
||||||
|
mock_type = MagicMock(spec=LLMProvider) if spec else MagicMock()
|
||||||
|
provider = mock_type
|
||||||
|
provider.get_default_model.return_value = default_model
|
||||||
|
provider.generation = SimpleNamespace(
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=0.1,
|
||||||
|
reasoning_effort=None,
|
||||||
|
)
|
||||||
|
provider.estimate_prompt_tokens.return_value = (10_000, "test")
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
def make_loop(
|
||||||
|
tmp_path: Path,
|
||||||
|
*,
|
||||||
|
model: str = "test-model",
|
||||||
|
context_window_tokens: int = 128_000,
|
||||||
|
session_ttl_minutes: int = 0,
|
||||||
|
max_messages: int = 120,
|
||||||
|
unified_session: bool = False,
|
||||||
|
mcp_servers: dict | None = None,
|
||||||
|
tools_config=None,
|
||||||
|
model_presets: dict | None = None,
|
||||||
|
hooks: list | None = None,
|
||||||
|
provider: MagicMock | None = None,
|
||||||
|
patch_deps: bool = False,
|
||||||
|
) -> AgentLoop:
|
||||||
|
"""Create a real AgentLoop for testing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
patch_deps: If True, patch ContextBuilder/SessionManager/SubagentManager
|
||||||
|
during construction (needed when workspace has no real files).
|
||||||
|
"""
|
||||||
|
bus = MessageBus()
|
||||||
|
if provider is None:
|
||||||
|
provider = make_provider(default_model=model)
|
||||||
|
|
||||||
|
kwargs = dict(
|
||||||
|
bus=bus,
|
||||||
|
provider=provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
model=model,
|
||||||
|
context_window_tokens=context_window_tokens,
|
||||||
|
session_ttl_minutes=session_ttl_minutes,
|
||||||
|
max_messages=max_messages,
|
||||||
|
unified_session=unified_session,
|
||||||
|
)
|
||||||
|
if mcp_servers is not None:
|
||||||
|
kwargs["mcp_servers"] = mcp_servers
|
||||||
|
if tools_config is not None:
|
||||||
|
kwargs["tools_config"] = tools_config
|
||||||
|
if model_presets is not None:
|
||||||
|
kwargs["model_presets"] = model_presets
|
||||||
|
if hooks is not None:
|
||||||
|
kwargs["hooks"] = hooks
|
||||||
|
|
||||||
|
if patch_deps:
|
||||||
|
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||||
|
patch("nanobot.agent.loop.SessionManager"), \
|
||||||
|
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||||
|
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||||
|
return AgentLoop(**kwargs)
|
||||||
|
return AgentLoop(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def loop_factory(tmp_path):
|
||||||
|
"""Fixture providing a factory for creating AgentLoop instances."""
|
||||||
|
def _factory(**kwargs):
|
||||||
|
return make_loop(tmp_path, **kwargs)
|
||||||
|
return _factory
|
||||||
554
tests/agent/test_autocompact_unit.py
Normal file
554
tests/agent/test_autocompact_unit.py
Normal file
@ -0,0 +1,554 @@
|
|||||||
|
"""Direct unit tests for AutoCompact class methods in isolation."""
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.autocompact import AutoCompact
|
||||||
|
from nanobot.session.manager import Session, SessionManager
|
||||||
|
|
||||||
|
|
||||||
|
def _make_session(
|
||||||
|
key: str = "cli:test",
|
||||||
|
messages: list | None = None,
|
||||||
|
last_consolidated: int = 0,
|
||||||
|
updated_at: datetime | None = None,
|
||||||
|
metadata: dict | None = None,
|
||||||
|
) -> Session:
|
||||||
|
"""Create a Session with sensible defaults for testing."""
|
||||||
|
session = Session(
|
||||||
|
key=key,
|
||||||
|
messages=messages or [],
|
||||||
|
metadata=metadata or {},
|
||||||
|
last_consolidated=last_consolidated,
|
||||||
|
)
|
||||||
|
if updated_at is not None:
|
||||||
|
session.updated_at = updated_at
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
def _make_autocompact(
|
||||||
|
ttl: int = 15,
|
||||||
|
sessions: SessionManager | None = None,
|
||||||
|
consolidator: MagicMock | None = None,
|
||||||
|
) -> AutoCompact:
|
||||||
|
"""Create an AutoCompact with mock dependencies."""
|
||||||
|
if sessions is None:
|
||||||
|
sessions = MagicMock(spec=SessionManager)
|
||||||
|
if consolidator is None:
|
||||||
|
consolidator = MagicMock()
|
||||||
|
consolidator.archive = AsyncMock(return_value="Summary.")
|
||||||
|
return AutoCompact(
|
||||||
|
sessions=sessions,
|
||||||
|
consolidator=consolidator,
|
||||||
|
session_ttl_minutes=ttl,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_turns(session: Session, turns: int, *, prefix: str = "msg") -> None:
|
||||||
|
"""Append simple user/assistant turns to a session."""
|
||||||
|
for i in range(turns):
|
||||||
|
session.add_message("user", f"{prefix} user {i}")
|
||||||
|
session.add_message("assistant", f"{prefix} assistant {i}")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# __init__
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestInit:
|
||||||
|
"""Test AutoCompact.__init__ stores constructor arguments correctly."""
|
||||||
|
|
||||||
|
def test_stores_ttl(self):
|
||||||
|
"""_ttl should match session_ttl_minutes argument."""
|
||||||
|
ac = _make_autocompact(ttl=30)
|
||||||
|
assert ac._ttl == 30
|
||||||
|
|
||||||
|
def test_default_ttl_is_zero(self):
|
||||||
|
"""Default TTL should be 0."""
|
||||||
|
ac = _make_autocompact(ttl=0)
|
||||||
|
assert ac._ttl == 0
|
||||||
|
|
||||||
|
def test_archiving_set_is_empty(self):
|
||||||
|
"""_archiving should start as an empty set."""
|
||||||
|
ac = _make_autocompact()
|
||||||
|
assert ac._archiving == set()
|
||||||
|
|
||||||
|
def test_summaries_dict_is_empty(self):
|
||||||
|
"""_summaries should start as an empty dict."""
|
||||||
|
ac = _make_autocompact()
|
||||||
|
assert ac._summaries == {}
|
||||||
|
|
||||||
|
def test_stores_sessions_reference(self):
|
||||||
|
"""sessions attribute should reference the passed SessionManager."""
|
||||||
|
mock_sm = MagicMock(spec=SessionManager)
|
||||||
|
ac = _make_autocompact(sessions=mock_sm)
|
||||||
|
assert ac.sessions is mock_sm
|
||||||
|
|
||||||
|
def test_stores_consolidator_reference(self):
|
||||||
|
"""consolidator attribute should reference the passed Consolidator."""
|
||||||
|
mock_c = MagicMock()
|
||||||
|
ac = _make_autocompact(consolidator=mock_c)
|
||||||
|
assert ac.consolidator is mock_c
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _is_expired
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsExpired:
|
||||||
|
"""Test AutoCompact._is_expired edge cases."""
|
||||||
|
|
||||||
|
def test_ttl_zero_always_false(self):
|
||||||
|
"""TTL=0 means auto-compact is disabled; always returns False."""
|
||||||
|
ac = _make_autocompact(ttl=0)
|
||||||
|
old = datetime.now() - timedelta(days=365)
|
||||||
|
assert ac._is_expired(old) is False
|
||||||
|
|
||||||
|
def test_none_timestamp_returns_false(self):
|
||||||
|
"""None timestamp should return False."""
|
||||||
|
ac = _make_autocompact(ttl=15)
|
||||||
|
assert ac._is_expired(None) is False
|
||||||
|
|
||||||
|
def test_empty_string_timestamp_returns_false(self):
|
||||||
|
"""Empty string timestamp should return False (falsy)."""
|
||||||
|
ac = _make_autocompact(ttl=15)
|
||||||
|
assert ac._is_expired("") is False
|
||||||
|
|
||||||
|
def test_exactly_at_boundary_is_expired(self):
|
||||||
|
"""Timestamp exactly at TTL boundary should be expired (>=)."""
|
||||||
|
ac = _make_autocompact(ttl=15)
|
||||||
|
now = datetime(2026, 1, 1, 12, 0, 0)
|
||||||
|
ts = now - timedelta(minutes=15)
|
||||||
|
assert ac._is_expired(ts, now=now) is True
|
||||||
|
|
||||||
|
def test_just_under_boundary_not_expired(self):
|
||||||
|
"""Timestamp just under TTL boundary should NOT be expired."""
|
||||||
|
ac = _make_autocompact(ttl=15)
|
||||||
|
now = datetime(2026, 1, 1, 12, 0, 0)
|
||||||
|
ts = now - timedelta(minutes=14, seconds=59)
|
||||||
|
assert ac._is_expired(ts, now=now) is False
|
||||||
|
|
||||||
|
def test_iso_string_parses_correctly(self):
|
||||||
|
"""ISO format string timestamp should be parsed and evaluated."""
|
||||||
|
ac = _make_autocompact(ttl=15)
|
||||||
|
now = datetime(2026, 1, 1, 12, 0, 0)
|
||||||
|
ts = (now - timedelta(minutes=20)).isoformat()
|
||||||
|
assert ac._is_expired(ts, now=now) is True
|
||||||
|
|
||||||
|
def test_custom_now_parameter(self):
|
||||||
|
"""Custom 'now' parameter should override datetime.now()."""
|
||||||
|
ac = _make_autocompact(ttl=10)
|
||||||
|
ts = datetime(2026, 1, 1, 10, 0, 0)
|
||||||
|
# 9 minutes later → not expired
|
||||||
|
now_under = datetime(2026, 1, 1, 10, 9, 0)
|
||||||
|
assert ac._is_expired(ts, now=now_under) is False
|
||||||
|
# 10 minutes later → expired
|
||||||
|
now_over = datetime(2026, 1, 1, 10, 10, 0)
|
||||||
|
assert ac._is_expired(ts, now=now_over) is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _format_summary
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatSummary:
|
||||||
|
"""Test AutoCompact._format_summary static method."""
|
||||||
|
|
||||||
|
def test_contains_isoformat_timestamp(self):
|
||||||
|
"""Output should contain last_active as isoformat."""
|
||||||
|
last_active = datetime(2026, 5, 13, 14, 30, 0)
|
||||||
|
result = AutoCompact._format_summary("Some text", last_active)
|
||||||
|
assert "2026-05-13T14:30:00" in result
|
||||||
|
|
||||||
|
def test_contains_summary_text(self):
|
||||||
|
"""Output should contain the provided text verbatim."""
|
||||||
|
last_active = datetime(2026, 1, 1)
|
||||||
|
result = AutoCompact._format_summary("User discussed Python.", last_active)
|
||||||
|
assert "User discussed Python." in result
|
||||||
|
|
||||||
|
def test_output_starts_with_label(self):
|
||||||
|
"""Output should start with the standard prefix."""
|
||||||
|
last_active = datetime(2026, 1, 1)
|
||||||
|
result = AutoCompact._format_summary("text", last_active)
|
||||||
|
assert result.startswith("Previous conversation summary (last active ")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _split_unconsolidated
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSplitUnconsolidated:
|
||||||
|
"""Test AutoCompact._split_unconsolidated splitting logic."""
|
||||||
|
|
||||||
|
def test_empty_session_returns_both_empty(self):
|
||||||
|
"""Empty session should return ([], [])."""
|
||||||
|
ac = _make_autocompact()
|
||||||
|
session = _make_session(messages=[])
|
||||||
|
archive, kept = ac._split_unconsolidated(session)
|
||||||
|
assert archive == []
|
||||||
|
assert kept == []
|
||||||
|
|
||||||
|
def test_all_messages_archivable_when_more_than_suffix(self):
|
||||||
|
"""Session with many messages should archive a prefix and keep suffix."""
|
||||||
|
ac = _make_autocompact()
|
||||||
|
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||||
|
session = _make_session(messages=msgs)
|
||||||
|
archive, kept = ac._split_unconsolidated(session)
|
||||||
|
assert len(archive) > 0
|
||||||
|
assert len(kept) <= AutoCompact._RECENT_SUFFIX_MESSAGES
|
||||||
|
|
||||||
|
def test_fewer_messages_than_suffix_returns_empty_archive(self):
|
||||||
|
"""Session with fewer messages than suffix should have empty archive."""
|
||||||
|
ac = _make_autocompact()
|
||||||
|
msgs = [{"role": "user", "content": f"u{i}"} for i in range(3)]
|
||||||
|
session = _make_session(messages=msgs)
|
||||||
|
archive, kept = ac._split_unconsolidated(session)
|
||||||
|
assert archive == []
|
||||||
|
assert len(kept) == len(msgs)
|
||||||
|
|
||||||
|
def test_respects_last_consolidated_offset(self):
|
||||||
|
"""Only messages after last_consolidated should be considered."""
|
||||||
|
ac = _make_autocompact()
|
||||||
|
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||||
|
# First 10 are already consolidated
|
||||||
|
session = _make_session(messages=msgs, last_consolidated=10)
|
||||||
|
archive, kept = ac._split_unconsolidated(session)
|
||||||
|
# Only the tail of 10 messages is considered for splitting
|
||||||
|
assert all(m["content"] in [f"u{i}" for i in range(10, 20)] for m in kept)
|
||||||
|
assert all(m["content"] in [f"u{i}" for i in range(10, 20)] for m in archive)
|
||||||
|
|
||||||
|
def test_retain_recent_legal_suffix_keeps_last_n(self):
|
||||||
|
"""The kept suffix should be at most _RECENT_SUFFIX_MESSAGES long."""
|
||||||
|
ac = _make_autocompact()
|
||||||
|
# 20 user messages = 20 messages total, all after last_consolidated=0
|
||||||
|
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||||
|
session = _make_session(messages=msgs)
|
||||||
|
archive, kept = ac._split_unconsolidated(session)
|
||||||
|
assert len(kept) <= AutoCompact._RECENT_SUFFIX_MESSAGES
|
||||||
|
assert len(archive) == len(msgs) - len(kept)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# check_expired
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckExpired:
|
||||||
|
"""Test AutoCompact.check_expired scheduling logic."""
|
||||||
|
|
||||||
|
def test_empty_sessions_list(self):
|
||||||
|
"""No sessions → schedule_background should never be called."""
|
||||||
|
ac = _make_autocompact(ttl=15)
|
||||||
|
mock_sm = MagicMock(spec=SessionManager)
|
||||||
|
mock_sm.list_sessions.return_value = []
|
||||||
|
ac.sessions = mock_sm
|
||||||
|
scheduler = MagicMock()
|
||||||
|
ac.check_expired(scheduler)
|
||||||
|
scheduler.assert_not_called()
|
||||||
|
|
||||||
|
def test_expired_session_schedules_background(self):
|
||||||
|
"""Expired session should trigger schedule_background."""
|
||||||
|
ac = _make_autocompact(ttl=15)
|
||||||
|
mock_sm = MagicMock(spec=SessionManager)
|
||||||
|
old_ts = (datetime.now() - timedelta(minutes=20)).isoformat()
|
||||||
|
mock_sm.list_sessions.return_value = [{"key": "cli:old", "updated_at": old_ts}]
|
||||||
|
ac.sessions = mock_sm
|
||||||
|
scheduler = MagicMock()
|
||||||
|
ac.check_expired(scheduler)
|
||||||
|
scheduler.assert_called_once()
|
||||||
|
assert "cli:old" in ac._archiving
|
||||||
|
|
||||||
|
def test_active_session_key_skips(self):
|
||||||
|
"""Session in active_session_keys should be skipped."""
|
||||||
|
ac = _make_autocompact(ttl=15)
|
||||||
|
mock_sm = MagicMock(spec=SessionManager)
|
||||||
|
old_ts = (datetime.now() - timedelta(minutes=20)).isoformat()
|
||||||
|
mock_sm.list_sessions.return_value = [{"key": "cli:busy", "updated_at": old_ts}]
|
||||||
|
ac.sessions = mock_sm
|
||||||
|
scheduler = MagicMock()
|
||||||
|
ac.check_expired(scheduler, active_session_keys={"cli:busy"})
|
||||||
|
scheduler.assert_not_called()
|
||||||
|
|
||||||
|
def test_session_already_in_archiving_skips(self):
|
||||||
|
"""Session already in _archiving set should be skipped."""
|
||||||
|
ac = _make_autocompact(ttl=15)
|
||||||
|
mock_sm = MagicMock(spec=SessionManager)
|
||||||
|
old_ts = (datetime.now() - timedelta(minutes=20)).isoformat()
|
||||||
|
mock_sm.list_sessions.return_value = [{"key": "cli:dup", "updated_at": old_ts}]
|
||||||
|
ac.sessions = mock_sm
|
||||||
|
ac._archiving.add("cli:dup")
|
||||||
|
scheduler = MagicMock()
|
||||||
|
ac.check_expired(scheduler)
|
||||||
|
scheduler.assert_not_called()
|
||||||
|
|
||||||
|
def test_session_with_no_key_skips(self):
|
||||||
|
"""Session info with empty/missing key should be skipped."""
|
||||||
|
ac = _make_autocompact(ttl=15)
|
||||||
|
mock_sm = MagicMock(spec=SessionManager)
|
||||||
|
mock_sm.list_sessions.return_value = [{"key": "", "updated_at": "old"}]
|
||||||
|
ac.sessions = mock_sm
|
||||||
|
scheduler = MagicMock()
|
||||||
|
ac.check_expired(scheduler)
|
||||||
|
scheduler.assert_not_called()
|
||||||
|
|
||||||
|
def test_session_with_missing_key_field_skips(self):
|
||||||
|
"""Session info dict without 'key' field should be skipped."""
|
||||||
|
ac = _make_autocompact(ttl=15)
|
||||||
|
mock_sm = MagicMock(spec=SessionManager)
|
||||||
|
mock_sm.list_sessions.return_value = [{"updated_at": "old"}]
|
||||||
|
ac.sessions = mock_sm
|
||||||
|
scheduler = MagicMock()
|
||||||
|
ac.check_expired(scheduler)
|
||||||
|
scheduler.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _archive
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestArchive:
|
||||||
|
"""Test AutoCompact._archive async method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_session_updates_timestamp_no_archive_call(self):
|
||||||
|
"""Empty session should refresh updated_at and not call consolidator.archive."""
|
||||||
|
ac = _make_autocompact()
|
||||||
|
mock_sm = MagicMock(spec=SessionManager)
|
||||||
|
empty_session = _make_session(messages=[])
|
||||||
|
mock_sm.get_or_create.return_value = empty_session
|
||||||
|
ac.sessions = mock_sm
|
||||||
|
ac.consolidator.archive = AsyncMock(return_value="Summary.")
|
||||||
|
|
||||||
|
await ac._archive("cli:test")
|
||||||
|
|
||||||
|
ac.consolidator.archive.assert_not_called()
|
||||||
|
mock_sm.save.assert_called_once_with(empty_session)
|
||||||
|
# updated_at was refreshed
|
||||||
|
assert empty_session.updated_at > datetime.now() - timedelta(seconds=5)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_archive_returns_empty_string_no_summary_stored(self):
|
||||||
|
"""If archive returns empty string, no summary should be stored."""
|
||||||
|
ac = _make_autocompact()
|
||||||
|
mock_sm = MagicMock(spec=SessionManager)
|
||||||
|
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||||
|
session = _make_session(messages=msgs)
|
||||||
|
mock_sm.get_or_create.return_value = session
|
||||||
|
ac.sessions = mock_sm
|
||||||
|
ac.consolidator.archive = AsyncMock(return_value="")
|
||||||
|
|
||||||
|
await ac._archive("cli:test")
|
||||||
|
|
||||||
|
assert "cli:test" not in ac._summaries
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_archive_returns_nothing_no_summary_stored(self):
|
||||||
|
"""If archive returns '(nothing)', no summary should be stored."""
|
||||||
|
ac = _make_autocompact()
|
||||||
|
mock_sm = MagicMock(spec=SessionManager)
|
||||||
|
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||||
|
session = _make_session(messages=msgs)
|
||||||
|
mock_sm.get_or_create.return_value = session
|
||||||
|
ac.sessions = mock_sm
|
||||||
|
ac.consolidator.archive = AsyncMock(return_value="(nothing)")
|
||||||
|
|
||||||
|
await ac._archive("cli:test")
|
||||||
|
|
||||||
|
assert "cli:test" not in ac._summaries
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_archive_exception_caught_key_removed_from_archiving(self):
|
||||||
|
"""If archive raises, exception is caught and key removed from _archiving."""
|
||||||
|
ac = _make_autocompact()
|
||||||
|
mock_sm = MagicMock(spec=SessionManager)
|
||||||
|
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||||
|
session = _make_session(messages=msgs)
|
||||||
|
mock_sm.get_or_create.return_value = session
|
||||||
|
ac.sessions = mock_sm
|
||||||
|
ac.consolidator.archive = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
await ac._archive("cli:test")
|
||||||
|
|
||||||
|
assert "cli:test" not in ac._archiving
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_successful_archive_stores_summary_in_summaries_and_metadata(self):
|
||||||
|
"""Successful archive should store summary in _summaries dict and metadata."""
|
||||||
|
ac = _make_autocompact()
|
||||||
|
mock_sm = MagicMock(spec=SessionManager)
|
||||||
|
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||||
|
last_active = datetime(2026, 5, 13, 10, 0, 0)
|
||||||
|
session = _make_session(messages=msgs, updated_at=last_active)
|
||||||
|
mock_sm.get_or_create.return_value = session
|
||||||
|
ac.sessions = mock_sm
|
||||||
|
ac.consolidator.archive = AsyncMock(return_value="User discussed AI.")
|
||||||
|
|
||||||
|
await ac._archive("cli:test")
|
||||||
|
|
||||||
|
# _summaries
|
||||||
|
entry = ac._summaries.get("cli:test")
|
||||||
|
assert entry is not None
|
||||||
|
assert entry[0] == "User discussed AI."
|
||||||
|
assert entry[1] == last_active
|
||||||
|
# metadata
|
||||||
|
meta = session.metadata.get("_last_summary")
|
||||||
|
assert meta is not None
|
||||||
|
assert meta["text"] == "User discussed AI."
|
||||||
|
assert "last_active" in meta
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_finally_block_always_removes_from_archiving(self):
|
||||||
|
"""Finally block should always remove key from _archiving, even on error."""
|
||||||
|
ac = _make_autocompact()
|
||||||
|
mock_sm = MagicMock(spec=SessionManager)
|
||||||
|
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||||
|
session = _make_session(messages=msgs)
|
||||||
|
mock_sm.get_or_create.return_value = session
|
||||||
|
ac.sessions = mock_sm
|
||||||
|
ac.consolidator.archive = AsyncMock(side_effect=RuntimeError("fail"))
|
||||||
|
|
||||||
|
# Pre-add key to archiving to verify it gets removed
|
||||||
|
ac._archiving.add("cli:test")
|
||||||
|
await ac._archive("cli:test")
|
||||||
|
assert "cli:test" not in ac._archiving
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_finally_removes_from_archiving_on_success(self):
|
||||||
|
"""Finally block should remove key from _archiving on success too."""
|
||||||
|
ac = _make_autocompact()
|
||||||
|
mock_sm = MagicMock(spec=SessionManager)
|
||||||
|
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||||
|
session = _make_session(messages=msgs)
|
||||||
|
mock_sm.get_or_create.return_value = session
|
||||||
|
ac.sessions = mock_sm
|
||||||
|
ac.consolidator.archive = AsyncMock(return_value="Summary.")
|
||||||
|
|
||||||
|
ac._archiving.add("cli:test")
|
||||||
|
await ac._archive("cli:test")
|
||||||
|
assert "cli:test" not in ac._archiving
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# prepare_session
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPrepareSession:
|
||||||
|
"""Test AutoCompact.prepare_session logic."""
|
||||||
|
|
||||||
|
def test_key_in_archiving_reloads_session(self):
|
||||||
|
"""If key is in _archiving, session should be reloaded via get_or_create."""
|
||||||
|
ac = _make_autocompact()
|
||||||
|
mock_sm = MagicMock(spec=SessionManager)
|
||||||
|
reloaded = _make_session(key="cli:test")
|
||||||
|
mock_sm.get_or_create.return_value = reloaded
|
||||||
|
ac.sessions = mock_sm
|
||||||
|
ac._archiving.add("cli:test")
|
||||||
|
|
||||||
|
original_session = _make_session()
|
||||||
|
result_session, summary = ac.prepare_session(original_session, "cli:test")
|
||||||
|
|
||||||
|
mock_sm.get_or_create.assert_called_once_with("cli:test")
|
||||||
|
assert result_session is reloaded
|
||||||
|
|
||||||
|
def test_expired_session_reloads(self):
|
||||||
|
"""If session is expired, it should be reloaded via get_or_create."""
|
||||||
|
ac = _make_autocompact(ttl=15)
|
||||||
|
mock_sm = MagicMock(spec=SessionManager)
|
||||||
|
reloaded = _make_session(key="cli:test", updated_at=datetime.now())
|
||||||
|
mock_sm.get_or_create.return_value = reloaded
|
||||||
|
ac.sessions = mock_sm
|
||||||
|
|
||||||
|
old_session = _make_session(updated_at=datetime.now() - timedelta(minutes=20))
|
||||||
|
result_session, summary = ac.prepare_session(old_session, "cli:test")
|
||||||
|
|
||||||
|
mock_sm.get_or_create.assert_called_once_with("cli:test")
|
||||||
|
assert result_session is reloaded
|
||||||
|
|
||||||
|
def test_hot_path_summary_from_summaries(self):
|
||||||
|
"""Summary from _summaries dict should be returned (hot path)."""
|
||||||
|
ac = _make_autocompact()
|
||||||
|
session = _make_session()
|
||||||
|
last_active = datetime(2026, 5, 13, 14, 0, 0)
|
||||||
|
ac._summaries["cli:test"] = ("Hot summary.", last_active)
|
||||||
|
|
||||||
|
result_session, summary = ac.prepare_session(session, "cli:test")
|
||||||
|
|
||||||
|
assert result_session is session
|
||||||
|
assert summary is not None
|
||||||
|
assert "Hot summary." in summary
|
||||||
|
assert "Previous conversation summary" in summary
|
||||||
|
|
||||||
|
def test_hot_path_pops_summary_one_shot(self):
|
||||||
|
"""Hot path should pop the summary (one-shot; second call returns None)."""
|
||||||
|
ac = _make_autocompact()
|
||||||
|
session = _make_session()
|
||||||
|
last_active = datetime(2026, 1, 1)
|
||||||
|
ac._summaries["cli:test"] = ("One-shot.", last_active)
|
||||||
|
|
||||||
|
_, summary1 = ac.prepare_session(session, "cli:test")
|
||||||
|
assert summary1 is not None
|
||||||
|
# Second call: hot path entry was popped
|
||||||
|
_, summary2 = ac.prepare_session(session, "cli:test")
|
||||||
|
assert summary2 is None
|
||||||
|
|
||||||
|
def test_cold_path_summary_from_metadata(self):
|
||||||
|
"""When _summaries is empty, summary should come from metadata (cold path)."""
|
||||||
|
ac = _make_autocompact()
|
||||||
|
last_active = datetime(2026, 5, 13, 14, 0, 0)
|
||||||
|
session = _make_session(metadata={
|
||||||
|
"_last_summary": {
|
||||||
|
"text": "Cold summary.",
|
||||||
|
"last_active": last_active.isoformat(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
result_session, summary = ac.prepare_session(session, "cli:test")
|
||||||
|
|
||||||
|
assert result_session is session
|
||||||
|
assert summary is not None
|
||||||
|
assert "Cold summary." in summary
|
||||||
|
|
||||||
|
def test_no_summary_available_returns_none(self):
|
||||||
|
"""When no summary is available, should return (session, None)."""
|
||||||
|
ac = _make_autocompact()
|
||||||
|
session = _make_session()
|
||||||
|
|
||||||
|
result_session, summary = ac.prepare_session(session, "cli:test")
|
||||||
|
|
||||||
|
assert result_session is session
|
||||||
|
assert summary is None
|
||||||
|
|
||||||
|
def test_cold_path_metadata_not_dict_returns_none(self):
|
||||||
|
"""If metadata _last_summary is not a dict, should return None summary."""
|
||||||
|
ac = _make_autocompact()
|
||||||
|
session = _make_session(metadata={"_last_summary": "not a dict"})
|
||||||
|
|
||||||
|
result_session, summary = ac.prepare_session(session, "cli:test")
|
||||||
|
|
||||||
|
assert result_session is session
|
||||||
|
assert summary is None
|
||||||
|
|
||||||
|
def test_hot_path_takes_priority_over_metadata(self):
|
||||||
|
"""Hot path (_summaries) should take priority over metadata."""
|
||||||
|
ac = _make_autocompact()
|
||||||
|
session = _make_session(metadata={
|
||||||
|
"_last_summary": {
|
||||||
|
"text": "Cold summary.",
|
||||||
|
"last_active": datetime(2026, 1, 1).isoformat(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
last_active = datetime(2026, 5, 13, 14, 0, 0)
|
||||||
|
ac._summaries["cli:test"] = ("Hot summary.", last_active)
|
||||||
|
|
||||||
|
_, summary = ac.prepare_session(session, "cli:test")
|
||||||
|
assert "Hot summary." in summary
|
||||||
|
# After hot path pops, cold path would kick in on next call
|
||||||
333
tests/agent/test_context_builder.py
Normal file
333
tests/agent/test_context_builder.py
Normal file
@ -0,0 +1,333 @@
|
|||||||
|
"""Tests for ContextBuilder — system prompt and message assembly."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.context import ContextBuilder
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _builder(tmp_path: Path, **kw) -> ContextBuilder:
|
||||||
|
return ContextBuilder(workspace=tmp_path, **kw)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _build_runtime_context (static)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildRuntimeContext:
|
||||||
|
def test_time_only(self):
|
||||||
|
ctx = ContextBuilder._build_runtime_context(None, None)
|
||||||
|
assert "[Runtime Context" in ctx
|
||||||
|
assert "[/Runtime Context]" in ctx
|
||||||
|
assert "Current Time:" in ctx
|
||||||
|
assert "Channel:" not in ctx
|
||||||
|
|
||||||
|
def test_with_channel_and_chat_id(self):
|
||||||
|
ctx = ContextBuilder._build_runtime_context("telegram", "chat123")
|
||||||
|
assert "Channel: telegram" in ctx
|
||||||
|
assert "Chat ID: chat123" in ctx
|
||||||
|
|
||||||
|
def test_with_sender_id(self):
|
||||||
|
ctx = ContextBuilder._build_runtime_context("cli", "direct", sender_id="user1")
|
||||||
|
assert "Sender ID: user1" in ctx
|
||||||
|
|
||||||
|
def test_with_timezone(self):
|
||||||
|
ctx = ContextBuilder._build_runtime_context(None, None, timezone="Asia/Shanghai")
|
||||||
|
assert "Current Time:" in ctx
|
||||||
|
|
||||||
|
def test_no_channel_no_chat_id_omits_both(self):
|
||||||
|
ctx = ContextBuilder._build_runtime_context(None, None)
|
||||||
|
assert "Channel:" not in ctx
|
||||||
|
assert "Chat ID:" not in ctx
|
||||||
|
|
||||||
|
def test_no_sender_id_omits(self):
|
||||||
|
ctx = ContextBuilder._build_runtime_context("cli", "direct")
|
||||||
|
assert "Sender ID:" not in ctx
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _merge_message_content (static)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMergeMessageContent:
|
||||||
|
def test_str_plus_str(self):
|
||||||
|
result = ContextBuilder._merge_message_content("hello", "world")
|
||||||
|
assert result == "hello\n\nworld"
|
||||||
|
|
||||||
|
def test_empty_left_plus_str(self):
|
||||||
|
result = ContextBuilder._merge_message_content("", "world")
|
||||||
|
assert result == "world"
|
||||||
|
|
||||||
|
def test_list_plus_list(self):
|
||||||
|
left = [{"type": "text", "text": "a"}]
|
||||||
|
right = [{"type": "text", "text": "b"}]
|
||||||
|
result = ContextBuilder._merge_message_content(left, right)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0]["text"] == "a"
|
||||||
|
assert result[1]["text"] == "b"
|
||||||
|
|
||||||
|
def test_str_plus_list(self):
|
||||||
|
right = [{"type": "text", "text": "b"}]
|
||||||
|
result = ContextBuilder._merge_message_content("hello", right)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0]["text"] == "hello"
|
||||||
|
assert result[1]["text"] == "b"
|
||||||
|
|
||||||
|
def test_list_plus_str(self):
|
||||||
|
left = [{"type": "text", "text": "a"}]
|
||||||
|
result = ContextBuilder._merge_message_content(left, "world")
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0]["text"] == "a"
|
||||||
|
assert result[1]["text"] == "world"
|
||||||
|
|
||||||
|
def test_none_plus_str(self):
|
||||||
|
result = ContextBuilder._merge_message_content(None, "hello")
|
||||||
|
assert result == [{"type": "text", "text": "hello"}]
|
||||||
|
|
||||||
|
def test_str_plus_none(self):
|
||||||
|
result = ContextBuilder._merge_message_content("hello", None)
|
||||||
|
assert result == [{"type": "text", "text": "hello"}]
|
||||||
|
|
||||||
|
def test_none_plus_none(self):
|
||||||
|
result = ContextBuilder._merge_message_content(None, None)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_list_items_not_dicts_wrapped(self):
|
||||||
|
result = ContextBuilder._merge_message_content(["raw_item"], None)
|
||||||
|
assert result == [{"type": "text", "text": "raw_item"}]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _load_bootstrap_files
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoadBootstrapFiles:
|
||||||
|
def test_no_bootstrap_files(self, tmp_path):
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
assert builder._load_bootstrap_files() == ""
|
||||||
|
|
||||||
|
def test_agents_md(self, tmp_path):
|
||||||
|
(tmp_path / "AGENTS.md").write_text("Be helpful.", encoding="utf-8")
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
result = builder._load_bootstrap_files()
|
||||||
|
assert "## AGENTS.md" in result
|
||||||
|
assert "Be helpful." in result
|
||||||
|
|
||||||
|
def test_multiple_bootstrap_files(self, tmp_path):
|
||||||
|
(tmp_path / "AGENTS.md").write_text("Rules.", encoding="utf-8")
|
||||||
|
(tmp_path / "SOUL.md").write_text("Soul.", encoding="utf-8")
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
result = builder._load_bootstrap_files()
|
||||||
|
assert "## AGENTS.md" in result
|
||||||
|
assert "## SOUL.md" in result
|
||||||
|
assert "Rules." in result
|
||||||
|
assert "Soul." in result
|
||||||
|
|
||||||
|
def test_all_bootstrap_files(self, tmp_path):
|
||||||
|
for name in ContextBuilder.BOOTSTRAP_FILES:
|
||||||
|
(tmp_path / name).write_text(f"Content of {name}", encoding="utf-8")
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
result = builder._load_bootstrap_files()
|
||||||
|
for name in ContextBuilder.BOOTSTRAP_FILES:
|
||||||
|
assert f"## {name}" in result
|
||||||
|
|
||||||
|
def test_utf8_content(self, tmp_path):
|
||||||
|
(tmp_path / "AGENTS.md").write_text("用中文回复", encoding="utf-8")
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
result = builder._load_bootstrap_files()
|
||||||
|
assert "用中文回复" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _is_template_content (static)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsTemplateContent:
|
||||||
|
def test_nonexistent_template_returns_false(self):
|
||||||
|
assert ContextBuilder._is_template_content("anything", "nonexistent/path.md") is False
|
||||||
|
|
||||||
|
def test_content_matching_template(self):
|
||||||
|
from importlib.resources import files as pkg_files
|
||||||
|
tpl = pkg_files("nanobot") / "templates" / "memory" / "MEMORY.md"
|
||||||
|
if not tpl.is_file():
|
||||||
|
pytest.skip("MEMORY.md template not bundled")
|
||||||
|
original = tpl.read_text(encoding="utf-8")
|
||||||
|
assert ContextBuilder._is_template_content(original, "memory/MEMORY.md") is True
|
||||||
|
|
||||||
|
def test_modified_content_returns_false(self):
|
||||||
|
from importlib.resources import files as pkg_files
|
||||||
|
tpl = pkg_files("nanobot") / "templates" / "memory" / "MEMORY.md"
|
||||||
|
if not tpl.is_file():
|
||||||
|
pytest.skip("MEMORY.md template not bundled")
|
||||||
|
assert ContextBuilder._is_template_content("totally different", "memory/MEMORY.md") is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _build_user_content
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildUserContent:
|
||||||
|
def test_no_media_returns_string(self, tmp_path):
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
result = builder._build_user_content("hello", None)
|
||||||
|
assert result == "hello"
|
||||||
|
|
||||||
|
def test_empty_media_returns_string(self, tmp_path):
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
result = builder._build_user_content("hello", [])
|
||||||
|
assert result == "hello"
|
||||||
|
|
||||||
|
def test_nonexistent_media_file_returns_string(self, tmp_path):
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
result = builder._build_user_content("hello", ["/nonexistent/image.png"])
|
||||||
|
assert result == "hello"
|
||||||
|
|
||||||
|
def test_non_image_file_returns_string(self, tmp_path):
|
||||||
|
txt = tmp_path / "doc.txt"
|
||||||
|
txt.write_text("not an image", encoding="utf-8")
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
result = builder._build_user_content("hello", [str(txt)])
|
||||||
|
assert result == "hello"
|
||||||
|
|
||||||
|
def test_valid_image_returns_list(self, tmp_path):
|
||||||
|
png = tmp_path / "test.png"
|
||||||
|
png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 16)
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
result = builder._build_user_content("hello", [str(png)])
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0]["type"] == "image_url"
|
||||||
|
assert result[0]["image_url"]["url"].startswith("data:image/png;base64,")
|
||||||
|
assert result[1]["type"] == "text"
|
||||||
|
assert result[1]["text"] == "hello"
|
||||||
|
|
||||||
|
def test_image_meta_includes_path(self, tmp_path):
|
||||||
|
png = tmp_path / "test.png"
|
||||||
|
png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 16)
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
result = builder._build_user_content("hello", [str(png)])
|
||||||
|
assert "_meta" in result[0]
|
||||||
|
assert "path" in result[0]["_meta"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# build_system_prompt
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildSystemPrompt:
|
||||||
|
def test_returns_nonempty_string(self, tmp_path):
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
result = builder.build_system_prompt()
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
def test_includes_identity_section(self, tmp_path):
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
result = builder.build_system_prompt()
|
||||||
|
assert "workspace" in result.lower() or "python" in result.lower()
|
||||||
|
|
||||||
|
def test_includes_bootstrap_files(self, tmp_path):
|
||||||
|
(tmp_path / "AGENTS.md").write_text("Be helpful and concise.", encoding="utf-8")
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
result = builder.build_system_prompt()
|
||||||
|
assert "Be helpful and concise." in result
|
||||||
|
|
||||||
|
def test_includes_session_summary(self, tmp_path):
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
result = builder.build_system_prompt(session_summary="Previous chat about Python.")
|
||||||
|
assert "Previous chat about Python." in result
|
||||||
|
assert "[Archived Context Summary]" in result
|
||||||
|
|
||||||
|
def test_sections_separated_by_separator(self, tmp_path):
|
||||||
|
(tmp_path / "AGENTS.md").write_text("Rules.", encoding="utf-8")
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
result = builder.build_system_prompt(session_summary="Summary.")
|
||||||
|
assert "\n\n---\n\n" in result
|
||||||
|
|
||||||
|
def test_no_bootstrap_no_summary(self, tmp_path):
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
result = builder.build_system_prompt()
|
||||||
|
assert "## AGENTS.md" not in result
|
||||||
|
assert "[Archived Context Summary]" not in result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# build_messages
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildMessages:
|
||||||
|
def test_basic_empty_history(self, tmp_path):
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
messages = builder.build_messages([], "hello")
|
||||||
|
assert len(messages) == 2
|
||||||
|
assert messages[0]["role"] == "system"
|
||||||
|
assert messages[1]["role"] == "user"
|
||||||
|
assert "hello" in str(messages[1]["content"])
|
||||||
|
|
||||||
|
def test_runtime_context_injected(self, tmp_path):
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
messages = builder.build_messages([], "hello", channel="cli", chat_id="direct")
|
||||||
|
user_msg = str(messages[-1]["content"])
|
||||||
|
assert "[Runtime Context" in user_msg
|
||||||
|
assert "hello" in user_msg
|
||||||
|
|
||||||
|
def test_consecutive_same_role_merged(self, tmp_path):
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
history = [{"role": "user", "content": "previous user message"}]
|
||||||
|
messages = builder.build_messages(history, "new message")
|
||||||
|
assert len(messages) == 2 # system + merged user
|
||||||
|
assert "previous user message" in str(messages[1]["content"])
|
||||||
|
assert "new message" in str(messages[1]["content"])
|
||||||
|
|
||||||
|
def test_different_role_appended(self, tmp_path):
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
history = [{"role": "assistant", "content": "previous response"}]
|
||||||
|
messages = builder.build_messages(history, "new message")
|
||||||
|
assert len(messages) == 3 # system + assistant + user
|
||||||
|
|
||||||
|
def test_media_with_history(self, tmp_path):
|
||||||
|
png = tmp_path / "img.png"
|
||||||
|
png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 16)
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
history = [{"role": "assistant", "content": "see this"}]
|
||||||
|
messages = builder.build_messages(history, "check image", media=[str(png)])
|
||||||
|
user_msg = messages[-1]["content"]
|
||||||
|
assert isinstance(user_msg, list)
|
||||||
|
assert any(b.get("type") == "image_url" for b in user_msg)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# add_tool_result
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAddToolResult:
|
||||||
|
def test_appends_tool_message(self, tmp_path):
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
msgs = [{"role": "user", "content": "hello"}]
|
||||||
|
result = builder.add_tool_result(msgs, "call_123", "read_file", "file content")
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[1]["role"] == "tool"
|
||||||
|
assert result[1]["tool_call_id"] == "call_123"
|
||||||
|
assert result[1]["name"] == "read_file"
|
||||||
|
assert result[1]["content"] == "file content"
|
||||||
|
|
||||||
|
def test_returns_same_list(self, tmp_path):
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
msgs = []
|
||||||
|
result = builder.add_tool_result(msgs, "id", "tool", "ok")
|
||||||
|
assert result is msgs
|
||||||
301
tests/agent/test_loop_runner_integration.py
Normal file
301
tests/agent/test_loop_runner_integration.py
Normal file
@ -0,0 +1,301 @@
|
|||||||
|
"""Tests for AgentLoop integration with AgentRunner: streaming, think-filter, error handling, subagent."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||||
|
|
||||||
|
|
||||||
|
def _make_loop(tmp_path):
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
|
||||||
|
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||||
|
patch("nanobot.agent.loop.SessionManager"), \
|
||||||
|
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||||
|
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||||
|
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path)
|
||||||
|
return loop
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_loop_max_iterations_message_stays_stable(tmp_path):
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||||
|
content="working",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||||
|
))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
loop.tools.execute = AsyncMock(return_value="ok")
|
||||||
|
loop.max_iterations = 2
|
||||||
|
|
||||||
|
final_content, _, _, _, _ = await loop._run_agent_loop([])
|
||||||
|
|
||||||
|
assert final_content == (
|
||||||
|
"I reached the maximum number of tool call iterations (2) "
|
||||||
|
"without completing the task. You can try breaking the task into smaller steps."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp_path):
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
deltas: list[str] = []
|
||||||
|
endings: list[bool] = []
|
||||||
|
|
||||||
|
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
|
||||||
|
await on_content_delta("<think>hidden")
|
||||||
|
await on_content_delta("</think>Hello")
|
||||||
|
return LLMResponse(content="<think>hidden</think>Hello", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
loop.provider.chat_stream_with_retry = chat_stream_with_retry
|
||||||
|
|
||||||
|
async def on_stream(delta: str) -> None:
|
||||||
|
deltas.append(delta)
|
||||||
|
|
||||||
|
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||||
|
endings.append(resuming)
|
||||||
|
|
||||||
|
final_content, _, _, _, _ = await loop._run_agent_loop(
|
||||||
|
[],
|
||||||
|
on_stream=on_stream,
|
||||||
|
on_stream_end=on_stream_end,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert final_content == "Hello"
|
||||||
|
assert deltas == ["Hello"]
|
||||||
|
assert endings == [False]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_loop_stream_filter_hides_partial_trailing_think_prefix(tmp_path):
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
deltas: list[str] = []
|
||||||
|
|
||||||
|
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
|
||||||
|
await on_content_delta("Hello <thin")
|
||||||
|
await on_content_delta("k>hidden</think>World")
|
||||||
|
return LLMResponse(content="Hello <think>hidden</think>World", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
loop.provider.chat_stream_with_retry = chat_stream_with_retry
|
||||||
|
|
||||||
|
async def on_stream(delta: str) -> None:
|
||||||
|
deltas.append(delta)
|
||||||
|
|
||||||
|
final_content, _, _, _, _ = await loop._run_agent_loop([], on_stream=on_stream)
|
||||||
|
|
||||||
|
assert final_content == "Hello World"
|
||||||
|
assert deltas == ["Hello", " World"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_loop_stream_filter_hides_complete_trailing_think_tag(tmp_path):
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
deltas: list[str] = []
|
||||||
|
|
||||||
|
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
|
||||||
|
await on_content_delta("Hello <think>")
|
||||||
|
await on_content_delta("hidden</think>World")
|
||||||
|
return LLMResponse(content="Hello <think>hidden</think>World", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
loop.provider.chat_stream_with_retry = chat_stream_with_retry
|
||||||
|
|
||||||
|
async def on_stream(delta: str) -> None:
|
||||||
|
deltas.append(delta)
|
||||||
|
|
||||||
|
final_content, _, _, _, _ = await loop._run_agent_loop([], on_stream=on_stream)
|
||||||
|
|
||||||
|
assert final_content == "Hello World"
|
||||||
|
assert deltas == ["Hello", " World"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_loop_retries_think_only_final_response(tmp_path):
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(**kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(content="<think>hidden</think>", tool_calls=[], usage={})
|
||||||
|
return LLMResponse(content="Recovered answer", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
loop.provider.chat_with_retry = chat_with_retry
|
||||||
|
|
||||||
|
final_content, _, _, _, _ = await loop._run_agent_loop([])
|
||||||
|
|
||||||
|
assert final_content == "Recovered answer"
|
||||||
|
assert call_count["n"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_streamed_flag_not_set_on_llm_error(tmp_path):
|
||||||
|
"""When LLM errors during a streaming-capable channel interaction,
|
||||||
|
_streamed must NOT be set so ChannelManager delivers the error."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||||
|
error_resp = LLMResponse(
|
||||||
|
content="503 service unavailable", finish_reason="error", tool_calls=[], usage={},
|
||||||
|
)
|
||||||
|
loop.provider.chat_with_retry = AsyncMock(return_value=error_resp)
|
||||||
|
loop.provider.chat_stream_with_retry = AsyncMock(return_value=error_resp)
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
msg = InboundMessage(
|
||||||
|
channel="feishu", sender_id="u1", chat_id="c1", content="hi",
|
||||||
|
)
|
||||||
|
result = await loop._process_message(
|
||||||
|
msg,
|
||||||
|
on_stream=AsyncMock(),
|
||||||
|
on_stream_end=AsyncMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert "503" in result.content
|
||||||
|
assert not result.metadata.get("_streamed"), \
|
||||||
|
"_streamed must not be set when stop_reason is error"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ssrf_soft_block_can_finalize_after_streamed_tool_call(tmp_path):
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
tool_call_resp = LLMResponse(
|
||||||
|
content="checking metadata",
|
||||||
|
tool_calls=[ToolCallRequest(
|
||||||
|
id="call_ssrf",
|
||||||
|
name="exec",
|
||||||
|
arguments={"command": "curl http://169.254.169.254/latest/meta-data/"},
|
||||||
|
)],
|
||||||
|
usage={},
|
||||||
|
)
|
||||||
|
provider.chat_stream_with_retry = AsyncMock(side_effect=[
|
||||||
|
tool_call_resp,
|
||||||
|
LLMResponse(
|
||||||
|
content="I cannot access private URLs. Please share the local file.",
|
||||||
|
tool_calls=[],
|
||||||
|
usage={},
|
||||||
|
),
|
||||||
|
])
|
||||||
|
|
||||||
|
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
loop.tools.prepare_call = MagicMock(return_value=(None, {}, None))
|
||||||
|
loop.tools.execute = AsyncMock(return_value=(
|
||||||
|
"Error: Command blocked by safety guard (internal/private URL detected)"
|
||||||
|
))
|
||||||
|
|
||||||
|
result = await loop._process_message(
|
||||||
|
InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="hi"),
|
||||||
|
on_stream=AsyncMock(),
|
||||||
|
on_stream_end=AsyncMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.content == "I cannot access private URLs. Please share the local file."
|
||||||
|
assert result.metadata.get("_streamed") is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_next_turn_after_llm_error_keeps_turn_boundary(tmp_path):
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.agent.runner import _PERSISTED_MODEL_ERROR_PLACEHOLDER
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
provider.chat_with_retry = AsyncMock(side_effect=[
|
||||||
|
LLMResponse(content="429 rate limit exceeded", finish_reason="error", tool_calls=[], usage={}),
|
||||||
|
LLMResponse(content="Recovered answer", tool_calls=[], usage={}),
|
||||||
|
])
|
||||||
|
|
||||||
|
loop = AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model")
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
first = await loop._process_message(
|
||||||
|
InboundMessage(channel="cli", sender_id="user", chat_id="test", content="first question")
|
||||||
|
)
|
||||||
|
assert first is not None
|
||||||
|
assert first.content == "429 rate limit exceeded"
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
assert [
|
||||||
|
{key: value for key, value in message.items() if key in {"role", "content"}}
|
||||||
|
for message in session.messages
|
||||||
|
] == [
|
||||||
|
{"role": "user", "content": "first question"},
|
||||||
|
{"role": "assistant", "content": _PERSISTED_MODEL_ERROR_PLACEHOLDER},
|
||||||
|
]
|
||||||
|
|
||||||
|
second = await loop._process_message(
|
||||||
|
InboundMessage(channel="cli", sender_id="user", chat_id="test", content="second question")
|
||||||
|
)
|
||||||
|
assert second is not None
|
||||||
|
assert second.content == "Recovered answer"
|
||||||
|
|
||||||
|
request_messages = provider.chat_with_retry.await_args_list[1].kwargs["messages"]
|
||||||
|
non_system = [message for message in request_messages if message.get("role") != "system"]
|
||||||
|
assert non_system[0]["role"] == "user"
|
||||||
|
assert "first question" in non_system[0]["content"]
|
||||||
|
assert non_system[1]["role"] == "assistant"
|
||||||
|
assert _PERSISTED_MODEL_ERROR_PLACEHOLDER in non_system[1]["content"]
|
||||||
|
assert non_system[2]["role"] == "user"
|
||||||
|
assert "second question" in non_system[2]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, monkeypatch):
|
||||||
|
from nanobot.agent.subagent import SubagentManager, SubagentStatus
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||||
|
content="working",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||||
|
))
|
||||||
|
mgr = SubagentManager(
|
||||||
|
provider=provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
bus=bus,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
)
|
||||||
|
mgr._announce_result = AsyncMock()
|
||||||
|
|
||||||
|
async def fake_execute(self, **kwargs):
|
||||||
|
return "tool result"
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
|
||||||
|
|
||||||
|
status = SubagentStatus(task_id="sub-1", label="label", task_description="do task", started_at=time.monotonic())
|
||||||
|
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"}, status)
|
||||||
|
|
||||||
|
mgr._announce_result.assert_awaited_once()
|
||||||
|
args = mgr._announce_result.await_args.args
|
||||||
|
assert args[3] == "Task completed but no final response was generated."
|
||||||
|
assert args[5] == "ok"
|
||||||
File diff suppressed because it is too large
Load Diff
481
tests/agent/test_runner_core.py
Normal file
481
tests/agent/test_runner_core.py
Normal file
@ -0,0 +1,481 @@
|
|||||||
|
"""Tests for core AgentRunner behavior: message passing, iteration limits,
|
||||||
|
timeouts, empty-response handling, usage accumulation, and config passthrough."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_preserves_reasoning_fields_and_tool_results():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
captured_second_call: list[dict] = []
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content="thinking",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||||
|
reasoning_content="hidden reasoning",
|
||||||
|
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
|
||||||
|
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||||
|
)
|
||||||
|
captured_second_call[:] = messages
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(return_value="tool result")
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "do task"},
|
||||||
|
],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=3,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
assert result.tools_used == ["list_dir"]
|
||||||
|
assert result.tool_events == [
|
||||||
|
{"name": "list_dir", "status": "ok", "detail": "tool result"}
|
||||||
|
]
|
||||||
|
|
||||||
|
assistant_messages = [
|
||||||
|
msg for msg in captured_second_call
|
||||||
|
if msg.get("role") == "assistant" and msg.get("tool_calls")
|
||||||
|
]
|
||||||
|
assert len(assistant_messages) == 1
|
||||||
|
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
|
||||||
|
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]
|
||||||
|
assert any(
|
||||||
|
msg.get("role") == "tool" and msg.get("content") == "tool result"
|
||||||
|
for msg in captured_second_call
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_returns_max_iterations_fallback():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||||
|
content="still working",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||||
|
))
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(return_value="tool result")
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=2,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.stop_reason == "max_iterations"
|
||||||
|
assert result.final_content == (
|
||||||
|
"I reached the maximum number of tool call iterations (2) "
|
||||||
|
"without completing the task. You can try breaking the task into smaller steps."
|
||||||
|
)
|
||||||
|
assert result.messages[-1]["role"] == "assistant"
|
||||||
|
assert result.messages[-1]["content"] == result.final_content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_times_out_hung_llm_request():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
|
||||||
|
async def chat_with_retry(**kwargs):
|
||||||
|
await asyncio.sleep(3600)
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
started = time.monotonic()
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "hello"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
llm_timeout_s=0.05,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert (time.monotonic() - started) < 1.0
|
||||||
|
assert result.stop_reason == "error"
|
||||||
|
assert "timed out" in (result.final_content or "").lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_replaces_empty_tool_result_with_marker():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
captured_second_call: list[dict] = []
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content="working",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="noop", arguments={})],
|
||||||
|
usage={},
|
||||||
|
)
|
||||||
|
captured_second_call[:] = messages
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(return_value="")
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "do task"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=2,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
|
||||||
|
assert tool_message["content"] == "(noop completed with no output)"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_retries_empty_final_response_with_summary_prompt():
|
||||||
|
"""Empty responses get 2 silent retries before finalization kicks in."""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
calls: list[dict] = []
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, tools=None, **kwargs):
|
||||||
|
calls.append({"messages": messages, "tools": tools})
|
||||||
|
if len(calls) <= 2:
|
||||||
|
return LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[],
|
||||||
|
usage={"prompt_tokens": 5, "completion_tokens": 1},
|
||||||
|
)
|
||||||
|
return LLMResponse(
|
||||||
|
content="final answer",
|
||||||
|
tool_calls=[],
|
||||||
|
usage={"prompt_tokens": 3, "completion_tokens": 7},
|
||||||
|
)
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "do task"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=3,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "final answer"
|
||||||
|
# 2 silent retries (iterations 0,1) + finalization on iteration 1
|
||||||
|
assert len(calls) == 3
|
||||||
|
assert calls[0]["tools"] is not None
|
||||||
|
assert calls[1]["tools"] is not None
|
||||||
|
assert calls[2]["tools"] is None
|
||||||
|
assert result.usage["prompt_tokens"] == 13
|
||||||
|
assert result.usage["completion_tokens"] == 9
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_uses_specific_message_after_empty_finalization_retry():
|
||||||
|
"""After silent retries + finalization all return empty, stop_reason is empty_final_response."""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
return LLMResponse(content=None, tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "do task"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=3,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
|
assert result.stop_reason == "empty_final_response"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_empty_response_does_not_break_tool_chain():
|
||||||
|
"""An empty intermediate response must not kill an ongoing tool chain.
|
||||||
|
|
||||||
|
Sequence: tool_call -> empty -> tool_call -> final text.
|
||||||
|
The runner should recover via silent retry and complete normally.
|
||||||
|
"""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, tools=None, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[ToolCallRequest(id="tc1", name="read_file", arguments={"path": "a.txt"})],
|
||||||
|
usage={"prompt_tokens": 10, "completion_tokens": 5},
|
||||||
|
)
|
||||||
|
if call_count == 2:
|
||||||
|
return LLMResponse(content=None, tool_calls=[], usage={"prompt_tokens": 10, "completion_tokens": 1})
|
||||||
|
if call_count == 3:
|
||||||
|
return LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[ToolCallRequest(id="tc2", name="read_file", arguments={"path": "b.txt"})],
|
||||||
|
usage={"prompt_tokens": 10, "completion_tokens": 5},
|
||||||
|
)
|
||||||
|
return LLMResponse(
|
||||||
|
content="Here are the results.",
|
||||||
|
tool_calls=[],
|
||||||
|
usage={"prompt_tokens": 10, "completion_tokens": 10},
|
||||||
|
)
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
provider.chat_stream_with_retry = chat_with_retry
|
||||||
|
|
||||||
|
async def fake_tool(name, args, **kw):
|
||||||
|
return "file content"
|
||||||
|
|
||||||
|
tool_registry = MagicMock()
|
||||||
|
tool_registry.get_definitions.return_value = [{"type": "function", "function": {"name": "read_file"}}]
|
||||||
|
tool_registry.execute = AsyncMock(side_effect=fake_tool)
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "read both files"}],
|
||||||
|
tools=tool_registry,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=10,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "Here are the results."
|
||||||
|
assert result.stop_reason == "completed"
|
||||||
|
assert call_count == 4
|
||||||
|
assert "read_file" in result.tools_used
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_accumulates_usage_and_preserves_cached_tokens():
|
||||||
|
"""Runner should accumulate prompt/completion tokens across iterations
|
||||||
|
and preserve cached_tokens from provider responses."""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content="thinking",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})],
|
||||||
|
usage={"prompt_tokens": 100, "completion_tokens": 10, "cached_tokens": 80},
|
||||||
|
)
|
||||||
|
return LLMResponse(
|
||||||
|
content="done",
|
||||||
|
tool_calls=[],
|
||||||
|
usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150},
|
||||||
|
)
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(return_value="file content")
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "do task"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=3,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Usage should be accumulated across iterations
|
||||||
|
assert result.usage["prompt_tokens"] == 300 # 100 + 200
|
||||||
|
assert result.usage["completion_tokens"] == 30 # 10 + 20
|
||||||
|
assert result.usage["cached_tokens"] == 230 # 80 + 150
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_binds_on_retry_wait_to_retry_callback_not_progress():
|
||||||
|
"""Regression: provider retry heartbeats must route through
|
||||||
|
``retry_wait_callback``, not ``progress_callback``. Binding them to
|
||||||
|
the progress callback (as an earlier runtime refactor did) caused
|
||||||
|
internal retry diagnostics like "Model request failed, retry in 1s"
|
||||||
|
to leak to end-user channels as normal progress updates.
|
||||||
|
"""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
captured: dict = {}
|
||||||
|
|
||||||
|
async def chat_with_retry(**kwargs):
|
||||||
|
captured.update(kwargs)
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
progress_cb = AsyncMock()
|
||||||
|
retry_wait_cb = AsyncMock()
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
progress_callback=progress_cb,
|
||||||
|
retry_wait_callback=retry_wait_cb,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert captured["on_retry_wait"] is retry_wait_cb
|
||||||
|
assert captured["on_retry_wait"] is not progress_cb
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Config passthrough tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_passes_temperature_to_provider():
|
||||||
|
"""temperature from AgentRunSpec should reach provider.chat_with_retry."""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
captured: dict = {}
|
||||||
|
|
||||||
|
async def chat_with_retry(**kwargs):
|
||||||
|
captured.update(kwargs)
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "hi"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
temperature=0.7,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert captured["temperature"] == 0.7
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_passes_max_tokens_to_provider():
|
||||||
|
"""max_tokens from AgentRunSpec should reach provider.chat_with_retry."""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
captured: dict = {}
|
||||||
|
|
||||||
|
async def chat_with_retry(**kwargs):
|
||||||
|
captured.update(kwargs)
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "hi"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
max_tokens=8192,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert captured["max_tokens"] == 8192
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_passes_reasoning_effort_to_provider():
|
||||||
|
"""reasoning_effort from AgentRunSpec should reach provider.chat_with_retry."""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
captured: dict = {}
|
||||||
|
|
||||||
|
async def chat_with_retry(**kwargs):
|
||||||
|
captured.update(kwargs)
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "hi"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
reasoning_effort="high",
|
||||||
|
))
|
||||||
|
|
||||||
|
assert captured["reasoning_effort"] == "high"
|
||||||
171
tests/agent/test_runner_errors.py
Normal file
171
tests/agent/test_runner_errors.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
"""Tests for AgentRunner error handling: tool errors, LLM errors,
|
||||||
|
session message isolation, and tool result preservation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_returns_structured_tool_error():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||||
|
content="working",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||||
|
))
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(side_effect=RuntimeError("boom"))
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=2,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
fail_on_tool_error=True,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.stop_reason == "tool_error"
|
||||||
|
assert result.error == "Error: RuntimeError: boom"
|
||||||
|
assert result.tool_events == [
|
||||||
|
{"name": "list_dir", "status": "error", "detail": "boom"}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_error_not_appended_to_session_messages():
|
||||||
|
"""When LLM returns finish_reason='error', the error content must NOT be
|
||||||
|
appended to the messages list (prevents polluting session history)."""
|
||||||
|
from nanobot.agent.runner import (
|
||||||
|
AgentRunSpec,
|
||||||
|
AgentRunner,
|
||||||
|
_PERSISTED_MODEL_ERROR_PLACEHOLDER,
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||||
|
content="429 rate limit exceeded", finish_reason="error", tool_calls=[], usage={},
|
||||||
|
))
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "hello"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=5,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.stop_reason == "error"
|
||||||
|
assert result.final_content == "429 rate limit exceeded"
|
||||||
|
assistant_msgs = [m for m in result.messages if m.get("role") == "assistant"]
|
||||||
|
assert all("429" not in (m.get("content") or "") for m in assistant_msgs), \
|
||||||
|
"Error content should not appear in session messages"
|
||||||
|
assert assistant_msgs[-1]["content"] == _PERSISTED_MODEL_ERROR_PLACEHOLDER
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_tool_error_sets_final_content():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
return LLMResponse(
|
||||||
|
content="working",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})],
|
||||||
|
usage={},
|
||||||
|
)
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(side_effect=RuntimeError("boom"))
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "do task"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
fail_on_tool_error=True,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "Error: RuntimeError: boom"
|
||||||
|
assert result.stop_reason == "tool_error"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_tool_error_preserves_tool_results_in_messages():
|
||||||
|
"""When a tool raises a fatal error, its results must still be appended
|
||||||
|
to messages so the session never contains orphan tool_calls (#2943)."""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
return LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(id="tc1", name="read_file", arguments={"path": "a"}),
|
||||||
|
ToolCallRequest(id="tc2", name="exec", arguments={"cmd": "bad"}),
|
||||||
|
],
|
||||||
|
usage={},
|
||||||
|
)
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
provider.chat_stream_with_retry = chat_with_retry
|
||||||
|
|
||||||
|
call_idx = 0
|
||||||
|
|
||||||
|
async def fake_execute(name, args, **kw):
|
||||||
|
nonlocal call_idx
|
||||||
|
call_idx += 1
|
||||||
|
if call_idx == 2:
|
||||||
|
raise RuntimeError("boom")
|
||||||
|
return "file content"
|
||||||
|
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(side_effect=fake_execute)
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "do stuff"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
fail_on_tool_error=True,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.stop_reason == "tool_error"
|
||||||
|
# Both tool results must be in messages even though tc2 had a fatal error.
|
||||||
|
tool_msgs = [m for m in result.messages if m.get("role") == "tool"]
|
||||||
|
assert len(tool_msgs) == 2
|
||||||
|
assert tool_msgs[0]["tool_call_id"] == "tc1"
|
||||||
|
assert tool_msgs[1]["tool_call_id"] == "tc2"
|
||||||
|
# The assistant message with tool_calls must precede the tool results.
|
||||||
|
asst_tc_idx = next(
|
||||||
|
i for i, m in enumerate(result.messages)
|
||||||
|
if m.get("role") == "assistant" and m.get("tool_calls")
|
||||||
|
)
|
||||||
|
tool_indices = [
|
||||||
|
i for i, m in enumerate(result.messages) if m.get("role") == "tool"
|
||||||
|
]
|
||||||
|
assert all(ti > asst_tc_idx for ti in tool_indices)
|
||||||
643
tests/agent/test_runner_governance.py
Normal file
643
tests/agent/test_runner_governance.py
Normal file
@ -0,0 +1,643 @@
|
|||||||
|
"""Tests for AgentRunner context governance: backfill, orphan cleanup, microcompact, snip_history."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||||
|
|
||||||
|
|
||||||
|
def _make_loop(tmp_path):
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
|
||||||
|
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||||
|
patch("nanobot.agent.loop.SessionManager"), \
|
||||||
|
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||||
|
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||||
|
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path)
|
||||||
|
return loop
|
||||||
|
|
||||||
|
async def test_runner_uses_raw_messages_when_context_governance_fails():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
captured_messages: list[dict] = []
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
captured_messages[:] = messages
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
initial_messages = [
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
]
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
runner._snip_history = MagicMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign]
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=initial_messages,
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
assert captured_messages == initial_messages
|
||||||
|
def test_snip_history_drops_orphaned_tool_results_from_trimmed_slice(monkeypatch):
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "old user"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "tool call",
|
||||||
|
"tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "ls", "arguments": "{}"}}],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "call_1", "content": "tool output"},
|
||||||
|
{"role": "assistant", "content": "after tool"},
|
||||||
|
]
|
||||||
|
spec = AgentRunSpec(
|
||||||
|
initial_messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
context_window_tokens=2000,
|
||||||
|
context_block_limit=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", lambda *_args, **_kwargs: (500, None))
|
||||||
|
token_sizes = {
|
||||||
|
"old user": 120,
|
||||||
|
"tool call": 120,
|
||||||
|
"tool output": 40,
|
||||||
|
"after tool": 40,
|
||||||
|
"system": 0,
|
||||||
|
}
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.agent.runner.estimate_message_tokens",
|
||||||
|
lambda msg: token_sizes.get(str(msg.get("content")), 40),
|
||||||
|
)
|
||||||
|
|
||||||
|
trimmed = runner._snip_history(spec, messages)
|
||||||
|
|
||||||
|
# After the fix, the user message is recovered so the sequence is valid
|
||||||
|
# for providers that require system → user (e.g. GLM error 1214).
|
||||||
|
assert trimmed[0]["role"] == "system"
|
||||||
|
non_system = [m for m in trimmed if m["role"] != "system"]
|
||||||
|
assert non_system[0]["role"] == "user", f"Expected user after system, got {non_system[0]['role']}"
|
||||||
|
async def test_backfill_missing_tool_results_inserts_error():
|
||||||
|
"""Orphaned tool_use (no matching tool_result) should get a synthetic error."""
|
||||||
|
from nanobot.agent.runner import AgentRunner, _BACKFILL_CONTENT
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{"id": "call_a", "type": "function", "function": {"name": "exec", "arguments": "{}"}},
|
||||||
|
{"id": "call_b", "type": "function", "function": {"name": "read_file", "arguments": "{}"}},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "call_a", "name": "exec", "content": "ok"},
|
||||||
|
]
|
||||||
|
result = AgentRunner._backfill_missing_tool_results(messages)
|
||||||
|
tool_msgs = [m for m in result if m.get("role") == "tool"]
|
||||||
|
assert len(tool_msgs) == 2
|
||||||
|
backfilled = [m for m in tool_msgs if m.get("tool_call_id") == "call_b"]
|
||||||
|
assert len(backfilled) == 1
|
||||||
|
assert backfilled[0]["content"] == _BACKFILL_CONTENT
|
||||||
|
assert backfilled[0]["name"] == "read_file"
|
||||||
|
|
||||||
|
|
||||||
|
def test_drop_orphan_tool_results_removes_unmatched_tool_messages():
|
||||||
|
from nanobot.agent.runner import AgentRunner
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "old user"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{"id": "call_ok", "type": "function", "function": {"name": "read_file", "arguments": "{}"}},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "call_ok", "name": "read_file", "content": "ok"},
|
||||||
|
{"role": "tool", "tool_call_id": "call_orphan", "name": "exec", "content": "stale"},
|
||||||
|
{"role": "assistant", "content": "after tool"},
|
||||||
|
]
|
||||||
|
|
||||||
|
cleaned = AgentRunner._drop_orphan_tool_results(messages)
|
||||||
|
|
||||||
|
assert cleaned == [
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "old user"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{"id": "call_ok", "type": "function", "function": {"name": "read_file", "arguments": "{}"}},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "call_ok", "name": "read_file", "content": "ok"},
|
||||||
|
{"role": "assistant", "content": "after tool"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_backfill_noop_when_complete():
|
||||||
|
"""Complete message chains should not be modified."""
|
||||||
|
from nanobot.agent.runner import AgentRunner
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{"id": "call_x", "type": "function", "function": {"name": "exec", "arguments": "{}"}},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "call_x", "name": "exec", "content": "done"},
|
||||||
|
{"role": "assistant", "content": "all good"},
|
||||||
|
]
|
||||||
|
result = AgentRunner._backfill_missing_tool_results(messages)
|
||||||
|
assert result is messages # same object — no copy
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_drops_orphan_tool_results_before_model_request():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
captured_messages: list[dict] = []
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
captured_messages[:] = messages
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "old user"},
|
||||||
|
{"role": "tool", "tool_call_id": "call_orphan", "name": "exec", "content": "stale"},
|
||||||
|
{"role": "assistant", "content": "after orphan"},
|
||||||
|
{"role": "user", "content": "new prompt"},
|
||||||
|
],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert all(
|
||||||
|
message.get("tool_call_id") != "call_orphan"
|
||||||
|
for message in captured_messages
|
||||||
|
if message.get("role") == "tool"
|
||||||
|
)
|
||||||
|
assert result.messages[2]["tool_call_id"] == "call_orphan"
|
||||||
|
assert result.final_content == "done"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_backfill_repairs_model_context_without_shifting_save_turn_boundary(tmp_path):
|
||||||
|
"""Historical backfill should not duplicate old tail messages on persist."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.agent.runner import _BACKFILL_CONTENT
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
response = LLMResponse(content="new answer", tool_calls=[], usage={})
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=response)
|
||||||
|
provider.chat_stream_with_retry = AsyncMock(return_value=response)
|
||||||
|
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=MessageBus(),
|
||||||
|
provider=provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
model="test-model",
|
||||||
|
)
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
session.messages = [
|
||||||
|
{"role": "user", "content": "old user", "timestamp": "2026-01-01T00:00:00"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_missing",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "read_file", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"timestamp": "2026-01-01T00:00:01",
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": "old tail", "timestamp": "2026-01-01T00:00:02"},
|
||||||
|
]
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
result = await loop._process_message(
|
||||||
|
InboundMessage(channel="cli", sender_id="user", chat_id="test", content="new prompt")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.content == "new answer"
|
||||||
|
|
||||||
|
request_messages = provider.chat_with_retry.await_args.kwargs["messages"]
|
||||||
|
synthetic = [
|
||||||
|
message
|
||||||
|
for message in request_messages
|
||||||
|
if message.get("role") == "tool" and message.get("tool_call_id") == "call_missing"
|
||||||
|
]
|
||||||
|
assert len(synthetic) == 1
|
||||||
|
assert synthetic[0]["content"] == _BACKFILL_CONTENT
|
||||||
|
|
||||||
|
session_after = loop.sessions.get_or_create("cli:test")
|
||||||
|
assert [
|
||||||
|
{
|
||||||
|
key: value
|
||||||
|
for key, value in message.items()
|
||||||
|
if key in {"role", "content", "tool_call_id", "name", "tool_calls"}
|
||||||
|
}
|
||||||
|
for message in session_after.messages
|
||||||
|
] == [
|
||||||
|
{"role": "user", "content": "old user"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_missing",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "read_file", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": "old tail"},
|
||||||
|
{"role": "user", "content": "new prompt"},
|
||||||
|
{"role": "assistant", "content": "new answer"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_backfill_only_mutates_model_context_not_returned_messages():
|
||||||
|
"""Runner should repair orphaned tool calls for the model without rewriting result.messages."""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _BACKFILL_CONTENT
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
captured_messages: list[dict] = []
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
captured_messages[:] = messages
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
initial_messages = [
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "old user"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_missing",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "read_file", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": "old tail"},
|
||||||
|
{"role": "user", "content": "new prompt"},
|
||||||
|
]
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=initial_messages,
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=3,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
synthetic = [
|
||||||
|
message
|
||||||
|
for message in captured_messages
|
||||||
|
if message.get("role") == "tool" and message.get("tool_call_id") == "call_missing"
|
||||||
|
]
|
||||||
|
assert len(synthetic) == 1
|
||||||
|
assert synthetic[0]["content"] == _BACKFILL_CONTENT
|
||||||
|
|
||||||
|
assert [
|
||||||
|
{
|
||||||
|
key: value
|
||||||
|
for key, value in message.items()
|
||||||
|
if key in {"role", "content", "tool_call_id", "name", "tool_calls"}
|
||||||
|
}
|
||||||
|
for message in result.messages
|
||||||
|
] == [
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "old user"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_missing",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "read_file", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": "old tail"},
|
||||||
|
{"role": "user", "content": "new prompt"},
|
||||||
|
{"role": "assistant", "content": "done"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Microcompact (stale tool result compaction)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_microcompact_replaces_old_tool_results():
|
||||||
|
"""Tool results beyond _MICROCOMPACT_KEEP_RECENT should be summarized."""
|
||||||
|
from nanobot.agent.runner import AgentRunner, _MICROCOMPACT_KEEP_RECENT
|
||||||
|
|
||||||
|
total = _MICROCOMPACT_KEEP_RECENT + 5
|
||||||
|
long_content = "x" * 600
|
||||||
|
messages: list[dict] = [{"role": "system", "content": "sys"}]
|
||||||
|
for i in range(total):
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [{"id": f"c{i}", "type": "function", "function": {"name": "read_file", "arguments": "{}"}}],
|
||||||
|
})
|
||||||
|
messages.append({
|
||||||
|
"role": "tool", "tool_call_id": f"c{i}", "name": "read_file",
|
||||||
|
"content": long_content,
|
||||||
|
})
|
||||||
|
|
||||||
|
result = AgentRunner._microcompact(messages)
|
||||||
|
tool_msgs = [m for m in result if m.get("role") == "tool"]
|
||||||
|
stale_count = total - _MICROCOMPACT_KEEP_RECENT
|
||||||
|
compacted = [m for m in tool_msgs if "omitted from context" in str(m.get("content", ""))]
|
||||||
|
preserved = [m for m in tool_msgs if m.get("content") == long_content]
|
||||||
|
assert len(compacted) == stale_count
|
||||||
|
assert len(preserved) == _MICROCOMPACT_KEEP_RECENT
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_microcompact_preserves_short_results():
|
||||||
|
"""Short tool results (< _MICROCOMPACT_MIN_CHARS) should not be replaced."""
|
||||||
|
from nanobot.agent.runner import AgentRunner, _MICROCOMPACT_KEEP_RECENT
|
||||||
|
|
||||||
|
total = _MICROCOMPACT_KEEP_RECENT + 5
|
||||||
|
messages: list[dict] = []
|
||||||
|
for i in range(total):
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [{"id": f"c{i}", "type": "function", "function": {"name": "exec", "arguments": "{}"}}],
|
||||||
|
})
|
||||||
|
messages.append({
|
||||||
|
"role": "tool", "tool_call_id": f"c{i}", "name": "exec",
|
||||||
|
"content": "short",
|
||||||
|
})
|
||||||
|
|
||||||
|
result = AgentRunner._microcompact(messages)
|
||||||
|
assert result is messages # no copy needed — all stale results are short
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_microcompact_skips_non_compactable_tools():
|
||||||
|
"""Non-compactable tools (e.g. 'message') should never be replaced."""
|
||||||
|
from nanobot.agent.runner import AgentRunner, _MICROCOMPACT_KEEP_RECENT
|
||||||
|
|
||||||
|
total = _MICROCOMPACT_KEEP_RECENT + 5
|
||||||
|
long_content = "y" * 1000
|
||||||
|
messages: list[dict] = []
|
||||||
|
for i in range(total):
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [{"id": f"c{i}", "type": "function", "function": {"name": "message", "arguments": "{}"}}],
|
||||||
|
})
|
||||||
|
messages.append({
|
||||||
|
"role": "tool", "tool_call_id": f"c{i}", "name": "message",
|
||||||
|
"content": long_content,
|
||||||
|
})
|
||||||
|
|
||||||
|
result = AgentRunner._microcompact(messages)
|
||||||
|
assert result is messages # no compactable tools found
|
||||||
|
|
||||||
|
|
||||||
|
def test_governance_repairs_orphans_after_snip():
|
||||||
|
"""After _snip_history clips an assistant+tool_calls, the second
|
||||||
|
_drop_orphan_tool_results pass must clean up the resulting orphans."""
|
||||||
|
from nanobot.agent.runner import AgentRunner
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "old msg"},
|
||||||
|
{"role": "assistant", "content": None,
|
||||||
|
"tool_calls": [{"id": "tc_old", "type": "function",
|
||||||
|
"function": {"name": "search", "arguments": "{}"}}]},
|
||||||
|
{"role": "tool", "tool_call_id": "tc_old", "name": "search",
|
||||||
|
"content": "old result"},
|
||||||
|
{"role": "assistant", "content": "old answer"},
|
||||||
|
{"role": "user", "content": "new msg"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Simulate snipping that keeps only the tail: drop the assistant with
|
||||||
|
# tool_calls but keep its tool result (orphan).
|
||||||
|
snipped = [
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "tool", "tool_call_id": "tc_old", "name": "search",
|
||||||
|
"content": "old result"},
|
||||||
|
{"role": "assistant", "content": "old answer"},
|
||||||
|
{"role": "user", "content": "new msg"},
|
||||||
|
]
|
||||||
|
|
||||||
|
cleaned = AgentRunner._drop_orphan_tool_results(snipped)
|
||||||
|
# The orphan tool result should be removed.
|
||||||
|
assert not any(
|
||||||
|
m.get("role") == "tool" and m.get("tool_call_id") == "tc_old"
|
||||||
|
for m in cleaned
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_governance_fallback_still_repairs_orphans():
|
||||||
|
"""When full governance fails, the fallback must still run
|
||||||
|
_drop_orphan_tool_results and _backfill_missing_tool_results."""
|
||||||
|
from nanobot.agent.runner import AgentRunner
|
||||||
|
|
||||||
|
# Messages with an orphan tool result (no matching assistant tool_call).
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
{"role": "tool", "tool_call_id": "orphan_tc", "name": "read",
|
||||||
|
"content": "stale"},
|
||||||
|
{"role": "assistant", "content": "hi"},
|
||||||
|
]
|
||||||
|
|
||||||
|
repaired = AgentRunner._drop_orphan_tool_results(messages)
|
||||||
|
repaired = AgentRunner._backfill_missing_tool_results(repaired)
|
||||||
|
# Orphan tool result should be gone.
|
||||||
|
assert not any(m.get("tool_call_id") == "orphan_tc" for m in repaired)
|
||||||
|
def test_snip_history_preserves_user_message_after_truncation(monkeypatch):
|
||||||
|
"""When _snip_history truncates messages and the only user message ends up
|
||||||
|
outside the kept window, the method must recover the nearest user message
|
||||||
|
so the resulting sequence is valid for providers like GLM (which reject
|
||||||
|
system→assistant with error 1214).
|
||||||
|
|
||||||
|
This reproduces the exact scenario from the bug report:
|
||||||
|
- Normal interaction: user asks, assistant calls tool, tool returns,
|
||||||
|
assistant replies.
|
||||||
|
- Injection adds a phantom user message, triggering more tool calls.
|
||||||
|
- _snip_history activates, keeping only recent assistant/tool pairs.
|
||||||
|
- The injected user message is in the truncated prefix and gets lost.
|
||||||
|
"""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "assistant", "content": "previous reply"},
|
||||||
|
{"role": "user", "content": ".nanobot的同目录"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{"id": "tc_1", "type": "function", "function": {"name": "exec", "arguments": "{}"}}],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "tc_1", "content": "tool output 1"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{"id": "tc_2", "type": "function", "function": {"name": "exec", "arguments": "{}"}}],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "tc_2", "content": "tool output 2"},
|
||||||
|
]
|
||||||
|
|
||||||
|
spec = AgentRunSpec(
|
||||||
|
initial_messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
context_window_tokens=2000,
|
||||||
|
context_block_limit=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make estimate_prompt_tokens_chain report above budget so _snip_history activates.
|
||||||
|
monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", lambda *_a, **_kw: (500, None))
|
||||||
|
# Make kept window small: only the last 2 messages fit the budget.
|
||||||
|
token_sizes = {
|
||||||
|
"system": 0,
|
||||||
|
"previous reply": 200,
|
||||||
|
".nanobot的同目录": 80,
|
||||||
|
"tool output 1": 80,
|
||||||
|
"tool output 2": 80,
|
||||||
|
}
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.agent.runner.estimate_message_tokens",
|
||||||
|
lambda msg: token_sizes.get(str(msg.get("content")), 100),
|
||||||
|
)
|
||||||
|
|
||||||
|
trimmed = runner._snip_history(spec, messages)
|
||||||
|
|
||||||
|
# The first non-system message MUST be user (not assistant).
|
||||||
|
non_system = [m for m in trimmed if m.get("role") != "system"]
|
||||||
|
assert non_system, "trimmed should contain at least one non-system message"
|
||||||
|
assert non_system[0]["role"] == "user", (
|
||||||
|
f"First non-system message must be 'user', got '{non_system[0]['role']}'. "
|
||||||
|
f"Roles: {[m['role'] for m in trimmed]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_snip_history_no_user_at_all_falls_back_gracefully(monkeypatch):
|
||||||
|
"""Edge case: if non_system has zero user messages, _snip_history should
|
||||||
|
still return a valid sequence (not crash or produce system→assistant)."""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "assistant", "content": "reply"},
|
||||||
|
{"role": "tool", "tool_call_id": "tc_1", "content": "result"},
|
||||||
|
{"role": "assistant", "content": "reply 2"},
|
||||||
|
{"role": "tool", "tool_call_id": "tc_2", "content": "result 2"},
|
||||||
|
]
|
||||||
|
|
||||||
|
spec = AgentRunSpec(
|
||||||
|
initial_messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
context_window_tokens=2000,
|
||||||
|
context_block_limit=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", lambda *_a, **_kw: (500, None))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.agent.runner.estimate_message_tokens",
|
||||||
|
lambda msg: 100,
|
||||||
|
)
|
||||||
|
|
||||||
|
trimmed = runner._snip_history(spec, messages)
|
||||||
|
|
||||||
|
# Should not crash. The result should still be a valid list.
|
||||||
|
assert isinstance(trimmed, list)
|
||||||
|
# Must have at least system.
|
||||||
|
assert any(m.get("role") == "system" for m in trimmed)
|
||||||
|
# The _enforce_role_alternation safety net must be able to fix whatever
|
||||||
|
# _snip_history returns here — verify it produces a valid sequence.
|
||||||
|
from nanobot.providers.base import LLMProvider
|
||||||
|
fixed = LLMProvider._enforce_role_alternation(trimmed)
|
||||||
|
non_system = [m for m in fixed if m["role"] != "system"]
|
||||||
|
if non_system:
|
||||||
|
assert non_system[0]["role"] in ("user", "tool"), (
|
||||||
|
f"Safety net should ensure first non-system is user/tool, got {non_system[0]['role']}"
|
||||||
|
)
|
||||||
172
tests/agent/test_runner_hooks.py
Normal file
172
tests/agent/test_runner_hooks.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
"""Tests for AgentRunner hook lifecycle: ordering, streaming deltas,
|
||||||
|
cached-token propagation, and hook context."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_calls_hooks_in_order():
|
||||||
|
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
call_count = {"n": 0}
|
||||||
|
events: list[tuple] = []
|
||||||
|
|
||||||
|
async def chat_with_retry(**kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content="thinking",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||||
|
)
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(return_value="tool result")
|
||||||
|
|
||||||
|
class RecordingHook(AgentHook):
|
||||||
|
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||||
|
events.append(("before_iteration", context.iteration))
|
||||||
|
|
||||||
|
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||||
|
events.append((
|
||||||
|
"before_execute_tools",
|
||||||
|
context.iteration,
|
||||||
|
[tc.name for tc in context.tool_calls],
|
||||||
|
))
|
||||||
|
|
||||||
|
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||||
|
events.append((
|
||||||
|
"after_iteration",
|
||||||
|
context.iteration,
|
||||||
|
context.final_content,
|
||||||
|
list(context.tool_results),
|
||||||
|
list(context.tool_events),
|
||||||
|
context.stop_reason,
|
||||||
|
))
|
||||||
|
|
||||||
|
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||||
|
events.append(("finalize_content", context.iteration, content))
|
||||||
|
return content.upper() if content else content
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=3,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
hook=RecordingHook(),
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "DONE"
|
||||||
|
assert events == [
|
||||||
|
("before_iteration", 0),
|
||||||
|
("before_execute_tools", 0, ["list_dir"]),
|
||||||
|
(
|
||||||
|
"after_iteration",
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
["tool result"],
|
||||||
|
[{"name": "list_dir", "status": "ok", "detail": "tool result"}],
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
("before_iteration", 1),
|
||||||
|
("finalize_content", 1, "done"),
|
||||||
|
("after_iteration", 1, "DONE", [], [], "completed"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_streaming_hook_receives_deltas_and_end_signal():
|
||||||
|
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
streamed: list[str] = []
|
||||||
|
endings: list[bool] = []
|
||||||
|
|
||||||
|
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
|
||||||
|
await on_content_delta("he")
|
||||||
|
await on_content_delta("llo")
|
||||||
|
return LLMResponse(content="hello", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||||
|
provider.chat_with_retry = AsyncMock()
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
class StreamingHook(AgentHook):
|
||||||
|
def wants_streaming(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||||
|
streamed.append(delta)
|
||||||
|
|
||||||
|
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||||
|
endings.append(resuming)
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
hook=StreamingHook(),
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "hello"
|
||||||
|
assert streamed == ["he", "llo"]
|
||||||
|
assert endings == [False]
|
||||||
|
provider.chat_with_retry.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_passes_cached_tokens_to_hook_context():
|
||||||
|
"""Hook context.usage should contain cached_tokens."""
|
||||||
|
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
captured_usage: list[dict] = []
|
||||||
|
|
||||||
|
class UsageHook(AgentHook):
|
||||||
|
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||||
|
captured_usage.append(dict(context.usage))
|
||||||
|
|
||||||
|
async def chat_with_retry(**kwargs):
|
||||||
|
return LLMResponse(
|
||||||
|
content="done",
|
||||||
|
tool_calls=[],
|
||||||
|
usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150},
|
||||||
|
)
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
hook=UsageHook(),
|
||||||
|
))
|
||||||
|
|
||||||
|
assert len(captured_usage) == 1
|
||||||
|
assert captured_usage[0]["cached_tokens"] == 150
|
||||||
1038
tests/agent/test_runner_injections.py
Normal file
1038
tests/agent/test_runner_injections.py
Normal file
File diff suppressed because it is too large
Load Diff
161
tests/agent/test_runner_persistence.py
Normal file
161
tests/agent/test_runner_persistence.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
"""Tests for tool result persistence: large results, pruning, temp files, cleanup."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||||
|
|
||||||
|
async def test_runner_persists_large_tool_results_for_follow_up_calls(tmp_path):
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
captured_second_call: list[dict] = []
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content="working",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_big", name="list_dir", arguments={"path": "."})],
|
||||||
|
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||||
|
)
|
||||||
|
captured_second_call[:] = messages
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(return_value="x" * 20_000)
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "do task"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=2,
|
||||||
|
workspace=tmp_path,
|
||||||
|
session_key="test:runner",
|
||||||
|
max_tool_result_chars=2048,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
|
||||||
|
assert "[tool output persisted]" in tool_message["content"]
|
||||||
|
assert "tool-results" in tool_message["content"]
|
||||||
|
assert (tmp_path / ".nanobot" / "tool-results" / "test_runner" / "call_big.txt").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_persist_tool_result_prunes_old_session_buckets(tmp_path):
|
||||||
|
from nanobot.utils.helpers import maybe_persist_tool_result
|
||||||
|
|
||||||
|
root = tmp_path / ".nanobot" / "tool-results"
|
||||||
|
old_bucket = root / "old_session"
|
||||||
|
recent_bucket = root / "recent_session"
|
||||||
|
old_bucket.mkdir(parents=True)
|
||||||
|
recent_bucket.mkdir(parents=True)
|
||||||
|
(old_bucket / "old.txt").write_text("old", encoding="utf-8")
|
||||||
|
(recent_bucket / "recent.txt").write_text("recent", encoding="utf-8")
|
||||||
|
|
||||||
|
stale = time.time() - (8 * 24 * 60 * 60)
|
||||||
|
os.utime(old_bucket, (stale, stale))
|
||||||
|
os.utime(old_bucket / "old.txt", (stale, stale))
|
||||||
|
|
||||||
|
persisted = maybe_persist_tool_result(
|
||||||
|
tmp_path,
|
||||||
|
"current:session",
|
||||||
|
"call_big",
|
||||||
|
"x" * 5000,
|
||||||
|
max_chars=64,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "[tool output persisted]" in persisted
|
||||||
|
assert not old_bucket.exists()
|
||||||
|
assert recent_bucket.exists()
|
||||||
|
assert (root / "current_session" / "call_big.txt").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_persist_tool_result_leaves_no_temp_files(tmp_path):
|
||||||
|
from nanobot.utils.helpers import maybe_persist_tool_result
|
||||||
|
|
||||||
|
root = tmp_path / ".nanobot" / "tool-results"
|
||||||
|
maybe_persist_tool_result(
|
||||||
|
tmp_path,
|
||||||
|
"current:session",
|
||||||
|
"call_big",
|
||||||
|
"x" * 5000,
|
||||||
|
max_chars=64,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (root / "current_session" / "call_big.txt").exists()
|
||||||
|
assert list((root / "current_session").glob("*.tmp")) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_persist_tool_result_logs_cleanup_failures(monkeypatch, tmp_path):
|
||||||
|
from nanobot.utils.helpers import maybe_persist_tool_result
|
||||||
|
|
||||||
|
warnings: list[str] = []
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.utils.helpers._cleanup_tool_result_buckets",
|
||||||
|
lambda *_args, **_kwargs: (_ for _ in ()).throw(OSError("busy")),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.utils.helpers.logger.exception",
|
||||||
|
lambda message, *args: warnings.append(message.format(*args)),
|
||||||
|
)
|
||||||
|
|
||||||
|
persisted = maybe_persist_tool_result(
|
||||||
|
tmp_path,
|
||||||
|
"current:session",
|
||||||
|
"call_big",
|
||||||
|
"x" * 5000,
|
||||||
|
max_chars=64,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "[tool output persisted]" in persisted
|
||||||
|
assert warnings and "Failed to clean stale tool result buckets" in warnings[0]
|
||||||
|
async def test_runner_keeps_going_when_tool_result_persistence_fails():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
captured_second_call: list[dict] = []
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content="working",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||||
|
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||||
|
)
|
||||||
|
captured_second_call[:] = messages
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(return_value="tool result")
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
with patch("nanobot.agent.runner.maybe_persist_tool_result", side_effect=RuntimeError("disk full")):
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "do task"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=2,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
|
||||||
|
assert tool_message["content"] == "tool result"
|
||||||
244
tests/agent/test_runner_safety.py
Normal file
244
tests/agent/test_runner_safety.py
Normal file
@ -0,0 +1,244 @@
|
|||||||
|
"""Tests for AgentRunner security: workspace violations, SSRF, shell guard, throttling."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||||
|
|
||||||
|
async def test_runner_does_not_abort_on_workspace_violation_anymore():
|
||||||
|
"""v2 behavior: workspace-bound rejections are *soft* tool errors.
|
||||||
|
|
||||||
|
Previously (PR #3493) any workspace boundary error became a fatal
|
||||||
|
RuntimeError that aborted the turn. That silently killed legitimate
|
||||||
|
workspace commands once the heuristic guard misfired (#3599 #3605), so
|
||||||
|
we now hand the error back to the LLM as a recoverable tool result and
|
||||||
|
rely on ``repeated_workspace_violation_error`` to throttle bypass loops.
|
||||||
|
"""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(side_effect=[
|
||||||
|
LLMResponse(
|
||||||
|
content="trying outside",
|
||||||
|
tool_calls=[ToolCallRequest(
|
||||||
|
id="call_1", name="read_file", arguments={"path": "/tmp/outside.md"},
|
||||||
|
)],
|
||||||
|
),
|
||||||
|
LLMResponse(content="ok, telling the user instead", tool_calls=[]),
|
||||||
|
])
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(
|
||||||
|
side_effect=PermissionError(
|
||||||
|
"Path /tmp/outside.md is outside allowed directory /workspace"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=3,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert provider.chat_with_retry.await_count == 2, (
|
||||||
|
"workspace violation must NOT short-circuit the loop"
|
||||||
|
)
|
||||||
|
assert result.stop_reason != "tool_error"
|
||||||
|
assert result.error is None
|
||||||
|
assert result.final_content == "ok, telling the user instead"
|
||||||
|
assert result.tool_events and result.tool_events[0]["status"] == "error"
|
||||||
|
# Detail still carries the workspace_violation breadcrumb for telemetry,
|
||||||
|
# but the runner did not raise.
|
||||||
|
assert "workspace_violation" in result.tool_events[0]["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_ssrf_violation_recognizes_private_url_blocks():
|
||||||
|
"""SSRF rejections are classified separately from workspace boundaries."""
|
||||||
|
from nanobot.agent.runner import AgentRunner
|
||||||
|
|
||||||
|
ssrf_msg = "Error: Command blocked by safety guard (internal/private URL detected)"
|
||||||
|
assert AgentRunner._is_ssrf_violation(ssrf_msg) is True
|
||||||
|
assert AgentRunner._is_ssrf_violation(
|
||||||
|
"URL validation failed: Blocked: host resolves to private/internal address 192.168.1.2"
|
||||||
|
) is True
|
||||||
|
|
||||||
|
# Workspace-bound markers are NOT classified as SSRF.
|
||||||
|
assert AgentRunner._is_ssrf_violation(
|
||||||
|
"Error: Command blocked by safety guard (path outside working dir)"
|
||||||
|
) is False
|
||||||
|
assert AgentRunner._is_ssrf_violation(
|
||||||
|
"Path /tmp/x is outside allowed directory /ws"
|
||||||
|
) is False
|
||||||
|
# Deny / allowlist filter messages stay non-fatal too.
|
||||||
|
assert AgentRunner._is_ssrf_violation(
|
||||||
|
"Error: Command blocked by deny pattern filter"
|
||||||
|
) is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_returns_non_retryable_hint_on_ssrf_violation():
|
||||||
|
"""SSRF stays blocked, but the runtime gives the LLM a final chance to recover."""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(side_effect=[
|
||||||
|
LLMResponse(
|
||||||
|
content="curl-ing metadata",
|
||||||
|
tool_calls=[ToolCallRequest(
|
||||||
|
id="call_ssrf",
|
||||||
|
name="exec",
|
||||||
|
arguments={"command": "curl http://169.254.169.254"},
|
||||||
|
)],
|
||||||
|
),
|
||||||
|
LLMResponse(
|
||||||
|
content="I cannot access that private URL. Please share local files.",
|
||||||
|
tool_calls=[],
|
||||||
|
),
|
||||||
|
])
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(return_value=(
|
||||||
|
"Error: Command blocked by safety guard (internal/private URL detected)"
|
||||||
|
))
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=3,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert provider.chat_with_retry.await_count == 2
|
||||||
|
assert result.stop_reason == "completed"
|
||||||
|
assert result.error is None
|
||||||
|
assert result.final_content == "I cannot access that private URL. Please share local files."
|
||||||
|
assert result.tool_events and result.tool_events[0]["detail"].startswith("ssrf_violation:")
|
||||||
|
tool_messages = [m for m in result.messages if m.get("role") == "tool"]
|
||||||
|
assert tool_messages
|
||||||
|
assert "non-bypassable security boundary" in tool_messages[0]["content"]
|
||||||
|
assert "Do not retry" in tool_messages[0]["content"]
|
||||||
|
assert "tools.ssrfWhitelist" in tool_messages[0]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_lets_llm_recover_from_shell_guard_path_outside():
|
||||||
|
"""Reporter scenario for #3599 / #3605 -- guard hit, agent recovers.
|
||||||
|
|
||||||
|
The shell `_guard_command` heuristic fires on `2>/dev/null`-style
|
||||||
|
redirects and other shell idioms. Before v2 that abort'd the whole
|
||||||
|
turn (silent hang on Telegram per #3605); now the LLM gets the soft
|
||||||
|
error back and can finalize on the next iteration.
|
||||||
|
"""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
captured_second_call: list[dict] = []
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
if provider.chat_with_retry.await_count == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content="trying noisy cleanup",
|
||||||
|
tool_calls=[ToolCallRequest(
|
||||||
|
id="call_blocked",
|
||||||
|
name="exec",
|
||||||
|
arguments={"command": "rm scratch.txt 2>/dev/null"},
|
||||||
|
)],
|
||||||
|
)
|
||||||
|
captured_second_call[:] = list(messages)
|
||||||
|
return LLMResponse(content="recovered final answer", tool_calls=[])
|
||||||
|
|
||||||
|
provider.chat_with_retry = AsyncMock(side_effect=chat_with_retry)
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(
|
||||||
|
return_value="Error: Command blocked by safety guard (path outside working dir)"
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=3,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert provider.chat_with_retry.await_count == 2, (
|
||||||
|
"guard hit must NOT short-circuit the loop -- LLM should get a second turn"
|
||||||
|
)
|
||||||
|
assert result.stop_reason != "tool_error"
|
||||||
|
assert result.error is None
|
||||||
|
assert result.final_content == "recovered final answer"
|
||||||
|
assert result.tool_events and result.tool_events[0]["status"] == "error"
|
||||||
|
# v2: detail keeps the breadcrumb but the runner did not raise.
|
||||||
|
assert "workspace_violation" in result.tool_events[0]["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_throttles_repeated_workspace_bypass_attempts():
|
||||||
|
"""#3493 motivation: stop the LLM bypass loop without aborting the turn.
|
||||||
|
|
||||||
|
LLM keeps switching tools (read_file -> exec cat -> python -c open(...))
|
||||||
|
against the same outside path. After the soft retry budget is exhausted
|
||||||
|
the runner replaces the tool result with a hard "stop trying" message
|
||||||
|
so the model finally gives up and surfaces the boundary to the user.
|
||||||
|
"""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
bypass_attempts = [
|
||||||
|
ToolCallRequest(
|
||||||
|
id=f"a{i}", name="exec",
|
||||||
|
arguments={"command": f"cat /Users/x/Downloads/01.md # try {i}"},
|
||||||
|
)
|
||||||
|
for i in range(4)
|
||||||
|
]
|
||||||
|
responses: list[LLMResponse] = [
|
||||||
|
LLMResponse(content=f"try {i}", tool_calls=[bypass_attempts[i]])
|
||||||
|
for i in range(4)
|
||||||
|
]
|
||||||
|
responses.append(LLMResponse(content="ok telling user", tool_calls=[]))
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(side_effect=responses)
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(
|
||||||
|
return_value="Error: Command blocked by safety guard (path outside working dir)"
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=10,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
# All 4 bypass attempts surface to the LLM (no fatal abort), and the
|
||||||
|
# runner finally completes once the LLM stops asking.
|
||||||
|
assert result.stop_reason != "tool_error"
|
||||||
|
assert result.error is None
|
||||||
|
assert result.final_content == "ok telling user"
|
||||||
|
# The third+ attempts must have been escalated -- look at the events.
|
||||||
|
escalated = [
|
||||||
|
ev for ev in result.tool_events
|
||||||
|
if ev["status"] == "error"
|
||||||
|
and ev["detail"].startswith("workspace_violation_escalated:")
|
||||||
|
]
|
||||||
|
assert escalated, (
|
||||||
|
"expected at least one escalated workspace_violation event, got: "
|
||||||
|
f"{result.tool_events}"
|
||||||
|
)
|
||||||
181
tests/agent/test_runner_tool_execution.py
Normal file
181
tests/agent/test_runner_tool_execution.py
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
"""Tests for AgentRunner tool execution: batching, concurrency, exclusive tools."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.tools.base import Tool
|
||||||
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||||
|
|
||||||
|
class _DelayTool(Tool):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
*,
|
||||||
|
delay: float,
|
||||||
|
read_only: bool,
|
||||||
|
shared_events: list[str],
|
||||||
|
exclusive: bool = False,
|
||||||
|
):
|
||||||
|
self._name = name
|
||||||
|
self._delay = delay
|
||||||
|
self._read_only = read_only
|
||||||
|
self._shared_events = shared_events
|
||||||
|
self._exclusive = exclusive
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict:
|
||||||
|
return {"type": "object", "properties": {}, "required": []}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def read_only(self) -> bool:
|
||||||
|
return self._read_only
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exclusive(self) -> bool:
|
||||||
|
return self._exclusive
|
||||||
|
|
||||||
|
async def execute(self, **kwargs):
|
||||||
|
self._shared_events.append(f"start:{self._name}")
|
||||||
|
await asyncio.sleep(self._delay)
|
||||||
|
self._shared_events.append(f"end:{self._name}")
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_batches_read_only_tools_before_exclusive_work():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
tools = ToolRegistry()
|
||||||
|
shared_events: list[str] = []
|
||||||
|
read_a = _DelayTool("read_a", delay=0.05, read_only=True, shared_events=shared_events)
|
||||||
|
read_b = _DelayTool("read_b", delay=0.05, read_only=True, shared_events=shared_events)
|
||||||
|
write_a = _DelayTool("write_a", delay=0.01, read_only=False, shared_events=shared_events)
|
||||||
|
tools.register(read_a)
|
||||||
|
tools.register(read_b)
|
||||||
|
tools.register(write_a)
|
||||||
|
|
||||||
|
runner = AgentRunner(MagicMock())
|
||||||
|
await runner._execute_tools(
|
||||||
|
AgentRunSpec(
|
||||||
|
initial_messages=[],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
concurrent_tools=True,
|
||||||
|
),
|
||||||
|
[
|
||||||
|
ToolCallRequest(id="ro1", name="read_a", arguments={}),
|
||||||
|
ToolCallRequest(id="ro2", name="read_b", arguments={}),
|
||||||
|
ToolCallRequest(id="rw1", name="write_a", arguments={}),
|
||||||
|
],
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert shared_events[0:2] == ["start:read_a", "start:read_b"]
|
||||||
|
assert "end:read_a" in shared_events and "end:read_b" in shared_events
|
||||||
|
assert shared_events.index("end:read_a") < shared_events.index("start:write_a")
|
||||||
|
assert shared_events.index("end:read_b") < shared_events.index("start:write_a")
|
||||||
|
assert shared_events[-2:] == ["start:write_a", "end:write_a"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_does_not_batch_exclusive_read_only_tools():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
tools = ToolRegistry()
|
||||||
|
shared_events: list[str] = []
|
||||||
|
read_a = _DelayTool("read_a", delay=0.03, read_only=True, shared_events=shared_events)
|
||||||
|
read_b = _DelayTool("read_b", delay=0.03, read_only=True, shared_events=shared_events)
|
||||||
|
ddg_like = _DelayTool(
|
||||||
|
"ddg_like",
|
||||||
|
delay=0.01,
|
||||||
|
read_only=True,
|
||||||
|
shared_events=shared_events,
|
||||||
|
exclusive=True,
|
||||||
|
)
|
||||||
|
tools.register(read_a)
|
||||||
|
tools.register(ddg_like)
|
||||||
|
tools.register(read_b)
|
||||||
|
|
||||||
|
runner = AgentRunner(MagicMock())
|
||||||
|
await runner._execute_tools(
|
||||||
|
AgentRunSpec(
|
||||||
|
initial_messages=[],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
concurrent_tools=True,
|
||||||
|
),
|
||||||
|
[
|
||||||
|
ToolCallRequest(id="ro1", name="read_a", arguments={}),
|
||||||
|
ToolCallRequest(id="ddg1", name="ddg_like", arguments={}),
|
||||||
|
ToolCallRequest(id="ro2", name="read_b", arguments={}),
|
||||||
|
],
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert shared_events[0] == "start:read_a"
|
||||||
|
assert shared_events.index("end:read_a") < shared_events.index("start:ddg_like")
|
||||||
|
assert shared_events.index("end:ddg_like") < shared_events.index("start:read_b")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_blocks_repeated_external_fetches():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
captured_final_call: list[dict] = []
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] <= 3:
|
||||||
|
return LLMResponse(
|
||||||
|
content="working",
|
||||||
|
tool_calls=[ToolCallRequest(id=f"call_{call_count['n']}", name="web_fetch", arguments={"url": "https://example.com"})],
|
||||||
|
usage={},
|
||||||
|
)
|
||||||
|
captured_final_call[:] = messages
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(return_value="page content")
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "research task"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=4,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
assert tools.execute.await_count == 2
|
||||||
|
blocked_tool_message = [
|
||||||
|
msg for msg in captured_final_call
|
||||||
|
if msg.get("role") == "tool" and msg.get("tool_call_id") == "call_3"
|
||||||
|
][0]
|
||||||
|
assert "repeated external lookup blocked" in blocked_tool_message["content"]
|
||||||
@ -10,6 +10,7 @@ See: https://github.com/HKUDS/nanobot/issues/2966
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import MagicMock, patch, AsyncMock
|
from unittest.mock import MagicMock, patch, AsyncMock
|
||||||
@ -17,42 +18,47 @@ from unittest.mock import MagicMock, patch, AsyncMock
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMProvider
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
def _make_provider():
|
||||||
def mock_loop():
|
"""Create an LLM provider mock with required attributes."""
|
||||||
"""Create a minimal AgentLoop with mocked dependencies."""
|
from types import SimpleNamespace
|
||||||
with patch.object(AgentLoop, "__init__", lambda self: None):
|
provider = MagicMock()
|
||||||
loop = AgentLoop()
|
provider.get_default_model.return_value = "test-model"
|
||||||
loop.sessions = MagicMock()
|
provider.generation = SimpleNamespace(max_tokens=4096, temperature=0.1, reasoning_effort=None)
|
||||||
loop._pending_queues = {}
|
provider.estimate_prompt_tokens.return_value = (10_000, "test")
|
||||||
loop._session_locks = {}
|
return provider
|
||||||
loop._active_tasks = {}
|
|
||||||
loop._concurrency_gate = None
|
|
||||||
loop._RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
|
def _make_loop(tmp_path: Path) -> AgentLoop:
|
||||||
loop._PENDING_USER_TURN_KEY = "pending_user_turn"
|
"""Create a real AgentLoop with mocked provider — avoids patching __init__."""
|
||||||
loop.bus = MagicMock()
|
bus = MessageBus()
|
||||||
loop.bus.publish_outbound = AsyncMock()
|
provider = _make_provider()
|
||||||
loop.bus.publish_inbound = AsyncMock()
|
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||||
loop.commands = MagicMock()
|
patch("nanobot.agent.loop.SessionManager"), \
|
||||||
loop.commands.dispatch_priority = AsyncMock(return_value=None)
|
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||||
return loop
|
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||||
|
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path)
|
||||||
|
|
||||||
|
|
||||||
class TestStopPreservesContext:
|
class TestStopPreservesContext:
|
||||||
"""Verify that /stop restores partial context via checkpoint."""
|
"""Verify that /stop restores partial context via checkpoint."""
|
||||||
|
|
||||||
def test_restore_checkpoint_method_exists(self, mock_loop):
|
def test_restore_checkpoint_method_exists(self, tmp_path):
|
||||||
"""AgentLoop should have _restore_runtime_checkpoint."""
|
"""AgentLoop should have _restore_runtime_checkpoint."""
|
||||||
assert hasattr(mock_loop, "_restore_runtime_checkpoint")
|
loop = _make_loop(tmp_path)
|
||||||
|
assert hasattr(loop, "_restore_runtime_checkpoint")
|
||||||
|
|
||||||
def test_checkpoint_key_constant(self, mock_loop):
|
def test_checkpoint_key_constant(self, tmp_path):
|
||||||
"""The runtime checkpoint key should be defined."""
|
"""The runtime checkpoint key should be defined."""
|
||||||
assert mock_loop._RUNTIME_CHECKPOINT_KEY == "runtime_checkpoint"
|
loop = _make_loop(tmp_path)
|
||||||
|
assert loop._RUNTIME_CHECKPOINT_KEY == "runtime_checkpoint"
|
||||||
|
|
||||||
def test_cancel_dispatch_restores_checkpoint(self, mock_loop):
|
def test_cancel_dispatch_restores_checkpoint(self, tmp_path):
|
||||||
"""When a task is cancelled, the checkpoint should be restored."""
|
"""When a task is cancelled, the checkpoint should be restored."""
|
||||||
# Create a mock session with a checkpoint
|
loop = _make_loop(tmp_path)
|
||||||
session = MagicMock()
|
session = MagicMock()
|
||||||
session.metadata = {
|
session.metadata = {
|
||||||
"runtime_checkpoint": {
|
"runtime_checkpoint": {
|
||||||
@ -74,14 +80,11 @@ class TestStopPreservesContext:
|
|||||||
session.messages = [
|
session.messages = [
|
||||||
{"role": "user", "content": "Search for something"},
|
{"role": "user", "content": "Search for something"},
|
||||||
]
|
]
|
||||||
mock_loop.sessions.get_or_create.return_value = session
|
loop.sessions.get_or_create.return_value = session
|
||||||
|
|
||||||
# The restore method should add checkpoint messages to session history
|
restored = loop._restore_runtime_checkpoint(session)
|
||||||
restored = mock_loop._restore_runtime_checkpoint(session)
|
|
||||||
assert restored is True
|
assert restored is True
|
||||||
# After restore, session should have more messages
|
|
||||||
assert len(session.messages) > 1
|
assert len(session.messages) > 1
|
||||||
# The checkpoint should be cleared
|
|
||||||
assert "runtime_checkpoint" not in session.metadata
|
assert "runtime_checkpoint" not in session.metadata
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
558
tests/agent/test_subagent_lifecycle.py
Normal file
558
tests/agent/test_subagent_lifecycle.py
Normal file
@ -0,0 +1,558 @@
|
|||||||
|
"""Tests for SubagentManager lifecycle — spawn, run, announce, cancel."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.hook import AgentHookContext
|
||||||
|
from nanobot.agent.runner import AgentRunResult
|
||||||
|
from nanobot.agent.subagent import (
|
||||||
|
SubagentManager,
|
||||||
|
SubagentStatus,
|
||||||
|
_SubagentHook,
|
||||||
|
)
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMProvider
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _manager(tmp_path: Path, **kw) -> SubagentManager:
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
defaults = dict(
|
||||||
|
provider=provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
bus=MessageBus(),
|
||||||
|
model="test-model",
|
||||||
|
max_tool_result_chars=16_000,
|
||||||
|
)
|
||||||
|
defaults.update(kw)
|
||||||
|
return SubagentManager(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_hook_context(**overrides) -> AgentHookContext:
|
||||||
|
defaults = dict(
|
||||||
|
iteration=1,
|
||||||
|
tool_calls=[],
|
||||||
|
tool_events=[],
|
||||||
|
messages=[],
|
||||||
|
usage={},
|
||||||
|
error=None,
|
||||||
|
stop_reason="completed",
|
||||||
|
final_content="ok",
|
||||||
|
)
|
||||||
|
defaults.update(overrides)
|
||||||
|
return AgentHookContext(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SubagentStatus defaults
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSubagentStatus:
|
||||||
|
def test_defaults(self):
|
||||||
|
s = SubagentStatus(
|
||||||
|
task_id="abc", label="test", task_description="do stuff",
|
||||||
|
started_at=time.monotonic(),
|
||||||
|
)
|
||||||
|
assert s.phase == "initializing"
|
||||||
|
assert s.iteration == 0
|
||||||
|
assert s.tool_events == []
|
||||||
|
assert s.usage == {}
|
||||||
|
assert s.stop_reason is None
|
||||||
|
assert s.error is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# set_provider
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSetProvider:
|
||||||
|
def test_updates_provider_model_runner(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
new_provider = MagicMock(spec=LLMProvider)
|
||||||
|
sm.set_provider(new_provider, "new-model")
|
||||||
|
assert sm.provider is new_provider
|
||||||
|
assert sm.model == "new-model"
|
||||||
|
assert sm.runner.provider is new_provider
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# spawn
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpawn:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_string_with_task_id(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
sm.runner.run = AsyncMock(return_value=AgentRunResult(
|
||||||
|
final_content="done", messages=[], stop_reason="completed",
|
||||||
|
))
|
||||||
|
result = await sm.spawn("do something")
|
||||||
|
assert "started" in result
|
||||||
|
assert "id:" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_creates_task_in_running_tasks(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
block = asyncio.Event()
|
||||||
|
async def _slow_run(spec):
|
||||||
|
await block.wait()
|
||||||
|
return AgentRunResult(final_content="done", messages=[], stop_reason="completed")
|
||||||
|
sm.runner.run = _slow_run
|
||||||
|
|
||||||
|
await sm.spawn("task", session_key="s1")
|
||||||
|
assert len(sm._running_tasks) == 1
|
||||||
|
|
||||||
|
block.set()
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
assert len(sm._running_tasks) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_creates_status(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
sm.runner.run = AsyncMock(return_value=AgentRunResult(
|
||||||
|
final_content="done", messages=[], stop_reason="completed",
|
||||||
|
))
|
||||||
|
await sm.spawn("my task")
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
# Status cleaned up after task completes
|
||||||
|
assert len(sm._task_statuses) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_registers_in_session_tasks(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
block = asyncio.Event()
|
||||||
|
async def _slow_run(spec):
|
||||||
|
await block.wait()
|
||||||
|
return AgentRunResult(final_content="done", messages=[], stop_reason="completed")
|
||||||
|
sm.runner.run = _slow_run
|
||||||
|
|
||||||
|
await sm.spawn("task", session_key="s1")
|
||||||
|
assert "s1" in sm._session_tasks
|
||||||
|
assert len(sm._session_tasks["s1"]) == 1
|
||||||
|
|
||||||
|
block.set()
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
assert "s1" not in sm._session_tasks
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_session_key_no_registration(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
block = asyncio.Event()
|
||||||
|
async def _slow_run(spec):
|
||||||
|
await block.wait()
|
||||||
|
return AgentRunResult(final_content="done", messages=[], stop_reason="completed")
|
||||||
|
sm.runner.run = _slow_run
|
||||||
|
|
||||||
|
await sm.spawn("task")
|
||||||
|
assert len(sm._session_tasks) == 0
|
||||||
|
|
||||||
|
block.set()
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_label_defaults_to_truncated_task(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
block = asyncio.Event()
|
||||||
|
async def _slow_run(spec):
|
||||||
|
await block.wait()
|
||||||
|
return AgentRunResult(final_content="done", messages=[], stop_reason="completed")
|
||||||
|
sm.runner.run = _slow_run
|
||||||
|
|
||||||
|
long_task = "A" * 50
|
||||||
|
await sm.spawn(long_task, session_key="s1")
|
||||||
|
status = next(iter(sm._task_statuses.values()))
|
||||||
|
assert status.label == long_task[:30] + "..."
|
||||||
|
|
||||||
|
block.set()
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_custom_label(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
block = asyncio.Event()
|
||||||
|
async def _slow_run(spec):
|
||||||
|
await block.wait()
|
||||||
|
return AgentRunResult(final_content="done", messages=[], stop_reason="completed")
|
||||||
|
sm.runner.run = _slow_run
|
||||||
|
|
||||||
|
await sm.spawn("task", label="Custom Label", session_key="s1")
|
||||||
|
status = next(iter(sm._task_statuses.values()))
|
||||||
|
assert status.label == "Custom Label"
|
||||||
|
|
||||||
|
block.set()
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_callback_removes_all_entries(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
sm.runner.run = AsyncMock(return_value=AgentRunResult(
|
||||||
|
final_content="done", messages=[], stop_reason="completed",
|
||||||
|
))
|
||||||
|
await sm.spawn("task", session_key="s1")
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
assert len(sm._running_tasks) == 0
|
||||||
|
assert len(sm._task_statuses) == 0
|
||||||
|
assert len(sm._session_tasks) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _run_subagent
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunSubagent:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_successful_run(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
sm.runner.run = AsyncMock(return_value=AgentRunResult(
|
||||||
|
final_content="Task done!", messages=[], stop_reason="completed",
|
||||||
|
))
|
||||||
|
with patch.object(sm, "_announce_result", new_callable=AsyncMock) as mock_announce:
|
||||||
|
await sm._run_subagent(
|
||||||
|
"t1", "do task", "label",
|
||||||
|
{"channel": "cli", "chat_id": "direct"},
|
||||||
|
SubagentStatus(task_id="t1", label="label", task_description="do task", started_at=time.monotonic()),
|
||||||
|
)
|
||||||
|
mock_announce.assert_called_once()
|
||||||
|
assert mock_announce.call_args.args[-2] == "ok"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_error_run(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
sm.runner.run = AsyncMock(return_value=AgentRunResult(
|
||||||
|
final_content=None, messages=[], stop_reason="tool_error",
|
||||||
|
tool_events=[{"name": "read_file", "status": "error", "detail": "not found"}],
|
||||||
|
))
|
||||||
|
status = SubagentStatus(task_id="t1", label="label", task_description="do task", started_at=time.monotonic())
|
||||||
|
with patch.object(sm, "_announce_result", new_callable=AsyncMock) as mock_announce:
|
||||||
|
await sm._run_subagent(
|
||||||
|
"t1", "do task", "label",
|
||||||
|
{"channel": "cli", "chat_id": "direct"}, status,
|
||||||
|
)
|
||||||
|
assert mock_announce.call_args.args[-2] == "error"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exception_run(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
sm.runner.run = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||||
|
status = SubagentStatus(task_id="t1", label="label", task_description="do task", started_at=time.monotonic())
|
||||||
|
with patch.object(sm, "_announce_result", new_callable=AsyncMock) as mock_announce:
|
||||||
|
await sm._run_subagent(
|
||||||
|
"t1", "do task", "label",
|
||||||
|
{"channel": "cli", "chat_id": "direct"}, status,
|
||||||
|
)
|
||||||
|
assert status.phase == "error"
|
||||||
|
assert "LLM down" in status.error
|
||||||
|
assert mock_announce.call_args.args[-2] == "error"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_status_updated_on_success(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
sm.runner.run = AsyncMock(return_value=AgentRunResult(
|
||||||
|
final_content="ok", messages=[], stop_reason="completed",
|
||||||
|
))
|
||||||
|
status = SubagentStatus(task_id="t1", label="label", task_description="do task", started_at=time.monotonic())
|
||||||
|
with patch.object(sm, "_announce_result", new_callable=AsyncMock):
|
||||||
|
await sm._run_subagent(
|
||||||
|
"t1", "do task", "label",
|
||||||
|
{"channel": "cli", "chat_id": "direct"}, status,
|
||||||
|
)
|
||||||
|
assert status.phase == "done"
|
||||||
|
assert status.stop_reason == "completed"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _announce_result
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnnounceResult:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_publishes_inbound_message(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
published = []
|
||||||
|
sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg))
|
||||||
|
|
||||||
|
await sm._announce_result(
|
||||||
|
"t1", "label", "task", "result text",
|
||||||
|
{"channel": "cli", "chat_id": "direct"}, "ok",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(published) == 1
|
||||||
|
msg = published[0]
|
||||||
|
assert msg.channel == "system"
|
||||||
|
assert msg.sender_id == "subagent"
|
||||||
|
assert msg.metadata["injected_event"] == "subagent_result"
|
||||||
|
assert msg.metadata["subagent_task_id"] == "t1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_key_override(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
published = []
|
||||||
|
sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg))
|
||||||
|
|
||||||
|
await sm._announce_result(
|
||||||
|
"t1", "label", "task", "result",
|
||||||
|
{"channel": "telegram", "chat_id": "123", "session_key": "s1"}, "ok",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert published[0].session_key_override == "s1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_key_override_fallback(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
published = []
|
||||||
|
sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg))
|
||||||
|
|
||||||
|
await sm._announce_result(
|
||||||
|
"t1", "label", "task", "result",
|
||||||
|
{"channel": "telegram", "chat_id": "123"}, "ok",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert published[0].session_key_override == "telegram:123"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ok_status_text(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
published = []
|
||||||
|
sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg))
|
||||||
|
|
||||||
|
await sm._announce_result(
|
||||||
|
"t1", "label", "task", "result",
|
||||||
|
{"channel": "cli", "chat_id": "direct"}, "ok",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "completed successfully" in published[0].content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_error_status_text(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
published = []
|
||||||
|
sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg))
|
||||||
|
|
||||||
|
await sm._announce_result(
|
||||||
|
"t1", "label", "task", "error details",
|
||||||
|
{"channel": "cli", "chat_id": "direct"}, "error",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "failed" in published[0].content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_origin_message_id_in_metadata(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
published = []
|
||||||
|
sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg))
|
||||||
|
|
||||||
|
await sm._announce_result(
|
||||||
|
"t1", "label", "task", "result",
|
||||||
|
{"channel": "cli", "chat_id": "direct"}, "ok",
|
||||||
|
origin_message_id="msg-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert published[0].metadata["origin_message_id"] == "msg-123"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _format_partial_progress
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatPartialProgress:
|
||||||
|
def _make_result(self, tool_events=None, error=None):
|
||||||
|
return MagicMock(tool_events=tool_events or [], error=error)
|
||||||
|
|
||||||
|
def test_completed_only(self):
|
||||||
|
result = self._make_result(tool_events=[
|
||||||
|
{"name": "read_file", "status": "ok", "detail": "file content"},
|
||||||
|
{"name": "exec", "status": "ok", "detail": "output"},
|
||||||
|
])
|
||||||
|
text = SubagentManager._format_partial_progress(result)
|
||||||
|
assert "Completed steps:" in text
|
||||||
|
assert "read_file" in text
|
||||||
|
assert "exec" in text
|
||||||
|
|
||||||
|
def test_failure_only(self):
|
||||||
|
result = self._make_result(tool_events=[
|
||||||
|
{"name": "read_file", "status": "error", "detail": "not found"},
|
||||||
|
])
|
||||||
|
text = SubagentManager._format_partial_progress(result)
|
||||||
|
assert "Failure:" in text
|
||||||
|
assert "not found" in text
|
||||||
|
|
||||||
|
def test_completed_and_failure(self):
|
||||||
|
result = self._make_result(tool_events=[
|
||||||
|
{"name": "read_file", "status": "ok", "detail": "content"},
|
||||||
|
{"name": "exec", "status": "error", "detail": "timeout"},
|
||||||
|
])
|
||||||
|
text = SubagentManager._format_partial_progress(result)
|
||||||
|
assert "Completed steps:" in text
|
||||||
|
assert "Failure:" in text
|
||||||
|
|
||||||
|
def test_limited_to_last_three(self):
|
||||||
|
result = self._make_result(tool_events=[
|
||||||
|
{"name": f"tool_{i}", "status": "ok", "detail": f"result_{i}"}
|
||||||
|
for i in range(5)
|
||||||
|
])
|
||||||
|
text = SubagentManager._format_partial_progress(result)
|
||||||
|
assert "tool_2" in text
|
||||||
|
assert "tool_3" in text
|
||||||
|
assert "tool_4" in text
|
||||||
|
assert "tool_0" not in text
|
||||||
|
assert "tool_1" not in text
|
||||||
|
|
||||||
|
def test_error_without_failure_event(self):
|
||||||
|
result = self._make_result(
|
||||||
|
tool_events=[{"name": "read_file", "status": "ok", "detail": "ok"}],
|
||||||
|
error="Something went wrong",
|
||||||
|
)
|
||||||
|
text = SubagentManager._format_partial_progress(result)
|
||||||
|
assert "Something went wrong" in text
|
||||||
|
|
||||||
|
def test_empty_events_with_error(self):
|
||||||
|
result = self._make_result(error="Total failure")
|
||||||
|
text = SubagentManager._format_partial_progress(result)
|
||||||
|
assert "Total failure" in text
|
||||||
|
|
||||||
|
def test_empty_no_error_returns_fallback(self):
|
||||||
|
result = self._make_result()
|
||||||
|
text = SubagentManager._format_partial_progress(result)
|
||||||
|
assert "Error" in text
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# cancel_by_session
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCancelBySession:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancels_running_tasks(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
block = asyncio.Event()
|
||||||
|
async def _slow_run(spec):
|
||||||
|
await block.wait()
|
||||||
|
return AgentRunResult(final_content="done", messages=[], stop_reason="completed")
|
||||||
|
sm.runner.run = _slow_run
|
||||||
|
|
||||||
|
await sm.spawn("task1", session_key="s1")
|
||||||
|
await sm.spawn("task2", session_key="s1")
|
||||||
|
assert len(sm._session_tasks.get("s1", set())) == 2
|
||||||
|
|
||||||
|
count = await sm.cancel_by_session("s1")
|
||||||
|
assert count == 2
|
||||||
|
block.set()
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_tasks_returns_zero(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
count = await sm.cancel_by_session("nonexistent")
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_already_done_not_counted(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
sm.runner.run = AsyncMock(return_value=AgentRunResult(
|
||||||
|
final_content="done", messages=[], stop_reason="completed",
|
||||||
|
))
|
||||||
|
await sm.spawn("task1", session_key="s1")
|
||||||
|
await asyncio.sleep(0.1) # Wait for completion
|
||||||
|
|
||||||
|
count = await sm.cancel_by_session("s1")
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_running_count / get_running_count_by_session
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunningCounts:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_running_count_zero(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
assert sm.get_running_count() == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_running_count_tracks_tasks(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
block = asyncio.Event()
|
||||||
|
async def _slow_run(spec):
|
||||||
|
await block.wait()
|
||||||
|
return AgentRunResult(final_content="done", messages=[], stop_reason="completed")
|
||||||
|
sm.runner.run = _slow_run
|
||||||
|
|
||||||
|
await sm.spawn("t1", session_key="s1")
|
||||||
|
await sm.spawn("t2", session_key="s1")
|
||||||
|
assert sm.get_running_count() == 2
|
||||||
|
assert sm.get_running_count_by_session("s1") == 2
|
||||||
|
|
||||||
|
block.set()
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
assert sm.get_running_count() == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_running_count_by_session_nonexistent(self, tmp_path):
|
||||||
|
sm = _manager(tmp_path)
|
||||||
|
assert sm.get_running_count_by_session("nonexistent") == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _SubagentHook
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSubagentHook:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_before_execute_tools_logs(self, tmp_path):
|
||||||
|
hook = _SubagentHook("t1")
|
||||||
|
tool_call = MagicMock()
|
||||||
|
tool_call.name = "read_file"
|
||||||
|
tool_call.arguments = {"path": "/tmp/test"}
|
||||||
|
ctx = _make_hook_context(tool_calls=[tool_call])
|
||||||
|
# Should not raise
|
||||||
|
await hook.before_execute_tools(ctx)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_after_iteration_updates_status(self):
|
||||||
|
status = SubagentStatus(
|
||||||
|
task_id="t1", label="test", task_description="do", started_at=time.monotonic(),
|
||||||
|
)
|
||||||
|
hook = _SubagentHook("t1", status)
|
||||||
|
ctx = _make_hook_context(
|
||||||
|
iteration=3,
|
||||||
|
tool_events=[{"name": "read_file", "status": "ok", "detail": ""}],
|
||||||
|
usage={"prompt_tokens": 100},
|
||||||
|
)
|
||||||
|
await hook.after_iteration(ctx)
|
||||||
|
assert status.iteration == 3
|
||||||
|
assert len(status.tool_events) == 1
|
||||||
|
assert status.usage == {"prompt_tokens": 100}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_after_iteration_no_status_noop(self):
|
||||||
|
hook = _SubagentHook("t1", status=None)
|
||||||
|
ctx = _make_hook_context(iteration=5)
|
||||||
|
# Should not raise
|
||||||
|
await hook.after_iteration(ctx)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_after_iteration_sets_error(self):
|
||||||
|
status = SubagentStatus(
|
||||||
|
task_id="t1", label="test", task_description="do", started_at=time.monotonic(),
|
||||||
|
)
|
||||||
|
hook = _SubagentHook("t1", status)
|
||||||
|
ctx = _make_hook_context(error="something broke")
|
||||||
|
await hook.after_iteration(ctx)
|
||||||
|
assert status.error == "something broke"
|
||||||
Loading…
x
Reference in New Issue
Block a user